Compare commits
6 Commits
a2452fb4b8
...
76465ab6ee
| Author | SHA1 | Date | |
|---|---|---|---|
| 76465ab6ee | |||
| 65d3a0e3e4 | |||
| 0217607e9b | |||
| 5ace4952fb | |||
| a84c3d999d | |||
| b25aaba088 |
@@ -3,16 +3,13 @@ package com.naaturel.ANN;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@@ -20,43 +17,11 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
DataSet orDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||
|
||||
DataSet andDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
|
||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
|
||||
));
|
||||
DataSet andDataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
@@ -64,14 +29,12 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||
Trainer trainer = new AdalineTraining();
|
||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
Network network = new Network(List.of(layer));
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
Trainer trainer = new SimpleTraining();
|
||||
trainer.train(network, andDataset);
|
||||
|
||||
trainer.train(n, 0.03F, andDataSet);
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface AlgorithmStrategy {
|
||||
|
||||
void apply();
|
||||
|
||||
}
|
||||
14
src/main/java/com/naaturel/ANN/domain/abstraction/Model.java
Normal file
14
src/main/java/com/naaturel/ANN/domain/abstraction/Model.java
Normal file
@@ -0,0 +1,14 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Model {
|
||||
int synCount();
|
||||
void forEachSynapse(Consumer<Synapse> consumer);
|
||||
List<Float> predict(List<Input> inputs);
|
||||
|
||||
}
|
||||
@@ -4,10 +4,9 @@ import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class Neuron {
|
||||
public abstract class Neuron implements Model {
|
||||
|
||||
protected List<Synapse> synapses;
|
||||
protected Bias bias;
|
||||
@@ -19,37 +18,25 @@ public abstract class Neuron {
|
||||
this.activationFunction = func;
|
||||
}
|
||||
|
||||
public abstract float predict();
|
||||
public abstract float calculateWeightedSum();
|
||||
|
||||
public int getSynCount(){
|
||||
return this.synapses.size();
|
||||
}
|
||||
|
||||
public void setInput(int index, Input input){
|
||||
Synapse syn = this.synapses.get(index);
|
||||
syn.setInput(input.getValue());
|
||||
}
|
||||
|
||||
public Bias getBias(){
|
||||
return this.bias;
|
||||
}
|
||||
|
||||
public void updateBias(Weight weight) {
|
||||
this.bias.setWeight(weight.getValue());
|
||||
}
|
||||
|
||||
public Synapse getSynapse(int index){
|
||||
return this.synapses.get(index);
|
||||
public void updateWeight(int index, Weight weight) {
|
||||
this.synapses.get(index).setWeight(weight.getValue());
|
||||
}
|
||||
|
||||
public List<Synapse> getSynapses() {
|
||||
return new ArrayList<>(this.synapses);
|
||||
protected void setInputs(List<Input> inputs){
|
||||
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
||||
Synapse syn = this.synapses.get(i);
|
||||
syn.setInput(inputs.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
public void setWeight(int index, Weight weight){
|
||||
Synapse syn = this.synapses.get(index);
|
||||
syn.setWeight(weight.getValue());
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.synapses.size()+1; //take the bias in account
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
public abstract class NeuronTrainer {
|
||||
|
||||
private Trainable trainable;
|
||||
|
||||
public NeuronTrainer(Trainable trainable){
|
||||
this.trainable = trainable;
|
||||
}
|
||||
|
||||
public abstract void train();
|
||||
|
||||
}
|
||||
@@ -3,6 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
|
||||
public interface Trainer {
|
||||
|
||||
void train(Neuron n, float learningRate, DataSet dataSet);
|
||||
void train(Model model, DataSet dataset);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
|
||||
public abstract class TrainingContext {
|
||||
public Model model;
|
||||
public DataSet dataset;
|
||||
public DataSetEntry currentEntry;
|
||||
|
||||
public Label currentLabel;
|
||||
public float prediction;
|
||||
public float delta;
|
||||
|
||||
public float globalLoss;
|
||||
public float localLoss;
|
||||
|
||||
public float learningRate;
|
||||
public int epoch;
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
public interface Trainable {
|
||||
|
||||
public interface TrainingStep {
|
||||
|
||||
void run();
|
||||
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSet implements Iterable<DataSetEntry>{
|
||||
@@ -7,7 +9,7 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
private Map<DataSetEntry, Label> data;
|
||||
|
||||
public DataSet() {
|
||||
this(new HashMap<>());
|
||||
this(new LinkedHashMap<>());
|
||||
}
|
||||
|
||||
public DataSet(Map<DataSetEntry, Label> data){
|
||||
@@ -31,15 +33,17 @@ public class DataSet implements Iterable<DataSetEntry>{
|
||||
|
||||
float maxAbs = entries.stream()
|
||||
.flatMap(e -> e.getData().stream())
|
||||
.map(Input::getValue)
|
||||
.map(Math::abs)
|
||||
.max(Float::compare)
|
||||
.orElse(1.0F);
|
||||
|
||||
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
||||
for (DataSetEntry entry : entries) {
|
||||
List<Float> normalizedData = new ArrayList<>();
|
||||
for (float value : entry.getData()) {
|
||||
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
|
||||
List<Input> normalizedData = new ArrayList<>();
|
||||
for (Input input : entry.getData()) {
|
||||
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
||||
normalizedData.add(normalizedInput);
|
||||
}
|
||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSetEntry implements Iterable<Float> {
|
||||
public class DataSetEntry implements Iterable<Input> {
|
||||
|
||||
private List<Float> data;
|
||||
private List<Input> data;
|
||||
|
||||
public DataSetEntry(List<Float> data){
|
||||
public DataSetEntry(List<Input> data){
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public List<Float> getData() {
|
||||
public List<Input> getData() {
|
||||
return new ArrayList<>(data);
|
||||
}
|
||||
|
||||
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Float> iterator() {
|
||||
public Iterator<Input> iterator() {
|
||||
return this.data.iterator();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
public class DatasetExtractor {
|
||||
|
||||
public DataSet extract(String path) {
|
||||
Map<DataSetEntry, Label> data = new LinkedHashMap<>();
|
||||
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] parts = line.split(",");
|
||||
List<Input> inputs = new ArrayList<>();
|
||||
for (int i = 0; i < parts.length - 1; i++) {
|
||||
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
|
||||
}
|
||||
float label = Float.parseFloat(parts[parts.length - 1].trim());
|
||||
data.put(new DataSetEntry(inputs), new Label(label));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
||||
}
|
||||
|
||||
return new DataSet(data);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Layer implements Model {
|
||||
|
||||
private final List<Neuron> neurons;
|
||||
|
||||
public Layer(List<Neuron> neurons) {
|
||||
this.neurons = neurons;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Float> result = new ArrayList<>();
|
||||
for(Neuron neuron : this.neurons){
|
||||
List<Float> res = neuron.predict(inputs);
|
||||
result.addAll(res);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for (Neuron neuron : this.neurons) {
|
||||
res += neuron.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Network implements Model {
|
||||
|
||||
private final List<Layer> layers;
|
||||
|
||||
public Network(List<Layer> layers) {
|
||||
this.layers = layers;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
List<Float> result = new ArrayList<>();
|
||||
for(Layer layer : this.layers){
|
||||
List<Float> res = layer.predict(inputs);
|
||||
result.addAll(res);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,8 @@ public class Synapse {
|
||||
return this.input.getValue();
|
||||
}
|
||||
|
||||
public void setInput(float value){
|
||||
this.input.setValue(value);
|
||||
public void setInput(Input input){
|
||||
this.input.setValue(input.getValue());
|
||||
}
|
||||
|
||||
public float getWeight() {
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.naaturel.ANN.domain.model.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class TrainingPipeline {
|
||||
|
||||
private final List<TrainingStep> steps;
|
||||
private Consumer<TrainingContext> beforeEpoch;
|
||||
private Consumer<TrainingContext> afterEpoch;
|
||||
private Predicate<TrainingContext> stopCondition;
|
||||
|
||||
private boolean verbose;
|
||||
private boolean timeMeasurement;
|
||||
|
||||
public TrainingPipeline(List<TrainingStep> steps) {
|
||||
this.steps = new ArrayList<>(steps);
|
||||
this.stopCondition = (ctx) -> false;
|
||||
this.beforeEpoch = (context -> {});
|
||||
this.afterEpoch = (context -> {});
|
||||
}
|
||||
|
||||
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
|
||||
this.stopCondition = predicate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
|
||||
this.beforeEpoch = consumer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
|
||||
this.afterEpoch = consumer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withVerbose(boolean enabled) {
|
||||
this.verbose = enabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withTimeMeasurement(boolean enabled) {
|
||||
this.timeMeasurement = enabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public void run(TrainingContext ctx) {
|
||||
do {
|
||||
this.beforeEpoch.accept(ctx);
|
||||
this.executeSteps(ctx);
|
||||
this.afterEpoch.accept(ctx);
|
||||
} while (!this.stopCondition.test(ctx));
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
for (DataSetEntry entry : ctx.dataset) {
|
||||
ctx.currentEntry = entry;
|
||||
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
||||
for (TrainingStep step : steps) {
|
||||
step.run();
|
||||
}
|
||||
if(this.verbose) {
|
||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
||||
System.out.printf("delta : %.2f, ", ctx.delta);
|
||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||
}
|
||||
}
|
||||
if(this.verbose) {
|
||||
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||
}
|
||||
ctx.epoch += 1;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.forEachSynapse(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
float c = syn.getWeight() + corrector;
|
||||
syn.setWeight(c);
|
||||
i.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.forEachSynapse(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
corrector += context.learningRate * context.delta * syn.getInput();
|
||||
context.correctorTerms.set(i.get(), corrector);
|
||||
i.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTrainingContext extends TrainingContext {
|
||||
|
||||
public List<Float> correctorTerms;
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.naaturel.ANN.implementation.activationFunction;
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public SquareLossStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
||||
this.context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -2,30 +2,38 @@ package com.naaturel.ANN.implementation.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class SimplePerceptron extends Neuron implements Trainable {
|
||||
public class SimplePerceptron extends Neuron {
|
||||
|
||||
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
||||
super(synapses, b, func);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float predict() {
|
||||
return activationFunction.accept(this);
|
||||
public List<Float> predict(List<Input> inputs) {
|
||||
super.setInputs(inputs);
|
||||
return List.of(activationFunction.accept(this));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
consumer.accept(this.bias);
|
||||
this.synapses.forEach(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : super.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.naaturel.ANN.implementation.activationFunction;
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
@@ -12,6 +12,6 @@ public class Heaviside implements ActivationFunction {
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
return weightedSum <= 0 ? 0:1;
|
||||
return weightedSum < 0 ? 0:1;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
|
||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||
context.model.forEachSynapse(syn -> {
|
||||
float currentW = syn.getWeight();
|
||||
float currentInput = syn.getInput();
|
||||
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
||||
syn.setWeight(newValue);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
|
||||
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleDeltaStrategy(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
DataSet dataSet = context.dataset;
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
Label label = dataSet.getLabel(entry);
|
||||
|
||||
context.delta = label.getValue() - context.prediction;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleLossStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
this.context.localLoss = Math.abs(this.context.delta);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePredictionStrategy implements AlgorithmStrategy {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimplePredictionStrategy(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void apply() {
|
||||
List<Float> predictions = context.model.predict(context.currentEntry.getData());
|
||||
context.prediction = predictions.getFirst();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
public class SimpleTrainingContext extends TrainingContext {
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
|
||||
public class AdalineTraining implements Trainer {
|
||||
/*public class AdalineTraining implements Trainer {
|
||||
|
||||
public AdalineTraining(){
|
||||
|
||||
@@ -78,4 +78,4 @@ public class AdalineTraining implements Trainer {
|
||||
return (float) Math.pow(delta, 2)/2;
|
||||
}
|
||||
|
||||
}
|
||||
}*/
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTraining implements Trainer {
|
||||
@@ -19,7 +22,38 @@ public class GradientDescentTraining implements Trainer {
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.2F;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SquareLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
|
||||
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 50)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
for (int i = 0; i < model.synCount(); i++){
|
||||
context.correctorTerms.add(0F);
|
||||
}
|
||||
})
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
|
||||
.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 402;
|
||||
float errorThreshold = 0F;
|
||||
@@ -120,6 +154,6 @@ public class GradientDescentTraining implements Trainer {
|
||||
variance /= dataSet.size();
|
||||
|
||||
return variance;
|
||||
}
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimpleTraining implements Trainer {
|
||||
|
||||
@@ -14,7 +16,30 @@ public class SimpleTraining implements Trainer {
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 10)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.withVerbose(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int errorCount;
|
||||
|
||||
@@ -65,5 +90,5 @@ public class SimpleTraining implements Trainer {
|
||||
private float calculateLoss(float delta){
|
||||
return Math.abs(delta);
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
|
||||
public class DeltaStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
|
||||
public DeltaStep(AlgorithmStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class ErrorRegistrationStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy strategy;
|
||||
|
||||
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class LossStep implements TrainingStep {
|
||||
|
||||
|
||||
private final AlgorithmStrategy lossStrategy;
|
||||
|
||||
public LossStep(AlgorithmStrategy strategy) {
|
||||
this.lossStrategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.lossStrategy.apply();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class PredictionStep implements TrainingStep {
|
||||
|
||||
private final SimplePredictionStrategy strategy;
|
||||
|
||||
public PredictionStep(SimplePredictionStrategy strategy) {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.strategy.apply();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.training.steps;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
|
||||
public class WeightCorrectionStep implements TrainingStep {
|
||||
|
||||
private final AlgorithmStrategy correctionStrategy;
|
||||
|
||||
public WeightCorrectionStep(AlgorithmStrategy strategy) {
|
||||
this.correctionStrategy = strategy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
this.correctionStrategy.apply();
|
||||
}
|
||||
}
|
||||
85
src/test/java/perceptron/simplePerceptronTest.java
Normal file
85
src/test/java/perceptron/simplePerceptronTest.java
Normal file
@@ -0,0 +1,85 @@
|
||||
package perceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
||||
public class simplePerceptronTest {
|
||||
|
||||
private DataSet dataset;
|
||||
private SimpleTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private Network network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@BeforeEach
|
||||
public void init(){
|
||||
dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new Network(List.of(layer));
|
||||
|
||||
context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps);
|
||||
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
|
||||
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_the_whole_algorithm(){
|
||||
|
||||
List<Float> expectedGlobalLosses = List.of(
|
||||
2.0F,
|
||||
3.0F,
|
||||
3.0F,
|
||||
2.0F,
|
||||
1.0F,
|
||||
0.0F
|
||||
);
|
||||
|
||||
context.learningRate = 1F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
int index = ctx.epoch-1;
|
||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
||||
});
|
||||
|
||||
pipeline.run(context);
|
||||
assertEquals(6, context.epoch);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user