diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 20645be..1bf8d8d 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -13,20 +13,21 @@ import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; +import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot; import java.io.Console; import java.util.*; public class Main { - public static void main(String[] args){ + public static void main(String[] args) throws Exception { int nbrClass = 1; DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass); - int[] neuronPerLayer = new int[]{50, 50, 25, dataset.getNbrLabels()}; + int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()}; int nbrInput = dataset.getNbrInputs(); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); @@ -34,9 +35,14 @@ public class Main { System.out.println(network.synCount()); Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.001F, 2000, network, dataset); + trainer.train(0.01F, 2000, network, dataset); //plotGraph(dataset, network); + ModelSnapshot snapshot = new ModelSnapshot(network); + snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json"); + snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json"); + Model m = snapshot.getModel(); + System.out.println(); } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ @@ -78,8 +84,8 @@ public class Main { }); } - float min = -3F; - float max = 3F; + float min = -0F; + float max = 10F; float step = 0.03F; for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index b1f3d8c..239bfed 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -10,6 +10,7 @@ import java.util.function.Consumer; public interface Model { int synCount(); int neuronCount(); + int layerIndexOf(Neuron n); int indexInLayerOf(Neuron n); void forEachNeuron(Consumer consumer); //void forEachSynapse(Consumer consumer); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java index 0e26063..240c3eb 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/FullyConnectedNetwork.java @@ -49,7 +49,6 @@ public class FullyConnectedNetwork implements Model { } return res; } - @Override public void forEachNeuron(Consumer consumer) { for(Layer l : this.layers){ @@ -65,9 +64,15 @@ public class FullyConnectedNetwork implements Model { @Override public void forEachNeuronConnectedTo(Neuron n, Consumer consumer) { + if(!this.connectionMap.containsKey(n)) return; this.connectionMap.get(n).forEach(consumer); } + @Override + public int layerIndexOf(Neuron n) { + return this.layerIndexByNeuron.get(n); + } + @Override public int indexInLayerOf(Neuron n) { int layerIndex = this.layerIndexByNeuron.get(n); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index 9969f99..0f7d5b2 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -41,6 +41,11 @@ public class Layer implements Model { return this.neurons.length; } + @Override + public int layerIndexOf(Neuron n) { + return 0; + } + @Override public int indexInLayerOf(Neuron n) { return this.neuronIndex.get(n); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java index 65fe2c5..6a7e618 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -72,6 +72,11 @@ public class Neuron implements Model { return 1; } + @Override + public int layerIndexOf(Neuron n) { + return 0; + } + @Override public int indexInLayerOf(Neuron n) { return 0; diff --git a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelDto.java b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelDto.java new file mode 100644 index 0000000..b904448 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelDto.java @@ -0,0 +1,11 @@ +package com.naaturel.ANN.infrastructure.persistence; + +import java.util.List; + +public class ModelDto { + + private List neurons; + + + +} diff --git a/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java new file mode 100644 index 0000000..70b49ec --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/persistence/ModelSnapshot.java @@ -0,0 +1,81 @@ +package com.naaturel.ANN.infrastructure.persistence; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.naaturel.ANN.domain.abstraction.Model; +import com.naaturel.ANN.domain.model.neuron.*; +import com.naaturel.ANN.implementation.multiLayers.TanH; + +import java.io.File; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public class ModelSnapshot { + + private Model model; + private final ObjectMapper mapper; + + public ModelSnapshot(){ + this(null); + } + + public ModelSnapshot(Model model){ + this.model = model; + mapper = new ObjectMapper(); + } + + public Model getModel() { + return model; + } + + public void saveToFile(String path) throws Exception { + + ArrayNode root = mapper.createArrayNode(); + model.forEachNeuron(n -> { + + ObjectNode neuronNode = mapper.createObjectNode(); + neuronNode.put("id", n.getId()); + neuronNode.put("layerIndex", model.layerIndexOf(n)); + + ArrayNode weights = mapper.createArrayNode(); + for (int i = 0; i < n.synCount(); i++) { + float weight = n.getWeight(i); + weights.add(weight); + } + neuronNode.set("weights", weights); + root.add(neuronNode); + }); + + mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root); + } + + public void loadFromFile(String path) throws Exception { + ArrayNode root = (ArrayNode) mapper.readTree(new File(path)); + + Map> neuronsByLayer = new LinkedHashMap<>(); + + root.forEach(neuronNode -> { + int id = neuronNode.get("id").asInt(); + int layerIndex = neuronNode.get("layerIndex").asInt(); + ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights"); + + Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue())); + Synapse[] synapses = new Synapse[weightsNode.size() - 1]; + for (int i = 0; i < synapses.length; i++) { + synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue())); + } + + Neuron n = new Neuron(id, synapses, bias, new TanH()); + neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n); + }); + + Layer[] layers = neuronsByLayer.values().stream() + .map(neurons -> new Layer(neurons.toArray(new Neuron[0]))) + .toArray(Layer[]::new); + + this.model = new FullyConnectedNetwork(layers); + } +} diff --git a/src/main/java/com/naaturel/ANN/infrastructure/persistence/NeuronDto.java b/src/main/java/com/naaturel/ANN/infrastructure/persistence/NeuronDto.java new file mode 100644 index 0000000..66be19a --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/persistence/NeuronDto.java @@ -0,0 +1,4 @@ +package com.naaturel.ANN.infrastructure.persistence; + +public class NeuronDto { +}