Optimize prediction

This commit is contained in:
2026-04-02 09:07:58 +02:00
parent 5ddf6dc580
commit 4c1eaff238
9 changed files with 38 additions and 33 deletions

View File

@@ -34,9 +34,9 @@ public class Main {
System.out.println(network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.001F, 2000, network, dataset);
trainer.train(0.01F, 2000, network, dataset);
plotGraph(dataset, network);
//plotGraph(dataset, network);
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -83,8 +83,8 @@ public class Main {
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
float[] predictions = network.predict(new float[]{x, y});
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
}
}