Minor fixes

This commit is contained in:
2026-03-31 16:26:28 +02:00
parent 165a2bc977
commit 5aca7b87e3
4 changed files with 19 additions and 11 deletions

View File

@@ -2,6 +2,7 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
@@ -18,12 +19,12 @@ public class Main {
public static void main(String[] args){ public static void main(String[] args){
int nbrClass = 3; int nbrClass = 1;
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{10, 5, 5, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
@@ -51,7 +52,7 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.0001F, 15000, network, dataset); trainer.train(0.0005F, 15000, network, dataset);
GraphVisualizer visualizer = new GraphVisualizer(); GraphVisualizer visualizer = new GraphVisualizer();
@@ -60,13 +61,13 @@ public class Main {
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
} }
float min = -5F; float min = 0F;
float max = 5F; float max = 15F;
float step = 0.025F; float step = 0.03F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
float predSeries = prediction > 0.5F ? 1 : 0; float predSeries = prediction > 0.5F ? 1 : -1;
visualizer.addPoint(Float.toString(predSeries), x, y); visualizer.addPoint(Float.toString(predSeries), x, y);
} }
} }

View File

@@ -30,12 +30,12 @@ public class GradientBackpropagationTraining implements Trainer {
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch) .stopCondition(ctx -> ctx.globalLoss <= 0.001F || ctx.epoch > epoch)
.beforeEpoch(ctx -> { .beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F; ctx.globalLoss = 0.0F;
}) })
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
.withVerbose(true, epoch/10) .withVerbose(true,epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }

View File

@@ -19,3 +19,10 @@
4,7,-1 4,7,-1
4,9,1 4,9,1
4,10,1 4,10,1
2,6,-1
7,7,-1
5,9,1
9,10,1
7,1,-1
5,0,1
9,5,1
1 1 6 1
19 4 7 -1
20 4 9 1
21 4 10 1
22 2 6 -1
23 7 7 -1
24 5 9 1
25 9 10 1
26 7 1 -1
27 5 0 1
28 9 5 1