Integrate model persistence

This commit is contained in:
2026-04-03 16:13:39 +02:00
parent 5a73337687
commit 87536f5a55
8 changed files with 125 additions and 7 deletions

View File

@@ -13,20 +13,21 @@ import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.io.Console; import java.io.Console;
import java.util.*; import java.util.*;
public class Main { public class Main {
public static void main(String[] args){ public static void main(String[] args) throws Exception {
int nbrClass = 1; int nbrClass = 1;
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
int[] neuronPerLayer = new int[]{50, 50, 25, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
@@ -34,9 +35,14 @@ public class Main {
System.out.println(network.synCount()); System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 2000, network, dataset); trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network); //plotGraph(dataset, network);
ModelSnapshot snapshot = new ModelSnapshot(network);
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
Model m = snapshot.getModel();
System.out.println();
} }
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -78,8 +84,8 @@ public class Main {
}); });
} }
float min = -3F; float min = -0F;
float max = 3F; float max = 10F;
float step = 0.03F; float step = 0.03F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){

View File

@@ -10,6 +10,7 @@ import java.util.function.Consumer;
public interface Model { public interface Model {
int synCount(); int synCount();
int neuronCount(); int neuronCount();
int layerIndexOf(Neuron n);
int indexInLayerOf(Neuron n); int indexInLayerOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer); void forEachNeuron(Consumer<Neuron> consumer);
//void forEachSynapse(Consumer<Synapse> consumer); //void forEachSynapse(Consumer<Synapse> consumer);

View File

@@ -49,7 +49,6 @@ public class FullyConnectedNetwork implements Model {
} }
return res; return res;
} }
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
for(Layer l : this.layers){ for(Layer l : this.layers){
@@ -65,9 +64,15 @@ public class FullyConnectedNetwork implements Model {
@Override @Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) { public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
if(!this.connectionMap.containsKey(n)) return;
this.connectionMap.get(n).forEach(consumer); this.connectionMap.get(n).forEach(consumer);
} }
@Override
public int layerIndexOf(Neuron n) {
return this.layerIndexByNeuron.get(n);
}
@Override @Override
public int indexInLayerOf(Neuron n) { public int indexInLayerOf(Neuron n) {
int layerIndex = this.layerIndexByNeuron.get(n); int layerIndex = this.layerIndexByNeuron.get(n);

View File

@@ -41,6 +41,11 @@ public class Layer implements Model {
return this.neurons.length; return this.neurons.length;
} }
@Override
public int layerIndexOf(Neuron n) {
return 0;
}
@Override @Override
public int indexInLayerOf(Neuron n) { public int indexInLayerOf(Neuron n) {
return this.neuronIndex.get(n); return this.neuronIndex.get(n);

View File

@@ -72,6 +72,11 @@ public class Neuron implements Model {
return 1; return 1;
} }
@Override
public int layerIndexOf(Neuron n) {
return 0;
}
@Override @Override
public int indexInLayerOf(Neuron n) { public int indexInLayerOf(Neuron n) {
return 0; return 0;

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.infrastructure.persistence;
import java.util.List;
public class ModelDto {
private List<NeuronDto> neurons;
}

View File

@@ -0,0 +1,81 @@
package com.naaturel.ANN.infrastructure.persistence;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.io.File;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class ModelSnapshot {
private Model model;
private final ObjectMapper mapper;
public ModelSnapshot(){
this(null);
}
public ModelSnapshot(Model model){
this.model = model;
mapper = new ObjectMapper();
}
public Model getModel() {
return model;
}
public void saveToFile(String path) throws Exception {
ArrayNode root = mapper.createArrayNode();
model.forEachNeuron(n -> {
ObjectNode neuronNode = mapper.createObjectNode();
neuronNode.put("id", n.getId());
neuronNode.put("layerIndex", model.layerIndexOf(n));
ArrayNode weights = mapper.createArrayNode();
for (int i = 0; i < n.synCount(); i++) {
float weight = n.getWeight(i);
weights.add(weight);
}
neuronNode.set("weights", weights);
root.add(neuronNode);
});
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
}
public void loadFromFile(String path) throws Exception {
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
root.forEach(neuronNode -> {
int id = neuronNode.get("id").asInt();
int layerIndex = neuronNode.get("layerIndex").asInt();
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
Synapse[] synapses = new Synapse[weightsNode.size() - 1];
for (int i = 0; i < synapses.length; i++) {
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
}
Neuron n = new Neuron(id, synapses, bias, new TanH());
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
});
Layer[] layers = neuronsByLayer.values().stream()
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
.toArray(Layer[]::new);
this.model = new FullyConnectedNetwork(layers);
}
}

View File

@@ -0,0 +1,4 @@
package com.naaturel.ANN.infrastructure.persistence;
public class NeuronDto {
}