Add JSON config loading

This commit is contained in:
2026-04-03 17:58:28 +02:00
parent 42e6d3dde8
commit 40ebca469e
4 changed files with 117 additions and 15 deletions

View File

@@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
@@ -22,27 +24,37 @@ public class Main {
public static void main(String[] args) throws Exception {
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
boolean newModel = config.getModelProperty("new", Boolean.class);
int[] modelParameters = config.getModelProperty("parameters", int[].class);
String modelPath = config.getModelProperty("path", String.class);
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
System.out.println(network.synCount());
ModelSnapshot snapshot;
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
Model network;
if(newModel){
network = createNetwork(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network);
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
//plotGraph(dataset, network);
ModelSnapshot snapshot = new ModelSnapshot(network);
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
Model m = snapshot.getModel();
System.out.println();
plotGraph(dataset, network);
snapshot.saveToFile(modelPath);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){