#!/usr/bin/env python3

from optparse import OptionParser
from ase.io import read
#from os import system
import numpy as np
from gpaw import GPAW
from gpaw.poisson import PoissonSolver
from gpaw.occupations import FermiDirac
from gpaw.tddft import TDDFT
from gpaw.lrtddft import LrTDDFT
from gpaw.mixer import Mixer
from os import path
from ase.data.molecules import molecule
from gpaw.lcaotddft import LCAOTDDFT
from gpaw.xc.gllb.c_gllbscr import C_GLLBScr
from gpaw.xc.gllb.c_response import C_Response
from gpaw.xc.gllb.c_xc import C_XC
from gpaw.xc.gllb.nonlocalfunctional import NonLocalFunctional


class TimedLCAOTDDFT(LCAOTDDFT):
    pass
    #timer_class = ParallelTimer


class TimedTDDFT(TDDFT):
    pass
    #timer_class = ParallelTimer


class TimedLrTDDFT(LrTDDFT):
    pass
    #timer_class = ParallelTimer


def center_divisible(atoms, vacuum, h, n=4):
    if isinstance(vacuum, int) or isinstance(vacuum, float):
        vacuum = (vacuum, vacuum, vacuum)
    for c, v in enumerate(vacuum):
        atoms.center(vacuum=v, axis=c)
    cell_c = np.diag(atoms.get_cell()) # Assume cubic !
    for c, vac in zip(range(3), vacuum):
        cell = cell_c[c]
        ng = int(np.ceil(cell/h))
        if ng % n != 0:
            ng += n - (ng % n) # make divisible
        vacuum_new =  (ng * h - (cell - 2*vac))/2
        atoms.center(vacuum=vacuum_new, axis=c)


parser = OptionParser()
parser.add_option('', '--kick', action='store', default='0,0,0.00001', 
                  help='Kick stengths')
parser.add_option('', '--charge', action='store', default=None, 
                  help='Charge of the systen')
parser.add_option('', '--parallel_band', action='store', default=None, 
                  help='Band parallelization')
parser.add_option('-d', '--directory', action='store', default=None, 
                  help='Specify directory for output files')
parser.add_option('-n', '--name', action='store', default=None, 
                  help='Prefix name for output files. Otherwise input name is used.')
parser.add_option('', '--collect_density', action='store', default=None,
                  help='Collect density.\n'
                  '--collect density=full     collect full density\n'
                  '--collect_density=10       collect bands 0-9 and 10-nbands '
                  'separately.\n'
                  '--collect_density=10,20,35 collect bands '
                  '0-9, 10-19, 20-34, 35-nbands separately.\n')
parser.add_option('', '--collect_split_density', action='store', default=None,
                  help='Collect two densities, give orbital index which splits them')
parser.add_option('', '--collect_vresp', action='store_true', default=None,
                  help='Collect response potential of GLLBSC')
parser.add_option('-s', '--system', action='store', default=None, 
                  help='Input an example system. Example C6H6')
parser.add_option('-g', '--gridspacing', action='store', default='0.2',
                  help='Specify grid spacing')
parser.add_option('-x', '--xc', action='store', default='LDA',
                  help='Specify xc functional')
parser.add_option('-f', '--fxc', action='store', default=None,
                  help='Specify xc functional for time-propagation separately.')
parser.add_option('-v', '--vacuum', action='store', default='5',
                  help='Specify amount of vacuum')
parser.add_option('-b', '--basis', action='store', default=None,
                  help='Perforn LCAO calculation with basis set')
parser.add_option('-t', '--timestep', action='store', default='10',
                  help='Time step for LCAO calculation (in as)')
parser.add_option('-T', '--time', action='store', default='10',
                  help='Simulation time (in fs)')
parser.add_option('', '--lrtddft', action='store', default=False,  
                  help='Perform linear response calculation')
parser.add_option('', '--nbands', action='store', default='None',  
                  help='Number of bands')
parser.add_option('', '--dxc', action='store_true', default=False,  
                  help='Calculate dxc')
parser.add_option('', '--hard', action='store_true', default=False,
                  help='Try harder to converge the system')
parser.add_option('', '--nodipole', action='store_true', default=False,
                  help='Do not use dipole correction is poisson equation. For expert use only.')
parser.add_option('', '--gsonly', action='store_true', default=False,
                  help='Calculate only the ground state calculation and then stop')

(options, args) = parser.parse_args()

class System:
    def __init__(self, atoms, name='system'):
        self.atoms = atoms
        self.name = name
        self.mode = 'fd'
        self.basis = None
        self.dxc = False
        self.vacuum = 5
        self.xc = 'LDA'
        self.kick = [0.0, 0.0, 0.00001]
        self.dt = 10
        self.T = 10
        self.lrtddft = False
        self.hard = False
        self.nodipole = False
        self.remove_moment = 0 if self.nodipole else 9
        self.bandpar = None
        self.charge = 0
        self.collect_vresp = False
        self.collect_density = None
        #self.collect_split_density = False
 
    #def set_collect_split_density(self, split):
    #    self.collect_split_density = split

    def set_collect_vresp(self):
        self.collect_vresp = True

    def set_collect_density(self, range_str):
        self.collect_density = range_str

    def set_kick(self, kick):
        self.kick = kick

    def set_charge(self, charge):
        self.charge = charge
  
    def set_bandpar(self, bandpar):
        self.bandpar = bandpar

    def set_nbands(self, nbands):
        self.nbands = nbands

    def set_h(self, h):
        self.h = h

    def set_vacuum(self, vacuum):
        if isinstance(vacuum, int) or isinstance(vacuum, float):
            self.vacuum = (vacuum, vacuum, vacuum)
        else:
            self.vacuum = vacuum

    def set_lrtddft(self, energy):
        self.lrtddft = True
        self.energy = energy

    def set_time(self, T):
        self.T = T

    def set_directory(self, directory):
        self.directory = directory
        if self.directory == None:
            self.directory = ''
    
    def set_timestep(self, dt):
        self.dt = dt

    def set_hard(self):
        self.hard = True

    def set_dxc(self):
        self.dxc = True

    def set_xc(self, xc):
        self.xc = xc

    def set_fxc(self, fxc):
        self.fxc = fxc

    def set_basis(self, basis):
        self.basis = basis
        if basis is not None:
            self.mode = 'lcao'

    def set_nodipole(self):
        self.nodipole = True
        self.remove_moment = 0 if self.nodipole else 9

    def get_base_name(self):
        lr = 'lr_' if self.lrtddft else ''
        if isinstance(self.xc, str):
            xcname = self.xc
        else:
            xcname = self.xc.name
        return '%s%s%.2f_%s_%s_%s' % (lr,  '_nodipole' if self.nodipole else '', self.h, xcname, self.basis, self.get_vacuum_str())

    def get_vacuum_str(self):
        return 'v%1.f%1.f%1.f' % (self.vacuum[0], self.vacuum[1], self.vacuum[2])

    def kick_str(self):
        if self.kick[0] == 0.0 and self.kick[1] == 0.0:
            return 'z'
        if self.kick[1] == 0.0 and self.kick[2] == 0.0:
            return 'x'
        if self.kick[0] == 0.0 and self.kick[2] == 0.0:
            return 'y'
        return '%.6f_%.6f_%.6f' % (self.kick[0],self.kick[1],self.kick[2])

    def get_td_base_name(self):
        return '%.1f_%.1f_%s_%s' % (self.dt, self.T, self.kick_str(), self.fxc)

    def get_gs_gpw_file(self):
        return self.directory + '%s%s.gpw' % (self.name, self.get_base_name())
    
    def get_full_name(self):
        return self.directory + '%s_%s_td_%s' % (self.name, self.get_base_name(), self.get_td_base_name())

    def get_td_gpw_file(self):
        return self.directory +'%s_%s_td_%s.gpw' % (self.name, self.get_base_name(), self.get_td_base_name())

    def get_dipole_file(self):
        return self.directory +'%s_%s_td_%s.dm' % (self.name, self.get_base_name(), self.get_td_base_name())

    def get_spectrum_file(self):
        return self.directory +'%s_%s_td_%s.spec' % (self.name, self.get_base_name(), self.get_td_base_name())

    def get_lrtddft_file(self):
        return self.directory +'%s_%s_td_%s.gz' % (self.name, self.get_base_name(), self.get_td_base_name())
 
    def groundstate(self):
        if path.exists(self.get_gs_gpw_file()):
            if self.lrtddft:
                self.calc = GPAW(self.get_gs_gpw_file())
            return
         
        atoms = self.atoms
        center_divisible(atoms, self.vacuum, self.h)
        dtype = complex if self.mode == 'lcao' else float
        if self.lrtddft:
            dtype = float
        basis = {} if self.basis is None else self.basis
        convergence = {'bands':'occupied','density':1e-8}
        mixer = Mixer(0.05, 5, weight=100.0) if self.hard else None
        eigensolver = None
        if basis == {} and self.hard:
            if self.bandpar is None:
                eigensolver = 'cg'

        if self.xc == 'GLLBSC':
            self.xc = NonLocalFunctional('GLLBSC')
            scr = C_GLLBScr(self.xc,  1.0, 'GGA_X_PBE_SOL', width=None) # No width
            resp = C_Response(self.xc, 1.0, scr)
            xc = C_XC(self.xc, 1.0, 'GGA_C_PBE_SOL')
            #if self.hard:
            #    self.xc.set_mix(0.2)

        poissonsolver = PoissonSolver(eps=1e-20, remove_moment=self.remove_moment)
        calc = GPAW(mode=self.mode, maxiter=300, h=self.h, nbands=self.nbands, xc=self.xc, basis=basis, 
                    parallel={'band':self.bandpar}, verbose=True, 
                    poissonsolver=poissonsolver, dtype=dtype, charge=self.charge, eigensolver=eigensolver, 
                    convergence=convergence, occupations=FermiDirac(0.05), mixer=mixer)
        atoms.set_calculator(calc)
        if self.collect_vresp:
            from gpaw.lcaotddft.split import VRespCollector
            calc.attach(VRespCollector(self.get_full_name(), calc))
        if self.collect_density:
            from gpaw.lcaotddft.split import DensityCollector
            calc.attach(DensityCollector(self.get_full_name(), calc, self.collect_density))

        atoms.get_potential_energy()
        #if self.hard:
        #    calc.write(self.get_gs_gpw_file()+'.initial', mode='all')
        #    del calc
        #    atoms, calc = restart(self.get_gs_gpw_file()+'.initial', occupations=FermiDirac(0.05), mixer = Mixer(0.01, 5, weight=200.0))
        #    atoms.get_potential_energy()
        if self.lrtddft and basis is None:
            calc.set(convergence={'bands':self.nbands-20}, eigensolver='cg', fixdensity=True)
            atoms.get_potential_energy()

        if self.dxc:
            response = calc.hamiltonian.xc.xcs['RESPONSE']
            response.calculate_delta_xc()
            EKs, Dxc = response.calculate_delta_xc_perturbation()
        calc.write(self.get_gs_gpw_file(), mode='all')
        if self.lrtddft:
            self.calc = calc
        else:
            del calc

    def tddft(self):
        fxc = self.fxc 
        if self.lrtddft:
            lr = TimedLrTDDFT(self.calc, xc=fxc, energy_range=self.energy)
            lr.write(self.get_lrtddft_file())
            from gpaw.lrtddft import photoabsorption_spectrum as spec
            #lr = LrTDDFT(self.get_lrtddft_file())
            lr.diagonalize()
            spec(lr, self.get_spectrum_file(), width=0.2)
            return

        if self.mode == 'lcao':
            td_calc = self.td_calc = TimedLCAOTDDFT(self.get_gs_gpw_file(), fxc=fxc, parallel={'band': self.bandpar}, 
                                                    poissonsolver=PoissonSolver(eps=1e-20, remove_moment=self.remove_moment))
            td_calc.occupations.calculate(td_calc.wfs)
            #td_calc.hamiltonian.xc.xcs['RESPONSE'].set_damp(1e-5)
            
            td_calc.kick(strength=self.kick)
            if self.collect_vresp:
                from gpaw.lcaotddft.split import VRespCollector
                td_calc.attach(VRespCollector(self.get_full_name(), td_calc))
            if self.collect_density is not None:
                from gpaw.lcaotddft.split import DensityCollector
                td_calc.attach(DensityCollector(self.get_full_name(), td_calc, self.collect_density))

        else:
            td_calc = TimedTDDFT(self.get_gs_gpw_file(), parallel={'band': self.bandpar }, 
                            poissonsolver=PoissonSolver(eps=1e-20, remove_moment=self.remove_moment))
            td_calc.occupations.calculate(td_calc.wfs)
            td_calc.absorption_kick(kick_strength=self.kick)
            if fxc:
                td_calc.hamiltonian.linearize_to_xc(fxc)
            if self.collect_vresp:
                from gpaw.lcaotddft.split import VRespCollector
                td_calc.attach(VRespCollector(self.get_full_name(), td_calc))
            if self.collect_density is not None:
                from gpaw.lcaotddft.split import DensityCollector
                td_calc.attach(DensityCollector(self.get_full_name(), td_calc, self.collect_density))
            

        td_calc.propagate(self.dt, int(1000 * self.T / self.dt), self.get_dipole_file(), dump_interval=25)
        td_calc.write(self.get_td_gpw_file(), mode='all')
        del td_calc

 
def calculate(atoms, name=None):
    system = System(atoms, name)
    system.set_h(eval(options.gridspacing))
    system.set_vacuum(eval(options.vacuum))
    system.set_basis(options.basis)
    system.set_time(eval(options.time))
    system.set_timestep(eval(options.timestep))
    system.set_nbands(eval(options.nbands))
    system.set_xc(options.xc)
    system.set_fxc(options.fxc)
    system.set_directory(options.directory)
    system.set_kick(map(eval, options.kick.split(',')))
    system.set_collect_density(options.collect_density)

    if options.collect_split_density:
        system.set_collect_split_density(eval(options.collect_split_density))
    if options.collect_vresp:
        system.set_collect_vresp()
    if options.charge:
        system.set_charge(eval(options.charge))
    if options.nodipole:
        system.set_nodipole()
    if options.dxc:
        system.set_dxc()
    if options.hard:
        system.set_hard()
    if options.lrtddft:
        system.set_lrtddft(eval(options.lrtddft))
    if options.parallel_band:
        system.set_bandpar(eval(options.parallel_band))

    system.groundstate()
    if options.gsonly:
        return
    system.tddft()

if options.system is not None:
    name = options.name
    if name is None:
        name = options.system
    calculate(molecule(options.system), name=name)

for system in args:
    name = options.name
    if name is None:
        name = system
    atoms = read(system)
    calculate(atoms, name=name)
