Implement main structure of framework
This commit is contained in:
@@ -1,18 +1,17 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.implementation.activation.Heaviside;
|
||||||
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@@ -64,14 +63,28 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||||
Trainer trainer = new AdalineTraining();
|
Layer layer = new Layer(List.of(neuron));
|
||||||
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
long start = System.currentTimeMillis();
|
TrainingContext context = new TrainingContext();
|
||||||
|
context.dataset = dataSet;
|
||||||
|
context.model = network;
|
||||||
|
|
||||||
trainer.train(n, 0.03F, andDataSet);
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(),
|
||||||
|
new DeltaStep(),
|
||||||
|
new SimpleLossStep(),
|
||||||
|
new SimpleErrorDetectionStep(),
|
||||||
|
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
||||||
|
);
|
||||||
|
|
||||||
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
|
pipeline
|
||||||
|
.stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000)
|
||||||
|
.afterEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
|
.withVerbose(true)
|
||||||
|
.run(context);
|
||||||
|
|
||||||
long end = System.currentTimeMillis();
|
|
||||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public interface CorrectionStrategy {
|
||||||
|
|
||||||
|
void apply(TrainingContext context);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public abstract class Neuron {
|
public abstract class Neuron implements Trainable {
|
||||||
|
|
||||||
protected List<Synapse> synapses;
|
protected List<Synapse> synapses;
|
||||||
protected Bias bias;
|
protected Bias bias;
|
||||||
@@ -19,37 +19,20 @@ public abstract class Neuron {
|
|||||||
this.activationFunction = func;
|
this.activationFunction = func;
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract float predict();
|
|
||||||
public abstract float calculateWeightedSum();
|
public abstract float calculateWeightedSum();
|
||||||
|
|
||||||
public int getSynCount(){
|
|
||||||
return this.synapses.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setInput(int index, Input input){
|
|
||||||
Synapse syn = this.synapses.get(index);
|
|
||||||
syn.setInput(input.getValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
public Bias getBias(){
|
|
||||||
return this.bias;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void updateBias(Weight weight) {
|
public void updateBias(Weight weight) {
|
||||||
this.bias.setWeight(weight.getValue());
|
this.bias.setWeight(weight.getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Synapse getSynapse(int index){
|
public void updateWeight(int index, Weight weight) {
|
||||||
return this.synapses.get(index);
|
this.synapses.get(index).setWeight(weight.getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Synapse> getSynapses() {
|
protected void setInputs(List<Input> inputs){
|
||||||
return new ArrayList<>(this.synapses);
|
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
||||||
|
Synapse syn = this.synapses.get(i);
|
||||||
|
syn.setInput(inputs.get(i));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWeight(int index, Weight weight){
|
|
||||||
Synapse syn = this.synapses.get(index);
|
|
||||||
syn.setWeight(weight.getValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
|
||||||
|
|
||||||
public abstract class NeuronTrainer {
|
|
||||||
|
|
||||||
private Trainable trainable;
|
|
||||||
|
|
||||||
public NeuronTrainer(Trainable trainable){
|
|
||||||
this.trainable = trainable;
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract void train();
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,14 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public interface Trainable {
|
public interface Trainable {
|
||||||
|
List<Float> predict(List<Input> inputs);
|
||||||
|
|
||||||
|
void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
|
|
||||||
void train(Neuron n, float learningRate, DataSet dataSet);
|
void train(TrainingContext context, List<TrainingStep> steps);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public interface TrainingStep {
|
||||||
|
|
||||||
|
void run(TrainingContext ctx);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public class Layer implements Trainable {
|
||||||
|
|
||||||
|
private final List<Neuron> neurons;
|
||||||
|
|
||||||
|
public Layer(List<Neuron> neurons) {
|
||||||
|
this.neurons = neurons;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> predict(List<Input> inputs) {
|
||||||
|
List<Float> result = new ArrayList<>();
|
||||||
|
for(Neuron neuron : this.neurons){
|
||||||
|
List<Float> res = neuron.predict(inputs);
|
||||||
|
result.addAll(res);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
|
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public class Network implements Trainable {
|
||||||
|
|
||||||
|
private final List<Layer> layers;
|
||||||
|
|
||||||
|
public Network(List<Layer> layers) {
|
||||||
|
this.layers = layers;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> predict(List<Input> inputs) {
|
||||||
|
List<Float> result = new ArrayList<>();
|
||||||
|
for(Layer layer : this.layers){
|
||||||
|
List<Float> res = layer.predict(inputs);
|
||||||
|
result.addAll(res);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
|
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,8 +14,8 @@ public class Synapse {
|
|||||||
return this.input.getValue();
|
return this.input.getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setInput(float value){
|
public void setInput(Input input){
|
||||||
this.input.setValue(value);
|
this.input.setValue(input.getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
public float getWeight() {
|
public float getWeight() {
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
|
public class TrainingContext {
|
||||||
|
public Trainable model;
|
||||||
|
public DataSet dataset;
|
||||||
|
public DataSetEntry currentEntry;
|
||||||
|
|
||||||
|
public float prediction;
|
||||||
|
public float delta;
|
||||||
|
public float localLoss;
|
||||||
|
public float globalLoss;
|
||||||
|
public float learningRate;
|
||||||
|
|
||||||
|
public int epoch;
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
|
public class TrainingPipeline {
|
||||||
|
|
||||||
|
private final List<TrainingStep> steps;
|
||||||
|
private Consumer<TrainingContext> afterAll;
|
||||||
|
private Predicate<TrainingContext> stopCondition;
|
||||||
|
|
||||||
|
private boolean verbose;
|
||||||
|
private boolean timeMeasurement;
|
||||||
|
|
||||||
|
public TrainingPipeline(List<TrainingStep> steps) {
|
||||||
|
this.steps = new ArrayList<>(steps);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
|
||||||
|
this.stopCondition = predicate;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
|
||||||
|
this.afterAll = consumer;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingPipeline withVerbose(boolean enabled) {
|
||||||
|
this.verbose = enabled;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingPipeline withTimeMeasurement(boolean enabled) {
|
||||||
|
this.timeMeasurement = enabled;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
do {
|
||||||
|
this.executeSteps(ctx);
|
||||||
|
if(this.afterAll != null) {
|
||||||
|
this.afterAll.accept(ctx);
|
||||||
|
}
|
||||||
|
} while (!this.stopCondition.test(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void executeSteps(TrainingContext ctx){
|
||||||
|
for (DataSetEntry sample : ctx.dataset) {
|
||||||
|
ctx.currentEntry = sample;
|
||||||
|
for (TrainingStep step : steps) {
|
||||||
|
step.run(ctx);
|
||||||
|
if(this.verbose) {
|
||||||
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
|
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||||
|
System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue());
|
||||||
|
System.out.printf("delta : %.2f\n", ctx.delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.epoch += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.implementation.activationFunction;
|
package com.naaturel.ANN.implementation.activation;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.implementation.activationFunction;
|
package com.naaturel.ANN.implementation.activation;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.correction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class SimpleCorrectionStrategy implements CorrectionStrategy {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply(TrainingContext context) {
|
||||||
|
context.model.forEachSynapse(syn -> {
|
||||||
|
float currentW = syn.getWeight();
|
||||||
|
float currentInput = syn.getInput();
|
||||||
|
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
||||||
|
syn.setWeight(newValue);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,22 +1,32 @@
|
|||||||
package com.naaturel.ANN.implementation.neuron;
|
package com.naaturel.ANN.implementation.neuron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class SimplePerceptron extends Neuron implements Trainable {
|
public class SimplePerceptron extends Neuron {
|
||||||
|
|
||||||
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
||||||
super(synapses, b, func);
|
super(synapses, b, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float predict() {
|
public List<Float> predict(List<Input> inputs) {
|
||||||
return activationFunction.accept(this);
|
super.setInputs(inputs);
|
||||||
|
return List.of(activationFunction.accept(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
|
this.synapses.forEach(consumer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
|
|
||||||
public class AdalineTraining implements Trainer {
|
/*public class AdalineTraining implements Trainer {
|
||||||
|
|
||||||
public AdalineTraining(){
|
public AdalineTraining(){
|
||||||
|
|
||||||
@@ -78,4 +78,4 @@ public class AdalineTraining implements Trainer {
|
|||||||
return (float) Math.pow(delta, 2)/2;
|
return (float) Math.pow(delta, 2)/2;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}*/
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class GradientDescentTraining implements Trainer {
|
/*public class GradientDescentTraining implements Trainer {
|
||||||
|
|
||||||
public GradientDescentTraining(){
|
public GradientDescentTraining(){
|
||||||
|
|
||||||
@@ -122,4 +122,4 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
return variance;
|
return variance;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}*/
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.neuron.Network;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class SimpleTraining implements Trainer {
|
public class SimpleTraining implements Trainer {
|
||||||
|
|
||||||
@@ -14,7 +17,12 @@ public class SimpleTraining implements Trainer {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
@Override
|
||||||
|
public void train(TrainingContext context, List<TrainingStep> steps) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int errorCount;
|
int errorCount;
|
||||||
|
|
||||||
@@ -65,5 +73,5 @@ public class SimpleTraining implements Trainer {
|
|||||||
private float calculateLoss(float delta){
|
private float calculateLoss(float delta){
|
||||||
return Math.abs(delta);
|
return Math.abs(delta);
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class DeltaStep implements TrainingStep {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
DataSet dataSet = ctx.dataset;
|
||||||
|
DataSetEntry entry = ctx.currentEntry;
|
||||||
|
Label label = dataSet.getLabel(entry);
|
||||||
|
|
||||||
|
ctx.delta = label.getValue() - ctx.prediction;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class PredictionStep implements TrainingStep {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
List<Input> inputs = new ArrayList<>();
|
||||||
|
for(Float f : ctx.currentEntry.getData()){
|
||||||
|
inputs.add(new Input(f));
|
||||||
|
}
|
||||||
|
List<Float> predictions = ctx.model.predict(inputs);
|
||||||
|
ctx.prediction = predictions.getFirst();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class SimpleErrorDetectionStep implements TrainingStep {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
ctx.globalLoss += ctx.localLoss;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class SimpleLossStep implements TrainingStep {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
ctx.localLoss = Math.abs(ctx.delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
|
|
||||||
|
public class WeightCorrectionStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final CorrectionStrategy correctionStrategy;
|
||||||
|
|
||||||
|
public WeightCorrectionStep(CorrectionStrategy strategy) {
|
||||||
|
this.correctionStrategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(TrainingContext ctx) {
|
||||||
|
this.correctionStrategy.apply(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user