Implement model selector and fix tests

This commit is contained in:
2026-05-11 14:22:09 +02:00
parent 159e414cb8
commit 613bbbcbe2
11 changed files with 81 additions and 85 deletions

View File

@@ -5,6 +5,8 @@ plugins {
group = "be.naaturel" group = "be.naaturel"
version = "1.0-SNAPSHOT" version = "1.0-SNAPSHOT"
repositories { repositories {
mavenCentral() mavenCentral()
} }
@@ -13,11 +15,21 @@ dependencies {
implementation("org.jfree:jfreechart:1.5.4") implementation("org.jfree:jfreechart:1.5.4")
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2") implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
implementation("org.jline:jline:3.27.1")
testImplementation(platform("org.junit:junit-bom:5.10.0")) testImplementation(platform("org.junit:junit-bom:5.10.0"))
testImplementation("org.junit.jupiter:junit-jupiter") testImplementation("org.junit.jupiter:junit-jupiter")
testRuntimeOnly("org.junit.platform:junit-platform-launcher") testRuntimeOnly("org.junit.platform:junit-platform-launcher")
} }
tasks.jar {
manifest {
attributes["Main-Class"] = "com.naaturel.ANN.Main"
}
from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) })
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
}
tasks.test { tasks.test {
useJUnitPlatform() useJUnitPlatform()
} }

View File

@@ -1,14 +1,14 @@
{ {
"model": { "model": {
"new": true, "new": true,
"parameters": [5, 5, 1], "parameters": [1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json" "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
}, },
"training" : { "training" : {
"learning_rate" : 0.03, "learning_rate" : 0.03,
"max_epoch" : 5000 "max_epoch" : 5000
}, },
"dataset" : { "dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv" "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
} }
} }

View File

@@ -1,16 +1,16 @@
package com.naaturel.ANN; package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto; import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer; import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
@@ -21,7 +21,17 @@ public class Main {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
Scanner sc = new Scanner(System.in); Scanner sc = new Scanner(System.in);
for (int i = 0; i < types.length; i++) {
System.out.printf("%d - %s\n", i+1, types[i]);
}
System.out.print(">>> ");
int typeIndex = sc.nextInt() - 1;
sc.nextLine();
System.out.printf("\nChosen type: %s\n", types[typeIndex]);
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
@@ -43,7 +53,13 @@ public class Main {
ModelSnapshot snapshot = new ModelSnapshot(); ModelSnapshot snapshot = new ModelSnapshot();
if(newModel) { if(newModel) {
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput); Trainer trainer = switch (typeIndex) {
case 0 -> new SimpleTraining(modelParameters, nbrInput);
case 1 -> new GradientDescentTraining(modelParameters, nbrInput);
case 2 -> new AdalineTraining(modelParameters, nbrInput);
case 3 -> new GradientBackpropagationTraining(modelParameters, nbrInput);
default -> throw new IllegalStateException("Unexpected value: " + typeIndex);
};
trainer.train(learningRate, maxEpoch, dataset); trainer.train(learningRate, maxEpoch, dataset);
trainer.saveModel(snapshot, modelPath); trainer.saveModel(snapshot, modelPath);
} }
@@ -56,8 +72,6 @@ public class Main {
.display(); .display();
} }
private static void plotGraph(DataSet dataset, Model network){ private static void plotGraph(DataSet dataset, Model network){
if(dataset.getNbrInputs() != 2) return; if(dataset.getNbrInputs() != 2) return;

View File

@@ -1,15 +1,15 @@
package com.naaturel.ANN.domain.model.helpers; package com.naaturel.ANN.domain.model.helpers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class ModelCreator { public class ModelCreator {
public static Model createModel(int[] neuronPerLayer, int nbrInput){ public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
int neuronId = 0; int neuronId = 0;
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){ for (int i = 0; i < neuronPerLayer.length; i++){
@@ -26,7 +26,7 @@ public class ModelCreator {
Bias bias = new Bias(new Weight()); Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, func);
neurons.add(n); neurons.add(n);
neuronId++; neuronId++;
} }

View File

@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
@@ -23,7 +24,7 @@ public class AdalineTraining implements Trainer {
private Model model; private Model model;
public AdalineTraining(int[] neurons, int nbrInputs){ public AdalineTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs); model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
} }
@Override @Override

View File

@@ -18,7 +18,7 @@ public class GradientBackpropagationTraining implements Trainer {
private Model model; private Model model;
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){ public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs); model = ModelCreator.createModel(neurons, nbrInputs, new TanH());
} }
public Model getModel() { public Model getModel() {

View File

@@ -4,12 +4,9 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
@@ -23,7 +20,7 @@ public class GradientDescentTraining implements Trainer {
private Model model; private Model model;
public GradientDescentTraining(int[] neurons, int nbrInputs){ public GradientDescentTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs); model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
} }
@Override @Override

View File

@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
private Model model; private Model model;
public SimpleTraining(int[] neurons, int nbrInputs){ public SimpleTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs); model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
} }
@Override @Override

View File

@@ -1,8 +1,10 @@
package adaline; package adaline;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
@@ -13,7 +15,6 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -27,9 +28,7 @@ public class AdalineTest {
private DataSet dataset; private DataSet dataset;
private AdalineTrainingContext context; private AdalineTrainingContext context;
private List<Synapse> synapses; private Model model;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -42,22 +41,16 @@ public class AdalineTest {
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0)); model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0)); context = new AdalineTrainingContext(model, dataset);
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new AdalineTrainingContext(); List<AlgorithmStep> steps = List.of(
context.dataset = dataset; new SimplePredictionStep(context),
context.model = network; new SimpleDeltaStep(context),
new SquareLossStep(context),
List<TrainingStep> steps = List.of( new SimpleErrorRegistrationStep(context),
new PredictionStep(new SimplePredictionStep(context)), new SimpleCorrectionStep(context)
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
); );
pipeline = new TrainingPipeline(steps) pipeline = new TrainingPipeline(steps)

View File

@@ -1,14 +1,14 @@
package gradientDescent; package gradientDescent;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.*; import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -23,9 +23,7 @@ public class GradientDescentTest {
private DataSet dataset; private DataSet dataset;
private GradientDescentTrainingContext context; private GradientDescentTrainingContext context;
private List<Synapse> synapses; private Model model;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -34,26 +32,16 @@ public class GradientDescentTest {
dataset = new DatasetExtractor() dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>(); model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0)); context = new GradientDescentTrainingContext(model, dataset);
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = network;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)), new SimplePredictionStep(context),
new DeltaStep(new SimpleDeltaStep(context)), new SimpleDeltaStep(context),
new LossStep(new SquareLossStep(context)), new SquareLossStep(context),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) new GradientDescentErrorStrategy(context)
); );
pipeline = new TrainingPipeline(steps) pipeline = new TrainingPipeline(steps)
@@ -91,7 +79,7 @@ public class GradientDescentTest {
}); });
pipeline pipeline
.withVerbose(true) .withVerbose(true, 1)
.run(context); .run(context);
assertEquals(67, context.epoch); assertEquals(67, context.epoch);
} }

View File

@@ -1,13 +1,13 @@
package perceptron; package perceptron;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -21,10 +21,7 @@ public class SimplePerceptronTest {
private DataSet dataset; private DataSet dataset;
private SimpleTrainingContext context; private SimpleTrainingContext context;
private Model model;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -37,22 +34,16 @@ public class SimplePerceptronTest {
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0)); model = ModelCreator.createModel(new int[]{1}, 2, new Heaviside());
Neuron neuron = new Neuron(syns, bias, new Heaviside()); context = new SimpleTrainingContext(model, dataset);
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new SimpleTrainingContext(); List<AlgorithmStep> steps = List.of(
context.dataset = dataset; new SimplePredictionStep(context),
context.model = network; new SimpleDeltaStep(context),
new SimpleLossStrategy(context),
List<TrainingStep> steps = List.of( new SimpleErrorRegistrationStep(context),
new PredictionStep(new SimplePredictionStep(context)), new SimpleCorrectionStep(context)
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
); );
pipeline = new TrainingPipeline(steps); pipeline = new TrainingPipeline(steps);
@@ -74,7 +65,7 @@ public class SimplePerceptronTest {
context.learningRate = 1F; context.learningRate = 1F;
pipeline.afterEpoch(ctx -> { pipeline.afterEpoch(ctx -> {
int index = ctx.epoch-1; int index = ctx.epoch;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss); assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
}); });