'''
Written by Steve Flanagan, 2022
contact: saf204@gmail.com

Code to extract images from the FireCams built by Dan Julio
ref: https://github.com/danjulio/firecam

Can run through all cameras at once if placed in a folder above
where all the ".json" files are stored

Creates a folder "ExtractedImages" with a sub-folder for each camera

Each camera folder will have then have sub-folders "OriginalImages"
which are the extracted visual and IR images, and "CSVfiles" which
are the IR x,y coordinates needed for the IR_processing script

If the user wants, they can turn on "additional_images" and "movie"
in the inputs below. Additional folders will then be made that 
produce images and a movie that might help the user when they run the
IR_processing script.

'''
################################################################# 
#                  START USER INPUT  
#################################################################

#Enter the file path to the '.json' files produced by the Firecams

#wdir='C:/users/RxLabComp1/Desktop/Ft_Stewart_2024/FireOnly/serdp_img/'
wdir='C:/users/RxLabComp1/Desktop/Ft_Stewart_2024/2025_burns/FireOnly/combined_for_wfsi/'
#wdir='C:/users/RxLabComp1/Desktop/Ft_Stewart_2024/2025_burns/FireOnly/BU8/'
#wdir='C:/users/RxLabComp1/Desktop/gallagher_data/card1_cameraFF39/SDHC/'

#wdir = "/mnt/c/Users/sflanagan/Desktop/thermal_camera/trial/" #Ubuntu

#Filter adjustments, Set nd_mult to 1 and nd_ladj is zero if a filter isn't used
nd_mult = 7.1        #the "m" in y=mx + b
nd_ladj = -136.4      #the "b"

#Metadata should store ambient temperatures, and the code pulls that in for an adjustment
#In the rare case where the meta data fails to record temp, this is used as the default
#In Celcius
amb_temp = 25


########################OPTIONAL INPUTS###########################

#Set to 0 if you do not want them, 1 if you do
#Additional images are primarily used for comparing visual to IR

additional_images = 0   #Set to 1 if you want to make additional images
movie =  0               #Set to 1 if you want to make a move, 
                        #additional_images nust also be 1

#Temperatire scale for additional IR images - Kelvin
min_temp = 300
max_temp = 1200
                        

#################################################################
#                  END USER INPUT                             #
#################################################################



#libraries
import os, shutil, sys, base64, json, gc
import numpy as np
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import matplotlib
matplotlib.use('agg')
from datetime import datetime

if movie==1:
    import cv2


os.chdir(wdir)

# Get a list of all json files in the folder and subfolders 
# Check that files are found, if not, exit
# Check that data hasn't already been processed
# If it has, ask the user if they would like to reprocess, otherwise exit

#Check if files already processed
print("Making a directory for the extracted images")
try:
    os.mkdir(wdir+"ExtractedImagesNewCalFinal/")
except:
    print("Directory already exists")
    overwrite = input("To overwrite, enter 'y'")
    if overwrite == "y":
        print("Removing directory, Reprocessing images")
        shutil.rmtree(wdir+"ExtractedImagesNewCalFinal/")
        os.mkdir(wdir+"ExtractedImagesNewCalFinal/")
    else:
        print('Exiting script')
        sys.exit(1)

#Get a list of cameras
obj_in_wdir = os.scandir(wdir)
cameras=[]
for entry in obj_in_wdir:
    if entry.is_dir():
        cameras.append(entry.name)

cameras = [x for x in cameras if not x.startswith('Ext')]

print("Cameras found to process:")
print(cameras)


for c in range(0,len(cameras)):
#for c in range(2,3):
    

    filelist = []
    print("Processing camera " +str(c+1) +" of " + str(len(cameras)))
    print("Checking for 'json' files in " +cameras[c] +" folder and sub-folders")


    #Check for json files, list if found, exit otherwise
    for r, d, f in os.walk(wdir+cameras[c]+"/"):
        for file in f:
            if '.json' in file:
                filelist.append(os.path.join(r,file).replace("\\","/"))
    if len(filelist) > 0:
        print(str(len(filelist))+" .json files found for " +cameras[c])
    else:
        print("No 'json' files found, check that file structure is correct")
        print("Exiting script")
        sys.exit(1)


    # Output Directories
    print("Creating output dirs")
    ndir = wdir+"ExtractedImagesNewCalFinal/"+cameras[c]+"/"
    os.makedirs(ndir)
    os.chdir(ndir)
    os.makedirs("OriginalImages")
    os.makedirs("CSVfiles")
    if additional_images ==1:
        os.makedirs("ResizedImages")
        os.makedirs("OverlayedImages")
        os.makedirs("CombinedImages")
        if movie == 1:
            os.makedirs("IRandVisualOnly")

    
    #Error log file
    errorlog = open("errorlog.txt", 'w+')
    errorlog.write("Camera           Type        Failed Image \n")  
                  

    # Image failure checks and total count
    radfail = 0
    jpegfail = 0
    count =0

    # Extract images for each json file
    print("Every failed rad or jpeg will print to the screen")
    print("And an update of every 1000 files processed")

    for f in filelist:

        # Check for if either image failed - if they did, don't do extra images
        imgFail = 0

        #Open file
        myjsonfile = json.load( open(f) )
        outputname = f.split('/')[-1].split('.json')[0]
        cameraName = myjsonfile['metadata']['Camera']

        if count ==0:

            #camera sees itself. baseline copies first image
            #and then every image after this subtracts that value
            #to perform the correction.
            img_array=np.zeros((120,160,len(filelist)-1)) 
            baseline=np.zeros((120,160)) 
            try:
                temp_adj = myjsonfile['metadata']['Lens Temp']
            except:
                temp_adj = amb_temp
            
            #x,y,t array parameters
            time_array = np.zeros((len(filelist)-1))
            time_start = myjsonfile["metadata"]["Time"]
            date_format = '%H:%M:%S'
            ts_converted = datetime.strptime(time_start,date_format)
            

        
        if count%1000==0:
            print('Processing image ' + str(count) + " of " + str(len(filelist))+" for " +cameras[c])
        
        # Decode the jpeg data and save the visual image
        try:
            jpegbinary = base64.b64decode( myjsonfile['jpeg'] ) 
            o = open("OriginalImages/"+outputname+"_visual.jpg", 'wb')
            o.write(jpegbinary)
            o.close()
        except:
            print("jpeg failed " + f)
            errorlog.write(cameraName +"     Visible       "   +f.split('/')[-1].split('.json')[0] + "\n")
            jpegfail += 1
            imgFail = 1


        # Decode the radiometric data
        try:
            radbinary = base64.b64decode( myjsonfile['radiometric'] )
            raddata2D = np.arange(19200).reshape(120,160)
            raddata2D= np.frombuffer(radbinary,dtype=np.uint16).reshape(raddata2D.shape)
        
            # Convert it based on gain used to C
            final_array = np.zeros((120,160))
            try:
                res = float(myjsonfile["metadata"]["Lepton Resolution"])
            except:
                res = 0.01 #default is high gain
            if res == 0.1:
                final_array = (raddata2D - 2731.50)/10.00*float(nd_mult)+float(nd_ladj)
            else:
                final_array = (raddata2D - 27315.0)/100.0*float(nd_mult)+float(nd_ladj)

            if count==0:
                baseline = final_array
            else:
                final_array-=baseline
                final_array+=temp_adj+273.15
            
            #Save arrays
            if count != 0:
                #set missfires to ambient temp
                final_array[final_array>3000]=temp_adj+273.15
                img_array[:,:,count-1]=final_array
                #get current time and difference from start
                current_time = myjsonfile["metadata"]["Time"]
                ct_converted = datetime.strptime(current_time,date_format)
                diff = ct_converted-ts_converted
                diff_int = int(diff.total_seconds())
                time_array[count-1]=diff_int 
                            
            # Save as CSV for data processing
            outputcsv = f.split('/')[-1].split('.json')[0]+'_IR.csv'
            np.savetxt("CSVfiles/"+outputcsv, final_array.astype(float), delimiter=',', header = myjsonfile["metadata"]["Time"])

            # Save as image
            outputrad = f.split('/')[-1].split('.json')[0]+'_IR.png'
            #plt.imsave("OriginalImages/"+outputrad, final_array, cmap = 'RdBu_r', vmin = min_temp, vmax= max_temp)
            plt.imshow(final_array, cmap = 'RdBu_r')
            plt.colorbar()
       
            plt.title('Temperature (K)')
            plt.savefig("OriginalImages/"+outputrad)
            plt.cla()
            plt.clf()
            plt.close('all')
            gc.collect()
        
        except:
            print('rad failed ' + f )
            errorlog.write(cameraName +"     Radiometric   "  + f.split('/')[-1].split('.json')[0] +"\n")
            radfail +=1
            
            imgFail = 1

        count +=1
        # If both images were recorded, perform some comparisons
        if imgFail == 0 and additional_images==1:
            # Resize the visual pixels (640 x 480) to matcth the radiometric (160x120)
            jpeg_resized = np.array(Image.open("OriginalImages/"+outputname+"_visual.jpg").resize((160,120)))
            plt.imsave("ResizedImages/"+outputname+".png",jpeg_resized)
            #plt.savefig(ndir+"/ResizedImages/"+outputjpeg)
            plt.cla()
            plt.clf()
            plt.close('all')
            gc.collect()

            # Overlay the two images
        
            plt.imshow(final_array, cmap='tab20c' , vmin = min_temp, vmax= max_temp)
            plt.colorbar()
            plt.imshow(jpeg_resized,alpha=.7)
            plt.savefig("OverlayedImages/"+outputname)
            plt.cla()
            plt.clf()
            plt.close('all')
            gc.collect

            # Combine visual, IR, overlayed, and a grid into one image
            fig, ax = plt.subplots(nrows=2, ncols=2, clear=True)        
            ax[0,0].imshow(final_array,cmap = 'RdBu_r', vmin = min_temp, vmax= max_temp)
            ax[0,1].imshow(jpeg_resized)
            ax[1,1].imshow(jpeg_resized)
            ax[1,1].imshow(final_array, cmap='RdBu_r' ,alpha = .7, vmin = min_temp, vmax= max_temp)
            ax[1,0].set_xticks(np.arange(0,160,5))
            ax[1,0].set_yticks(np.arange(0,120,5))
            ax[1,0].grid(color='black')
            ax[1,0].imshow(final_array,cmap = 'RdBu_r', vmin = min_temp, vmax= max_temp)
            plt.savefig("CombinedImages/"+outputname)
            ax[1,1].cla()
            ax[0,0].cla()
            ax[1,0].cla()
            ax[0,1].cla()
            plt.clf()
            plt.close('all')
            plt.close(fig)
            gc.collect()

            #For movie
            if movie==1:
                fig, ax = plt.subplots(nrows=1, ncols=2, clear=True)
                ax[0].axis('off')
                ax[0].set_title('IR')
                ax[0].imshow(final_array,cmap = 'RdBu_r', vmin = min_temp, vmax= max_temp)
                ax[1].axis('off')
                ax[1].set_title('Visual')
                ax[1].imshow(jpeg_resized)
                plt.savefig("IRandVisualOnly/"+outputname)
                ax[0].cla()
                ax[1].cla()
                plt.clf()
                plt.close('all')
                plt.close(fig)
                gc.collect()


    #Make a movie of IR and visual combined
    if movie ==1:
        
        #os.chdir(ndir+cameras[c])

        image_folder = 'IRandVisualOnly'
        video_name = 'IRandVisual.avi'

        images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
        frame = cv2.imread(os.path.join(image_folder, images[0]))
        height, width, layers = frame.shape

        video = cv2.VideoWriter(video_name, 0, 1, (width,height))

        for image in images:
            video.write(cv2.imread(os.path.join(image_folder, image)))

        cv2.destroyAllWindows()
        video.release()


    #Write out total failures and close script 
    errorlog.write("\n\n\n Total camera failues: \n Visible "+str(jpegfail)+"\n Radiometric " +str(radfail))
    errorlog.close()
    img_array[0,0,:] = time_array
    np.save(cameras[c]+"_array",img_array)
    

#endtime = datetime.datetime.now()
#print("end time " + endtime.strftime("%Y-%m-%d %H:%M:%S"))
#print("start time " + starttime.strftime("%Y-%m-%d %H:%M:%S"))
print('End of script')
sys.exit(1)
