Start to reimplement gradient descent

This commit is contained in:
2026-03-23 23:12:52 +01:00
parent 1da32862f5
commit 2936bf33bf
16 changed files with 157 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.dataset.Label; import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
@@ -15,49 +16,14 @@ import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.SimpleTraining; import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.implementation.training.steps.*;
import javax.xml.crypto.Data;
import java.util.*; import java.util.*;
public class Main { public class Main {
public static void main(String[] args){ public static void main(String[] args){
DataSet orDataSet = new DataSet(Map.ofEntries( DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
));
List<Synapse> syns = new ArrayList<>(); List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0))); syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -70,7 +36,7 @@ public class Main {
Network network = new Network(List.of(layer)); Network network = new Network(List.of(layer));
Trainer trainer = new SimpleTraining(); Trainer trainer = new SimpleTraining();
trainer.train(network, orDataSet); trainer.train(network, dataset);
} }
} }

View File

@@ -2,8 +2,8 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface CorrectionStrategy { public interface AlgorithmStrategy {
void apply(TrainingContext context); void apply(TrainingContext ctx);
} }

View File

@@ -1,12 +1,14 @@
package com.naaturel.ANN.domain.model.dataset; package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*; import java.util.*;
public class DataSet implements Iterable<DataSetEntry>{ public class DataSet implements Iterable<DataSetEntry>{
private Map<DataSetEntry, Label> data; private Map<DataSetEntry, Label> data;
public DataSet(){ public DataSet() {
this(new HashMap<>()); this(new HashMap<>());
} }
@@ -31,15 +33,17 @@ public class DataSet implements Iterable<DataSetEntry>{
float maxAbs = entries.stream() float maxAbs = entries.stream()
.flatMap(e -> e.getData().stream()) .flatMap(e -> e.getData().stream())
.map(Input::getValue)
.map(Math::abs) .map(Math::abs)
.max(Float::compare) .max(Float::compare)
.orElse(1.0F); .orElse(1.0F);
Map<DataSetEntry, Label> normalized = new HashMap<>(); Map<DataSetEntry, Label> normalized = new HashMap<>();
for (DataSetEntry entry : entries) { for (DataSetEntry entry : entries) {
List<Float> normalizedData = new ArrayList<>(); List<Input> normalizedData = new ArrayList<>();
for (float value : entry.getData()) { for (Input input : entry.getData()) {
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F); Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
normalizedData.add(normalizedInput);
} }
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry)); normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
} }

View File

@@ -1,16 +1,18 @@
package com.naaturel.ANN.domain.model.dataset; package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*; import java.util.*;
public class DataSetEntry implements Iterable<Float> { public class DataSetEntry implements Iterable<Input> {
private List<Float> data; private List<Input> data;
public DataSetEntry(List<Float> data){ public DataSetEntry(List<Input> data){
this.data = data; this.data = data;
} }
public List<Float> getData() { public List<Input> getData() {
return new ArrayList<>(data); return new ArrayList<>(data);
} }
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
} }
@Override @Override
public Iterator<Float> iterator() { public Iterator<Input> iterator() {
return this.data.iterator(); return this.data.iterator();
} }

View File

@@ -0,0 +1,36 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DatasetExtractor {
public DataSet extract(String path) {
Map<DataSetEntry, Label> data = new HashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split(",");
List<Input> inputs = new ArrayList<>();
for (int i = 0; i < parts.length - 1; i++) {
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
}
float label = Float.parseFloat(parts[parts.length - 1].trim());
data.put(new DataSetEntry(inputs), new Label(label));
}
} catch (IOException e) {
throw new RuntimeException("Failed to read dataset from: " + path, e);
}
return new DataSet(data);
}
}

View File

@@ -1,9 +1,21 @@
package com.naaturel.ANN.implementation.correction; package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
public class GradientDescentCorrectionStrategy implements CorrectionStrategy { import java.util.ArrayList;
import java.util.List;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
List<Float> correctorTerms;
public GradientDescentCorrectionStrategy(int nbrCorrectors){
this.correctorTerms = new ArrayList<>();
for (int i = 0; i < nbrCorrectors; i++){
correctorTerms.add(0F);
}
}
@Override @Override
public void apply(TrainingContext context) { public void apply(TrainingContext context) {

View File

@@ -1,9 +1,9 @@
package com.naaturel.ANN.implementation.correction; package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleCorrectionStrategy implements CorrectionStrategy { public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@Override @Override
public void apply(TrainingContext context) { public void apply(TrainingContext context) {

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
}
}

View File

@@ -1,13 +1,10 @@
package com.naaturel.ANN.implementation.neuron; package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.List; import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;

View File

@@ -1,25 +1,50 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry; import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Bias; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.domain.model.neuron.Input; import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.domain.model.neuron.Weight; import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
/*public class GradientDescentTraining implements Trainer { public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){ public GradientDescentTraining(){
} }
public void train(Neuron n, float learningRate, DataSet dataSet) { @Override
public void train(Trainable model, DataSet dataset) {
TrainingContext context = new TrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new LossStep(new SquareLossStrategy()),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.afterEpoch(ctx -> ())
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1; int epoch = 1;
int maxEpoch = 402; int maxEpoch = 402;
float errorThreshold = 0F; float errorThreshold = 0F;
@@ -120,6 +145,6 @@ import java.util.List;
variance /= dataSet.size(); variance /= dataSet.size();
return variance; return variance;
} }*/
}*/ }

View File

@@ -7,6 +7,7 @@ import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List; import java.util.List;
@@ -27,7 +28,7 @@ public class SimpleTraining implements Trainer {
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
new PredictionStep(), new PredictionStep(),
new DeltaStep(), new DeltaStep(),
new SimpleLossStep(), new LossStep(new SimpleLossStrategy()),
new SimpleErrorDetectionStep(), new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy()) new WeightCorrectionStep(new SimpleCorrectionStrategy())
); );

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy;
public LossStep(AlgorithmStrategy strategy) {
this.lossStrategy = strategy;
}
@Override
public void run(TrainingContext ctx) {
this.lossStrategy.apply(ctx);
}
}

View File

@@ -11,11 +11,7 @@ public class PredictionStep implements TrainingStep {
@Override @Override
public void run(TrainingContext ctx) { public void run(TrainingContext ctx) {
List<Input> inputs = new ArrayList<>(); List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
for(Float f : ctx.currentEntry.getData()){
inputs.add(new Input(f));
}
List<Float> predictions = ctx.model.predict(inputs);
ctx.prediction = predictions.getFirst(); ctx.prediction = predictions.getFirst();
} }
} }

View File

@@ -1,12 +0,0 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -1,14 +1,14 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext; import com.naaturel.ANN.domain.model.training.TrainingContext;
public class WeightCorrectionStep implements TrainingStep { public class WeightCorrectionStep implements TrainingStep {
private final CorrectionStrategy correctionStrategy; private final AlgorithmStrategy correctionStrategy;
public WeightCorrectionStep(CorrectionStrategy strategy) { public WeightCorrectionStep(AlgorithmStrategy strategy) {
this.correctionStrategy = strategy; this.correctionStrategy = strategy;
} }