From 64bc830f187c62b071a08848faa0d4356a406c50 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 26 Mar 2026 21:21:31 +0100 Subject: [PATCH] Add multi-layer support --- build.gradle.kts | 2 + src/main/java/com/naaturel/ANN/Main.java | 30 +++++---- .../abstraction/ActivationFunction.java | 2 + .../ANN/domain/abstraction/Model.java | 3 + .../ANN/domain/abstraction/Neuron.java | 42 ------------ .../domain/abstraction/TrainingContext.java | 9 +-- .../ANN/domain/model/dataset/DataSet.java | 12 ++-- .../model/dataset/DatasetExtractor.java | 22 ++++-- .../ANN/domain/model/dataset/Label.java | 15 ----- .../ANN/domain/model/dataset/Labels.java | 16 +++++ .../ANN/domain/model/neuron/Layer.java | 6 +- .../ANN/domain/model/neuron/Network.java | 5 ++ .../ANN/domain/model/neuron/Neuron.java | 67 +++++++++++++++++++ .../model/training/TrainingPipeline.java | 14 ++-- .../GradientDescentErrorStrategy.java | 22 ++++-- .../gradientDescent/Linear.java | 2 +- .../gradientDescent/SquareLossStrategy.java | 6 +- .../neuron/SimplePerceptron.java | 40 ----------- .../simplePerceptron/Heaviside.java | 2 +- .../SimpleCorrectionStrategy.java | 22 ++++-- .../simplePerceptron/SimpleDeltaStrategy.java | 15 ++++- .../simplePerceptron/SimpleLossStrategy.java | 2 +- .../SimplePredictionStrategy.java | 3 +- .../training/AdalineTraining.java | 4 +- .../training/GradientDescentTraining.java | 2 +- .../training/steps/DeltaStep.java | 4 -- .../infrastructure/graph/GraphVisualizer.java | 9 +++ src/test/java/adaline/AdalineTest.java | 7 +- .../gradientDescent/GradientDescentTest.java | 8 +-- .../java/perceptron/SimplePerceptronTest.java | 7 +- 30 files changed, 228 insertions(+), 172 deletions(-) delete mode 100644 src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java delete mode 100644 src/main/java/com/naaturel/ANN/domain/model/dataset/Label.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java create mode 100644 src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java delete mode 100644 src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java create mode 100644 src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java diff --git a/build.gradle.kts b/build.gradle.kts index d65d34b..a60785f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -10,6 +10,8 @@ repositories { } dependencies { + implementation("org.jfree:jfreechart:1.5.4") + testImplementation(platform("org.junit:junit-bom:5.10.0")) testImplementation("org.junit.jupiter:junit-jupiter") testRuntimeOnly("org.junit.platform:junit-platform-launcher") diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 9c25e08..9ee6434 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -1,16 +1,13 @@ package com.naaturel.ANN; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.implementation.gradientDescent.Linear; -import com.naaturel.ANN.implementation.simplePerceptron.Heaviside; -import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.training.AdalineTraining; import com.naaturel.ANN.implementation.training.GradientDescentTraining; -import com.naaturel.ANN.implementation.training.SimpleTraining; import java.util.*; @@ -18,20 +15,27 @@ public class Main { public static void main(String[] args){ + int nbrInput = 3; + int nbrClass = 3; + DataSet dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv"); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); - DataSet andDataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); + List neurons = new ArrayList<>(); - List syns = new ArrayList<>(); - syns.add(new Synapse(new Input(0), new Weight(0))); - syns.add(new Synapse(new Input(0), new Weight(0))); + for (int i=0; i < nbrClass; i++){ + List syns = new ArrayList<>(); + for (int j=0; j < nbrInput; j++){ + syns.add(new Synapse(new Input(0), new Weight(0))); + } - Bias bias = new Bias(new Weight(0)); + Bias bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); - Layer layer = new Layer(List.of(neuron)); + Neuron n = new Neuron(syns, bias, new Linear()); + neurons.add(n); + } + + Layer layer = new Layer(neurons); Network network = new Network(List.of(layer)); Trainer trainer = new AdalineTraining(); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java b/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java index 856f95c..d47ac8a 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/ActivationFunction.java @@ -1,5 +1,7 @@ package com.naaturel.ANN.domain.abstraction; +import com.naaturel.ANN.domain.model.neuron.Neuron; + public interface ActivationFunction { float accept(Neuron n); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java index d850f9c..011618c 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/Model.java @@ -1,13 +1,16 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.domain.model.neuron.Input; +import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Synapse; import java.util.List; +import java.util.function.BiConsumer; import java.util.function.Consumer; public interface Model { int synCount(); + void forEachNeuron(Consumer consumer); void forEachSynapse(Consumer consumer); List predict(List inputs); diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java b/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java deleted file mode 100644 index b8664b1..0000000 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/Neuron.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.naaturel.ANN.domain.abstraction; -import com.naaturel.ANN.domain.model.neuron.Bias; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.neuron.Synapse; -import com.naaturel.ANN.domain.model.neuron.Weight; - -import java.util.List; - -public abstract class Neuron implements Model { - - protected List synapses; - protected Bias bias; - protected ActivationFunction activationFunction; - - public Neuron(List synapses, Bias bias, ActivationFunction func){ - this.synapses = synapses; - this.bias = bias; - this.activationFunction = func; - } - - public abstract float calculateWeightedSum(); - - public void updateBias(Weight weight) { - this.bias.setWeight(weight.getValue()); - } - - public void updateWeight(int index, Weight weight) { - this.synapses.get(index).setWeight(weight.getValue()); - } - - protected void setInputs(List inputs){ - for(int i = 0; i < inputs.size() && i < synapses.size(); i++){ - Synapse syn = this.synapses.get(i); - syn.setInput(inputs.get(i)); - } - } - - @Override - public int synCount() { - return this.synapses.size()+1; //take the bias in account - } -} diff --git a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java index e329ccd..14dc8c9 100644 --- a/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java +++ b/src/main/java/com/naaturel/ANN/domain/abstraction/TrainingContext.java @@ -2,16 +2,17 @@ package com.naaturel.ANN.domain.abstraction; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.dataset.Label; + +import java.util.List; public abstract class TrainingContext { public Model model; public DataSet dataset; public DataSetEntry currentEntry; - public Label currentLabel; - public float prediction; - public float delta; + public List expectations; + public List predictions; + public List deltas; public float globalLoss; public float localLoss; diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java index dd10ca4..f9ba91a 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DataSet.java @@ -6,13 +6,13 @@ import java.util.*; public class DataSet implements Iterable{ - private Map data; + private final Map data; public DataSet() { this(new LinkedHashMap<>()); } - public DataSet(Map data){ + public DataSet(Map data){ this.data = data; } @@ -24,8 +24,8 @@ public class DataSet implements Iterable{ return new ArrayList<>(this.data.keySet()); } - public Label getLabel(DataSetEntry entry){ - return this.data.get(entry); + public List getLabelsAsFloat(DataSetEntry entry){ + return this.data.get(entry).getValues(); } public DataSet toNormalized() { @@ -38,13 +38,15 @@ public class DataSet implements Iterable{ .max(Float::compare) .orElse(1.0F); - Map normalized = new HashMap<>(); + Map normalized = new HashMap<>(); for (DataSetEntry entry : entries) { List normalizedData = new ArrayList<>(); + for (Input input : entry.getData()) { Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F); normalizedData.add(normalizedInput); } + normalized.put(new DataSetEntry(normalizedData), this.data.get(entry)); } diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java index 4f2688f..f3fd04a 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/DatasetExtractor.java @@ -9,19 +9,29 @@ import java.util.*; public class DatasetExtractor { - public DataSet extract(String path) { - Map data = new LinkedHashMap<>(); + public DataSet extract(String path, int nbrLabels) { + Map data = new LinkedHashMap<>(); try (BufferedReader reader = new BufferedReader(new FileReader(path))) { String line; while ((line = reader.readLine()) != null) { String[] parts = line.split(","); + + String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels); + String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length); + List inputs = new ArrayList<>(); - for (int i = 0; i < parts.length - 1; i++) { - inputs.add(new Input(Float.parseFloat(parts[i].trim()))); + List labels = new ArrayList<>(); + + for (String entry : rawInputs) { + inputs.add(new Input(Float.parseFloat(entry.trim()))); } - float label = Float.parseFloat(parts[parts.length - 1].trim()); - data.put(new DataSetEntry(inputs), new Label(label)); + + for (String entry : rawLabels) { + labels.add(Float.parseFloat(entry.trim())); + } + + data.put(new DataSetEntry(inputs), new Labels(labels)); } } catch (IOException e) { throw new RuntimeException("Failed to read dataset from: " + path, e); diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/Label.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/Label.java deleted file mode 100644 index 4f1849e..0000000 --- a/src/main/java/com/naaturel/ANN/domain/model/dataset/Label.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.naaturel.ANN.domain.model.dataset; - -public class Label { - - private float value; - - public Label(float value){ - this.value = value; - } - - - public float getValue() { - return value; - } -} diff --git a/src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java b/src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java new file mode 100644 index 0000000..9a7a785 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/dataset/Labels.java @@ -0,0 +1,16 @@ +package com.naaturel.ANN.domain.model.dataset; + +import java.util.List; + +public class Labels { + + private final List values; + + public Labels(List value){ + this.values = value; + } + + public List getValues() { + return values.stream().toList(); + } +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java index e366374..6768ecd 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Layer.java @@ -1,6 +1,5 @@ package com.naaturel.ANN.domain.model.neuron; -import com.naaturel.ANN.domain.abstraction.Neuron; import com.naaturel.ANN.domain.abstraction.Model; import java.util.ArrayList; @@ -34,6 +33,11 @@ public class Layer implements Model { return res; } + @Override + public void forEachNeuron(Consumer consumer) { + this.neurons.forEach(consumer); + } + @Override public void forEachSynapse(Consumer consumer) { this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java index 769f9fc..91d8b5e 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Network.java @@ -33,6 +33,11 @@ public class Network implements Model { return res; } + @Override + public void forEachNeuron(Consumer consumer) { + this.layers.forEach(layer -> layer.forEachNeuron(consumer)); + } + @Override public void forEachSynapse(Consumer consumer) { this.layers.forEach(layer -> layer.forEachSynapse(consumer)); diff --git a/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java new file mode 100644 index 0000000..e85174c --- /dev/null +++ b/src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java @@ -0,0 +1,67 @@ +package com.naaturel.ANN.domain.model.neuron; +import com.naaturel.ANN.domain.abstraction.ActivationFunction; +import com.naaturel.ANN.domain.abstraction.Model; + +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +public class Neuron implements Model { + + protected List synapses; + protected Bias bias; + protected ActivationFunction activationFunction; + + public Neuron(List synapses, Bias bias, ActivationFunction func){ + this.synapses = synapses; + this.bias = bias; + this.activationFunction = func; + } + + public void updateBias(Weight weight) { + this.bias.setWeight(weight.getValue()); + } + + public void updateWeight(int index, Weight weight) { + this.synapses.get(index).setWeight(weight.getValue()); + } + + protected void setInputs(List inputs){ + for(int i = 0; i < inputs.size() && i < synapses.size(); i++){ + Synapse syn = this.synapses.get(i); + syn.setInput(inputs.get(i)); + } + } + + @Override + public int synCount() { + return this.synapses.size()+1; //take the bias in account + } + + @Override + public List predict(List inputs) { + this.setInputs(inputs); + return List.of(activationFunction.accept(this)); + } + + @Override + public void forEachNeuron(Consumer consumer) { + consumer.accept(this); + } + + @Override + public void forEachSynapse(Consumer consumer) { + consumer.accept(this.bias); + this.synapses.forEach(consumer); + } + + public float calculateWeightedSum() { + float res = 0; + res += this.bias.getWeight() * this.bias.getInput(); + for(Synapse syn : this.synapses){ + res += syn.getWeight() * syn.getInput(); + } + return res; + } + +} diff --git a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java index d99ef80..3f24142 100644 --- a/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java +++ b/src/main/java/com/naaturel/ANN/domain/model/training/TrainingPipeline.java @@ -4,8 +4,8 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import java.sql.Time; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Consumer; import java.util.function.Predicate; @@ -74,19 +74,23 @@ public class TrainingPipeline { private void executeSteps(TrainingContext ctx){ for (DataSetEntry entry : ctx.dataset) { + ctx.currentEntry = entry; - ctx.currentLabel = ctx.dataset.getLabel(entry); + ctx.expectations = ctx.dataset.getLabelsAsFloat(entry); + for (TrainingStep step : steps) { step.run(); } + if(this.verbose) { System.out.printf("Epoch : %d, ", ctx.epoch); - System.out.printf("predicted : %.2f, ", ctx.prediction); - System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue()); - System.out.printf("delta : %.2f, ", ctx.delta); + System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray())); + System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray())); + System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray())); System.out.printf("loss : %.5f\n", ctx.localLoss); } } ctx.epoch += 1; } + } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java index e326eb7..411cb2a 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/GradientDescentErrorStrategy.java @@ -15,13 +15,23 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy { @Override public void apply() { - AtomicInteger i = new AtomicInteger(0); - context.model.forEachSynapse(syn -> { - float corrector = context.correctorTerms.get(i.get()); - corrector += context.learningRate * context.delta * syn.getInput(); - context.correctorTerms.set(i.get(), corrector); - i.incrementAndGet(); + + AtomicInteger neuronIndex = new AtomicInteger(0); + AtomicInteger synIndex = new AtomicInteger(0); + + context.model.forEachNeuron(neuron -> { + float correspondingDelta = context.deltas.get(neuronIndex.get()); + + neuron.forEachSynapse(syn -> { + float corrector = context.correctorTerms.get(synIndex.get()); + corrector += context.learningRate * correspondingDelta * syn.getInput(); + context.correctorTerms.set(synIndex.get(), corrector); + synIndex.incrementAndGet(); + }); + + neuronIndex.incrementAndGet(); }); + context.globalLoss += context.localLoss; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java index b7bc8a7..18c7acc 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/Linear.java @@ -1,7 +1,7 @@ package com.naaturel.ANN.implementation.gradientDescent; import com.naaturel.ANN.domain.abstraction.ActivationFunction; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; public class Linear implements ActivationFunction { diff --git a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java index 2aa2cd4..fa15fd7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/gradientDescent/SquareLossStrategy.java @@ -4,6 +4,8 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext; +import java.util.stream.Stream; + public class SquareLossStrategy implements AlgorithmStrategy { private final TrainingContext context; @@ -14,6 +16,8 @@ public class SquareLossStrategy implements AlgorithmStrategy { @Override public void apply() { - this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2; + Stream deltaStream = this.context.deltas.stream(); + this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2))); + this.context.localLoss /= 2; } } diff --git a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java b/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java deleted file mode 100644 index 0c7686f..0000000 --- a/src/main/java/com/naaturel/ANN/implementation/neuron/SimplePerceptron.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.naaturel.ANN.implementation.neuron; - -import com.naaturel.ANN.domain.abstraction.ActivationFunction; -import com.naaturel.ANN.domain.abstraction.Neuron; -import com.naaturel.ANN.domain.model.neuron.Bias; -import com.naaturel.ANN.domain.model.neuron.Input; -import com.naaturel.ANN.domain.model.neuron.Synapse; - -import java.util.List; -import java.util.function.Consumer; - -public class SimplePerceptron extends Neuron { - - public SimplePerceptron(List synapses, Bias b, ActivationFunction func) { - super(synapses, b, func); - } - - @Override - public List predict(List inputs) { - super.setInputs(inputs); - return List.of(activationFunction.accept(this)); - } - - @Override - public void forEachSynapse(Consumer consumer) { - consumer.accept(this.bias); - this.synapses.forEach(consumer); - } - - @Override - public float calculateWeightedSum() { - float res = 0; - res += this.bias.getWeight() * this.bias.getInput(); - for(Synapse syn : super.synapses){ - res += syn.getWeight() * syn.getInput(); - } - return res; - } - -} diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java index b3e2ec4..af0df44 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/Heaviside.java @@ -1,7 +1,7 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.ActivationFunction; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; public class Heaviside implements ActivationFunction { diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java index 05671f9..26441a7 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleCorrectionStrategy.java @@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingContext; +import java.util.concurrent.atomic.AtomicInteger; + public class SimpleCorrectionStrategy implements AlgorithmStrategy { @@ -14,12 +16,20 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy { @Override public void apply() { - if(context.currentLabel.getValue() == context.prediction) return ; - context.model.forEachSynapse(syn -> { - float currentW = syn.getWeight(); - float currentInput = syn.getInput(); - float newValue = currentW + (context.learningRate * context.delta * currentInput); - syn.setWeight(newValue); + if(context.expectations.equals(context.predictions)) return; + AtomicInteger neuronIndex = new AtomicInteger(0); + AtomicInteger synIndex = new AtomicInteger(0); + + context.model.forEachNeuron(neuron -> { + float correspondingDelta = context.deltas.get(neuronIndex.get()); + neuron.forEachSynapse(syn -> { + float currentW = syn.getWeight(); + float currentInput = syn.getInput(); + float newValue = currentW + (context.learningRate * correspondingDelta * currentInput); + syn.setWeight(newValue); + synIndex.incrementAndGet(); + }); + neuronIndex.incrementAndGet(); }); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java index 4e7da26..ec57d65 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleDeltaStrategy.java @@ -4,7 +4,11 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.dataset.Label; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class SimpleDeltaStrategy implements AlgorithmStrategy { @@ -18,9 +22,14 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy { public void apply() { DataSet dataSet = context.dataset; DataSetEntry entry = context.currentEntry; - Label label = dataSet.getLabel(entry); + List predicted = context.predictions; + List expected = dataSet.getLabelsAsFloat(entry); - context.delta = label.getValue() - context.prediction; + //context.delta = label.getValue() - context.predictions; + context.deltas = IntStream.range(0, predicted.size()) + .mapToObj(i -> expected.get(i) - predicted.get(i)) + .collect(Collectors.toList()); + System.out.printf(""); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java index 145413d..4e4247f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimpleLossStrategy.java @@ -12,6 +12,6 @@ public class SimpleLossStrategy implements AlgorithmStrategy { @Override public void apply() { - this.context.localLoss = Math.abs(this.context.delta); + this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java index 64b7e2e..bf234ca 100644 --- a/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java +++ b/src/main/java/com/naaturel/ANN/implementation/simplePerceptron/SimplePredictionStrategy.java @@ -15,7 +15,6 @@ public class SimplePredictionStrategy implements AlgorithmStrategy { @Override public void apply() { - List predictions = context.model.predict(context.currentEntry.getData()); - context.prediction = predictions.getFirst(); + context.predictions = context.model.predict(context.currentEntry.getData()); } } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java index f1a2166..8c4553f 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/AdalineTraining.java @@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000) + .stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000) .beforeEpoch(ctx -> ctx.globalLoss = 0.0F) .afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size()) - .withVerbose(true) .withTimeMeasurement(true) + .withVerbose(true) .run(context); } diff --git a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java index 8c2f975..2d98ccb 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/GradientDescentTraining.java @@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer { ); new TrainingPipeline(steps) - .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000) + .stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500) .beforeEpoch(ctx -> { GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx; gdCtx.globalLoss = 0.0F; diff --git a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java index d05c9df..c5c377b 100644 --- a/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java +++ b/src/main/java/com/naaturel/ANN/implementation/training/steps/DeltaStep.java @@ -1,11 +1,7 @@ package com.naaturel.ANN.implementation.training.steps; import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; -import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingStep; -import com.naaturel.ANN.domain.model.dataset.DataSet; -import com.naaturel.ANN.domain.model.dataset.DataSetEntry; -import com.naaturel.ANN.domain.model.dataset.Label; public class DeltaStep implements TrainingStep { diff --git a/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java new file mode 100644 index 0000000..85f1627 --- /dev/null +++ b/src/main/java/com/naaturel/ANN/infrastructure/graph/GraphVisualizer.java @@ -0,0 +1,9 @@ +package com.naaturel.ANN.infrastructure.graph; + +public class GraphVisualizer { + + public GraphVisualizer(){ + + } + +} diff --git a/src/test/java/adaline/AdalineTest.java b/src/test/java/adaline/AdalineTest.java index 75c4f30..c360de6 100644 --- a/src/test/java/adaline/AdalineTest.java +++ b/src/test/java/adaline/AdalineTest.java @@ -1,7 +1,7 @@ package adaline; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; @@ -9,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.gradientDescent.*; -import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; @@ -37,7 +36,7 @@ public class AdalineTest { @BeforeEach public void init(){ dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -45,7 +44,7 @@ public class AdalineTest { bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); + Neuron neuron = new Neuron(syns, bias, new Linear()); Layer layer = new Layer(List.of(neuron)); network = new Network(List.of(layer)); diff --git a/src/test/java/gradientDescent/GradientDescentTest.java b/src/test/java/gradientDescent/GradientDescentTest.java index 6a33af6..29fe934 100644 --- a/src/test/java/gradientDescent/GradientDescentTest.java +++ b/src/test/java/gradientDescent/GradientDescentTest.java @@ -1,13 +1,12 @@ package gradientDescent; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.implementation.gradientDescent.*; -import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; @@ -15,7 +14,6 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.*; @@ -34,7 +32,7 @@ public class GradientDescentTest { @BeforeEach public void init(){ dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv"); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -42,7 +40,7 @@ public class GradientDescentTest { bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Linear()); + Neuron neuron = new Neuron(syns, bias, new Linear()); Layer layer = new Layer(List.of(neuron)); network = new Network(List.of(layer)); diff --git a/src/test/java/perceptron/SimplePerceptronTest.java b/src/test/java/perceptron/SimplePerceptronTest.java index 13ac89f..4615d10 100644 --- a/src/test/java/perceptron/SimplePerceptronTest.java +++ b/src/test/java/perceptron/SimplePerceptronTest.java @@ -1,12 +1,11 @@ package perceptron; -import com.naaturel.ANN.domain.abstraction.Neuron; +import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.model.dataset.DataSet; import com.naaturel.ANN.domain.model.dataset.DatasetExtractor; import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.training.TrainingPipeline; -import com.naaturel.ANN.implementation.neuron.SimplePerceptron; import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.training.steps.*; import org.junit.jupiter.api.BeforeEach; @@ -32,7 +31,7 @@ public class SimplePerceptronTest { @BeforeEach public void init(){ dataset = new DatasetExtractor() - .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"); + .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1); List syns = new ArrayList<>(); syns.add(new Synapse(new Input(0), new Weight(0))); @@ -40,7 +39,7 @@ public class SimplePerceptronTest { bias = new Bias(new Weight(0)); - Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside()); + Neuron neuron = new Neuron(syns, bias, new Heaviside()); Layer layer = new Layer(List.of(neuron)); network = new Network(List.of(layer));