# Make a multi-panel figure of mean SO2 SCDs (calcualted by UTC time)
# over a specificed region

from matplotlib.collections import PolyCollection
from matplotlib.colors import LinearSegmentedColormap
import sys
sys.path.insert(0,'../readers/')
from read_h5_TEMPOSO2 import read_h5_TEMPOSO2_TRU
from read_CF import read_GEOS_CF_MET
#import earthaccess
import cartopy.io.shapereader as shpreader
import re 
from matplotlib.colors import LogNorm
import matplotlib.colors
import h5py
import numpy as np
import glob
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from pylab import *
from matplotlib import ticker


def plot_poly(axes,lon_corns,lat_corns,data,cmap,norm):

        #oned_clat = np.dstack((lat_corns[0,:],lat_corns[1,:],lat_corns[2,:],lat_corns[3,:]))
        #oned_clon = np.dstack((lon_corns[0,:],lon_corns[1,:],lon_corns[2,:],lon_corns[3,:]))
        oned_clat = np.dstack((lat_corns[:,0],lat_corns[:,1],lat_corns[:,2],lat_corns[:,3]))
        oned_clon = np.dstack((lon_corns[:,0],lon_corns[:,1],lon_corns[:,2],lon_corns[:,3]))
        verts = np.dstack((oned_clon[0,:,:],oned_clat[0,:,:]))

        inds, = np.where((~np.isnan(data)))

        plot_data = data[inds]
        coll = PolyCollection(verts[inds,:,:], array=plot_data, cmap=cmap,norm=norm,linewidth=0,antialiased=False,rasterized=True)
        axes.add_collection(coll)
        coll.set(array=plot_data, cmap=cmap,norm=norm)
        return coll




    #Making base map
def make_map(axes,icol, nrow,ncol):
    #Adding gridlines to the map 
    grid_ticks = 15
    if icol == 0:
        axes[icol].set_yticks(np.arange(-90,110,grid_ticks))

    axes[icol].set_xticks(np.arange(-180,200,grid_ticks))
    #axes[irow,icol].grid()

    #Adding country outlines, lake outlines, and coastlines 
    country = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_boundary_lines_land', scale='10m', facecolor='none')
    axes[icol].add_feature(country,edgecolor='k')
    lakes = cfeature.NaturalEarthFeature(category='physical',name='lakes',scale='10m',facecolor='None',lw=0.3)
    axes[icol].add_feature(lakes,edgecolor='k')
    land = cfeature.NaturalEarthFeature(category='physical',name='land',scale='10m',facecolor='None',lw=0.3)
    axes[icol].add_feature(land,edgecolor='k')

    #Adding highways/roads to map 
    shpfilename = shpreader.natural_earth(resolution='10m',category='cultural',name='roads')
    reader = shpreader.Reader(shpfilename)
    roads = reader.records()
    for road in roads:
        if (road.attributes['sov_a3'] == 'USA') &(road.attributes['level'] == 'Interstate') & (road.attributes['type']== 'Major Highway'):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
        if (road.attributes['sov_a3'] != 'USA') &(road.attributes['expressway'] == 1):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
    return fig, axes


# Here is the main plotting code
panel_header = np.array(['(a) SNPP/OMPS','(b) TEMPO w/o RMS filter','(c) TEMPO@OMPS Res.'])

period = '20240515_S010'

# Plotting boundary 
min_lon=-135; min_lat=15; max_lon=-52; max_lat=60

# Input file
infile = 'SNPP_TEMPO_2deg_noise_20240515_S010.h5'
f = h5py.File(infile,'r')
Lat_center = f['/Latitude_Center'][:]
Lon_center = f['/Longitude_Center'][:]
Lat_edge = f['/Latitude_Edge'][:]
Lon_edge = f['/Longitude_Edge'][:]
sigma_NPP = f['/sigma_NPP'][:]
sigma_TEMPO_nofilter = f['/sigma_TEMPO_nofilter'][:]
sigma_TEMPO_filter = f['/sigma_TEMPO_filter'][:]
sigma_TEMPO_filter_rebin = f['/sigma_TEMPO_filter_rebin'][:]
f.close()

inds = (np.isnan(sigma_TEMPO_filter)) | (np.isnan(sigma_NPP)) | (np.isnan(sigma_TEMPO_filter_rebin))



ncol =3    
nrow =1    
# Start map
fig, axes = plt.subplots(nrows=nrow,ncols=ncol,figsize=(13,4),subplot_kw=dict(projection=ccrs.PlateCarree()))

#Making custom colormap for given colors, cmin/cmax range
cmin = 0
cmax = 0.5
cdata = np.loadtxt('022_Hue_Sat_Value_2.dat',delimiter=',')
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("",cdata)
bounds = np.linspace(cmin,cmax,255)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

# Panel a), OMPS 
make_map(axes,0,nrow,ncol)
        
# Time for the measurements 
axes[0].set_title(panel_header[0], fontsize=14,loc='left')
        
data = sigma_NPP
#Zm = np.ma.array(data,mask=np.isnan(sigma_TEMPO_filter))
Zm = np.ma.array(data,mask=inds)
im = axes[0].pcolormesh(Lon_edge, Lat_edge, Zm, cmap=cmap, norm=norm)
axes[0].set_extent([min_lon,max_lon,min_lat,max_lat])

# Panel b), TEMPO 
make_map(axes,1,nrow,ncol)
        
# Time for the measurements 
axes[1].set_title(panel_header[1], fontsize=14,loc='left')
        
data = sigma_TEMPO_nofilter
#Zm = np.ma.array(data,mask=np.isnan(sigma_TEMPO_filter))
Zm = np.ma.array(data,mask=inds)
im = axes[1].pcolormesh(Lon_edge, Lat_edge, Zm, cmap=cmap, norm=norm)
axes[1].set_extent([min_lon,max_lon,min_lat,max_lat])

# Panel c), TEMPO binned to OMPS resolution 
make_map(axes,2,nrow,ncol)
        
# Time for the measurements 
axes[2].set_title(panel_header[2], fontsize=14,loc='left')
        
data = sigma_TEMPO_filter_rebin
#Zm = np.ma.array(data,mask=np.isnan(sigma_TEMPO_filter))
Zm = np.ma.array(data,mask=inds)
im = axes[2].pcolormesh(Lon_edge, Lat_edge, Zm, cmap=cmap, norm=norm)
axes[2].set_extent([min_lon,max_lon,min_lat,max_lat])
        
        
#cb = fig.colorbar(im,fraction=0.025)
cax = fig.add_axes([0.33, 0.15, 0.33, 0.025])
#cb = fig.colorbar(im,ax = axes[:],location='bottom',fraction=0.035,shrink=1.0)
cb = fig.colorbar(im,cax = cax,orientation='horizontal')
tick_locator = ticker.MaxNLocator(nbins=5)
cb.locator = tick_locator
cb.update_ticks()
cb.set_label('SO2 SCD Standard Deviation (DU)',fontsize=12)
#plt.tight_layout
#plt.suptitle(str(year)+'m'+str(month).zfill(2)+str(day).zfill(2)+' S011',fontsize=20)

plt.savefig( 'FigureS1_TEMPO_vs_OMPS_SO2_noise_'+period+'.png')
plt.clf()
plt.close()

