import sys
from ase.io import read, write
from ase.visualize import view
import numpy as np
from ase import Atoms

inputdir = 'inputfiles'

def run(arg, indices, site_name, dist):
    system = read('%s/%s' % (inputdir, arg))
    system.center(vacuum=6.0)
    outfname1 = '%s.%04d.%s.traj' % ('Pt', len(system), 'cubo')
    outfname2 = 'CO.%s.%s' % (site_name, outfname1)
    write(outfname1, system)

    neighbourpos = [system.positions[x] for x in indices]
    a0 = np.linalg.norm(neighbourpos[1] - neighbourpos[0]) * np.sqrt(2.0)
    assert len(neighbourpos) == 3
    center = np.array(neighbourpos).sum(axis=0) / 3.0
    d1 = neighbourpos[1] - neighbourpos[0]
    d2 = neighbourpos[2] - neighbourpos[0]
    normal = np.cross(d1, d2)
    normal /= np.linalg.norm(normal)
    Opos = center + normal * dist
    CO = Atoms('CO', positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.150]])
    CO.rotate([0.0, 0.0, 1.0], np.array([-1.0, 1.0, 1.0]))
    CO.positions += center + normal * (dist + a0 / np.sqrt(3.0))
    #O = Atoms('O', positions=[Opos])
    system += CO
    #view(system)
    print outfname2
    write(outfname2, system)

# Triples of indices of 2nd-nearest atoms in the layer *below* the surface
indices_111_top = [[9, 11, 12], #13
                  [15, 36, 43], #55
                  [42, 48, 56], #147
                  [121, 123, 124], #309
                  [117, 223, 243], #561
                  [218, 224, 242], #923
                  [403, 405, 406] #1415
                  ]

indices_top_edge = [[9, 11, 12], #13
                    [15, 36, 43], #55
                    [56, 106, 103], #147
                    [118, 124, 131], #309
                    [243, 245, 246], #561
                    [243, 399, 415], #923
                    [425, 644, 641] #1415
                    ]

args = """Pt.0013.co.traj
Pt.0055.co.traj
Pt.0147.co.traj
Pt.0309.co.traj
Pt.0561.co.traj
Pt.0923.co.traj
Pt.1415.co.traj
Pt.2057.co.traj""".split()

for arg, indices in zip(args, indices_111_top):
    run(arg, indices, '111_top', 1.900)

for arg, indices in zip(args, indices_top_edge):
    run(arg, indices, 'top_edge', 2.000)

if 0:
        if opts.ontop is not None:
            Xtop = cluster.copy()
            Opos = Xtop.positions[-1]
            Oneighbours = [i for i, pos in enumerate(Xtop.positions[:-1])
                           if np.linalg.norm(pos -
                                             Opos) < 3.0 and Xtop.numbers[i]]
            if not opts.ontop in Oneighbours:
                print Oneighbours, opts.ontop
                raise ValueError('bad geometry or something')
            Oneighbours.remove(opts.ontop)
            Tpos = Xtop.positions[opts.ontop]
            d1 = Xtop.positions[Oneighbours[0]] - Tpos
            d2 = Xtop.positions[Oneighbours[1]] - Tpos
            normal = np.cross(d1, d2)
            norm = np.linalg.norm(normal)
            normal /= norm
            if np.vdot(Opos - Tpos, normal) < 0:
                normal *= -1.0
            Xtop.positions[-1, :] = Tpos + normal * 1.90
            Otop = Xtop[Xtop.numbers != 0].copy()
            Xtop.positions[-1, :] = Tpos + normal * 3.15
            Xtop += Atoms('C')
            Xtop.positions[-1, :] = Tpos + normal * 2.00
            COtop = Xtop[Xtop.numbers != 0].copy()

