Minor changes
This commit is contained in:
@@ -23,7 +23,7 @@ public class Main {
|
|||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
@@ -51,7 +51,7 @@ public class Main {
|
|||||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.01F, 5000, network, dataset);
|
trainer.train(0.0001F, 15000, network, dataset);
|
||||||
|
|
||||||
GraphVisualizer visualizer = new GraphVisualizer();
|
GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
|
||||||
@@ -60,9 +60,9 @@ public class Main {
|
|||||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
float min = -3F;
|
float min = -5F;
|
||||||
float max = 2F;
|
float max = 5F;
|
||||||
float step = 0.01F;
|
float step = 0.025F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
||||||
@@ -71,8 +71,7 @@ public class Main {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
visualizer.buildScatterGraph((int)min-1, (int)max+1);
|
||||||
visualizer.buildScatterGraph();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
ctx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
})
|
})
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
|
||||||
.withVerbose(false, epoch/10)
|
.withVerbose(true, epoch/10)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,13 +42,13 @@ public class GraphVisualizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void buildScatterGraph(){
|
public void buildScatterGraph(int lower, int upper){
|
||||||
JFreeChart chart = ChartFactory.createScatterPlot(
|
JFreeChart chart = ChartFactory.createScatterPlot(
|
||||||
"Predictions", "X", "Y", dataset
|
"Predictions", "X", "Y", dataset
|
||||||
);
|
);
|
||||||
XYPlot plot = chart.getXYPlot();
|
XYPlot plot = chart.getXYPlot();
|
||||||
plot.getDomainAxis().setRange(-2, 2);
|
plot.getDomainAxis().setRange(lower, upper);
|
||||||
plot.getRangeAxis().setRange(-2, 2);
|
plot.getRangeAxis().setRange(lower, upper);
|
||||||
|
|
||||||
JFrame frame = new JFrame("Predictions");
|
JFrame frame = new JFrame("Predictions");
|
||||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||||
|
|||||||
Reference in New Issue
Block a user