#!/usr/bin/python3

import math
import random
import multiprocessing
import subprocess
import argparse
import signal
import os
import sys
import time
import csv
import numpy as np

keep_going = 1

def handler(signum, frame):
    global keep_going
    keep_going = 0
    subprocess.run(["voxl-set-cpu-mode", "auto_with_pitmode"])
    print('\nCtrl+C was pressed. Exiting gracefully...')
    # keep_going flag doesn't seem reliable yet, just quit
    sys.exit(0)

def hogcpu():
    while True:
        math.sqrt(random.randint(1, 10))

def get_px4_baro():
    cmd = 'px4-listener sensor_baro'
    #cmd = 'px4-listener vehicle_air_data'

    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    timestamp   = None
    pressure    = None
    temperature = None
    device_id   = None

    for line in p.stdout.readlines():
        l=str(line)
        #print(l)
        if 'timestamp' in l:
            ll = l.split(':')[1].split()[0]
            timestamp = float(ll)
        if 'pressure' in l:
            ll = l.split(':')[-1][:-3]
            pressure = float(ll)
        if 'temp' in l:
            ll = l.split(':')[-1][:-3]
            temperature = float(ll)
        if 'device_id' in l:
            ll = l.split(':')[1].split()[0]
            device_id = int(ll)

        retval = p.wait()

    if p.returncode != 0 or not isinstance(temperature, float):
        print("ERROR calling px4-listener, likely px4 isn't running.")
        print("Start with systemctl start voxl-px4")
        sys.exit(1)

    return (timestamp,device_id,pressure,temperature)

if __name__ == '__main__':

    print('Barometer temperature calibration process starting')

    parser = argparse.ArgumentParser(description='Barometer Test Script')
    parser.add_argument('-n','--num-samples',   type=int, required=False, default=1250)
    parser.add_argument('-p','--log_path',   type=str, required=False, default='/data/baro_cal_logs/', help='location to save logs and plots to')
    parser.add_argument('-c','--num-cpus',   type=int, required=False, default=7)
    parser.add_argument('-i','--num-iter',   type=int, required=False, default=1)
    parser.add_argument('-f','--force', action='store_true', help='disable temperature range checks and run calibration anyway')
    parser.add_argument('-s','--sleep-interval',   type=int, required=False, default=300)
    parser.add_argument('-t','--enable-plots', action='store_true', help='Enable plotting')
    parser.add_argument('-x','--no-params',   type=int, required=False, default=0)
    args = parser.parse_args()

    log_path = args.log_path
    if log_path != '.':
            try:
                os.makedirs(log_path, exist_ok=True)
            except Exception as e:
                print(f'Error: Could not create log directory: {e}')
                sys.exit(1)

    if args.force:
        print("running in forced mode")

    max_sample_count = args.num_samples
    num_iterations = args.num_iter
    num_cpu_stress = args.num_cpus
    sleep_between_iterations = args.sleep_interval #5*60 #seconds

    enable_plots = args.enable_plots
    if enable_plots:
        try:
            import plotly.graph_objects as go
            from plotly.subplots import make_subplots
        except:
            print('ERROR: In order to plot the results, install the Python "plotly" module:')
            print('sudo apt install -y python3-pip')
            print('pip3 install plotly --upgrade')
            sys.exit(1)

    set_px4_parameters = (args.no_params != 1)

    signal.signal(signal.SIGINT, handler)

    # disable cpu pitmode so we can get hot enough
    subprocess.run(["voxl-set-cpu-mode", "perf"])

    for i in range(num_iterations):
        if keep_going==0:
            print('Exiting')
            sys.exit(0)

        print(f'Starting iteration #{i}\n')

        tlast = 0
        sample_count = 0
        time_start = time.time()
        ts = []
        ps = []
        temps = []
        procs = []

        fname = log_path + '/baro_log_%d_%d.txt' % (i,time_start)
        f     = open(fname, 'w')



        # wait for temp to be low enough
        needs_fan_turning_off = False
        start_temp = 100.0 # this will be reset to current temp shortly
        while True:

            (timestamp, device_id, pressure, start_temp) = get_px4_baro()

            if args.force:
                break

            if start_temp < 35.0:
                break

            print(f"\rBaro too hot [{start_temp:5.1f}C] cool the VOXL with a fan until <35C", end="", flush=True);
            needs_fan_turning_off = True
            time.sleep(1);

        if needs_fan_turning_off:
            input("\nPlease turn the fan OFF, then press Enter to continue.\n")


        # Create and start CPU hog processes
        if num_cpu_stress:
            print('\nStarting stress tasks')
            for j in range(num_cpu_stress):
                proc = multiprocessing.Process(target=hogcpu)
                procs.append(proc)
                proc.start()
        print('Starting data capture')

        # Start the data collection loop
        while (sample_count < max_sample_count and keep_going==1):

            (timestamp, device_id, pressure, temperature) = get_px4_baro()

            #make sure we got new data
            if timestamp != tlast and timestamp is not None:
                if device_id == 12018473 or device_id == 12018433:  #icp10100 internal or external
                    sample_count += 1
                    ts.append(timestamp)
                    ps.append(pressure)
                    temps.append(temperature)
                    #print(f'({sample_count:4.d})[{timestamp}] ID: {device_id} P: {pressure}, T: {temperature}')
                    print(f'({sample_count:4d}) P:{pressure:8.1f}Pa T: {temperature:4.1f}C')
                    f.write(f'{sample_count}, {int(timestamp)}, {device_id}, {pressure}, {temperature}\n')

                    # can break early if we've already covered a wide 30C temp range
                    if (temperature > (start_temp+30.0)):
                        break
                else:
                    print(f'Incorrect baro device id: {device_id}')
                    break

            tlast=timestamp
            time.sleep(0.01)


        print('Finished collecting data')

        print(f'Closing log file {fname}..')
        f.close()

        # Stop all CPU hog processes
        if len(procs):
            print('Stopping stress tasks')
            for proc in procs:
                proc.terminate()
                proc.join()

        if enable_plots:
            print('Generating plots..')

            fig = make_subplots(rows=4, cols=1, start_cell="top-left")

            ts_plot  = np.array(ts)
            ts_plot -= ts_plot[0]
            ts_plot *= 0.000001 #convert from us to s

            # calculate approximate height from start
            dh       = np.array(ps)
            dh      -= dh[0] # subtract the first value
            dh      /= 12.0  # about 12 Pascals per meter at standard pressure and temperature
            dh      *= -1.0  # positive pressure increase means lower height

            t_plot = np.array(temps)
            p_plot = np.array(ps)

            fig.add_trace(go.Scatter(x=ts_plot, y=np.array(ps), name='Pressure (Pa)'), row=1, col=1)
            fig.add_trace(go.Scatter(x=ts_plot, y=t_plot, name='Temperature (deg C)'), row=2, col=1)
            fig.add_trace(go.Scatter(x=ts_plot, y=dh, name='Approx Height Change (m)'), row=3, col=1)
            fig.add_trace(go.Scatter(x=t_plot, y=p_plot, name='Pressure (Pa) vs Temperature'), row=4, col=1)

            fig.update_layout(title_text='Barometer Test Results')
            fig.update_xaxes(title_text='Time (s)',row=1, col=1)
            fig.update_yaxes(title_text='Pressure (Pa)',row=1, col=1)
            fig.update_xaxes(title_text='Time (s)',row=2, col=1)
            fig.update_yaxes(title_text='Temperature (deg C)',row=2, col=1)
            fig.update_xaxes(title_text='Time (s)',row=3, col=1)
            fig.update_yaxes(title_text='Approx Height Change (m)',row=3, col=1)
            fig.update_xaxes(title_text='Temperature (deg C)',row=4, col=1)
            fig.update_yaxes(title_text='Pressure (Pa)',row=4, col=1)

            html_path = '%s/barometer_test_results_%d_%d.html' % (log_path, i, time_start)
            print(f"Saving plot to: {html_path}")
            fig.write_html(html_path, include_plotlyjs='cdn')
            #fig.show()  #the figure will not show on VOXL because there is no display / browser

        # Figure out PX4 parameters only for first iteration
        enable_reference = False
        poly_fit_order = 1
        baro_sensor_id1 = int(12018473)
        baro_sensor_id2 = int(12018433)
        baro_sensor_id = int(0)
        if i == 0:
            with open(fname, 'r') as fp:
                nlines = len(fp.readlines())
            #print(f'File has {nlines} lines')

            data = np.zeros((4,nlines),dtype=float)
            data2 = np.zeros((4,nlines),dtype=float)

            ndata = 0
            ndata2 = 0
            with open(fname) as csvfile:
                reader = csv.reader(csvfile, delimiter=',')
                for row in reader:
                    #print(', '.join(row))
                    vals = [float(v.strip()) for v in row]
                    nvals = np.asarray(vals)
                    nvals = nvals[[0,1,3,4]]
                    #print(nvals)

                    if vals[2] == baro_sensor_id1:
                        data[:,ndata] = nvals
                        ndata += 1
                        baro_sensor_id = baro_sensor_id1
                    elif vals[2] == baro_sensor_id2:
                        data[:,ndata] = nvals
                        ndata += 1
                        baro_sensor_id = baro_sensor_id2
                    else: # secondary barometer
                        data2[:,ndata2] = nvals
                        ndata2 += 1
                        #enable_reference = True

            data   = data[:,0:ndata]
            if ndata == 0 or np.allclose(data[2:4, :], 0):      # add a simple check that will exit if the data is all 0s 
                print("ERROR: captured baro pressure/temperature data is empty or all zeros, aborting calibration.")
                sys.exit(1)

            ts_plot  = data[1,:]
            t_off    = ts_plot[0]

            ts_plot -= t_off
            ts_plot *= 0.000001 #convert from us to s
            t_plot   = data[3,:]
            p_plot   = data[2,:]
            p_comp   = p_plot.copy() #-p_plot[i_ref]
            
            #print(i_ref)

            if enable_reference:
                data2 = data2[:,0:ndata2]

                ts_plot2 = data2[1,:]
                ts_plot2 -= t_off  #same offset as main baro
                ts_plot2 *= 0.000001 #convert from us to s
                t_plot2   = data2[3,:]
                p_plot2  = data2[2,:]
                #p_plot2  = moving_average(p_plot2,5)
                p_comp2 = p_plot2 - np.mean(p_plot2)

                for pi in range(p_comp.size):
                    #print(pi)
                    try:
                        ii = np.where(ts_plot2 > ts_plot[pi])[0][0]
                    except:
                        ii = p_comp2.size-1
                    p_comp[pi] -= p_plot2[ii]

            #t_ref = np.mean(t_plot)
            t_min = np.min(t_plot)
            t_max = np.max(t_plot)
            t_ref = (t_max+t_min)/2
            #i_ref = np.where(t_plot > ref_temp)[0][0]
            i_ref = np.where(t_plot > t_ref)[0][0]
            
            p_comp = p_comp-p_comp[i_ref]

            #p_plot -= p_plot[i_ref]

            # fig.add_trace(go.Scatter(x=ts_plot,  y=p_plot,  name=f'[{log_cntr}] Pressure Raw (Pa)'),  row=1, col=1)
            # fig.add_trace(go.Scatter(x=t_plot,   y=p_plot,  name=f'[{log_cntr}] Pressure Raw (Pa)'),  row=2, col=1)
            # fig.add_trace(go.Scatter(x=t_plot ,  y=p_comp,  name=f'[{log_cntr}] Pressure Comp (Pa)'), row=3, col=1)

            xx = t_plot -t_ref  #- 50.0
            yy = p_comp

            # if log_cntr == 1:
                # p = np.polyfit(xx, yy, poly_fit_order)
            p = np.polyfit(xx, yy, poly_fit_order)

            zz = np.polyval(p,xx)
            # fig.add_trace(go.Scatter(x=t_plot ,  y=zz,  name=f'[{log_cntr}] Pressure Fit (Pa)'), row=3, col=1)

            print('Poly Fit:')
            print(p)

            # fig.add_trace(go.Scatter(x=t_plot ,  y=p_comp-zz,  name=f'[{log_cntr}] Fit Error (Pa)'), row=4, col=1)
            # 
            # if enable_reference:
            #     fig.add_trace(go.Scatter(x=ts_plot2 , y=p_plot2, name=f'[{log_cntr}] Pressure Ref (Pa)'),  row=5, col=1)
            #     fig.add_trace(go.Scatter(x=ts_plot2 , y=t_plot2, name=f'[{log_cntr}] Temperature Ref (Pa)'),  row=6, col=1)


            # check we actually covered a decent (>20C temp range)
            temp_range = t_max - t_min
            if not args.force and temp_range < 20.0:
                print(f'ERROR: calibration only covered {temp_range}C, not enough for a good cal')
                sys.exit(1)


            # min temp should be very low to allow extrapolation of calibration in cold weather
            # cal was usually done on a bench with no airflow in a warm room
            # increase t_max a little bit to allow extrapolation outdoors in the sun
            if t_min > 0.0:
                t_min = 0.0
            t_max = t_max + 10

            # print the PX4 params
            print('param set TC_B_ENABLE 1')
            print('param set TC_B0_ID    %d' % baro_sensor_id)
            print('param set TC_B0_TREF  %.2f' % t_ref)
            print('param set TC_B0_TMIN  %.2f' % t_min)
            print('param set TC_B0_TMAX  %.2f' % t_max)
            for i in range(poly_fit_order+1):
                print('param set TC_B0_X%d    %.5f' % (i,p[poly_fit_order-i]))
            for i in range(poly_fit_order+1,6):
                print('param set TC_B0_X%d    %.5f' % (i,0.0))

            # Backup the parameters
            cal_param_file_name = '/data/px4/param/parameters_baro_tc.cal'
            if os.path.exists(cal_param_file_name):
                print(f'Parameter file {cal_param_file_name} already exists, overwriting with new data')
            cal_file = open(cal_param_file_name, 'w')
            cal_file.write('1\t1\tTC_B_ENABLE\t1\t6\n')
            cal_file.write('1\t1\tTC_B0_ID\t%d\t6\n' % baro_sensor_id)
            cal_file.write('1\t1\tTC_B0_TREF\t%.2f\t9\n' % t_ref)
            cal_file.write('1\t1\tTC_B0_TMIN\t%.2f\t9\n' % t_min)
            cal_file.write('1\t1\tTC_B0_TMAX\t%.2f\t9\n' % t_max)
            for k in range(poly_fit_order+1):
                cal_file.write('1\t1\tTC_B0_X%d\t%.6f\t9\n' % (k,p[poly_fit_order-k]))
            for k in range(poly_fit_order+1,6):
                cal_file.write('1\t1\tTC_B0_X%d\t%.6f\t9\n' % (k,0.0))
            cal_file.close()

            if set_px4_parameters:
                print('\nLoading new params to PX4')
                subprocess.call(["voxl-configure-px4-params", "-n", "-f", cal_param_file_name])

            if enable_plots:
                print('pull the plot file with')
                print(f'adb pull {html_path}')

        if (i < (num_iterations-1) ):
            print('sleeping..')
            nsleep = sleep_between_iterations
            for s in range(nsleep):
                time.sleep(1.0)
                if keep_going==0:
                    subprocess.run(["voxl-set-cpu-mode", "auto_with_pitmode"])
                    print('Exiting')
                    sys.exit(0)

            print('done sleeping')

    subprocess.run(["voxl-set-cpu-mode", "auto_with_pitmode"])
    print('Barometer temperature calibration process ending')
