import sys
from ase.io import read, write
from ase.visualize import view
import numpy as np
from ase import Atoms

dist = 1.28670742509

inputdir = 'inputfiles'

def run(arg, indices):
    system = read('%s/%s' % (inputdir, arg))
    system.center(vacuum=6.0)
    outfname1 = '%s.%04d.%s.traj' % ('Pt', len(system), 'cubo')
    outfname2 = 'O.%s' % outfname1
    write(outfname1, system)

    neighbourpos = [system.positions[x] for x in indices]
    assert len(neighbourpos) == 3
    center = np.array(neighbourpos).sum(axis=0) / 3.0
    d1 = neighbourpos[1] - neighbourpos[0]
    d2 = neighbourpos[2] - neighbourpos[0]
    normal = np.cross(d1, d2)
    normal /= np.linalg.norm(normal)
    print np.linalg.norm(d1), np.linalg.norm(d2)
    print np.linalg.norm(normal)
    Opos = center + normal * dist
    O = Atoms('O', positions=[Opos])
    system += O
    #view(system)
    write(outfname2, system)


#args = sys.argv[1:]
allindices = [[3,2,10],
              [13,17,20],
              [58,59,60],
              [53,122,132],
              [115,120,131],
              [244, 245, 246],
              [216, 404, 427]
              ]
args = """Pt.0013.co.traj
Pt.0055.co.traj
Pt.0147.co.traj
Pt.0309.co.traj
Pt.0561.co.traj
Pt.0923.co.traj
Pt.1415.co.traj
Pt.2057.co.traj""".split()

for arg, indices in zip(args, allindices):
    run(arg, indices)


if 0:
        if opts.ontop is not None:
            Xtop = cluster.copy()
            Opos = Xtop.positions[-1]
            Oneighbours = [i for i, pos in enumerate(Xtop.positions[:-1])
                           if np.linalg.norm(pos -
                                             Opos) < 3.0 and Xtop.numbers[i]]
            if not opts.ontop in Oneighbours:
                print Oneighbours, opts.ontop
                raise ValueError('bad geometry or something')
            Oneighbours.remove(opts.ontop)
            Tpos = Xtop.positions[opts.ontop]
            d1 = Xtop.positions[Oneighbours[0]] - Tpos
            d2 = Xtop.positions[Oneighbours[1]] - Tpos
            normal = np.cross(d1, d2)
            norm = np.linalg.norm(normal)
            normal /= norm
            if np.vdot(Opos - Tpos, normal) < 0:
                normal *= -1.0
            Xtop.positions[-1, :] = Tpos + normal * 1.90
            Otop = Xtop[Xtop.numbers != 0].copy()
            Xtop.positions[-1, :] = Tpos + normal * 3.15
            Xtop += Atoms('C')
            Xtop.positions[-1, :] = Tpos + normal * 2.00
            COtop = Xtop[Xtop.numbers != 0].copy()

