Saturday, November 26, 2011

Perceptron, Winnow and Balanced Winnow in Java

References

John Wakefield, Single Layer Perceptron [online] (http://dynamicnotions.blogspot.com/2008/09/single-layer-perceptron.html)
Hiroki Arimura, Winnow Algorithm [online] (http://www-ikn.ist.hokudai.ac.jp/ikn-tokuron/arim4winnow.pdf )
Roni Khardon, The Winnow Algorithm [online] (http://www.cs.tufts.edu/~roni/Teaching/CLT/LN/lecture16.pdf )

Here We Go

So many explanation? Please no, I'm sleepy. Let me just post this and explain later.
For simplicity, here is the class diagram.

public abstract class LinearClassifier {
	
	protected double[][] inputs = null;
    protected double[] outputs;
	
	protected Double learningRate = 0.45;
	
	public LinearClassifier(double[][] inputs, double[] output) {
		this.inputs = inputs;
		this.outputs = output;
	}
	
	abstract void calculateWeight();
	
	abstract double activationFunction(double[] input);

	//A class to record the calculation status for each iteration
	public class IterationStatus {
		public int iteration;
		public int instanceNumber;
		public double[] input;
		public double output;
		public double prediction;
		public double error;
	}
}


import java.util.ArrayList;

public class Perceptron extends LinearClassifier {

	/*variable to record the state of computation for each iteration*/
	private ArrayList iterations = new ArrayList();
	
	public double[] weights;

	public Perceptron(double[][] inputs, double[] outputs) {
		super(inputs, outputs);
		
		this.calculateWeight();
	}

	public ArrayList getIterations() {
		return this.iterations;
	}
	
	@Override
	void calculateWeight() {
		// TODO Auto-generated method stub
		
		int	numOfInstance = this.outputs.length;
		int numOfAttribute = this.inputs[0].length;
		
		//1. Initialize weight to 1.0
		this.weights = new double[numOfAttribute];
		for(int i=0; i < numOfAttribute; i++) { 
			this.weights[i] = 1.0 ;
		}
		
		
		//2. Define the total error of instances, iteration counter
		//and maximum iteration
		double totalErrorOfInstances;
		int iteration = 0;
		int maxIteration = 20;
		
		do
		{
			totalErrorOfInstances = 0;
			
			for(int i = 0; i < numOfInstance; i++) {
				
				//set iteration object
				PerceptronIterationStatus current = new PerceptronIterationStatus();
				current.iteration = iteration;
				
				// System.out.println("\nInstance iteration : " + i);
				current.instanceNumber = i;
				
				//Calculate prediction for current instance
				double prediction = this.activationFunction(inputs[i]);
				
				//Calculate error
				double errorOfInstance = outputs[i] - prediction;
				current.error = errorOfInstance;
				
				//Check error
				if(errorOfInstance != 0) {
					
					for(int j = 0; j= 0) return 1;
		else return 0;	
	}
	
	public class PerceptronIterationStatus extends IterationStatus {
		public double[] weight;
	}
}


import java.util.ArrayList;

public class Winnow extends LinearClassifier {

	/*variable to record the state of computation for each iteration*/
	private ArrayList iterations = new ArrayList();
	
	public double[] weights;

	public Winnow(double[][] inputs, double[] outputs) {
		super(inputs, outputs);
		this.calculateWeight();
	}
	
	public ArrayList getIterations() {
		return this.iterations;
	}

	@Override
	void calculateWeight() {
		//get the number of instances and attributes
		int numOfInstances = this.outputs.length;
		int numOfAttribute = this.inputs[0].length;
		
		//Initialize the weights to 1/N
		weights = new double[numOfAttribute];
		for (int i = 0; i < weights.length; i++) {
			weights[i] = 1;
		}
		System.out.println("Weight computed");
		//2. Initialize total error of instances, iteration counter and 
		// max iteration limit
		double productErrorOfInstances;
		int iteration = 0, maxIteration = 20;
		
		//3. While the total error of instance is not zero
		// and it hasn't reach the max limit of iteration, iterate
		do {
			//For each iteration set this variable to 1
			productErrorOfInstances  = 0.0;
			
			for (int i = 0; i < numOfInstances; i++) {
				// System.out.println("\nInstance iteration " + i);
				
				//set the iteration status 
				WinnowIterationStatus current = new WinnowIterationStatus();
				current.iteration = iteration;
				
				current.instanceNumber = i;
				
				//calculate prediction
				double prediction = this.activationFunction(this.inputs[i]);
				// System.out.println("Prediction " + prediction);
				// System.out.println("Output " + this.outputs[i]);
				
				//check the error of this instance, this statement will
				//produce 0 if error, 1 otherwise
				double error = this.outputs[i] - prediction;
				// System.out.println("Error " + error );
				
				//if error, e.g. error = 0
				if(error != 0.0) {
					for (int j = 0; j < numOfAttribute; j++) {
						//if error occurs in positive example, e.g. output == 1
						//then multiply weight by n^input[j]
						if(this.outputs[i] == 1.0 && this.inputs[i][j] == 1.0) {
							weights[j] = weights[j] * Math.pow( new Integer( numOfAttribute).doubleValue(), this.learningRate);
						}
						else if(this.outputs[i] == 0.0 && this.inputs[i][j] == 1.0) {
							weights[j] = weights[j] / (Math.pow(new Integer(numOfAttribute).doubleValue(), this.learningRate));
						}
					}
					
					
				}
				
				current.input = this.inputs[i].clone();
				current.output = this.outputs[i];
				current.prediction = prediction;
				current.weight = this.weights.clone();
				
				this.iterations.add(current);
				//set the productErrorOfInstances
				productErrorOfInstances = productErrorOfInstances + error;
				
			}
			//increment iteration
			iteration++;
			
		}while(productErrorOfInstances != 0 && iteration < maxIteration);
		
	}

	@Override
	double activationFunction(double[] inputs) {
		// if wx > n -> positive, if wx < n -> negative
		double output = 0.0;
		double n = (new Integer(weights.length)).doubleValue();
		// System.out.println("Inside activation func, n = " + n);
		for(int i=0; i= numOfAttribute) return 1;
		else return 0;
	}
	
	public class BalancedWinnowIterationStatus extends IterationStatus {
		public double[] positive_weights = null;
		public double[] negative_weights = null;
	}
}


import java.util.ArrayList;


public class Main {

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		// TODO Auto-generated method stub

		//AND pattern
		double[][] andInput = new double[4][3];
		double[] andOutput = new double[4];
		
		andInput[0][0] = 1.0; //bias
		andInput[0][1] = 0;
		andInput[0][2] = 0;
		andOutput[0] = 0.0;
		
		andInput[1][0] = 1.0; //bias
		andInput[1][1] = 0;
		andInput[1][2] = 1.0;
		andOutput[1] = 0.0;
		
		andInput[2][0] = 1.0; //bias
		andInput[2][1] = 1.0;
		andInput[2][2] = 0;
		andOutput[2] = 0.0;
		
		andInput[3][0] = 1.0; //bias
		andInput[3][1] = 1.0;
		andInput[3][2] = 1.0;
		andOutput[3] = 1.0;
		
		System.out.println("\n\n==================================");
		System.out.println("---------AND---------");
		System.out.println("\n\n==================================");
		System.out.println("PERCEPTRON ALGORITHM");
		Perceptron p = new Perceptron(andInput, andOutput);
		displayIterations(p);
		System.out.println("\n==================================");
		
		System.out.println("\n\n==================================");
		System.out.println("WINNOW ALGORITHM");
		Winnow w = new Winnow(andInput, andOutput);
		displayIterations(w);
		System.out.println("\n==================================");
		
		System.out.println("\n\n==================================");
		System.out.println("BALANCED WINNOW ALGORITHM");
		BalancedWinnow bw = new BalancedWinnow(andInput, andOutput);
		displayIterations(bw);
		System.out.println("\n==================================");
		
		//OR pattern
		double[][] orInput = new double[4][3];
		double[] orOutput = new double[4];
		
		orInput[0][0] = 1.0; //bias
		orInput[0][1] = 0.0;
		orInput[0][2] = 0.0;
		orOutput[0] = 0.0;
		
		orInput[1][0] = 1.0; //bias
		orInput[1][1] = 0.0;
		orInput[1][2] = 1.0;
		orOutput[1] = 1.0;
		
		orInput[2][0] = 1.0; //bias
		orInput[2][1] = 1.0;
		orInput[2][2] = 0;
		orOutput[2] = 1.0;
		
		orInput[3][0] = 1.0; //bias
		orInput[3][1] = 1.0;
		orInput[3][2] = 1.0;
		orOutput[3] = 1.0;
		
		System.out.println("\n\n==================================");
		System.out.println("---------OR---------");
		System.out.println("\n\n==================================");
		System.out.println("PERCEPTRON ALGORITHM");
		p = new Perceptron(orInput, orOutput);
		displayIterations(p);
		System.out.println("\n==================================");
		
		System.out.println("\n\n==================================");
		System.out.println("WINNOW ALGORITHM");
		w = new Winnow(orInput, orOutput);
		displayIterations(w);
		System.out.println("\n==================================");
		
		System.out.println("\n\n==================================");
		System.out.println("BALANCED WINNOW ALGORITHM");
		bw = new BalancedWinnow(orInput, orOutput);
		displayIterations(bw);
		System.out.println("\n==================================");
	}
	
	static void displayIterations(LinearClassifier classifier) {
		
		if(classifier.getClass() == Perceptron.class) {
			Perceptron pClass = (Perceptron)classifier;
			ArrayList iterations = pClass.getIterations();
			
			System.out.println();
			System.out.print(" | Iteration | ");
			System.out.print(" | Instance | ");
			for (int j = 0; j <  pClass.weights.length; j++) {
				System.out.print(" | W " + j + " | ");
			}
			System.out.print(" | Output | ");
			System.out.print(" | Prediction | ");
			
			for (int i = 0; i < iterations.size(); i++) {
				System.out.println();
				System.out.print( " | " + iterations.get(i).iteration + " | ");
				System.out.print( " | " + iterations.get(i).instanceNumber + " | ");
				for (int j = 0; j < pClass.weights.length; j++) {
					System.out.print(" | " + iterations.get(i).weight[j] + " | ");
				}
				System.out.print( " | " + iterations.get(i).output + " | ");
				System.out.print( " | " + iterations.get(i).prediction + " | ");
			}
		}
		
		if(classifier.getClass() == Winnow.class) {
			Winnow wClass = (Winnow)classifier;
			ArrayList iterations = wClass.getIterations();
			
			System.out.println();
			System.out.print(" | Iteration | ");
			System.out.print(" | Instance | "); 
			for (int j = 0; j <  wClass.weights.length; j++) {
				System.out.print(" | W " + j + " | ");
			}
			System.out.print(" | Output | ");
			System.out.print(" | Prediction | ");
			
			for (int i = 0; i < iterations.size(); i++) {
				System.out.println();
				System.out.print( " | " + iterations.get(i).iteration + " | ");
				System.out.print( " | " + iterations.get(i).instanceNumber + " | ");
				for (int j = 0; j < wClass.weights.length; j++) {
					System.out.print(" | " + iterations.get(i).weight[j] + " | ");
				}
				System.out.print( " | " + iterations.get(i).output + " | ");
				System.out.print( " | " + iterations.get(i).prediction + " | ");
			}
		}
		if(classifier.getClass() == BalancedWinnow.class) {
			BalancedWinnow bwClass = (BalancedWinnow)classifier;
			ArrayList iterations = bwClass.getIterations();
			
			System.out.println();
			System.out.print(" | Iteration | ");
			System.out.print(" | Instance | "); 
			for (int j = 0; j <  bwClass.positive_weights.length; j++) {
				System.out.print(" | P_W " + j + " | ");
			}
			for (int j = 0; j <  bwClass.negative_weights.length; j++) {
				System.out.print(" | N_W " + j + " | ");
			}
			System.out.print(" | Output | ");
			System.out.print(" | Prediction | ");
			
			for (int i = 0; i < iterations.size(); i++) {
				System.out.println();
				System.out.print( " | " + iterations.get(i).iteration + " | ");
				System.out.print( " | " + iterations.get(i).instanceNumber + " | ");
				for (int j = 0; j < bwClass.positive_weights.length; j++) {
					System.out.print(" | " + iterations.get(i).positive_weights[j] + " | ");
				}
				for (int j = 0; j < bwClass.negative_weights.length; j++) {
					System.out.print(" | " + iterations.get(i).negative_weights[j] + " | ");
				}
				System.out.print( " | " + iterations.get(i).output + " | ");
				System.out.print( " | " + iterations.get(i).prediction + " | ");
			}
		}
		
	}
}


No comments:

Post a Comment