Refactor handling of model creation

This commit is contained in:
2026-05-11 11:40:01 +02:00
parent 45cdab0373
commit 159e414cb8
8 changed files with 121 additions and 64 deletions

View File

@@ -40,22 +40,15 @@ public class Main {
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
ModelSnapshot snapshot;
ModelSnapshot snapshot = new ModelSnapshot();
Model network;
if(newModel) {
network = createNetwork(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network);
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
snapshot.saveToFile(modelPath);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput);
trainer.train(learningRate, maxEpoch, dataset);
trainer.saveModel(snapshot, modelPath);
}
Model network = snapshot.loadFromFile(modelPath);
plotGraph(dataset, network);
new ModelVisualizer(network)
@@ -63,33 +56,7 @@ public class Main {
.display();
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
private static void plotGraph(DataSet dataset, Model network){

View File

@@ -1,7 +1,9 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
public interface Trainer {
void train(float learningRate, int epoch, Model model, DataSet dataset);
void train(float learningRate, int epoch, DataSet dataset);
void saveModel(ModelSnapshot snapshot, String path);
}

View File

@@ -0,0 +1,39 @@
package com.naaturel.ANN.domain.model.helpers;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.util.ArrayList;
import java.util.List;
public class ModelCreator {
public static Model createModel(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
}

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.List;
@@ -18,12 +20,14 @@ import java.util.List;
public class AdalineTraining implements Trainer {
public AdalineTraining(){
private Model model;
public AdalineTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
}
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
context.learningRate = learningRate;
@@ -45,6 +49,15 @@ public class AdalineTraining implements Trainer {
.run(context);
}
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 202;

View File

@@ -3,19 +3,30 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
import com.naaturel.ANN.implementation.multiLayers.*;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.util.List;
public class GradientBackpropagationTraining implements Trainer {
private Model model;
public GradientBackpropagationTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
}
public Model getModel() {
return model;
}
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, DataSet dataset) {
GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, 10);
@@ -39,4 +50,13 @@ public class GradientBackpropagationTraining implements Trainer {
.withTimeMeasurement(true)
.run(context);
}
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
@@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.ArrayList;
@@ -18,12 +20,14 @@ import java.util.List;
public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){
private Model model;
public GradientDescentTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
}
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>();
@@ -55,6 +59,15 @@ public class GradientDescentTraining implements Trainer {
.run(context);
}
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 402;

View File

@@ -3,20 +3,24 @@ package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.helpers.ModelCreator;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.util.List;
public class SimpleTraining implements Trainer {
public SimpleTraining() {
private Model model;
public SimpleTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs);
}
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
public void train(float learningRate, int epoch, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
context.dataset = dataset;
context.model = model;
@@ -38,6 +42,15 @@ public class SimpleTraining implements Trainer {
.run(context);
}
@Override
public void saveModel(ModelSnapshot snapshot, String path) {
try {
snapshot.saveToFile(model, path);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int errorCount;

View File

@@ -15,23 +15,13 @@ import java.util.Map;
public class ModelSnapshot {
private Model model;
private final ObjectMapper mapper;
public ModelSnapshot(){
this(null);
}
public ModelSnapshot(Model model){
this.model = model;
mapper = new ObjectMapper();
}
public Model getModel() {
return model;
}
public void saveToFile(String path) throws Exception {
public void saveToFile(Model model, String path) throws Exception {
ArrayNode root = mapper.createArrayNode();
model.forEachNeuron(n -> {
@@ -52,7 +42,7 @@ public class ModelSnapshot {
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
}
public void loadFromFile(String path) throws Exception {
public Model loadFromFile(String path) throws Exception {
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
@@ -76,6 +66,6 @@ public class ModelSnapshot {
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
.toArray(Layer[]::new);
this.model = new FullyConnectedNetwork(layers);
return new FullyConnectedNetwork(layers);
}
}