#!/usr/bin/env python

import sys
from ase.io import read
from ase.data import chemical_symbols
from ase.calculators.neighborlist import NeighborList
import numpy as np

def d(r):
    return np.linalg.norm(r)


def get_coordination(system):
    Da = 1.4
    D = 3.3
    Z0 = system.numbers[0]
    sys0 = system[system.numbers == Z0]
    ads = system[system.numbers != Z0]
    N0 = len(sys0)
    
    nl = NeighborList([D / 2.0] * len(sys0) + [Da / 2.0] * len(ads),
                      bothways=True)
    nl.update(system)
    for a, ads1 in enumerate(ads):
        neighbours = []
        asym = chemical_symbols[ads.numbers[a]]
        neighbours, _ = nl.get_neighbors(N0 + a)
        neighbours = [n for n in neighbours if n < N0]
        #print '%s: %d neighbours' % (asym, len(neighbours))
        nnvalues = []
        nncnvalues = [] # number of nearest common neighbours
        for n in neighbours:
            nncn = 0
            nn, _ = nl.get_neighbors(n)
            nn = [nn0 for nn0 in nn if nn0 < N0 and nn0 != n]

            #for N2 in neighbours:
            #    myneighbours, _ = nl.get_neighbors(N1)
            #    for N2 in nn:
            #        itsneighbours, _ = nl.get_neighbors(N2)
            #    common = [N for N in myneighbours if N in itsneighbours
            #              and N < N0 and N != N1 and N != N2
            #              ]
            #    print 'comm', len(common)
                #system.get_distance(N, x) for x in range(len(system) - 1)
                #if N in neighbours:
                #    nncn += 1
            #nncnvalues.append(len(common))
            
            #nncn = [nn0 for nn0 in nn if 
            #nn = [a for a, pos in enumerate(sys0.positions)
            #      if 0.01 < d(pos - sys0.positions[n]) < D]
            nnvalues.append(len(nn))
        #if nnvalues:
        #    print 'coordination numbers of neighbours', nnvalues
        #    print 'average', float(sum(nnvalues)) / len(nnvalues)

        nncn = []
        for N1 in neighbours:
            print 'N1', N1
            NN1, _ = nl.get_neighbors(N1)
            sNN1 = set(NN1)
            for N2 in neighbours:
                if N2 == N1:
                    continue
                NN2, _ = nl.get_neighbors(N2)
                #print NN1, NN2
                common = [N for N in sNN1 if N in NN2 and N != N1 and N != N2
                          and N < N0]
                print N1, N2, 'common', common
                nncn.append(len(common))
        nncnvalues = nncn
                #sNN2 = set([N for N in NN2 if N != N1 and N != N2 and N < N0])
                #nncn.append(sNN1.intersection(sNN2))
        #print 'Qnncn', nncn
        
        yield (asym, neighbours,
               ''.join(system[neighbours].get_chemical_symbols()),
               nnvalues, nncnvalues)
               
        #if len(neighbours) == 0:
        #    print 'hello'
        #else:
        #    print ('%s: %d neighbours (%s), coords %s -> avg %f'
        #           % (asym,
        #              len(neighbours),
        #              ''.join(system[neighbours].get_chemical_symbols()),
        #              str(nnvalues),
        #              sum(nnvalues) / float(len(nnvalues))))
    #print

if __name__ == '__main__':
    for arg in sys.argv[1:]:
        print arg
        system = read(arg)
        for (asym, neighbours, symbols, nnvalues,
             nncnvalues) in get_coordination(system):
            if len(neighbours) == 0:
                print 'hello'
            else:
                print ('%s: %d neighbours (%s), coords %s -> avg %f'
                       % (asym,
                          len(neighbours),
                          symbols,
                          str(nnvalues),
                          sum(nnvalues) / float(len(nnvalues))))
                print nncnvalues
                print 'nncn', nncnvalues
                print '(coord + 2nncn) / 3', (2 * sum(nncnvalues) + sum(nnvalues)) / float(len(nnvalues))
        print

