Optimize prediction

This commit is contained in:
2026-04-02 09:07:58 +02:00
parent 5ddf6dc580
commit 4c1eaff238
9 changed files with 38 additions and 33 deletions

View File

@@ -34,9 +34,9 @@ public class Main {
System.out.println(network.synCount()); System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 2000, network, dataset); trainer.train(0.01F, 2000, network, dataset);
plotGraph(dataset, network); //plotGraph(dataset, network);
} }
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -83,8 +83,8 @@ public class Main {
float step = 0.03F; float step = 0.03F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y))); float[] predictions = network.predict(new float[]{x, y});
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y); visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
} }
} }

View File

@@ -15,5 +15,5 @@ public interface Model {
//void forEachSynapse(Consumer<Synapse> consumer); //void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer); void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer); void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
List<Float> predict(List<Input> inputs); float[] predict(float[] inputs);
} }

View File

@@ -11,7 +11,7 @@ public abstract class TrainingContext {
public DataSetEntry currentEntry; public DataSetEntry currentEntry;
public List<Float> expectations; public List<Float> expectations;
public List<Float> predictions; public float[] predictions;
public float[] deltas; public float[] deltas;
public float globalLoss; public float globalLoss;

View File

@@ -24,13 +24,12 @@ public class FullyConnectedNetwork implements Model {
} }
@Override @Override
public List<Float> predict(List<Input> inputs) { public float[] predict(float[] inputs) {
List<Input> previousLayerOutputs = new ArrayList<>(inputs); float[] previousLayerOutputs = inputs;
for(Layer layer : this.layers){ for (Layer layer : layers) {
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs); previousLayerOutputs = layer.predict(previousLayerOutputs);
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
} }
return previousLayerOutputs.stream().map(Input::getValue).toList(); return previousLayerOutputs;
} }
@Override @Override

View File

@@ -19,11 +19,10 @@ public class Layer implements Model {
} }
@Override @Override
public List<Float> predict(List<Input> inputs) { public float[] predict(float[] inputs) {
List<Float> result = new ArrayList<>(); float[] result = new float[neurons.length];
for(Neuron neuron : this.neurons){ for (int i = 0; i < neurons.length; i++) {
List<Float> res = neuron.predict(inputs); result[i] = neurons[i].predict(inputs)[0];
result.addAll(res);
} }
return result; return result;
} }

View File

@@ -17,6 +17,7 @@ public class Neuron implements Model {
this.id = id; this.id = id;
this.activationFunction = func; this.activationFunction = func;
output = 0;
weights = new float[synapses.length+1]; //takes the bias into account weights = new float[synapses.length+1]; //takes the bias into account
inputs = new float[synapses.length+1]; //takes the bias into account inputs = new float[synapses.length+1]; //takes the bias into account
@@ -44,10 +45,6 @@ public class Neuron implements Model {
return this.activationFunction; return this.activationFunction;
} }
public float getOutput(){
return this.output;
}
public float calculateWeightedSum() { public float calculateWeightedSum() {
int count = synCount(); int count = synCount();
float weightedSum = 0F; float weightedSum = 0F;
@@ -61,6 +58,10 @@ public class Neuron implements Model {
return this.id; return this.id;
} }
public float getOutput() {
return this.output;
}
@Override @Override
public int synCount() { public int synCount() {
return this.weights.length; return this.weights.length;
@@ -77,10 +78,10 @@ public class Neuron implements Model {
} }
@Override @Override
public List<Float> predict(List<Input> inputs) { public float[] predict(float[] inputs) {
this.setInputs(inputs); this.setInputs(inputs);
this.output = activationFunction.accept(this); output = activationFunction.accept(this);
return List.of(output); return new float[] {output};
} }
@Override @Override
@@ -98,10 +99,8 @@ public class Neuron implements Model {
throw new UnsupportedOperationException("Neurons have no connection with themselves"); throw new UnsupportedOperationException("Neurons have no connection with themselves");
} }
private void setInputs(List<Input> values){ private void setInputs(float[] values){
for(int i = 0; i < values.size(); i++){ System.arraycopy(values, 0, inputs, 1, values.length);
inputs[i+1] = values.get(i).getValue();
}
} }
} }

View File

@@ -101,7 +101,7 @@ public class TrainingPipeline {
if(this.verbose && ctx.epoch % this.verboseDelay == 0) { if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
System.out.printf("Epoch : %d, ", ctx.epoch); System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray())); System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray())); System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas)); System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
System.out.printf("loss : %.5f\n", ctx.localLoss); System.out.printf("loss : %.5f\n", ctx.localLoss);

View File

@@ -21,11 +21,11 @@ public class SimpleDeltaStep implements AlgorithmStep {
public void run() { public void run() {
DataSet dataSet = context.dataset; DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry; DataSetEntry entry = context.currentEntry;
List<Float> predicted = context.predictions; float[] predicted = context.predictions;
List<Float> expected = dataSet.getLabelsAsFloat(entry); List<Float> expected = dataSet.getLabelsAsFloat(entry);
for (int i = 0; i < predicted.size(); i++) { for (int i = 0; i < predicted.length; i++) {
context.deltas[i] = expected.get(i) - predicted.get(i); context.deltas[i] = expected.get(i) - predicted[0];
} }
} }

View File

@@ -2,6 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.List;
public class SimplePredictionStep implements AlgorithmStep { public class SimplePredictionStep implements AlgorithmStep {
@@ -13,6 +16,11 @@ public class SimplePredictionStep implements AlgorithmStep {
@Override @Override
public void run() { public void run() {
context.predictions = context.model.predict(context.currentEntry.getData()); List<Input> data = context.currentEntry.getData();
float[] flatData = new float[data.size()];
for (int i = 0; i < data.size(); i++) {
flatData[i] = data.get(i).getValue();
}
context.predictions = context.model.predict(flatData);
} }
} }