import numpy as np

from gpaw.utilities import erf


class DipoleCorrectionPoissonSolver:
    def __init__(self, poissonsolver, direction):
        self.corrector = DipoleCorrection(direction)
        self.poissonsolver = poissonsolver
        self.relax_method = 0
        self.nn = 17

    def set_grid_descriptor(self, gd):
        self.poissonsolver.set_grid_descriptor(gd)

    def initialize(self):
        self.poissonsolver.initialize()

    def solve(self, phi, rho, **kwargs):
        gd = self.poissonsolver.gd
        drho, dphi = self.corrector.get_dipole_correction(gd, rho)
        iters = self.poissonsolver.solve(phi, rho + drho, **kwargs)
        phi += dphi
        return iters

    def estimate_memory(self, mem):
        self.poissonsolver.estimate_memory(mem)


class DipoleCorrection:
    def __init__(self, direction):
        self.c = direction

    def get_dipole_correction(self, gd, rhot_g):
        c = self.c        
        moment = gd.calculate_dipole_moment(rhot_g)[c]
        if abs(moment) < 1e-12:
            return gd.zeros(), gd.zeros()
        
        r_g = gd.get_grid_point_coordinates()[c]
        
        cellsize = abs(gd.cell_cv[c, c])
        sr_g = 2.0 / cellsize * r_g - 1.0
        alpha = 12.0
        drho_g = sr_g * np.exp(-alpha * sr_g**2)
        moment2 = gd.calculate_dipole_moment(drho_g)[c]
        factor = -moment / moment2
        drho_g *= factor

        #charge = gd.integrate(drho_g)
        #drho_g -= charge / gd.dv / np.prod(gd.N_c - gd.pbc_c)
        phifactor = factor * (np.pi / alpha)**1.5 * cellsize**2 / 4.0
        dphi_g = -phifactor * erf(sr_g * np.sqrt(alpha))
        #print 'charge', charge
        #print 'moment1', moment_v
        #print 'corr moment1', correctible_moment_v
        #print 'moment2', moment2_v * factor
        return drho_g, dphi_g

    def apply_dipole_correction(self, gd, rhot_g):

        drho_g, dphi_g = self.get_dipole_correction(gd, moment)
        rhot_g


class MehDipoleCorrection:
    def __init__(self, filter_c):
        assert len(filter_c) == 3
        self.filter_c = map(bool, filter_c)
        self.directions = [c for c, f in enumerate(filter_c) if f]

    def get_correction_functions(self, scaled_r):
        return get_dipole_correction_functions(scaled_r)

    def apply_dipole_correction(self, gd, rhot_g):
        moment_v = gd.calculate_dipole_moment(rhot_g)
        correctible_moment_v = moment_v * self.filter_c
        momentsize = np.linalg.norm(correctible_moment_v)
        if momentsize < 1e-12:
            return gd.zeros(), gd.zeros()
        
        direction_v = correctible_moment_v / momentsize
        
        r_vg = gd.get_grid_point_coordinates()
        
        # Find the "radius", a, of the cell
        a_v = gd.cell_cv.diagonal()
        cell = np.take(a_v, self.directions)
        ndirections = len(cell)
        if ndirections == 1:
            a = cell[0]
        elif ndirections == 2:
            x, y = cell
            a = 2 * x * y / (x + y)
        elif ndirections == 3:
            x, y, z = cell
            a = 2 * x * y * z / (y * z + z * x + x * y)

        #a /= 3

        r1_vg = r_vg.reshape(3, -1)
        print 'a', a, a_v
        print 'dir', direction_v, np.linalg.norm(direction_v)
        sr1_g = np.dot(direction_v, 2.0 / a * r1_vg - 1.0)
        sr_g = sr1_g.reshape(r_vg.shape[1:])
        print 'rvalues', sr_g.min(), sr_g.max(), sr_g.sum()
        alpha = 16.0
        drho_g = sr_g * np.exp(-alpha * sr_g**2)

        moment2_v = gd.calculate_dipole_moment(drho_g)
        size1 = np.linalg.norm(moment2_v)
        size2 = np.linalg.norm(correctible_moment_v)
        factor = size2 / size1
        drho_g *= factor

        charge = gd.integrate(drho_g)
        #Ng = np.prod(gd.N_c)
        drho_g -= charge / gd.dv / np.prod(gd.N_c - self.filter_c)
        print 'newcharge', gd.integrate(drho_g)

        phifactor = factor * (np.pi / alpha)**1.5 * a**2 / 4.0
        dphi_g = -phifactor * erf(sr_g * np.sqrt(alpha))
        print 'charge', charge
        print 'moment1', moment_v
        print 'corr moment1', correctible_moment_v
        print 'moment2', moment2_v * factor
        
        return drho_g, dphi_g

