"""Test of pblas_pdgemm.  This test requires 4 processors.

This is a test of the GPAW interface to the parallel
matrix multiplication routine, pblas_pdgemm.

The test generates random matrices A and B and their product C on master.

Then A and B are distributed, and pblas_dgemm is invoked to get C in
distributed form.  This is then collected to master and compared to
the serially calculated reference.
"""

import numpy as np
from gpaw.blacs import BlacsGrid, Redistributor, parallelprint
from gpaw.utilities.blacs import pblas_simple_gemm
from gpaw.mpi import world, rank
import _gpaw

def pblas_pdgemv(M, N, alpha, a, adesc, x, xdesc, beta, y, ydesc):
    _gpaw.pblas_pdgemv(N, N, alpha,
                       a.T, adesc.asarray(),
                       x, xdesc.asarray(),
                       beta,
                       y, ydesc.asarray())

def pblas_simple_pdgemv(adesc, xdesc, ydesc, a, x, y):
    alpha = 1.0
    beta = 0.0
    M, N = adesc.gshape
    pblas_pdgemv(M, N, alpha, a, adesc, x, xdesc, beta, y, ydesc)


def main():
    seed = 42
    gen = np.random.RandomState(seed)
    grid = BlacsGrid(world, 2, 2)

    M = 16
    N = 12
    K = 14

    # Create matrices on master:
    globA = grid.new_descriptor(M, K, M, K)
    globB = grid.new_descriptor(K, N, K, N)
    globC = grid.new_descriptor(M, N, M, N)
    globX = grid.new_descriptor(K, 1, K, 1)
    globY = grid.new_descriptor(M, 1, M, 1)

    # Random matrices local to master:
    A0 = gen.rand(*globA.shape)
    B0 = gen.rand(*globB.shape)
    X0 = gen.rand(*globX.shape)

    # Local result matrices
    Y0 = globY.empty()
    C0 = globC.empty()

    # Local matrix product:
    if rank == 0:
        C0[:] = np.dot(A0, B0)
        Y0[:] = np.dot(A0, X0)

    assert globA.check(A0) and globB.check(B0) and globC.check(C0)
    assert globX.check(X0) and globY.check(Y0)

    # Create distributed destriptors with various block sizes:
    distA = grid.new_descriptor(M, K, 2, 3)
    distB = grid.new_descriptor(K, N, 2, 4)
    distC = grid.new_descriptor(M, N, 3, 2)
    distX = grid.new_descriptor(1, K, 1, 2)
    distY = grid.new_descriptor(1, M, 1, 2)

    # Distributed matrices:
    A = distA.empty()
    B = distB.empty()
    C = distC.empty()
    X = distX.empty()
    Y = distY.empty()

    Redistributor(world, globA, distA).redistribute(A0, A)
    Redistributor(world, globB, distB).redistribute(B0, B)
    Redistributor(world, globX, distX).redistribute(X0, X)
    #xredistributor = Redistributor(world, globX, distX)
    #xredistributor.redistribute(X0, X)
    #xredistributor.redistribute(
    
    pblas_simple_gemm(distA, distB, distC, A, B, C)
    pblas_simple_pdgemv(distA, distX, distY, A, X, Y)

    # Collect result back on master
    C1 = globC.empty()
    X1 = globX.empty()
    Redistributor(world, distC, globC).redistribute(C, C1)
    Redistributor(world, distX, globX).redistribute(X, X1)

    if rank == 0:
        err = abs(C0 - C1).max()
        print 'Err', err
    else:
        err = 0.0
    err = world.sum(err) # We don't like exceptions on only one process
    assert err < 1e-14
    

if __name__ == '__main__':
    main()

