From 613bbbcbe2d57f1096f6eb4d1e8bdea3f15892c5 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 11 May 2026 14:22:09 +0200 Subject: [PATCH] Implement model selector and fix tests --- build.gradle.kts | 12 +++++++ config.json | 6 ++-- src/main/java/com/naaturel/ANN/Main.java | 26 ++++++++++---- .../domain/model/helpers/ModelCreator.java | 6 ++-- .../training/AdalineTraining.java | 3 +- .../GradientBackpropagationTraining.java | 2 +- .../training/GradientDescentTraining.java | 7 ++-- .../training/SimpleTraining.java | 2 +- src/test/java/adaline/AdalineTest.java | 31 +++++++--------- .../gradientDescent/GradientDescentTest.java | 36 +++++++------------ .../java/perceptron/SimplePerceptronTest.java | 35 +++++++----------- 11 files changed, 81 insertions(+), 85 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index f649391..adebd89 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -5,6 +5,8 @@ plugins { group = "be.naaturel" version = "1.0-SNAPSHOT" + + repositories { mavenCentral() } @@ -13,11 +15,21 @@ dependencies { implementation("org.jfree:jfreechart:1.5.4") implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2") + implementation("org.jline:jline:3.27.1") + testImplementation(platform("org.junit:junit-bom:5.10.0")) testImplementation("org.junit.jupiter:junit-jupiter") testRuntimeOnly("org.junit.platform:junit-platform-launcher") } +tasks.jar { + manifest { + attributes["Main-Class"] = "com.naaturel.ANN.Main" + } + from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) }) + duplicatesStrategy = DuplicatesStrategy.EXCLUDE +} + tasks.test { useJUnitPlatform() } \ No newline at end of file diff --git a/config.json b/config.json index 69ff49f..41c2604 100644 --- a/config.json +++ b/config.json @@ -1,14 +1,14 @@ { "model": { "new": true, - "parameters": [5, 5, 1], - "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json" + "parameters": [1], + "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" }, "training" : { "learning_rate" : 0.03, "max_epoch" : 5000 }, "dataset" : { - "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv" + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv" } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 6d01538..741ab24 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,16 +1,16 @@ package com.naaturel.ANN; import com.naaturel.ANN.domain.abstraction.Model; -import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; -import com.naaturel.ANN.implementation.multiLayers.TanH; +import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; +import com.naaturel.ANN.implementation.training.GradientDescentTraining; +import com.naaturel.ANN.implementation.training.SimpleTraining; import com.naaturel.ANN.infrastructure.config.ConfigDto; import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; -import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer; @@ -21,7 +21,17 @@ public class Main { public static void main(String[] args) throws Exception { + String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"}; + Scanner sc = new Scanner(System.in); + for (int i = 0; i < types.length; i++) { + System.out.printf("%d - %s\n", i+1, types[i]); + } + + System.out.print(">>> "); + int typeIndex = sc.nextInt() - 1; + sc.nextLine(); + System.out.printf("\nChosen type: %s\n", types[typeIndex]); ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); @@ -43,7 +53,13 @@ public class Main { ModelSnapshot snapshot = new ModelSnapshot(); if(newModel) { - Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput); + Trainer trainer = switch (typeIndex) { + case 0 -> new SimpleTraining(modelParameters, nbrInput); + case 1 -> new GradientDescentTraining(modelParameters, nbrInput); + case 2 -> new AdalineTraining(modelParameters, nbrInput); + case 3 -> new GradientBackpropagationTraining(modelParameters, nbrInput); + default -> throw new IllegalStateException("Unexpected value: " + typeIndex); + }; trainer.train(learningRate, maxEpoch, dataset); trainer.saveModel(snapshot, modelPath); } @@ -56,8 +72,6 @@ public class Main { .display(); } - - private static void plotGraph(DataSet dataset, Model network){ if(dataset.getNbrInputs() != 2) return; diff --git a/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java index b00e662..7665ee8 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java +++ b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java @@ -1,15 +1,15 @@ package com.naaturel.ANN.domain.model.helpers; +import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.model.neuron.*; -import com.naaturel.ANN.implementation.multiLayers.TanH; import java.util.ArrayList; import java.util.List; public class ModelCreator { - public static Model createModel(int[] neuronPerLayer, int nbrInput){ + public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){ int neuronId = 0; List layers = new ArrayList<>(); for (int i = 0; i < neuronPerLayer.length; i++){ @@ -26,7 +26,7 @@ public class ModelCreator { Bias bias = new Bias(new Weight()); - Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); + Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, func); neurons.add(n); neuronId++; } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 27758ce..4b441a1 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.helpers.ModelCreator; +import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; @@ -23,7 +24,7 @@ public class AdalineTraining implements Trainer { private Model model; public AdalineTraining(int[] neurons, int nbrInputs){ - model = ModelCreator.createModel(neurons, nbrInputs); + model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0)); } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index adc7005..9a5f727 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -18,7 +18,7 @@ public class GradientBackpropagationTraining implements Trainer { private Model model; public GradientBackpropagationTraining(int[] neurons, int nbrInputs){ - model = ModelCreator.createModel(neurons, nbrInputs); + model = ModelCreator.createModel(neurons, nbrInputs, new TanH()); } public Model getModel() { diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index ef4c0c1..2848f4d 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -4,12 +4,9 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.helpers.ModelCreator; +import com.naaturel.ANN.implementation.gradientDescent.*; import com.naaturel.ANN.infrastructure.dataset.DataSet; -import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; -import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy; -import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; @@ -23,7 +20,7 @@ public class GradientDescentTraining implements Trainer { private Model model; public GradientDescentTraining(int[] neurons, int nbrInputs){ - model = ModelCreator.createModel(neurons, nbrInputs); + model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0)); } @Override diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 4c1c2f8..a7c3ee0 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer { private Model model; public SimpleTraining(int[] neurons, int nbrInputs){ - model = ModelCreator.createModel(neurons, nbrInputs); + model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside()); } @Override diff --git a/src/test/java/adaline/AdalineTest.java b/src/test/java/adaline/AdalineTest.java index 82b2c0c..416f1ed 100644 --- a/src/test/java/adaline/AdalineTest.java +++ b/src/test/java/adaline/AdalineTest.java @@ -1,8 +1,10 @@ package adaline; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.domain.model.neuron.Neuron; -import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; @@ -13,7 +15,6 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; -import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -27,9 +28,7 @@ public class AdalineTest { private DataSet dataset; private AdalineTrainingContext context; - private List synapses; - private Bias bias; - private FullyConnectedNetwork network; + private Model model; private TrainingPipeline pipeline; @@ -42,22 +41,16 @@ public class AdalineTest { syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0))); - bias = new Bias(new Weight(0)); + model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0)); - Neuron neuron = new Neuron(syns, bias, new Linear(1, 0)); - Layer layer = new Layer(List.of(neuron)); - network = new FullyConnectedNetwork(List.of(layer)); + context = new AdalineTrainingContext(model, dataset); - context = new AdalineTrainingContext(); - context.dataset = dataset; - context.model = network; - - List steps = List.of( - new PredictionStep(new SimplePredictionStep(context)), - new DeltaStep(new SimpleDeltaStep(context)), - new LossStep(new SquareLossStep(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)), - new WeightCorrectionStep(new SimpleCorrectionStep(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SquareLossStep(context), + new SimpleErrorRegistrationStep(context), + new SimpleCorrectionStep(context) ); pipeline = new TrainingPipeline(steps) diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java index 0312862..145ea57 100644 --- a/src/test/java/gradientDescent/GradientDescentTest.java +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -1,14 +1,14 @@ package gradientDescent; -import com.naaturel.ANN.domain.model.neuron.Neuron; -import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.gradientDescent.*; import com.naaturel.ANN.implementation.simplePerceptron.*; -import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -23,9 +23,7 @@ public class GradientDescentTest { private DataSet dataset; private GradientDescentTrainingContext context; - private List synapses; - private Bias bias; - private FullyConnectedNetwork network; + private Model model; private TrainingPipeline pipeline; @@ -34,26 +32,16 @@ public class GradientDescentTest { dataset = new DatasetExtractor() .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1); - List syns = new ArrayList<>(); - syns.add(new Synapse(new Input(0), new Weight(0))); - syns.add(new Synapse(new Input(0), new Weight(0))); + model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0)); - bias = new Bias(new Weight(0)); - - Neuron neuron = new Neuron(syns, bias, new Linear(1, 0)); - Layer layer = new Layer(List.of(neuron)); - network = new FullyConnectedNetwork(List.of(layer)); - - context = new GradientDescentTrainingContext(); - context.dataset = dataset; - context.model = network; + context = new GradientDescentTrainingContext(model, dataset); context.correctorTerms = new ArrayList<>(); - List steps = List.of( - new PredictionStep(new SimplePredictionStep(context)), - new DeltaStep(new SimpleDeltaStep(context)), - new LossStep(new SquareLossStep(context)), - new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SquareLossStep(context), + new GradientDescentErrorStrategy(context) ); pipeline = new TrainingPipeline(steps) @@ -91,7 +79,7 @@ public class GradientDescentTest { }); pipeline - .withVerbose(true) + .withVerbose(true, 1) .run(context); assertEquals(67, context.epoch); } diff --git a/src/test/java/perceptron/SimplePerceptronTest.java b/src/test/java/perceptron/SimplePerceptronTest.java index cabd4e1..94a3a92 100644 --- a/src/test/java/perceptron/SimplePerceptronTest.java +++ b/src/test/java/perceptron/SimplePerceptronTest.java @@ -1,13 +1,13 @@ package perceptron; -import com.naaturel.ANN.domain.model.neuron.Neuron; -import com.naaturel.ANN.domain.abstraction.TrainingStep; +import com.naaturel.ANN.domain.abstraction.AlgorithmStep; +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.simplePerceptron.*; -import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -21,10 +21,7 @@ public class SimplePerceptronTest { private DataSet dataset; private SimpleTrainingContext context; - - private List synapses; - private Bias bias; - private FullyConnectedNetwork network; + private Model model; private TrainingPipeline pipeline; @@ -37,22 +34,16 @@ public class SimplePerceptronTest { syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0))); - bias = new Bias(new Weight(0)); + model = ModelCreator.createModel(new int[]{1}, 2, new Heaviside()); - Neuron neuron = new Neuron(syns, bias, new Heaviside()); - Layer layer = new Layer(List.of(neuron)); - network = new FullyConnectedNetwork(List.of(layer)); + context = new SimpleTrainingContext(model, dataset); - context = new SimpleTrainingContext(); - context.dataset = dataset; - context.model = network; - - List steps = List.of( - new PredictionStep(new SimplePredictionStep(context)), - new DeltaStep(new SimpleDeltaStep(context)), - new LossStep(new SimpleLossStrategy(context)), - new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)), - new WeightCorrectionStep(new SimpleCorrectionStep(context)) + List steps = List.of( + new SimplePredictionStep(context), + new SimpleDeltaStep(context), + new SimpleLossStrategy(context), + new SimpleErrorRegistrationStep(context), + new SimpleCorrectionStep(context) ); pipeline = new TrainingPipeline(steps); @@ -74,7 +65,7 @@ public class SimplePerceptronTest { context.learningRate = 1F; pipeline.afterEpoch(ctx -> { - int index = ctx.epoch-1; + int index = ctx.epoch; assertEquals(expectedGlobalLosses.get(index), context.globalLoss); });