John Wakefield, Single Layer Perceptron [online] (http://dynamicnotions.blogspot.com/2008/09/single-layer-perceptron.html)
Hiroki Arimura, Winnow Algorithm [online] (http://www-ikn.ist.hokudai.ac.jp/ikn-tokuron/arim4winnow.pdf )
Roni Khardon, The Winnow Algorithm [online] (http://www.cs.tufts.edu/~roni/Teaching/CLT/LN/lecture16.pdf )
Here We Go
So many explanation? Please no, I'm sleepy. Let me just post this and explain later.
For simplicity, here is the class diagram.
public abstract class LinearClassifier { protected double[][] inputs = null; protected double[] outputs; protected Double learningRate = 0.45; public LinearClassifier(double[][] inputs, double[] output) { this.inputs = inputs; this.outputs = output; } abstract void calculateWeight(); abstract double activationFunction(double[] input); //A class to record the calculation status for each iteration public class IterationStatus { public int iteration; public int instanceNumber; public double[] input; public double output; public double prediction; public double error; } }
import java.util.ArrayList; public class Perceptron extends LinearClassifier { /*variable to record the state of computation for each iteration*/ private ArrayListiterations = new ArrayList (); public double[] weights; public Perceptron(double[][] inputs, double[] outputs) { super(inputs, outputs); this.calculateWeight(); } public ArrayList getIterations() { return this.iterations; } @Override void calculateWeight() { // TODO Auto-generated method stub int numOfInstance = this.outputs.length; int numOfAttribute = this.inputs[0].length; //1. Initialize weight to 1.0 this.weights = new double[numOfAttribute]; for(int i=0; i < numOfAttribute; i++) { this.weights[i] = 1.0 ; } //2. Define the total error of instances, iteration counter //and maximum iteration double totalErrorOfInstances; int iteration = 0; int maxIteration = 20; do { totalErrorOfInstances = 0; for(int i = 0; i < numOfInstance; i++) { //set iteration object PerceptronIterationStatus current = new PerceptronIterationStatus(); current.iteration = iteration; // System.out.println("\nInstance iteration : " + i); current.instanceNumber = i; //Calculate prediction for current instance double prediction = this.activationFunction(inputs[i]); //Calculate error double errorOfInstance = outputs[i] - prediction; current.error = errorOfInstance; //Check error if(errorOfInstance != 0) { for(int j = 0; j = 0) return 1; else return 0; } public class PerceptronIterationStatus extends IterationStatus { public double[] weight; } }
import java.util.ArrayList; public class Winnow extends LinearClassifier { /*variable to record the state of computation for each iteration*/ private ArrayListiterations = new ArrayList (); public double[] weights; public Winnow(double[][] inputs, double[] outputs) { super(inputs, outputs); this.calculateWeight(); } public ArrayList getIterations() { return this.iterations; } @Override void calculateWeight() { //get the number of instances and attributes int numOfInstances = this.outputs.length; int numOfAttribute = this.inputs[0].length; //Initialize the weights to 1/N weights = new double[numOfAttribute]; for (int i = 0; i < weights.length; i++) { weights[i] = 1; } System.out.println("Weight computed"); //2. Initialize total error of instances, iteration counter and // max iteration limit double productErrorOfInstances; int iteration = 0, maxIteration = 20; //3. While the total error of instance is not zero // and it hasn't reach the max limit of iteration, iterate do { //For each iteration set this variable to 1 productErrorOfInstances = 0.0; for (int i = 0; i < numOfInstances; i++) { // System.out.println("\nInstance iteration " + i); //set the iteration status WinnowIterationStatus current = new WinnowIterationStatus(); current.iteration = iteration; current.instanceNumber = i; //calculate prediction double prediction = this.activationFunction(this.inputs[i]); // System.out.println("Prediction " + prediction); // System.out.println("Output " + this.outputs[i]); //check the error of this instance, this statement will //produce 0 if error, 1 otherwise double error = this.outputs[i] - prediction; // System.out.println("Error " + error ); //if error, e.g. error = 0 if(error != 0.0) { for (int j = 0; j < numOfAttribute; j++) { //if error occurs in positive example, e.g. output == 1 //then multiply weight by n^input[j] if(this.outputs[i] == 1.0 && this.inputs[i][j] == 1.0) { weights[j] = weights[j] * Math.pow( new Integer( numOfAttribute).doubleValue(), this.learningRate); } else if(this.outputs[i] == 0.0 && this.inputs[i][j] == 1.0) { weights[j] = weights[j] / (Math.pow(new Integer(numOfAttribute).doubleValue(), this.learningRate)); } } } current.input = this.inputs[i].clone(); current.output = this.outputs[i]; current.prediction = prediction; current.weight = this.weights.clone(); this.iterations.add(current); //set the productErrorOfInstances productErrorOfInstances = productErrorOfInstances + error; } //increment iteration iteration++; }while(productErrorOfInstances != 0 && iteration < maxIteration); } @Override double activationFunction(double[] inputs) { // if wx > n -> positive, if wx < n -> negative double output = 0.0; double n = (new Integer(weights.length)).doubleValue(); // System.out.println("Inside activation func, n = " + n); for(int i=0; i = numOfAttribute) return 1; else return 0; } public class BalancedWinnowIterationStatus extends IterationStatus { public double[] positive_weights = null; public double[] negative_weights = null; } }
import java.util.ArrayList; public class Main { /** * @param args */ public static void main(String[] args) { // TODO Auto-generated method stub //AND pattern double[][] andInput = new double[4][3]; double[] andOutput = new double[4]; andInput[0][0] = 1.0; //bias andInput[0][1] = 0; andInput[0][2] = 0; andOutput[0] = 0.0; andInput[1][0] = 1.0; //bias andInput[1][1] = 0; andInput[1][2] = 1.0; andOutput[1] = 0.0; andInput[2][0] = 1.0; //bias andInput[2][1] = 1.0; andInput[2][2] = 0; andOutput[2] = 0.0; andInput[3][0] = 1.0; //bias andInput[3][1] = 1.0; andInput[3][2] = 1.0; andOutput[3] = 1.0; System.out.println("\n\n=================================="); System.out.println("---------AND---------"); System.out.println("\n\n=================================="); System.out.println("PERCEPTRON ALGORITHM"); Perceptron p = new Perceptron(andInput, andOutput); displayIterations(p); System.out.println("\n=================================="); System.out.println("\n\n=================================="); System.out.println("WINNOW ALGORITHM"); Winnow w = new Winnow(andInput, andOutput); displayIterations(w); System.out.println("\n=================================="); System.out.println("\n\n=================================="); System.out.println("BALANCED WINNOW ALGORITHM"); BalancedWinnow bw = new BalancedWinnow(andInput, andOutput); displayIterations(bw); System.out.println("\n=================================="); //OR pattern double[][] orInput = new double[4][3]; double[] orOutput = new double[4]; orInput[0][0] = 1.0; //bias orInput[0][1] = 0.0; orInput[0][2] = 0.0; orOutput[0] = 0.0; orInput[1][0] = 1.0; //bias orInput[1][1] = 0.0; orInput[1][2] = 1.0; orOutput[1] = 1.0; orInput[2][0] = 1.0; //bias orInput[2][1] = 1.0; orInput[2][2] = 0; orOutput[2] = 1.0; orInput[3][0] = 1.0; //bias orInput[3][1] = 1.0; orInput[3][2] = 1.0; orOutput[3] = 1.0; System.out.println("\n\n=================================="); System.out.println("---------OR---------"); System.out.println("\n\n=================================="); System.out.println("PERCEPTRON ALGORITHM"); p = new Perceptron(orInput, orOutput); displayIterations(p); System.out.println("\n=================================="); System.out.println("\n\n=================================="); System.out.println("WINNOW ALGORITHM"); w = new Winnow(orInput, orOutput); displayIterations(w); System.out.println("\n=================================="); System.out.println("\n\n=================================="); System.out.println("BALANCED WINNOW ALGORITHM"); bw = new BalancedWinnow(orInput, orOutput); displayIterations(bw); System.out.println("\n=================================="); } static void displayIterations(LinearClassifier classifier) { if(classifier.getClass() == Perceptron.class) { Perceptron pClass = (Perceptron)classifier; ArrayListiterations = pClass.getIterations(); System.out.println(); System.out.print(" | Iteration | "); System.out.print(" | Instance | "); for (int j = 0; j < pClass.weights.length; j++) { System.out.print(" | W " + j + " | "); } System.out.print(" | Output | "); System.out.print(" | Prediction | "); for (int i = 0; i < iterations.size(); i++) { System.out.println(); System.out.print( " | " + iterations.get(i).iteration + " | "); System.out.print( " | " + iterations.get(i).instanceNumber + " | "); for (int j = 0; j < pClass.weights.length; j++) { System.out.print(" | " + iterations.get(i).weight[j] + " | "); } System.out.print( " | " + iterations.get(i).output + " | "); System.out.print( " | " + iterations.get(i).prediction + " | "); } } if(classifier.getClass() == Winnow.class) { Winnow wClass = (Winnow)classifier; ArrayList iterations = wClass.getIterations(); System.out.println(); System.out.print(" | Iteration | "); System.out.print(" | Instance | "); for (int j = 0; j < wClass.weights.length; j++) { System.out.print(" | W " + j + " | "); } System.out.print(" | Output | "); System.out.print(" | Prediction | "); for (int i = 0; i < iterations.size(); i++) { System.out.println(); System.out.print( " | " + iterations.get(i).iteration + " | "); System.out.print( " | " + iterations.get(i).instanceNumber + " | "); for (int j = 0; j < wClass.weights.length; j++) { System.out.print(" | " + iterations.get(i).weight[j] + " | "); } System.out.print( " | " + iterations.get(i).output + " | "); System.out.print( " | " + iterations.get(i).prediction + " | "); } } if(classifier.getClass() == BalancedWinnow.class) { BalancedWinnow bwClass = (BalancedWinnow)classifier; ArrayList iterations = bwClass.getIterations(); System.out.println(); System.out.print(" | Iteration | "); System.out.print(" | Instance | "); for (int j = 0; j < bwClass.positive_weights.length; j++) { System.out.print(" | P_W " + j + " | "); } for (int j = 0; j < bwClass.negative_weights.length; j++) { System.out.print(" | N_W " + j + " | "); } System.out.print(" | Output | "); System.out.print(" | Prediction | "); for (int i = 0; i < iterations.size(); i++) { System.out.println(); System.out.print( " | " + iterations.get(i).iteration + " | "); System.out.print( " | " + iterations.get(i).instanceNumber + " | "); for (int j = 0; j < bwClass.positive_weights.length; j++) { System.out.print(" | " + iterations.get(i).positive_weights[j] + " | "); } for (int j = 0; j < bwClass.negative_weights.length; j++) { System.out.print(" | " + iterations.get(i).negative_weights[j] + " | "); } System.out.print( " | " + iterations.get(i).output + " | "); System.out.print( " | " + iterations.get(i).prediction + " | "); } } } }
No comments:
Post a Comment