From 159e414cb825cef581129ab293ab519c0dcc5105 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 11 May 2026 11:40:01 +0200 Subject: [PATCH] Refactor handling of model creation --- src/main/java/com/naaturel/ANN/Main.java | 45 +++---------------- .../ANN/domain/abstraction/Trainer.java | 4 +- .../domain/model/helpers/ModelCreator.java | 39 ++++++++++++++++ .../training/AdalineTraining.java | 17 ++++++- .../GradientBackpropagationTraining.java | 30 ++++++++++--- .../training/GradientDescentTraining.java | 17 ++++++- .../training/SimpleTraining.java | 17 ++++++- .../persistence/ModelSnapshot.java | 16 ++----- 8 files changed, 121 insertions(+), 64 deletions(-) create mode 100644 src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index ec84328..6d01538 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -40,22 +40,15 @@ public class Main { DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); int nbrInput = dataset.getNbrInputs(); - ModelSnapshot snapshot; + ModelSnapshot snapshot = new ModelSnapshot(); - Model network; - if(newModel){ - network = createNetwork(modelParameters, nbrInput); - snapshot = new ModelSnapshot(network); - System.out.println("Parameters: " + network.synCount()); - Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(learningRate, maxEpoch, network, dataset); - snapshot.saveToFile(modelPath); - } else { - snapshot = new ModelSnapshot(); - snapshot.loadFromFile(modelPath); - network = snapshot.getModel(); + if(newModel) { + Trainer trainer = new GradientBackpropagationTraining(modelParameters, nbrInput); + trainer.train(learningRate, maxEpoch, dataset); + trainer.saveModel(snapshot, modelPath); } + Model network = snapshot.loadFromFile(modelPath); plotGraph(dataset, network); new ModelVisualizer(network) @@ -63,33 +56,7 @@ public class Main { .display(); } - private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ - int neuronId = 0; - List layers = new ArrayList<>(); - for (int i = 0; i < neuronPerLayer.length; i++){ - List neurons = new ArrayList<>(); - for (int j = 0; j < neuronPerLayer[i]; j++){ - - int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; - - List syns = new ArrayList<>(); - for (int k=0; k < nbrSyn; k++){ - syns.add(new Synapse(new Input(0), new Weight())); - } - - Bias bias = new Bias(new Weight()); - - Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); - neurons.add(n); - neuronId++; - } - Layer layer = new Layer(neurons.toArray(new Neuron[0])); - layers.add(layer); - } - - return new FullyConnectedNetwork(layers.toArray(new Layer[0])); - } private static void plotGraph(DataSet dataset, Model network){ diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java index 4286d48..faef043 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Trainer.java @@ -1,7 +1,9 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; public interface Trainer { - void train(float learningRate, int epoch, Model model, DataSet dataset); + void train(float learningRate, int epoch, DataSet dataset); + void saveModel(ModelSnapshot snapshot, String path); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java new file mode 100644 index 0000000..b00e662 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java @@ -0,0 +1,39 @@ +package com.naaturel.ANN.domain.model.helpers; + +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.implementation.multiLayers.TanH; + +import java.util.ArrayList; +import java.util.List; + +public class ModelCreator { + + public static Model createModel(int[] neuronPerLayer, int nbrInput){ + int neuronId = 0; + List layers = new ArrayList<>(); + for (int i = 0; i < neuronPerLayer.length; i++){ + + List neurons = new ArrayList<>(); + for (int j = 0; j < neuronPerLayer[i]; j++){ + + int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1]; + + List syns = new ArrayList<>(); + for (int k=0; k < nbrSyn; k++){ + syns.add(new Synapse(new Input(0), new Weight())); + } + + Bias bias = new Bias(new Weight()); + + Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH()); + neurons.add(n); + neuronId++; + } + Layer layer = new Layer(neurons.toArray(new Neuron[0])); + layers.add(layer); + } + + return new FullyConnectedNetwork(layers.toArray(new Layer[0])); + } +} diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 52c2172..27758ce 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; @@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import java.util.List; @@ -18,12 +20,14 @@ import java.util.List; public class AdalineTraining implements Trainer { - public AdalineTraining(){ + private Model model; + public AdalineTraining(int[] neurons, int nbrInputs){ + model = ModelCreator.createModel(neurons, nbrInputs); } @Override - public void train(float learningRate, int epoch, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, DataSet dataset) { AdalineTrainingContext context = new AdalineTrainingContext(model, dataset); context.learningRate = learningRate; @@ -45,6 +49,15 @@ public class AdalineTraining implements Trainer { .run(context); } + @Override + public void saveModel(ModelSnapshot snapshot, String path) { + try { + snapshot.saveToFile(model, path); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int maxEpoch = 202; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java index 6091357..adc7005 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientBackpropagationTraining.java @@ -3,19 +3,30 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; -import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep; -import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext; -import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep; -import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep; +import com.naaturel.ANN.implementation.multiLayers.*; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; import com.naaturel.ANN.infrastructure.dataset.DataSet; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; + import java.util.List; public class GradientBackpropagationTraining implements Trainer { + + private Model model; + + public GradientBackpropagationTraining(int[] neurons, int nbrInputs){ + model = ModelCreator.createModel(neurons, nbrInputs); + } + + public Model getModel() { + return model; + } + @Override - public void train(float learningRate, int epoch, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, DataSet dataset) { GradientBackpropagationContext context = new GradientBackpropagationContext(model, dataset, learningRate, 10); @@ -39,4 +50,13 @@ public class GradientBackpropagationTraining implements Trainer { .withTimeMeasurement(true) .run(context); } + + @Override + public void saveModel(ModelSnapshot snapshot, String path) { + try { + snapshot.saveToFile(model, path); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index d5e15ad..ef4c0c1 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -3,6 +3,7 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; @@ -11,6 +12,7 @@ import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrection import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import java.util.ArrayList; @@ -18,12 +20,14 @@ import java.util.List; public class GradientDescentTraining implements Trainer { - public GradientDescentTraining(){ + private Model model; + public GradientDescentTraining(int[] neurons, int nbrInputs){ + model = ModelCreator.createModel(neurons, nbrInputs); } @Override - public void train(float learningRate, int epoch, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, DataSet dataset) { GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset); context.learningRate = learningRate; context.correctorTerms = new ArrayList<>(); @@ -55,6 +59,15 @@ public class GradientDescentTraining implements Trainer { .run(context); } + @Override + public void saveModel(ModelSnapshot snapshot, String path) { + try { + snapshot.saveToFile(model, path); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int maxEpoch = 402; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index 0cdf100..4c1c2f8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -3,20 +3,24 @@ package com.naaturel.ANN.implementation.training; import com.naaturel.ANN.domain.abstraction.AlgorithmStep; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Trainer; +import com.naaturel.ANN.domain.model.helpers.ModelCreator; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import java.util.List; public class SimpleTraining implements Trainer { - public SimpleTraining() { + private Model model; + public SimpleTraining(int[] neurons, int nbrInputs){ + model = ModelCreator.createModel(neurons, nbrInputs); } @Override - public void train(float learningRate, int epoch, Model model, DataSet dataset) { + public void train(float learningRate, int epoch, DataSet dataset) { SimpleTrainingContext context = new SimpleTrainingContext(model, dataset); context.dataset = dataset; context.model = model; @@ -38,6 +42,15 @@ public class SimpleTraining implements Trainer { .run(context); } + @Override + public void saveModel(ModelSnapshot snapshot, String path) { + try { + snapshot.saveToFile(model, path); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + /*public void train(Neuron n, float learningRate, DataSet dataSet) { int epoch = 1; int errorCount; diff --git a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java index 70b49ec..37800d1 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java @@ -15,23 +15,13 @@ import java.util.Map; public class ModelSnapshot { - private Model model; private final ObjectMapper mapper; public ModelSnapshot(){ - this(null); - } - - public ModelSnapshot(Model model){ - this.model = model; mapper = new ObjectMapper(); } - public Model getModel() { - return model; - } - - public void saveToFile(String path) throws Exception { + public void saveToFile(Model model, String path) throws Exception { ArrayNode root = mapper.createArrayNode(); model.forEachNeuron(n -> { @@ -52,7 +42,7 @@ public class ModelSnapshot { mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root); } - public void loadFromFile(String path) throws Exception { + public Model loadFromFile(String path) throws Exception { ArrayNode root = (ArrayNode) mapper.readTree(new File(path)); Map> neuronsByLayer = new LinkedHashMap<>(); @@ -76,6 +66,6 @@ public class ModelSnapshot { .map(neurons -> new Layer(neurons.toArray(new Neuron[0]))) .toArray(Layer[]::new); - this.model = new FullyConnectedNetwork(layers); + return new FullyConnectedNetwork(layers); } }