import sys
from ase.io import read, write
from ase.visualize import view
import numpy as np
from ase import Atoms

inputdir = 'inputfiles'

def run(arg, indices, site_name, dist):
    system = read('%s/%s' % (inputdir, arg))
    system.numbers[:] = 79
    system.positions *= 4.218 / 3.999
    system.center(vacuum=6.0)
    outfname1 = '%s.%04d.%s.morespace.traj' % ('Au', len(system), 'cubo')
    outfname2 = 'O.%s.%s' % (site_name, outfname1)

    site = indices[-1]
    below1, below2 = indices[:-1]

    pos = system.positions
    sitepos = pos[site]
    u1 = pos[below1]
    u2 = pos[below2]

    neighbourpos = [system.positions[x] for x in indices]
    a0 = np.linalg.norm(neighbourpos[1] - neighbourpos[0]) * np.sqrt(2.0)
    assert len(neighbourpos) == 3
    center = (u1 + u2) / 2.0
    #center = np.array([u1, u2]).sum(axis=0) / 2.0
    normal = sitepos - center
    normal /= np.linalg.norm(normal)
    #d1 = neighbourpos[1] - neighbourpos[0]
    #d2 = neighbourpos[2] - neighbourpos[0]
    #normal = np.cross(d1, d2)
    
    CO = Atoms('CO', positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.150]])
    CO.rotate([0.0, 0.0, 1.0], normal)
    CO.positions += sitepos + normal * dist
    #O = Atoms('O', positions=[Opos])
    system += CO
    system.center(vacuum=6.0)
    #view(system)
    print outfname2
    write(outfname2, system)
    system.pop()
    system.pop()
    print outfname1
    write(outfname1, system)

# Triples of indices of 2nd-nearest atoms in the layer *below* the surface
#indices_111_top = [[9, 11, 12], #13
#                  [15, 36, 43], #55
#                  [42, 48, 56], #147
#                  [121, 123, 124], #309
#                  [117, 223, 243], #561
#                  [218, 224, 242], #923
#                  [403, 405, 406] #1415
#                  ]

indices_top_edge = [[9, 11, 10], #13
                    [15, 36, 20], #55
                    [56, 106, 60], #147
                    [118, 124, 132], #309
                    [243, 245, 244], #561
                    [243, 399, 253], #923
                    [425, 644, 429] #1415
                    ]

args = """Pt.0013.co.traj
Pt.0055.co.traj
Pt.0147.co.traj
Pt.0309.co.traj
Pt.0561.co.traj
Pt.0923.co.traj
Pt.1415.co.traj
Pt.2057.co.traj""".split()

#for arg, indices in zip(args, indices_111_top):
#    run(arg, indices, '111_top', 1.900)

for arg, indices in zip(args, indices_top_edge):
    run(arg, indices, 'bridge_edge', 1.850)

