Optimize some stuff

This commit is contained in:
2026-04-01 16:14:13 +02:00
parent daba4f8420
commit 1e8b02089c
20 changed files with 150 additions and 102 deletions

View File

@@ -24,21 +24,23 @@ public class Main {
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
@@ -54,8 +56,9 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns.toArray(new Synapse[0]), bias, new TanH());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
@@ -77,7 +80,7 @@ public class Main {
float min = -5F;
float max = 5F;
float step = 0.01F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));

View File

@@ -12,11 +12,18 @@ public abstract class TrainingContext {
public List<Float> expectations;
public List<Float> predictions;
public List<Float> deltas;
public float[] deltas;
public float globalLoss;
public float localLoss;
public float learningRate;
public int epoch;
public TrainingContext(Model model, DataSet dataset) {
this.model = model;
this.dataset = dataset;
this.deltas = new float[model.neuronCount()];
}
}

View File

@@ -7,18 +7,25 @@ import java.util.function.Consumer;
public class Neuron implements Model {
protected Synapse[] synapses;
protected int id;
protected final Synapse[] synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
protected Float output;
protected Float weightedSum;
protected final float[] weights;
protected final float[] inputs;
public Neuron(Synapse[] synapses, Bias bias, ActivationFunction func){
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
this.id = id;
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
this.output = null;
this.weightedSum = null;
weights = new float[synapses.length];
inputs = new float[synapses.length];
}
public void updateBias(Weight weight) {
@@ -53,12 +60,18 @@ public class Neuron implements Model {
}
public float calculateWeightedSum() {
this.weightedSum = 0F;
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
this.weightedSum += syn.getWeight() * syn.getInput();
float sum = bias.getWeight() * bias.getInput();
for (int i = 0; i < weights.length; i++) {
sum += weights[i] * inputs[i];
}
return this.weightedSum;
this.weightedSum = sum;
return sum;
}
public int getId(){
return this.id;
}
@Override

View File

@@ -103,7 +103,7 @@ public class TrainingPipeline {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray()));
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}

View File

@@ -1,6 +1,11 @@
package com.naaturel.ANN.implementation.adaline;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public class AdalineTrainingContext extends TrainingContext {
public AdalineTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -20,7 +20,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas.get(neuronIndex.get());
float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(synIndex.get());

View File

@@ -1,6 +1,8 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
@@ -8,4 +10,7 @@ public class GradientDescentTrainingContext extends TrainingContext {
public List<Float> correctorTerms;
public GradientDescentTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -15,9 +15,11 @@ public class SquareLossStep implements AlgorithmStep {
@Override
public void run() {
Stream<Float> deltaStream = this.context.deltas.stream();
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
this.context.localLoss /= 2;
this.context.globalLoss += this.context.localLoss; //broke MSE en gradientDescentTraining
float loss = 0f;
for (float d : this.context.deltas) {
loss += d * d;
}
this.context.localLoss = loss / 2f;
this.context.globalLoss += this.context.localLoss;
}
}

View File

@@ -2,41 +2,55 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class BackpropagationCorrectionStep implements AlgorithmStep {
private GradientBackpropagationContext context;
private final GradientBackpropagationContext context;
private final int synCount;
private final float[] inputs;
private final float[] signals;
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
this.context = context;
this.synCount = context.correctionBuffer.length;
this.inputs = new float[synCount];
this.signals = new float[synCount];
}
@Override
public void run() {
AtomicInteger synIndex = new AtomicInteger(0);
this.context.model.forEachNeuron(n -> {
float signal = context.errorSignals.get(n);
int[] synIndex = {0};
context.model.forEachNeuron(n -> {
float signal = context.errorSignals[n.getId()];
n.forEachSynapse(syn -> {
float lr = context.learningRate;
float corrector = lr * signal * syn.getInput();
float existingCorrector = context.correctionBuffer[synIndex.get()];
float newCorrector = existingCorrector + corrector;
if(context.currentSample >= context.batchSize){
float newWeight = syn.getWeight() + newCorrector;
syn.setWeight(newWeight);
context.correctionBuffer[synIndex.get()] = 0;
} else {
context.correctionBuffer[synIndex.get()] = newCorrector;
}
synIndex.incrementAndGet();
inputs[synIndex[0]] = syn.getInput();
signals[synIndex[0]] = signal;
synIndex[0]++;
});
});
if(context.currentSample >= context.batchSize) {
float lr = context.learningRate;
boolean applyUpdate = context.currentSample >= context.batchSize;
for (int i = 0; i < synCount; i++) {
context.correctionBuffer[i] += lr * signals[i] * inputs[i];
}
if (applyUpdate) {
syncWeights();
context.currentSample = 0;
}
context.currentSample += 1;
context.currentSample++;
}
}
private void syncWeights() {
int[] i = {0};
context.model.forEachNeuron(n -> {
n.forEachSynapse(syn -> {
syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]);
context.correctionBuffer[i[0]] = 0f;
i[0]++;
});
});
}
}

View File

@@ -1,38 +1,29 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
public class ErrorSignalStep implements AlgorithmStep {
private GradientBackpropagationContext context;
private final GradientBackpropagationContext context;
public ErrorSignalStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
this.context.model.forEachNeuron(n -> {
calculateErrorSignalRecursive(n, this.context.errorSignals);
context.model.forEachNeuron(n -> {
if (context.errorSignalsComputed[n.getId()]) return;
int neuronIndex = context.model.indexInLayerOf(n);
float[] signalSum = {0f};
context.model.forEachNeuronConnectedTo(n, connected -> {
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
});
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
context.errorSignalsComputed[n.getId()] = true;
});
}
private float calculateErrorSignalRecursive(Neuron n, Map<Neuron, Float> signals) {
if (signals.containsKey(n)) return signals.get(n);
int neuronIndex = this.context.model.indexInLayerOf(n);
AtomicReference<Float> signalSum = new AtomicReference<>(0F);
this.context.model.forEachNeuronConnectedTo(n, connected -> {
float weightedSignal = calculateErrorSignalRecursive(connected, signals) * connected.getWeight(neuronIndex);
signalSum.set(signalSum.get() + weightedSignal);
});
float derivative = n.getActivationFunction().derivative(n.getOutput());
float finalSignal = derivative * signalSum.get();
signals.put(n, finalSignal);
return finalSignal;
}
}
}

View File

@@ -10,20 +10,21 @@ import java.util.Map;
public class GradientBackpropagationContext extends TrainingContext {
public final Map<Neuron, Float> errorSignals;
public final float[] errorSignals;
public final float[] correctionBuffer;
public final boolean[] errorSignalsComputed;
public int currentSample;
public int batchSize;
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
this.model = model;
this.dataset = dataSet;
super(model, dataSet);
this.learningRate = learningRate;
this.batchSize = batchSize;
this.errorSignals = new HashMap<>();
this.errorSignals = new float[model.neuronCount()];
this.correctionBuffer = new float[model.synCount()];
this.errorSignalsComputed = new boolean[model.neuronCount()];
this.currentSample = 1;
}
}

View File

@@ -3,37 +3,40 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public class OutputLayerErrorStep implements AlgorithmStep {
private final GradientBackpropagationContext context;
private final float[] expectations;
public OutputLayerErrorStep(GradientBackpropagationContext context){
this.context = context;
this.expectations = new float[context.model.neuronCount()];
}
@Override
public void run() {
context.deltas = new ArrayList<>();
DataSetEntry entry = this.context.currentEntry;
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
AtomicInteger index = new AtomicInteger(0);
Arrays.fill(context.errorSignals, 0f);
Arrays.fill(context.errorSignalsComputed, false);
context.errorSignals.clear();
this.context.model.forEachOutputNeurons(n -> {
float expected = expectations.get(index.get());
DataSetEntry entry = context.currentEntry;
List<Float> labels = context.dataset.getLabelsAsFloat(entry);
for (int i = 0; i < labels.size(); i++) {
expectations[i] = labels.get(i);
}
int[] index = {0};
context.model.forEachOutputNeurons(n -> {
float expected = expectations[index[0]];
float predicted = n.getOutput();
float output = n.getOutput();
float delta = expected - predicted;
float signal = delta * n.getActivationFunction().derivative(output);
this.context.deltas.add(delta);
this.context.errorSignals.put(n, signal);
index.incrementAndGet();
context.deltas[index[0]] = delta;
context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted);
context.errorSignalsComputed[n.getId()] = true;
index[0]++;
});
}
}
}

View File

@@ -21,7 +21,7 @@ public class SimpleCorrectionStep implements AlgorithmStep {
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas.get(neuronIndex.get());
float correspondingDelta = context.deltas[neuronIndex.get()];
neuron.forEachSynapse(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();

View File

@@ -24,10 +24,9 @@ public class SimpleDeltaStep implements AlgorithmStep {
List<Float> predicted = context.predictions;
List<Float> expected = dataSet.getLabelsAsFloat(entry);
//context.delta = label.getValue() - context.predictions;
context.deltas = IntStream.range(0, predicted.size())
.mapToObj(i -> expected.get(i) - predicted.get(i))
.collect(Collectors.toList());
for (int i = 0; i < predicted.size(); i++) {
context.deltas[i] = expected.get(i) - predicted.get(i);
}
}
}

View File

@@ -12,6 +12,10 @@ public class SimpleLossStrategy implements AlgorithmStep {
@Override
public void run() {
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
float loss = 0f;
for (float d : context.deltas) {
loss += d;
}
context.localLoss = loss;
}
}

View File

@@ -1,6 +1,11 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public class SimpleTrainingContext extends TrainingContext {
public SimpleTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -24,9 +24,7 @@ public class AdalineTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = model;
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(

View File

@@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3);
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
@@ -28,14 +28,14 @@ public class GradientBackpropagationTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
})
.afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size();
})
.withVerbose(true,epoch/10)
.withVerbose(false,epoch/10)
.withTimeMeasurement(true)
.run(context);
}

View File

@@ -24,9 +24,7 @@ public class GradientDescentTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>();

View File

@@ -17,7 +17,7 @@ public class SimpleTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext();
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
context.dataset = dataset;
context.model = model;
context.learningRate = learningRate;