import os
import sys
from optparse import OptionParser

import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.visualize import view
from ase.calculators.neighborlist import NeighborList

p = OptionParser(usage='%prog NAME FILE...')
p.add_option('--seed', #type=int,
             default='0',#default=0,
             help='random seed')
p.add_option('--write', action='store_true',
             help='actually write files')
p.add_option('--plot', action='store_true', #metavar='N1-N2', default='0-0',
             help='make a beautiful plot of the construction procedure'
             ' for given sizes.')
p.add_option('--ontop', type=int, metavar='ATOM',
             help='move adsorbate on top of ATOM')
p.add_option('--central', action='store_true',
             help='ignore adsorbate')
#p.add_option('--vertex', type=float, metavar='DIST',
#             help='put adsorbate on vertex site')
opts, args = p.parse_args()

name = args[0]
assert not os.path.isfile(name)

args1 = args[1:]
args2 = args[2:]


def get_coordination_numbers(system, maxbondlength=3.0):
    radius = maxbondlength / 2.0
    coordination = np.zeros(len(cluster), dtype=int)
    neighbors = NeighborList([radius] * len(cluster), self_interaction=False,
                             bothways=True)
    neighbors.update(cluster)
    for a in range(len(cluster)):
        indices, offsets = neighbors.get_neighbors(a)
        assert len(offsets) == len(indices)
        #for b in indices:
        #    assert cluster.numbers[b] != 0
        #    coordination[b] += 1
        coordination[a] += len(indices)
    return coordination

def main(smallfname, largefname, name):
    small = read(smallfname)
    large = read(largefname)
    assert len(small) < len(large)
    
    if opts.central:
        if small.numbers[-1] == 8:
            small.pop(-1)
            large.pop(-1)
        small.center()
        large.center()
        cen_s = small.positions.sum(axis=0) / len(small)
        cen_l = large.positions.sum(axis=0) / len(large)
        small.positions -= cen_s
        large.positions -= cen_l
        small += Atoms('O')
        large += Atoms('O')

    small.center(vacuum=5.0)
    large.center(vacuum=5.0)
    Na = len(large[large.numbers != large.numbers[0]])
    
    seeds = map(int, opts.seed.split(','))
    for seed in seeds:
        gen = np.random.RandomState(seed)

        cluster = large.copy()

        maxbondlength = 3.0

        radius = maxbondlength / 2.0
        coordination = np.zeros(len(cluster), dtype=int)
        neighbors = NeighborList([radius] * len(cluster), self_interaction=False,
                                 bothways=True)
        neighbors.update(cluster)
        for a in range(len(cluster)):
            indices, offsets = neighbors.get_neighbors(a)
            assert len(offsets) == len(indices)
            #for b in indices:
            #    assert cluster.numbers[b] != 0
            #    coordination[b] += 1
            coordination[a] += len(indices)

        def get_neighbors(a):
            indices, offsets = neighbors.get_neighbors(a)
            return [index for index in indices if cluster.numbers[index] != 0]


        if 0:
            def get_coordination_numbers(system, maxbondlength):
                radius = maxbondlength / 2.0
                coordination = np.zeros(len(system), dtype=int)
                cutoffs = (system.numbers != 0).astype(int) * radius
                #print cutoffs, system.numbers
                assert len(cutoffs) == len(coordination), (len(cutoffs), len(coordination))
                neighbors = NeighborList(cutoffs, self_interaction=False)
                neighbors.update(system)
                for a in range(len(system)):
                    #if system.numbers[a] == 0:
                    #    continue
                    indices, offsets = neighbors.get_neighbors(a)
                    if system.numbers[a] == 0:
                        assert len(indices) == 0, (indices, system[indices])
                    assert len(offsets) == len(indices)
                    for b in indices:
                        assert system.numbers[b] != 0
                        coordination[b] += 1
                    coordination[a] += len(indices)
                return coordination


        sm_pos_O = small.positions[-Na]
        lg_pos_O = large.positions[-Na]
        cen_sm = small.positions[:-Na].sum(axis=0) / (len(small) - Na)
        cen_lg = large.positions[:-Na].sum(axis=0) / (len(large) - Na)

        #a_cen_large = np.argmin([((cen_lg - pos)**2).sum() for pos in large.positions])
        #print a_cen_large
        #err = np.linalg.norm(large.positions[a_cen_large] - cen_lg)
        #assert err < 1e-5, err

        Odisps_sm = [pos - sm_pos_O for pos in small.positions]

        def find_removables():
            Odisps_lg = [pos - lg_pos_O for a, pos in enumerate(cluster.positions)]

            removables = []
            unremovables = []
            ghosts = []

            for a, disp in enumerate(Odisps_lg):
                mostsimilar = min([((dispsm - disp)**2).sum() 
                                   for dispsm in Odisps_sm])
                #print mostsimilar
                if mostsimilar < 1e-2:
                    unremovables.append(a)
                elif cluster.numbers[a] != 0:
                    removables.append(a)
                else:
                    ghosts.append(a)
            assert len(unremovables) == len(small), (len(unremovables), len(small))
            return removables


        def save():
            realcluster = cluster[cluster.numbers != 0]
            purecluster = realcluster.copy()
            N0 = purecluster.numbers[0]
            while purecluster.numbers[-1] != N0:
                purecluster.pop(-1)
            Ngold = len(purecluster)

            holycluster = purecluster.copy()
            holycenter = holycluster.positions.sum(axis=0) / len(holycluster)
            a_hole = np.argmin([np.linalg.norm(holycenter - pos)
                                for pos in holycluster.positions])
            holycluster.pop(a_hole)

            seeddir = 'seed.%02d' % seed
            if not os.path.exists(seeddir):
                os.mkdir(seeddir)
            geomdir = '%s/Au.%03d' % (seeddir, Ngold)
            if not os.path.exists(geomdir):
                os.mkdir(geomdir)

            playable = cluster.copy()
            playable.center(vacuum=5.0)
            for a, pos in enumerate(playable.positions):
                if playable.numbers[a] == 0:
                    pos[:] = 0.0
                    playable.numbers[a] = 79


            if opts.ontop is not None:
                Xtop = cluster.copy()
                Opos = Xtop.positions[-1]
                Oneighbours = [i for i, pos in enumerate(Xtop.positions[:-1])
                               if np.linalg.norm(pos -
                                                 Opos) < 3.0 and Xtop.numbers[i]]
                if not opts.ontop in Oneighbours:
                    #print Oneighbours, opts.ontop
                    raise ValueError('bad geometry or something')
                Oneighbours.remove(opts.ontop)
                Tpos = Xtop.positions[opts.ontop]
                d1 = Xtop.positions[Oneighbours[0]] - Tpos
                d2 = Xtop.positions[Oneighbours[1]] - Tpos
                normal = np.cross(d1, d2)
                norm = np.linalg.norm(normal)
                normal /= norm
                if np.vdot(Opos - Tpos, normal) < 0:
                    normal *= -1.0
                Xtop.positions[-1, :] = Tpos + normal * 1.90
                Otop = Xtop[Xtop.numbers != 0].copy()
                Xtop.positions[-1, :] = Tpos + normal * 3.15
                Xtop += Atoms('C')
                Xtop.positions[-1, :] = Tpos + normal * 2.00
                COtop = Xtop[Xtop.numbers != 0].copy()

            def dowrite(fname, cluster):
                tokens = []
                if name != '':
                    tokens.append(name)
                if opts.central:
                    tokens.append('central')
                tokens.append(fname)
                dst = '%s/%s' % (geomdir, '.'.join(tokens))
                if opts.write:
                    print dst
                    write(dst, cluster)
                else:
                    print 'would write to %s' % dst
            dowrite('cluster.traj', purecluster)
            if not opts.central:
                dowrite('geom.traj', realcluster)
                dowrite('xgeom.traj', cluster)
            #dowrite('holy.traj', holycluster)
            #dowrite('playable.traj', playable)
            if opts.ontop is not None:
                dowrite('top.O.traj', Otop)
                dowrite('top.CO.traj', COtop)

        removables = find_removables()
        original_removables = list(removables)

        Nplotmin, Nplotmax = 81, 87
        #Nplotmin, Nplotmax = map(int, opts.plot.split('-'))
        plotorigclusters = []
        plotclusters = []
        plotcoordination = []
        plottable_removed_atoms = []

        save()
        for i in range(len(large) - len(small)):
            rcoord = coordination[removables]
            mincoord = min(rcoord)
            removal_candidates = np.array(removables, int)[rcoord == mincoord]
            r = gen.randint(len(removal_candidates))
            a = removal_candidates[r]
            #print removal_candidates, coordination[removal_candidates], '->', a
            #print removables
            removables.remove(a)

            cluster.numbers[a] = 0
            coordination[a] = 0
            for b in get_neighbors(a):
                coordination[b] -= 1

            if Nplotmin <= list(cluster.numbers).count(79) < Nplotmax:
                plotorigclusters.append(cluster.copy())
                realcluster = cluster[cluster.numbers != 0]
                coord = coordination[cluster.numbers != 0]
                plotclusters.append(realcluster)
                plotcoordination.append(list(coord))
                plottable_removed_atoms.append(a)

            save()

        def mkplot(clusters, origclusters, coordnumbers):
            from ase.data import covalent_radii
            covalent_radii[90:] = 1.34
            from ase.gui.gui import GUI
            class MyGUI(GUI):
                def __init__(self, *args, **kwargs):
                    GUI.__init__(self, *args, **kwargs)

                def set_colors(self):
                    new = self.drawing_area.window.new_gc
                    alloc = self.colormap.alloc_color
                    c0 = np.array([1., 0., 0.])
                    c1 = np.array([0., 0., 1.])
                    for z in self.images.Z:
                        i = -1
                        if self.frame is not None:
                            i = self.frame - 1
                        try:
                            w = self.w[i][z]
                        except IndexError:
                            print self.w.shape, i, z
                            raise
                        r, g, b = c0 * (1 - w) + c1 * w
                        self.colors[z] = new(alloc(int(65535 * r),
                                                   int(65535 * g),
                                                   int(65535 * b)))


            system = Atoms()

            c1 = clusters[0].copy()
            c1.center(vacuum=2.0)
            cell = c1.cell
            opos = c1.positions[-1]

            from ase.data.colors import jmol_colors
            from ase.constraints import FixAtoms

            bigmask = []

            X = [0.0, 0.0, 0.0, 0.0, 0.2, 0.4, 0.7, 0.8, 0.9, 1., 1., 1., 1.]
            #X = np.linspace(1.0, 0.0, 13)
            for Z in range(len(X)):
                x = X[Z] * 0.75
                jmol_colors[90 + Z] = (x, x, x)

            jmol_colors[79] = [1., 1., 1.]

            for i, (cluster, orig, coord) in enumerate(zip(clusters, origclusters,
                                                           coordnumbers)):
                cluster = orig # XXX
                removemask = np.array([(N in original_removables)
                                       for N in range(len(cluster))])
                existsmask = (cluster.numbers != 0)
                mask = removemask & existsmask
                cluster.numbers[mask] = 78
                #print 'sdjkf', len(coord), len(mask)
                coords = np.array(coord)
                #print min(coords)
                #cluster.numbers[mask & (coordnumbers == 6)] = 81

                mask = removemask[existsmask]
                cluster = cluster[existsmask]
                bigmask.extend(mask)


                for c in range(13):
                    #cluster.numbers[mask & (coords == 5)] = 81
                    cluster.numbers[mask & (coords == c)] = 90 + c

                cluster.rotate([1, 0, 0], [1, 0.3, 0])
                cluster.rotate([0, 0, 1], [0.3, 0, 1])
                cluster.rotate([0, 0, 1], [1, 0, 0])
                #cluster.rotate([0, 1, 0], [0, 1, 0.3])
                #cluster.rotate([], [])
                cluster.cell = cell
                #inc
                #if i < 3:
                cluster.positions += opos - cluster.positions[-1]
                cluster.positions[:, 0] += cluster.cell[0,0] * (i % 6)*1.05
                if 0:#i > 2:
                    #cluster.positions += opos - cluster.positions[-1]
                    cluster.positions -= [0., 20., 0]
                    #cluster.positions[:, 0] -= [cluster.cell[0,0] * 3]
                system += cluster
                #system.#center(vacuum=3.0, axis=0)

            from ase.gui.images import Images
            from ase.gui.gui import GUI
            #system.rotate([0, 0, 1], [1, 0, 0])
            #system.rotate([0, 1, 0], [0, 1, 0.3])
            c = FixAtoms(mask=np.array(bigmask)^True)
            system.set_constraint(c)

            from ase.io.eps import EPS

            eps = EPS(system, show_unit_cell=False)

            #eps.write('peel.eps')

            eps.filename = 'peel.eps'
            eps.write_header()
            #eps.write_body()

            try:
                from matplotlib.path import Path
            except ImportError:
                Path = None
                from matplotlib.patches import Circle, Polygon
            else:
                from matplotlib.patches import Circle, PathPatch

            indices = eps.X[:, 2].argsort()
            from matplotlib.patches import Polygon
            for a in indices:
                xy = eps.X[a, :2]
                if a < eps.natoms:
                    #c = eps.T[a]
                    x1 = eps.d[a] / 2
                    x2 = eps.colors[a]
                    circle = Circle(xy, x1, facecolor=x2)
                    circle.draw(eps.renderer)
                    if a in [22, 110, 182, 278, 361]:
                        hxy = eps.d[a] / 4
                        col = 'white' #eps.colors[0]
                        line1 = Polygon((xy + hxy, xy - hxy), edgecolor=col,
                                        lw=2)
                        line2 = Polygon((xy + (hxy, -hxy), xy + (-hxy, hxy)),
                                        edgecolor=col, lw=2)
                        line1.draw(eps.renderer)
                        line2.draw(eps.renderer)


            eps.write_trailer()


            if 0:
                images = Images()
                images.initialize([system])
                gui = GUI(images, show_unit_cell=False)
                gui.run()
            #from ase.visualize import view
            #view(system)


        if opts.plot:
            mkplot(plotclusters, plotorigclusters, plotcoordination)



for smallfname, largefname in zip(args1, args2):
    main(smallfname, largefname, name)

