Fix restoration of activation function
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
|
||||||
},
|
},
|
||||||
"training" : {
|
"training" : {
|
||||||
"learning_rate" : 0.03,
|
"learning_rate" : 1.0,
|
||||||
"max_epoch" : 5000
|
"max_epoch" : 5000
|
||||||
},
|
},
|
||||||
"dataset" : {
|
"dataset" : {
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ public class Main {
|
|||||||
//plot predictions
|
//plot predictions
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
|
|
||||||
List<Float> predictions = new ArrayList<>();
|
List<Float> predictions = new ArrayList<>();
|
||||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,23 @@ package com.naaturel.ANN.domain.model.helpers;
|
|||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class ModelCreator {
|
public class ModelCreator {
|
||||||
|
|
||||||
|
public static Model createPerceptron(int nbrInputs){
|
||||||
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
for (int k = 0; k < nbrInputs; k++){
|
||||||
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
return new Neuron(0, syns.toArray(new Synapse[0]), bias, new Heaviside());
|
||||||
|
}
|
||||||
|
|
||||||
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
|
public static Model createModel(int[] neuronPerLayer, int nbrInput, ActivationFunction func){
|
||||||
int neuronId = 0;
|
int neuronId = 0;
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
|
|||||||
@@ -23,4 +23,8 @@ public class Linear implements ActivationFunction {
|
|||||||
return this.slope;
|
return this.slope;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "linear";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|||||||
|
|
||||||
public class Sigmoid implements ActivationFunction {
|
public class Sigmoid implements ActivationFunction {
|
||||||
|
|
||||||
private float steepness;
|
private final float steepness;
|
||||||
|
|
||||||
public Sigmoid(float steepness) {
|
public Sigmoid(float steepness) {
|
||||||
this.steepness = steepness;
|
this.steepness = steepness;
|
||||||
@@ -20,4 +20,9 @@ public class Sigmoid implements ActivationFunction {
|
|||||||
public float derivative(float value) {
|
public float derivative(float value) {
|
||||||
return steepness * value * (1 - value);
|
return steepness * value * (1 - value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "sigmoid";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,4 +18,9 @@ public class TanH implements ActivationFunction {
|
|||||||
public float derivative(float value) {
|
public float derivative(float value) {
|
||||||
return 1 - value * value;
|
return 1 - value * value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "tanh";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,4 +21,10 @@ public class Heaviside implements ActivationFunction {
|
|||||||
public float derivative(float value) {
|
public float derivative(float value) {
|
||||||
throw new UnsupportedOperationException("Heaviside is not differentiable");
|
throw new UnsupportedOperationException("Heaviside is not differentiable");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "heaviside";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,6 @@ public class SimpleErrorRegistrationStep implements AlgorithmStep {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
context.globalLoss += context.localLoss;
|
context.globalLoss += context.localLoss == 0 ? 0 : 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ public class AdalineTraining implements Trainer {
|
|||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.withVerbose(true, 1)
|
.withVerbose(true, 500)
|
||||||
.withVisualization(true, new GraphVisualizer())
|
.withVisualization(true, new GraphVisualizer())
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ public class SimpleTraining implements Trainer {
|
|||||||
private Model model;
|
private Model model;
|
||||||
|
|
||||||
public SimpleTraining(int[] neurons, int nbrInputs){
|
public SimpleTraining(int[] neurons, int nbrInputs){
|
||||||
model = ModelCreator.createModel(neurons, nbrInputs, new Heaviside());
|
model = ModelCreator.createPerceptron(nbrInputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -3,9 +3,13 @@ package com.naaturel.ANN.infrastructure.persistence;
|
|||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -28,6 +32,7 @@ public class ModelSnapshot {
|
|||||||
|
|
||||||
ObjectNode neuronNode = mapper.createObjectNode();
|
ObjectNode neuronNode = mapper.createObjectNode();
|
||||||
neuronNode.put("id", n.getId());
|
neuronNode.put("id", n.getId());
|
||||||
|
neuronNode.put("func", n.getActivationFunction().toString());
|
||||||
neuronNode.put("layerIndex", model.layerIndexOf(n));
|
neuronNode.put("layerIndex", model.layerIndexOf(n));
|
||||||
|
|
||||||
ArrayNode weights = mapper.createArrayNode();
|
ArrayNode weights = mapper.createArrayNode();
|
||||||
@@ -50,6 +55,13 @@ public class ModelSnapshot {
|
|||||||
root.forEach(neuronNode -> {
|
root.forEach(neuronNode -> {
|
||||||
int id = neuronNode.get("id").asInt();
|
int id = neuronNode.get("id").asInt();
|
||||||
int layerIndex = neuronNode.get("layerIndex").asInt();
|
int layerIndex = neuronNode.get("layerIndex").asInt();
|
||||||
|
ActivationFunction func = switch (neuronNode.get("func").asText()){
|
||||||
|
case "heaviside" -> new Heaviside();
|
||||||
|
case "linear" -> new Linear(1, 0);
|
||||||
|
case "sigmoid" -> new Sigmoid(1);
|
||||||
|
case "tanh" -> new TanH();
|
||||||
|
default -> throw new IllegalStateException("Unexpected value: " + neuronNode.get("func").asText());
|
||||||
|
};
|
||||||
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
|
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
|
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
|
||||||
@@ -58,7 +70,7 @@ public class ModelSnapshot {
|
|||||||
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
|
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Neuron n = new Neuron(id, synapses, bias, new TanH());
|
Neuron n = new Neuron(id, synapses, bias, func);
|
||||||
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user