Rename some stuff

This commit is contained in:
2026-03-29 21:32:08 +02:00
parent 83526b72d4
commit 0fe309cd4e
38 changed files with 334 additions and 215 deletions

View File

@@ -3,13 +3,10 @@ package com.naaturel.ANN;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import java.util.*;
@@ -20,7 +17,7 @@ public class Main {
int nbrInput = 2;
int nbrClass = 3;
int nbrLayers = 1;
int nbrLayers = 2;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
@@ -44,7 +41,7 @@ public class Main {
Layer layer = new Layer(neurons);
layers.add(layer);
}
Network network = new Network(layers);
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(network, dataset);

View File

@@ -5,5 +5,6 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
public interface ActivationFunction {
float accept(Neuron n);
float derivative(float value);
}

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.domain.abstraction;
@FunctionalInterface
public interface AlgorithmStrategy {
public interface AlgorithmStep {
void apply();
void run();
}

View File

@@ -5,13 +5,14 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
public interface Model {
int synCount();
int neuronCount();
void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
List<Float> predict(List<Input> inputs);
}

View File

@@ -0,0 +1,10 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import java.util.function.Consumer;
public interface Network {
}

View File

@@ -1,7 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
public interface TrainingStep {
/*public interface TrainingStep {
void run();
}
}*/

View File

@@ -0,0 +1,84 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Network;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class FullyConnectedNetwork implements Model {
private final List<Layer> layers;;
private final Map<Neuron, List<Neuron>> connectionMap;
public FullyConnectedNetwork(List<Layer> layers) {
this.layers = layers;
this.connectionMap = this.createConnectionMap();
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Input> previousLayerOutputs = new ArrayList<>(inputs);
for(Layer layer : this.layers){
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs);
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
}
return previousLayerOutputs.stream().map(Input::getValue).toList();
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public int neuronCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.neuronCount();
}
return res;
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.layers.getLast().forEachNeuron(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
this.connectionMap.get(n).forEach(consumer);
}
private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>();
for (int i = 0; i < this.layers.size() - 1; i++) {
List<Neuron> nextLayerNeurons = new ArrayList<>();
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
}
return res;
}
}

View File

@@ -33,6 +33,11 @@ public class Layer implements Model {
return res;
}
@Override
public int neuronCount() {
return this.neurons.size();
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
@@ -42,4 +47,14 @@ public class Layer implements Model {
public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
}
}

View File

@@ -1,48 +0,0 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class Network implements Model {
private final List<Layer> layers;
public Network(List<Layer> layers) {
this.layers = layers;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Input> previousLayerOutput = new ArrayList<>(inputs);
for(Layer layer : this.layers){
List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
}
return previousLayerOutput.stream().map(Input::getValue).toList();
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
}
}

View File

@@ -3,7 +3,6 @@ import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
public class Neuron implements Model {
@@ -11,11 +10,13 @@ public class Neuron implements Model {
protected List<Synapse> synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
protected Float output;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
this.output = 0F;
}
public void updateBias(Weight weight) {
@@ -33,15 +34,38 @@ public class Neuron implements Model {
}
}
public ActivationFunction getActivationFunction(){
return this.activationFunction;
}
public float getOutput(){
return this.output;
}
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
}
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias in account
return this.synapses.size()+1; //take the bias into account
}
@Override
public int neuronCount() {
return 1;
}
@Override
public List<Float> predict(List<Input> inputs) {
this.setInputs(inputs);
return List.of(activationFunction.accept(this));
this.output = activationFunction.accept(this);
return List.of(output);
}
@Override
@@ -55,13 +79,13 @@ public class Neuron implements Model {
this.synapses.forEach(consumer);
}
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
consumer.accept(this);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection with themselves");
}
}

View File

@@ -1,5 +1,6 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
@@ -15,7 +16,7 @@ import java.util.function.Predicate;
public class TrainingPipeline {
private final List<TrainingStep> steps;
private final List<AlgorithmStep> steps;
private Consumer<TrainingContext> beforeEpoch;
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition;
@@ -25,7 +26,7 @@ public class TrainingPipeline {
private boolean visualization;
private boolean timeMeasurement;
public TrainingPipeline(List<TrainingStep> steps) {
public TrainingPipeline(List<AlgorithmStep> steps) {
this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false;
this.beforeEpoch = (context -> {});
@@ -90,7 +91,7 @@ public class TrainingPipeline {
ctx.currentEntry = entry;
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
for (TrainingStep step : steps) {
for (AlgorithmStep step : steps) {
step.run();
}

View File

@@ -1,10 +1,10 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
public class GradientDescentCorrectionStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context;
@@ -13,7 +13,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
}
@Override
public void apply() {
public void run() {
AtomicInteger i = new AtomicInteger(0);
context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get());

View File

@@ -1,10 +1,10 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
public class GradientDescentErrorStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context;
@@ -14,7 +14,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
@Override
public void apply() {
public void run() {
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);

View File

@@ -5,9 +5,22 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Linear implements ActivationFunction {
private final float slope;
private final float intercept;
public Linear(float slope, float intercept) {
this.slope = slope;
this.intercept = intercept;
}
@Override
public float accept(Neuron n) {
return n.calculateWeightedSum();
return slope * n.calculateWeightedSum() + intercept;
}
@Override
public float derivative(float value) {
return this.slope;
}
}

View File

@@ -1,21 +1,20 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.stream.Stream;
public class SquareLossStrategy implements AlgorithmStrategy {
public class SquareLossStep implements AlgorithmStep {
private final TrainingContext context;
public SquareLossStrategy(TrainingContext context) {
public SquareLossStep(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
public void run() {
Stream<Float> deltaStream = this.context.deltas.stream();
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
this.context.localLoss /= 2;

View File

@@ -2,5 +2,14 @@ package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.ArrayList;
import java.util.List;
public class GradientBackpropagationContext extends TrainingContext {
public List<Float> hiddenDeltas;
public GradientBackpropagationContext(){
}
}

View File

@@ -0,0 +1,24 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class GradientBackpropagationStep implements AlgorithmStep {
private GradientBackpropagationContext context;
public GradientBackpropagationStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
}
private float calculateDeltaRecursive(Neuron n){
}
}

View File

@@ -1,17 +0,0 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
private GradientBackpropagationContext context;
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void apply() {
}
}

View File

@@ -15,4 +15,9 @@ public class Sigmoid implements ActivationFunction {
public float accept(Neuron n) {
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
}
@Override
public float derivative(float value) {
return steepness * value * (1 - value);
}
}

View File

@@ -14,4 +14,8 @@ public class TanH implements ActivationFunction {
return (float)(res);
}
@Override
public float derivative(float value) {
return 1 - value * value;
}
}

View File

@@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import javax.naming.OperationNotSupportedException;
public class Heaviside implements ActivationFunction {
public Heaviside(){
@@ -14,4 +16,9 @@ public class Heaviside implements ActivationFunction {
float weightedSum = n.calculateWeightedSum();
return weightedSum < 0 ? 0:1;
}
@Override
public float derivative(float value) {
throw new UnsupportedOperationException("Heaviside is not differentiable");
}
}

View File

@@ -1,21 +1,21 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.concurrent.atomic.AtomicInteger;
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
public class SimpleCorrectionStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleCorrectionStrategy(TrainingContext context) {
public SimpleCorrectionStep(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
public void run() {
if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);

View File

@@ -1,6 +1,6 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
@@ -9,16 +9,16 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class SimpleDeltaStrategy implements AlgorithmStrategy {
public class SimpleDeltaStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleDeltaStrategy(TrainingContext context) {
public SimpleDeltaStep(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
public void run() {
DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry;
List<Float> predicted = context.predictions;
@@ -28,7 +28,6 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy {
context.deltas = IntStream.range(0, predicted.size())
.mapToObj(i -> expected.get(i) - predicted.get(i))
.collect(Collectors.toList());
System.out.printf("");
}
}

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
public class SimpleErrorRegistrationStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleErrorRegistrationStrategy(TrainingContext context) {
public SimpleErrorRegistrationStep(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
public void run() {
context.globalLoss += context.localLoss;
}
}

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class SimpleLossStrategy implements AlgorithmStrategy {
public class SimpleLossStrategy implements AlgorithmStep {
private final SimpleTrainingContext context;
@@ -11,7 +11,7 @@ public class SimpleLossStrategy implements AlgorithmStrategy {
}
@Override
public void apply() {
public void run() {
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
}
}

View File

@@ -1,20 +1,18 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.List;
public class SimplePredictionStrategy implements AlgorithmStrategy {
public class SimplePredictionStep implements AlgorithmStep {
private final TrainingContext context;
public SimplePredictionStrategy(TrainingContext context) {
public SimplePredictionStep(TrainingContext context) {
this.context = context;
}
@Override
public void apply() {
public void run() {
context.predictions = context.model.predict(context.currentEntry.getData());
}
}

View File

@@ -1,17 +1,16 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import java.util.List;
@@ -30,12 +29,12 @@ public class AdalineTraining implements Trainer {
context.model = model;
context.learningRate = 0.003F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
new TrainingPipeline(steps)

View File

@@ -1,14 +1,13 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.training.steps.DeltaStep;
import com.naaturel.ANN.implementation.training.steps.PredictionStep;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
@@ -17,14 +16,15 @@ import java.util.List;
public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(Model model, DataSet dataset) {
TrainingContext context = new GradientDescentTrainingContext();
GradientBackpropagationContext context = new GradientBackpropagationContext();
context.dataset = dataset;
context.model = model;
context.learningRate = 0.0008F;
context.learningRate = 0.001F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep()
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new GradientBackpropagationStep(context)
);
new TrainingPipeline(steps)

View File

@@ -1,16 +1,16 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
@@ -31,11 +31,11 @@ public class GradientDescentTraining implements Trainer {
context.learningRate = 0.0008F;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new GradientDescentErrorStrategy(context)
);
new TrainingPipeline(steps)
@@ -48,7 +48,7 @@ public class GradientDescentTraining implements Trainer {
})
.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply();
new GradientDescentCorrectionStrategy(context).run();
})
//.withVerbose(true)
.withTimeMeasurement(true)

View File

@@ -1,8 +1,8 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
@@ -23,12 +23,12 @@ public class SimpleTraining implements Trainer {
context.model = model;
context.learningRate = 0.3F;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SimpleLossStrategy(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
TrainingPipeline pipeline = new TrainingPipeline(steps);

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class DeltaStep implements TrainingStep {
private final AlgorithmStrategy strategy;
private final AlgorithmStep strategy;
public DeltaStep(AlgorithmStrategy strategy) {
public DeltaStep(AlgorithmStep strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
this.strategy.run();
}
}

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class ErrorRegistrationStep implements TrainingStep {
private final AlgorithmStrategy strategy;
private final AlgorithmStep strategy;
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
public ErrorRegistrationStep(AlgorithmStep strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
this.strategy.run();
}
}

View File

@@ -1,20 +1,19 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class LossStep implements TrainingStep {
private final AlgorithmStrategy lossStrategy;
private final AlgorithmStep lossStrategy;
public LossStep(AlgorithmStrategy strategy) {
public LossStep(AlgorithmStep strategy) {
this.lossStrategy = strategy;
}
@Override
public void run() {
this.lossStrategy.apply();
this.lossStrategy.run();
}
}

View File

@@ -1,23 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.List;
public class PredictionStep implements TrainingStep {
private final AlgorithmStrategy strategy;
private final AlgorithmStep strategy;
public PredictionStep(AlgorithmStrategy strategy) {
public PredictionStep(AlgorithmStep strategy) {
this.strategy = strategy;
}
@Override
public void run() {
this.strategy.apply();
this.strategy.run();
}
}

View File

@@ -1,18 +1,18 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
public class WeightCorrectionStep implements TrainingStep {
private final AlgorithmStrategy correctionStrategy;
private final AlgorithmStep correctionStrategy;
public WeightCorrectionStep(AlgorithmStrategy strategy) {
public WeightCorrectionStep(AlgorithmStep strategy) {
this.correctionStrategy = strategy;
}
@Override
public void run() {
this.correctionStrategy.apply();
this.correctionStrategy.run();
}
}

View File

@@ -9,10 +9,10 @@ import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -29,7 +29,7 @@ public class AdalineTest {
private List<Synapse> synapses;
private Bias bias;
private Network network;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@@ -44,20 +44,20 @@ public class AdalineTest {
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear());
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
network = new FullyConnectedNetwork(List.of(layer));
context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
);
pipeline = new TrainingPipeline(steps)

View File

@@ -25,7 +25,7 @@ public class GradientDescentTest {
private List<Synapse> synapses;
private Bias bias;
private Network network;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@@ -40,9 +40,9 @@ public class GradientDescentTest {
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear());
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
network = new FullyConnectedNetwork(List.of(layer));
context = new GradientDescentTrainingContext();
context.dataset = dataset;
@@ -50,9 +50,9 @@ public class GradientDescentTest {
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new LossStep(new SquareLossStrategy(context)),
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
);
@@ -82,7 +82,7 @@ public class GradientDescentTest {
context.learningRate = 0.2F;
pipeline.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).apply();
new GradientDescentCorrectionStrategy(context).run();
int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return;

View File

@@ -24,7 +24,7 @@ public class SimplePerceptronTest {
private List<Synapse> synapses;
private Bias bias;
private Network network;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@@ -41,18 +41,18 @@ public class SimplePerceptronTest {
Neuron neuron = new Neuron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));
network = new FullyConnectedNetwork(List.of(layer));
context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStrategy(context)),
new DeltaStep(new SimpleDeltaStrategy(context)),
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
);
pipeline = new TrainingPipeline(steps);