Just a regular commit
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
@@ -11,6 +12,7 @@ import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
|||||||
import com.naaturel.ANN.implementation.activation.Heaviside;
|
import com.naaturel.ANN.implementation.activation.Heaviside;
|
||||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
@@ -67,25 +69,8 @@ public class Main {
|
|||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
TrainingContext context = new TrainingContext();
|
Trainer trainer = new SimpleTraining();
|
||||||
context.dataset = orDataSet;
|
trainer.train(network, orDataSet);
|
||||||
context.model = network;
|
|
||||||
context.learningRate = 0.3F;
|
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
|
||||||
new PredictionStep(),
|
|
||||||
new DeltaStep(),
|
|
||||||
new SimpleLossStep(),
|
|
||||||
new SimpleErrorDetectionStep(),
|
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
|
||||||
);
|
|
||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
|
||||||
pipeline
|
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
|
||||||
.withVerbose(true)
|
|
||||||
.run(context);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
|
void train(Trainable model, DataSet dataset);
|
||||||
void train(TrainingContext context, List<TrainingStep> steps);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Network;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
import com.naaturel.ANN.domain.model.training.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
|
||||||
@@ -18,8 +18,26 @@ public class SimpleTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(TrainingContext context, List<TrainingStep> steps) {
|
public void train(Trainable model, DataSet dataset) {
|
||||||
|
TrainingContext context = new TrainingContext();
|
||||||
|
context.dataset = dataset;
|
||||||
|
context.model = model;
|
||||||
|
context.learningRate = 0.3F;
|
||||||
|
|
||||||
|
List<TrainingStep> steps = List.of(
|
||||||
|
new PredictionStep(),
|
||||||
|
new DeltaStep(),
|
||||||
|
new SimpleLossStep(),
|
||||||
|
new SimpleErrorDetectionStep(),
|
||||||
|
new WeightCorrectionStep(new SimpleCorrectionStrategy())
|
||||||
|
);
|
||||||
|
|
||||||
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
|
pipeline
|
||||||
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||||
|
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
|
.withVerbose(true)
|
||||||
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
|
|||||||
Reference in New Issue
Block a user