Add JSON config loading
This commit is contained in:
@@ -8,6 +8,8 @@ import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
@@ -22,27 +24,37 @@ public class Main {
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
||||
String modelPath = config.getModelProperty("path", String.class);
|
||||
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
int nbrClass = 1;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{2, 3, dataset.getNbrLabels()};
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||
|
||||
System.out.println(network.synCount());
|
||||
ModelSnapshot snapshot;
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
Model network;
|
||||
if(newModel){
|
||||
network = createNetwork(modelParameters, nbrInput);
|
||||
snapshot = new ModelSnapshot(network);
|
||||
System.out.println("Parameters: " + network.synCount());
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||
} else {
|
||||
snapshot = new ModelSnapshot();
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
|
||||
//plotGraph(dataset, network);
|
||||
ModelSnapshot snapshot = new ModelSnapshot(network);
|
||||
snapshot.saveToFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||
snapshot.loadFromFile("C:/Users/Laurent/Desktop/MASI4-RNA/snapshot.json");
|
||||
Model m = snapshot.getModel();
|
||||
System.out.println();
|
||||
plotGraph(dataset, network);
|
||||
snapshot.saveToFile(modelPath);
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
|
||||
Reference in New Issue
Block a user