Compare commits
5 Commits
main
...
a2452fb4b8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2452fb4b8 | ||
| 2936bf33bf | |||
|
|
1da32862f5 | ||
|
|
fbf2a571ef | ||
|
|
89d9abe329 |
@@ -10,9 +10,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("org.jfree:jfreechart:1.5.4")
|
|
||||||
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
|
|
||||||
|
|
||||||
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
||||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||||
|
|||||||
14
config.json
14
config.json
@@ -1,14 +0,0 @@
|
|||||||
{
|
|
||||||
"model": {
|
|
||||||
"new": true,
|
|
||||||
"parameters": [2, 4, 2, 1],
|
|
||||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
|
|
||||||
},
|
|
||||||
"training" : {
|
|
||||||
"learning_rate" : 0.0003,
|
|
||||||
"max_epoch" : 5000
|
|
||||||
},
|
|
||||||
"dataset" : {
|
|
||||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,110 +1,40 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
|
||||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||||
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
|
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args){
|
||||||
|
|
||||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
DataSet dataset = new DatasetExtractor()
|
||||||
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
||||||
|
|
||||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
DataSet orDataset = new DatasetExtractor()
|
||||||
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
|
||||||
String modelPath = config.getModelProperty("path", String.class);
|
|
||||||
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
|
||||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
|
||||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
|
||||||
|
|
||||||
int nbrClass = 5;
|
List<Synapse> syns = new ArrayList<>();
|
||||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
int nbrInput = dataset.getNbrInputs();
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
|
|
||||||
ModelSnapshot snapshot;
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Model network;
|
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
||||||
if(newModel){
|
Layer layer = new Layer(List.of(neuron));
|
||||||
network = createNetwork(modelParameters, nbrInput);
|
Network network = new Network(List.of(layer));
|
||||||
snapshot = new ModelSnapshot(network);
|
|
||||||
System.out.println("Parameters: " + network.synCount());
|
Trainer trainer = new GradientDescentTraining();
|
||||||
Trainer trainer = new GradientBackpropagationTraining();
|
trainer.train(network, dataset);
|
||||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
|
||||||
snapshot.saveToFile(modelPath);
|
|
||||||
} else {
|
|
||||||
snapshot = new ModelSnapshot();
|
|
||||||
snapshot.loadFromFile(modelPath);
|
|
||||||
network = snapshot.getModel();
|
|
||||||
}
|
|
||||||
//plotGraph(dataset, network);
|
|
||||||
|
|
||||||
new ModelVisualizer(network)
|
|
||||||
.withWeights(true)
|
|
||||||
.display();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
|
||||||
int neuronId = 0;
|
|
||||||
List<Layer> layers = new ArrayList<>();
|
|
||||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
|
||||||
|
|
||||||
List<Neuron> neurons = new ArrayList<>();
|
|
||||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
|
||||||
|
|
||||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
|
||||||
for (int k=0; k < nbrSyn; k++){
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight());
|
|
||||||
|
|
||||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
|
||||||
neurons.add(n);
|
|
||||||
neuronId++;
|
|
||||||
}
|
|
||||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
|
||||||
layers.add(layer);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void plotGraph(DataSet dataset, Model network){
|
|
||||||
GraphVisualizer visualizer = new GraphVisualizer();
|
|
||||||
|
|
||||||
/*for (DataSetEntry entry : dataset) {
|
|
||||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
|
||||||
label.forEach(l -> {
|
|
||||||
visualizer.addPoint("Label " + l,
|
|
||||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
|
||||||
});
|
|
||||||
}*/
|
|
||||||
|
|
||||||
float min = -50F;
|
|
||||||
float max = 50F;
|
|
||||||
float step = 0.03F;
|
|
||||||
for (float x = min; x < max; x+=step){
|
|
||||||
for (float y = min; y < max; y+=step){
|
|
||||||
float[] predictions = network.predict(new float[]{x, y});
|
|
||||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
visualizer.buildScatterGraph((int)min-1, (int)max+1);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
public interface ActivationFunction {
|
public interface ActivationFunction {
|
||||||
|
|
||||||
float accept(Neuron n);
|
float accept(Neuron n);
|
||||||
float derivative(float value);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
public interface AlgorithmStrategy {
|
||||||
|
|
||||||
|
void apply();
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -9,12 +8,7 @@ import java.util.function.Consumer;
|
|||||||
|
|
||||||
public interface Model {
|
public interface Model {
|
||||||
int synCount();
|
int synCount();
|
||||||
int neuronCount();
|
void applyOnSynapses(Consumer<Synapse> consumer);
|
||||||
int layerIndexOf(Neuron n);
|
List<Float> predict(List<Input> inputs);
|
||||||
int indexInLayerOf(Neuron n);
|
|
||||||
void forEachNeuron(Consumer<Neuron> consumer);
|
|
||||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
|
||||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
|
||||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
|
||||||
float[] predict(float[] inputs);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
public interface Network {
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public abstract class Neuron implements Model {
|
||||||
|
|
||||||
|
protected List<Synapse> synapses;
|
||||||
|
protected Bias bias;
|
||||||
|
protected ActivationFunction activationFunction;
|
||||||
|
|
||||||
|
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||||
|
this.synapses = synapses;
|
||||||
|
this.bias = bias;
|
||||||
|
this.activationFunction = func;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract float calculateWeightedSum();
|
||||||
|
|
||||||
|
public void updateBias(Weight weight) {
|
||||||
|
this.bias.setWeight(weight.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateWeight(int index, Weight weight) {
|
||||||
|
this.synapses.get(index).setWeight(weight.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setInputs(List<Input> inputs){
|
||||||
|
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
||||||
|
Synapse syn = this.synapses.get(i);
|
||||||
|
syn.setInput(inputs.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
return this.synapses.size()+1; //take the bias in account
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
|
||||||
public interface Trainer {
|
public interface Trainer {
|
||||||
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
void train(Model model, DataSet dataset);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,21 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class TrainingContext {
|
public abstract class TrainingContext {
|
||||||
public Model model;
|
public Model model;
|
||||||
public DataSet dataset;
|
public DataSet dataset;
|
||||||
public DataSetEntry currentEntry;
|
public DataSetEntry currentEntry;
|
||||||
|
|
||||||
public List<Float> expectations;
|
public Label currentLabel;
|
||||||
public float[] predictions;
|
public float prediction;
|
||||||
public float[] deltas;
|
public float delta;
|
||||||
|
|
||||||
public float globalLoss;
|
public float globalLoss;
|
||||||
public float localLoss;
|
public float localLoss;
|
||||||
|
|
||||||
public float learningRate;
|
public float learningRate;
|
||||||
public int epoch;
|
public int epoch;
|
||||||
|
|
||||||
public TrainingContext(Model model, DataSet dataset) {
|
|
||||||
this.model = model;
|
|
||||||
this.dataset = dataset;
|
|
||||||
this.deltas = new float[dataset.getNbrLabels()];
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
@FunctionalInterface
|
public interface TrainingStep {
|
||||||
public interface AlgorithmStep {
|
|
||||||
|
|
||||||
void run();
|
void run();
|
||||||
|
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.dataset;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class DataSet implements Iterable<DataSetEntry>{
|
||||||
|
|
||||||
|
private Map<DataSetEntry, Label> data;
|
||||||
|
|
||||||
|
public DataSet() {
|
||||||
|
this(new HashMap<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataSet(Map<DataSetEntry, Label> data){
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return data.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<DataSetEntry> getData(){
|
||||||
|
return new ArrayList<>(this.data.keySet());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Label getLabel(DataSetEntry entry){
|
||||||
|
return this.data.get(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataSet toNormalized() {
|
||||||
|
List<DataSetEntry> entries = this.getData();
|
||||||
|
|
||||||
|
float maxAbs = entries.stream()
|
||||||
|
.flatMap(e -> e.getData().stream())
|
||||||
|
.map(Input::getValue)
|
||||||
|
.map(Math::abs)
|
||||||
|
.max(Float::compare)
|
||||||
|
.orElse(1.0F);
|
||||||
|
|
||||||
|
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
||||||
|
for (DataSetEntry entry : entries) {
|
||||||
|
List<Input> normalizedData = new ArrayList<>();
|
||||||
|
for (Input input : entry.getData()) {
|
||||||
|
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
||||||
|
normalizedData.add(normalizedInput);
|
||||||
|
}
|
||||||
|
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DataSet(normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<DataSetEntry> iterator() {
|
||||||
|
return this.data.keySet().iterator();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.naaturel.ANN.infrastructure.dataset;
|
package com.naaturel.ANN.domain.model.dataset;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.dataset;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class DatasetExtractor {
|
||||||
|
|
||||||
|
public DataSet extract(String path) {
|
||||||
|
Map<DataSetEntry, Label> data = new HashMap<>();
|
||||||
|
|
||||||
|
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
String[] parts = line.split(",");
|
||||||
|
List<Input> inputs = new ArrayList<>();
|
||||||
|
for (int i = 0; i < parts.length - 1; i++) {
|
||||||
|
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
|
||||||
|
}
|
||||||
|
float label = Float.parseFloat(parts[parts.length - 1].trim());
|
||||||
|
data.put(new DataSetEntry(inputs), new Label(label));
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DataSet(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.dataset;
|
||||||
|
|
||||||
|
public class Label {
|
||||||
|
|
||||||
|
private float value;
|
||||||
|
|
||||||
|
public Label(float value){
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public float getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a fully connected neural network
|
|
||||||
*/
|
|
||||||
public class FullyConnectedNetwork implements Model {
|
|
||||||
|
|
||||||
private final Layer[] layers;
|
|
||||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
|
||||||
private final Map<Neuron, Integer> layerIndexByNeuron;
|
|
||||||
public FullyConnectedNetwork(Layer[] layers) {
|
|
||||||
this.layers = layers;
|
|
||||||
this.connectionMap = this.createConnectionMap();
|
|
||||||
this.layerIndexByNeuron = this.createNeuronIndex();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float[] predict(float[] inputs) {
|
|
||||||
float[] previousLayerOutputs = inputs;
|
|
||||||
for (Layer layer : layers) {
|
|
||||||
previousLayerOutputs = layer.predict(previousLayerOutputs);
|
|
||||||
}
|
|
||||||
return previousLayerOutputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int synCount() {
|
|
||||||
int res = 0;
|
|
||||||
for(Layer layer : this.layers){
|
|
||||||
res += layer.synCount();
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int neuronCount() {
|
|
||||||
int res = 0;
|
|
||||||
for(Layer layer : this.layers){
|
|
||||||
res += layer.neuronCount();
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
|
||||||
for(Layer l : this.layers){
|
|
||||||
l.forEachNeuron(consumer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
|
||||||
int lastIndex = this.layers.length-1;
|
|
||||||
this.layers[lastIndex].forEachNeuron(consumer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
|
||||||
if(!this.connectionMap.containsKey(n)) return;
|
|
||||||
this.connectionMap.get(n).forEach(consumer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int layerIndexOf(Neuron n) {
|
|
||||||
return this.layerIndexByNeuron.get(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int indexInLayerOf(Neuron n) {
|
|
||||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
|
||||||
return this.layers[layerIndex].indexInLayerOf(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
|
||||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < this.layers.length - 1; i++) {
|
|
||||||
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
|
||||||
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
|
|
||||||
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<Neuron, Integer> createNeuronIndex() {
|
|
||||||
Map<Neuron, Integer> res = new HashMap<>();
|
|
||||||
AtomicInteger index = new AtomicInteger(0);
|
|
||||||
for(Layer l : this.layers){
|
|
||||||
l.forEachNeuron(n -> res.put(n, index.get()));
|
|
||||||
index.incrementAndGet();
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,28 +1,26 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public class Layer implements Model {
|
public class Layer implements Model {
|
||||||
|
|
||||||
private final Neuron[] neurons;
|
private final List<Neuron> neurons;
|
||||||
private final Map<Neuron, Integer> neuronIndex;
|
|
||||||
|
|
||||||
public Layer(Neuron[] neurons) {
|
public Layer(List<Neuron> neurons) {
|
||||||
this.neurons = neurons;
|
this.neurons = neurons;
|
||||||
this.neuronIndex = createNeuronIndex();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] predict(float[] inputs) {
|
public List<Float> predict(List<Input> inputs) {
|
||||||
float[] result = new float[neurons.length];
|
List<Float> result = new ArrayList<>();
|
||||||
for (int i = 0; i < neurons.length; i++) {
|
for(Neuron neuron : this.neurons){
|
||||||
result[i] = neurons[i].predict(inputs)[0];
|
List<Float> res = neuron.predict(inputs);
|
||||||
|
result.addAll(res);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -37,51 +35,7 @@ public class Layer implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int neuronCount() {
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
return this.neurons.length;
|
this.neurons.forEach(neuron -> neuron.applyOnSynapses(consumer));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int layerIndexOf(Neuron n) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int indexInLayerOf(Neuron n) {
|
|
||||||
return this.neuronIndex.get(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
|
||||||
for (Neuron n : this.neurons){
|
|
||||||
consumer.accept(n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*@Override
|
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
|
||||||
for (Neuron n : this.neurons){
|
|
||||||
n.forEachSynapse(consumer);
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
|
||||||
this.forEachNeuron(consumer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
|
||||||
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<Neuron, Integer> createNeuronIndex() {
|
|
||||||
Map<Neuron, Integer> res = new HashMap<>();
|
|
||||||
int[] index = {0};
|
|
||||||
this.forEachNeuron(n -> {
|
|
||||||
res.put(n, index[0]++);
|
|
||||||
});
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public class Network implements Model {
|
||||||
|
|
||||||
|
private final List<Layer> layers;
|
||||||
|
|
||||||
|
public Network(List<Layer> layers) {
|
||||||
|
this.layers = layers;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> predict(List<Input> inputs) {
|
||||||
|
List<Float> result = new ArrayList<>();
|
||||||
|
for(Layer layer : this.layers){
|
||||||
|
List<Float> res = layer.predict(inputs);
|
||||||
|
result.addAll(res);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
int res = 0;
|
||||||
|
for(Layer layer : this.layers){
|
||||||
|
res += layer.synCount();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
|
this.layers.forEach(layer -> layer.applyOnSynapses(consumer));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
public class Neuron implements Model {
|
|
||||||
|
|
||||||
private final int id;
|
|
||||||
private float output;
|
|
||||||
private final float[] weights;
|
|
||||||
private final float[] inputs;
|
|
||||||
private final ActivationFunction activationFunction;
|
|
||||||
|
|
||||||
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
|
|
||||||
this.id = id;
|
|
||||||
this.activationFunction = func;
|
|
||||||
|
|
||||||
output = 0;
|
|
||||||
weights = new float[synapses.length+1]; //takes the bias into account
|
|
||||||
inputs = new float[synapses.length+1]; //takes the bias into account
|
|
||||||
|
|
||||||
weights[0] = bias.getWeight();
|
|
||||||
inputs[0] = bias.getInput();
|
|
||||||
for (int i = 0; i < synapses.length; i++){
|
|
||||||
weights[i+1] = synapses[i].getWeight();
|
|
||||||
inputs[i+1] = synapses[i].getInput();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWeight(int index, float value) {
|
|
||||||
this.weights[index] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getWeight(int index) {
|
|
||||||
return this.weights[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getInput(int index) {
|
|
||||||
return this.inputs[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
public ActivationFunction getActivationFunction(){
|
|
||||||
return this.activationFunction;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateWeightedSum() {
|
|
||||||
int count = weights.length;
|
|
||||||
float weightedSum = 0F;
|
|
||||||
for (int i = 0; i < count; i++){
|
|
||||||
weightedSum += weights[i] * inputs[i];
|
|
||||||
}
|
|
||||||
return weightedSum;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getId(){
|
|
||||||
return this.id;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getOutput() {
|
|
||||||
return this.output;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int synCount() {
|
|
||||||
return this.weights.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int neuronCount() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int layerIndexOf(Neuron n) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int indexInLayerOf(Neuron n) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float[] predict(float[] inputs) {
|
|
||||||
this.setInputs(inputs);
|
|
||||||
output = activationFunction.accept(this);
|
|
||||||
return new float[] {output};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
|
||||||
consumer.accept(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
|
||||||
consumer.accept(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
|
||||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setInputs(float[] values){
|
|
||||||
System.arraycopy(values, 0, inputs, 1, values.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -25,4 +25,8 @@ public class Synapse {
|
|||||||
public void setWeight(float value){
|
public void setWeight(float value){
|
||||||
this.weight.setValue(value);
|
this.weight.setValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,31 +1,25 @@
|
|||||||
package com.naaturel.ANN.domain.model.training;
|
package com.naaturel.ANN.domain.model.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
public class TrainingPipeline {
|
public class TrainingPipeline {
|
||||||
|
|
||||||
private final List<AlgorithmStep> steps;
|
private final List<TrainingStep> steps;
|
||||||
private Consumer<TrainingContext> beforeEpoch;
|
private Consumer<TrainingContext> beforeEpoch;
|
||||||
private Consumer<TrainingContext> afterEpoch;
|
private Consumer<TrainingContext> afterEpoch;
|
||||||
private Predicate<TrainingContext> stopCondition;
|
private Predicate<TrainingContext> stopCondition;
|
||||||
|
|
||||||
private boolean verbose;
|
private boolean verbose;
|
||||||
private boolean visualization;
|
|
||||||
private boolean timeMeasurement;
|
private boolean timeMeasurement;
|
||||||
|
|
||||||
private GraphVisualizer visualizer;
|
public TrainingPipeline(List<TrainingStep> steps) {
|
||||||
private int verboseDelay;
|
|
||||||
|
|
||||||
public TrainingPipeline(List<AlgorithmStep> steps) {
|
|
||||||
this.steps = new ArrayList<>(steps);
|
this.steps = new ArrayList<>(steps);
|
||||||
this.stopCondition = (ctx) -> false;
|
this.stopCondition = (ctx) -> false;
|
||||||
this.beforeEpoch = (context -> {});
|
this.beforeEpoch = (context -> {});
|
||||||
@@ -47,16 +41,8 @@ public class TrainingPipeline {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
|
public TrainingPipeline withVerbose(boolean enabled) {
|
||||||
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
|
|
||||||
this.verbose = enabled;
|
this.verbose = enabled;
|
||||||
this.verboseDelay = epochDelay;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) {
|
|
||||||
this.visualization = enabled;
|
|
||||||
this.visualizer = visualizer;
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,67 +52,31 @@ public class TrainingPipeline {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void run(TrainingContext ctx) {
|
public void run(TrainingContext ctx) {
|
||||||
|
|
||||||
long start = this.timeMeasurement ? System.currentTimeMillis() : 0;
|
|
||||||
|
|
||||||
do {
|
do {
|
||||||
this.beforeEpoch.accept(ctx);
|
this.beforeEpoch.accept(ctx);
|
||||||
this.executeSteps(ctx);
|
this.executeSteps(ctx);
|
||||||
this.afterEpoch.accept(ctx);
|
this.afterEpoch.accept(ctx);
|
||||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
if(this.verbose) {
|
||||||
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
System.out.printf("[Global error] : %.2f\n", ctx.globalLoss);
|
||||||
}
|
}
|
||||||
ctx.epoch += 1;
|
|
||||||
} while (!this.stopCondition.test(ctx));
|
} while (!this.stopCondition.test(ctx));
|
||||||
|
|
||||||
if(this.timeMeasurement) {
|
|
||||||
long end = System.currentTimeMillis();
|
|
||||||
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
|
|
||||||
}
|
|
||||||
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
|
|
||||||
//if(this.visualization) this.visualize(ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
for (DataSetEntry entry : ctx.dataset) {
|
for (DataSetEntry entry : ctx.dataset) {
|
||||||
|
|
||||||
ctx.currentEntry = entry;
|
ctx.currentEntry = entry;
|
||||||
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
|
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
||||||
|
for (TrainingStep step : steps) {
|
||||||
for (AlgorithmStep step : steps) {
|
|
||||||
step.run();
|
step.run();
|
||||||
}
|
}
|
||||||
|
if(this.verbose) {
|
||||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
|
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
||||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
||||||
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
System.out.printf("delta : %.2f, ", ctx.delta);
|
||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*private void visualize(TrainingContext ctx){
|
|
||||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
|
||||||
ctx.model.forEachNeuron(n -> {
|
|
||||||
List<Float> weights = new ArrayList<>();
|
|
||||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
|
||||||
|
|
||||||
float b = weights.get(0);
|
|
||||||
float w1 = weights.get(1);
|
|
||||||
float w2 = weights.get(2);
|
|
||||||
|
|
||||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
|
||||||
});
|
|
||||||
int i = 0;
|
|
||||||
for(DataSetEntry entry : ctx.dataset){
|
|
||||||
List<Input> inputs = entry.getData();
|
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
|
||||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
this.visualizer.buildLineGraph();
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.adaline;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
|
|
||||||
public class AdalineTrainingContext extends TrainingContext {
|
|
||||||
public AdalineTrainingContext(Model model, DataSet dataset) {
|
|
||||||
super(model, dataset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class GradientDescentCorrectionStrategy implements AlgorithmStep {
|
public class GradientDescentCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final GradientDescentTrainingContext context;
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
@@ -13,15 +13,13 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStep {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void apply() {
|
||||||
int[] globalSynIndex = {0};
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
context.model.forEachNeuron(n -> {
|
context.model.applyOnSynapses(syn -> {
|
||||||
for(int i = 0; i < n.synCount(); i++){
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
float corrector = context.correctorTerms.get(globalSynIndex[0]);
|
float c = syn.getWeight() + corrector;
|
||||||
float c = n.getWeight(i) + corrector;
|
syn.setWeight(c);
|
||||||
n.setWeight(i, c);
|
i.incrementAndGet();
|
||||||
globalSynIndex[0]++;
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class GradientDescentErrorStrategy implements AlgorithmStep {
|
public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final GradientDescentTrainingContext context;
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
@@ -14,23 +14,13 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
|
|||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void apply() {
|
||||||
|
AtomicInteger i = new AtomicInteger(0);
|
||||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
context.model.applyOnSynapses(syn -> {
|
||||||
AtomicInteger synIndex = new AtomicInteger(0);
|
float corrector = context.correctorTerms.get(i.get());
|
||||||
|
corrector += context.learningRate * context.delta * syn.getInput();
|
||||||
context.model.forEachNeuron(neuron -> {
|
context.correctorTerms.set(i.get(), corrector);
|
||||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
i.incrementAndGet();
|
||||||
|
|
||||||
for(int i = 0; i < neuron.synCount(); i++){
|
|
||||||
float corrector = context.correctorTerms.get(synIndex.get());
|
|
||||||
corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
|
|
||||||
context.correctorTerms.set(synIndex.get(), corrector);
|
|
||||||
synIndex.incrementAndGet();
|
|
||||||
}
|
|
||||||
neuronIndex.incrementAndGet();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
context.globalLoss += context.localLoss;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -10,7 +8,4 @@ public class GradientDescentTrainingContext extends TrainingContext {
|
|||||||
|
|
||||||
public List<Float> correctorTerms;
|
public List<Float> correctorTerms;
|
||||||
|
|
||||||
public GradientDescentTrainingContext(Model model, DataSet dataset) {
|
|
||||||
super(model, dataset);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,13 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
|
||||||
public class Linear implements ActivationFunction {
|
public class Linear implements ActivationFunction {
|
||||||
|
|
||||||
private final float slope;
|
|
||||||
private final float intercept;
|
|
||||||
|
|
||||||
public Linear(float slope, float intercept) {
|
|
||||||
this.slope = slope;
|
|
||||||
this.intercept = intercept;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float accept(Neuron n) {
|
public float accept(Neuron n) {
|
||||||
return slope * n.calculateWeightedSum() + intercept;
|
return n.calculateWeightedSum();
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float derivative(float value) {
|
|
||||||
return this.slope;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
|
|
||||||
import java.util.stream.Stream;
|
|
||||||
|
|
||||||
public class SquareLossStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final TrainingContext context;
|
|
||||||
|
|
||||||
public SquareLossStep(TrainingContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
float loss = 0f;
|
|
||||||
for (float d : this.context.deltas) {
|
|
||||||
loss += d * d;
|
|
||||||
}
|
|
||||||
this.context.localLoss = loss / 2f;
|
|
||||||
this.context.globalLoss += this.context.localLoss;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
|
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final GradientDescentTrainingContext context;
|
||||||
|
|
||||||
|
public SquareLossStrategy(GradientDescentTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
||||||
|
this.context.globalLoss += context.localLoss;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
|
|
||||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final GradientBackpropagationContext context;
|
|
||||||
private final int synCount;
|
|
||||||
private final float[] inputs;
|
|
||||||
private final float[] signals;
|
|
||||||
|
|
||||||
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
|
||||||
this.context = context;
|
|
||||||
this.synCount = context.correctionBuffer.length;
|
|
||||||
this.inputs = new float[synCount];
|
|
||||||
this.signals = new float[synCount];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
int[] synIndex = {0};
|
|
||||||
context.model.forEachNeuron(n -> {
|
|
||||||
float signal = context.errorSignals[n.getId()];
|
|
||||||
for (int i = 0; i < n.synCount(); i++){
|
|
||||||
inputs[synIndex[0]] = n.getInput(i);
|
|
||||||
signals[synIndex[0]] = signal;
|
|
||||||
synIndex[0]++;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
float lr = context.learningRate;
|
|
||||||
boolean applyUpdate = context.currentSample >= context.batchSize;
|
|
||||||
|
|
||||||
for (int i = 0; i < synCount; i++) {
|
|
||||||
context.correctionBuffer[i] += lr * signals[i] * inputs[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (applyUpdate) {
|
|
||||||
syncWeights();
|
|
||||||
context.currentSample = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
context.currentSample++;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void syncWeights() {
|
|
||||||
int[] synIndex = {0};
|
|
||||||
context.model.forEachNeuron(n -> {
|
|
||||||
for (int i = 0; i < n.synCount(); i++) {
|
|
||||||
n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
|
|
||||||
context.correctionBuffer[synIndex[0]] = 0f;
|
|
||||||
synIndex[0]++;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
|
|
||||||
public class BatchAccumulatorStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
|
|
||||||
public class ErrorSignalStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final GradientBackpropagationContext context;
|
|
||||||
|
|
||||||
public ErrorSignalStep(GradientBackpropagationContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
|
|
||||||
context.model.forEachNeuron(n -> {
|
|
||||||
if (context.errorSignalsComputed[n.getId()]) return;
|
|
||||||
|
|
||||||
int neuronIndex = context.model.indexInLayerOf(n);
|
|
||||||
float[] signalSum = {0f};
|
|
||||||
context.model.forEachNeuronConnectedTo(n, connected -> {
|
|
||||||
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
|
|
||||||
});
|
|
||||||
|
|
||||||
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
|
|
||||||
context.errorSignalsComputed[n.getId()] = true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class GradientBackpropagationContext extends TrainingContext {
|
|
||||||
|
|
||||||
public final float[] errorSignals;
|
|
||||||
public final float[] correctionBuffer;
|
|
||||||
public final boolean[] errorSignalsComputed;
|
|
||||||
|
|
||||||
public int currentSample;
|
|
||||||
public int batchSize;
|
|
||||||
|
|
||||||
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
|
|
||||||
super(model, dataSet);
|
|
||||||
this.learningRate = learningRate;
|
|
||||||
this.batchSize = batchSize;
|
|
||||||
|
|
||||||
this.errorSignals = new float[model.neuronCount()];
|
|
||||||
this.correctionBuffer = new float[model.synCount()];
|
|
||||||
this.errorSignalsComputed = new boolean[model.neuronCount()];
|
|
||||||
this.currentSample = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class OutputLayerErrorStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final GradientBackpropagationContext context;
|
|
||||||
private final float[] expectations;
|
|
||||||
|
|
||||||
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
|
||||||
this.context = context;
|
|
||||||
this.expectations = new float[context.dataset.getNbrLabels()];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
Arrays.fill(context.errorSignals, 0f);
|
|
||||||
Arrays.fill(context.errorSignalsComputed, false);
|
|
||||||
|
|
||||||
DataSetEntry entry = context.currentEntry;
|
|
||||||
List<Float> labels = context.dataset.getLabelsAsFloat(entry);
|
|
||||||
for (int i = 0; i < labels.size(); i++) {
|
|
||||||
expectations[i] = labels.get(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
int[] index = {0};
|
|
||||||
context.model.forEachOutputNeurons(n -> {
|
|
||||||
float expected = expectations[index[0]];
|
|
||||||
float predicted = n.getOutput();
|
|
||||||
float delta = expected - predicted;
|
|
||||||
|
|
||||||
context.deltas[index[0]] = delta;
|
|
||||||
context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted);
|
|
||||||
context.errorSignalsComputed[n.getId()] = true;
|
|
||||||
index[0]++;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
public class Sigmoid implements ActivationFunction {
|
|
||||||
|
|
||||||
private float steepness;
|
|
||||||
|
|
||||||
public Sigmoid(float steepness) {
|
|
||||||
this.steepness = steepness;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float accept(Neuron n) {
|
|
||||||
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float derivative(float value) {
|
|
||||||
return steepness * value * (1 - value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.multiLayers;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
public class TanH implements ActivationFunction {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float accept(Neuron n) {
|
|
||||||
//For educational purpose. Math.tanh() could have been used here
|
|
||||||
float weightedSum = n.calculateWeightedSum();
|
|
||||||
double exp = Math.exp(weightedSum);
|
|
||||||
double res = (exp-(1/exp))/(exp+(1/exp));
|
|
||||||
return (float)(res);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float derivative(float value) {
|
|
||||||
return 1 - value * value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package com.naaturel.ANN.implementation.neuron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public class SimplePerceptron extends Neuron {
|
||||||
|
|
||||||
|
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
||||||
|
super(synapses, b, func);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> predict(List<Input> inputs) {
|
||||||
|
super.setInputs(inputs);
|
||||||
|
return List.of(activationFunction.accept(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyOnSynapses(Consumer<Synapse> consumer) {
|
||||||
|
consumer.accept(this.bias);
|
||||||
|
this.synapses.forEach(consumer);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float calculateWeightedSum() {
|
||||||
|
float res = 0;
|
||||||
|
for(Synapse syn : super.synapses){
|
||||||
|
res += syn.getWeight() * syn.getInput();
|
||||||
|
}
|
||||||
|
res += this.bias.getWeight() * this.bias.getInput();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
|
|
||||||
import javax.naming.OperationNotSupportedException;
|
|
||||||
|
|
||||||
public class Heaviside implements ActivationFunction {
|
public class Heaviside implements ActivationFunction {
|
||||||
|
|
||||||
@@ -14,11 +12,6 @@ public class Heaviside implements ActivationFunction {
|
|||||||
@Override
|
@Override
|
||||||
public float accept(Neuron n) {
|
public float accept(Neuron n) {
|
||||||
float weightedSum = n.calculateWeightedSum();
|
float weightedSum = n.calculateWeightedSum();
|
||||||
return weightedSum < 0 ? 0:1;
|
return weightedSum <= 0 ? 0:1;
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float derivative(float value) {
|
|
||||||
throw new UnsupportedOperationException("Heaviside is not differentiable");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
|
|
||||||
public class SimpleCorrectionStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final TrainingContext context;
|
|
||||||
|
|
||||||
public SimpleCorrectionStep(TrainingContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
if(context.expectations.equals(context.predictions)) return;
|
|
||||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
|
||||||
|
|
||||||
context.model.forEachNeuron(neuron -> {
|
|
||||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
|
||||||
|
|
||||||
for(int i = 0; i < neuron.synCount(); i++){
|
|
||||||
float currentW = neuron.getWeight(i);
|
|
||||||
float currentInput = neuron.getInput(i);
|
|
||||||
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
|
|
||||||
neuron.setWeight(i, newValue);
|
|
||||||
}
|
|
||||||
neuronIndex.incrementAndGet();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final SimpleTrainingContext context;
|
||||||
|
|
||||||
|
public SimpleCorrectionStrategy(SimpleTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
if(context.currentLabel.getValue() == context.prediction) return ;
|
||||||
|
context.model.applyOnSynapses(syn -> {
|
||||||
|
float currentW = syn.getWeight();
|
||||||
|
float currentInput = syn.getInput();
|
||||||
|
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
||||||
|
syn.setWeight(newValue);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.IntStream;
|
|
||||||
|
|
||||||
public class SimpleDeltaStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final TrainingContext context;
|
|
||||||
|
|
||||||
public SimpleDeltaStep(TrainingContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
DataSet dataSet = context.dataset;
|
|
||||||
DataSetEntry entry = context.currentEntry;
|
|
||||||
float[] predicted = context.predictions;
|
|
||||||
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
|
||||||
|
|
||||||
for (int i = 0; i < predicted.length; i++) {
|
|
||||||
context.deltas[i] = expected.get(i) - predicted[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
|
||||||
|
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final TrainingContext context;
|
||||||
|
|
||||||
|
public SimpleDeltaStrategy(TrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
DataSet dataSet = context.dataset;
|
||||||
|
DataSetEntry entry = context.currentEntry;
|
||||||
|
Label label = dataSet.getLabel(entry);
|
||||||
|
|
||||||
|
context.delta = label.getValue() - context.prediction;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
|
|
||||||
public class SimpleErrorRegistrationStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final TrainingContext context;
|
|
||||||
|
|
||||||
public SimpleErrorRegistrationStep(TrainingContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
context.globalLoss += context.localLoss;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
|
public class SimpleErrorRegistrationStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final SimpleTrainingContext context;
|
||||||
|
|
||||||
|
public SimpleErrorRegistrationStrategy(SimpleTrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
context.globalLoss += context.localLoss;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
|
||||||
public class SimpleLossStrategy implements AlgorithmStep {
|
public class SimpleLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final SimpleTrainingContext context;
|
private final SimpleTrainingContext context;
|
||||||
|
|
||||||
@@ -11,11 +11,7 @@ public class SimpleLossStrategy implements AlgorithmStep {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void apply() {
|
||||||
float loss = 0f;
|
this.context.localLoss = Math.abs(this.context.delta);
|
||||||
for (float d : context.deltas) {
|
|
||||||
loss += d;
|
|
||||||
}
|
|
||||||
context.localLoss = loss;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class SimplePredictionStep implements AlgorithmStep {
|
|
||||||
|
|
||||||
private final TrainingContext context;
|
|
||||||
|
|
||||||
public SimplePredictionStep(TrainingContext context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
List<Input> data = context.currentEntry.getData();
|
|
||||||
float[] flatData = new float[data.size()];
|
|
||||||
for (int i = 0; i < data.size(); i++) {
|
|
||||||
flatData[i] = data.get(i).getValue();
|
|
||||||
}
|
|
||||||
context.predictions = context.model.predict(flatData);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class SimplePredictionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
|
private final TrainingContext context;
|
||||||
|
|
||||||
|
public SimplePredictionStrategy(TrainingContext context) {
|
||||||
|
this.context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void apply() {
|
||||||
|
List<Float> predictions = context.model.predict(context.currentEntry.getData());
|
||||||
|
context.prediction = predictions.getFirst();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,6 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
|
|
||||||
public class SimpleTrainingContext extends TrainingContext {
|
public class SimpleTrainingContext extends TrainingContext {
|
||||||
public SimpleTrainingContext(Model model, DataSet dataset) {
|
|
||||||
super(model, dataset);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,51 +1,21 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public class AdalineTraining implements Trainer {
|
/*public class AdalineTraining implements Trainer {
|
||||||
|
|
||||||
public AdalineTraining(){
|
public AdalineTraining(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
|
||||||
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
|
||||||
context.learningRate = learningRate;
|
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
|
||||||
new SimplePredictionStep(context),
|
|
||||||
new SimpleDeltaStep(context),
|
|
||||||
new SquareLossStep(context),
|
|
||||||
new SimpleErrorRegistrationStep(context),
|
|
||||||
new SimpleCorrectionStep(context)
|
|
||||||
);
|
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
|
||||||
.withTimeMeasurement(true)
|
|
||||||
.withVerbose(true, 1)
|
|
||||||
.withVisualization(true, new GraphVisualizer())
|
|
||||||
.run(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
int maxEpoch = 202;
|
int maxEpoch = 202;
|
||||||
float errorThreshold = 0.0F;
|
float errorThreshold = 0.0F;
|
||||||
@@ -106,6 +76,6 @@ public class AdalineTraining implements Trainer {
|
|||||||
|
|
||||||
private float calculateLoss(float delta){
|
private float calculateLoss(float delta){
|
||||||
return (float) Math.pow(delta, 2)/2;
|
return (float) Math.pow(delta, 2)/2;
|
||||||
}*/
|
}
|
||||||
|
|
||||||
}
|
}*/
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class GradientBackpropagationTraining implements Trainer {
|
|
||||||
@Override
|
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
|
||||||
GradientBackpropagationContext context =
|
|
||||||
new GradientBackpropagationContext(model, dataset, learningRate, 10);
|
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
|
||||||
new SimplePredictionStep(context),
|
|
||||||
new OutputLayerErrorStep(context),
|
|
||||||
new ErrorSignalStep(context),
|
|
||||||
new BackpropagationCorrectionStep(context),
|
|
||||||
new SquareLossStep(context)
|
|
||||||
);
|
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
|
||||||
.beforeEpoch(ctx -> {
|
|
||||||
ctx.globalLoss = 0.0F;
|
|
||||||
})
|
|
||||||
.afterEpoch(ctx -> {
|
|
||||||
ctx.globalLoss /= dataset.size();
|
|
||||||
})
|
|
||||||
.withVerbose(true,epoch/10)
|
|
||||||
.withTimeMeasurement(true)
|
|
||||||
.run(context);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStrategy;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -23,36 +23,34 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(Model model, DataSet dataset) {
|
||||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
GradientDescentTrainingContext context = new GradientDescentTrainingContext();
|
||||||
context.learningRate = learningRate;
|
context.dataset = dataset;
|
||||||
|
context.model = model;
|
||||||
|
context.learningRate = 0.00011F;
|
||||||
context.correctorTerms = new ArrayList<>();
|
context.correctorTerms = new ArrayList<>();
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
new SimpleDeltaStep(context),
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
new SquareLossStep(context),
|
new LossStep(new SquareLossStrategy(context)),
|
||||||
new GradientDescentErrorStrategy(context)
|
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context)),
|
||||||
|
new WeightCorrectionStep(new GradientDescentCorrectionStrategy(context))
|
||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
|
pipeline
|
||||||
.beforeEpoch(ctx -> {
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 1000)
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
.beforeEpoch(ctx -> {
|
||||||
gdCtx.globalLoss = 0.0F;
|
ctx.globalLoss = 0.0F;
|
||||||
gdCtx.correctorTerms.clear();
|
for (int i = 0; i < model.synCount(); i++){
|
||||||
for(int i = 0; i < gdCtx.model.synCount(); i++){
|
context.correctorTerms.add(0F);
|
||||||
gdCtx.correctorTerms.add(0F);
|
}
|
||||||
}
|
})
|
||||||
})
|
.afterEpoch(ctx -> ctx.globalLoss /= ctx.dataset.size())
|
||||||
.afterEpoch(ctx -> {
|
.withVerbose(true)
|
||||||
context.globalLoss /= context.dataset.size();
|
.withTimeMeasurement(true)
|
||||||
new GradientDescentCorrectionStrategy(context).run();
|
.run(context);
|
||||||
})
|
|
||||||
//.withVerbose(true)
|
|
||||||
.withTimeMeasurement(true)
|
|
||||||
.withVisualization(true, new GraphVisualizer())
|
|
||||||
.run(context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package com.naaturel.ANN.implementation.training;
|
package com.naaturel.ANN.implementation.training;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -16,25 +17,25 @@ public class SimpleTraining implements Trainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
public void train(Model model, DataSet dataset) {
|
||||||
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
SimpleTrainingContext context = new SimpleTrainingContext();
|
||||||
context.dataset = dataset;
|
context.dataset = dataset;
|
||||||
context.model = model;
|
context.model = model;
|
||||||
context.learningRate = learningRate;
|
context.learningRate = 0.3F;
|
||||||
|
|
||||||
List<AlgorithmStep> steps = List.of(
|
List<TrainingStep> steps = List.of(
|
||||||
new SimplePredictionStep(context),
|
new PredictionStep(new SimplePredictionStrategy(context)),
|
||||||
new SimpleDeltaStep(context),
|
new DeltaStep(new SimpleDeltaStrategy(context)),
|
||||||
new SimpleLossStrategy(context),
|
new LossStep(new SimpleLossStrategy(context)),
|
||||||
new SimpleErrorRegistrationStep(context),
|
new ErrorRegistrationStep(new SimpleErrorRegistrationStrategy(context)),
|
||||||
new SimpleCorrectionStep(context)
|
new WeightCorrectionStep(new SimpleCorrectionStrategy(context))
|
||||||
);
|
);
|
||||||
|
|
||||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||||
pipeline
|
pipeline
|
||||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
|
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||||
.withVerbose(true, 1)
|
.withVerbose(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||||
|
|
||||||
|
public class DeltaStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final AlgorithmStrategy strategy;
|
||||||
|
|
||||||
|
public DeltaStep(AlgorithmStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.strategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
|
||||||
|
public class ErrorRegistrationStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final AlgorithmStrategy strategy;
|
||||||
|
|
||||||
|
public ErrorRegistrationStep(AlgorithmStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.strategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
|
||||||
|
public class LossStep implements TrainingStep {
|
||||||
|
|
||||||
|
|
||||||
|
private final AlgorithmStrategy lossStrategy;
|
||||||
|
|
||||||
|
public LossStep(AlgorithmStrategy strategy) {
|
||||||
|
this.lossStrategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.lossStrategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStrategy;
|
||||||
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class PredictionStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final SimplePredictionStrategy strategy;
|
||||||
|
|
||||||
|
public PredictionStep(SimplePredictionStrategy strategy) {
|
||||||
|
this.strategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.strategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
|
|
||||||
|
public class WeightCorrectionStep implements TrainingStep {
|
||||||
|
|
||||||
|
private final AlgorithmStrategy correctionStrategy;
|
||||||
|
|
||||||
|
public WeightCorrectionStep(AlgorithmStrategy strategy) {
|
||||||
|
this.correctionStrategy = strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
this.correctionStrategy.apply();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.config;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class ConfigDto {
|
|
||||||
|
|
||||||
@JsonProperty("model")
|
|
||||||
private Map<String, Object> modelConfig;
|
|
||||||
|
|
||||||
@JsonProperty("training")
|
|
||||||
private Map<String, Object> trainingConfig;
|
|
||||||
|
|
||||||
@JsonProperty("dataset")
|
|
||||||
private Map<String, Object> datasetConfig;
|
|
||||||
|
|
||||||
public <T> T getModelProperty(String key, Class<T> type) {
|
|
||||||
Object value = find(key, this.modelConfig);
|
|
||||||
if (value instanceof List<?> list && type.isArray()) {
|
|
||||||
int[] arr = new int[list.size()];
|
|
||||||
for (int i = 0; i < list.size(); i++) {
|
|
||||||
arr[i] = ((Number) list.get(i)).intValue();
|
|
||||||
}
|
|
||||||
return type.cast(arr);
|
|
||||||
}
|
|
||||||
if (!type.isInstance(value)) {
|
|
||||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
|
||||||
}
|
|
||||||
return type.cast(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public <T> T getTrainingProperty(String key, Class<T> type) {
|
|
||||||
Object value = find(key, this.trainingConfig);
|
|
||||||
if (!type.isInstance(value)) {
|
|
||||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
|
||||||
}
|
|
||||||
return type.cast(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
public <T> T getDatasetProperty(String key, Class<T> type) {
|
|
||||||
Object value = find(key, this.datasetConfig);
|
|
||||||
if (!type.isInstance(value)) {
|
|
||||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
|
||||||
}
|
|
||||||
return type.cast(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Object find(String key, Map<String, Object> config){
|
|
||||||
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
|
|
||||||
return config.get(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.config;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
|
|
||||||
public class ConfigLoader {
|
|
||||||
|
|
||||||
|
|
||||||
public static ConfigDto load(String path) throws Exception {
|
|
||||||
try {
|
|
||||||
|
|
||||||
ObjectMapper mapper = new ObjectMapper();
|
|
||||||
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
|
|
||||||
|
|
||||||
return config;
|
|
||||||
} catch (Exception e){
|
|
||||||
throw new Exception("Unable to load config : " + e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.dataset;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.stream.Stream;
|
|
||||||
|
|
||||||
public class DataSet implements Iterable<DataSetEntry>{
|
|
||||||
|
|
||||||
private final Map<DataSetEntry, Labels> data;
|
|
||||||
|
|
||||||
private final int nbrInputs;
|
|
||||||
private final int nbrLabels;
|
|
||||||
|
|
||||||
public DataSet() {
|
|
||||||
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataSet(Map<DataSetEntry, Labels> data){
|
|
||||||
this.data = data;
|
|
||||||
this.nbrInputs = this.calculateNbrInput();
|
|
||||||
this.nbrLabels = this.calculateNbrLabel();
|
|
||||||
}
|
|
||||||
|
|
||||||
private int calculateNbrInput(){
|
|
||||||
//assumes every entry are the same length
|
|
||||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
|
||||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
|
||||||
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
private int calculateNbrLabel(){
|
|
||||||
//assumes every label are the same length
|
|
||||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
|
||||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
|
||||||
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public int size() {
|
|
||||||
return data.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNbrInputs() {
|
|
||||||
return this.nbrInputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNbrLabels(){
|
|
||||||
return this.nbrLabels;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<DataSetEntry> getData(){
|
|
||||||
return new ArrayList<>(this.data.keySet());
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Float> getLabelsAsFloat(DataSetEntry entry){
|
|
||||||
return this.data.get(entry).getValues();
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataSet toNormalized() {
|
|
||||||
List<DataSetEntry> entries = this.getData();
|
|
||||||
|
|
||||||
float maxAbs = entries.stream()
|
|
||||||
.flatMap(e -> e.getData().stream())
|
|
||||||
.map(Input::getValue)
|
|
||||||
.map(Math::abs)
|
|
||||||
.max(Float::compare)
|
|
||||||
.orElse(1.0F);
|
|
||||||
|
|
||||||
Map<DataSetEntry, Labels> normalized = new HashMap<>();
|
|
||||||
for (DataSetEntry entry : entries) {
|
|
||||||
List<Input> normalizedData = new ArrayList<>();
|
|
||||||
|
|
||||||
for (Input input : entry.getData()) {
|
|
||||||
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
|
||||||
normalizedData.add(normalizedInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
|
||||||
}
|
|
||||||
|
|
||||||
return new DataSet(normalized);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<DataSetEntry> iterator() {
|
|
||||||
return this.data.keySet().iterator();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.dataset;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
|
||||||
import java.io.FileReader;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
public class DatasetExtractor {
|
|
||||||
|
|
||||||
public DataSet extract(String path, int nbrLabels) {
|
|
||||||
Map<DataSetEntry, Labels> data = new LinkedHashMap<>();
|
|
||||||
|
|
||||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
String[] parts = line.split(",");
|
|
||||||
|
|
||||||
String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels);
|
|
||||||
String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length);
|
|
||||||
|
|
||||||
List<Input> inputs = new ArrayList<>();
|
|
||||||
List<Float> labels = new ArrayList<>();
|
|
||||||
|
|
||||||
for (String entry : rawInputs) {
|
|
||||||
inputs.add(new Input(Float.parseFloat(entry.trim())));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (String entry : rawLabels) {
|
|
||||||
labels.add(Float.parseFloat(entry.trim()));
|
|
||||||
}
|
|
||||||
|
|
||||||
data.put(new DataSetEntry(inputs), new Labels(labels));
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new DataSet(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.dataset;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class Labels {
|
|
||||||
|
|
||||||
private final List<Float> values;
|
|
||||||
|
|
||||||
public Labels(List<Float> value){
|
|
||||||
this.values = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Float> getValues() {
|
|
||||||
return values.stream().toList();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.persistence;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ModelDto {
|
|
||||||
|
|
||||||
private List<NeuronDto> neurons;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.persistence;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
|
||||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class ModelSnapshot {
|
|
||||||
|
|
||||||
private Model model;
|
|
||||||
private final ObjectMapper mapper;
|
|
||||||
|
|
||||||
public ModelSnapshot(){
|
|
||||||
this(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ModelSnapshot(Model model){
|
|
||||||
this.model = model;
|
|
||||||
mapper = new ObjectMapper();
|
|
||||||
}
|
|
||||||
|
|
||||||
public Model getModel() {
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void saveToFile(String path) throws Exception {
|
|
||||||
|
|
||||||
ArrayNode root = mapper.createArrayNode();
|
|
||||||
model.forEachNeuron(n -> {
|
|
||||||
|
|
||||||
ObjectNode neuronNode = mapper.createObjectNode();
|
|
||||||
neuronNode.put("id", n.getId());
|
|
||||||
neuronNode.put("layerIndex", model.layerIndexOf(n));
|
|
||||||
|
|
||||||
ArrayNode weights = mapper.createArrayNode();
|
|
||||||
for (int i = 0; i < n.synCount(); i++) {
|
|
||||||
float weight = n.getWeight(i);
|
|
||||||
weights.add(weight);
|
|
||||||
}
|
|
||||||
neuronNode.set("weights", weights);
|
|
||||||
root.add(neuronNode);
|
|
||||||
});
|
|
||||||
|
|
||||||
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void loadFromFile(String path) throws Exception {
|
|
||||||
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
|
||||||
|
|
||||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
|
||||||
|
|
||||||
root.forEach(neuronNode -> {
|
|
||||||
int id = neuronNode.get("id").asInt();
|
|
||||||
int layerIndex = neuronNode.get("layerIndex").asInt();
|
|
||||||
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
|
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
|
|
||||||
Synapse[] synapses = new Synapse[weightsNode.size() - 1];
|
|
||||||
for (int i = 0; i < synapses.length; i++) {
|
|
||||||
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Neuron n = new Neuron(id, synapses, bias, new TanH());
|
|
||||||
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
|
||||||
});
|
|
||||||
|
|
||||||
Layer[] layers = neuronsByLayer.values().stream()
|
|
||||||
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
|
||||||
.toArray(Layer[]::new);
|
|
||||||
|
|
||||||
this.model = new FullyConnectedNetwork(layers);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.persistence;
|
|
||||||
|
|
||||||
public class NeuronDto {
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.visualization;
|
|
||||||
|
|
||||||
import org.jfree.chart.ChartFactory;
|
|
||||||
import org.jfree.chart.ChartPanel;
|
|
||||||
import org.jfree.chart.JFreeChart;
|
|
||||||
import org.jfree.chart.plot.XYPlot;
|
|
||||||
import org.jfree.data.xy.XYSeries;
|
|
||||||
import org.jfree.data.xy.XYSeriesCollection;
|
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
|
|
||||||
public class GraphVisualizer {
|
|
||||||
|
|
||||||
XYSeriesCollection dataset;
|
|
||||||
|
|
||||||
public GraphVisualizer(){
|
|
||||||
this.dataset = new XYSeriesCollection();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addPoint(String title, float x, float y) {
|
|
||||||
if (this.dataset.getSeriesIndex(title) == -1)
|
|
||||||
this.dataset.addSeries(new XYSeries(title));
|
|
||||||
this.dataset.getSeries(title).add(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) {
|
|
||||||
for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) {
|
|
||||||
float x2 = (-y1 * x1 - k) / y2;
|
|
||||||
addPoint(title, x1, x2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void buildLineGraph(){
|
|
||||||
JFreeChart chart = ChartFactory.createXYLineChart(
|
|
||||||
"Model learning", "X", "Y", dataset
|
|
||||||
);
|
|
||||||
JFrame frame = new JFrame("Training Loss");
|
|
||||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
|
||||||
frame.add(new ChartPanel(chart));
|
|
||||||
frame.pack();
|
|
||||||
frame.setVisible(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void buildScatterGraph(int lower, int upper){
|
|
||||||
JFreeChart chart = ChartFactory.createScatterPlot(
|
|
||||||
"Predictions", "X", "Y", dataset
|
|
||||||
);
|
|
||||||
XYPlot plot = chart.getXYPlot();
|
|
||||||
plot.getDomainAxis().setRange(lower, upper);
|
|
||||||
plot.getRangeAxis().setRange(lower, upper);
|
|
||||||
|
|
||||||
JFrame frame = new JFrame("Predictions");
|
|
||||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
|
||||||
frame.add(new ChartPanel(chart));
|
|
||||||
frame.pack();
|
|
||||||
frame.setVisible(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
package com.naaturel.ANN.infrastructure.visualization;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ModelVisualizer {
|
|
||||||
|
|
||||||
private final JFrame frame;
|
|
||||||
private final Model model;
|
|
||||||
private boolean withWeights;
|
|
||||||
|
|
||||||
public ModelVisualizer(Model model){
|
|
||||||
this.frame = initFrame();
|
|
||||||
this.model = model;
|
|
||||||
this.withWeights = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ModelVisualizer withWeights(boolean value){
|
|
||||||
this.withWeights = value;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void display() {
|
|
||||||
JPanel panel = buildPanel();
|
|
||||||
frame.add(panel);
|
|
||||||
frame.setVisible(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
private JFrame initFrame(){
|
|
||||||
JFrame frame = new JFrame("Model Visualizer");
|
|
||||||
frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
|
|
||||||
frame.setSize(800, 600);
|
|
||||||
return frame;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<Integer, Point> computeNeuronPositions(int width, int height) {
|
|
||||||
|
|
||||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
|
||||||
model.forEachNeuron(n -> {
|
|
||||||
int layerIndex = model.layerIndexOf(n);
|
|
||||||
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
|
||||||
});
|
|
||||||
|
|
||||||
Map<Integer, Point> neuronPositions = new HashMap<>();
|
|
||||||
int layerCount = neuronsByLayer.size();
|
|
||||||
int layerX = width / (layerCount + 1);
|
|
||||||
|
|
||||||
for (Map.Entry<Integer, List<Neuron>> entry : neuronsByLayer.entrySet()) {
|
|
||||||
int layerIndex = entry.getKey();
|
|
||||||
List<Neuron> neurons = entry.getValue();
|
|
||||||
int x = layerX * (layerIndex + 1);
|
|
||||||
int neuronCount = neurons.size();
|
|
||||||
|
|
||||||
for (int i = 0; i < neuronCount; i++) {
|
|
||||||
int y = height / (neuronCount + 1) * (i + 1);
|
|
||||||
neuronPositions.put(neurons.get(i).getId(), new Point(x, y));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return neuronPositions;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void drawConnections(Graphics2D g2, Map<Integer, Point> neuronPositions) {
|
|
||||||
model.forEachNeuron(n -> {
|
|
||||||
Point from = neuronPositions.get(n.getId());
|
|
||||||
int neuronIndex = model.indexInLayerOf(n);
|
|
||||||
model.forEachNeuronConnectedTo(n, connected -> {
|
|
||||||
Point to = neuronPositions.get(connected.getId());
|
|
||||||
g2.setColor(Color.LIGHT_GRAY);
|
|
||||||
g2.drawLine(from.x, from.y, to.x, to.y);
|
|
||||||
|
|
||||||
if(!this.withWeights) return;
|
|
||||||
int mx = (from.x + to.x) / 2;
|
|
||||||
int my = (from.y + to.y) / 2;
|
|
||||||
float weight = connected.getWeight(neuronIndex + 1);
|
|
||||||
g2.setColor(Color.DARK_GRAY);
|
|
||||||
g2.drawString(String.format("%.2f", weight), mx, my);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private void drawNeurons(Graphics2D g2, Map<Integer, Point> neuronPositions, int neuronRadius) {
|
|
||||||
model.forEachNeuron(n -> {
|
|
||||||
Point p = neuronPositions.get(n.getId());
|
|
||||||
g2.setColor(Color.WHITE);
|
|
||||||
g2.fillOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
|
|
||||||
g2.setColor(Color.BLACK);
|
|
||||||
g2.drawOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
|
|
||||||
g2.drawString(String.valueOf(n.getId()), p.x - 5, p.y + 5);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private JPanel buildPanel() {
|
|
||||||
int neuronRadius = 20;
|
|
||||||
return new JPanel() {
|
|
||||||
@Override
|
|
||||||
protected void paintComponent(Graphics g) {
|
|
||||||
super.paintComponent(g);
|
|
||||||
Graphics2D g2 = (Graphics2D) g;
|
|
||||||
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
|
|
||||||
|
|
||||||
Map<Integer, Point> neuronPositions = computeNeuronPositions(getWidth(), getHeight());
|
|
||||||
drawConnections(g2, neuronPositions);
|
|
||||||
drawNeurons(g2, neuronPositions, neuronRadius);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
0,0,-1
|
|
||||||
0,1,-1
|
|
||||||
1,0,-1
|
|
||||||
1,1,1
|
|
||||||
|
@@ -1,4 +0,0 @@
|
|||||||
0,0,0
|
|
||||||
0,1,1
|
|
||||||
1,0,1
|
|
||||||
1,1,0
|
|
||||||
|
@@ -1,37 +0,0 @@
|
|||||||
[ {
|
|
||||||
"id" : 0,
|
|
||||||
"layerIndex" : 0,
|
|
||||||
"weights" : [ 0.73659766, -0.77503574, -0.23966503, 0.5794077, -0.1634599, 0.6023351, 0.770439, -0.53199804, -0.75792193, 0.36339867, 0.017356396, -0.3005041, -0.5709046, 0.8757278, 0.22738981, -0.22997773, 0.81583726, -0.6008339, 0.8148706, -0.27952576, 0.11130214, 0.21506107, -0.96409214, 0.8534266, 0.9998076, 0.8971077, -0.55812895, 0.28677964, -0.4315225, -0.12088442, 0.41834033, -0.83330417, 0.013990879, -0.021193504, 0.95400894, 0.24115086, 0.122039676, 0.7069808, 0.74929047, 0.2558391, 0.1307528, 0.45781684, -0.19833839 ]
|
|
||||||
}, {
|
|
||||||
"id" : 1,
|
|
||||||
"layerIndex" : 0,
|
|
||||||
"weights" : [ -0.25173664, -0.15121317, 0.3947369, 0.39584184, 0.8823843, -0.43822396, 0.38901758, -0.83357096, 0.4349022, -0.57956433, -0.78882015, -0.7952547, 0.87495327, 0.20618606, -0.802194, -0.108419776, 0.47257674, -0.7384255, 0.41351187, -0.6334251, -0.61948025, 0.40434432, -0.18576205, -0.47782838, 0.57454073, -0.36851442, -0.7262291, 0.3893906, 0.83334255, -0.6979947, 0.43623865, 0.5753021, 0.0041633844, 0.6598717, 0.21344411, -0.40663266, 0.73282886, -0.00848031, -0.7269838, -0.36129093, -0.7280016, -0.0039184093, -0.71608007 ]
|
|
||||||
}, {
|
|
||||||
"id" : 2,
|
|
||||||
"layerIndex" : 1,
|
|
||||||
"weights" : [ 0.0760473, -0.54335964, -0.789695 ]
|
|
||||||
}, {
|
|
||||||
"id" : 3,
|
|
||||||
"layerIndex" : 1,
|
|
||||||
"weights" : [ -0.16066408, 0.97240365, 0.72418904 ]
|
|
||||||
}, {
|
|
||||||
"id" : 4,
|
|
||||||
"layerIndex" : 1,
|
|
||||||
"weights" : [ 0.58718896, -0.59475064, 0.81476605 ]
|
|
||||||
}, {
|
|
||||||
"id" : 5,
|
|
||||||
"layerIndex" : 1,
|
|
||||||
"weights" : [ -0.09139645, 0.71847045, -0.19625723 ]
|
|
||||||
}, {
|
|
||||||
"id" : 6,
|
|
||||||
"layerIndex" : 2,
|
|
||||||
"weights" : [ 0.60579014, 0.7003196, 0.82173157, 0.7627305, 0.83753014 ]
|
|
||||||
}, {
|
|
||||||
"id" : 7,
|
|
||||||
"layerIndex" : 2,
|
|
||||||
"weights" : [ -0.7915581, -0.5355048, -1.0218064, 0.35036424, -0.6445867 ]
|
|
||||||
}, {
|
|
||||||
"id" : 8,
|
|
||||||
"layerIndex" : 3,
|
|
||||||
"weights" : [ -0.24528235, 0.08879423, -0.46046266 ]
|
|
||||||
} ]
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
package adaline;
|
|
||||||
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
|
|
||||||
public class AdalineTest {
|
|
||||||
|
|
||||||
private DataSet dataset;
|
|
||||||
private AdalineTrainingContext context;
|
|
||||||
|
|
||||||
private List<Synapse> synapses;
|
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
public void init(){
|
|
||||||
dataset = new DatasetExtractor()
|
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new AdalineTrainingContext();
|
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = network;
|
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
|
||||||
new LossStep(new SquareLossStep(context)),
|
|
||||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
|
||||||
);
|
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps)
|
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000)
|
|
||||||
.beforeEpoch(ctx -> {
|
|
||||||
ctx.globalLoss = 0.0F;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_the_whole_algorithm(){
|
|
||||||
|
|
||||||
List<Float> expectedGlobalLosses = List.of(
|
|
||||||
0.501522F,
|
|
||||||
0.498601F
|
|
||||||
);
|
|
||||||
|
|
||||||
context.learningRate = 0.03F;
|
|
||||||
pipeline.afterEpoch(ctx -> {
|
|
||||||
ctx.globalLoss /= context.dataset.size();
|
|
||||||
|
|
||||||
int index = ctx.epoch-1;
|
|
||||||
if(index >= expectedGlobalLosses.size()) return;
|
|
||||||
|
|
||||||
//assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
|
||||||
});
|
|
||||||
|
|
||||||
pipeline.run(context);
|
|
||||||
assertEquals(214, context.epoch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
package gradientDescent;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
|
|
||||||
public class GradientDescentTest {
|
|
||||||
|
|
||||||
private DataSet dataset;
|
|
||||||
private GradientDescentTrainingContext context;
|
|
||||||
|
|
||||||
private List<Synapse> synapses;
|
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
public void init(){
|
|
||||||
dataset = new DatasetExtractor()
|
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new GradientDescentTrainingContext();
|
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = network;
|
|
||||||
context.correctorTerms = new ArrayList<>();
|
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
|
||||||
new LossStep(new SquareLossStep(context)),
|
|
||||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
|
||||||
);
|
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps)
|
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
|
|
||||||
.beforeEpoch(ctx -> {
|
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
|
||||||
gdCtx.globalLoss = 0.0F;
|
|
||||||
gdCtx.correctorTerms.clear();
|
|
||||||
for (int i = 0; i < ctx.model.synCount(); i++){
|
|
||||||
gdCtx.correctorTerms.add(0F);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_the_whole_algorithm(){
|
|
||||||
|
|
||||||
List<Float> expectedGlobalLosses = List.of(
|
|
||||||
0.5F,
|
|
||||||
0.38F,
|
|
||||||
0.3176F,
|
|
||||||
0.272096F,
|
|
||||||
0.237469F
|
|
||||||
);
|
|
||||||
|
|
||||||
context.learningRate = 0.2F;
|
|
||||||
pipeline.afterEpoch(ctx -> {
|
|
||||||
context.globalLoss /= context.dataset.size();
|
|
||||||
new GradientDescentCorrectionStrategy(context).run();
|
|
||||||
|
|
||||||
int index = ctx.epoch-1;
|
|
||||||
if(index >= expectedGlobalLosses.size()) return;
|
|
||||||
|
|
||||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
|
||||||
});
|
|
||||||
|
|
||||||
pipeline
|
|
||||||
.withVerbose(true)
|
|
||||||
.run(context);
|
|
||||||
assertEquals(67, context.epoch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package perceptron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
|
|
||||||
public class SimplePerceptronTest {
|
|
||||||
|
|
||||||
private DataSet dataset;
|
|
||||||
private SimpleTrainingContext context;
|
|
||||||
|
|
||||||
private List<Synapse> synapses;
|
|
||||||
private Bias bias;
|
|
||||||
private FullyConnectedNetwork network;
|
|
||||||
|
|
||||||
private TrainingPipeline pipeline;
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
public void init(){
|
|
||||||
dataset = new DatasetExtractor()
|
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1);
|
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
|
||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
|
||||||
|
|
||||||
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
|
||||||
Layer layer = new Layer(List.of(neuron));
|
|
||||||
network = new FullyConnectedNetwork(List.of(layer));
|
|
||||||
|
|
||||||
context = new SimpleTrainingContext();
|
|
||||||
context.dataset = dataset;
|
|
||||||
context.model = network;
|
|
||||||
|
|
||||||
List<TrainingStep> steps = List.of(
|
|
||||||
new PredictionStep(new SimplePredictionStep(context)),
|
|
||||||
new DeltaStep(new SimpleDeltaStep(context)),
|
|
||||||
new LossStep(new SimpleLossStrategy(context)),
|
|
||||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
|
||||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
|
||||||
);
|
|
||||||
|
|
||||||
pipeline = new TrainingPipeline(steps);
|
|
||||||
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
|
|
||||||
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_the_whole_algorithm(){
|
|
||||||
|
|
||||||
List<Float> expectedGlobalLosses = List.of(
|
|
||||||
2.0F,
|
|
||||||
3.0F,
|
|
||||||
3.0F,
|
|
||||||
2.0F,
|
|
||||||
1.0F,
|
|
||||||
0.0F
|
|
||||||
);
|
|
||||||
|
|
||||||
context.learningRate = 1F;
|
|
||||||
pipeline.afterEpoch(ctx -> {
|
|
||||||
int index = ctx.epoch-1;
|
|
||||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
|
||||||
});
|
|
||||||
|
|
||||||
pipeline.run(context);
|
|
||||||
assertEquals(6, context.epoch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user