Tune model parameters
This commit is contained in:
@@ -33,11 +33,10 @@ public class Main {
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
int nbrClass = 1;
|
||||
int nbrClass = 5;
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
|
||||
ModelSnapshot snapshot;
|
||||
|
||||
Model network;
|
||||
@@ -47,14 +46,13 @@ public class Main {
|
||||
System.out.println("Parameters: " + network.synCount());
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||
snapshot.saveToFile(modelPath);
|
||||
} else {
|
||||
snapshot = new ModelSnapshot();
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
|
||||
plotGraph(dataset, network);
|
||||
snapshot.saveToFile(modelPath);
|
||||
//plotGraph(dataset, network);
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
@@ -96,8 +94,8 @@ public class Main {
|
||||
});
|
||||
}
|
||||
|
||||
float min = -0F;
|
||||
float max = 10F;
|
||||
float min = -5F;
|
||||
float max = 5F;
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
|
||||
Reference in New Issue
Block a user