Refactor handling of model creation

This commit is contained in:
2026-05-11 11:40:01 +02:00
parent 45cdab0373
commit 159e414cb8
8 changed files with 121 additions and 64 deletions

View File

@@ -40,22 +40,15 @@ public class Main {
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
ModelSnapshot snapshot; ModelSnapshot snapshot = new ModelSnapshot();
Model network;
if(newModel) { if(newModel) {
network = createNetwork(modelParameters, nbrInput); Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network); trainer.train(learningRate, maxEpoch, dataset);
System.out.println("Parameters: " + network.synCount()); trainer.saveModel(snapshot, modelPath);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
snapshot.saveToFile(modelPath);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
} }
Model network = snapshot.loadFromFile(modelPath);
plotGraph(dataset, network); plotGraph(dataset, network);
new ModelVisualizer(network) new ModelVisualizer(network)
@@ -63,33 +56,7 @@ public class Main {
.display(); .display();
} }
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
private static void plotGraph(DataSet dataset, Model network){ private static void plotGraph(DataSet dataset, Model network){

View File

@@ -1,7 +1,9 @@
package com.naaturel.ANN.domain.abstraction; package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
public interface Trainer { public interface Trainer {
void train(float learningRate, int epoch, Model model, DataSet dataset); void train(float learningRate, int epoch, DataSet dataset);
void saveModel(ModelSnapshot snapshot, String path);
} }

View File

@@ -0,0 +1,39 @@
package com.naaturel.ANN.domain.model.helpers;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.util.ArrayList;
import java.util.List;
public class ModelCreator {
public static Model createModel(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
}

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.List; import java.util.List;
@@ -18,12 +20,14 @@ import java.util.List;
public class AdalineTraining implements Trainer { public class AdalineTraining implements Trainer {
public AdalineTraining(){ private Model model;
public AdalineTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
} }
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset); AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
context.learningRate = learningRate; context.learningRate = learningRate;
@@ -45,6 +49,15 @@ public class AdalineTraining implements Trainer {
.run(context); .run(context);
} }
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) { /*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 202; int maxEpoch = 202;

View File

@@ -3,19 +3,30 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep; import com.naaturel.ANN.implementation.multiLayers.*;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.util.List; import java.util.List;
public class GradientBackpropagationTraining implements Trainer { public class GradientBackpropagationTraining implements Trainer {
private Model model;
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
}
public Model getModel() {
return model;
}
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, DataSet dataset) {
GradientBackpropagationContext context = GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, 10); new GradientBackpropagationContext(model, dataset, learningRate, 10);
@@ -39,4 +50,13 @@ public class GradientBackpropagationTraining implements Trainer {
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
} }

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.ArrayList; import java.util.ArrayList;
@@ -18,12 +20,14 @@ import java.util.List;
public class GradientDescentTraining implements Trainer { public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){ private Model model;
public GradientDescentTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
} }
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset); GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
context.learningRate = learningRate; context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
@@ -55,6 +59,15 @@ public class GradientDescentTraining implements Trainer {
.run(context); .run(context);
} }
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) { /*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 402; int maxEpoch = 402;

View File

@@ -3,20 +3,24 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.util.List; import java.util.List;
public class SimpleTraining implements Trainer { public class SimpleTraining implements Trainer {
public SimpleTraining() { private Model model;
public SimpleTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
} }
@Override @Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) { public void train(float learningRate, int epoch, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset); SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
@@ -38,6 +42,15 @@ public class SimpleTraining implements Trainer {
.run(context); .run(context);
} }
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) { /*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int errorCount; int errorCount;

View File

@@ -15,23 +15,13 @@ import java.util.Map;
public class ModelSnapshot { public class ModelSnapshot {
private Model model;
private final ObjectMapper mapper; private final ObjectMapper mapper;
public ModelSnapshot(){ public ModelSnapshot(){
this(null);
}
public ModelSnapshot(Model model){
this.model = model;
mapper = new ObjectMapper(); mapper = new ObjectMapper();
} }
public Model getModel() { public void saveToFile(Model model, String path) throws Exception {
return model;
}
public void saveToFile(String path) throws Exception {
ArrayNode root = mapper.createArrayNode(); ArrayNode root = mapper.createArrayNode();
model.forEachNeuron(n -> { model.forEachNeuron(n -> {
@@ -52,7 +42,7 @@ public class ModelSnapshot {
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root); mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
} }
public void loadFromFile(String path) throws Exception { public Model loadFromFile(String path) throws Exception {
ArrayNode root = (ArrayNode) mapper.readTree(new File(path)); ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>(); Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
@@ -76,6 +66,6 @@ public class ModelSnapshot {
.map(neurons -> new Layer(neurons.toArray(new Neuron[0]))) .map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
.toArray(Layer[]::new); .toArray(Layer[]::new);
this.model = new FullyConnectedNetwork(layers); return new FullyConnectedNetwork(layers);
} }
} }