import sys
import fileinput

import pylab as pl
import numpy as np


#files = sys.argv[1:]

class History:
    def __init__(self):
        self.energies = []
        self.weights = []

    def fromfile(self, inp):
        for line in inp:
            if line.startswith('energy'):
                e = float(line.split()[1])
                self.energies.append(e)
            elif line.startswith('weights'):
                w = np.array(map(float, line.split()[1:]))
                self.weights.append(w)
        if len(self.weights) == 1 + len(self.energies):
            self.weights.pop()
        self.weights = np.asarray(self.weights)
        assert len(self.weights) == len(self.energies), \
               '%d vs %d' % (len(self.weights), len(self.energies))
        self.iters = len(self.energies)
        self.nw = self.weights.shape[1]
        assert self.nw == 64

    def plot_energy(self):
        pl.plot(self.energies)
        return pl

    #def plot_weights(self):
    #    ivalues = range(self.iters)
    #    w = np.linspace(0.0, 1.0, self.nw)
    #    pl.contourf(ivalues, w, self.weights)

    def plot_weights(self):
        w = self.weights.T.copy()
        for i, single_weight in enumerate(w):
            remainder = i % 8
            if 0 <= remainder < 4:
                color = 'r'
            else:
                color = 'g'
            pl.plot(range(len(single_weight)), single_weight, color)
        pl.xlabel('iteration')
        pl.ylabel('weights (colour by expected type Cl vs Na)')

    def show(self):
        pl.show()

    def plot_weightsum(self):
        pl.plot(np.array([w.sum() for w in self.weights]) / self.nw)
        pl.xlabel('iteration')
        pl.ylabel('sum of all weights')

h = History()

h.fromfile(fileinput.input())
#h.plot_energy()
h.plot_weights()
#h.plot_weightsum()
pl.show()

