Add JSON config loading

This commit is contained in:
2026-04-03 17:58:28 +02:00
parent 42e6d3dde8
commit 40ebca469e
4 changed files with 117 additions and 15 deletions

14
config.json Normal file
View File

@@ -0,0 +1,14 @@
{
"model": {
"new": false,
"parameters": [25, 50, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/snapshot.json"
},
"training" : {
"learning_rate" : 0.01,
"max_epoch" : 2000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
}
}

View File

@@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
@@ -22,27 +24,37 @@ public class Main {
public static void main(String[] args) throws Exception {
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
boolean newModel = config.getModelProperty("new", Boolean.class);
int[] modelParameters = config.getModelProperty("parameters", int[].class);
String modelPath = config.getModelProperty("path", String.class);
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
System.out.println(network.synCount());
ModelSnapshot snapshot;
Model network;
if(newModel){
network = createNetwork(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network);
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 2000, network, dataset);
trainer.train(learningRate, maxEpoch, network, dataset);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
//plotGraph(dataset, network);
ModelSnapshot snapshot = new ModelSnapshot(network);
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
Model m = snapshot.getModel();
System.out.println();
plotGraph(dataset, network);
snapshot.saveToFile(modelPath);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){

View File

@@ -0,0 +1,54 @@
package com.naaturel.ANN.infrastructure.config;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;
public class ConfigDto {
@JsonProperty("model")
private Map<String, Object> modelConfig;
@JsonProperty("training")
private Map<String, Object> trainingConfig;
@JsonProperty("dataset")
private Map<String, Object> datasetConfig;
public <T> T getModelProperty(String key, Class<T> type) {
Object value = find(key, this.modelConfig);
if (value instanceof List<?> list && type.isArray()) {
int[] arr = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
arr[i] = ((Number) list.get(i)).intValue();
}
return type.cast(arr);
}
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
public <T> T getTrainingProperty(String key, Class<T> type) {
Object value = find(key, this.trainingConfig);
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
public <T> T getDatasetProperty(String key, Class<T> type) {
Object value = find(key, this.datasetConfig);
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
private Object find(String key, Map<String, Object> config){
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
return config.get(key);
}
}

View File

@@ -0,0 +1,22 @@
package com.naaturel.ANN.infrastructure.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
public class ConfigLoader {
public static ConfigDto load(String path) throws Exception {
try {
ObjectMapper mapper = new ObjectMapper();
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
return config;
} catch (Exception e){
throw new Exception("Unable to load config : " + e.getMessage());
}
}
}