Minor changes

This commit is contained in:
2026-03-30 23:08:37 +02:00
parent 881088df28
commit 165a2bc977
3 changed files with 10 additions and 11 deletions

View File

@@ -23,7 +23,7 @@ public class Main {
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_14.csv", nbrClass);
int[] neuronPerLayer = new int[]{3, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 37, dataset.getNbrLabels()}; int[] neuronPerLayer = new int[]{27, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
@@ -51,7 +51,7 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0])); FullyConnectedNetwork network = new FullyConnectedNetwork(layers.toArray(new Layer[0]));
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(0.01F, 5000, network, dataset); trainer.train(0.0001F, 15000, network, dataset);
GraphVisualizer visualizer = new GraphVisualizer(); GraphVisualizer visualizer = new GraphVisualizer();
@@ -60,9 +60,9 @@ public class Main {
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
} }
float min = -3F; float min = -5F;
float max = 2F; float max = 5F;
float step = 0.01F; float step = 0.025F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst(); float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
@@ -71,8 +71,7 @@ public class Main {
} }
} }
visualizer.buildScatterGraph((int)min-1, (int)max+1);
visualizer.buildScatterGraph();
} }
} }

View File

@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
ctx.globalLoss = 0.0F; ctx.globalLoss = 0.0F;
}) })
.afterEpoch(ctx -> ctx.globalLoss /= dataset.size()) .afterEpoch(ctx -> ctx.globalLoss /= dataset.size())
.withVerbose(false, epoch/10) .withVerbose(true, epoch/10)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);
} }

View File

@@ -42,13 +42,13 @@ public class GraphVisualizer {
} }
public void buildScatterGraph(){ public void buildScatterGraph(int lower, int upper){
JFreeChart chart = ChartFactory.createScatterPlot( JFreeChart chart = ChartFactory.createScatterPlot(
"Predictions", "X", "Y", dataset "Predictions", "X", "Y", dataset
); );
XYPlot plot = chart.getXYPlot(); XYPlot plot = chart.getXYPlot();
plot.getDomainAxis().setRange(-2, 2); plot.getDomainAxis().setRange(lower, upper);
plot.getRangeAxis().setRange(-2, 2); plot.getRangeAxis().setRange(lower, upper);
JFrame frame = new JFrame("Predictions"); JFrame frame = new JFrame("Predictions");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);