Change signature of train method

This commit is contained in:
2026-03-30 18:28:21 +02:00
parent aed78fe9d2
commit ada01d350b
11 changed files with 100 additions and 26 deletions

View File

@@ -6,8 +6,10 @@ import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.*;
@@ -15,13 +17,13 @@ public class Main {
public static void main(String[] args){
int nbrInput = 25;
int nbrClass = 4;
int[] neuronPerLayer = new int[]{10, nbrClass};
int nbrClass = 1;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass);
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
int nbrInput = dataset.getNbrInputs();
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
@@ -38,7 +40,7 @@ public class Main {
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(syns, bias, new TanH());
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
neurons.add(n);
}
Layer layer = new Layer(neurons);
@@ -48,7 +50,28 @@ public class Main {
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(network, dataset);
trainer.train(0.5F, network, dataset);
/*GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
float min = -2F;
float max = 2F;
float step = 0.01F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
float predSeries = prediction > 0.5F ? 1 : 0;
visualizer.addPoint(Float.toString(predSeries), x, y);
}
}
visualizer.buildScatterGraph();*/
}
}
}

View File

@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public interface Trainer {
void train(Model model, DataSet dataset);
void train(float learningRate, Model model, DataSet dataset);
}

View File

@@ -53,12 +53,11 @@ public class Neuron implements Model {
}
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
this.weightedSum = 0F;
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
this.weightedSum += syn.getWeight() * syn.getInput();
}
this.weightedSum = res;
return this.weightedSum;
}

View File

@@ -124,7 +124,7 @@ public class TrainingPipeline {
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
}
this.visualizer.build();
this.visualizer.buildLineGraph();
}
}

View File

@@ -23,11 +23,11 @@ public class AdalineTraining implements Trainer {
}
@Override
public void train(Model model, DataSet dataset) {
public void train(float learningRate, Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.003F;
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),

View File

@@ -13,14 +13,13 @@ import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(Model model, DataSet dataset) {
public void train(float learningRate, Model model, DataSet dataset) {
GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.1F;
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
@@ -30,8 +29,10 @@ public class GradientBackpropagationTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.epoch == 250)
.withVerbose(true)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
.stopCondition(ctx -> ctx.epoch > 1000000)
.withVerbose(false)
.withTimeMeasurement(true)
.run(context);

View File

@@ -23,11 +23,11 @@ public class GradientDescentTraining implements Trainer {
}
@Override
public void train(Model model, DataSet dataset) {
public void train(float learningRate, Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.0008F;
context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>();
List<AlgorithmStep> steps = List.of(

View File

@@ -16,11 +16,11 @@ public class SimpleTraining implements Trainer {
}
@Override
public void train(Model model, DataSet dataset) {
public void train(float learningRate, Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),

View File

@@ -3,23 +3,52 @@ package com.naaturel.ANN.infrastructure.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
import java.util.stream.Stream;
public class DataSet implements Iterable<DataSetEntry>{
private final Map<DataSetEntry, Labels> data;
private final int nbrInputs;
private final int nbrLabels;
public DataSet() {
this(new LinkedHashMap<>());
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
}
public DataSet(Map<DataSetEntry, Labels> data){
this.data = data;
this.nbrInputs = this.calculateNbrInput();
this.nbrLabels = this.calculateNbrLabel();
}
private int calculateNbrInput(){
//assumes every entry are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
}
private int calculateNbrLabel(){
//assumes every label are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
}
public int size() {
return data.size();
}
public int getNbrInputs() {
return this.nbrInputs;
}
public int getNbrLabels(){
return this.nbrLabels;
}
public List<DataSetEntry> getData(){
return new ArrayList<>(this.data.keySet());
}

View File

@@ -3,6 +3,7 @@ package com.naaturel.ANN.infrastructure.graph;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
@@ -29,7 +30,7 @@ public class GraphVisualizer {
}
}
public void build(){
public void buildLineGraph(){
JFreeChart chart = ChartFactory.createXYLineChart(
"Model learning", "X", "Y", dataset
);
@@ -39,4 +40,21 @@ public class GraphVisualizer {
frame.pack();
frame.setVisible(true);
}
public void buildScatterGraph(){
JFreeChart chart = ChartFactory.createScatterPlot(
"Predictions", "X", "Y", dataset
);
XYPlot plot = chart.getXYPlot();
plot.getDomainAxis().setRange(-2, 2);
plot.getRangeAxis().setRange(-2, 2);
JFrame frame = new JFrame("Predictions");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.add(new ChartPanel(chart));
frame.pack();
frame.setVisible(true);
}
}

View File

@@ -0,0 +1,4 @@
0,0,0
0,1,1
1,0,1
1,1,0
1 0 0 0
2 0 1 1
3 1 0 1
4 1 1 0