/**
 * Matrix.
 *
 * General matrix class.
 *
 * Copyright (C) 1999  Shazron Abdullah
 * 
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * @author Shazron Abdullah
 * @since JDK1.1.7
 */
public class Matrix
{
	public static final int DEFAULT_ROWS 	= 3;
	
	protected long[][] data;
	protected int num_rows = 0;
	protected int num_cols = 0;
	
	/**
	 * Constructor that makes the default 3x3 square matrix.
	 */
	public Matrix()
	{
		this( DEFAULT_ROWS );
	}
	
	public Matrix( long[][] values )
	{
		copyData( values ); // deep copy
	}
	
	/**
	 * Constructor to make a square matrix with the specified size.
	 */
	public Matrix( int size )
	{
		this( size, size );
	}

	/**
	 * Constructor to make a matrix with the specified rows and columns
	 */
	public Matrix( int rows, int cols ) 
	{
		setData( rows, cols );
	}
	
	/**
	 * Initialize the internal storage to store the matrix values, with the
	 * number of rows and columns specified.
	 */
	private void setData( int rows, int cols )
	{
		if ( rows <= 0 || cols <= 0 ) throw new NegativeArraySizeException();
		
		data = new long[rows][cols];
		num_rows = rows;
		num_cols = cols;
	}
	
	/**
	 * Set the value at a row and a column.
	 * The row and column values are 0-indexed.
	 */
	public void setValue( int row, int col, long value )
	{
		if ( row < 0 || col < 0 || row >= num_rows || col >= num_cols ) {
			throw new ArrayIndexOutOfBoundsException();
		}
		
		data[row][col] = value;
	}
	
	/**
	 * Get the value at a row and a column.
	 * The row and column values are 0-indexed.
	 */
	public long getValue( int row, int col )
	{
		if ( row < 0 || col < 0 || row >= num_rows || col >= num_cols ) {
			throw new ArrayIndexOutOfBoundsException();
		}
		
		return data[row][col];
	}
	
	
	/**
	 * Copies the values of the Matrix m to this matrix.
	 */
	public void copyMatrix( Matrix m )
	{
		copyData( m.data );
	}
	
	/**
	 * Copy the data from the array, to this matrix.
	 */
	public void copyData( long[][] values )
	{
		setData( values.length, values[0].length );
		
		for ( int row=0; row < num_rows; row++ ) {
			for ( int col=0; col < num_cols; col++ ) {
				data[row][col] = values[row][col];
			}
		}
	}
	
	/**
	 * Get this matrix's values as an array of array of longs.
	 */
	public long[][] getArray()
	{
		return data;
	}

	/**
	 * Zero out all the values in this matrix.
	 */
	public Matrix zero()
	{
		for ( int i=0; i < num_rows; i++ ) {
			for ( int j=0; j < num_cols; j++ ) {
				data[i][j] = 0;
			}
		}
		return this;
	}
	
	/**
	 * Convert this matrix to an identity matrix.
	 */
	public Matrix identity()
	{
		return identityMult(1);
	}
	
	/**
	 * Convert this matrix to an identity matrix, with the specified 
	 * magnitude. (only for square matrices).
	 */
	public Matrix identityMult( int multiplier )
	{
		if ( num_rows != num_cols ) {
			throw new MatrixNotSquareException();
		}
		
		zero();
		for ( int i=0; i < num_rows; i++ ) {
			data[i][i] = multiplier;
		}
		return this;
	}
	
	/**
	 * Right-hand-side matrix multiply (this matrix is on the left, m is on the right)
	 * 
	 */
	public Matrix mult( Matrix m )
	{
		return new Matrix( mult( getArray(), m.getArray() ) );
	}
	
	/**
	 * Equality function.
	 */
	public boolean equals( Matrix m )
	{
		if ( num_rows != m.num_rows )  return false;
		
		for ( int i=0; i < num_rows; i++ ) {
			for ( int j=0; j < num_rows; j++ ) {
				if ( data[i][j] != m.data[i][j] ) {
					return false;
				}
			}
		}
		return true;
	}
	
	/**
	 * Prints a representation of this Matrix as a String.
	 */
	public String toString()
	{
		return print( data );
	}
	
	
	/* ************************************************************
	 *
	 * Static functions.
	 *
	 *************************************************************/
	
	public static Matrix getIdentity(int num_rows) { return new Matrix( num_rows ).identity(); }
	public static Matrix getIdentity() { return getIdentity( DEFAULT_ROWS ); }
	public static Matrix getZero( int num_rows ) { return new Matrix( num_rows ).zero(); }
	public static Matrix getZero() { return getZero( DEFAULT_ROWS ); }
	
	
	/**
	 * Multiply two matrices.
	 */
	public static long[][] mult( long[][] m1, long[][] m2 )
	{
		if ( m1[0].length != m2.length )  throw new MatrixIncompatibleException();
		
		long[][] array = new long[m1.length][m2[0].length];
		for ( int row=0; row < m1.length; row++ ) {
			for ( int col=0; col < m2[0].length; col++ ) {
				long val = dotProduct( getRow(m1,row), getColumn(m2,col) );
				array[row][col] = val;
			}
		}
		return array;
	}
	
	/**
	 * Multiplies each value in the matrix by a scalar.
	 */
	public static long[][] scalar( long[][] m, int scalar )
	{
		long[][] array = new long[m.length][m[0].length];
		for ( int row=0; row < m.length; row++ ) {
			for ( int col=0; col < m[0].length; col++ ) {
				array[row][col] = scalar * m[row][col];
			}
		}
		return array;
	}
	
	
	/**
	 * Add two matrices.
	 */
	public static long[][] add( long[][] m1, long[][] m2 )
	{
		if ( m1[0].length != m2[0].length && m1.length != m2.length )  throw new MatrixIncompatibleException();
		
		long[][] array = new long[m1.length][m1[0].length];
		for ( int row=0; row < m1.length; row++ ) {
			for ( int col=0; col < m2[0].length; col++ ) {
				array[row][col] = m1[row][col] + m2[row][col];
			}
		}
		return array;
	}

	/**
	 * Subtract two matrices (negate the second matrix, add to first)
	 */
	public static long[][] subtract( long[][] m1, long[][] m2 )
	{
		return ( add( m1, negate( m2 ) ) );
	}
	
	/**
	 * Negate a matrix
	 */
	public static long[][] negate( long[][] m )
	{
		long[][] array = new long[m.length][m[0].length];
		for ( int row=0; row < m.length; row++ ) {
			for ( int col=0; col < m[0].length; col++ ) {
				array[row][col] = m[row][col] * -1;
			}
		}
		
		return array;
	}
	
	/**
	 * Calculate the dot product of two vectors (array of values).
	 */
	public static long dotProduct( long[] vect1, long[] vect2)
	{
		if ( vect1.length != vect2.length ) return 0;
		
		long val=0;
		for ( int i=0; i < vect1.length; i++ ) {
			val += vect1[i] * vect2[i];
		}
		return val;
	}
	
	/**
	 * Gets the specified column of a matrix as an array.
	 */
	public static long[] getColumn( long[][] matrix, int col )
	{
		if ( col < 0 || col >= matrix[0].length ) {
			throw new ArrayIndexOutOfBoundsException();
		}
		
		long[] val = new long[ matrix.length ];
		for (int row=0; row < matrix.length; row++ ) {
			val[row] = matrix[row][col];
		}
		return val;
	}
	
	/**
	 * Gets the specified row of a matrix as an array.
	 */
	public static long[] getRow( long[][] matrix, int row )
	{
		if ( row < 0 || row >= matrix.length ) {
			throw new ArrayIndexOutOfBoundsException();
		}
		
		long[] val = new long[ matrix[0].length ];
		for (int col=0; col < matrix[0].length; col++ ) {
			val[col] = matrix[row][col];
		}
		return val;
	}
	
	/**
	 * Prints the matrix into a String representation.
	 */
	public static String print( long[][] matrix )
	{
		final String SPACER = "   "; // 3 spaces
		
		StringBuffer b = new StringBuffer("");
		for ( int row=0; row < matrix.length; row++ ) {
			b.append("|").append( SPACER );
			for ( int col=0; col < matrix[0].length; col++ ) {
				b.append( matrix[row][col] ).append( SPACER );
			}
			b.append("|\n");
		}
		return b.toString();
	}
	
	public static class MatrixNotSquareException extends RuntimeException {} 
	public static class MatrixIncompatibleException extends RuntimeException {}
}

