Optimize prediction
This commit is contained in:
@@ -34,9 +34,9 @@ public class Main {
|
|||||||
System.out.println(network.synCount());
|
System.out.println(network.synCount());
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.001F, 2000, network, dataset);
|
trainer.train(0.01F, 2000, network, dataset);
|
||||||
|
|
||||||
plotGraph(dataset, network);
|
//plotGraph(dataset, network);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
@@ -83,8 +83,8 @@ public class Main {
|
|||||||
float step = 0.03F;
|
float step = 0.03F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
float[] predictions = network.predict(new float[]{x, y});
|
||||||
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
|
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,5 +15,5 @@ public interface Model {
|
|||||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
//void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
||||||
List<Float> predict(List<Input> inputs);
|
float[] predict(float[] inputs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ public abstract class TrainingContext {
|
|||||||
public DataSetEntry currentEntry;
|
public DataSetEntry currentEntry;
|
||||||
|
|
||||||
public List<Float> expectations;
|
public List<Float> expectations;
|
||||||
public List<Float> predictions;
|
public float[] predictions;
|
||||||
public float[] deltas;
|
public float[] deltas;
|
||||||
|
|
||||||
public float globalLoss;
|
public float globalLoss;
|
||||||
|
|||||||
@@ -24,13 +24,12 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Float> predict(List<Input> inputs) {
|
public float[] predict(float[] inputs) {
|
||||||
List<Input> previousLayerOutputs = new ArrayList<>(inputs);
|
float[] previousLayerOutputs = inputs;
|
||||||
for(Layer layer : this.layers){
|
for (Layer layer : layers) {
|
||||||
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs);
|
previousLayerOutputs = layer.predict(previousLayerOutputs);
|
||||||
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
|
|
||||||
}
|
}
|
||||||
return previousLayerOutputs.stream().map(Input::getValue).toList();
|
return previousLayerOutputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -19,11 +19,10 @@ public class Layer implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Float> predict(List<Input> inputs) {
|
public float[] predict(float[] inputs) {
|
||||||
List<Float> result = new ArrayList<>();
|
float[] result = new float[neurons.length];
|
||||||
for(Neuron neuron : this.neurons){
|
for (int i = 0; i < neurons.length; i++) {
|
||||||
List<Float> res = neuron.predict(inputs);
|
result[i] = neurons[i].predict(inputs)[0];
|
||||||
result.addAll(res);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ public class Neuron implements Model {
|
|||||||
this.id = id;
|
this.id = id;
|
||||||
this.activationFunction = func;
|
this.activationFunction = func;
|
||||||
|
|
||||||
|
output = 0;
|
||||||
weights = new float[synapses.length+1]; //takes the bias into account
|
weights = new float[synapses.length+1]; //takes the bias into account
|
||||||
inputs = new float[synapses.length+1]; //takes the bias into account
|
inputs = new float[synapses.length+1]; //takes the bias into account
|
||||||
|
|
||||||
@@ -44,10 +45,6 @@ public class Neuron implements Model {
|
|||||||
return this.activationFunction;
|
return this.activationFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float getOutput(){
|
|
||||||
return this.output;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateWeightedSum() {
|
public float calculateWeightedSum() {
|
||||||
int count = synCount();
|
int count = synCount();
|
||||||
float weightedSum = 0F;
|
float weightedSum = 0F;
|
||||||
@@ -61,6 +58,10 @@ public class Neuron implements Model {
|
|||||||
return this.id;
|
return this.id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public float getOutput() {
|
||||||
|
return this.output;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int synCount() {
|
public int synCount() {
|
||||||
return this.weights.length;
|
return this.weights.length;
|
||||||
@@ -77,10 +78,10 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Float> predict(List<Input> inputs) {
|
public float[] predict(float[] inputs) {
|
||||||
this.setInputs(inputs);
|
this.setInputs(inputs);
|
||||||
this.output = activationFunction.accept(this);
|
output = activationFunction.accept(this);
|
||||||
return List.of(output);
|
return new float[] {output};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -98,10 +99,8 @@ public class Neuron implements Model {
|
|||||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
||||||
}
|
}
|
||||||
|
|
||||||
private void setInputs(List<Input> values){
|
private void setInputs(float[] values){
|
||||||
for(int i = 0; i < values.size(); i++){
|
System.arraycopy(values, 0, inputs, 1, values.length);
|
||||||
inputs[i+1] = values.get(i).getValue();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ public class TrainingPipeline {
|
|||||||
|
|
||||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
|
||||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||||
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ public class SimpleDeltaStep implements AlgorithmStep {
|
|||||||
public void run() {
|
public void run() {
|
||||||
DataSet dataSet = context.dataset;
|
DataSet dataSet = context.dataset;
|
||||||
DataSetEntry entry = context.currentEntry;
|
DataSetEntry entry = context.currentEntry;
|
||||||
List<Float> predicted = context.predictions;
|
float[] predicted = context.predictions;
|
||||||
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
||||||
|
|
||||||
for (int i = 0; i < predicted.size(); i++) {
|
for (int i = 0; i < predicted.length; i++) {
|
||||||
context.deltas[i] = expected.get(i) - predicted.get(i);
|
context.deltas[i] = expected.get(i) - predicted[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class SimplePredictionStep implements AlgorithmStep {
|
public class SimplePredictionStep implements AlgorithmStep {
|
||||||
|
|
||||||
@@ -13,6 +16,11 @@ public class SimplePredictionStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
context.predictions = context.model.predict(context.currentEntry.getData());
|
List<Input> data = context.currentEntry.getData();
|
||||||
|
float[] flatData = new float[data.size()];
|
||||||
|
for (int i = 0; i < data.size(); i++) {
|
||||||
|
flatData[i] = data.get(i).getValue();
|
||||||
|
}
|
||||||
|
context.predictions = context.model.predict(flatData);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user