import numpy as np
import tomllib
import xraylib
from transfocator_calcs import lookup_diameter, materials_to_deltas, materials_to_linear_attenuation
from transfocator_calcs import find_levels, calc_lookup_table, get_densities

MAT_MACRO = 'MAT'
NLENS_MACRO = 'NUMLENS'
RADIUS_MACRO = 'RADIUS'
LOC_MACRO = 'LOC'
THICKERR_MACRO = 'THICKERR'

'''
Config variables

Beam Properties
        energy      : energy in keV
        L_und       : undulator length in m
        sigmaH_e    : Sigma electron source size in H direction in m
        sigmaV_e    : Sigma electron source size in V direction in m
        sigmaHp     : Sigma electron divergence in H direction in rad
        sigmaVp_e   : Sigma electron divergence in V direction in rad
Beamline properties
        d_StoL1 : Source-to-CRL1 distance, in m
        d_Stof  : Source-to-sample distance, in m
CRL properties
        d_min   : Minimum thickness at the apex in m
        stack_d : Stack thickness in m
'''
DEFAULT_CONFIG = {'beam':{'energy': 15, 'L_und': 4.7, 'sigmaH_e': 14.8e-6,
                          'sigmaV_e': 3.7e-6, 'sigmaHp_e': 2.8e-6, 'sigmaVp_e': 1.5e-6},
                  'beamline': {'d_StoL1': 51.9, 'd_Stof': 66.2},
                  'crl':{'stack_d': 50.0e-3, 'd_min': 3.0e-5}}

class singleTF():
    
    def __init__(self, crl_setup = None, beam_config = DEFAULT_CONFIG['beam'], 
                 beamline_config = DEFAULT_CONFIG['beamline'], 
                 crl_config = DEFAULT_CONFIG['crl']):

        self.verbose = True

        if crl_setup is None:
            beam = beam_config
            beamline = beamline_config
            crl = crl_config
#            slits = slits_config
        else:
            with open(crl_setup, "rb") as f:
                config = tomllib.load(f)
            beam = config['beam']
            beamline = config['beamline']
            crl = config['crl']
 #           slits = config['slits'] 

        self.beam = {}
        self.setupSource(beam)
        
        self.bl = {}
        self.setupBeamline(beamline)
        
        self.crl = {}
        self.setupCRL(crl)
        
        self.slits = [{'H':0,'V':0}]
        
#        self.slit1_H = 0
#        self.slit1_V = 0
                
        # Initialize lens variables -- TODO -- this is done via a subs file -- are any of these needed prior to that loading?
        self.numlens        = np.array([1,      1,      1,      1,      1,      1,      2,      4,      8,      16])                # CRL1 number of lenses in each stack (was L1_n)
        self.radius         = np.array([2.0e-3, 1.0e-3, 5.0e-4, 3.0e-4, 2.0e-4, 1.0e-4, 1.0e-4, 1.0e-4, 1.0e-4, 1.0e-4])            # CRL1 lens radius in each stack (was L1_R)
        self.materials      = np.array(["Be",   "Be",   "Be",   "Be",   "Be",   "Be",   "Be",   "Be",   "Be",   "Be"])              # CRL1 lens material in each stack (was L1_mater)
        self.lens_loc       = np.array([4.5,    3.5,    2.5,    1.5,    0.5,    -0.5,   -1.5,   -2.5,   -3.5,   -4.5])*self.crl['stack_d']      # CRL1 lens stack location relative to center stack, positive means upstream (was L1_Loc)
        self.lens_thickerr  = np.array([1.0e-6, 1.0e-6, 1.0e-6, 1.0e-6, 1.0e-6, 1.0e-6, 1.4e-6, 2.0e-6, 2.8e-6, 4.0e-6])            # CRL1 lens RMS thickness error (was L1_HE)


        self.Lens_diameter_table = [
                                    (50, 450.0),
                                    (100, 632.0),
                                    (200, 894.0),
                                    (300, 1095.0),
                                    (500, 1414.0),
                                    (1000, 2000.0),
                                    (1500, 2450.0),
                                ]
        # Convert the lookup table to a dictionary for faster lookup        
        self.Lens_diameter_dict = {int(col1): col2 for col1, col2 in self.Lens_diameter_table}
        
        self.energy = 0  # gets value from an ao (incoming beam energy)
        self.focalSize = 0 # get value from an ao (desired focal length)
        self.lenses = 0 # sets integer (2^10) whose binary representation indicates which lenses are in or out
        
        self.num_stacks = 10 # Number of lenses in system
        
        self.lookupTable = []
        
        self.thickerr_flag = True
                
    def setupSource(self, beam_properties):
        '''
        Beam properties can have entries for the following
        
        energy      : energy in keV
        L_und       : undulator length in m
        sigmaH_e    : Sigma electron source size in H direction in m
        sigmaV_e    : Sigma electron source size in V direction in m
        sigmaHp_e   : Sigma electron divergence in H direction in rad
        sigmaVp_e   : Sigma electron divergence in V direction in rad
        '''
        
        self.setEnergy(beam_properties['energy'])
        self.L_und = beam_properties['L_und']
        self.sigmaH_e = beam_properties['sigmaH_e']
        self.sigmaV_e = beam_properties['sigmaV_e']
        self.sigmaHp_e = beam_properties['sigmaHp_e']
        self.sigmaVp_e = beam_properties['sigmaVp_e']
        
        self.setupSourceEnergyDependent()
    
    def setupSourceEnergyDependent(self):
        '''
        Fill in later
        '''    
        self.beam['sigmaH'] =  (self.sigmaH_e**2 +  self.wl*self.L_und/2/np.pi/np.pi)**0.5
        self.beam['sigmaV'] =  (self.sigmaV_e**2 +  self.wl*self.L_und/2/np.pi/np.pi)**0.5
        self.beam['sigmaHp'] = (self.sigmaHp_e**2 + self.wl/self.L_und/2)**0.5
        self.beam['sigmaVp'] = (self.sigmaVp_e**2 + self.wl/self.L_und/2)**0.5
        
    def setupBeamline(self, beamline_properties, num=1):
        '''
        Beamline properties can contain entries for the following
        
        d_StoL1 : Source-to-CRL1 distance, in m
        d_Stof  : Source-to-sample distance, in m
        '''            
        
        self.bl['d_StoL1'] = beamline_properties['d_StoL1']
        self.bl['d_Stof'] = beamline_properties['d_Stof']
            
            
    def setupCRL(self, crl_properties):
        '''
        CRL properties can contiain entries for the following
        
        d_min   : Minimum thickness at the apex in m
        stack_d : Stack thickness in m
        '''
        self.crl['d_min'] = crl_properties['d_min'] 
        self.crl['stack_d'] = crl_properties['stack_d']

    def setupSlits(self, slit_properties):
        '''
        Slit properties can contain entries for the following
        
        '''
        pass 
                
    def setupLookupTable(self, subs_file, n_lenses):
        '''
        lookup table created after IOC startup
        energy and slit size are updated before this is called
        '''
        print(80*'#')
        print('Setting up lens control...')
        
        self.num_stacks = n_lenses
        self.num_configs = 2**self.num_stacks
        self.configs = np.arange(self.num_configs)
        
#       self.energy = energy
        
        #read in substitutions file
        try:
            subsFile = open(subs_file,"r")
        except:
            raise RuntimeError(f"Substiution file ({subsFile}) not found.")
        subsFileContent = subsFile.readlines()
        subsFile.close()
        
        macros = subsFileContent[2].replace('{','').replace('}','').replace(',','').split()
        lens_properties = {key: [] for key in macros} # dictionary of lists
        for i in range(self.num_stacks):
            try:
                xx = subsFileContent[3+i].replace('{','').replace('}','').replace(',','').replace('"','').split()
                lens_properties[macros[0]].append(xx[0])
                lens_properties[macros[1]].append(xx[1])
                lens_properties[macros[2]].append(xx[2])
                lens_properties[macros[3]].append(xx[3])
                lens_properties[macros[4]].append(xx[4])
                lens_properties[macros[5]].append(xx[5])
                lens_properties[macros[6]].append(xx[6])
            except:
                raise RuntimeError(f"Number of lenses ({self.num_stacks}) doesn't match substitution file")
        
        self.numlens = []
        self.radius = []
        self.materials = []
        self.lens_loc = []
        self.lens_thickerr = []
            
        # get number of lens for each lens stack from lens properties dictionary-list
        print('Getting lens materials...')
        if NLENS_MACRO in macros:
            self.numlens = np.array([int(i) for i in lens_properties[NLENS_MACRO]])
            print('Number of lens read in.\n')
        else:
            raise RuntimeError(f"Number of lenses macro ({NLENS_MACRO}) not found in substituion file")

        # get radii for each lens from lens properties dictionary-list
        print('Getting lens\' radii...')
        if RADIUS_MACRO in macros:
            self.radius = np.array([float(i) for i in lens_properties[RADIUS_MACRO]])
            print('Radius of lenses read in.\n')
        else:
            raise RuntimeError(f"Radius macro ({RADIUS_MACRO}) not found in substituion file")

        # get materials from lens properties dictionary-list
        print('Getting lens materials...')
        if MAT_MACRO in macros:
            self.materials = lens_properties[MAT_MACRO]
            print('Lens material read in.\n')
        else:
            raise RuntimeError(f"Material macro ({MAT_MACRO}) not found in substituion file")
        
        # get densities from local definition (for compounds) or from xraylib (for elements)
        densities = get_densities(self.materials)
        self.densities = np.array([densities[material] for material in self.materials])

        # get location of each lens from lens properties dictionary-list
        print('Getting lens\' locations...')
        if LOC_MACRO in macros:
            self.lens_loc = np.array([float(i)*self.crl['stack_d'] for i in lens_properties[LOC_MACRO]])
            print('Location of lenses read in.\n')
        else:
            raise RuntimeError(f"Location macro ({LOC_MACRO}) not found in substituion file")

        # get thicknesses errprfrom lens properties dictionary-list
        print('Getting lens thickness error...')
        if THICKERR_MACRO in macros:
            self.lens_thickerr = np.array([float(i) for i in lens_properties[THICKERR_MACRO]])
            print('Lens thickness errors read in.\n')
        else:
            raise RuntimeError(f"Thickness errors macro ({THICKERR_MACRO}) not found in substituion file")

        print('Constructing lookup table...')
        self.construct_lookup_table()
        print('Lookup table calculation complete.\n')
        
        print('Transfocator control setup complete.')
        print(80*'#')

    def construct_lookup_table(self):
        arr_a, arr_b, arr_c = calc_lookup_table(self.num_configs, self.radius, 
                                                self.materials, self.energy, self.wl,
                                                self.numlens, 
                                                self.lens_loc, self.beam, self.bl,
                                                self.crl, self.slits[0]['H'], self.slits[0]['V'],
                                                self.lens_thickerr, flag_HE = self.thickerr_flag,
                                                verbose = self.verbose)
        self.lookupTable = arr_a
        self.sorted_invF_index = arr_b
        self.sorted_invF = arr_c                                                            
                                                                    
#        self.sort_lookup_table()
        self.updateEnergyRBV()
        self.updateSlitSizeRBV('hor')
        self.updateSlitSizeRBV('vert')
        self.updateLookupWaveform()
        self.updateInvFWaveform()
        self.updateLookupConfigs()

        
#    def sort_lookup_table(self):
#        '''
#        
#        '''
#        if self.verbose: print(f'Sorting lookup table of length {len(self.lookupTable)}')        
#        self.sorted_L_index = np.argsort(self.lookupTable)

    def updateConfig(self, config_BW):
        '''
        When user manually changes lenses, this gets focal size and displays it
        along with updated RBVs but it doesn't set the config PV
        '''
        self.index = int(config_BW)
        # Find the configuration in the 1/f sorted list
        self.indexSorted = self.sorted_invF_index.tolist().index(self.index)

        self.setFocalSizeActual()
        self.updateLensRBV()
        self.updateFocalSizeRBVs()      



    def setFocalSizeActual(self):
        '''
        
        '''
#        self.focalSize_actual = self.culledTable[self.culledIndex] 
        self.focalSize_actual = self.lookupTable[self.indexSorted] 

    def find_config(self):
        ''' 
        User selected focal size, this function finds nearest acheivable focal 
        size from the lookup table
        '''
        # Code to search lookup table for nearest focal size to desired; note the
        # lookup table is already sorted by 1/f
        if self.verbose: print(f'Searching for config closest to {self.focalSize}')
#       simple approach
#        self.indexSorted = np.argmin(np.abs(self.lookupTable - self.focalSize))

        # XS approach -- can handle nan but in pydev application don't have a good
        # way to "transmit" errors (i.e. no solution found) to user.
        indices, _ = find_levels(self.lookupTable, self.focalSize, direction='forward')[0]
        self.indexSorted = indices[0]
        

        if self.verbose: print(f'1/f-sorted config index found at {self.indexSorted}')

        self.index = self.sorted_invF_index[self.indexSorted]
        if self.verbose: print(f'Config index found at {self.index}')

        # Update PVs
        self.setFocalSizeActual()
        self.updateLensConfigPV()
        self.updateLensRBV()
        self.updateFocalSizeRBVs()      

    def getPreviewFocalSize(self, sortedIndex):
        '''
        
        '''
        fSize_preview = self.lookupTable[sortedIndex]
        if self.verbose: print(f'Preview focal sizes for {sortedIndex} is {fSize_preview}')
        pydev.iointr('new_preview', fSize_preview)
    
    def setSlitSize(self, size, slit):
        '''
        Update proper slit size
        '''
        
        if slit == 'hor':
#            self.slit1_H = float(size)     # H slit size before CRL 1
            self.slits[0]['H'] = float(size)     # H slit size before CRL 1
        elif slit == 'vert':
#            self.slit1_V = float(size)     # V slit size before CRL 1
            self.slits[0]['V'] = float(size)     # V slit size before CRL 1
        else: 
            raise RuntimeError(f"Slit identifier ({slit}) not recognized. Should be 'hor' or 'vert'")
           
                    
    def updateSlitSize(self, size, slit):
        '''
        Slit size updates are propagated to CRL object from EPICS.  The beam
        size lookup table is then recalculated.
        '''
        self.setSlitSize(size, slit)
        
        if self.verbose: 
            if slit == 'hor':
#                print(f'Horizontal slit size updated to {self.slit1_H} m')
                print(f"Horizontal slit size updated to {self.slits[0]['H']} m")
            elif slit == 'vert':
#                print(f'Vertical slit size updated to {self.slit1_V} m')
                print(f"Vertical slit size updated to {self.slits[0]['V']} m")
        
        
    def setEnergy(self, energy):
        '''
        Sets various forms of energy
        '''
        if energy > 0.0001:
            self.energy = float(energy)
            self.energy_eV = self.energy*1000.0  # Energy in keV
            self.wl = 1239.84 / (self.energy_eV * 10**9)    #Wavelength in nm(?)
            if self.verbose: print(f'Setting energy to {self.energy} keV')
            
    def updateE(self, energy):
        '''
        Beam energy updates are propagated to CRL object from EPICS. The beam
        size lookup table is then recalculated.
        '''

        if energy > 0.0001:
            # Energy variable sent from IOC as a string
            self.setEnergy(energy)
            # Update beam properties that are dependent on energy
            self.setupSourceEnergyDependent()
        else:
            if verbose: print(f'Invalid energy setting: {energy} kev; staying at {self.energy} keV')
            
    def updateFsize(self, focalSize):
        '''
        User updates desired focal size. Lookup table is traversed to find nearest
        to desired.
        '''
        # focalPoint variable sent from IOC as a string
        self.focalSize = float(focalSize)
        self.find_config()
        
    def updateIndex(self, sortedIndex):
        '''
        User has updated desired sorted index
        '''
        self.indexSorted = int(sortedIndex)
        self.index = self.sorted_invF_index[self.indexSorted]

        # Update PVs
        self.setFocalSizeActual()
        self.updateLensConfigPV()
        self.updateLensRBV()
        self.updateFocalSizeRBVs()    

    def setThickerrFlag(self, flag):
        '''
        User has updated thickness error flag so that ...
        '''
        self.thickerr_flag = int(flag)
        if self.verbose: print(f'Thickness Error Flag set to {flag}')
        self.updateThickerrFlagRBV()

    def updateThickerrFlagRBV(self):
        '''
        Thickness error flag has been updated
        '''
        if self.verbose: print(f'Thickness Error Flag RBV set to {self.thickerr_flag}')
        pydev.iointr('updated_thickerr_Flag', self.thickerr_flag)

        
    def updateLensConfigPV(self):
        '''
        
        '''
        self.config = self.configs[self.index]
        pydev.iointr('new_lenses', int(self.config))

    def updateLensRBV(self):
        '''
        
        '''
        pydev.iointr('new_index', int(self.indexSorted))

    def updateEnergyRBV(self):
        '''
        
        '''
        pydev.iointr('updated_E', float(self.energy))

    def updateSlitSizeRBV(self, slit):
        '''
        Update proper slit size
        '''
        
        if slit == 'hor':
#            pydev.iointr('updated_slitSize_H', float(self.slit1_H))
#            if self.verbose: print(f'Horizontal slit size RBV updated to {self.slit1_H} m')
            pydev.iointr('updated_slitSize_H', float(self.slits[0]['H']))
            if self.verbose: print(f"Horizontal slit size RBV updated to {self.slits[0]['H']} m")
        elif slit == 'vert':
#            pydev.iointr('updated_slitSize_V', float(self.slit1_V))
#            if self.verbose: print(f'Vertical slit size RBV updated to {self.slit1_V} m')
            pydev.iointr('updated_slitSize_V', float(self.slits[0]['V']))
            if self.verbose: print(f"Vertical slit size RBV updated to {self.slits[0]['V']} m")
            

        
    def updateFocalSizeRBVs(self):
        '''
        
        '''
        pydev.iointr('new_fSize', self.focalSize_actual)
        
    def updateVerbosity(self, verbosity):
        '''
        Turn on minor printing
        '''
        print(f'Verbosity set to {verbosity}')
        self.verbose = int(verbosity)

    def updateLookupWaveform(self):
        '''
        Puts lookup table focal sizes into waveform PV
        '''
        pydev.iointr('new_lookupTable', self.lookupTable.tolist())

    def updateInvFWaveform(self):
        '''
        Puts invF list into waveform PV
        '''
        pydev.iointr('new_invFind_list', self.sorted_invF_index.tolist())
        pydev.iointr('new_invF_list', self.sorted_invF.tolist())

    
    def updateLookupConfigs(self):
        '''
        Puts lookup table config integers into waveform PV
        '''
        pydev.iointr('new_configs', self.configs.tolist())