Change signature of train method

This commit is contained in:
2026-03-30 18:28:21 +02:00
parent aed78fe9d2
commit ada01d350b
11 changed files with 100 additions and 26 deletions

View File

@@ -6,8 +6,10 @@ import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH; import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.*; import java.util.*;
@@ -15,13 +17,13 @@ public class Main {
public static void main(String[] args){ public static void main(String[] args){
int nbrInput = 25; int nbrClass = 1;
int nbrClass = 4;
int[] neuronPerLayer = new int[]{10, nbrClass};
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>(); List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){ for (int i = 0; i < neuronPerLayer.length; i++){
@@ -38,7 +40,7 @@ public class Main {
Bias bias = new Bias(new Weight()); Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new TanH()); Neuron n = new Neuron(syns, bias, new Sigmoid(2));
neurons.add(n); neurons.add(n);
} }
Layer layer = new Layer(neurons); Layer layer = new Layer(neurons);
@@ -48,7 +50,28 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers); FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(network, dataset); trainer.train(0.5F, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
float min = -2F;
float max = 2F;
float step = 0.01F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
float predSeries = prediction > 0.5F ? 1 : 0;
visualizer.addPoint(Float.toString(predSeries), x, y);
}
}
visualizer.buildScatterGraph();*/
} }
} }

View File

@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
public interface Trainer { public interface Trainer {
void train(Model model, DataSet dataset); void train(float learningRate, Model model, DataSet dataset);
} }

View File

@@ -53,12 +53,11 @@ public class Neuron implements Model {
} }
public float calculateWeightedSum() { public float calculateWeightedSum() {
float res = 0; this.weightedSum = 0F;
res += this.bias.getWeight() * this.bias.getInput(); this.weightedSum += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){ for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput(); this.weightedSum += syn.getWeight() * syn.getInput();
} }
this.weightedSum = res;
return this.weightedSum; return this.weightedSum;
} }

View File

@@ -124,7 +124,7 @@ public class TrainingPipeline {
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++; i++;
} }
this.visualizer.build(); this.visualizer.buildLineGraph();
} }
} }

View File

@@ -23,11 +23,11 @@ public class AdalineTraining implements Trainer {
} }
@Override @Override
public void train(Model model, DataSet dataset) { public void train(float learningRate, Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext(); AdalineTrainingContext context = new AdalineTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.003F; context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),

View File

@@ -13,14 +13,13 @@ import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List; import java.util.List;
public class GradientBackpropagationTraining implements Trainer { public class GradientBackpropagationTraining implements Trainer {
@Override @Override
public void train(Model model, DataSet dataset) { public void train(float learningRate, Model model, DataSet dataset) {
GradientBackpropagationContext context = new GradientBackpropagationContext(); GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.1F; context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),
@@ -30,8 +29,10 @@ public class GradientBackpropagationTraining implements Trainer {
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.epoch == 250) .beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.withVerbose(true) .afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
.stopCondition(ctx -> ctx.epoch > 1000000)
.withVerbose(false)
.withTimeMeasurement(true) .withTimeMeasurement(true)
.run(context); .run(context);

View File

@@ -23,11 +23,11 @@ public class GradientDescentTraining implements Trainer {
} }
@Override @Override
public void train(Model model, DataSet dataset) { public void train(float learningRate, Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(); GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.0008F; context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(

View File

@@ -16,11 +16,11 @@ public class SimpleTraining implements Trainer {
} }
@Override @Override
public void train(Model model, DataSet dataset) { public void train(float learningRate, Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext(); SimpleTrainingContext context = new SimpleTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.3F; context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context), new SimplePredictionStep(context),

View File

@@ -3,23 +3,52 @@ package com.naaturel.ANN.infrastructure.dataset;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*; import java.util.*;
import java.util.stream.Stream;
public class DataSet implements Iterable<DataSetEntry>{ public class DataSet implements Iterable<DataSetEntry>{
private final Map<DataSetEntry, Labels> data; private final Map<DataSetEntry, Labels> data;
private final int nbrInputs;
private final int nbrLabels;
public DataSet() { public DataSet() {
this(new LinkedHashMap<>()); this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
} }
public DataSet(Map<DataSetEntry, Labels> data){ public DataSet(Map<DataSetEntry, Labels> data){
this.data = data; this.data = data;
this.nbrInputs = this.calculateNbrInput();
this.nbrLabels = this.calculateNbrLabel();
} }
private int calculateNbrInput(){
//assumes every entry are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
}
private int calculateNbrLabel(){
//assumes every label are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
}
public int size() { public int size() {
return data.size(); return data.size();
} }
public int getNbrInputs() {
return this.nbrInputs;
}
public int getNbrLabels(){
return this.nbrLabels;
}
public List<DataSetEntry> getData(){ public List<DataSetEntry> getData(){
return new ArrayList<>(this.data.keySet()); return new ArrayList<>(this.data.keySet());
} }

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.infrastructure.graph;
import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel; import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart; import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection; import org.jfree.data.xy.XYSeriesCollection;
@@ -29,7 +30,7 @@ public class GraphVisualizer {
} }
} }
public void build(){ public void buildLineGraph(){
JFreeChart chart = ChartFactory.createXYLineChart( JFreeChart chart = ChartFactory.createXYLineChart(
"Model learning", "X", "Y", dataset "Model learning", "X", "Y", dataset
); );
@@ -39,4 +40,21 @@ public class GraphVisualizer {
frame.pack(); frame.pack();
frame.setVisible(true); frame.setVisible(true);
} }
public void buildScatterGraph(){
JFreeChart chart = ChartFactory.createScatterPlot(
"Predictions", "X", "Y", dataset
);
XYPlot plot = chart.getXYPlot();
plot.getDomainAxis().setRange(-2, 2);
plot.getRangeAxis().setRange(-2, 2);
JFrame frame = new JFrame("Predictions");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.add(new ChartPanel(chart));
frame.pack();
frame.setVisible(true);
}
} }

View File

@@ -0,0 +1,4 @@
0,0,0
0,1,1
1,0,1
1,1,0
1 0 0 0
2 0 1 1
3 1 0 1
4 1 1 0