Implement model selector and fix tests
This commit is contained in:
@@ -5,6 +5,8 @@ plugins {
|
|||||||
group = "be.naaturel"
|
group = "be.naaturel"
|
||||||
version = "1.0-SNAPSHOT"
|
version = "1.0-SNAPSHOT"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
@@ -13,11 +15,21 @@ dependencies {
|
|||||||
implementation("org.jfree:jfreechart:1.5.4")
|
implementation("org.jfree:jfreechart:1.5.4")
|
||||||
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
|
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
|
||||||
|
|
||||||
|
implementation("org.jline:jline:3.27.1")
|
||||||
|
|
||||||
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
||||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tasks.jar {
|
||||||
|
manifest {
|
||||||
|
attributes["Main-Class"] = "com.naaturel.ANN.Main"
|
||||||
|
}
|
||||||
|
from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) })
|
||||||
|
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||||
|
}
|
||||||
|
|
||||||
tasks.test {
|
tasks.test {
|
||||||
useJUnitPlatform()
|
useJUnitPlatform()
|
||||||
}
|
}
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
{
|
{
|
||||||
"model": {
|
"model": {
|
||||||
"new": true,
|
"new": true,
|
||||||
"parameters": [5, 5, 1],
|
"parameters": [1],
|
||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||||
},
|
},
|
||||||
"training" : {
|
"training" : {
|
||||||
"learning_rate" : 0.03,
|
"learning_rate" : 0.03,
|
||||||
"max_epoch" : 5000
|
"max_epoch" : 5000
|
||||||
},
|
},
|
||||||
"dataset" : {
|
"dataset" : {
|
||||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
|
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||||
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
|
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
||||||
@@ -21,7 +21,17 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
|
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
|
||||||
|
|
||||||
Scanner sc = new Scanner(System.in);
|
Scanner sc = new Scanner(System.in);
|
||||||
|
for (int i = 0; i < types.length; i++) {
|
||||||
|
System.out.printf("%d - %s\n", i+1, types[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.print(">>> ");
|
||||||
|
int typeIndex = sc.nextInt() - 1;
|
||||||
|
sc.nextLine();
|
||||||
|
System.out.printf("\nChosen type: %s\n", types[typeIndex]);
|
||||||
|
|
||||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||||
|
|
||||||
@@ -43,7 +53,13 @@ public class Main {
|
|||||||
ModelSnapshot snapshot = new ModelSnapshot();
|
ModelSnapshot snapshot = new ModelSnapshot();
|
||||||
|
|
||||||
if(newModel) {
|
if(newModel) {
|
||||||
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
|
Trainer trainer = switch (typeIndex) {
|
||||||
|
case 0 -> new SimpleTraining(modelParameters, nbrInput);
|
||||||
|
case 1 -> new GradientDescentTraining(modelParameters, nbrInput);
|
||||||
|
case 2 -> new AdalineTraining(modelParameters, nbrInput);
|
||||||
|
case 3 -> new GradientBackpropagationTraining(modelParameters, nbrInput);
|
||||||
|
default -> throw new IllegalStateException("Unexpected value: " + typeIndex);
|
||||||
|
};
|
||||||
trainer.train(learningRate, maxEpoch, dataset);
|
trainer.train(learningRate, maxEpoch, dataset);
|
||||||
trainer.saveModel(snapshot, modelPath);
|
trainer.saveModel(snapshot, modelPath);
|
||||||
}
|
}
|
||||||
@@ -56,8 +72,6 @@ public class Main {
|
|||||||
.display();
|
.display();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private static void plotGraph(DataSet dataset, Model network){
|
private static void plotGraph(DataSet dataset, Model network){
|
||||||
|
|
||||||
if(dataset.getNbrInputs() != 2) return;
|
if(dataset.getNbrInputs() != 2) return;
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
package com.naaturel.ANN.domain.model.helpers;
|
package com.naaturel.ANN.domain.model.helpers;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class ModelCreator {
|
public class ModelCreator {
|
||||||
|
|
||||||
public static Model createModel(int[] neuronPerLayer, int nbrInput){
|
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
|
||||||
int neuronId = 0;
|
int neuronId = 0;
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||||
@@ -26,7 +26,7 @@ public class ModelCreator {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, func);
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
neuronId++;
|
neuronId++;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
@@ -23,7 +24,7 @@ public class AdalineTraining implements Trainer {
|
|||||||
private Model model;
|
private Model model;
|
||||||
|
|
||||||
public AdalineTraining(int[] neurons, int nbrInputs){
|
public AdalineTraining(int[] neurons, int nbrInputs){
|
||||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
private Model model;
|
private Model model;
|
||||||
|
|
||||||
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
|
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
|
||||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
model = ModelCreator.createModel(neurons, nbrInputs, new TanH());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Model getModel() {
|
public Model getModel() {
|
||||||
|
|||||||
@@ -4,12 +4,9 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
@@ -23,7 +20,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
private Model model;
|
private Model model;
|
||||||
|
|
||||||
public GradientDescentTraining(int[] neurons, int nbrInputs){
|
public GradientDescentTraining(int[] neurons, int nbrInputs){
|
||||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
model = ModelCreator.createModel(neurons, nbrInputs, new Linear(1, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
|
|||||||
private Model model;
|
private Model model;
|
||||||
|
|
||||||
public SimpleTraining(int[] neurons, int nbrInputs){
|
public SimpleTraining(int[] neurons, int nbrInputs){
|
||||||
model = ModelCreator.createModel(neurons, nbrInputs);
|
model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package adaline;
|
package adaline;
|
||||||
|
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
@@ -13,7 +15,6 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
|||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -27,9 +28,7 @@ public class AdalineTest {
|
|||||||
private DataSet dataset;
|
private DataSet dataset;
|
||||||
private AdalineTrainingContext context;
|
private AdalineTrainingContext context;
|
||||||
|
|
||||||
private List<Synapse> synapses;
|
private Model model;
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
@@ -42,22 +41,16 @@ public class AdalineTest {
|
|||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
context = new AdalineTrainingContext(model, dataset);
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new AdalineTrainingContext();
|
List<AlgorithmStep> steps = List.of(
|
||||||
context.dataset = dataset;
|
new SimplePredictionStep(context),
|
||||||
context.model = network;
|
new SimpleDeltaStep(context),
|
||||||
|
new SquareLossStep(context),
|
||||||
List<TrainingStep> steps = List.of(
|
new SimpleErrorRegistrationStep(context),
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
new SimpleCorrectionStep(context)
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
|
||||||
new LossStep(new SquareLossStep(context)),
|
|
||||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps)
|
pipeline = new TrainingPipeline(steps)
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
package gradientDescent;
|
package gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -23,9 +23,7 @@ public class GradientDescentTest {
|
|||||||
private DataSet dataset;
|
private DataSet dataset;
|
||||||
private GradientDescentTrainingContext context;
|
private GradientDescentTrainingContext context;
|
||||||
|
|
||||||
private List<Synapse> synapses;
|
private Model model;
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
@@ -34,26 +32,16 @@ public class GradientDescentTest {
|
|||||||
dataset = new DatasetExtractor()
|
dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
model = ModelCreator.createModel(new int[]{1}, 2, new Linear(1, 0));
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
context = new GradientDescentTrainingContext(model, dataset);
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new GradientDescentTrainingContext();
|
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = network;
|
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
new SimplePredictionStep(context),
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
new SimpleDeltaStep(context),
|
||||||
new LossStep(new SquareLossStep(context)),
|
new SquareLossStep(context),
|
||||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
new GradientDescentErrorStrategy(context)
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps)
|
pipeline = new TrainingPipeline(steps)
|
||||||
@@ -91,7 +79,7 @@ public class GradientDescentTest {
|
|||||||
});
|
});
|
||||||
|
|
||||||
pipeline
|
pipeline
|
||||||
.withVerbose(true)
|
.withVerbose(true, 1)
|
||||||
.run(context);
|
.run(context);
|
||||||
assertEquals(67, context.epoch);
|
assertEquals(67, context.epoch);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package perceptron;
|
package perceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -21,10 +21,7 @@ public class SimplePerceptronTest {
|
|||||||
|
|
||||||
private DataSet dataset;
|
private DataSet dataset;
|
||||||
private SimpleTrainingContext context;
|
private SimpleTrainingContext context;
|
||||||
|
private Model model;
|
||||||
private List<Synapse> synapses;
|
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
private TrainingPipeline pipeline;
|
||||||
|
|
||||||
@@ -37,22 +34,16 @@ public class SimplePerceptronTest {
|
|||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
model = ModelCreator.createModel(new int[]{1}, 2, new Heaviside());
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
context = new SimpleTrainingContext(model, dataset);
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new SimpleTrainingContext();
|
List<AlgorithmStep> steps = List.of(
|
||||||
context.dataset = dataset;
|
new SimplePredictionStep(context),
|
||||||
context.model = network;
|
new SimpleDeltaStep(context),
|
||||||
|
new SimpleLossStrategy(context),
|
||||||
List<TrainingStep> steps = List.of(
|
new SimpleErrorRegistrationStep(context),
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
new SimpleCorrectionStep(context)
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
|
||||||
new LossStep(new SimpleLossStrategy(context)),
|
|
||||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
|
||||||
);
|
);
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps);
|
pipeline = new TrainingPipeline(steps);
|
||||||
@@ -74,7 +65,7 @@ public class SimplePerceptronTest {
|
|||||||
|
|
||||||
context.learningRate = 1F;
|
context.learningRate = 1F;
|
||||||
pipeline.afterEpoch(ctx -> {
|
pipeline.afterEpoch(ctx -> {
|
||||||
int index = ctx.epoch-1;
|
int index = ctx.epoch;
|
||||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user