Change signature of train method
This commit is contained in:
@@ -6,8 +6,10 @@ import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
|||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
|
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@@ -15,13 +17,13 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
|
|
||||||
int nbrInput = 25;
|
int nbrClass = 1;
|
||||||
int nbrClass = 4;
|
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{10, nbrClass};
|
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||||
|
|
||||||
|
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
|
||||||
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
List<Layer> layers = new ArrayList<>();
|
List<Layer> layers = new ArrayList<>();
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||||
@@ -38,7 +40,7 @@ public class Main {
|
|||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
Bias bias = new Bias(new Weight());
|
||||||
|
|
||||||
Neuron n = new Neuron(syns, bias, new TanH());
|
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
|
||||||
neurons.add(n);
|
neurons.add(n);
|
||||||
}
|
}
|
||||||
Layer layer = new Layer(neurons);
|
Layer layer = new Layer(neurons);
|
||||||
@@ -48,7 +50,28 @@ public class Main {
|
|||||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||||
|
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(network, dataset);
|
trainer.train(0.5F, network, dataset);
|
||||||
|
|
||||||
|
/*GraphVisualizer visualizer = new GraphVisualizer();
|
||||||
|
|
||||||
|
for (DataSetEntry entry : dataset) {
|
||||||
|
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||||
|
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
float min = -2F;
|
||||||
|
float max = 2F;
|
||||||
|
float step = 0.01F;
|
||||||
|
for (float x = min; x < max; x+=step){
|
||||||
|
for (float y = min; y < max; y+=step){
|
||||||
|
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
||||||
|
float predSeries = prediction > 0.5F ? 1 : 0;
|
||||||
|
visualizer.addPoint(Float.toString(predSeries), x, y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
visualizer.buildScatterGraph();*/
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
|||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(Model model, DataSet dataset);
|
void train(float learningRate, Model model, DataSet dataset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,12 +53,11 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public float calculateWeightedSum() {
|
public float calculateWeightedSum() {
|
||||||
float res = 0;
|
this.weightedSum = 0F;
|
||||||
res += this.bias.getWeight() * this.bias.getInput();
|
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
|
||||||
for(Synapse syn : this.synapses){
|
for(Synapse syn : this.synapses){
|
||||||
res += syn.getWeight() * syn.getInput();
|
this.weightedSum += syn.getWeight() * syn.getInput();
|
||||||
}
|
}
|
||||||
this.weightedSum = res;
|
|
||||||
return this.weightedSum;
|
return this.weightedSum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ public class TrainingPipeline {
|
|||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
this.visualizer.build();
|
this.visualizer.buildLineGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ public class AdalineTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(Model model, DataSet dataset) {
|
public void train(float learningRate, Model model, DataSet dataset) {
|
||||||
AdalineTrainingContext context = new AdalineTrainingContext();
|
AdalineTrainingContext context = new AdalineTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.003F;
|
context.learningRate = learningRate;
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
|
|||||||
@@ -13,14 +13,13 @@ import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
public class GradientBackpropagationTraining implements Trainer {
|
public class GradientBackpropagationTraining implements Trainer {
|
||||||
@Override
|
@Override
|
||||||
public void train(Model model, DataSet dataset) {
|
public void train(float learningRate, Model model, DataSet dataset) {
|
||||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.1F;
|
context.learningRate = learningRate;
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
@@ -30,8 +29,10 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.epoch == 250)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||||
.withVerbose(true)
|
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
|
||||||
|
.stopCondition(ctx -> ctx.epoch > 1000000)
|
||||||
|
.withVerbose(false)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
|
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(Model model, DataSet dataset) {
|
public void train(float learningRate, Model model, DataSet dataset) {
|
||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.0008F;
|
context.learningRate = learningRate;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ public class SimpleTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(Model model, DataSet dataset) {
|
public void train(float learningRate, Model model, DataSet dataset) {
|
||||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = 0.3F;
|
context.learningRate = learningRate;
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<AlgorithmStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new SimplePredictionStep(context),
|
||||||
|
|||||||
@@ -3,23 +3,52 @@ package com.naaturel.ANN.infrastructure.dataset;
|
|||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class DataSet implements Iterable<DataSetEntry>{
|
public class DataSet implements Iterable<DataSetEntry>{
|
||||||
|
|
||||||
private final Map<DataSetEntry, Labels> data;
|
private final Map<DataSetEntry, Labels> data;
|
||||||
|
|
||||||
|
private final int nbrInputs;
|
||||||
|
private final int nbrLabels;
|
||||||
|
|
||||||
public DataSet() {
|
public DataSet() {
|
||||||
this(new LinkedHashMap<>());
|
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataSet(Map<DataSetEntry, Labels> data){
|
public DataSet(Map<DataSetEntry, Labels> data){
|
||||||
this.data = data;
|
this.data = data;
|
||||||
|
this.nbrInputs = this.calculateNbrInput();
|
||||||
|
this.nbrLabels = this.calculateNbrLabel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private int calculateNbrInput(){
|
||||||
|
//assumes every entry are the same length
|
||||||
|
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||||
|
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||||
|
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int calculateNbrLabel(){
|
||||||
|
//assumes every label are the same length
|
||||||
|
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||||
|
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||||
|
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public int size() {
|
public int size() {
|
||||||
return data.size();
|
return data.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int getNbrInputs() {
|
||||||
|
return this.nbrInputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNbrLabels(){
|
||||||
|
return this.nbrLabels;
|
||||||
|
}
|
||||||
|
|
||||||
public List<DataSetEntry> getData(){
|
public List<DataSetEntry> getData(){
|
||||||
return new ArrayList<>(this.data.keySet());
|
return new ArrayList<>(this.data.keySet());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.infrastructure.graph;
|
|||||||
import org.jfree.chart.ChartFactory;
|
import org.jfree.chart.ChartFactory;
|
||||||
import org.jfree.chart.ChartPanel;
|
import org.jfree.chart.ChartPanel;
|
||||||
import org.jfree.chart.JFreeChart;
|
import org.jfree.chart.JFreeChart;
|
||||||
|
import org.jfree.chart.plot.XYPlot;
|
||||||
import org.jfree.data.xy.XYSeries;
|
import org.jfree.data.xy.XYSeries;
|
||||||
import org.jfree.data.xy.XYSeriesCollection;
|
import org.jfree.data.xy.XYSeriesCollection;
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ public class GraphVisualizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void build(){
|
public void buildLineGraph(){
|
||||||
JFreeChart chart = ChartFactory.createXYLineChart(
|
JFreeChart chart = ChartFactory.createXYLineChart(
|
||||||
"Model learning", "X", "Y", dataset
|
"Model learning", "X", "Y", dataset
|
||||||
);
|
);
|
||||||
@@ -39,4 +40,21 @@ public class GraphVisualizer {
|
|||||||
frame.pack();
|
frame.pack();
|
||||||
frame.setVisible(true);
|
frame.setVisible(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void buildScatterGraph(){
|
||||||
|
JFreeChart chart = ChartFactory.createScatterPlot(
|
||||||
|
"Predictions", "X", "Y", dataset
|
||||||
|
);
|
||||||
|
XYPlot plot = chart.getXYPlot();
|
||||||
|
plot.getDomainAxis().setRange(-2, 2);
|
||||||
|
plot.getRangeAxis().setRange(-2, 2);
|
||||||
|
|
||||||
|
JFrame frame = new JFrame("Predictions");
|
||||||
|
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||||
|
frame.add(new ChartPanel(chart));
|
||||||
|
frame.pack();
|
||||||
|
frame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
4
src/main/resources/assets/xor.csv
Normal file
4
src/main/resources/assets/xor.csv
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
0,0,0
|
||||||
|
0,1,1
|
||||||
|
1,0,1
|
||||||
|
1,1,0
|
||||||
|
Reference in New Issue
Block a user