#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 12 15:05:45 2020

@author: nicky
"""


import matplotlib as mpl
import netCDF4 as nc
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import numpy as np
#import cmocean
import matplotlib.gridspec as gridspec
from scipy.interpolate import griddata

#%%
mpl.rcdefaults()

#%%
# =============================================================================
# Importing DATA via OPENDAP
#
#Please verify the file paths
# =============================================================================
file = 'https://thredds-iow.io-warnemuende.de/thredds/dodsC/Baltic/IOW-THREDDS-Baltic_Burchard_etal_2020_2020-11-11-10.nc'

# =============================================================================
# slicing parameters
# =============================================================================
t=0
tend= None
iend=201
iright=1
jstart=1
jend=31

# =============================================================================
# Please refer the variable names using ncdump
# =============================================================================

with nc.Dataset(file,'r') as ncdata:
    print('Loading time...')
    time = np.float64(ncdata.variables['time'][t:tend])
    print('Loading salt...')
    salt = np.float64(ncdata.variables['salt'][t:tend,:,jstart:jend,iright:iend]) 
#    print('Loading uu...')
#    uu = np.float64(ncdata.variables['uu'][t:tend,:,jstart:jend,iright:iend]) 
    print('Loading zc...')
    zc   = np.float64(ncdata.variables['zc'][t:tend,:,jstart:jend,iright:iend]) 
#    print('Loading U...')
#    U    = np.float64(ncdata.variables['U'][t:tend,jstart:jend,iright-1:iend])
#    print('Loading xic...')
#    xic = np.float64(ncdata.variables['xic'][:])
#    print('Loading etac...')
#    etac = ncdata.variables['etac'][jstart:jend]
    #u   = np.float64(ncdata.variables['uu'][t:tend,:,jstart:jend,iright:iend])
    print('Loading h...')
    h    = np.float64(ncdata.variables['hn'][t:tend,:,jstart:jend,iright:iend])
#    print('Loading Sflux...')
#    Sflux = np.float64(ncdata.variables['Sfluxu'][t:tend,:,jstart:jend,iright:iend])
#    print('Loading Sfluxu2...')
#    Sfluxu2 = np.float64(ncdata.variables['Sfluxu2'][t:tend,:,jstart:jend,iright:iend])
    #print('Loading nummix...')
    #nummix = np.float64(ncdata.variables['nummix_salt'][t:tend,:,jstart:jend,iright:iend])
    #print('Loading phymix...')
    #phymix = np.float64(ncdata.variables['phymix_salt'][t:tend,:,jstart:jend,iright:iend])
    print('Loading bathymetry...')
    bathymetry = np.float64(ncdata.variables['bathymetry'][jstart:jend,iright:iend])
#%%   
# =============================================================================
# TEF DATA (sorted into Salinity classes S)
# =============================================================================
with nc.Dataset(file,'r') as ncdata:
    #print('Loading uu_s...')
    #uu_s = np.float64(ncdata.variables['uu_s'][t:tend,:,jstart:jend,iright:iend])
    print('Loading h_s...')
    h_s = np.float64(ncdata.variables['h_s'][t:tend,:,jstart:jend,iright:iend])
    #print('Loading hS_s...')
    #hS_s = np.float64(ncdata.variables['hS_s'][t:tend,:,jstart:jend,iright:iend])
    #print('Loading hS2_s...')
    #hS2_s = np.float64(ncdata.variables['hS2_s'][t:tend,:,jstart:jend,iright:iend])
    #print('Loading salt_s...')
    salt_s = np.float64(ncdata.variables['salt_s'][:])
#    print('Loading Sfluxu_s...')
#    Sfluxu_s = np.float64(ncdata.variables['Sfluxu_s'][t:tend,:,jstart:jend,iright:iend])
#    print('Loading S2fluxu_s...')
#    S2fluxu_s = np.float64(ncdata.variables['S2fluxu_s'][t:tend,:,jstart:jend,iright:iend])
    print('Loading hpmS_s...')
    hpmS_s = np.float64(ncdata.variables['hpmS_s'][t:tend,:,jstart:jend,iright:iend]) # phymixS content (bin)
    print('Loading hnmS_s...')
    hnmS_s = np.float64(ncdata.variables['hnmS_s'][t:tend,:,jstart:jend,iright:iend]) # nummixS content (bin)
    print('Loading flags_s...')
    flags_s = np.float64(ncdata.variables['flags_s'][t:tend,:,jstart:jend,iright:iend])

#%%
# =============================================================================
# Grid related data
# =============================================================================
with nc.Dataset(file,'r') as ncdata:
    print('Loading dx...')
    dx = np.float64(ncdata.variables['dxc'][jstart:jend,iright:iend])
    print('Loading dy...')
    dy = np.float64(ncdata.variables['dyc'][jstart:jend,iright:iend])
    print('Loading areaC...')
    areaC = np.float64(ncdata.variables['areaC'][jstart:jend,iright:iend])

#%%
# =============================================================================
# Corner points of each grid cell (only for ploting purposes)
#
# Please download the numpy file and veify the path
# =============================================================================
x = np.load('x-coordinate.npy',allow_pickle=True) 

y = np.load('y-coordinate.npy',allow_pickle=True) 
#%%

ds =(salt_s[1]-salt_s[0])

dA = areaC
#%%
# =============================================================================
# Computation of flags
# =============================================================================
print('Computation of flags...')
vol_test = h_s*dA

fl = np.ones_like(flags_s)

for tt in range(len(time)):
    for i in range(len(salt_s)):
        for j in range(30):
            for k in range(len(vol_test[1,1,1,:])):
                if(vol_test[tt,i,j,k]==0):
                    fl[tt,i,j,k]=0
#%%
# =============================================================================
# Temporal averaging of Mixing variables and layer height in the Salinity classes
# =============================================================================
print('Temporal averaging...')
hpmS_s_mean = np.mean(hpmS_s[:,:,:,:],axis=0) #Physical Mixing
hnmS_s_mean = np.mean(hnmS_s[:,:,:,:],axis=0) #Numerical Mixing
h_s_mean = np.mean(h_s[:,:,:,:],axis=0) #layer height
#%%
# =============================================================================
# Integration along X and Y; Division by ds to get variables per salinity class
# =============================================================================
print('Integration along X and Y...')
mms_phy = np.zeros(len(salt_s))
mms_num = np.zeros(len(salt_s))
vol_s = np.zeros(len(salt_s))

mms_phy = np.sum(np.sum(hpmS_s_mean*dA,axis=1),axis=1)/ds
mms_num = np.sum(np.sum(hnmS_s_mean*dA,axis=1),axis=1)/ds
vol_s = np.sum(np.sum(h_s_mean*dA,axis=1),axis=1)/ds

mms_total = mms_phy+mms_num

#%%
# =============================================================================
# Preparation for figure 9 (Mixing distribution)
# =============================================================================
print('Preparation for Fig 9...')
mms_phy_x = np.sum(hpmS_s_mean*dA/dx,axis=1)/ds
mms_num_x = np.sum(hnmS_s_mean*dA/dx,axis=1)/ds
vol_s_x = np.sum(h_s_mean*dA/dx,axis=1)/ds
##
mms_total_x = mms_phy_x+mms_num_x


#%%
# =============================================================================
# Integration along Y and division by ds
# =============================================================================
print('Integration along Y...')
a_x = fl*areaC
a_x = np.mean(a_x[:,:,:,:],axis=0)
a_x = np.sum(a_x,axis=1)

mms_phy_x1 = np.sum(hpmS_s_mean*dA,axis=1)/ds
mms_num_x1 = np.sum(hnmS_s_mean*dA,axis=1)/ds
vol_s_x1 = np.sum(h_s_mean*dA,axis=1)/ds

#%%
# =============================================================================
# Computation of diffusivities along X-direction
# =============================================================================
print('Computation diffusivities...')
K_phy_x = np.zeros((len(salt_s),len(a_x[1,:])),dtype = np.float64)
K_num_x = np.zeros((len(salt_s),len(a_x[1,:])),dtype = np.float64)

ids = np.where(a_x > 0)
K_phy_x[ids] = 0.5*mms_phy_x1[ids]*vol_s_x1[ids]/(a_x[ids]**2)
K_num_x[ids] = 0.5*mms_num_x1[ids]*vol_s_x1[ids]/(a_x[ids]**2)

K_phy_x_mean = np.sum(K_phy_x*vol_s_x,axis=1)/np.sum(vol_s_x,axis=1)
K_num_x_mean = np.sum(K_num_x*vol_s_x,axis=1)/np.sum(vol_s_x,axis=1)
K_total_x_mean = K_phy_x_mean + K_num_x_mean

K_total_x = K_phy_x+K_num_x

vol_s_x_mas = np.ma.masked_where(vol_s_x<10,vol_s_x)
mms_total_x_mas = np.ma.masked_where(mms_total_x==0,mms_total_x)
K_total_x_mas = np.ma.masked_where(K_total_x<=0,K_total_x)
K_total_x_mas2 = np.ma.masked_where(K_total_x_mas==np.nan,K_total_x_mas)
K_total_x_mean = np.ma.sum(K_total_x,axis=1)
#%%
# =============================================================================
# No Integration (variable distribution in [S,y,x])
# =============================================================================

mms_phy_x_y = hpmS_s_mean*dA/ds
mms_num_x_y = hnmS_s_mean*dA/ds
vol_s_x_y = h_s_mean*dA/ds

a_x_y = np.mean(fl*areaC,axis=0)

K_phy_x_y = np.zeros_like(mms_phy_x_y, dtype = np.float64)
K_num_x_y = np.zeros_like(mms_num_x_y, dtype = np.float64)

ids = np.where(a_x_y > 0)
K_phy_x_y[ids] = 0.5*mms_phy_x_y[ids]*vol_s_x_y[ids]/(a_x_y[ids]**2)
K_num_x_y[ids] = 0.5*mms_num_x_y[ids]*vol_s_x_y[ids]/(a_x_y[ids]**2)


K_total_x_y = K_phy_x_y+K_num_x_y

# =============================================================================
# Masking out negative values caused by anti-diffusive advection scheme occured at coarse grid
# =============================================================================
K_total_x_y_mas = np.ma.masked_where(K_total_x_y<=0,K_total_x_y)


#%%
# =============================================================================
# Mixing integrated along S
# =============================================================================

mms_total_onlyx = np.sum(mms_total_x,axis=0)


#%%
# =============================================================================
# Diahaline diffusivity
# =============================================================================
print('Diahaline diffusivity...')

a = fl*areaC

a = np.mean(a[:,:,:,:],axis=0)
a = np.sum(np.sum(a,axis=1),axis=1)


V= np.ma.masked_where(vol_s==0,vol_s)
aa = np.ma.masked_where(a==0,a)

K_phy = np.zeros(len(salt_s))
K_num = np.zeros(len(salt_s))

for i in range(len(salt_s)):
    K_phy[i] = ((0.5*mms_phy[i]*V[i])/(aa[i]**2))
    K_num[i] = ((0.5*mms_num[i]*V[i])/(aa[i]**2))

    
K_total = K_phy+K_num

K_total= np.ma.array(K_total)

Q_r=700

#%%
# =============================================================================
# Integrated Mixing for salinities < S
# =============================================================================
print('Integrated Mixing...')
mms_phyInt = np.zeros(len(salt_s))
mms_numInt = np.zeros(len(salt_s))

for i in range(len(salt_s)):
    mms_phyInt[i] = np.sum(mms_phy[:i].data)*ds
    mms_numInt[i] = np.sum(mms_num[:i].data)*ds

mms_totalInt = mms_phyInt+mms_numInt

#%%
# =============================================================================
# PLOTS!!!!!!!!!!!!!!!!!!
# =============================================================================
print('Plotting...')
plt.ion()
print('Fig. 8...')
plt.close('all')
xline =22
fig, ax = plt.subplots(2,2)#,sharex='col')
fig.tight_layout()
fig.set_size_inches(14,10)
ax[0,0].plot(salt_s[:-9],mms_total[:-9], label = "total mixing $m$")
ax[0,0].plot(salt_s[:-9],mms_phy[:-9], label = "physical mixing $m^{\mathrm{phy}}$")
ax[0,0].plot(salt_s[:-9],mms_num[:-9], label = "numerical mixing $m^{\mathrm{num}}$")
ax[0,0].plot(salt_s[:-9],2*salt_s[:-9]*Q_r, label='theoretical mixing $2SQ_{r}$')
ax[0,0].axvline(x=xline,color='k',linestyle='dashed')
#ax[0,0].set_xlabel("Salinity class $S$ (g/kg) ")
ax[0,0].set_ylabel("$m(S)$ [m$^{3}$s$^{-1}$(g/kg)]")
ax[0,0].set_title("a: mixing per salinity class")
ax[0,0].set_xlim(0,35)
#ax.set_ylim(-500,45000)
ax[0,0].legend()
ax[0,0].grid()
#ax.plot()
#ax.invert_yaxis()
#ax.invert_xaxis()

ax[0,1].plot(salt_s,aa, label = "area of isohaline $a(S)$")
ax[0,1].plot(salt_s,V, label = "volume per salinity class $v(S)$")
ax[0,1].axvline(x=xline,color='k',linestyle='dashed')
#ax[0,1].set_xlabel("Salinity class $S$ (g/kg) ")
ax[0,1].set_ylabel("$a(S)$ [m$^{2}$] , $v(S)$ [m$^{3}$(g/kg)$^{-1}$]")
ax[0,1].set_title("b: isohaline area and volume")
#ax.set_xscale('symlog')
ax[0,1].set_xlim(0,35)
ax[0,1].legend()
ax[0,1].grid()
ax[0,1].set_yscale('symlog')


ax[1,0].plot(salt_s,aa/V, label = "b$^{-1}$=a/v(s)")
ax[1,0].axvline(x=xline,color='k',linestyle='dashed')
ax[1,0].set_xlabel("$S$ [g/kg] ")
ax[1,0].set_ylabel("$b^{-1}$ [(g/kg)m$^{-1}$]")
ax[1,0].set_title("c: salinity gradient per salinity class")
ax[1,0].set_xlim(0,35)
#ax[1,0].legend()
ax[1,0].grid()

ax[1,1].plot(salt_s[:],K_total[:], label = r"total diffusivity $\overline{K_{n}}$")
ax[1,1].plot(salt_s[:],K_phy[:], label = r"physical diffusivity $\overline{K^{\mathrm{phy}}_{n}}$")
ax[1,1].plot(salt_s[:],K_num[:], label = r"numerical diffusivity $\overline{K^{\mathrm{num}}_{n}}$")
ax[1,1].axvline(x=xline,color='k',linestyle='dashed')
ax[1,1].set_xlabel("$S$ [g/kg] ")
ax[1,1].set_ylabel(r"$\overline{K_{n}}(S)$ (m$^{2}$ s$^{-1}$)")
ax[1,1].set_title("d: effective diahaline diffusivities")
ax[1,1].set_xlim(0,35)
#ax[1,1].set_ylim(0,0.00007)
ax[1,1].set_ylim(0.000001,0.001)
ax[1,1].set_yscale('Log')
ax[1,1].legend()
ax[1,1].grid(True,which='both')

fig.savefig('Figure_8.png', format='png', dpi=400)
plt.show()


#%%
print('Fig. 9...')
f = plt.figure()
f.set_size_inches(20,14)
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 3], height_ratios=[1,3])


ax = plt.subplot(gs[3])
cmap = plt.get_cmap('cubehelix_r')
#cmap = plt.get_cmap('cubehelix_r')
norm = colors.Normalize(vmin=mms_total_x_mas[:-9,7:].min(),vmax=mms_total_x_mas[:-9,7:].max())
#im = ax.pcolormesh(x[30,:]/1000, salt_s, vol_s_x_mas[:,:]) #cmap=cmap,norm=norm)
im = ax.pcolormesh(x[15,7:]/1000, salt_s[:-9], mms_total_x_mas[:-9,7:],cmap=cmap,norm=norm)
ax.set_xlabel('x [km]')
ax.set_ylabel('$S$ [g/kg]')
ax.set_title('a: total mixing per salinity class per meter')
ax.set_xlim(-100,-60)
ax.set_ylim(0,35)
#cbar = f.colorbar(im,ax=ax)
#cbar.ax.set_ylabel('m$^{2}$s$^{-1}$(g/kg)')

ax1 = plt.subplot(gs[2])
ax1.plot(mms_total[:-9],salt_s[:-9], label = "total mixing $m$")
#ax.plot(salt_s[:-11],mms_phy[:-11], label = "Mixing phy $m^{s} phy$")
#ax.plot(salt_s[:-11],mms_num[:-11], label = "Mixing num $m^{s} num$")
ax1.plot(2*salt_s[:-9]*Q_r,salt_s[:-9], label='theoretical mixing $2SQ_{r}$')
ax1.set_ylabel("$S$ [g/kg] ")
ax1.set_xlabel("m(S) [m$^{3}$s$^{-1}$(g/kg)]")
ax1.set_title("b: total mixing per salinity class")
#ax.set_xlim(29,0)
ax1.set_ylim(0,35)
ax1.legend(loc='lower right')
ax1.grid()

ax2 = plt.subplot(gs[1])
ax2.plot(x[15,7:]/1000,mms_total_onlyx[7:])#, label = "$\partial_x M(S)$")
#ax.plot(salt_s[:-11],mms_phy[:-11], label = "Mixing phy $m^{s} phy$")
#ax.plot(salt_s[:-11],mms_num[:-11], label = "Mixing num $m^{s} num$")
#ax2.plot(2*salt_s[:-11]*Q_r,salt_s[:-11], label='theoretical ($2SQ_{r}$)')
ax2.set_xlabel("x [km] ")
ax2.set_ylabel("$\partial_x M(S)$ [m$^{2}$s$^{-1}$(g/kg)$^{2}$]")
ax2.set_title("c: total mixing per meter")
ax2.set_xlim(-100,-60)
#ax.set_ylim(-500,45000)
#ax2.legend()
ax2.grid()

f.subplots_adjust(bottom=0.1, right=0.8, top=0.9,hspace = 0.25,wspace= 0.15)
cbar_ax = f.add_axes([0.85, 0.1, 0.02, 0.54])
f.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel('$\partial_x m(S)$ [m$^{2}$s$^{-1}$(g/kg)]')
#f.tight_layout()
f.savefig('Figure_9.png', format='png', dpi=400)


#%%
print('Fig. 10...')
f, ax = plt.subplots(1,2)
f.set_size_inches(14,7)
f.tight_layout()
cmap = plt.get_cmap('cubehelix_r')
#cmap = cmocean.cm.speed
#cmap = plt.get_cmap('CMRmap_r')

#cmap = cm.get_cmap(color)
norm1 = colors.LogNorm(vmin=vol_s_x_mas[:-9,:].min(),vmax=10e4)
im1 = ax[0].pcolormesh(x[15,:]/1000, salt_s[:-9], vol_s_x_mas[:-9,:],cmap=cmap,norm=norm1)
#im = ax.pcolormesh(x[30,:]/1000, salt_s, K_total_x_mas[:,:],cmap=cmap)
#ax[1].axhline(y=xline,color='k',linestyle='dashed')
ax[0].set_xlabel('x [km]')
ax[0].set_ylabel('$S$ [g/kg]')
ax[0].set_title('a: volume per salinity class per meter')
ax[0].set_xlim(-100,-60)
ax[0].set_ylim(0,35)
cbar = f.colorbar(im1,ax=ax[0])
cbar.ax.set_ylabel('$\partial_xv(S)$ [m$^{2}$(g/kg)$^{-1}$]')

norm = colors.LogNorm(vmin=1e-6,vmax=K_total_x_mas2.max())
im = ax[1].pcolormesh(x[15,7:]/1000, salt_s[:-9], K_total_x_mas2[:-9,7:],cmap=cmap,norm=norm)
#ax[0].axhline(y=xline,color='k',linestyle='dashed')
ax[1].set_xlabel('x [km]')
ax[1].set_ylabel('$S$ [g/kg]')
ax[1].set_title('b: effective total diahaline diffusivity')
ax[1].set_xlim(-100,-60)
ax[1].set_ylim(0,35)
cbar = f.colorbar(im,ax=ax[1])
cbar.ax.set_ylabel('$\overline{K_n}$ [m$^{2}$s$^{-1}$]')

f.tight_layout()
f.savefig('Figure_10.png', format='png', dpi=600)


#%%
print('Fig. 11...')
i = 25
f, ax = plt.subplots(3)
f.set_size_inches(15,15)
f.tight_layout()
cmap = plt.get_cmap('cubehelix_r')

index = np.array(np.where(np.round(salt_s,decimals=2)==i))
K_s_x_y = np.sum(K_total_x_y_mas[index[0,:],:,:],axis=0)

norm = colors.LogNorm(vmin=K_s_x_y.min(),vmax=K_s_x_y.max())
im = ax[0].pcolormesh(x[:,:]/1000, y[:,:]/1000, K_s_x_y[:,:],cmap=cmap,norm=norm)
ax[0].plot(x[0,:]/1000,y[0,:]/1000,color='k', linewidth=1)
ax[0].plot(x[-1,:]/1000,y[-1,:]/1000,color='k', linewidth=1)
ax[0].plot(x[:,0]/1000,y[:,0]/1000,color='k', linewidth=1)
#ax[0].set_xlabel('X (km)')
ax[0].set_ylabel('y [km]')
ax[0].set_title('a: $\overline{K_{n}}$  ($S$='+str(i)+'g/kg)')
ax[0].set_xlim(-110,-60)
ax[0].set_ylim(-10,10)

index = np.array(np.where(np.round(salt_s,decimals=2)==i))
K_s_x_y_phy = np.sum(K_phy_x_y[index[0,:],:,:],axis=0)

norm1 = colors.LogNorm(vmin=K_s_x_y_phy.min(),vmax=K_s_x_y_phy.max())
im1 = ax[1].pcolormesh(x[:,:]/1000, y[:,:]/1000, K_s_x_y_phy[:,:],cmap=cmap,norm=norm)
ax[1].plot(x[0,:]/1000,y[0,:]/1000,color='k', linewidth=1)
ax[1].plot(x[-1,:]/1000,y[-1,:]/1000,color='k', linewidth=1)
ax[1].plot(x[:,0]/1000,y[:,0]/1000,color='k', linewidth=1)
#ax[1].set_xlabel('X (km)')
ax[1].set_ylabel('y [km]')
ax[1].set_title('b: $\overline{K_{n}^{\mathrm{phy}}}$ ($S$='+str(i)+'g/kg)')
ax[1].set_xlim(-110,-60)
ax[1].set_ylim(-10,10)

index = np.array(np.where(np.round(salt_s,decimals=2)==i))
K_s_x_y_num = np.sum(K_num_x_y[index[0,:],:,:],axis=0)

norm2 = colors.LogNorm(vmin=K_s_x_y_num.min(),vmax=K_s_x_y_num.max())
im2 = ax[2].pcolormesh(x[:,:]/1000, y[:,:]/1000, K_s_x_y_num[:,:],cmap=cmap,norm=norm)
ax[2].plot(x[0,:]/1000,y[0,:]/1000,color='k', linewidth=1)
ax[2].plot(x[-1,:]/1000,y[-1,:]/1000,color='k', linewidth=1)
ax[2].plot(x[:,0]/1000,y[:,0]/1000,color='k', linewidth=1)
ax[2].set_xlabel('x [km]')
ax[2].set_ylabel('y [km]')
ax[2].set_title('c: $\overline{K_{n}^{\mathrm{num}}}$ ($S$='+str(i)+'g/kg)')
ax[2].set_xlim(-110,-60)
ax[2].set_ylim(-10,10)

f.subplots_adjust(left=0.1, bottom=0.05, right=0.8, top=0.95,hspace = 0.3,wspace= 0.3)
cbar_ax = f.add_axes([0.85, 0.1, 0.02, 0.80])
f.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel('$\overline{K_n}$ [m$^{2}$s$^{-1}$]')
f.savefig('Figure_11.png', format='png', dpi=600)




#%%
# =============================================================================
# TEF-averaged isohaline distribution
#Please verify the file path of s_TEF_2.csv
# =============================================================================
print('TEF-averaged isohaline distribution...')
height = np.loadtxt('s_TEF.csv',delimiter=',',usecols=[1])
index_h = np.loadtxt('s_TEF.csv',delimiter=',',usecols=[0])
tef = np.loadtxt('s_TEF.csv',delimiter=',',usecols=[2])

#%%
index = np.arange(0,200,1)

depth = np.arange(0,-15.1,-0.1)


i,d = np.meshgrid(index,depth)

#%%

s_tef = griddata((index_h,height),tef,(i,d),method='linear')

#%%
z_new = zc+h/2

z_new = np.mean(z_new,axis=0)

salt_mean= np.mean(salt[:,:,:,:],axis=0)

print('Fig. 6...')
xplot0,_ = np.meshgrid(x[15,:],np.zeros(151))
xplot,_ = np.meshgrid(x[15,:],np.zeros(30))
_,depthplot = np.meshgrid(np.zeros(len(x[15,:])),depth)
color = 'lightgray'
f, ax = plt.subplots(1)
f.set_size_inches(15,5)
#cmap = cmocean.cm.haline
cmap = plt.get_cmap('viridis')
norm = colors.Normalize(vmin=0,vmax=35)
norm1 = colors.Normalize(vmin=0,vmax=0)
ax.pcolormesh(x[15,:]/1000,z_new[:,15,:],salt_mean[:,15,:],cmap=cmap,norm=norm1)
im = ax.pcolormesh(x[15,:]/1000,depth[:],s_tef[:,:],cmap=cmap,norm=norm)
c=ax.contour(xplot0[0:150,:]/1000, depthplot[0:150,:], s_tef[0:150,:],levels=35,colors='k')
ax.contour(xplot0/1000, depthplot, s_tef,levels=[15],colors='red',linewidths=3)
#ax[0].clabel(c,c.levels,inline=True,fmt='%d')
ax.set_ylabel("z [m] ")
ax.set_xlabel("x [km]")
ax.set_xlim(-100,-50)
ax.set_ylim(-15,2.5)
#ax.set_title(' TEF-based average')
ax.set_facecolor(color)


cbar=f.colorbar(im, ax=ax)
cbar.ax.set_ylabel('salinity [g/kg]')
#f.tight_layout()
f.savefig('Figure_6.png', format='png', dpi=600)
