Compare commits

...

5 Commits

Author SHA1 Message Date
Laurent
a2452fb4b8 Fix implementation 2026-03-25 16:11:09 +01:00
2936bf33bf Start to reimplement gradient descent 2026-03-23 23:12:52 +01:00
Laurent
1da32862f5 Just a regular commit 2026-03-23 18:47:36 +01:00
Laurent
fbf2a571ef Fix weights correction 2026-03-23 17:13:53 +01:00
Laurent
89d9abe329 Implement main structure of framework 2026-03-23 16:39:12 +01:00
36 changed files with 669 additions and 127 deletions

View File

@@ -3,16 +3,13 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import java.util.*;
@@ -20,43 +17,11 @@ public class Main {
public static void main(String[] args){
DataSet orDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
));
DataSet orDataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -64,14 +29,12 @@ public class Main {
Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear());
Trainer trainer = new AdalineTraining();
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer));
long start = System.currentTimeMillis();
Trainer trainer = new GradientDescentTraining();
trainer.train(network, dataset);
trainer.train(n, 0.03F, andDataSet);
long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
}
}

View File

@@ -0,0 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
public interface AlgorithmStrategy {
void apply();
}

View File

@@ -0,0 +1,14 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public interface Model {
int synCount();
void applyOnSynapses(Consumer<Synapse> consumer);
List<Float> predict(List<Input> inputs);
}

View File

@@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public abstract class Neuron {
public abstract class Neuron implements Model {
protected List<Synapse> synapses;
protected Bias bias;
@@ -19,37 +18,25 @@ public abstract class Neuron {
this.activationFunction = func;
}
public abstract float predict();
public abstract float calculateWeightedSum();
public int getSynCount(){
return this.synapses.size();
}
public void setInput(int index, Input input){
Synapse syn = this.synapses.get(index);
syn.setInput(input.getValue());
}
public Bias getBias(){
return this.bias;
}
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
}
public Synapse getSynapse(int index){
return this.synapses.get(index);
public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue());
}
public List<Synapse> getSynapses() {
return new ArrayList<>(this.synapses);
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
Synapse syn = this.synapses.get(i);
syn.setInput(inputs.get(i));
}
}
public void setWeight(int index, Weight weight){
Synapse syn = this.synapses.get(index);
syn.setWeight(weight.getValue());
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias in account
}
}

View File

@@ -1,13 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
public abstract class NeuronTrainer {
private Trainable trainable;
public NeuronTrainer(Trainable trainable){
this.trainable = trainable;
}
public abstract void train();
}

View File

@@ -3,6 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
public interface Trainer {
void train(Neuron n, float learningRate, DataSet dataSet);
void train(Model model, DataSet dataset);
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public abstract class TrainingContext {
public Model model;
public DataSet dataset;
public DataSetEntry currentEntry;
public Label currentLabel;
public float prediction;
public float delta;
public float globalLoss;
public float localLoss;
public float learningRate;
public int epoch;
}

View File

@@ -1,7 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
public interface Trainable {
public interface TrainingStep {
void run();
}

View File

@@ -1,12 +1,14 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
public class DataSet implements Iterable<DataSetEntry>{
private Map<DataSetEntry, Label> data;
public DataSet(){
public DataSet() {
this(new HashMap<>());
}
@@ -31,15 +33,17 @@ public class DataSet implements Iterable<DataSetEntry>{
float maxAbs = entries.stream()
.flatMap(e -> e.getData().stream())
.map(Input::getValue)
.map(Math::abs)
.max(Float::compare)
.orElse(1.0F);
Map<DataSetEntry, Label> normalized = new HashMap<>();
for (DataSetEntry entry : entries) {
List<Float> normalizedData = new ArrayList<>();
for (float value : entry.getData()) {
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
List<Input> normalizedData = new ArrayList<>();
for (Input input : entry.getData()) {
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
normalizedData.add(normalizedInput);
}
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
}

View File

@@ -1,16 +1,18 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
public class DataSetEntry implements Iterable<Float> {
public class DataSetEntry implements Iterable<Input> {
private List<Float> data;
private List<Input> data;
public DataSetEntry(List<Float> data){
public DataSetEntry(List<Input> data){
this.data = data;
}
public List<Float> getData() {
public List<Input> getData() {
return new ArrayList<>(data);
}
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
}
@Override
public Iterator<Float> iterator() {
public Iterator<Input> iterator() {
return this.data.iterator();
}

View File

@@ -0,0 +1,36 @@
package com.naaturel.ANN.domain.model.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DatasetExtractor {
public DataSet extract(String path) {
Map<DataSetEntry, Label> data = new HashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split(",");
List<Input> inputs = new ArrayList<>();
for (int i = 0; i < parts.length - 1; i++) {
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
}
float label = Float.parseFloat(parts[parts.length - 1].trim());
data.put(new DataSetEntry(inputs), new Label(label));
}
} catch (IOException e) {
throw new RuntimeException("Failed to read dataset from: " + path, e);
}
return new DataSet(data);
}
}

View File

@@ -0,0 +1,41 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Layer implements Model {
private final List<Neuron> neurons;
public Layer(List<Neuron> neurons) {
this.neurons = neurons;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Float> result = new ArrayList<>();
for(Neuron neuron : this.neurons){
List<Float> res = neuron.predict(inputs);
result.addAll(res);
}
return result;
}
@Override
public int synCount() {
int res = 0;
for (Neuron neuron : this.neurons) {
res += neuron.synCount();
}
return res;
}
@Override
public void applyOnSynapses(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
}
}

View File

@@ -0,0 +1,40 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Network implements Model {
private final List<Layer> layers;
public Network(List<Layer> layers) {
this.layers = layers;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Float> result = new ArrayList<>();
for(Layer layer : this.layers){
List<Float> res = layer.predict(inputs);
result.addAll(res);
}
return result;
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public void applyOnSynapses(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
}
}

View File

@@ -14,8 +14,8 @@ public class Synapse {
return this.input.getValue();
}
public void setInput(float value){
this.input.setValue(value);
public void setInput(Input input){
this.input.setValue(input.getValue());
}
public float getWeight() {

View File

@@ -0,0 +1,82 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
public class TrainingPipeline {
private final List<TrainingStep> steps;
private Consumer<TrainingContext> beforeEpoch;
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition;
private boolean verbose;
private boolean timeMeasurement;
public TrainingPipeline(List<TrainingStep> steps) {
this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false;
this.beforeEpoch = (context -> {});
this.afterEpoch = (context -> {});
}
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
this.stopCondition = predicate;
return this;
}
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
this.beforeEpoch = consumer;
return this;
}
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
this.afterEpoch = consumer;
return this;
}
public TrainingPipeline withVerbose(boolean enabled) {
this.verbose = enabled;
return this;
}
public TrainingPipeline withTimeMeasurement(boolean enabled) {
this.timeMeasurement = enabled;
return this;
}
public void run(TrainingContext ctx) {
do {
this.beforeEpoch.accept(ctx);
this.executeSteps(ctx);
this.afterEpoch.accept(ctx);
if(this.verbose) {
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
}
} while (!this.stopCondition.test(ctx));
}
private void executeSteps(TrainingContext ctx){
for (DataSetEntry entry : ctx.dataset) {
ctx.currentEntry = entry;
ctx.currentLabel = ctx.dataset.getLabel(entry);
for (TrainingStep step : steps) {
step.run();
}
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
System.out.printf("delta : %.2f, ", ctx.delta);
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}
ctx.epoch += 1;
}
}

View File

@@ -0,0 +1,25 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> {
float corrector = context.correctorTerms.get(i.get());
float c = syn.getWeight() + corrector;
syn.setWeight(c);
i.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
AtomicInteger i = new AtomicInteger(0);
context.model.applyOnSynapses(syn -> {
float corrector = context.correctorTerms.get(i.get());
corrector += context.learningRate * context.delta * syn.getInput();
context.correctorTerms.set(i.get(), corrector);
i.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List;
public class GradientDescentTrainingContext extends TrainingContext {
public List<Float> correctorTerms;
}

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activationFunction;
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
public class SquareLossStrategy implements AlgorithmStrategy {
private final GradientDescentTrainingContext context;
public SquareLossStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
this.context.globalLoss += context.localLoss;
}
}

View File

@@ -2,21 +2,29 @@ package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public class SimplePerceptron extends Neuron implements Trainable {
public class SimplePerceptron extends Neuron {
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
super(synapses, b, func);
}
@Override
public float predict() {
return activationFunction.accept(this);
public List<Float> predict(List<Input> inputs) {
super.setInputs(inputs);
return List.of(activationFunction.accept(this));
}
@Override
public void applyOnSynapses(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
this.synapses.forEach(consumer);
}
@Override

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activationFunction;
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -0,0 +1,23 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context;
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
if(context.currentLabel.getValue() == context.prediction) return ;
context.model.applyOnSynapses(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * context.delta * currentInput);
syn.setWeight(newValue);
});
}
}

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class SimpleDeltaStrategy implements AlgorithmStrategy {
private final TrainingContext context;
public SimpleDeltaStrategy(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry;
Label label = dataSet.getLabel(entry);
context.delta = label.getValue() - context.prediction;
}
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context;
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
context.globalLoss += context.localLoss;
}
}

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class SimpleLossStrategy implements AlgorithmStrategy {
private final SimpleTrainingContext context;
public SimpleLossStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void apply() {
this.context.localLoss = Math.abs(this.context.delta);
}
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List;
public class SimplePredictionStrategy implements AlgorithmStrategy {
private final TrainingContext context;
public SimplePredictionStrategy(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
List<Float> predictions = context.model.predict(context.currentEntry.getData());
context.prediction = predictions.getFirst();
}
}

View File

@@ -0,0 +1,6 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleTrainingContext extends TrainingContext {
}

View File

@@ -9,7 +9,7 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
public class AdalineTraining implements Trainer {
/*public class AdalineTraining implements Trainer {
public AdalineTraining(){
@@ -78,4 +78,4 @@ public class AdalineTraining implements Trainer {
return (float) Math.pow(delta, 2)/2;
}
}
}*/

View File

@@ -1,16 +1,19 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GradientDescentTraining implements Trainer {
@@ -19,7 +22,38 @@ public class GradientDescentTraining implements Trainer {
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.00011F;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
for (int i = 0; i < model.synCount(); i++){
context.correctorTerms.add(0F);
}
})
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
.withVerbose(true)
.withTimeMeasurement(true)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 402;
float errorThreshold = 0F;
@@ -120,6 +154,6 @@ public class GradientDescentTraining implements Trainer {
variance /= dataSet.size();
return variance;
}
}*/
}

View File

@@ -1,12 +1,14 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List;
public class SimpleTraining implements Trainer {
@@ -14,7 +16,30 @@ public class SimpleTraining implements Trainer {
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int errorCount;
@@ -65,5 +90,5 @@ public class SimpleTraining implements Trainer {
private float calculateLoss(float delta){
return Math.abs(delta);
}
*/
}

View File

@@ -0,0 +1,22 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class DeltaStep implements TrainingStep {
private final AlgorithmStrategy strategy;
public DeltaStep(AlgorithmStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
}
}

View File

@@ -0,0 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class ErrorRegistrationStep implements TrainingStep {
private final AlgorithmStrategy strategy;
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
}
}

View File

@@ -0,0 +1,20 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy;
public LossStep(AlgorithmStrategy strategy) {
this.lossStrategy = strategy;
}
@Override
public void run() {
this.lossStrategy.apply();
}
}

View File

@@ -0,0 +1,23 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.List;
public class PredictionStep implements TrainingStep {
private final SimplePredictionStrategy strategy;
public PredictionStep(SimplePredictionStrategy strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
}
}

View File

@@ -0,0 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class WeightCorrectionStep implements TrainingStep {
private final AlgorithmStrategy correctionStrategy;
public WeightCorrectionStep(AlgorithmStrategy strategy) {
this.correctionStrategy = strategy;
}
@Override
public void run() {
this.correctionStrategy.apply();
}
}