# Make a multi-panel figure of mean SO2 SCDs (calcualted by UTC time)
# over a specificed region

from matplotlib.collections import PolyCollection
from matplotlib.colors import LinearSegmentedColormap
import sys
sys.path.insert(0,'../readers/')
from read_h5_TEMPOSO2 import read_h5_TEMPOSO2_TRU
from read_CF import read_GEOS_CF_MET
#import earthaccess
import cartopy.io.shapereader as shpreader
import re 
from matplotlib.colors import LogNorm
import matplotlib.colors
import h5py
import numpy as np
import glob
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from pylab import *
from matplotlib import ticker
from scipy import stats


def plot_poly(axes,lon_corns,lat_corns,data,cmap,norm):

        #oned_clat = np.dstack((lat_corns[0,:],lat_corns[1,:],lat_corns[2,:],lat_corns[3,:]))
        #oned_clon = np.dstack((lon_corns[0,:],lon_corns[1,:],lon_corns[2,:],lon_corns[3,:]))
        oned_clat = np.dstack((lat_corns[:,0],lat_corns[:,1],lat_corns[:,2],lat_corns[:,3]))
        oned_clon = np.dstack((lon_corns[:,0],lon_corns[:,1],lon_corns[:,2],lon_corns[:,3]))
        verts = np.dstack((oned_clon[0,:,:],oned_clat[0,:,:]))

        inds, = np.where((~np.isnan(data)))

        plot_data = data[inds]
        coll = PolyCollection(verts[inds,:,:], array=plot_data, cmap=cmap,norm=norm,linewidth=0,antialiased=False,rasterized=True)
        axes.add_collection(coll)
        coll.set(array=plot_data, cmap=cmap,norm=norm)
        return coll




    #Making base map
def make_map(axes):
    #Adding gridlines to the map 
    grid_ticks = 15
    axes.set_yticks(np.arange(-90,110,grid_ticks))

    axes.set_xticks(np.arange(-180,200,grid_ticks))

    #Adding country outlines, lake outlines, and coastlines 
    country = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_boundary_lines_land', scale='10m', facecolor='none')
    axes.add_feature(country,edgecolor='k')
    lakes = cfeature.NaturalEarthFeature(category='physical',name='lakes',scale='10m',facecolor='None',lw=0.3)
    axes.add_feature(lakes,edgecolor='k')
    land = cfeature.NaturalEarthFeature(category='physical',name='land',scale='10m',facecolor='None',lw=0.3)
    axes.add_feature(land,edgecolor='k')

    #Adding highways/roads to map 
    shpfilename = shpreader.natural_earth(resolution='10m',category='cultural',name='roads')
    reader = shpreader.Reader(shpfilename)
    roads = reader.records()
    for road in roads:
        if (road.attributes['sov_a3'] == 'USA') &(road.attributes['level'] == 'Interstate') & (road.attributes['type']== 'Major Highway'):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
        if (road.attributes['sov_a3'] != 'USA') &(road.attributes['expressway'] == 1):
            pass
            #axes.add_geometries([road.geometry],ccrs.PlateCarree(),edgecolor='k',facecolor='None',lw=0.3,alpha=0.9)
    return fig, axes


# Here is the main plotting code
panel_header = np.array(['(a)','(b)'])

period = '20240515_S010'


# Input file
infile = 'TEMPOSO2_VCD_20240513-20240616_V3_RMS_fit0.08_CRF0.3_AMF0.0_S005_S012.h5'
f = h5py.File(infile,'r')
Lat_center = f['/Latitude_Center'][:]
Lon_center = f['/Longitude_Center'][:]
Lat_edge = f['/Latitude_Edge'][:]
Lon_edge = f['/Longitude_Edge'][:]
SCD_mean = f['/SCD_mean'][:]
VCD_mean = f['/VCD_mean'][:]
f.close()

inds = (SCD_mean < -1e29) | (VCD_mean < -1e29) 



ncol =1    
nrow =1    
# Start map
fig, axes = plt.subplots(nrows=nrow,ncols=ncol,figsize=(8,6),subplot_kw=dict(projection=ccrs.PlateCarree()))

# Panel a), Whole domain 
cmin = -0.5
cmax = 0.5
cdata = np.loadtxt('022_Hue_Sat_Value_2.dat',delimiter=',')
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("",cdata)
bounds = np.linspace(cmin,cmax,255)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
# Plotting boundary 
min_lon=-138; min_lat=15; max_lon=-48; max_lat=55
make_map(axes)
        
# Time for the measurements 
axes.set_title(panel_header[0], fontsize=14,loc='left')
        
data = VCD_mean
Zm = np.ma.array(data,mask=inds)
im = axes.pcolormesh(Lon_edge, Lat_edge, Zm, cmap=cmap, norm=norm)
axes.set_extent([min_lon,max_lon,min_lat,max_lat])

#cb = fig.colorbar(im,fraction=0.025)
cax = fig.add_axes([0.25, 0.15, 0.5, 0.02])
#cb = fig.colorbar(im,ax = axes[:],location='bottom',fraction=0.035,shrink=1.0)
cb = fig.colorbar(im,cax = cax,orientation='horizontal')
tick_locator = ticker.MaxNLocator(nbins=10)
cb.locator = tick_locator
cb.update_ticks()
cb.set_label('SO2 VCD (DU)',fontsize=14)

#plt.tight_layout
#plt.suptitle(str(year)+'m'+str(month).zfill(2)+str(day).zfill(2)+' S011',fontsize=20)

plt.savefig( 'Figure2a_TEMPO_mean_SO2_VCD_'+period+'.png')
plt.clf()
plt.close()

# Make histogram plot

bin_centers = np.arange(-0.495,0.495,0.01)
fix,axes = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
axes.hist(data[data > -1e29].ravel(),bins=np.arange(-0.5,0.5,0.01),edgecolor='black')
axes.set_title(panel_header[1], fontsize=14,loc='left')
axes.set_xlabel('SO2 VCD (DU)')
axes.set_ylabel('Number of Grid Cells')

# total number of data within certrain range
count_all = np.sum(data > -1e29)
count_p2 = np.sum((data >-0.2) & (data <0.2))
count_p1 = np.sum((data >-0.1) & (data <0.1))
count_p05 = np.sum((data >-0.05) & (data <0.05))
count_p02 = np.sum((data >-0.02) & (data <0.02))
perc_p2 = count_p2/count_all*100.
perc_p1 = count_p1/count_all*100.
perc_p05 = count_p05/count_all*100.
perc_p02 = count_p02/count_all*100.
axes.text(-0.48,110000,'-0.2 < VCD < 0.2: '+'{0:.2f}'.format(perc_p2)+'%',fontsize=12)
axes.text(-0.48,100000,'-0.1 < VCD < 0.1: '+'{0:.2f}'.format(perc_p1)+'%',fontsize=12)
axes.text(-0.48,90000,'-0.05< VCD <0.05: '+'{0:.2f}'.format(perc_p05)+'%',fontsize=12)
axes.text(-0.48,80000,'-0.02< VCD <0.02: '+'{0:.2f}'.format(perc_p02)+'%',fontsize=12)


plt.savefig( 'Figure2b_TEMPO_SO2_VCD_histogram_'+period+'.png')
plt.clf()
plt.close()
