Rename some stuff

This commit is contained in:
2026-03-29 21:32:08 +02:00
parent 83526b72d4
commit 0fe309cd4e
38 changed files with 334 additions and 215 deletions

View File

@@ -3,13 +3,10 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid; import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining; import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import java.util.*; import java.util.*;
@@ -20,7 +17,7 @@ public class Main {
int nbrInput = 2; int nbrInput = 2;
int nbrClass = 3; int nbrClass = 3;
int nbrLayers = 1; int nbrLayers = 2;
DataSet dataset = new DatasetExtractor() DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass); .extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
@@ -44,7 +41,7 @@ public class Main {
Layer layer = new Layer(neurons); Layer layer = new Layer(neurons);
layers.add(layer); layers.add(layer);
} }
Network network = new Network(layers); FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining(); Trainer trainer = new GradientBackpropagationTraining();
trainer.train(network, dataset); trainer.train(network, dataset);

View File

@@ -5,5 +5,6 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
public interface ActivationFunction { public interface ActivationFunction {
float accept(Neuron n); float accept(Neuron n);
float derivative(float value);
} }

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.domain.abstraction; package com.naaturel.ANN.domain.abstraction;
@FunctionalInterface @FunctionalInterface
public interface AlgorithmStrategy { public interface AlgorithmStep {
void apply(); void run();
} }

View File

@@ -5,13 +5,14 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.model.neuron.Synapse; import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
public interface Model { public interface Model {
int synCount(); int synCount();
int neuronCount();
void forEachNeuron(Consumer<Neuron> consumer); void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer); void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
List<Float> predict(List<Input> inputs); List<Float> predict(List<Input> inputs);
} }

View File

@@ -0,0 +1,10 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import java.util.function.Consumer;
public interface Network {
}

View File

@@ -1,7 +1,7 @@
package com.naaturel.ANN.domain.abstraction; package com.naaturel.ANN.domain.abstraction;
public interface TrainingStep { /*public interface TrainingStep {
void run(); void run();
} }*/

View File

@@ -0,0 +1,84 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Network;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class FullyConnectedNetwork implements Model {
private final List<Layer> layers;;
private final Map<Neuron, List<Neuron>> connectionMap;
public FullyConnectedNetwork(List<Layer> layers) {
this.layers = layers;
this.connectionMap = this.createConnectionMap();
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Input> previousLayerOutputs = new ArrayList<>(inputs);
for(Layer layer : this.layers){
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs);
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
}
return previousLayerOutputs.stream().map(Input::getValue).toList();
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public int neuronCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.neuronCount();
}
return res;
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.layers.getLast().forEachNeuron(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
this.connectionMap.get(n).forEach(consumer);
}
private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>();
for (int i = 0; i < this.layers.size() - 1; i++) {
List<Neuron> nextLayerNeurons = new ArrayList<>();
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
}
return res;
}
}

View File

@@ -33,6 +33,11 @@ public class Layer implements Model {
return res; return res;
} }
@Override
public int neuronCount() {
return this.neurons.size();
}
@Override @Override
public void forEachNeuron(Consumer<Neuron> consumer) { public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer); this.neurons.forEach(consumer);
@@ -42,4 +47,14 @@ public class Layer implements Model {
public void forEachSynapse(Consumer<Synapse> consumer) { public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer)); this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
} }
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
}
} }

View File

@@ -1,48 +0,0 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class Network implements Model {
private final List<Layer> layers;
public Network(List<Layer> layers) {
this.layers = layers;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Input> previousLayerOutput = new ArrayList<>(inputs);
for(Layer layer : this.layers){
List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
}
return previousLayerOutput.stream().map(Input::getValue).toList();
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
}
}

View File

@@ -3,7 +3,6 @@ import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
public class Neuron implements Model { public class Neuron implements Model {
@@ -11,11 +10,13 @@ public class Neuron implements Model {
protected List<Synapse> synapses; protected List<Synapse> synapses;
protected Bias bias; protected Bias bias;
protected ActivationFunction activationFunction; protected ActivationFunction activationFunction;
protected Float output;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){ public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses; this.synapses = synapses;
this.bias = bias; this.bias = bias;
this.activationFunction = func; this.activationFunction = func;
this.output = 0F;
} }
public void updateBias(Weight weight) { public void updateBias(Weight weight) {
@@ -33,15 +34,38 @@ public class Neuron implements Model {
} }
} }
public ActivationFunction getActivationFunction(){
return this.activationFunction;
}
public float getOutput(){
return this.output;
}
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
}
@Override @Override
public int synCount() { public int synCount() {
return this.synapses.size()+1; //take the bias in account return this.synapses.size()+1; //take the bias into account
}
@Override
public int neuronCount() {
return 1;
} }
@Override @Override
public List<Float> predict(List<Input> inputs) { public List<Float> predict(List<Input> inputs) {
this.setInputs(inputs); this.setInputs(inputs);
return List.of(activationFunction.accept(this)); this.output = activationFunction.accept(this);
return List.of(output);
} }
@Override @Override
@@ -55,13 +79,13 @@ public class Neuron implements Model {
this.synapses.forEach(consumer); this.synapses.forEach(consumer);
} }
public float calculateWeightedSum() { @Override
float res = 0; public void forEachOutputNeurons(Consumer<Neuron> consumer) {
res += this.bias.getWeight() * this.bias.getInput(); consumer.accept(this);
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
} }
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection with themselves");
}
} }

View File

@@ -1,5 +1,6 @@
package com.naaturel.ANN.domain.model.training; package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
@@ -15,7 +16,7 @@ import java.util.function.Predicate;
public class TrainingPipeline { public class TrainingPipeline {
private final List<TrainingStep> steps; private final List<AlgorithmStep> steps;
private Consumer<TrainingContext> beforeEpoch; private Consumer<TrainingContext> beforeEpoch;
private Consumer<TrainingContext> afterEpoch; private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition; private Predicate<TrainingContext> stopCondition;
@@ -25,7 +26,7 @@ public class TrainingPipeline {
private boolean visualization; private boolean visualization;
private boolean timeMeasurement; private boolean timeMeasurement;
public TrainingPipeline(List<TrainingStep> steps) { public TrainingPipeline(List<AlgorithmStep> steps) {
this.steps = new ArrayList<>(steps); this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false; this.stopCondition = (ctx) -> false;
this.beforeEpoch = (context -> {}); this.beforeEpoch = (context -> {});
@@ -90,7 +91,7 @@ public class TrainingPipeline {
ctx.currentEntry = entry; ctx.currentEntry = entry;
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry); ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
for (TrainingStep step : steps) { for (AlgorithmStep step : steps) {
step.run(); step.run();
} }

View File

@@ -1,10 +1,10 @@
package com.naaturel.ANN.implementation.gradientDescent; package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy { public class GradientDescentCorrectionStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context; private final GradientDescentTrainingContext context;
@@ -13,7 +13,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
} }
@Override @Override
public void apply() { public void run() {
AtomicInteger i = new AtomicInteger(0); AtomicInteger i = new AtomicInteger(0);
context.model.forEachSynapse(syn -> { context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get()); float corrector = context.correctorTerms.get(i.get());

View File

@@ -1,10 +1,10 @@
package com.naaturel.ANN.implementation.gradientDescent; package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentErrorStrategy implements AlgorithmStrategy { public class GradientDescentErrorStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context; private final GradientDescentTrainingContext context;
@@ -14,7 +14,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
@Override @Override
public void apply() { public void run() {
AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0); AtomicInteger synIndex = new AtomicInteger(0);

View File

@@ -5,9 +5,22 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Linear implements ActivationFunction { public class Linear implements ActivationFunction {
private final float slope;
private final float intercept;
public Linear(float slope, float intercept) {
this.slope = slope;
this.intercept = intercept;
}
@Override @Override
public float accept(Neuron n) { public float accept(Neuron n) {
return n.calculateWeightedSum(); return slope * n.calculateWeightedSum() + intercept;
}
@Override
public float derivative(float value) {
return this.slope;
} }
} }

View File

@@ -1,21 +1,20 @@
package com.naaturel.ANN.implementation.gradientDescent; package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.stream.Stream; import java.util.stream.Stream;
public class SquareLossStrategy implements AlgorithmStrategy { public class SquareLossStep implements AlgorithmStep {
private final TrainingContext context; private final TrainingContext context;
public SquareLossStrategy(TrainingContext context) { public SquareLossStep(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void run() {
Stream<Float> deltaStream = this.context.deltas.stream(); Stream<Float> deltaStream = this.context.deltas.stream();
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2))); this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
this.context.localLoss /= 2; this.context.localLoss /= 2;

View File

@@ -2,5 +2,14 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.ArrayList;
import java.util.List;
public class GradientBackpropagationContext extends TrainingContext { public class GradientBackpropagationContext extends TrainingContext {
public List<Float> hiddenDeltas;
public GradientBackpropagationContext(){
}
} }

View File

@@ -0,0 +1,24 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class GradientBackpropagationStep implements AlgorithmStep {
private GradientBackpropagationContext context;
public GradientBackpropagationStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
}
private float calculateDeltaRecursive(Neuron n){
}
}

View File

@@ -1,17 +0,0 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
private GradientBackpropagationContext context;
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void apply() {
}
}

View File

@@ -15,4 +15,9 @@ public class Sigmoid implements ActivationFunction {
public float accept(Neuron n) { public float accept(Neuron n) {
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum()))); return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
} }
@Override
public float derivative(float value) {
return steepness * value * (1 - value);
}
} }

View File

@@ -14,4 +14,8 @@ public class TanH implements ActivationFunction {
return (float)(res); return (float)(res);
} }
@Override
public float derivative(float value) {
return 1 - value * value;
}
} }

View File

@@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction; import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron; import com.naaturel.ANN.domain.model.neuron.Neuron;
import javax.naming.OperationNotSupportedException;
public class Heaviside implements ActivationFunction { public class Heaviside implements ActivationFunction {
public Heaviside(){ public Heaviside(){
@@ -14,4 +16,9 @@ public class Heaviside implements ActivationFunction {
float weightedSum = n.calculateWeightedSum(); float weightedSum = n.calculateWeightedSum();
return weightedSum < 0 ? 0:1; return weightedSum < 0 ? 0:1;
} }
@Override
public float derivative(float value) {
throw new UnsupportedOperationException("Heaviside is not differentiable");
}
} }

View File

@@ -1,21 +1,21 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public class SimpleCorrectionStrategy implements AlgorithmStrategy { public class SimpleCorrectionStep implements AlgorithmStep {
private final TrainingContext context; private final TrainingContext context;
public SimpleCorrectionStrategy(TrainingContext context) { public SimpleCorrectionStep(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void run() {
if(context.expectations.equals(context.predictions)) return; if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0); AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0); AtomicInteger synIndex = new AtomicInteger(0);

View File

@@ -1,6 +1,6 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry; import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
@@ -9,16 +9,16 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
public class SimpleDeltaStrategy implements AlgorithmStrategy { public class SimpleDeltaStep implements AlgorithmStep {
private final TrainingContext context; private final TrainingContext context;
public SimpleDeltaStrategy(TrainingContext context) { public SimpleDeltaStep(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void run() {
DataSet dataSet = context.dataset; DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry; DataSetEntry entry = context.currentEntry;
List<Float> predicted = context.predictions; List<Float> predicted = context.predictions;
@@ -28,7 +28,6 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy {
context.deltas = IntStream.range(0, predicted.size()) context.deltas = IntStream.range(0, predicted.size())
.mapToObj(i -> expected.get(i) - predicted.get(i)) .mapToObj(i -> expected.get(i) - predicted.get(i))
.collect(Collectors.toList()); .collect(Collectors.toList());
System.out.printf("");
} }
} }

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy { public class SimpleErrorRegistrationStep implements AlgorithmStep {
private final TrainingContext context; private final TrainingContext context;
public SimpleErrorRegistrationStrategy(TrainingContext context) { public SimpleErrorRegistrationStep(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void run() {
context.globalLoss += context.localLoss; context.globalLoss += context.localLoss;
} }
} }

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class SimpleLossStrategy implements AlgorithmStrategy { public class SimpleLossStrategy implements AlgorithmStep {
private final SimpleTrainingContext context; private final SimpleTrainingContext context;
@@ -11,7 +11,7 @@ public class SimpleLossStrategy implements AlgorithmStrategy {
} }
@Override @Override
public void apply() { public void run() {
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum); this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
} }
} }

View File

@@ -1,20 +1,18 @@
package com.naaturel.ANN.implementation.simplePerceptron; package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext; import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List; public class SimplePredictionStep implements AlgorithmStep {
public class SimplePredictionStrategy implements AlgorithmStrategy {
private final TrainingContext context; private final TrainingContext context;
public SimplePredictionStrategy(TrainingContext context) { public SimplePredictionStep(TrainingContext context) {
this.context = context; this.context = context;
} }
@Override @Override
public void apply() { public void run() {
context.predictions = context.model.predict(context.currentEntry.getData()); context.predictions = context.model.predict(context.currentEntry.getData());
} }
} }

View File

@@ -1,17 +1,16 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.List; import java.util.List;
@@ -30,12 +29,12 @@ public class AdalineTraining implements Trainer {
context.model = model; context.model = model;
context.learningRate = 0.003F; context.learningRate = 0.003F;
List<TrainingStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new SimplePredictionStep(context),
new DeltaStep(new SimpleDeltaStrategy(context)), new SimpleDeltaStep(context),
new LossStep(new SquareLossStrategy(context)), new SquareLossStep(context),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), new SimpleErrorRegistrationStep(context),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) new SimpleCorrectionStep(context)
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)

View File

@@ -1,14 +1,13 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep;
import com.naaturel.ANN.implementation.training.steps.DeltaStep; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.training.steps.PredictionStep; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List; import java.util.List;
@@ -17,14 +16,15 @@ import java.util.List;
public class GradientBackpropagationTraining implements Trainer { public class GradientBackpropagationTraining implements Trainer {
@Override @Override
public void train(Model model, DataSet dataset) { public void train(Model model, DataSet dataset) {
TrainingContext context = new GradientDescentTrainingContext(); GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset; context.dataset = dataset;
context.model = model; context.model = model;
context.learningRate = 0.0008F; context.learningRate = 0.001F;
List<TrainingStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new SimplePredictionStep(context),
new DeltaStep() new SimpleDeltaStep(context),
new GradientBackpropagationStep(context)
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)

View File

@@ -1,16 +1,16 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy; import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy; import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer; import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
@@ -31,11 +31,11 @@ public class GradientDescentTraining implements Trainer {
context.learningRate = 0.0008F; context.learningRate = 0.0008F;
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new SimplePredictionStep(context),
new DeltaStep(new SimpleDeltaStrategy(context)), new SimpleDeltaStep(context),
new LossStep(new SquareLossStrategy(context)), new SquareLossStep(context),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) new GradientDescentErrorStrategy(context)
); );
new TrainingPipeline(steps) new TrainingPipeline(steps)
@@ -48,7 +48,7 @@ public class GradientDescentTraining implements Trainer {
}) })
.afterEpoch(ctx -> { .afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size(); context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply(); new GradientDescentCorrectionStrategy(context).run();
}) })
//.withVerbose(true) //.withVerbose(true)
.withTimeMeasurement(true) .withTimeMeasurement(true)

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.implementation.training; package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model; import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer; import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*; import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
@@ -23,12 +23,12 @@ public class SimpleTraining implements Trainer {
context.model = model; context.model = model;
context.learningRate = 0.3F; context.learningRate = 0.3F;
List<TrainingStep> steps = List.of( List<AlgorithmStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new SimplePredictionStep(context),
new DeltaStep(new SimpleDeltaStrategy(context)), new SimpleDeltaStep(context),
new LossStep(new SimpleLossStrategy(context)), new SimpleLossStrategy(context),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), new SimpleErrorRegistrationStep(context),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) new SimpleCorrectionStep(context)
); );
TrainingPipeline pipeline = new TrainingPipeline(steps); TrainingPipeline pipeline = new TrainingPipeline(steps);

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class DeltaStep implements TrainingStep { public class DeltaStep implements TrainingStep {
private final AlgorithmStrategy strategy; private final AlgorithmStep strategy;
public DeltaStep(AlgorithmStrategy strategy) { public DeltaStep(AlgorithmStep strategy) {
this.strategy = strategy; this.strategy = strategy;
} }
@Override @Override
public void run() { public void run() {
this.strategy.apply(); this.strategy.run();
} }
} }

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class ErrorRegistrationStep implements TrainingStep { public class ErrorRegistrationStep implements TrainingStep {
private final AlgorithmStrategy strategy; private final AlgorithmStep strategy;
public ErrorRegistrationStep(AlgorithmStrategy strategy) { public ErrorRegistrationStep(AlgorithmStep strategy) {
this.strategy = strategy; this.strategy = strategy;
} }
@Override @Override
public void run() { public void run() {
this.strategy.apply(); this.strategy.run();
} }
} }

View File

@@ -1,20 +1,19 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class LossStep implements TrainingStep { public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy; private final AlgorithmStep lossStrategy;
public LossStep(AlgorithmStrategy strategy) { public LossStep(AlgorithmStep strategy) {
this.lossStrategy = strategy; this.lossStrategy = strategy;
} }
@Override @Override
public void run() { public void run() {
this.lossStrategy.apply(); this.lossStrategy.run();
} }
} }

View File

@@ -1,23 +1,18 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.List;
public class PredictionStep implements TrainingStep { public class PredictionStep implements TrainingStep {
private final AlgorithmStrategy strategy; private final AlgorithmStep strategy;
public PredictionStep(AlgorithmStrategy strategy) { public PredictionStep(AlgorithmStep strategy) {
this.strategy = strategy; this.strategy = strategy;
} }
@Override @Override
public void run() { public void run() {
this.strategy.apply(); this.strategy.run();
} }
} }

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps; package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy; import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep; import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class WeightCorrectionStep implements TrainingStep { public class WeightCorrectionStep implements TrainingStep {
private final AlgorithmStrategy correctionStrategy; private final AlgorithmStep correctionStrategy;
public WeightCorrectionStep(AlgorithmStrategy strategy) { public WeightCorrectionStep(AlgorithmStep strategy) {
this.correctionStrategy = strategy; this.correctionStrategy = strategy;
} }
@Override @Override
public void run() { public void run() {
this.correctionStrategy.apply(); this.correctionStrategy.run();
} }
} }

View File

@@ -9,10 +9,10 @@ import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline; import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext; import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.*; import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy; import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*; import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -29,7 +29,7 @@ public class AdalineTest {
private List<Synapse> synapses; private List<Synapse> synapses;
private Bias bias; private Bias bias;
private Network network; private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -44,20 +44,20 @@ public class AdalineTest {
bias = new Bias(new Weight(0)); bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear()); Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer)); network = new FullyConnectedNetwork(List.of(layer));
context = new AdalineTrainingContext(); context = new AdalineTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = network; context.model = network;
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStrategy(context)), new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStrategy(context)), new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) new WeightCorrectionStep(new SimpleCorrectionStep(context))
); );
pipeline = new TrainingPipeline(steps) pipeline = new TrainingPipeline(steps)

View File

@@ -25,7 +25,7 @@ public class GradientDescentTest {
private List<Synapse> synapses; private List<Synapse> synapses;
private Bias bias; private Bias bias;
private Network network; private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -40,9 +40,9 @@ public class GradientDescentTest {
bias = new Bias(new Weight(0)); bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear()); Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer)); network = new FullyConnectedNetwork(List.of(layer));
context = new GradientDescentTrainingContext(); context = new GradientDescentTrainingContext();
context.dataset = dataset; context.dataset = dataset;
@@ -50,9 +50,9 @@ public class GradientDescentTest {
context.correctorTerms = new ArrayList<>(); context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStrategy(context)), new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStrategy(context)), new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)) new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
); );
@@ -82,7 +82,7 @@ public class GradientDescentTest {
context.learningRate = 0.2F; context.learningRate = 0.2F;
pipeline.afterEpoch(ctx -> { pipeline.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size(); context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply(); new GradientDescentCorrectionStrategy(context).run();
int index = ctx.epoch-1; int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return; if(index >= expectedGlobalLosses.size()) return;

View File

@@ -24,7 +24,7 @@ public class SimplePerceptronTest {
private List<Synapse> synapses; private List<Synapse> synapses;
private Bias bias; private Bias bias;
private Network network; private FullyConnectedNetwork network;
private TrainingPipeline pipeline; private TrainingPipeline pipeline;
@@ -41,18 +41,18 @@ public class SimplePerceptronTest {
Neuron neuron = new Neuron(syns, bias, new Heaviside()); Neuron neuron = new Neuron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron)); Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer)); network = new FullyConnectedNetwork(List.of(layer));
context = new SimpleTrainingContext(); context = new SimpleTrainingContext();
context.dataset = dataset; context.dataset = dataset;
context.model = network; context.model = network;
List<TrainingStep> steps = List.of( List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)), new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStrategy(context)), new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SimpleLossStrategy(context)), new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)), new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context)) new WeightCorrectionStep(new SimpleCorrectionStep(context))
); );
pipeline = new TrainingPipeline(steps); pipeline = new TrainingPipeline(steps);