Add JSON config loading
This commit is contained in:
14
config.json
Normal file
14
config.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": false,
|
||||
"parameters": [25, 50, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/snapshot.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.01,
|
||||
"max_epoch" : 2000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
@@ -22,27 +24,37 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
||||
String modelPath = config.getModelProperty("path", String.class);
|
||||
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
int nbrClass = 1;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||
|
||||
System.out.println(network.synCount());
|
||||
ModelSnapshot snapshot;
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
Model network;
|
||||
if(newModel){
|
||||
network = createNetwork(modelParameters, nbrInput);
|
||||
snapshot = new ModelSnapshot(network);
|
||||
System.out.println("Parameters: " + network.synCount());
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||
} else {
|
||||
snapshot = new ModelSnapshot();
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
|
||||
//plotGraph(dataset, network);
|
||||
ModelSnapshot snapshot = new ModelSnapshot(network);
|
||||
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||
Model m = snapshot.getModel();
|
||||
System.out.println();
|
||||
plotGraph(dataset, network);
|
||||
snapshot.saveToFile(modelPath);
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.naaturel.ANN.infrastructure.config;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ConfigDto {
|
||||
|
||||
@JsonProperty("model")
|
||||
private Map<String, Object> modelConfig;
|
||||
|
||||
@JsonProperty("training")
|
||||
private Map<String, Object> trainingConfig;
|
||||
|
||||
@JsonProperty("dataset")
|
||||
private Map<String, Object> datasetConfig;
|
||||
|
||||
public <T> T getModelProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.modelConfig);
|
||||
if (value instanceof List<?> list && type.isArray()) {
|
||||
int[] arr = new int[list.size()];
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
arr[i] = ((Number) list.get(i)).intValue();
|
||||
}
|
||||
return type.cast(arr);
|
||||
}
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
public <T> T getTrainingProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.trainingConfig);
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
public <T> T getDatasetProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.datasetConfig);
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
private Object find(String key, Map<String, Object> config){
|
||||
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
|
||||
return config.get(key);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.naaturel.ANN.infrastructure.config;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class ConfigLoader {
|
||||
|
||||
|
||||
public static ConfigDto load(String path) throws Exception {
|
||||
try {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
|
||||
|
||||
return config;
|
||||
} catch (Exception e){
|
||||
throw new Exception("Unable to load config : " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user