# Make a multi-panel figure of mean SO2 SCDs (calcualted by UTC time)
# over a specificed region
# Overlay GEOS-CF surface wind if needed

from matplotlib.collections import PolyCollection
from matplotlib.colors import LinearSegmentedColormap
import sys
sys.path.insert(0,'../readers/')
from read_h5_TEMPOSO2 import read_h5_TEMPOSO2_TRU
from read_CF import read_GEOS_CF_MET
#import earthaccess
import cartopy.io.shapereader as shpreader
import re 
from matplotlib.colors import LogNorm
import matplotlib.colors
import h5py
import numpy as np
import glob
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from pylab import *
from matplotlib import ticker


def plot_poly(axes,lon_corns,lat_corns,data,cmap,norm):

        #oned_clat = np.dstack((lat_corns[0,:],lat_corns[1,:],lat_corns[2,:],lat_corns[3,:]))
        #oned_clon = np.dstack((lon_corns[0,:],lon_corns[1,:],lon_corns[2,:],lon_corns[3,:]))
        oned_clat = np.dstack((lat_corns[:,0],lat_corns[:,1],lat_corns[:,2],lat_corns[:,3]))
        oned_clon = np.dstack((lon_corns[:,0],lon_corns[:,1],lon_corns[:,2],lon_corns[:,3]))
        verts = np.dstack((oned_clon[0,:,:],oned_clat[0,:,:]))

        inds, = np.where((~np.isnan(data)))

        plot_data = data[inds]
        coll = PolyCollection(verts[inds,:,:], array=plot_data, cmap=cmap,norm=norm,linewidth=0,antialiased=False,rasterized=True)
        axes.add_collection(coll)
        coll.set(array=plot_data, cmap=cmap,norm=norm)
        return coll




    #Making base map
def make_map(axes,irow,icol, nrow,ncol):
    #Adding gridlines to the map 
    grid_ticks = 1
    if icol == 0:
        axes[irow,icol].set_yticks(np.arange(-90,110,grid_ticks))
    if irow == nrow-1:
        axes[irow,icol].set_xticks(np.arange(-180,200,grid_ticks))
    #axes[irow,icol].grid()

    #Adding country outlines, lake outlines, and coastlines 
    country = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_boundary_lines_land', scale='10m', facecolor='none')
    axes[irow,icol].add_feature(country,edgecolor='k')
    lakes = cfeature.NaturalEarthFeature(category='physical',name='lakes',scale='10m',facecolor='None',lw=0.3)
    axes[irow,icol].add_feature(lakes,edgecolor='k')
    land = cfeature.NaturalEarthFeature(category='physical',name='land',scale='10m',facecolor='None',lw=0.3)
    axes[irow,icol].add_feature(land,edgecolor='k')

    #Adding highways/roads to map 
    shpfilename = shpreader.natural_earth(resolution='10m',category='cultural',name='roads')
    reader = shpreader.Reader(shpfilename)
    roads = reader.records()
    for road in roads:
        if (road.attributes['sov_a3'] == 'USA') &(road.attributes['level'] == 'Interstate') & (road.attributes['type']== 'Major Highway'):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
        if (road.attributes['sov_a3'] != 'USA') &(road.attributes['expressway'] == 1):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
    return fig, axes


# Here is the main plotting code
inpath_TEMPO = './byhour/'
hour_all = np.array(['T14','T15','T16','T17','T18','T19','T20','T21','T22'])
panel_header = np.array(['(a) 14UTC','(b) 15UTC','(c) 16UTC','(d) 17UTC','(e) 18UTC',
    '(f) 19UTC','(g) 20UTC','(h) 21UTC','(i) 22UTC'])

period = '20240513_20240616'

#region = 'Houston'; min_lon=-100; min_lat=28; max_lon=-92; max_lat=35
#region = 'NEUS'; min_lon=-80; min_lat=36; max_lon=-71; max_lat=42
#region = 'CONUS'; min_lon=-140; min_lat=15; max_lon=-45; max_lat=60
region = 'Cantarell'; min_lon=-94; min_lat=17; max_lon=-90; max_lat=21
x_box = [-92.4,-91.9,-91.9,-92.4,-92.4]
y_box = [19.25,19.25,19.75,19.75,19.25]
x_box1 = [-93.4,-92.9,-92.9,-93.4,-93.4]
y_box1 = [17.7,17.7,18.2,18.2,17.7]

ncol =3    
nrow =3    
# Start map
fig, axes = plt.subplots(nrows=nrow,ncols=ncol,figsize=(13.5,12),subplot_kw=dict(projection=ccrs.PlateCarree()))

#Making custom colormap for given colors, cmin/cmax range
cmin = -2
cmax = 2
cdata = np.loadtxt('022_Hue_Sat_Value_2.dat',delimiter=',')
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("",cdata)
bounds = np.linspace(cmin,cmax,255)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

k = 0
for irow in range(0,nrow):
    for icol in range(0,ncol):
        make_map(axes,irow,icol,nrow,ncol)
        hour = hour_all[k]
        infiles = glob.glob(inpath_TEMPO + '*'+hour+'*h5')
        
        print(infiles)
        
        # Loop through files
        for infile in infiles:
            print('TEMPO SO2 file: ', infile)
            # Time for the measurements 
            axes[irow,icol].set_title(panel_header[k], fontsize=14,loc='left')
            # Find the corresponding file for 
#            infiles_CF = glob.glob('./'+'GEOS-CF.v01.rpl.met_tavg_1hr_g1440x721_v36.'+date+
#                    '_'+HH+'30z'+'*.nc4')
#            if not infiles_CF:
#                print('Cannot find matching CF MET file')
#                sys.exit()
        
            # Read TEMPO SO2 data
            f = h5py.File(infile,'r')
            Lat_center = f['/Lat_Center'][:]
            Lon_center = f['/Lon_Center'][:]
            Lat_edge = f['/Lat_Edge'][:]
            Lon_edge = f['/Lon_Edge'][:]
            SCD_mean = f['/SCD_mean'][:]/2.69E16
            VCD_mean = f['/VCD_mean'][:]
            data_count = f['/data_count'][:]
            f.close()
        
            data = VCD_mean
            #idx_x = np.squeeze(np.where((Lon_center > np.min(x_box)) & (Lon_center < np.max(x_box))))
            #idx_y = np.squeeze(np.where((Lat_center > np.min(y_box)) & (Lat_center < np.max(y_box))))
            SCD_mean_box = np.mean(SCD_mean[45:54,352:362])
            VCD_mean_box = np.mean(VCD_mean[45:54,352:362])
            print(hour, ', SCD_mean_box: ', SCD_mean_box, ', VCD_mean_box: ', VCD_mean_box)
            Zm = np.ma.array(data,mask=np.isnan(data))
            im = axes[irow,icol].pcolormesh(Lon_edge, Lat_edge, Zm, cmap=cmap, norm=norm)
            axes[irow,icol].plot(x_box,y_box,color='gray')
            #axes[irow,icol].plot(x_box1,y_box1,color='gray')
        
#            # Read the GEOS-CF Met file
#            print('GEOS-CF MET file: ', infiles_CF[0])
#            GEOS = read_GEOS_CF_MET(infiles_CF[0])
#            print(GEOS.U.shape, GEOS.V.shape,GEOS.Lat.shape, GEOS.Lon.shape, GEOS.SLP.shape, GEOS.DELP.shape)
#            nlayer = GEOS.U.shape[1]
#            # Calculate the pressure at each layer
#            #for ind in range(0,len(G5_layers)-1):
#            print('mean and sigma of height:',
#                    np.mean(GEOS.mid_layer_heights[:,ilayer,:,:])
#                    ,np.std(GEOS.mid_layer_heights[:,ilayer,:,:]))
#        
#            indy =  (GEOS.Lat > min_lat) & (GEOS.Lat < max_lat)
#            indx =  (GEOS.Lon > min_lon) & (GEOS.Lon < max_lat)
#            Lat_G5 = GEOS.Lat[indy]
#            Lon_G5 = GEOS.Lon[indx]
#            U_layer = np.squeeze(GEOS.U[:,ilayer,:,:])
#            V_layer = np.squeeze(GEOS.V[:,ilayer,:,:])
#    
#            # Subset UV data
#            U = U_layer[np.ix_(indy,indx)]
#            V = V_layer[np.ix_(indy,indx)]
#    
#    
#    
#            #Plotting TEMPO SO2 polygons on map 
#            im = plot_poly(axes[irow,icol],
#                    loncnr[tempo_inds,:],latcnr[tempo_inds,:],SO2_TRU[tempo_inds],cmap,norm)
#    
#            #Add wind vectors
#    
#            #Q = axes.quiver(GEOS.Lon[indx],GEOS.Lat[indy],U[indy,indx],V[indy,indx],scale=300)
#            Q = axes[irow,icol].quiver(Lon_G5[::4],Lat_G5[::4],U[::4,::4],V[::4,::4],scale=1000)
#            qk =axes[irow,icol].quiverkey(Q, 0.9, 1.02, 30, r'$30 \frac{m}{s}$', labelpos='N',coordinates='axes',fontproperties={'size': 12})
#
            axes[irow,icol].set_extent([min_lon,max_lon,min_lat,max_lat])
            # End of Loop
            k = k+1
        
#cb = fig.colorbar(im,fraction=0.025)
cb = fig.colorbar(im,ax = axes[:,:],location='right',fraction=0.015)
tick_locator = ticker.MaxNLocator(nbins=10)
cb.locator = tick_locator
cb.update_ticks()
cb.set_label('SO2 VCD (DU)',fontsize=14)
#plt.tight_layout
#plt.suptitle(str(year)+'m'+str(month).zfill(2)+str(day).zfill(2)+' S011',fontsize=20)

plt.savefig( 'TEMPOSO2_mean_VCD_by_hour'+period+'_'+region+'_Cantarell.png')
plt.clf()
plt.close()

#    #earthaccess.download(results[fname_id],tempo_file_path)

##Grabbing TEMPO NO2 files after downloaded locally 
#files = glob.glob(tempo_file_path+'*'+str(year)+str(month).zfill(2)+str(day).zfill(2)+'*')


#    tempo_inds = (lat > min_lat) & (lat < max_lat) & (lon < max_lon) & (lon > min_lon) & (gpqf & 15  == 1)
    #tempo_inds = (lat > min_lat) & (lat < max_lat) & (lon < max_lon) & (lon > min_lon) & (flag_saturation == 0)

    #cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["white","lightsteelblue","orange","red","firebrick","maroon"])
    #cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["white","whitesmoke","orange","red","firebrick","maroon"])
    #cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["white","whitesmoke","lightsteelblue","orange","red"])
    #cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["white","plum","mediumslateblue","lightcyan","lime","yellow","red"])
    #cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["lightgrey","cyan","yellow","orange","red"])
    #bounds = np.linspace(cmin,cmax,21)
    #norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
