Implement batch size

This commit is contained in:
2026-03-31 22:52:03 +02:00
parent 5aca7b87e3
commit daba4f8420
8 changed files with 82 additions and 36 deletions

View File

@@ -1,5 +1,7 @@
package com.naaturel.ANN; package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Network;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.gradientDescent.Linear;
@@ -22,11 +24,21 @@ public class Main {
int nbrClass = 1; int nbrClass = 1;
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{50, 50, 50, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){ for (int i = 0; i < neuronPerLayer.length; i++){
@@ -49,26 +61,27 @@ public class Main {
layers.add(layer); layers.add(layer);
} }
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.0005F, 15000, network, dataset);
private static void plotGraph(DataSet dataset, Model network){
GraphVisualizer visualizer = new GraphVisualizer(); GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) { for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry); List<Float> label = dataset.getLabelsAsFloat(entry);
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); label.forEach(l -> {
visualizer.addPoint("Label " + l,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
});
} }
float min = 0F; float min = -5F;
float max = 15F; float max = 5F;
float step = 0.03F; float step = 0.01F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
float predSeries = prediction > 0.5F ? 1 : -1; visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
visualizer.addPoint(Float.toString(predSeries), x, y);
} }
} }

View File

@@ -25,8 +25,4 @@ public class Synapse {
public void setWeight(float value){ public void setWeight(float value){
this.weight.setValue(value); this.weight.setValue(value);
} }
} }

View File

@@ -2,6 +2,8 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class BackpropagationCorrectionStep implements AlgorithmStep { public class BackpropagationCorrectionStep implements AlgorithmStep {
private GradientBackpropagationContext context; private GradientBackpropagationContext context;
@@ -12,13 +14,29 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
@Override @Override
public void run() { public void run() {
AtomicInteger synIndex = new AtomicInteger(0);
this.context.model.forEachNeuron(n -> { this.context.model.forEachNeuron(n -> {
float signal = context.errorSignals.get(n);
n.forEachSynapse(syn -> { n.forEachSynapse(syn -> {
float lr = context.learningRate; float lr = context.learningRate;
float signal = context.errorSignals.get(n); float corrector = lr * signal * syn.getInput();
float newWeight = syn.getWeight() + (lr * signal * syn.getInput()); float existingCorrector = context.correctionBuffer[synIndex.get()];
syn.setWeight(newWeight); float newCorrector = existingCorrector + corrector;
if(context.currentSample >= context.batchSize){
float newWeight = syn.getWeight() + newCorrector;
syn.setWeight(newWeight);
context.correctionBuffer[synIndex.get()] = 0;
} else {
context.correctionBuffer[synIndex.get()] = newCorrector;
}
synIndex.incrementAndGet();
}); });
}); });
if(context.currentSample >= context.batchSize) {
context.currentSample = 0;
}
context.currentSample += 1;
} }
} }

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class BatchAccumulatorStep implements AlgorithmStep {
@Override
public void run() {
}
}

View File

@@ -1,14 +1,29 @@
package com.naaturel.ANN.implementation.multiLayers; package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
public class GradientBackpropagationContext extends TrainingContext { public class GradientBackpropagationContext extends TrainingContext {
public Map<Neuron, Float> errorSignals; public final Map<Neuron, Float> errorSignals;
public final float[] correctionBuffer;
public GradientBackpropagationContext(){ public int currentSample;
public int batchSize;
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
this.model = model;
this.dataset = dataSet;
this.learningRate = learningRate;
this.batchSize = batchSize;
this.errorSignals = new HashMap<>();
this.correctionBuffer = new float[model.synCount()];
this.currentSample = 1;
} }
} }

View File

@@ -23,7 +23,7 @@ public class OutputLayerErrorStep implements AlgorithmStep {
List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry); List<Float> expectations = this.context.dataset.getLabelsAsFloat(entry);
AtomicInteger index = new AtomicInteger(0); AtomicInteger index = new AtomicInteger(0);
context.errorSignals = new HashMap<>(); context.errorSignals.clear();
this.context.model.forEachOutputNeurons(n -> { this.context.model.forEachOutputNeurons(n -> {
float expected = expectations.get(index.get()); float expected = expectations.get(index.get());
float predicted = n.getOutput(); float predicted = n.getOutput();

View File

@@ -16,10 +16,8 @@ import java.util.List;
public class GradientBackpropagationTraining implements Trainer { public class GradientBackpropagationTraining implements Trainer {
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context = new GradientBackpropagationContext(); GradientBackpropagationContext context =
context.dataset = dataset; new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
context.model = model;
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),
@@ -34,7 +32,9 @@ public class GradientBackpropagationTraining implements Trainer {
.beforeEpoch(ctx -> { .beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F; ctx.globalLoss = 0.0F;
}) })
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) .afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size();
})
.withVerbose(true,epoch/10) .withVerbose(true,epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);

View File

@@ -19,10 +19,3 @@
4,7,-1 4,7,-1
4,9,1 4,9,1
4,10,1 4,10,1
2,6,-1
7,7,-1
5,9,1
9,10,1
7,1,-1
5,0,1
9,5,1
1 1 6 1
19 4 7 -1
20 4 9 1
21 4 10 1
2 6 -1
7 7 -1
5 9 1
9 10 1
7 1 -1
5 0 1
9 5 1