diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 85c5fa0..ebae177 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -34,9 +34,9 @@ public class Main { System.out.println(network.synCount()); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.001F, 2000, network, dataset); + trainer.train(0.01F, 2000, network, dataset); - plotGraph(dataset, network); + //plotGraph(dataset, network); } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ @@ -83,8 +83,8 @@ public class Main { float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ - List predictions = network.predict(List.of(new Input(x), new Input(y))); - visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y); + float[] predictions = network.predict(new float[]{x, y}); + visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y); } } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index 9d13be9..b1f3d8c 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -15,5 +15,5 @@ public interface Model { //void forEachSynapse(Consumer consumer); void forEachOutputNeurons(Consumer consumer); void forEachNeuronConnectedTo(Neuron n, Consumer consumer); - List predict(List inputs); + float[] predict(float[] inputs); } diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index 7de9377..47183d3 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -11,7 +11,7 @@ public abstract class TrainingContext { public DataSetEntry currentEntry; public List expectations; - public List predictions; + public float[] predictions; public float[] deltas; public float globalLoss; diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index d7d9407..0e26063 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -24,13 +24,12 @@ public class FullyConnectedNetwork implements Model { } @Override - public List predict(List inputs) { - List previousLayerOutputs = new ArrayList<>(inputs); - for(Layer layer : this.layers){ - List currentLayerOutputs = layer.predict(previousLayerOutputs); - previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList(); + public float[] predict(float[] inputs) { + float[] previousLayerOutputs = inputs; + for (Layer layer : layers) { + previousLayerOutputs = layer.predict(previousLayerOutputs); } - return previousLayerOutputs.stream().map(Input::getValue).toList(); + return previousLayerOutputs; } @Override diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index a214296..9969f99 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -19,11 +19,10 @@ public class Layer implements Model { } @Override - public List predict(List inputs) { - List result = new ArrayList<>(); - for(Neuron neuron : this.neurons){ - List res = neuron.predict(inputs); - result.addAll(res); + public float[] predict(float[] inputs) { + float[] result = new float[neurons.length]; + for (int i = 0; i < neurons.length; i++) { + result[i] = neurons[i].predict(inputs)[0]; } return result; } diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index 297f6dd..e9470c4 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -17,6 +17,7 @@ public class Neuron implements Model { this.id = id; this.activationFunction = func; + output = 0; weights = new float[synapses.length+1]; //takes the bias into account inputs = new float[synapses.length+1]; //takes the bias into account @@ -44,10 +45,6 @@ public class Neuron implements Model { return this.activationFunction; } - public float getOutput(){ - return this.output; - } - public float calculateWeightedSum() { int count = synCount(); float weightedSum = 0F; @@ -61,6 +58,10 @@ public class Neuron implements Model { return this.id; } + public float getOutput() { + return this.output; + } + @Override public int synCount() { return this.weights.length; @@ -77,10 +78,10 @@ public class Neuron implements Model { } @Override - public List predict(List inputs) { + public float[] predict(float[] inputs) { this.setInputs(inputs); - this.output = activationFunction.accept(this); - return List.of(output); + output = activationFunction.accept(this); + return new float[] {output}; } @Override @@ -98,10 +99,8 @@ public class Neuron implements Model { throw new UnsupportedOperationException("Neurons have no connection with themselves"); } - private void setInputs(List values){ - for(int i = 0; i < values.size(); i++){ - inputs[i+1] = values.get(i).getValue(); - } + private void setInputs(float[] values){ + System.arraycopy(values, 0, inputs, 1, values.length); } } diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 9e7a367..578a725 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -101,7 +101,7 @@ public class TrainingPipeline { if(this.verbose && ctx.epoch % this.verboseDelay == 0) { System.out.printf("Epoch : %d, ", ctx.epoch); - System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray())); + System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions)); System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray())); System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas)); System.out.printf("loss : %.5f\n", ctx.localLoss); diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java index 5b5f88c..a3cb274 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStep.java @@ -21,11 +21,11 @@ public class SimpleDeltaStep implements AlgorithmStep { public void run() { DataSet dataSet = context.dataset; DataSetEntry entry = context.currentEntry; - List predicted = context.predictions; + float[] predicted = context.predictions; List expected = dataSet.getLabelsAsFloat(entry); - for (int i = 0; i < predicted.size(); i++) { - context.deltas[i] = expected.get(i) - predicted.get(i); + for (int i = 0; i < predicted.length; i++) { + context.deltas[i] = expected.get(i) - predicted[0]; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java index ab25f23..ea62105 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStep.java @@ -2,6 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import com.naaturel.ANN.domain.model.neuron.Input; + +import java.util.List; public class SimplePredictionStep implements AlgorithmStep { @@ -13,6 +16,11 @@ public class SimplePredictionStep implements AlgorithmStep { @Override public void run() { - context.predictions = context.model.predict(context.currentEntry.getData()); + List data = context.currentEntry.getData(); + float[] flatData = new float[data.size()]; + for (int i = 0; i < data.size(); i++) { + flatData[i] = data.get(i).getValue(); + } + context.predictions = context.model.predict(flatData); } }