Start to reimplement gradient descent

This commit is contained in:
2026-03-23 23:12:52 +01:00
parent 1da32862f5
commit 2936bf33bf
16 changed files with 157 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext;
@@ -15,49 +16,14 @@ import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.implementation.training.steps.*;
import javax.xml.crypto.Data;
import java.util.*;
public class Main {
public static void main(String[] args){
DataSet orDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
));
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -70,7 +36,7 @@ public class Main {
Network network = new Network(List.of(layer));
Trainer trainer = new SimpleTraining();
trainer.train(network, orDataSet);
trainer.train(network, dataset);
}
}

View File

@@ -2,8 +2,8 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface CorrectionStrategy {
public interface AlgorithmStrategy {
void apply(TrainingContext context);
void apply(TrainingContext ctx);
}

View File

@@ -1,5 +1,7 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
public class DataSet implements Iterable<DataSetEntry>{
@@ -31,15 +33,17 @@ public class DataSet implements Iterable<DataSetEntry>{
float maxAbs = entries.stream()
.flatMap(e -> e.getData().stream())
.map(Input::getValue)
.map(Math::abs)
.max(Float::compare)
.orElse(1.0F);
Map<DataSetEntry, Label> normalized = new HashMap<>();
for (DataSetEntry entry : entries) {
List<Float> normalizedData = new ArrayList<>();
for (float value : entry.getData()) {
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
List<Input> normalizedData = new ArrayList<>();
for (Input input : entry.getData()) {
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
normalizedData.add(normalizedInput);
}
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
}

View File

@@ -1,16 +1,18 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
public class DataSetEntry implements Iterable<Float> {
public class DataSetEntry implements Iterable<Input> {
private List<Float> data;
private List<Input> data;
public DataSetEntry(List<Float> data){
public DataSetEntry(List<Input> data){
this.data = data;
}
public List<Float> getData() {
public List<Input> getData() {
return new ArrayList<>(data);
}
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
}
@Override
public Iterator<Float> iterator() {
public Iterator<Input> iterator() {
return this.data.iterator();
}

View File

@@ -0,0 +1,36 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DatasetExtractor {
public DataSet extract(String path) {
Map<DataSetEntry, Label> data = new HashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split(",");
List<Input> inputs = new ArrayList<>();
for (int i = 0; i < parts.length - 1; i++) {
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
}
float label = Float.parseFloat(parts[parts.length - 1].trim());
data.put(new DataSetEntry(inputs), new Label(label));
}
} catch (IOException e) {
throw new RuntimeException("Failed to read dataset from: " + path, e);
}
return new DataSet(data);
}
}

View File

@@ -1,9 +1,21 @@
package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class GradientDescentCorrectionStrategy implements CorrectionStrategy {
import java.util.ArrayList;
import java.util.List;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
List<Float> correctorTerms;
public GradientDescentCorrectionStrategy(int nbrCorrectors){
this.correctorTerms = new ArrayList<>();
for (int i = 0; i < nbrCorrectors; i++){
correctorTerms.add(0F);
}
}
@Override
public void apply(TrainingContext context) {

View File

@@ -1,9 +1,9 @@
package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleCorrectionStrategy implements CorrectionStrategy {
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext context) {

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.loss;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy {
@Override
public void apply(TrainingContext ctx) {
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
}
}

View File

@@ -1,13 +1,10 @@
package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.List;
import java.util.function.Consumer;

View File

@@ -1,25 +1,50 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/*public class GradientDescentTraining implements Trainer {
public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(Trainable model, DataSet dataset) {
TrainingContext context = new TrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new LossStep(new SquareLossStrategy()),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.afterEpoch(ctx -> ())
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 402;
float errorThreshold = 0F;
@@ -120,6 +145,6 @@ import java.util.List;
variance /= dataSet.size();
return variance;
}
}*/
}

View File

@@ -7,6 +7,7 @@ import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List;
@@ -27,7 +28,7 @@ public class SimpleTraining implements Trainer {
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new SimpleLossStep(),
new LossStep(new SimpleLossStrategy()),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
);

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy;
public LossStep(AlgorithmStrategy strategy) {
this.lossStrategy = strategy;
}
@Override
public void run(TrainingContext ctx) {
this.lossStrategy.apply(ctx);
}
}

View File

@@ -11,11 +11,7 @@ public class PredictionStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
List<Input> inputs = new ArrayList<>();
for(Float f : ctx.currentEntry.getData()){
inputs.add(new Input(f));
}
List<Float> predictions = ctx.model.predict(inputs);
List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
ctx.prediction = predictions.getFirst();
}
}

View File

@@ -1,12 +0,0 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -1,14 +1,14 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class WeightCorrectionStep implements TrainingStep {
private final CorrectionStrategy correctionStrategy;
private final AlgorithmStrategy correctionStrategy;
public WeightCorrectionStep(CorrectionStrategy strategy) {
public WeightCorrectionStep(AlgorithmStrategy strategy) {
this.correctionStrategy = strategy;
}