Optimize prediction
This commit is contained in:
@@ -34,9 +34,9 @@ public class Main {
|
||||
System.out.println(network.synCount());
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.001F, 2000, network, dataset);
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
|
||||
plotGraph(dataset, network);
|
||||
//plotGraph(dataset, network);
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
@@ -83,8 +83,8 @@ public class Main {
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
List<Float> predictions = network.predict(List.of(new Input(x), new Input(y)));
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions.getFirst())), x, y);
|
||||
float[] predictions = network.predict(new float[]{x, y});
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user