import sys

import pylab as pl
from gpaw.atom import generator, basis

import plt

symbol = 'N'
g = generator.Generator(symbol)
g.N *= 2
g.run(write_xml=False, **generator.parameters[symbol])
bm = basis.BasisMaker(g, run=False)

rc = 4.
ri = .6 * rc
A = 12.

j = -2
l = g.l_j[j]
n = g.n_j[j]

vconf = g.get_confinement_potential(A, ri, rc)
u_free = g.u_j[j]
u_hard, e = g.solve_confined(j, rc)
u_soft, e = g.solve_confined(j, rc, vconf)

phit_free = bm.smoothify(u_free, l)
phit_hard = bm.smoothify(u_hard, l)
phit_soft = bm.smoothify(u_soft, l)

plt.mkaxis(rc)

pl.title('Transformation of %s %d%s orbital' % (plt.get_name(symbol), n, 
                                                'spd'[l]))
pl.plot(g.r, u_free, 'g', label='AE, free')
pl.plot(g.r, u_hard, 'r--', label='AE, confined', zorder=42)
pl.plot(g.r, phit_hard, 'b', label='PS, confined')

xlabel = r'$r$'
ylabel = r'$r\phi(r)$'

plt.labels(xlabel, ylabel)

actual_rc = g.r[g.r2g(rc)]
rcstring = r'$r_{\mathrm{conf}}$'

plt.vert(actual_rc, rcstring, align='right')
plt.vert(g.rcut_l[l], r'$r_{\mathrm{cut}}$', align='left')

pl.legend()
pl.axis('tight')
pl.axis(xmax=rc * 1.05)

plt.done('orbitaltransform')

#######################################################

#pl.figure()
#pl.title('Confinement potentials for %s %d%s-orbital' % (symbol, n, 
#                                                         'spd'[l]))

#pl.plot(g.r, phit_free, 'r', label='Free atom')
#pl.plot(g.r, phit_hard, 'g', label='Infinite')
#pl.plot(g.r, phit_soft, 'b', label='Smooth')
#ax = pl.axis()
#pl.plot(g.r, vconf/vconf[g.r2g(rc * .98)], 'm--', 
#        label='Potential', zorder=-42)

#plt.vert(actual_rc, rcstring)
#plt.vert(g.r[g.r2g(ri)], r'$r_i$')

#plt.labels(xlabel, ylabel + ', $V_\mathrm{conf}(r)$')

#pl.legend()
#pl.axis(ax)
#pl.axis(xmax=rc * 1.1)
#pl.axis(ymin=0.)

#plt.done('confinement')

##################################################

pl.figure()

tailnorm = .17
rsplit, partial_norm, splitwave = bm.rsplit_by_norm(l, phit_soft, 
                                                    tailnorm ** 2., 
                                                    sys.stdout)

#ax = pl.axis(xmax=rc * 1.1)
pl.plot(g.r, phit_soft, label='SZ function')
pl.plot(g.r, splitwave, label='polynomial/tail')
pl.plot(g.r, phit_soft - splitwave, label='DZ function')
plt.vert(rsplit, r'$r_\mathrm{split}$')
plt.vert(rc, r'$r_\mathrm{conf}$', 'right')
pl.axis(xmax=rc * 1.1)

pl.legend()
pl.title('Multiple-zeta generation, %s %d%s orbital' % (plt.get_name(symbol), 
                                                        n, 'spd'[l]))
plt.labels(xlabel, ylabel)

plt.done('splitvalence')

