diff --git a/config.json b/config.json new file mode 100644 index 0000000..dd22219 --- /dev/null +++ b/config.json @@ -0,0 +1,14 @@ +{ + "model": { + "new": false, + "parameters": [25, 50, 1], + "path": "C:/Users/Laurent/Desktop/ANN-framework/snapshot.json" + }, + "training" : { + "learning_rate" : 0.01, + "max_epoch" : 2000 + }, + "dataset" : { + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv" + } +} \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 1bf8d8d..45b13be 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear; import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; +import com.naaturel.ANN.infrastructure.config.ConfigDto; +import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; @@ -22,27 +24,37 @@ public class Main { public static void main(String[] args) throws Exception { + ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); + + boolean newModel = config.getModelProperty("new", Boolean.class); + int[] modelParameters = config.getModelProperty("parameters", int[].class); + String modelPath = config.getModelProperty("path", String.class); + int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class); + float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue(); + String datasetPath = config.getDatasetProperty("path", String.class); + int nbrClass = 1; - - DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass); - - int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()}; + DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); int nbrInput = dataset.getNbrInputs(); - FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput); - System.out.println(network.synCount()); + ModelSnapshot snapshot; - Trainer trainer = new GradientBackpropagationTraining(); - trainer.train(0.01F, 2000, network, dataset); + Model network; + if(newModel){ + network = createNetwork(modelParameters, nbrInput); + snapshot = new ModelSnapshot(network); + System.out.println("Parameters: " + network.synCount()); + Trainer trainer = new GradientBackpropagationTraining(); + trainer.train(learningRate, maxEpoch, network, dataset); + } else { + snapshot = new ModelSnapshot(); + snapshot.loadFromFile(modelPath); + network = snapshot.getModel(); + } - //plotGraph(dataset, network); - ModelSnapshot snapshot = new ModelSnapshot(network); - snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json"); - snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json"); - Model m = snapshot.getModel(); - System.out.println(); + plotGraph(dataset, network); + snapshot.saveToFile(modelPath); } private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){ diff --git a/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigDto.java b/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigDto.java new file mode 100644 index 0000000..a5aa313 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigDto.java @@ -0,0 +1,54 @@ +package com.naaturel.ANN.infrastructure.config; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Map; + +public class ConfigDto { + + @JsonProperty("model") + private Map modelConfig; + + @JsonProperty("training") + private Map trainingConfig; + + @JsonProperty("dataset") + private Map datasetConfig; + + public T getModelProperty(String key, Class type) { + Object value = find(key, this.modelConfig); + if (value instanceof List list && type.isArray()) { + int[] arr = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + arr[i] = ((Number) list.get(i)).intValue(); + } + return type.cast(arr); + } + if (!type.isInstance(value)) { + throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName()); + } + return type.cast(value); + } + + public T getTrainingProperty(String key, Class type) { + Object value = find(key, this.trainingConfig); + if (!type.isInstance(value)) { + throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName()); + } + return type.cast(value); + } + + public T getDatasetProperty(String key, Class type) { + Object value = find(key, this.datasetConfig); + if (!type.isInstance(value)) { + throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName()); + } + return type.cast(value); + } + + private Object find(String key, Map config){ + if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'"); + return config.get(key); + } +} diff --git a/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigLoader.java b/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigLoader.java new file mode 100644 index 0000000..0fb1e6c --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/config/ConfigLoader.java @@ -0,0 +1,22 @@ +package com.naaturel.ANN.infrastructure.config; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.File; + +public class ConfigLoader { + + + public static ConfigDto load(String path) throws Exception { + try { + + ObjectMapper mapper = new ObjectMapper(); + ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class); + + return config; + } catch (Exception e){ + throw new Exception("Unable to load config : " + e.getMessage()); + } + } + +}