Implement model selector and fix tests

This commit is contained in:
2026-05-11 14:22:09 +02:00
parent 159e414cb8
commit 613bbbcbe2
11 changed files with 81 additions and 85 deletions

View File

@@ -5,6 +5,8 @@ plugins {
group = "be.naaturel"
version = "1.0-SNAPSHOT"
repositories {
mavenCentral()
}
@@ -13,11 +15,21 @@ dependencies {
implementation("org.jfree:jfreechart:1.5.4")
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
implementation("org.jline:jline:3.27.1")
testImplementation(platform("org.junit:junit-bom:5.10.0"))
testImplementation("org.junit.jupiter:junit-jupiter")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
}
tasks.jar {
manifest {
attributes["Main-Class"] = "com.naaturel.ANN.Main"
}
from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) })
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
}
tasks.test {
useJUnitPlatform()
}

View File

@@ -1,14 +1,14 @@
{
"model": {
"new": true,
"parameters": [5, 5, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
"parameters": [1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
},
"training" : {
"learning_rate" : 0.03,
"max_epoch" : 5000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
}
}

View File

@@ -1,16 +1,16 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
@@ -21,7 +21,17 @@ public class Main {
public static void main(String[] args) throws Exception {
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
Scanner sc = new Scanner(System.in);
for (int i = 0; i < types.length; i++) {
System.out.printf("%d - %s\n", i+1, types[i]);
}
System.out.print(">>> ");
int typeIndex = sc.nextInt() - 1;
sc.nextLine();
System.out.printf("\nChosen type: %s\n", types[typeIndex]);
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
@@ -43,7 +53,13 @@ public class Main {
ModelSnapshot snapshot = new ModelSnapshot();
if(newModel) {
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
Trainer trainer = switch (typeIndex) {
case 0 -> new SimpleTraining(modelParameters, nbrInput);
case 1 -> new GradientDescentTraining(modelParameters, nbrInput);
case 2 -> new AdalineTraining(modelParameters, nbrInput);
case 3 -> new GradientBackpropagationTraining(modelParameters, nbrInput);
default -> throw new IllegalStateException("Unexpected value: " + typeIndex);
};
trainer.train(learningRate, maxEpoch, dataset);
trainer.saveModel(snapshot, modelPath);
}
@@ -56,8 +72,6 @@ public class Main {
.display();
}
private static void plotGraph(DataSet dataset, Model network){
if(dataset.getNbrInputs() != 2) return;

View File

@@ -1,15 +1,15 @@
package com.naaturel.ANN.domain.model.helpers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.util.ArrayList;
import java.util.List;
public class ModelCreator {
public static Model createModel(int[] neuronPerLayer, int nbrInput){
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
@@ -26,7 +26,7 @@ public class ModelCreator {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, func);
neurons.add(n);
neuronId++;
}

View File

@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
@@ -23,7 +24,7 @@ public class AdalineTraining implements Trainer {
private Model model;
public AdalineTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
}
@Override

View File

@@ -18,7 +18,7 @@ public class GradientBackpropagationTraining implements Trainer {
private Model model;
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
model = ModelCreator.createModel(neurons, nbrInputs, new TanH());
}
public Model getModel() {

View File

@@ -4,12 +4,9 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
@@ -23,7 +20,7 @@ public class GradientDescentTraining implements Trainer {
private Model model;
public GradientDescentTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
}
@Override

View File

@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
private Model model;
public SimpleTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
}
@Override

View File

@@ -1,8 +1,10 @@
package adaline;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
@@ -13,7 +15,6 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -27,9 +28,7 @@ public class AdalineTest {
private DataSet dataset;
private AdalineTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private Model model;
private TrainingPipeline pipeline;
@@ -42,22 +41,16 @@ public class AdalineTest {
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new AdalineTrainingContext(model, dataset);
context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
pipeline = new TrainingPipeline(steps)

View File

@@ -1,14 +1,14 @@
package gradientDescent;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -23,9 +23,7 @@ public class GradientDescentTest {
private DataSet dataset;
private GradientDescentTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private Model model;
private TrainingPipeline pipeline;
@@ -34,26 +32,16 @@ public class GradientDescentTest {
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = network;
context = new GradientDescentTrainingContext(model, dataset);
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new GradientDescentErrorStrategy(context)
);
pipeline = new TrainingPipeline(steps)
@@ -91,7 +79,7 @@ public class GradientDescentTest {
});
pipeline
.withVerbose(true)
.withVerbose(true, 1)
.run(context);
assertEquals(67, context.epoch);
}

View File

@@ -1,13 +1,13 @@
package perceptron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -21,10 +21,7 @@ public class SimplePerceptronTest {
private DataSet dataset;
private SimpleTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private Model model;
private TrainingPipeline pipeline;
@@ -37,22 +34,16 @@ public class SimplePerceptronTest {
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
model = ModelCreator.createModel(new int[]{1}, 2, new Heaviside());
Neuron neuron = new Neuron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new SimpleTrainingContext(model, dataset);
context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SimpleLossStrategy(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
pipeline = new TrainingPipeline(steps);
@@ -74,7 +65,7 @@ public class SimplePerceptronTest {
context.learningRate = 1F;
pipeline.afterEpoch(ctx -> {
int index = ctx.epoch-1;
int index = ctx.epoch;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
});