From 9f4a76d9e5cf89a98136e2fb41caa9c892725b2d Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 11 May 2026 19:42:20 +0200 Subject: [PATCH] Fix restoration of activation function --- config.json | 2 +- src/main/java/com/naaturel/ANN/Main.java | 1 + .../ANN/domain/model/helpers/ModelCreator.java | 11 +++++++++++ .../ANN/implementation/gradientDescent/Linear.java | 4 ++++ .../ANN/implementation/multiLayers/Sigmoid.java | 7 ++++++- .../ANN/implementation/multiLayers/TanH.java | 5 +++++ .../implementation/simplePerceptron/Heaviside.java | 6 ++++++ .../SimpleErrorRegistrationStep.java | 2 +- .../implementation/training/AdalineTraining.java | 2 +- .../implementation/training/SimpleTraining.java | 2 +- .../infrastructure/persistence/ModelSnapshot.java | 14 +++++++++++++- 11 files changed, 50 insertions(+), 6 deletions(-) diff --git a/config.json b/config.json index 41c2604..97f4010 100644 --- a/config.json +++ b/config.json @@ -5,7 +5,7 @@ "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" }, "training" : { - "learning_rate" : 0.03, + "learning_rate" : 1.0, "max_epoch" : 5000 }, "dataset" : { diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 741ab24..4b73161 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -101,6 +101,7 @@ public class Main { //plot predictions for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ + List predictions = new ArrayList<>(); for (float p : network.predict(new float[]{x, y})) predictions.add(p); diff --git a/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java index 7665ee8..0a43793 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java +++ b/src/main/java/com/naaturel/ANN/domain/model/helpers/ModelCreator.java @@ -3,12 +3,23 @@ package com.naaturel.ANN.domain.model.helpers; import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; import java.util.ArrayList; import java.util.List; public class ModelCreator { + public static Model createPerceptron(int nbrInputs){ + List syns = new ArrayList<>(); + for (int k = 0; k < nbrInputs; k++){ + syns.add(new Synapse(new Input(0), new Weight(0))); + } + + Bias bias = new Bias(new Weight(0)); + return new Neuron(0, syns.toArray(new Synapse[0]), bias, new Heaviside()); + } + public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){ int neuronId = 0; List layers = new ArrayList<>(); diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java index 2caff73..de92ec7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java @@ -23,4 +23,8 @@ public class Linear implements ActivationFunction { return this.slope; } + @Override + public String toString() { + return "linear"; + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java index 95ae036..29076ec 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/Sigmoid.java @@ -5,7 +5,7 @@ import com.naaturel.ANN.domain.model.neuron.Neuron; public class Sigmoid implements ActivationFunction { - private float steepness; + private final float steepness; public Sigmoid(float steepness) { this.steepness = steepness; @@ -20,4 +20,9 @@ public class Sigmoid implements ActivationFunction { public float derivative(float value) { return steepness * value * (1 - value); } + + @Override + public String toString() { + return "sigmoid"; + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java index 59e7b16..b24ad6e 100644 --- a/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java +++ b/src/main/java/com/naaturel/ANN/implementation/multiLayers/TanH.java @@ -18,4 +18,9 @@ public class TanH implements ActivationFunction { public float derivative(float value) { return 1 - value * value; } + + @Override + public String toString() { + return "tanh"; + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java index 5f97847..42ec795 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java @@ -21,4 +21,10 @@ public class Heaviside implements ActivationFunction { public float derivative(float value) { throw new UnsupportedOperationException("Heaviside is not differentiable"); } + + + @Override + public String toString() { + return "heaviside"; + } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java index 07f5295..02a64ba 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleErrorRegistrationStep.java @@ -13,6 +13,6 @@ public class SimpleErrorRegistrationStep implements AlgorithmStep { @Override public void run() { - context.globalLoss += context.localLoss; + context.globalLoss += context.localLoss == 0 ? 0 : 1; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 4b441a1..ee28553 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -45,7 +45,7 @@ public class AdalineTraining implements Trainer { .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) .afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size()) .withTimeMeasurement(true) - .withVerbose(true, 1) + .withVerbose(true, 500) .withVisualization(true, new GraphVisualizer()) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java index a7c3ee0..1929f6a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/SimpleTraining.java @@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer { private Model model; public SimpleTraining(int[] neurons, int nbrInputs){ - model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside()); + model = ModelCreator.createPerceptron(nbrInputs); } @Override diff --git a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java index 37800d1..9094516 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java @@ -3,9 +3,13 @@ package com.naaturel.ANN.infrastructure.persistence; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.implementation.gradientDescent.Linear; +import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; +import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; import java.io.File; import java.util.ArrayList; @@ -28,6 +32,7 @@ public class ModelSnapshot { ObjectNode neuronNode = mapper.createObjectNode(); neuronNode.put("id", n.getId()); + neuronNode.put("func", n.getActivationFunction().toString()); neuronNode.put("layerIndex", model.layerIndexOf(n)); ArrayNode weights = mapper.createArrayNode(); @@ -50,6 +55,13 @@ public class ModelSnapshot { root.forEach(neuronNode -> { int id = neuronNode.get("id").asInt(); int layerIndex = neuronNode.get("layerIndex").asInt(); + ActivationFunction func = switch (neuronNode.get("func").asText()){ + case "heaviside" -> new Heaviside(); + case "linear" -> new Linear(1, 0); + case "sigmoid" -> new Sigmoid(1); + case "tanh" -> new TanH(); + default -> throw new IllegalStateException("Unexpected value: " + neuronNode.get("func").asText()); + }; ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights"); Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue())); @@ -58,7 +70,7 @@ public class ModelSnapshot { synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue())); } - Neuron n = new Neuron(id, synapses, bias, new TanH()); + Neuron n = new Neuron(id, synapses, bias, func); neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n); });