Start to reimplement gradient descent
This commit is contained in:
@@ -5,6 +5,7 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
@@ -15,49 +16,14 @@ import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import javax.xml.crypto.Data;
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
DataSet orDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
|
||||
DataSet andDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
|
||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
|
||||
));
|
||||
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
@@ -70,7 +36,7 @@ public class Main {
|
||||
Network network = new Network(List.of(layer));
|
||||
|
||||
Trainer trainer = new SimpleTraining();
|
||||
trainer.train(network, orDataSet);
|
||||
trainer.train(network, dataset);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public interface CorrectionStrategy {
|
||||
public interface AlgorithmStrategy {
|
||||
|
||||
void apply(TrainingContext context);
|
||||
void apply(TrainingContext ctx);
|
||||
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSet implements Iterable<DataSetEntry>{
|
||||
@@ -31,15 +33,17 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
|
||||
float maxAbs = entries.stream()
|
||||
.flatMap(e -> e.getData().stream())
|
||||
.map(Input::getValue)
|
||||
.map(Math::abs)
|
||||
.max(Float::compare)
|
||||
.orElse(1.0F);
|
||||
|
||||
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
||||
for (DataSetEntry entry : entries) {
|
||||
List<Float> normalizedData = new ArrayList<>();
|
||||
for (float value : entry.getData()) {
|
||||
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
|
||||
List<Input> normalizedData = new ArrayList<>();
|
||||
for (Input input : entry.getData()) {
|
||||
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
||||
normalizedData.add(normalizedInput);
|
||||
}
|
||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSetEntry implements Iterable<Float> {
|
||||
public class DataSetEntry implements Iterable<Input> {
|
||||
|
||||
private List<Float> data;
|
||||
private List<Input> data;
|
||||
|
||||
public DataSetEntry(List<Float> data){
|
||||
public DataSetEntry(List<Input> data){
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public List<Float> getData() {
|
||||
public List<Input> getData() {
|
||||
return new ArrayList<>(data);
|
||||
}
|
||||
|
||||
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Float> iterator() {
|
||||
public Iterator<Input> iterator() {
|
||||
return this.data.iterator();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class DatasetExtractor {
|
||||
|
||||
public DataSet extract(String path) {
|
||||
Map<DataSetEntry, Label> data = new HashMap<>();
|
||||
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] parts = line.split(",");
|
||||
List<Input> inputs = new ArrayList<>();
|
||||
for (int i = 0; i < parts.length - 1; i++) {
|
||||
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
|
||||
}
|
||||
float label = Float.parseFloat(parts[parts.length - 1].trim());
|
||||
data.put(new DataSetEntry(inputs), new Label(label));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
||||
}
|
||||
|
||||
return new DataSet(data);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,9 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.correction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements CorrectionStrategy {
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
List<Float> correctorTerms;
|
||||
|
||||
public GradientDescentCorrectionStrategy(int nbrCorrectors){
|
||||
this.correctorTerms = new ArrayList<>();
|
||||
for (int i = 0; i < nbrCorrectors; i++){
|
||||
correctorTerms.add(0F);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply(TrainingContext context) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.naaturel.ANN.implementation.correction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleCorrectionStrategy implements CorrectionStrategy {
|
||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
@Override
|
||||
public void apply(TrainingContext context) {
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.loss;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
@Override
|
||||
public void apply(TrainingContext ctx) {
|
||||
ctx.localLoss = Math.abs(ctx.delta);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.loss;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||
@Override
|
||||
public void apply(TrainingContext ctx) {
|
||||
ctx.localLoss = (float)Math.pow(ctx.delta, 2) / 2;
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,10 @@
|
||||
package com.naaturel.ANN.implementation.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
@@ -1,25 +1,50 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.correction.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.loss.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/*public class GradientDescentTraining implements Trainer {
|
||||
public class GradientDescentTraining implements Trainer {
|
||||
|
||||
public GradientDescentTraining(){
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(Trainable model, DataSet dataset) {
|
||||
TrainingContext context = new TrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(),
|
||||
new DeltaStep(),
|
||||
new LossStep(new SquareLossStrategy()),
|
||||
new SimpleErrorDetectionStep(),
|
||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(2))
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.afterEpoch(ctx -> ())
|
||||
.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 402;
|
||||
float errorThreshold = 0F;
|
||||
@@ -120,6 +145,6 @@ import java.util.List;
|
||||
variance /= dataSet.size();
|
||||
|
||||
return variance;
|
||||
}
|
||||
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.loss.SimpleLossStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.List;
|
||||
@@ -27,7 +28,7 @@ public class SimpleTraining implements Trainer {
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(),
|
||||
new DeltaStep(),
|
||||
new SimpleLossStep(),
|
||||
new LossStep(new SimpleLossStrategy()),
|
||||
new SimpleErrorDetectionStep(),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
||||
);
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class LossStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy lossStrategy;
|
||||
|
||||
public LossStep(AlgorithmStrategy strategy) {
|
||||
this.lossStrategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
this.lossStrategy.apply(ctx);
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,7 @@ public class PredictionStep implements TrainingStep {
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
List<Input> inputs = new ArrayList<>();
|
||||
for(Float f : ctx.currentEntry.getData()){
|
||||
inputs.add(new Input(f));
|
||||
}
|
||||
List<Float> predictions = ctx.model.predict(inputs);
|
||||
List<Float> predictions = ctx.model.predict(ctx.currentEntry.getData());
|
||||
ctx.prediction = predictions.getFirst();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class SimpleLossStep implements TrainingStep {
|
||||
|
||||
@Override
|
||||
public void run(TrainingContext ctx) {
|
||||
ctx.localLoss = Math.abs(ctx.delta);
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||
|
||||
public class WeightCorrectionStep implements TrainingStep {
|
||||
|
||||
private final CorrectionStrategy correctionStrategy;
|
||||
private final AlgorithmStrategy correctionStrategy;
|
||||
|
||||
public WeightCorrectionStep(CorrectionStrategy strategy) {
|
||||
public WeightCorrectionStep(AlgorithmStrategy strategy) {
|
||||
this.correctionStrategy = strategy;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user