'''
Written by Steve Flanagan, 2023
contact: saf204@gmail.com
'''

import seaborn as sns
#from itertools import count
import os, shutil, sys, gc, glob, warnings
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, date, timedelta
import matplotlib
import pandas as pd

#creates a time series from the CSVs made by FireCamImageExtractor
#################################################################
#                  START USER INPUT                             #
#################################################################
warnings.filterwarnings('ignore')
#enter path to fire
wdir = "G:/eglin_2023_firecam/FireOnly/serdp_img/ExtractedImagesNewCalFinal/"
wdir='C:/users/RxLabComp1/Desktop/Ft_Stewart_2024/FireOnly/serdp_img/ExtractedImagesNewCalFinal/'
wdir='C:/users/RxLabComp1/Desktop/Ft_Stewart_2024/2025_burns/FireOnly/combined_for_wfsi/ExtractedImagesNewCalFinal/'
pixel_size = .0025

#Get camera names
os.chdir(wdir)
camera_names = next(os.walk('.'))[1]

#make output directories
csvdir = wdir+"time_csvs/"
imgdir = wdir+"FinalImages/"
try:
    os.mkdir(imgdir)
    os.mkdir(csvdir)
except:
    pass

#make summary output table
sumstats = open(wdir+'summary_stats.csv', 'w+')
sumstats.write(str.join(',',["Unit","FRP_max (W/m2)","FRP_mean (W/m2)","Duration (sec)","FRE (W/m2)"])+'\n')


#loop through cameras
for c in camera_names:
    print('processing: ' +c)
    
    #make output file
    outputname = c.split('.')[0]
    f1 = open(wdir+'time_csvs/'+ outputname+"_time_data.csv","w")
    f1.write( str.join(',',["Original Time Stamp","Elapsed Seconds","AvgT (K)","FRFD (W/m2)","Total E (J)","FRFD no avg (W/m2)"]) +'\n')
    
    #get all IR csv images and loop through
    os.chdir(wdir+c+'/CSVfiles/')
    files = glob.glob('*.csv')
    intt =0
    skipfirst =0
    #count =0
    for f in files:
        data= np.loadtxt(f,skiprows=1,delimiter=',').astype(float)
        #get time
        time = str(pd.read_csv(f,index_col=0,nrows=0).columns.tolist()).split("# ")[1].split("'")[0]
        time_obj = datetime.strptime(time,'%H:%M:%S')
        
        #save initial time
        if intt==0:
            start_time = time_obj
            stime = time
            intt+=1

        #get elapsed time
        elapsed_time = (time_obj-start_time).total_seconds()

        #calculate frfd and joules with average temp
        avgt=np.nanmean(data)
        frfd = avgt**4*0.000000056703*(.98)
        joules = frfd*120*160*pixel_size
        
        #calcualte frfd using every pixel
        frfdnoavg = np.sum(data**4*0.000000056703*(.98)*pixel_size/48.0)
        
        #write out to file
        if skipfirst==0:
            skipfirst+=1
        else:
            f1.write(str.join(',',[str(time),str(elapsed_time),str(avgt),str(frfd),str(joules),str(frfdnoavg)])+"\n")
    f1.close()

    #plot the time file created
    sns.set_style('darkgrid')
    df = pd.read_csv(wdir+'time_csvs/'+ outputname+"_time_data.csv")
    time_val = df[["Elapsed Seconds"]].to_numpy()
    total_t = len(time_val)
    energy_val = df[["FRFD (W/m2)"]].to_numpy()
    plt.scatter(time_val,energy_val)
    plt.xlabel('Elapsed time (sec)')
    plt.ylabel('FRFD (W/m2)')
    plt.title("Start time (UTC): "+ stime + " FRP max: " +str(int(np.max(energy_val))) +" \nFRP mean: " +str(int(np.mean(energy_val))) + " Duration: " + str(total_t) + " FRE: " +str(int(np.mean(energy_val)*total_t)))
    plt.ticklabel_format(style='plain')
    plt.savefig(imgdir+outputname+"_energy_time.png",bbox_inches='tight')
    plt.clf()

    #update summary table
    sumstats.write( str.join(',', [outputname, str(np.max(energy_val)), str(np.mean(energy_val)), str(total_t), str(np.mean(energy_val)*total_t)]) +'\n')

    #make spatial images
    arr = np.load(wdir+c+'/'+c+'_array.npy')
    temp_summed = np.nansum(arr,axis=2)
    #get bounds for image range
    tl = temp_summed.flatten()
    tl.sort()
    plt.imshow(temp_summed,cmap = 'cividis', aspect =1.25,  vmin=tl[480],vmax=tl[18720]) #afmhot
    plt.title("Summed Temperature (K)")
    plt.colorbar()
    plt.savefig(imgdir+outputname+"_temp_summed.png")
    plt.close()
    plt.clf()

    #make histogram, remove zeros, removed highest/lowest 2.5%
    histo =[ x for x in tl if x !=0]
    histo=np.array(histo).astype('int')
    histo=histo[:-480]
    histo=histo[480:]
    plot=sns.displot(histo).set(title="Distribution of Summed temperature by pixel \n outliers (lowest/highest 2.5% removed)")
    plot.set_xlabels('Kelvin')
    plot.fig.savefig(imgdir+outputname+"_hist.png",bbox_inches='tight')
    plot.fig.clf()

#close summary file and exit
sumstats.close()
print('End of script')
sys.exit(1)

