Integrate model persistence

This commit is contained in:
2026-04-03 16:13:39 +02:00
parent 5a73337687
commit 87536f5a55
8 changed files with 125 additions and 7 deletions

View File

@@ -13,20 +13,21 @@ import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import java.io.Console;
import java.util.*;
public class Main {
public static void main(String[] args){
public static void main(String[] args) throws Exception {
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
int[] neuronPerLayer = new int[]{50, 50, 25, dataset.getNbrLabels()};
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
@@ -34,9 +35,14 @@ public class Main {
System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 2000, network, dataset);
trainer.train(0.01F, 2000, network, dataset);
//plotGraph(dataset, network);
ModelSnapshot snapshot = new ModelSnapshot(network);
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
Model m = snapshot.getModel();
System.out.println();
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -78,8 +84,8 @@ public class Main {
});
}
float min = -3F;
float max = 3F;
float min = -0F;
float max = 10F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){