From 572e5c7484593bd44193959a461648c8bf8e3b7b Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 26 Mar 2026 22:35:07 +0100 Subject: [PATCH] Add some plotting --- src/main/java/com/naaturel/ANN/Main.java | 2 +- .../model/training/TrainingPipeline.java | 30 ++++++++++++++++ .../training/AdalineTraining.java | 4 ++- .../infrastructure/graph/GraphVisualizer.java | 35 ++++++++++++++++++- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 9ee6434..8f52b09 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -15,7 +15,7 @@ public class Main { public static void main(String[] args){ - int nbrInput = 3; + int nbrInput = 2; int nbrClass = 3; DataSet dataset = new DatasetExtractor() diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index 3f24142..fc8aa1a 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -3,10 +3,13 @@ package com.naaturel.ANN.domain.model.training; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; +import com.naaturel.ANN.domain.model.neuron.Input; +import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Predicate; @@ -17,7 +20,9 @@ public class TrainingPipeline { private Consumer afterEpoch; private Predicate stopCondition; + private GraphVisualizer visualizer; private boolean verbose; + private boolean visualization; private boolean timeMeasurement; public TrainingPipeline(List steps) { @@ -47,6 +52,12 @@ public class TrainingPipeline { return this; } + public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) { + this.visualization = enabled; + this.visualizer = visualizer; + return this; + } + public TrainingPipeline withTimeMeasurement(boolean enabled) { this.timeMeasurement = enabled; return this; @@ -70,6 +81,25 @@ public class TrainingPipeline { System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0); } + AtomicInteger neuronIndex = new AtomicInteger(0); + ctx.model.forEachNeuron(n -> { + List weights = new ArrayList<>(); + n.forEachSynapse(syn -> weights.add(syn.getWeight())); + + float b = weights.get(0); + float w1 = weights.get(1); + float w2 = weights.get(2); + + this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3); + }); + int i = 0; + for(DataSetEntry entry : ctx.dataset){ + List inputs = entry.getData(); + this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue()); + this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); + i++; + } + this.visualizer.build(); } private void executeSteps(TrainingContext ctx){ diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index 8c4553f..e1dd0b8 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.training.steps.*; +import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import java.util.List; @@ -38,11 +39,12 @@ public class AdalineTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000) + .stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25) .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) .afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size()) .withTimeMeasurement(true) .withVerbose(true) + .withVisualization(true, new GraphVisualizer()) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java index 85f1627..67d2feb 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java @@ -1,9 +1,42 @@ package com.naaturel.ANN.infrastructure.graph; +import org.jfree.chart.ChartFactory; +import org.jfree.chart.ChartPanel; +import org.jfree.chart.JFreeChart; +import org.jfree.data.xy.XYSeries; +import org.jfree.data.xy.XYSeriesCollection; + +import javax.swing.*; + public class GraphVisualizer { - public GraphVisualizer(){ + XYSeriesCollection dataset; + public GraphVisualizer(){ + this.dataset = new XYSeriesCollection(); } + public void addPoint(String title, float x, float y) { + if (this.dataset.getSeriesIndex(title) == -1) + this.dataset.addSeries(new XYSeries(title)); + this.dataset.getSeries(title).add(x, y); + } + + public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) { + for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) { + float x2 = (-y1 * x1 - k) / y2; + addPoint(title, x1, x2); + } + } + + public void build(){ + JFreeChart chart = ChartFactory.createXYLineChart( + "Model learning", "X", "Y", dataset + ); + JFrame frame = new JFrame("Training Loss"); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frame.add(new ChartPanel(chart)); + frame.pack(); + frame.setVisible(true); + } }