Rename some stuff
This commit is contained in:
@@ -3,13 +3,10 @@ package com.naaturel.ANN;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@@ -20,7 +17,7 @@ public class Main {
|
||||
int nbrInput = 2;
|
||||
int nbrClass = 3;
|
||||
|
||||
int nbrLayers = 1;
|
||||
int nbrLayers = 2;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
|
||||
@@ -44,7 +41,7 @@ public class Main {
|
||||
Layer layer = new Layer(neurons);
|
||||
layers.add(layer);
|
||||
}
|
||||
Network network = new Network(layers);
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(network, dataset);
|
||||
|
||||
@@ -5,5 +5,6 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
public interface ActivationFunction {
|
||||
|
||||
float accept(Neuron n);
|
||||
float derivative(float value);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface AlgorithmStrategy {
|
||||
public interface AlgorithmStep {
|
||||
|
||||
void apply();
|
||||
void run();
|
||||
|
||||
}
|
||||
@@ -5,13 +5,14 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Model {
|
||||
int synCount();
|
||||
int neuronCount();
|
||||
void forEachNeuron(Consumer<Neuron> consumer);
|
||||
void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
||||
List<Float> predict(List<Input> inputs);
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Network {
|
||||
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
public interface TrainingStep {
|
||||
/*public interface TrainingStep {
|
||||
|
||||
void run();
|
||||
|
||||
}
|
||||
}*/
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Network;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Represents a fully connected neural network
|
||||
*/
|
||||
public class FullyConnectedNetwork implements Model {
|
||||
|
||||
private final List<Layer> layers;;
|
||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||
|
||||
public FullyConnectedNetwork(List<Layer> layers) {
|
||||
this.layers = layers;
|
||||
this.connectionMap = this.createConnectionMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Input> previousLayerOutputs = new ArrayList<>(inputs);
|
||||
for(Layer layer : this.layers){
|
||||
List<Float> currentLayerOutputs = layer.predict(previousLayerOutputs);
|
||||
previousLayerOutputs = currentLayerOutputs.stream().map(Input::new).toList();
|
||||
}
|
||||
return previousLayerOutputs.stream().map(Input::getValue).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.neuronCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
this.layers.getLast().forEachNeuron(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
this.connectionMap.get(n).forEach(consumer);
|
||||
}
|
||||
|
||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
||||
|
||||
for (int i = 0; i < this.layers.size() - 1; i++) {
|
||||
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
||||
this.layers.get(i + 1).forEachNeuron(nextLayerNeurons::add);
|
||||
this.layers.get(i).forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
||||
@@ -33,6 +33,11 @@ public class Layer implements Model {
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
return this.neurons.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.neurons.forEach(consumer);
|
||||
@@ -42,4 +47,14 @@ public class Layer implements Model {
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
this.neurons.forEach(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Represents a fully connected neural network
|
||||
*/
|
||||
public class Network implements Model {
|
||||
|
||||
private final List<Layer> layers;
|
||||
|
||||
public Network(List<Layer> layers) {
|
||||
this.layers = layers;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Input> previousLayerOutput = new ArrayList<>(inputs);
|
||||
for(Layer layer : this.layers){
|
||||
List<Float> currentLayerOutput = layer.predict(previousLayerOutput);
|
||||
previousLayerOutput = currentLayerOutput.stream().map(Input::new).toList();
|
||||
}
|
||||
return previousLayerOutput.stream().map(Input::getValue).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Neuron implements Model {
|
||||
@@ -11,11 +10,13 @@ public class Neuron implements Model {
|
||||
protected List<Synapse> synapses;
|
||||
protected Bias bias;
|
||||
protected ActivationFunction activationFunction;
|
||||
protected Float output;
|
||||
|
||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||
this.synapses = synapses;
|
||||
this.bias = bias;
|
||||
this.activationFunction = func;
|
||||
this.output = 0F;
|
||||
}
|
||||
|
||||
public void updateBias(Weight weight) {
|
||||
@@ -33,15 +34,38 @@ public class Neuron implements Model {
|
||||
}
|
||||
}
|
||||
|
||||
public ActivationFunction getActivationFunction(){
|
||||
return this.activationFunction;
|
||||
}
|
||||
|
||||
public float getOutput(){
|
||||
return this.output;
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : this.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.synapses.size()+1; //take the bias in account
|
||||
return this.synapses.size()+1; //take the bias into account
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
this.setInputs(inputs);
|
||||
return List.of(activationFunction.accept(this));
|
||||
this.output = activationFunction.accept(this);
|
||||
return List.of(output);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -55,13 +79,13 @@ public class Neuron implements Model {
|
||||
this.synapses.forEach(consumer);
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : this.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
return res;
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
consumer.accept(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.naaturel.ANN.domain.model.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
@@ -15,7 +16,7 @@ import java.util.function.Predicate;
|
||||
|
||||
public class TrainingPipeline {
|
||||
|
||||
private final List<TrainingStep> steps;
|
||||
private final List<AlgorithmStep> steps;
|
||||
private Consumer<TrainingContext> beforeEpoch;
|
||||
private Consumer<TrainingContext> afterEpoch;
|
||||
private Predicate<TrainingContext> stopCondition;
|
||||
@@ -25,7 +26,7 @@ public class TrainingPipeline {
|
||||
private boolean visualization;
|
||||
private boolean timeMeasurement;
|
||||
|
||||
public TrainingPipeline(List<TrainingStep> steps) {
|
||||
public TrainingPipeline(List<AlgorithmStep> steps) {
|
||||
this.steps = new ArrayList<>(steps);
|
||||
this.stopCondition = (ctx) -> false;
|
||||
this.beforeEpoch = (context -> {});
|
||||
@@ -90,7 +91,7 @@ public class TrainingPipeline {
|
||||
ctx.currentEntry = entry;
|
||||
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
|
||||
|
||||
for (TrainingStep step : steps) {
|
||||
for (AlgorithmStep step : steps) {
|
||||
step.run();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStep {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
@@ -13,7 +13,7 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.forEachSynapse(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||
public class GradientDescentErrorStrategy implements AlgorithmStep {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
@@ -14,7 +14,7 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
AtomicInteger synIndex = new AtomicInteger(0);
|
||||
|
||||
@@ -5,9 +5,22 @@ import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class Linear implements ActivationFunction {
|
||||
|
||||
private final float slope;
|
||||
private final float intercept;
|
||||
|
||||
public Linear(float slope, float intercept) {
|
||||
this.slope = slope;
|
||||
this.intercept = intercept;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
return n.calculateWeightedSum();
|
||||
return slope * n.calculateWeightedSum() + intercept;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return this.slope;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||
public class SquareLossStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SquareLossStrategy(TrainingContext context) {
|
||||
public SquareLossStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
Stream<Float> deltaStream = this.context.deltas.stream();
|
||||
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
||||
this.context.localLoss /= 2;
|
||||
@@ -2,5 +2,14 @@ package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientBackpropagationContext extends TrainingContext {
|
||||
|
||||
public List<Float> hiddenDeltas;
|
||||
|
||||
public GradientBackpropagationContext(){
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class GradientBackpropagationStep implements AlgorithmStep {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
public GradientBackpropagationStep(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
private float calculateDeltaRecursive(Neuron n){
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class GradientBackpropagationStrategy implements AlgorithmStrategy {
|
||||
|
||||
private GradientBackpropagationContext context;
|
||||
|
||||
public GradientBackpropagationStrategy(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -15,4 +15,9 @@ public class Sigmoid implements ActivationFunction {
|
||||
public float accept(Neuron n) {
|
||||
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return steepness * value * (1 - value);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,4 +14,8 @@ public class TanH implements ActivationFunction {
|
||||
return (float)(res);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return 1 - value * value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import javax.naming.OperationNotSupportedException;
|
||||
|
||||
public class Heaviside implements ActivationFunction {
|
||||
|
||||
public Heaviside(){
|
||||
@@ -14,4 +16,9 @@ public class Heaviside implements ActivationFunction {
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
return weightedSum < 0 ? 0:1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
throw new UnsupportedOperationException("Heaviside is not differentiable");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
|
||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||
public class SimpleCorrectionStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleCorrectionStrategy(TrainingContext context) {
|
||||
public SimpleCorrectionStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
if(context.expectations.equals(context.predictions)) return;
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
AtomicInteger synIndex = new AtomicInteger(0);
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
@@ -9,16 +9,16 @@ import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||
public class SimpleDeltaStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleDeltaStrategy(TrainingContext context) {
|
||||
public SimpleDeltaStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
DataSet dataSet = context.dataset;
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
List<Float> predicted = context.predictions;
|
||||
@@ -28,7 +28,6 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||
context.deltas = IntStream.range(0, predicted.size())
|
||||
.mapToObj(i -> expected.get(i) - predicted.get(i))
|
||||
.collect(Collectors.toList());
|
||||
System.out.printf("");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||
public class SimpleErrorRegistrationStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleErrorRegistrationStrategy(TrainingContext context) {
|
||||
public SimpleErrorRegistrationStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
public class SimpleLossStrategy implements AlgorithmStep {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
@@ -11,7 +11,7 @@ public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePredictionStrategy implements AlgorithmStrategy {
|
||||
public class SimplePredictionStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimplePredictionStrategy(TrainingContext context) {
|
||||
public SimplePredictionStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
public void run() {
|
||||
context.predictions = context.model.predict(context.currentEntry.getData());
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,16 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.List;
|
||||
@@ -30,12 +29,12 @@ public class AdalineTraining implements Trainer {
|
||||
context.model = model;
|
||||
context.learningRate = 0.003F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.DeltaStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.PredictionStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
@@ -17,14 +16,15 @@ import java.util.List;
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
TrainingContext context = new GradientDescentTrainingContext();
|
||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.0008F;
|
||||
context.learningRate = 0.001F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep()
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new GradientBackpropagationStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
@@ -31,11 +31,11 @@ public class GradientDescentTraining implements Trainer {
|
||||
context.learningRate = 0.0008F;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new GradientDescentErrorStrategy(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
@@ -48,7 +48,7 @@ public class GradientDescentTraining implements Trainer {
|
||||
})
|
||||
.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).apply();
|
||||
new GradientDescentCorrectionStrategy(context).run();
|
||||
})
|
||||
//.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
@@ -23,12 +23,12 @@ public class SimpleTraining implements Trainer {
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SimpleLossStrategy(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class DeltaStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
private final AlgorithmStep strategy;
|
||||
|
||||
public DeltaStep(AlgorithmStrategy strategy) {
|
||||
public DeltaStep(AlgorithmStep strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
this.strategy.run();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class ErrorRegistrationStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
private final AlgorithmStep strategy;
|
||||
|
||||
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
|
||||
public ErrorRegistrationStep(AlgorithmStep strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
this.strategy.run();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class LossStep implements TrainingStep {
|
||||
|
||||
|
||||
private final AlgorithmStrategy lossStrategy;
|
||||
private final AlgorithmStep lossStrategy;
|
||||
|
||||
public LossStep(AlgorithmStrategy strategy) {
|
||||
public LossStep(AlgorithmStep strategy) {
|
||||
this.lossStrategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.lossStrategy.apply();
|
||||
this.lossStrategy.run();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class PredictionStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
private final AlgorithmStep strategy;
|
||||
|
||||
public PredictionStep(AlgorithmStrategy strategy) {
|
||||
public PredictionStep(AlgorithmStep strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
this.strategy.run();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class WeightCorrectionStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy correctionStrategy;
|
||||
private final AlgorithmStep correctionStrategy;
|
||||
|
||||
public WeightCorrectionStep(AlgorithmStrategy strategy) {
|
||||
public WeightCorrectionStep(AlgorithmStep strategy) {
|
||||
this.correctionStrategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.correctionStrategy.apply();
|
||||
this.correctionStrategy.run();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -29,7 +29,7 @@ public class AdalineTest {
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private Network network;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -44,20 +44,20 @@ public class AdalineTest {
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear());
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new Network(List.of(layer));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new AdalineTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps)
|
||||
|
||||
@@ -25,7 +25,7 @@ public class GradientDescentTest {
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private Network network;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -40,9 +40,9 @@ public class GradientDescentTest {
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear());
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new Network(List.of(layer));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
@@ -50,9 +50,9 @@ public class GradientDescentTest {
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||
);
|
||||
|
||||
@@ -82,7 +82,7 @@ public class GradientDescentTest {
|
||||
context.learningRate = 0.2F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).apply();
|
||||
new GradientDescentCorrectionStrategy(context).run();
|
||||
|
||||
int index = ctx.epoch-1;
|
||||
if(index >= expectedGlobalLosses.size()) return;
|
||||
|
||||
@@ -24,7 +24,7 @@ public class SimplePerceptronTest {
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private Network network;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@@ -41,18 +41,18 @@ public class SimplePerceptronTest {
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new Network(List.of(layer));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps);
|
||||
|
||||
Reference in New Issue
Block a user