Fix restoration of activation function

This commit is contained in:
2026-05-11 19:42:20 +02:00
parent 613bbbcbe2
commit 9f4a76d9e5
11 changed files with 50 additions and 6 deletions

View File

@@ -5,7 +5,7 @@
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
},
"training" : {
"learning_rate" : 0.03,
"learning_rate" : 1.0,
"max_epoch" : 5000
},
"dataset" : {

View File

@@ -101,6 +101,7 @@ public class Main {
//plot predictions
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);

View File

@@ -3,12 +3,23 @@ package com.naaturel.ANN.domain.model.helpers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import java.util.ArrayList;
import java.util.List;
public class ModelCreator {
public static Model createPerceptron(int nbrInputs){
List<Synapse> syns = new ArrayList<>();
for (int k = 0; k < nbrInputs; k++){
syns.add(new Synapse(new Input(0), new Weight(0)));
}
Bias bias = new Bias(new Weight(0));
return new Neuron(0, syns.toArray(new Synapse[0]), bias, new Heaviside());
}
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();

View File

@@ -23,4 +23,8 @@ public class Linear implements ActivationFunction {
return this.slope;
}
@Override
public String toString() {
return "linear";
}
}

View File

@@ -5,7 +5,7 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Sigmoid implements ActivationFunction {
private float steepness;
private final float steepness;
public Sigmoid(float steepness) {
this.steepness = steepness;
@@ -20,4 +20,9 @@ public class Sigmoid implements ActivationFunction {
public float derivative(float value) {
return steepness * value * (1 - value);
}
@Override
public String toString() {
return "sigmoid";
}
}

View File

@@ -18,4 +18,9 @@ public class TanH implements ActivationFunction {
public float derivative(float value) {
return 1 - value * value;
}
@Override
public String toString() {
return "tanh";
}
}

View File

@@ -21,4 +21,10 @@ public class Heaviside implements ActivationFunction {
public float derivative(float value) {
throw new UnsupportedOperationException("Heaviside is not differentiable");
}
@Override
public String toString() {
return "heaviside";
}
}

View File

@@ -13,6 +13,6 @@ public class SimpleErrorRegistrationStep implements AlgorithmStep {
@Override
public void run() {
context.globalLoss += context.localLoss;
context.globalLoss += context.localLoss == 0 ? 0 : 1;
}
}

View File

@@ -45,7 +45,7 @@ public class AdalineTraining implements Trainer {
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
.withTimeMeasurement(true)
.withVerbose(true, 1)
.withVerbose(true, 500)
.withVisualization(true, new GraphVisualizer())
.run(context);
}

View File

@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
private Model model;
public SimpleTraining(int[] neurons, int nbrInputs){
model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
model = ModelCreator.createPerceptron(nbrInputs);
}
@Override

View File

@@ -3,9 +3,13 @@ package com.naaturel.ANN.infrastructure.persistence;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import java.io.File;
import java.util.ArrayList;
@@ -28,6 +32,7 @@ public class ModelSnapshot {
ObjectNode neuronNode = mapper.createObjectNode();
neuronNode.put("id", n.getId());
neuronNode.put("func", n.getActivationFunction().toString());
neuronNode.put("layerIndex", model.layerIndexOf(n));
ArrayNode weights = mapper.createArrayNode();
@@ -50,6 +55,13 @@ public class ModelSnapshot {
root.forEach(neuronNode -> {
int id = neuronNode.get("id").asInt();
int layerIndex = neuronNode.get("layerIndex").asInt();
ActivationFunction func = switch (neuronNode.get("func").asText()){
case "heaviside" -> new Heaviside();
case "linear" -> new Linear(1, 0);
case "sigmoid" -> new Sigmoid(1);
case "tanh" -> new TanH();
default -> throw new IllegalStateException("Unexpected value: " + neuronNode.get("func").asText());
};
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
@@ -58,7 +70,7 @@ public class ModelSnapshot {
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
}
Neuron n = new Neuron(id, synapses, bias, new TanH());
Neuron n = new Neuron(id, synapses, bias, func);
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
});