#!/usr/bin/env python

import numpy as np
import pylab as pl

# Equation:
#
#   x'' + 2(eps - v) x = 0
#
# With v=-60 for r < 1, v = 0 otherwise, a solution will oscillate inside
# the well but decay outside, if its eigenvalue eps satisfies v < eps < 0.
#
# Inside:
#   x(r) = A sin kr
#
# Outside:
#   x(r) = B exp(-K r)
#
# We'll have to match the values and derivatives at the boundary.  This is
# equivalent to matching the logarithmic derivatives x'(r)/x(r) which will
# require a bit of work.

# Potential inside well:
v = -60.0

# We know that k and K are related to the eigenvalue epsilon as per these
# formulas:
def get_k(eps):
    return np.sqrt(2. * (eps - v))

def get_K(eps):
    return np.sqrt(-2. * eps)

# This is the logarithmic derivative of a sine:
def kcotk(k):
    return k * np.cos(k) / np.sin(k)

# For an exponential the logarithmic derivative is simply -K.
# To match this we must therefore solve k cot k = -K.
#
# Both k and K are functions of eps.  Let's make a plot.
eps = np.linspace(-61.0, 0.0, 1000)
k = get_k(eps)
K = get_K(eps)

FIT_FIGURE = 1 # labels to avoid confusing the plots we're going to make
pl.figure(FIT_FIGURE)
pl.title('Fitting logarithmic derivatives')
pl.plot(eps, kcotk(k), label='k cot k')
pl.plot(eps, -K, label='-K')
pl.xlabel('epsilon')
pl.legend()

# From the plot we see three apparent roots.  Let's use the scipy
# root-finding function fsolve to get more precise values.
#
# Define a function which is zero when k cot k = -K:
def kcotk_minus_K(eps):
    return kcotk(get_k(eps)) + get_K(eps)

# We solve it by calling:
#   root = fsolve(function, initial_guess)
from scipy.optimize import fsolve
eps0 = np.array([fsolve(kcotk_minus_K, -55.8),
                 fsolve(kcotk_minus_K, -43.5),
                 fsolve(kcotk_minus_K, -23.6)])

# eps0 are now the actual eigenvalues with matching logarithmic derivatives.
#
# Let's evaluate the corresponding k and K values:
k0 = get_k(eps0)
K0 = get_K(eps0)

# Add eps0 as vertical lines in the previous plot:
for k1, K1, eps1 in zip(k0, K0, eps0):
    print 'Eigenvalue, k, K: %7.4f %7.4f %7.4f' % (eps1, k1, K1)
    pl.figure(FIT_FIGURE)
    pl.plot([eps1, eps1], [-100, 100], 'k--', label='__nolegend__')
pl.axis(ymin=-15, ymax=5)

# Now that we know k and K, we can plot all the solutions.
# Radial grids for plotting functions inside and outside r=1:
r1 = np.linspace(0.0, 1.0, 1000)
r2 = np.linspace(1.0, 2.0, 1000)
WF_FIGURE = 2
pl.figure(WF_FIGURE)
pl.title('Wave functions')
for eps1, k1, K1 in zip(eps0, k0, K0):
    pl.plot(r1, np.sin(k1 * r1), label='inner[%.2f]' % eps1)
    
    # Make sure the plotted functions match at boundary:
    factor = np.sin(k1 * r2[0]) / np.exp(-K1 * r2[0])
    pl.plot(r2, np.exp(-K1 * r2) * factor, label='outer[%.2f]' % eps1)

#-----------------------------------------------------------------------

# We now throw the 1s and 2s states away, and want to construct a
# pseudopotential to describe the 3s state only.
#
#                              ~
# We define the PS wave as sin kr.
#
#                   ~
# Find have to find k of the pseudowave function.
#
# We know it must match the decaying exponential outside and have the same
# eigenvalue as the known 3s solution.
#      ~     ~
# Thus k cot k = -K
#
# The 3s state is the 'last' one.  Extract numbers from arrays for convenience:
K3s = K0[-1]
eps3s = eps0[-1]
k3s = k0[-1]

# We'll use fsolve again:
def ktfunc(k):
    return kcotk(k) + K3s
kt = fsolve(ktfunc, 2.0) # kt means 'k-tilde'
print 'K3s', K3s
print 'k-tilde', kt # ~2.76
# Compare this to the 1s value of k which is around 2.88

# Add the pseudofunction to the previous wavefunction plot:
pl.figure(WF_FIGURE)
factor = np.sin(k0[-1] * r2[0]) / np.sin(kt * r2[0])
pl.plot(r1, np.sin(kt * r1) * factor, '--', label='3s PS WF')

# The pseudopotential is found by rewriting the Schroedinger equation:
#
#     1 ~     ~ ~         ~
#   - - x'' + v x = eps   x
#     2                3s
#
# yielding:
#                  ~
#   ~           1  x''            1 ~2
#   v = eps   + - ----  = eps   - - k
#          3s   2  ~         3s   2  3s
#                  x
#
# In other words the pseudopotential is equal to the all-electron potential,
# but just shifted by a constant.  This is hardly surprising as we have
# required the pseudofunction to be sinusoidal, so no matter what we should
# end up with the same differential equation as before.
#
vt = eps3s - 0.5 * kt**2.0
print 'pseudopotential inside:', vt # -27'ish

# Let's also make a plot of all the potentials for comparison.
# Of course they're all straight lines so far.
POT_FIGURE = 3
pl.figure(POT_FIGURE)
pl.title('Potentials')
pl.plot(r1, np.ones_like(r1) * v, label='AEP')
pl.plot(r1, np.ones_like(r1) * vt, label='PSP')
pl.plot(r2, np.zeros_like(r2), label='outside')


# Plot the logarithmic derivatives of the pseudofunction as a function
# for different energies:
LOGDERIV_FIGURE = 4
def get_ktilde(eps):
    return np.sqrt(2.*(eps-vt))
# logarithmic derivatives as function of energy, all-electron
pl.figure(LOGDERIV_FIGURE)
pl.title('Logarithmic derivatives')
pl.plot(eps, kcotk(get_k(eps)), label='AE')
pl.plot(eps, kcotk(get_ktilde(eps)), label='PS')
pl.plot([eps3s, eps3s], [-50, 50], 'k--', label='3s eig.')
pl.axis([-30, 0, -50, 50])

#-----------------------------------------------------------------

# Next we want to find a *normconserving* pseudopotential.  This is a bit
# more complicated, so we'll do this numerically.

# Calculate the norm of the AE 3s state by integration.
dr = r1[1]
norm1 = (np.sin(k3s * r1)**2).sum() * dr / np.sin(k3s)**2
# (The division by sin(k3s)**2 defines the function to be 1 at the boundary,
# thus the function is not actually normalized.  But we don't care about that,
# as we just have to construct something with the *same* norm.)
print '3s AE state norm', norm1

# Strategy:
#
# We define our pseudofunction as xt = r + a r**3 + b r**5.
#
# Matching the logarithmic derivative imposes one requirement, meaning
# that a can be expressed in terms of b.  That leaves b as a free parameter,
# which we will adjust such that the norm of the pseudofunction equals that
# of the AE function.
#
# With paper and pencil we find the following formula for a given b, such
# that the pseudofunction matches the 3s logarithmic derivative:
def get_a(b):
    return -(K3s + 1. + (5. + K3s) * b)/(3. + K3s)

# Next we have to find b to conserve the norm.  Now, since the pseudofunction
# is a polynomial, we can just square that polynomial and integrate, thus
# getting a 10th-order equation.  This is a bit too complicated, so instead
# we do it numerically.

# This function calculates the norm of the pseudofunction given b.
def get_norm(b):
    a = get_a(b)
    wf = r1 + a * r1**3 + b * r1**5
    norm2 = (wf**2).sum() * dr
    return norm2

# This function calculates the absolute deviation of the norm from the 3s AE
# norm.  We want to find the zero of this function using e.g. fsolve.
def deviation(b): 
    return get_norm(b) - norm1

# This is the deviation of the norm from the target value as a function of b:
NORM_OPTIMIZATION_FIGURE = 5
pl.figure(NORM_OPTIMIZATION_FIGURE)
pl.title('Enforcing norm conservation')
B0 = np.linspace(-10., 10., 50)
pl.plot(B0, [deviation(B) for B in B0])
pl.xlabel('b')
pl.ylabel('error in norm')
# There are two solutions in this area.

# We'll go with the -5'ish solution.  If we choose the other one, that
# function will also match the logarithmic derivative, but have the
# wrong sign - i.e. it is not continuous unless we change its sign, at
# which point it is no longer differentiable.

bvalue = fsolve(deviation, -5.0)
alternative_bvalue = fsolve(deviation, 5.0)
avalue = get_a(bvalue)
print 'a, b = ', avalue, bvalue
print '(Or a, b = %s, %s)' % (get_a(alternative_bvalue), alternative_bvalue)


print 'Value of b and error in norm', bvalue, deviation(bvalue)

# This is the normconserving pseudofunction, then:
phit = r1 + avalue * r1**3 + bvalue * r1**5

# Check that things are correct...
print 'Norm of pseudofunction', (phit**2).sum() * dr

# The pseudopotential is found like in the previous case by
# differentiating twice and dividing (vt = eps + (1/2) xt''/xt):
def get_vtnc(r, b=bvalue):
    a = get_a(bvalue)
    enum = 3 * a + 10 * r**2 * b
    denom = 1 + a * r**2 + b * r**4
    v = eps3s + enum / denom
    return v

# Now plot stuff:
vtnc = get_vtnc(r1)
pl.figure(POT_FIGURE)
pl.plot(r1, vtnc, label='NCPSP')

pl.figure(WF_FIGURE)
matchvalue = np.sin(k3s * r2[0])
pl.plot(r1, phit * matchvalue / phit[-1], '--', label='3s normcons')

# (Ideally we wouldn't have to rescale the function as it should
# already match, but now we already plotted some arbitrarily scaled
# functions, so we'll have to do it anyway for this plot)


#--------------------------------------------------------------

# Finally we want to calculate the logarithmic derivatives as a function
# of various energies.

# This we want to consider a range of energies, and for each energy
# solve the differential equation.  For each solution we can then get
# the logarithmic derivatives.

# Unfortunately we cannot solve the differential equation analytically
# anymore, since the pseudopotential is a bit more complicated now.
# We'll have to use a numerical method again:

from scipy.integrate import odeint

# Convert 2-order differential equation to system of two first-order ones:
#
#   -y'/2 + vx = eps x
#            y = x'
#
# This function integrates the above differential equations given any energy,
# then returns the solution as a matrix:
def solve_with_eps(eps1):
    def derivs(X, r):
        x, y = X
        xprime = y
        yprime = 2.0 * (get_vtnc(r) - eps1) * x
        return np.array([xprime, yprime])
    solution = odeint(derivs, np.array([0.0, 1.0]), r1)
    return solution

# As a sanity check, verify that the numerical function gets the same
# solution as the defined PS wave function at the eigenvalue:
phit3s = solve_with_eps(eps3s)[:, 0]

pl.figure(WF_FIGURE)
pl.plot(r1, phit3s * matchvalue / phit3s[-1], ':', 
        label='3s NC numeric')
# Looks okay.

# Now for each energy in some grid, calculate the logarithmic derivative:
epsrange = np.linspace(-2., 2., 200) + eps3s
lderivs = []
for eps1 in epsrange:
    solution = solve_with_eps(eps1)
    last_phi, last_deriv = solution[-1] # value and derivative at r=1
    logderiv = last_deriv / last_phi
    lderivs.append(logderiv)

# Add logarithmic derivatives to previous plot:
pl.figure(LOGDERIV_FIGURE)
pl.plot(epsrange, lderivs, label='3s NC')
pl.axis([-30, 0, -50, 50])
# It WORKS.  Another job well done!

for fig in [LOGDERIV_FIGURE, WF_FIGURE, FIT_FIGURE, POT_FIGURE]:
    pl.figure(fig)
    pl.legend()
pl.show()

