Change signature of train method
This commit is contained in:
@@ -6,8 +6,10 @@ import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@@ -15,13 +17,13 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
int nbrInput = 25;
|
||||
int nbrClass = 4;
|
||||
|
||||
int[] neuronPerLayer = new int[]{10, nbrClass};
|
||||
int nbrClass = 1;
|
||||
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_5.csv", nbrClass);
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{10, dataset.getNbrLabels()};
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
@@ -38,7 +40,7 @@ public class Main {
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(syns, bias, new TanH());
|
||||
Neuron n = new Neuron(syns, bias, new Sigmoid(2));
|
||||
neurons.add(n);
|
||||
}
|
||||
Layer layer = new Layer(neurons);
|
||||
@@ -48,7 +50,28 @@ public class Main {
|
||||
FullyConnectedNetwork network = new FullyConnectedNetwork(layers);
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(network, dataset);
|
||||
trainer.train(0.5F, network, dataset);
|
||||
|
||||
/*GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
visualizer.addPoint("Label " + label.getFirst(), entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
}
|
||||
|
||||
float min = -2F;
|
||||
float max = 2F;
|
||||
float step = 0.01F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float prediction = network.predict(List.of(new Input(x), new Input(y))).getFirst();
|
||||
float predSeries = prediction > 0.5F ? 1 : 0;
|
||||
visualizer.addPoint(Float.toString(predSeries), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
visualizer.buildScatterGraph();*/
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package com.naaturel.ANN.domain.abstraction;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public interface Trainer {
|
||||
void train(Model model, DataSet dataset);
|
||||
void train(float learningRate, Model model, DataSet dataset);
|
||||
}
|
||||
|
||||
@@ -53,12 +53,11 @@ public class Neuron implements Model {
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
this.weightedSum = 0F;
|
||||
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : this.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
this.weightedSum += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
this.weightedSum = res;
|
||||
return this.weightedSum;
|
||||
}
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ public class TrainingPipeline {
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
}
|
||||
this.visualizer.build();
|
||||
this.visualizer.buildLineGraph();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ public class AdalineTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
AdalineTrainingContext context = new AdalineTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.003F;
|
||||
context.learningRate = learningRate;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
|
||||
@@ -13,14 +13,13 @@ import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
GradientBackpropagationContext context = new GradientBackpropagationContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.1F;
|
||||
context.learningRate = learningRate;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
@@ -30,8 +29,10 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.epoch == 250)
|
||||
.withVerbose(true)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||
.afterEpoch(ctx -> ctx.globalLoss = ctx.localLoss/dataset.size())
|
||||
.stopCondition(ctx -> ctx.epoch > 1000000)
|
||||
.withVerbose(false)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
|
||||
|
||||
@@ -23,11 +23,11 @@ public class GradientDescentTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.0008F;
|
||||
context.learningRate = learningRate;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
|
||||
@@ -16,11 +16,11 @@ public class SimpleTraining implements Trainer {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void train(Model model, DataSet dataset) {
|
||||
public void train(float learningRate, Model model, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = 0.3F;
|
||||
context.learningRate = learningRate;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
|
||||
@@ -3,23 +3,52 @@ package com.naaturel.ANN.infrastructure.dataset;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class DataSet implements Iterable<DataSetEntry>{
|
||||
|
||||
private final Map<DataSetEntry, Labels> data;
|
||||
|
||||
private final int nbrInputs;
|
||||
private final int nbrLabels;
|
||||
|
||||
public DataSet() {
|
||||
this(new LinkedHashMap<>());
|
||||
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
|
||||
}
|
||||
|
||||
public DataSet(Map<DataSetEntry, Labels> data){
|
||||
this.data = data;
|
||||
this.nbrInputs = this.calculateNbrInput();
|
||||
this.nbrLabels = this.calculateNbrLabel();
|
||||
}
|
||||
|
||||
private int calculateNbrInput(){
|
||||
//assumes every entry are the same length
|
||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
|
||||
}
|
||||
|
||||
private int calculateNbrLabel(){
|
||||
//assumes every label are the same length
|
||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
||||
}
|
||||
|
||||
|
||||
public int size() {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
public int getNbrInputs() {
|
||||
return this.nbrInputs;
|
||||
}
|
||||
|
||||
public int getNbrLabels(){
|
||||
return this.nbrLabels;
|
||||
}
|
||||
|
||||
public List<DataSetEntry> getData(){
|
||||
return new ArrayList<>(this.data.keySet());
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.naaturel.ANN.infrastructure.graph;
|
||||
import org.jfree.chart.ChartFactory;
|
||||
import org.jfree.chart.ChartPanel;
|
||||
import org.jfree.chart.JFreeChart;
|
||||
import org.jfree.chart.plot.XYPlot;
|
||||
import org.jfree.data.xy.XYSeries;
|
||||
import org.jfree.data.xy.XYSeriesCollection;
|
||||
|
||||
@@ -29,7 +30,7 @@ public class GraphVisualizer {
|
||||
}
|
||||
}
|
||||
|
||||
public void build(){
|
||||
public void buildLineGraph(){
|
||||
JFreeChart chart = ChartFactory.createXYLineChart(
|
||||
"Model learning", "X", "Y", dataset
|
||||
);
|
||||
@@ -39,4 +40,21 @@ public class GraphVisualizer {
|
||||
frame.pack();
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
|
||||
public void buildScatterGraph(){
|
||||
JFreeChart chart = ChartFactory.createScatterPlot(
|
||||
"Predictions", "X", "Y", dataset
|
||||
);
|
||||
XYPlot plot = chart.getXYPlot();
|
||||
plot.getDomainAxis().setRange(-2, 2);
|
||||
plot.getRangeAxis().setRange(-2, 2);
|
||||
|
||||
JFrame frame = new JFrame("Predictions");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.add(new ChartPanel(chart));
|
||||
frame.pack();
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
4
src/main/resources/assets/xor.csv
Normal file
4
src/main/resources/assets/xor.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
0,0,0
|
||||
0,1,1
|
||||
1,0,1
|
||||
1,1,0
|
||||
|
Reference in New Issue
Block a user