Add some plotting
This commit is contained in:
@@ -15,7 +15,7 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
int nbrInput = 3;
|
||||
int nbrInput = 2;
|
||||
int nbrClass = 3;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
|
||||
@@ -3,10 +3,13 @@ package com.naaturel.ANN.domain.model.training;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
@@ -17,7 +20,9 @@ public class TrainingPipeline {
|
||||
private Consumer<TrainingContext> afterEpoch;
|
||||
private Predicate<TrainingContext> stopCondition;
|
||||
|
||||
private GraphVisualizer visualizer;
|
||||
private boolean verbose;
|
||||
private boolean visualization;
|
||||
private boolean timeMeasurement;
|
||||
|
||||
public TrainingPipeline(List<TrainingStep> steps) {
|
||||
@@ -47,6 +52,12 @@ public class TrainingPipeline {
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) {
|
||||
this.visualization = enabled;
|
||||
this.visualizer = visualizer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withTimeMeasurement(boolean enabled) {
|
||||
this.timeMeasurement = enabled;
|
||||
return this;
|
||||
@@ -70,6 +81,25 @@ public class TrainingPipeline {
|
||||
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
||||
}
|
||||
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||
|
||||
float b = weights.get(0);
|
||||
float w1 = weights.get(1);
|
||||
float w2 = weights.get(2);
|
||||
|
||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||
});
|
||||
int i = 0;
|
||||
for(DataSetEntry entry : ctx.dataset){
|
||||
List<Input> inputs = entry.getData();
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
}
|
||||
this.visualizer.build();
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -38,11 +39,12 @@ public class AdalineTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > 25)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||
.withTimeMeasurement(true)
|
||||
.withVerbose(true)
|
||||
.withVisualization(true, new GraphVisualizer())
|
||||
.run(context);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,42 @@
|
||||
package com.naaturel.ANN.infrastructure.graph;
|
||||
|
||||
import org.jfree.chart.ChartFactory;
|
||||
import org.jfree.chart.ChartPanel;
|
||||
import org.jfree.chart.JFreeChart;
|
||||
import org.jfree.data.xy.XYSeries;
|
||||
import org.jfree.data.xy.XYSeriesCollection;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
public class GraphVisualizer {
|
||||
|
||||
public GraphVisualizer(){
|
||||
XYSeriesCollection dataset;
|
||||
|
||||
public GraphVisualizer(){
|
||||
this.dataset = new XYSeriesCollection();
|
||||
}
|
||||
|
||||
public void addPoint(String title, float x, float y) {
|
||||
if (this.dataset.getSeriesIndex(title) == -1)
|
||||
this.dataset.addSeries(new XYSeries(title));
|
||||
this.dataset.getSeries(title).add(x, y);
|
||||
}
|
||||
|
||||
public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) {
|
||||
for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) {
|
||||
float x2 = (-y1 * x1 - k) / y2;
|
||||
addPoint(title, x1, x2);
|
||||
}
|
||||
}
|
||||
|
||||
public void build(){
|
||||
JFreeChart chart = ChartFactory.createXYLineChart(
|
||||
"Model learning", "X", "Y", dataset
|
||||
);
|
||||
JFrame frame = new JFrame("Training Loss");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.add(new ChartPanel(chart));
|
||||
frame.pack();
|
||||
frame.setVisible(true);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user