Implement model selector and fix tests
This commit is contained in:
@@ -5,6 +5,8 @@ plugins {
|
||||
group = "be.naaturel"
|
||||
version = "1.0-SNAPSHOT"
|
||||
|
||||
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
@@ -13,11 +15,21 @@ dependencies {
|
||||
implementation("org.jfree:jfreechart:1.5.4")
|
||||
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
|
||||
|
||||
implementation("org.jline:jline:3.27.1")
|
||||
|
||||
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
}
|
||||
|
||||
tasks.jar {
|
||||
manifest {
|
||||
attributes["Main-Class"] = "com.naaturel.ANN.Main"
|
||||
}
|
||||
from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) })
|
||||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||
}
|
||||
|
||||
tasks.test {
|
||||
useJUnitPlatform()
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": true,
|
||||
"parameters": [5, 5, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
|
||||
"parameters": [1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.03,
|
||||
"max_epoch" : 5000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
||||
@@ -21,7 +21,17 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
|
||||
|
||||
Scanner sc = new Scanner(System.in);
|
||||
for (int i = 0; i < types.length; i++) {
|
||||
System.out.printf("%d - %s\n", i+1, types[i]);
|
||||
}
|
||||
|
||||
System.out.print(">>> ");
|
||||
int typeIndex = sc.nextInt() - 1;
|
||||
sc.nextLine();
|
||||
System.out.printf("\nChosen type: %s\n", types[typeIndex]);
|
||||
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
@@ -43,7 +53,13 @@ public class Main {
|
||||
ModelSnapshot snapshot = new ModelSnapshot();
|
||||
|
||||
if(newModel) {
|
||||
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
|
||||
Trainer trainer = switch (typeIndex) {
|
||||
case 0 -> new SimpleTraining(modelParameters, nbrInput);
|
||||
case 1 -> new GradientDescentTraining(modelParameters, nbrInput);
|
||||
case 2 -> new AdalineTraining(modelParameters, nbrInput);
|
||||
case 3 -> new GradientBackpropagationTraining(modelParameters, nbrInput);
|
||||
default -> throw new IllegalStateException("Unexpected value: " + typeIndex);
|
||||
};
|
||||
trainer.train(learningRate, maxEpoch, dataset);
|
||||
trainer.saveModel(snapshot, modelPath);
|
||||
}
|
||||
@@ -56,8 +72,6 @@ public class Main {
|
||||
.display();
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
|
||||
if(dataset.getNbrInputs() != 2) return;
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.naaturel.ANN.domain.model.helpers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ModelCreator {
|
||||
|
||||
public static Model createModel(int[] neuronPerLayer, int nbrInput){
|
||||
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
|
||||
int neuronId = 0;
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
@@ -26,7 +26,7 @@ public class ModelCreator {
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, func);
|
||||
neurons.add(n);
|
||||
neuronId++;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
@@ -23,7 +24,7 @@ public class AdalineTraining implements Trainer {
|
||||
private Model model;
|
||||
|
||||
public AdalineTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -18,7 +18,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
private Model model;
|
||||
|
||||
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
model = ModelCreator.createModel(neurons, nbrInputs, new TanH());
|
||||
}
|
||||
|
||||
public Model getModel() {
|
||||
|
||||
@@ -4,12 +4,9 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
@@ -23,7 +20,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
private Model model;
|
||||
|
||||
public GradientDescentTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
|
||||
private Model model;
|
||||
|
||||
public SimpleTraining(int[] neurons, int nbrInputs){
|
||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
||||
model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package adaline;
|
||||
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
@@ -13,7 +15,6 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -27,9 +28,7 @@ public class AdalineTest {
|
||||
private DataSet dataset;
|
||||
private AdalineTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
private Model model;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -42,22 +41,16 @@ public class AdalineTest {
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
context = new AdalineTrainingContext(model, dataset);
|
||||
|
||||
context = new AdalineTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -23,9 +23,7 @@ public class GradientDescentTest {
|
||||
private DataSet dataset;
|
||||
private GradientDescentTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
private Model model;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -34,26 +32,16 @@ public class GradientDescentTest {
|
||||
dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
context = new GradientDescentTrainingContext(model, dataset);
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new GradientDescentErrorStrategy(context)
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps)
|
||||
@@ -91,7 +79,7 @@ public class GradientDescentTest {
|
||||
});
|
||||
|
||||
pipeline
|
||||
.withVerbose(true)
|
||||
.withVerbose(true, 1)
|
||||
.run(context);
|
||||
assertEquals(67, context.epoch);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package perceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -21,10 +21,7 @@ public class SimplePerceptronTest {
|
||||
|
||||
private DataSet dataset;
|
||||
private SimpleTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
private Model model;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -37,22 +34,16 @@ public class SimplePerceptronTest {
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
model = ModelCreator.createModel(new int[]{1}, 2, new Heaviside());
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
context = new SimpleTrainingContext(model, dataset);
|
||||
|
||||
context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SimpleLossStrategy(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps);
|
||||
@@ -74,7 +65,7 @@ public class SimplePerceptronTest {
|
||||
|
||||
context.learningRate = 1F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
int index = ctx.epoch-1;
|
||||
int index = ctx.epoch;
|
||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user