Add some plotting

This commit is contained in:
2026-03-26 22:35:07 +01:00
parent 64bc830f18
commit 572e5c7484
4 changed files with 68 additions and 3 deletions

View File

@@ -15,7 +15,7 @@ public class Main {
public static void main(String[] args){
int nbrInput = 3;
int nbrInput = 2;
int nbrClass = 3;
DataSet dataset = new DatasetExtractor()

View File

@@ -3,10 +3,13 @@ package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Predicate;
@@ -17,7 +20,9 @@ public class TrainingPipeline {
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition;
private GraphVisualizer visualizer;
private boolean verbose;
private boolean visualization;
private boolean timeMeasurement;
public TrainingPipeline(List<TrainingStep> steps) {
@@ -47,6 +52,12 @@ public class TrainingPipeline {
return this;
}
public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) {
this.visualization = enabled;
this.visualizer = visualizer;
return this;
}
public TrainingPipeline withTimeMeasurement(boolean enabled) {
this.timeMeasurement = enabled;
return this;
@@ -70,6 +81,25 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
}
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
float b = weights.get(0);
float w1 = weights.get(1);
float w2 = weights.get(2);
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
});
int i = 0;
for(DataSetEntry entry : ctx.dataset){
List<Input> inputs = entry.getData();
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
}
this.visualizer.build();
}
private void executeSteps(TrainingContext ctx){

View File

@@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.List;
@@ -38,11 +39,12 @@ public class AdalineTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
.withTimeMeasurement(true)
.withVerbose(true)
.withVisualization(true, new GraphVisualizer())
.run(context);
}

View File

@@ -1,9 +1,42 @@
package com.naaturel.ANN.infrastructure.graph;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import javax.swing.*;
public class GraphVisualizer {
public GraphVisualizer(){
XYSeriesCollection dataset;
public GraphVisualizer(){
this.dataset = new XYSeriesCollection();
}
public void addPoint(String title, float x, float y) {
if (this.dataset.getSeriesIndex(title) == -1)
this.dataset.addSeries(new XYSeries(title));
this.dataset.getSeries(title).add(x, y);
}
public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) {
for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) {
float x2 = (-y1 * x1 - k) / y2;
addPoint(title, x1, x2);
}
}
public void build(){
JFreeChart chart = ChartFactory.createXYLineChart(
"Model learning", "X", "Y", dataset
);
JFrame frame = new JFrame("Training Loss");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.add(new ChartPanel(chart));
frame.pack();
frame.setVisible(true);
}
}