Integrate model persistence
This commit is contained in:
@@ -13,20 +13,21 @@ import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||||
|
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||||
|
|
||||||
import java.io.Console;
|
import java.io.Console;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
public static void main(String[] args){
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
int nbrClass = 1;
|
int nbrClass = 1;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{50, 50, 25, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||||
@@ -34,9 +35,14 @@ public class Main {
|
|||||||
System.out.println(network.synCount());
|
System.out.println(network.synCount());
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.001F, 2000, network, dataset);
|
trainer.train(0.01F, 2000, network, dataset);
|
||||||
|
|
||||||
//plotGraph(dataset, network);
|
//plotGraph(dataset, network);
|
||||||
|
ModelSnapshot snapshot = new ModelSnapshot(network);
|
||||||
|
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||||
|
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||||
|
Model m = snapshot.getModel();
|
||||||
|
System.out.println();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
@@ -78,8 +84,8 @@ public class Main {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
float min = -3F;
|
float min = -0F;
|
||||||
float max = 3F;
|
float max = 10F;
|
||||||
float step = 0.03F;
|
float step = 0.03F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import java.util.function.Consumer;
|
|||||||
public interface Model {
|
public interface Model {
|
||||||
int synCount();
|
int synCount();
|
||||||
int neuronCount();
|
int neuronCount();
|
||||||
|
int layerIndexOf(Neuron n);
|
||||||
int indexInLayerOf(Neuron n);
|
int indexInLayerOf(Neuron n);
|
||||||
void forEachNeuron(Consumer<Neuron> consumer);
|
void forEachNeuron(Consumer<Neuron> consumer);
|
||||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
//void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
for(Layer l : this.layers){
|
for(Layer l : this.layers){
|
||||||
@@ -65,9 +64,15 @@ public class FullyConnectedNetwork implements Model {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||||
|
if(!this.connectionMap.containsKey(n)) return;
|
||||||
this.connectionMap.get(n).forEach(consumer);
|
this.connectionMap.get(n).forEach(consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int layerIndexOf(Neuron n) {
|
||||||
|
return this.layerIndexByNeuron.get(n);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexInLayerOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ public class Layer implements Model {
|
|||||||
return this.neurons.length;
|
return this.neurons.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int layerIndexOf(Neuron n) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexInLayerOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
return this.neuronIndex.get(n);
|
return this.neuronIndex.get(n);
|
||||||
|
|||||||
@@ -72,6 +72,11 @@ public class Neuron implements Model {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int layerIndexOf(Neuron n) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexInLayerOf(Neuron n) {
|
public int indexInLayerOf(Neuron n) {
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.persistence;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class ModelDto {
|
||||||
|
|
||||||
|
private List<NeuronDto> neurons;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.persistence;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ModelSnapshot {
|
||||||
|
|
||||||
|
private Model model;
|
||||||
|
private final ObjectMapper mapper;
|
||||||
|
|
||||||
|
public ModelSnapshot(){
|
||||||
|
this(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ModelSnapshot(Model model){
|
||||||
|
this.model = model;
|
||||||
|
mapper = new ObjectMapper();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Model getModel() {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveToFile(String path) throws Exception {
|
||||||
|
|
||||||
|
ArrayNode root = mapper.createArrayNode();
|
||||||
|
model.forEachNeuron(n -> {
|
||||||
|
|
||||||
|
ObjectNode neuronNode = mapper.createObjectNode();
|
||||||
|
neuronNode.put("id", n.getId());
|
||||||
|
neuronNode.put("layerIndex", model.layerIndexOf(n));
|
||||||
|
|
||||||
|
ArrayNode weights = mapper.createArrayNode();
|
||||||
|
for (int i = 0; i < n.synCount(); i++) {
|
||||||
|
float weight = n.getWeight(i);
|
||||||
|
weights.add(weight);
|
||||||
|
}
|
||||||
|
neuronNode.set("weights", weights);
|
||||||
|
root.add(neuronNode);
|
||||||
|
});
|
||||||
|
|
||||||
|
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadFromFile(String path) throws Exception {
|
||||||
|
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
||||||
|
|
||||||
|
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
root.forEach(neuronNode -> {
|
||||||
|
int id = neuronNode.get("id").asInt();
|
||||||
|
int layerIndex = neuronNode.get("layerIndex").asInt();
|
||||||
|
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
|
||||||
|
|
||||||
|
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
|
||||||
|
Synapse[] synapses = new Synapse[weightsNode.size() - 1];
|
||||||
|
for (int i = 0; i < synapses.length; i++) {
|
||||||
|
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Neuron n = new Neuron(id, synapses, bias, new TanH());
|
||||||
|
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
||||||
|
});
|
||||||
|
|
||||||
|
Layer[] layers = neuronsByLayer.values().stream()
|
||||||
|
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
||||||
|
.toArray(Layer[]::new);
|
||||||
|
|
||||||
|
this.model = new FullyConnectedNetwork(layers);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.persistence;
|
||||||
|
|
||||||
|
public class NeuronDto {
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user