Add JSON config loading
This commit is contained in:
14
config.json
Normal file
14
config.json
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"new": false,
|
||||||
|
"parameters": [25, 50, 1],
|
||||||
|
"path": "C:/Users/Laurent/Desktop/ANN-framework/snapshot.json"
|
||||||
|
},
|
||||||
|
"training" : {
|
||||||
|
"learning_rate" : 0.01,
|
||||||
|
"max_epoch" : 2000
|
||||||
|
},
|
||||||
|
"dataset" : {
|
||||||
|
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
|||||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||||
|
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||||
|
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
@@ -22,27 +24,37 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
|
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||||
|
|
||||||
|
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||||
|
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
||||||
|
String modelPath = config.getModelProperty("path", String.class);
|
||||||
|
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
||||||
|
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||||
|
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||||
|
|
||||||
int nbrClass = 1;
|
int nbrClass = 1;
|
||||||
|
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||||
DataSet dataset = new DatasetExtractor()
|
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
|
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
|
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
|
||||||
|
|
||||||
System.out.println(network.synCount());
|
ModelSnapshot snapshot;
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Model network;
|
||||||
trainer.train(0.01F, 2000, network, dataset);
|
if(newModel){
|
||||||
|
network = createNetwork(modelParameters, nbrInput);
|
||||||
|
snapshot = new ModelSnapshot(network);
|
||||||
|
System.out.println("Parameters: " + network.synCount());
|
||||||
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
|
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||||
|
} else {
|
||||||
|
snapshot = new ModelSnapshot();
|
||||||
|
snapshot.loadFromFile(modelPath);
|
||||||
|
network = snapshot.getModel();
|
||||||
|
}
|
||||||
|
|
||||||
//plotGraph(dataset, network);
|
plotGraph(dataset, network);
|
||||||
ModelSnapshot snapshot = new ModelSnapshot(network);
|
snapshot.saveToFile(modelPath);
|
||||||
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
|
||||||
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
|
||||||
Model m = snapshot.getModel();
|
|
||||||
System.out.println();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.config;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ConfigDto {
|
||||||
|
|
||||||
|
@JsonProperty("model")
|
||||||
|
private Map<String, Object> modelConfig;
|
||||||
|
|
||||||
|
@JsonProperty("training")
|
||||||
|
private Map<String, Object> trainingConfig;
|
||||||
|
|
||||||
|
@JsonProperty("dataset")
|
||||||
|
private Map<String, Object> datasetConfig;
|
||||||
|
|
||||||
|
public <T> T getModelProperty(String key, Class<T> type) {
|
||||||
|
Object value = find(key, this.modelConfig);
|
||||||
|
if (value instanceof List<?> list && type.isArray()) {
|
||||||
|
int[] arr = new int[list.size()];
|
||||||
|
for (int i = 0; i < list.size(); i++) {
|
||||||
|
arr[i] = ((Number) list.get(i)).intValue();
|
||||||
|
}
|
||||||
|
return type.cast(arr);
|
||||||
|
}
|
||||||
|
if (!type.isInstance(value)) {
|
||||||
|
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||||
|
}
|
||||||
|
return type.cast(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T getTrainingProperty(String key, Class<T> type) {
|
||||||
|
Object value = find(key, this.trainingConfig);
|
||||||
|
if (!type.isInstance(value)) {
|
||||||
|
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||||
|
}
|
||||||
|
return type.cast(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T getDatasetProperty(String key, Class<T> type) {
|
||||||
|
Object value = find(key, this.datasetConfig);
|
||||||
|
if (!type.isInstance(value)) {
|
||||||
|
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||||
|
}
|
||||||
|
return type.cast(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object find(String key, Map<String, Object> config){
|
||||||
|
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
|
||||||
|
return config.get(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.config;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
|
||||||
|
public class ConfigLoader {
|
||||||
|
|
||||||
|
|
||||||
|
public static ConfigDto load(String path) throws Exception {
|
||||||
|
try {
|
||||||
|
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
|
||||||
|
|
||||||
|
return config;
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new Exception("Unable to load config : " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user