Tune model parameters

This commit is contained in:
2026-04-03 21:19:27 +02:00
parent 40ebca469e
commit 8beb6aa870
6 changed files with 1257 additions and 17 deletions

View File

@@ -33,11 +33,10 @@ public class Main {
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
int nbrClass = 1;
int nbrClass = 5;
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
ModelSnapshot snapshot;
Model network;
@@ -47,14 +46,13 @@ public class Main {
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
snapshot.saveToFile(modelPath);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
plotGraph(dataset, network);
snapshot.saveToFile(modelPath);
//plotGraph(dataset, network);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -96,8 +94,8 @@ public class Main {
});
}
float min = -0F;
float max = 10F;
float min = -5F;
float max = 5F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){