From 6d886513857cfcc7c2b693f8fc1998429443907d Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 28 Mar 2026 12:25:59 +0100 Subject: [PATCH] Move dataset components --- src/main/java/com/naaturel/ANN/Main.java | 37 +++++++++++-------- .../ANN/domain/abstraction/Trainer.java | 2 +- .../domain/abstraction/TrainingContext.java | 4 +- .../ANN/domain/model/neuron/Network.java | 7 ++-- .../model/training/TrainingPipeline.java | 2 +- .../simplePerceptron/SimpleDeltaStrategy.java | 5 +-- .../training/AdalineTraining.java | 2 +- .../training/GradientDescentTraining.java | 10 ++--- .../training/SimpleTraining.java | 2 +- .../dataset/DataSet.java | 2 +- .../dataset/DataSetEntry.java | 2 +- .../dataset/DatasetExtractor.java | 2 +- .../dataset/Labels.java | 2 +- src/test/java/adaline/AdalineTest.java | 4 +- .../gradientDescent/GradientDescentTest.java | 4 +- .../java/perceptron/SimplePerceptronTest.java | 4 +- 16 files changed, 47 insertions(+), 44 deletions(-) rename src/main/java/com/naaturel/ANN/{domain/model => infrastructure}/dataset/DataSet.java (97%) rename src/main/java/com/naaturel/ANN/{domain/model => infrastructure}/dataset/DataSetEntry.java (94%) rename src/main/java/com/naaturel/ANN/{domain/model => infrastructure}/dataset/DatasetExtractor.java (96%) rename src/main/java/com/naaturel/ANN/{domain/model => infrastructure}/dataset/Labels.java (83%) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index c6a5562..d53a1ab 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -2,11 +2,10 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.implementation.gradientDescent.Linear; -import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining; import java.util.*; @@ -18,25 +17,31 @@ public class Main { int nbrInput = 2; int nbrClass = 3; + int nbrLayers = 1; + DataSet dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); - List neurons = new ArrayList<>(); + List layers = new ArrayList<>(); + for(int i = 0; i < nbrLayers; i++){ - for (int i=0; i < nbrClass; i++){ - List syns = new ArrayList<>(); - for (int j=0; j < nbrInput; j++){ - syns.add(new Synapse(new Input(0), new Weight(0))); + List neurons = new ArrayList<>(); + for (int j=0; j < nbrClass; j++){ + + List syns = new ArrayList<>(); + for (int k=0; k < nbrInput; k++){ + syns.add(new Synapse(new Input(0), new Weight(0))); + } + + Bias bias = new Bias(new Weight(0)); + + Neuron n = new Neuron(syns, bias, new Linear()); + neurons.add(n); } - - Bias bias = new Bias(new Weight(0)); - - Neuron n = new Neuron(syns, bias, new Linear()); - neurons.add(n); + Layer layer = new Layer(neurons); + layers.add(layer); } - - Layer layer = new Layer(neurons); - Network network = new Network(List.of(layer)); + Network network = new Network(layers); Trainer trainer = new GradientDescentTraining(); trainer.train(network, dataset); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index 06875a2..867341d 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -1,6 +1,6 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSet; public interface Trainer { void train(Model model, DataSet dataset); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index 14dc8c9..beeab05 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -1,7 +1,7 @@ package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import java.util.List; diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index 91d8b5e..7210e5d 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -16,12 +16,11 @@ public class Network implements Model { @Override public List predict(List inputs) { - List result = new ArrayList<>(); + List currentLayerOutput = new ArrayList<>(inputs); for(Layer layer : this.layers){ - List res = layer.predict(inputs); - result.addAll(res); + currentLayerOutput = layer.predict(currentLayerOutput).stream().map(Input::new).toList(); } - return result; + return currentLayerOutput.stream().map(Input::getValue).toList(); } @Override diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 5e8865d..8975d94 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -2,7 +2,7 @@ package com.naaturel.ANN.domain.model.training; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java index ec57d65..108798b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java @@ -2,10 +2,9 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingContext; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; -import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index e1dd0b8..58d434d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index d40ab28..f6a08d8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; @@ -28,7 +28,7 @@ public class GradientDescentTraining implements Trainer { GradientDescentTrainingContext context = new GradientDescentTrainingContext(); context.dataset = dataset; context.model = model; - context.learningRate = 0.0005F; + context.learningRate = 0.0008F; context.correctorTerms = new ArrayList<>(); List steps = List.of( @@ -39,7 +39,7 @@ public class GradientDescentTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 50000) + .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 150) .beforeEpoch(ctx -> { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; @@ -50,9 +50,9 @@ public class GradientDescentTraining implements Trainer { context.globalLoss /= context.dataset.size(); new GradientDescentCorrectionStrategy(context).apply(); }) - //.withVerbose(true) + .withVerbose(true) .withTimeMeasurement(true) - .withVisualization(false, new GraphVisualizer()) + .withVisualization(true, new GraphVisualizer()) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 3c7b97b..a68ec32 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -3,7 +3,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.training.steps.*; diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java similarity index 97% rename from src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java rename to src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java index f9ba91a..7fe085d 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.domain.model.dataset; +package com.naaturel.ANN.infrastructure.dataset; import com.naaturel.ANN.domain.model.neuron.Input; diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java similarity index 94% rename from src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java rename to src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java index c87bafe..99941ce 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSetEntry.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSetEntry.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.domain.model.dataset; +package com.naaturel.ANN.infrastructure.dataset; import com.naaturel.ANN.domain.model.neuron.Input; diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DatasetExtractor.java similarity index 96% rename from src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java rename to src/main/java/com/naaturel/ANN/infrastructure/dataset/DatasetExtractor.java index f3fd04a..612162a 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DatasetExtractor.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.domain.model.dataset; +package com.naaturel.ANN.infrastructure.dataset; import com.naaturel.ANN.domain.model.neuron.Input; diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/Labels.java similarity index 83% rename from src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java rename to src/main/java/com/naaturel/ANN/infrastructure/dataset/Labels.java index 9a7a785..3cbb1c7 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/Labels.java @@ -1,4 +1,4 @@ -package com.naaturel.ANN.domain.model.dataset; +package com.naaturel.ANN.infrastructure.dataset; import java.util.List; diff --git a/src/test/java/adaline/AdalineTest.java b/src/test/java/adaline/AdalineTest.java index c360de6..1cbe109 100644 --- a/src/test/java/adaline/AdalineTest.java +++ b/src/test/java/adaline/AdalineTest.java @@ -3,8 +3,8 @@ package adaline; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java index 29fe934..9952431 100644 --- a/src/test/java/gradientDescent/GradientDescentTest.java +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -2,8 +2,8 @@ package gradientDescent; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.gradientDescent.*; diff --git a/src/test/java/perceptron/SimplePerceptronTest.java b/src/test/java/perceptron/SimplePerceptronTest.java index 4615d10..2251988 100644 --- a/src/test/java/perceptron/SimplePerceptronTest.java +++ b/src/test/java/perceptron/SimplePerceptronTest.java @@ -2,8 +2,8 @@ package perceptron; import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; +import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.simplePerceptron.*;