Compare commits

...

36 Commits

Author SHA1 Message Date
b253fb74ee Add model visualization 2026-04-04 17:16:18 +02:00
8beb6aa870 Tune model parameters 2026-04-03 21:19:27 +02:00
40ebca469e Add JSON config loading 2026-04-03 17:58:28 +02:00
42e6d3dde8 Forgot to add deps 2026-04-03 16:25:46 +02:00
87536f5a55 Integrate model persistence 2026-04-03 16:13:39 +02:00
5a73337687 Minor changes 2026-04-02 09:21:38 +02:00
4c1eaff238 Optimize prediction 2026-04-02 09:07:58 +02:00
5ddf6dc580 Reworked synapses data structure 2026-04-01 22:48:06 +02:00
4441b149f9 Fix weighted sum back 2026-04-01 17:40:33 +02:00
1e8b02089c Optimize some stuff 2026-04-01 16:14:13 +02:00
daba4f8420 Implement batch size 2026-03-31 22:52:03 +02:00
5aca7b87e3 Minor fixes 2026-03-31 16:26:28 +02:00
165a2bc977 Minor changes 2026-03-30 23:08:37 +02:00
881088df28 Minor perfomance improvements 2026-03-30 22:14:33 +02:00
fd97d0853c Fix multi layer implementation 2026-03-30 21:13:03 +02:00
ada01d350b Change signature of train method 2026-03-30 18:28:21 +02:00
aed78fe9d2 Implement multi layer 2026-03-30 13:38:44 +02:00
b36a900f87 Delete some stuff 2026-03-29 21:33:00 +02:00
0fe309cd4e Rename some stuff 2026-03-29 21:32:08 +02:00
83526b72d4 Just a regular commit 2026-03-28 17:53:21 +01:00
17cff89b44 Implement gradient backpropagation stub 2026-03-28 13:19:36 +01:00
6d88651385 Move dataset components 2026-03-28 12:25:59 +01:00
7fb4a7c057 Minor changes 2026-03-27 12:40:00 +01:00
572e5c7484 Add some plotting 2026-03-26 22:35:07 +01:00
64bc830f18 Add multi-layer support 2026-03-26 21:21:31 +01:00
3dd4404f51 Clean unused import statements 2026-03-26 11:27:50 +01:00
0d3ab0de8d Reimplement Adaline 2026-03-26 11:27:10 +01:00
c389646794 Add gradient descent test 2026-03-26 08:23:24 +01:00
76465ab6ee Start to add test coverage 2026-03-25 22:36:26 +01:00
65d3a0e3e4 Fix implementation 2026-03-25 16:11:09 +01:00
0217607e9b Start to reimplement gradient descent 2026-03-23 23:12:52 +01:00
5ace4952fb Just a regular commit 2026-03-23 18:47:36 +01:00
a84c3d999d Fix weights correction 2026-03-23 17:13:53 +01:00
b25aaba088 Implement main structure of framework 2026-03-23 16:39:12 +01:00
76bc791889 Just a regular commit 2026-03-22 23:36:44 +01:00
56f88bded3 Remove .idea 2026-03-20 17:36:01 +01:00
69 changed files with 2162 additions and 374 deletions

10
.idea/.gitignore generated vendored
View File

@@ -1,10 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml
# Ignored default folder with query files
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/

1
.idea/.name generated
View File

@@ -1 +0,0 @@
ANN

17
.idea/gradle.xml generated
View File

@@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="externalProjectPath" value="$PROJECT_DIR$" />
<option name="gradleHome" value="" />
<option name="modules">
<set>
<option value="$PROJECT_DIR$" />
</set>
</option>
</GradleProjectSettings>
</option>
</component>
</project>

10
.idea/misc.xml generated
View File

@@ -1,10 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="FrameworkDetectionExcludesConfiguration">
<file type="web" url="file://$PROJECT_DIR$" />
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="21" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@@ -10,6 +10,9 @@ repositories {
}
dependencies {
implementation("org.jfree:jfreechart:1.5.4")
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
testImplementation(platform("org.junit:junit-bom:5.10.0"))
testImplementation("org.junit.jupiter:junit-jupiter")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")

14
config.json Normal file
View File

@@ -0,0 +1,14 @@
{
"model": {
"new": true,
"parameters": [2, 4, 2, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
},
"training" : {
"learning_rate" : 0.0003,
"max_epoch" : 5000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
}
}

View File

@@ -1,77 +1,110 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
import java.util.*;
public class Main {
public static void main(String[] args){
public static void main(String[] args) throws Exception {
DataSet orDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
boolean newModel = config.getModelProperty("new", Boolean.class);
int[] modelParameters = config.getModelProperty("parameters", int[].class);
String modelPath = config.getModelProperty("path", String.class);
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
));
int nbrClass = 5;
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
ModelSnapshot snapshot;
Bias bias = new Bias(new Weight(0));
Model network;
if(newModel){
network = createNetwork(modelParameters, nbrInput);
snapshot = new ModelSnapshot(network);
System.out.println("Parameters: " + network.synCount());
Trainer trainer = new GradientBackpropagationTraining();
trainer.train(learningRate, maxEpoch, network, dataset);
snapshot.saveToFile(modelPath);
} else {
snapshot = new ModelSnapshot();
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
//plotGraph(dataset, network);
Neuron n = new SimplePerceptron(syns, bias, new Linear());
AdalineTraining st = new AdalineTraining();
long start = System.currentTimeMillis();
st.train(n, 0.03F, andDataSet);
long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
new ModelVisualizer(network)
.withWeights(true)
.display();
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
int neuronId = 0;
List<Layer> layers = new ArrayList<>();
for (int i = 0; i < neuronPerLayer.length; i++){
List<Neuron> neurons = new ArrayList<>();
for (int j = 0; j < neuronPerLayer[i]; j++){
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
List<Synapse> syns = new ArrayList<>();
for (int k=0; k < nbrSyn; k++){
syns.add(new Synapse(new Input(0), new Weight()));
}
Bias bias = new Bias(new Weight());
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
neurons.add(n);
neuronId++;
}
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
layers.add(layer);
}
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
}
private static void plotGraph(DataSet dataset, Model network){
GraphVisualizer visualizer = new GraphVisualizer();
/*for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
label.forEach(l -> {
visualizer.addPoint("Label " + l,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
});
}*/
float min = -50F;
float max = 50F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
float[] predictions = network.predict(new float[]{x, y});
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
}
}
visualizer.buildScatterGraph((int)min-1, (int)max+1);
}
}

View File

@@ -1,7 +1,10 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public interface ActivationFunction {
float accept(Neuron n);
float derivative(float value);
}

View File

@@ -0,0 +1,8 @@
package com.naaturel.ANN.domain.abstraction;
@FunctionalInterface
public interface AlgorithmStep {
void run();
}

View File

@@ -0,0 +1,20 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public interface Model {
int synCount();
int neuronCount();
int layerIndexOf(Neuron n);
int indexInLayerOf(Neuron n);
void forEachNeuron(Consumer<Neuron> consumer);
//void forEachSynapse(Consumer<Synapse> consumer);
void forEachOutputNeurons(Consumer<Neuron> consumer);
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
float[] predict(float[] inputs);
}

View File

@@ -0,0 +1,10 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import java.util.function.Consumer;
public interface Network {
}

View File

@@ -1,55 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public abstract class Neuron {
protected List<Synapse> synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
}
public abstract float predict();
public abstract float calculateWeightedSum();
public int getSynCount(){
return this.synapses.size();
}
public void setInput(int index, Input input){
Synapse syn = this.synapses.get(index);
syn.setInput(input.getValue());
}
public Bias getBias(){
return this.bias;
}
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
}
public Synapse getSynapse(int index){
return this.synapses.get(index);
}
public List<Synapse> getSynapses() {
return new ArrayList<>(this.synapses);
}
public void setWeight(int index, Weight weight){
Synapse syn = this.synapses.get(index);
syn.setWeight(weight.getValue());
}
}

View File

@@ -1,13 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
public abstract class NeuronTrainer {
private Trainable trainable;
public NeuronTrainer(Trainable trainable){
this.trainable = trainable;
}
public abstract void train();
}

View File

@@ -1,7 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
public interface Trainable {
}

View File

@@ -0,0 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public interface Trainer {
void train(float learningRate, int epoch, Model model, DataSet dataset);
}

View File

@@ -0,0 +1,29 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.List;
public abstract class TrainingContext {
public Model model;
public DataSet dataset;
public DataSetEntry currentEntry;
public List<Float> expectations;
public float[] predictions;
public float[] deltas;
public float globalLoss;
public float localLoss;
public float learningRate;
public int epoch;
public TrainingContext(Model model, DataSet dataset) {
this.model = model;
this.dataset = dataset;
this.deltas = new float[dataset.getNbrLabels()];
}
}

View File

@@ -1,54 +0,0 @@
package com.naaturel.ANN.domain.model.dataset;
import java.util.*;
public class DataSet implements Iterable<DataSetEntry>{
private Map<DataSetEntry, Label> data;
public DataSet(){
this(new HashMap<>());
}
public DataSet(Map<DataSetEntry, Label> data){
this.data = data;
}
public int size() {
return data.size();
}
public List<DataSetEntry> getData(){
return new ArrayList<>(this.data.keySet());
}
public Label getLabel(DataSetEntry entry){
return this.data.get(entry);
}
public DataSet toNormalized() {
List<DataSetEntry> entries = this.getData();
float maxAbs = entries.stream()
.flatMap(e -> e.getData().stream())
.map(Math::abs)
.max(Float::compare)
.orElse(1.0F);
Map<DataSetEntry, Label> normalized = new HashMap<>();
for (DataSetEntry entry : entries) {
List<Float> normalizedData = new ArrayList<>();
for (float value : entry.getData()) {
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
}
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
}
return new DataSet(normalized);
}
@Override
public Iterator<DataSetEntry> iterator() {
return this.data.keySet().iterator();
}
}

View File

@@ -1,15 +0,0 @@
package com.naaturel.ANN.domain.model.dataset;
public class Label {
private float value;
public Label(float value){
this.value = value;
}
public float getValue() {
return value;
}
}

View File

@@ -0,0 +1,102 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
/**
* Represents a fully connected neural network
*/
public class FullyConnectedNetwork implements Model {
private final Layer[] layers;
private final Map<Neuron, List<Neuron>> connectionMap;
private final Map<Neuron, Integer> layerIndexByNeuron;
public FullyConnectedNetwork(Layer[] layers) {
this.layers = layers;
this.connectionMap = this.createConnectionMap();
this.layerIndexByNeuron = this.createNeuronIndex();
}
@Override
public float[] predict(float[] inputs) {
float[] previousLayerOutputs = inputs;
for (Layer layer : layers) {
previousLayerOutputs = layer.predict(previousLayerOutputs);
}
return previousLayerOutputs;
}
@Override
public int synCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.synCount();
}
return res;
}
@Override
public int neuronCount() {
int res = 0;
for(Layer layer : this.layers){
res += layer.neuronCount();
}
return res;
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
for(Layer l : this.layers){
l.forEachNeuron(consumer);
}
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
int lastIndex = this.layers.length-1;
this.layers[lastIndex].forEachNeuron(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
if(!this.connectionMap.containsKey(n)) return;
this.connectionMap.get(n).forEach(consumer);
}
@Override
public int layerIndexOf(Neuron n) {
return this.layerIndexByNeuron.get(n);
}
@Override
public int indexInLayerOf(Neuron n) {
int layerIndex = this.layerIndexByNeuron.get(n);
return this.layers[layerIndex].indexInLayerOf(n);
}
private Map<Neuron, List<Neuron>> createConnectionMap() {
Map<Neuron, List<Neuron>> res = new HashMap<>();
for (int i = 0; i < this.layers.length - 1; i++) {
List<Neuron> nextLayerNeurons = new ArrayList<>();
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
}
return res;
}
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
AtomicInteger index = new AtomicInteger(0);
for(Layer l : this.layers){
l.forEachNeuron(n -> res.put(n, index.get()));
index.incrementAndGet();
}
return res;
}
}

View File

@@ -0,0 +1,87 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
public class Layer implements Model {
private final Neuron[] neurons;
private final Map<Neuron, Integer> neuronIndex;
public Layer(Neuron[] neurons) {
this.neurons = neurons;
this.neuronIndex = createNeuronIndex();
}
@Override
public float[] predict(float[] inputs) {
float[] result = new float[neurons.length];
for (int i = 0; i < neurons.length; i++) {
result[i] = neurons[i].predict(inputs)[0];
}
return result;
}
@Override
public int synCount() {
int res = 0;
for (Neuron neuron : this.neurons) {
res += neuron.synCount();
}
return res;
}
@Override
public int neuronCount() {
return this.neurons.length;
}
@Override
public int layerIndexOf(Neuron n) {
return 0;
}
@Override
public int indexInLayerOf(Neuron n) {
return this.neuronIndex.get(n);
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
for (Neuron n : this.neurons){
consumer.accept(n);
}
}
/*@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
for (Neuron n : this.neurons){
n.forEachSynapse(consumer);
}
}*/
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
this.forEachNeuron(consumer);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
}
private Map<Neuron, Integer> createNeuronIndex() {
Map<Neuron, Integer> res = new HashMap<>();
int[] index = {0};
this.forEachNeuron(n -> {
res.put(n, index[0]++);
});
return res;
}
}

View File

@@ -0,0 +1,111 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.List;
import java.util.function.Consumer;
public class Neuron implements Model {
private final int id;
private float output;
private final float[] weights;
private final float[] inputs;
private final ActivationFunction activationFunction;
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
this.id = id;
this.activationFunction = func;
output = 0;
weights = new float[synapses.length+1]; //takes the bias into account
inputs = new float[synapses.length+1]; //takes the bias into account
weights[0] = bias.getWeight();
inputs[0] = bias.getInput();
for (int i = 0; i < synapses.length; i++){
weights[i+1] = synapses[i].getWeight();
inputs[i+1] = synapses[i].getInput();
}
}
public void setWeight(int index, float value) {
this.weights[index] = value;
}
public float getWeight(int index) {
return this.weights[index];
}
public float getInput(int index) {
return this.inputs[index];
}
public ActivationFunction getActivationFunction(){
return this.activationFunction;
}
public float calculateWeightedSum() {
int count = weights.length;
float weightedSum = 0F;
for (int i = 0; i < count; i++){
weightedSum += weights[i] * inputs[i];
}
return weightedSum;
}
public int getId(){
return this.id;
}
public float getOutput() {
return this.output;
}
@Override
public int synCount() {
return this.weights.length;
}
@Override
public int neuronCount() {
return 1;
}
@Override
public int layerIndexOf(Neuron n) {
return 0;
}
@Override
public int indexInLayerOf(Neuron n) {
return 0;
}
@Override
public float[] predict(float[] inputs) {
this.setInputs(inputs);
output = activationFunction.accept(this);
return new float[] {output};
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
consumer.accept(this);
}
@Override
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
consumer.accept(this);
}
@Override
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
throw new UnsupportedOperationException("Neurons have no connection with themselves");
}
private void setInputs(float[] values){
System.arraycopy(values, 0, inputs, 1, values.length);
}
}

View File

@@ -14,8 +14,8 @@ public class Synapse {
return this.input.getValue();
}
public void setInput(float value){
this.input.setValue(value);
public void setInput(Input input){
this.input.setValue(input.getValue());
}
public float getWeight() {
@@ -25,6 +25,4 @@ public class Synapse {
public void setWeight(float value){
this.weight.setValue(value);
}
}

View File

@@ -0,0 +1,132 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
public class TrainingPipeline {
private final List<AlgorithmStep> steps;
private Consumer<TrainingContext> beforeEpoch;
private Consumer<TrainingContext> afterEpoch;
private Predicate<TrainingContext> stopCondition;
private boolean verbose;
private boolean visualization;
private boolean timeMeasurement;
private GraphVisualizer visualizer;
private int verboseDelay;
public TrainingPipeline(List<AlgorithmStep> steps) {
this.steps = new ArrayList<>(steps);
this.stopCondition = (ctx) -> false;
this.beforeEpoch = (context -> {});
this.afterEpoch = (context -> {});
}
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
this.stopCondition = predicate;
return this;
}
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
this.beforeEpoch = consumer;
return this;
}
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
this.afterEpoch = consumer;
return this;
}
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
this.verbose = enabled;
this.verboseDelay = epochDelay;
return this;
}
public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) {
this.visualization = enabled;
this.visualizer = visualizer;
return this;
}
public TrainingPipeline withTimeMeasurement(boolean enabled) {
this.timeMeasurement = enabled;
return this;
}
public void run(TrainingContext ctx) {
long start = this.timeMeasurement ? System.currentTimeMillis() : 0;
do {
this.beforeEpoch.accept(ctx);
this.executeSteps(ctx);
this.afterEpoch.accept(ctx);
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
}
ctx.epoch += 1;
} while (!this.stopCondition.test(ctx));
if(this.timeMeasurement) {
long end = System.currentTimeMillis();
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
}
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
//if(this.visualization) this.visualize(ctx);
}
private void executeSteps(TrainingContext ctx){
for (DataSetEntry entry : ctx.dataset) {
ctx.currentEntry = entry;
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
for (AlgorithmStep step : steps) {
step.run();
}
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}
}
/*private void visualize(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
float b = weights.get(0);
float w1 = weights.get(1);
float w2 = weights.get(2);
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
});
int i = 0;
for(DataSetEntry entry : ctx.dataset){
List<Input> inputs = entry.getData();
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
}
this.visualizer.buildLineGraph();
}*/
}

View File

@@ -1,17 +0,0 @@
package com.naaturel.ANN.implementation.activationFunction;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
public class Heaviside implements ActivationFunction {
public Heaviside(){
}
@Override
public float accept(Neuron n) {
float weightedSum = n.calculateWeightedSum();
return weightedSum <= 0 ? 0:1;
}
}

View File

@@ -1,13 +0,0 @@
package com.naaturel.ANN.implementation.activationFunction;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
public class Linear implements ActivationFunction {
@Override
public float accept(Neuron n) {
return n.calculateWeightedSum();
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.adaline;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public class AdalineTrainingContext extends TrainingContext {
public AdalineTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -0,0 +1,27 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentCorrectionStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context;
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void run() {
int[] globalSynIndex = {0};
context.model.forEachNeuron(n -> {
for(int i = 0; i < n.synCount(); i++){
float corrector = context.correctorTerms.get(globalSynIndex[0]);
float c = n.getWeight(i) + corrector;
n.setWeight(i, c);
globalSynIndex[0]++;
}
});
}
}

View File

@@ -0,0 +1,36 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import java.util.concurrent.atomic.AtomicInteger;
public class GradientDescentErrorStrategy implements AlgorithmStep {
private final GradientDescentTrainingContext context;
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
this.context = context;
}
@Override
public void run() {
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()];
for(int i = 0; i < neuron.synCount(); i++){
float corrector = context.correctorTerms.get(synIndex.get());
corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
context.correctorTerms.set(synIndex.get(), corrector);
synIndex.incrementAndGet();
}
neuronIndex.incrementAndGet();
});
context.globalLoss += context.localLoss;
}
}

View File

@@ -0,0 +1,16 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
public class GradientDescentTrainingContext extends TrainingContext {
public List<Float> correctorTerms;
public GradientDescentTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Linear implements ActivationFunction {
private final float slope;
private final float intercept;
public Linear(float slope, float intercept) {
this.slope = slope;
this.intercept = intercept;
}
@Override
public float accept(Neuron n) {
return slope * n.calculateWeightedSum() + intercept;
}
@Override
public float derivative(float value) {
return this.slope;
}
}

View File

@@ -0,0 +1,25 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.stream.Stream;
public class SquareLossStep implements AlgorithmStep {
private final TrainingContext context;
public SquareLossStep(TrainingContext context) {
this.context = context;
}
@Override
public void run() {
float loss = 0f;
for (float d : this.context.deltas) {
loss += d * d;
}
this.context.localLoss = loss / 2f;
this.context.globalLoss += this.context.localLoss;
}
}

View File

@@ -0,0 +1,56 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class BackpropagationCorrectionStep implements AlgorithmStep {
private final GradientBackpropagationContext context;
private final int synCount;
private final float[] inputs;
private final float[] signals;
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
this.context = context;
this.synCount = context.correctionBuffer.length;
this.inputs = new float[synCount];
this.signals = new float[synCount];
}
@Override
public void run() {
int[] synIndex = {0};
context.model.forEachNeuron(n -> {
float signal = context.errorSignals[n.getId()];
for (int i = 0; i < n.synCount(); i++){
inputs[synIndex[0]] = n.getInput(i);
signals[synIndex[0]] = signal;
synIndex[0]++;
}
});
float lr = context.learningRate;
boolean applyUpdate = context.currentSample >= context.batchSize;
for (int i = 0; i < synCount; i++) {
context.correctionBuffer[i] += lr * signals[i] * inputs[i];
}
if (applyUpdate) {
syncWeights();
context.currentSample = 0;
}
context.currentSample++;
}
private void syncWeights() {
int[] synIndex = {0};
context.model.forEachNeuron(n -> {
for (int i = 0; i < n.synCount(); i++) {
n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
context.correctionBuffer[synIndex[0]] = 0f;
synIndex[0]++;
}
});
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class BatchAccumulatorStep implements AlgorithmStep {
@Override
public void run() {
}
}

View File

@@ -0,0 +1,29 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class ErrorSignalStep implements AlgorithmStep {
private final GradientBackpropagationContext context;
public ErrorSignalStep(GradientBackpropagationContext context) {
this.context = context;
}
@Override
public void run() {
context.model.forEachNeuron(n -> {
if (context.errorSignalsComputed[n.getId()]) return;
int neuronIndex = context.model.indexInLayerOf(n);
float[] signalSum = {0f};
context.model.forEachNeuronConnectedTo(n, connected -> {
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
});
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
context.errorSignalsComputed[n.getId()] = true;
});
}
}

View File

@@ -0,0 +1,30 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.HashMap;
import java.util.Map;
public class GradientBackpropagationContext extends TrainingContext {
public final float[] errorSignals;
public final float[] correctionBuffer;
public final boolean[] errorSignalsComputed;
public int currentSample;
public int batchSize;
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
super(model, dataSet);
this.learningRate = learningRate;
this.batchSize = batchSize;
this.errorSignals = new float[model.neuronCount()];
this.correctionBuffer = new float[model.synCount()];
this.errorSignalsComputed = new boolean[model.neuronCount()];
this.currentSample = 1;
}
}

View File

@@ -0,0 +1,42 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.Arrays;
import java.util.List;
public class OutputLayerErrorStep implements AlgorithmStep {
private final GradientBackpropagationContext context;
private final float[] expectations;
public OutputLayerErrorStep(GradientBackpropagationContext context){
this.context = context;
this.expectations = new float[context.dataset.getNbrLabels()];
}
@Override
public void run() {
Arrays.fill(context.errorSignals, 0f);
Arrays.fill(context.errorSignalsComputed, false);
DataSetEntry entry = context.currentEntry;
List<Float> labels = context.dataset.getLabelsAsFloat(entry);
for (int i = 0; i < labels.size(); i++) {
expectations[i] = labels.get(i);
}
int[] index = {0};
context.model.forEachOutputNeurons(n -> {
float expected = expectations[index[0]];
float predicted = n.getOutput();
float delta = expected - predicted;
context.deltas[index[0]] = delta;
context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted);
context.errorSignalsComputed[n.getId()] = true;
index[0]++;
});
}
}

View File

@@ -0,0 +1,23 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Sigmoid implements ActivationFunction {
private float steepness;
public Sigmoid(float steepness) {
this.steepness = steepness;
}
@Override
public float accept(Neuron n) {
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
}
@Override
public float derivative(float value) {
return steepness * value * (1 - value);
}
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.implementation.multiLayers;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class TanH implements ActivationFunction {
@Override
public float accept(Neuron n) {
//For educational purpose. Math.tanh() could have been used here
float weightedSum = n.calculateWeightedSum();
double exp = Math.exp(weightedSum);
double res = (exp-(1/exp))/(exp+(1/exp));
return (float)(res);
}
@Override
public float derivative(float value) {
return 1 - value * value;
}
}

View File

@@ -1,32 +0,0 @@
package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
public class SimplePerceptron extends Neuron implements Trainable {
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
super(synapses, b, func);
}
@Override
public float predict() {
return activationFunction.accept(this);
}
@Override
public float calculateWeightedSum() {
float res = 0;
for(Synapse syn : super.synapses){
res += syn.getWeight() * syn.getInput();
}
res += this.bias.getWeight() * this.bias.getInput();
return res;
}
}

View File

@@ -0,0 +1,24 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import javax.naming.OperationNotSupportedException;
public class Heaviside implements ActivationFunction {
public Heaviside(){
}
@Override
public float accept(Neuron n) {
float weightedSum = n.calculateWeightedSum();
return weightedSum < 0 ? 0:1;
}
@Override
public float derivative(float value) {
throw new UnsupportedOperationException("Heaviside is not differentiable");
}
}

View File

@@ -0,0 +1,34 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.concurrent.atomic.AtomicInteger;
public class SimpleCorrectionStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleCorrectionStep(TrainingContext context) {
this.context = context;
}
@Override
public void run() {
if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas[neuronIndex.get()];
for(int i = 0; i < neuron.synCount(); i++){
float currentW = neuron.getWeight(i);
float currentInput = neuron.getInput(i);
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
neuron.setWeight(i, newValue);
}
neuronIndex.incrementAndGet();
});
}
}

View File

@@ -0,0 +1,32 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class SimpleDeltaStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleDeltaStep(TrainingContext context) {
this.context = context;
}
@Override
public void run() {
DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry;
float[] predicted = context.predictions;
List<Float> expected = dataSet.getLabelsAsFloat(entry);
for (int i = 0; i < predicted.length; i++) {
context.deltas[i] = expected.get(i) - predicted[0];
}
}
}

View File

@@ -0,0 +1,18 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
public class SimpleErrorRegistrationStep implements AlgorithmStep {
private final TrainingContext context;
public SimpleErrorRegistrationStep(TrainingContext context) {
this.context = context;
}
@Override
public void run() {
context.globalLoss += context.localLoss;
}
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
public class SimpleLossStrategy implements AlgorithmStep {
private final SimpleTrainingContext context;
public SimpleLossStrategy(SimpleTrainingContext context) {
this.context = context;
}
@Override
public void run() {
float loss = 0f;
for (float d : context.deltas) {
loss += d;
}
context.localLoss = loss;
}
}

View File

@@ -0,0 +1,26 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.List;
public class SimplePredictionStep implements AlgorithmStep {
private final TrainingContext context;
public SimplePredictionStep(TrainingContext context) {
this.context = context;
}
@Override
public void run() {
List<Input> data = context.currentEntry.getData();
float[] flatData = new float[data.size()];
for (int i = 0; i < data.size(); i++) {
flatData[i] = data.get(i).getValue();
}
context.predictions = context.model.predict(flatData);
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
public class SimpleTrainingContext extends TrainingContext {
public SimpleTrainingContext(Model model, DataSet dataset) {
super(model, dataset);
}
}

View File

@@ -1,33 +1,60 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.ArrayList;
import java.util.List;
public class AdalineTraining {
public class AdalineTraining implements Trainer {
public AdalineTraining(){
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
.withTimeMeasurement(true)
.withVerbose(true, 1)
.withVisualization(true, new GraphVisualizer())
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 1000;
int maxEpoch = 202;
float errorThreshold = 0.0F;
float mse;
do {
if(epoch > maxEpoch) break;
mse = 0;
for(DataSetEntry entry : dataSet) {
for(DataSetEntry entry : dataSet) {
this.updateInputs(n, entry);
float prediction = n.predict();
float expectation = dataSet.getLabel(entry).getValue();
@@ -49,23 +76,22 @@ public class AdalineTraining {
System.out.printf("predicted : %.2f, ", prediction);
System.out.printf("expected : %.2f, ", expectation);
System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss);
System.out.printf("loss : %.5f\n", loss);
}
mse /= dataSet.size();
System.out.printf("[Total error : %f]\n", mse);
System.out.println("[Final weights]");
System.out.printf("Bias: %f\n", n.getBias().getWeight());
int i = 1;
for(Synapse syn : n.getSynapses()){
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
i++;
}
epoch++;
} while(mse > errorThreshold);
}
private List<Float> initCorrectorTerms(int number){
List<Float> res = new ArrayList<>();
for(int i = 0; i < number; i++){
res.add(0F);
}
return res;
}
private void updateInputs(Neuron n, DataSetEntry entry){
int index = 0;
for(float value : entry){
@@ -80,10 +106,6 @@ public class AdalineTraining {
private float calculateLoss(float delta){
return (float) Math.pow(delta, 2)/2;
}
private float calculateWeightCorrection(float value, float delta){
return value * delta;
}
}*/
}

View File

@@ -0,0 +1,42 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import java.util.List;
public class GradientBackpropagationTraining implements Trainer {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientBackpropagationContext context =
new GradientBackpropagationContext(model, dataset, learningRate, 10);
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new OutputLayerErrorStep(context),
new ErrorSignalStep(context),
new BackpropagationCorrectionStep(context),
new SquareLossStep(context)
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
})
.afterEpoch(ctx -> {
ctx.globalLoss /= dataset.size();
})
.withVerbose(true,epoch/10)
.withTimeMeasurement(true)
.run(context);
}
}

View File

@@ -1,26 +1,64 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import java.util.ArrayList;
import java.util.List;
public class GradientDescentTraining {
public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
context.learningRate = learningRate;
context.correctorTerms = new ArrayList<>();
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SquareLossStep(context),
new GradientDescentErrorStrategy(context)
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear();
for(int i = 0; i < gdCtx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
})
.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).run();
})
//.withVerbose(true)
.withTimeMeasurement(true)
.withVisualization(true, new GraphVisualizer())
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int maxEpoch = 1000;
float errorThreshold = 0.0F;
int maxEpoch = 402;
float errorThreshold = 0F;
float mse;
do {
@@ -54,6 +92,7 @@ public class GradientDescentTraining {
System.out.printf("delta : %.2f, ", delta);
System.out.printf("loss : %.2f\n", loss);
}
mse /= dataSet.size();
System.out.printf("[Total error : %f]\n", mse);
float currentBias = n.getBias().getWeight();
@@ -69,6 +108,13 @@ public class GradientDescentTraining {
epoch++;
} while(mse > errorThreshold);
System.out.println("[Final weights]");
System.out.printf("Bias: %f\n", n.getBias().getWeight());
int i = 1;
for(Synapse syn : n.getSynapses()){
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
i++;
}
}
private List<Float> initCorrectorTerms(int number){
@@ -95,8 +141,21 @@ public class GradientDescentTraining {
return (float) Math.pow(delta, 2)/2;
}
private float calculateWeightCorrection(float value, float delta){
return value * delta;
}
public float computeThreshold(DataSet dataSet) {
float sum = 0;
for (DataSetEntry entry : dataSet) {
sum += dataSet.getLabel(entry).getValue();
}
float mean = sum / dataSet.size();
float variance = 0;
for (DataSetEntry entry : dataSet) {
float diff = dataSet.getLabel(entry).getValue() - mean;
variance += diff * diff;
}
variance /= dataSet.size();
return variance;
}*/
}

View File

@@ -1,19 +1,44 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
public class SimpleTraining {
import java.util.List;
public class SimpleTraining implements Trainer {
public SimpleTraining() {
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
context.dataset = dataset;
context.model = model;
context.learningRate = learningRate;
List<AlgorithmStep> steps = List.of(
new SimplePredictionStep(context),
new SimpleDeltaStep(context),
new SimpleLossStrategy(context),
new SimpleErrorRegistrationStep(context),
new SimpleCorrectionStep(context)
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
.beforeEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true, 1)
.run(context);
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int errorCount;
@@ -64,5 +89,5 @@ public class SimpleTraining {
private float calculateLoss(float delta){
return Math.abs(delta);
}
*/
}

View File

@@ -0,0 +1,54 @@
package com.naaturel.ANN.infrastructure.config;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Map;
public class ConfigDto {
@JsonProperty("model")
private Map<String, Object> modelConfig;
@JsonProperty("training")
private Map<String, Object> trainingConfig;
@JsonProperty("dataset")
private Map<String, Object> datasetConfig;
public <T> T getModelProperty(String key, Class<T> type) {
Object value = find(key, this.modelConfig);
if (value instanceof List<?> list && type.isArray()) {
int[] arr = new int[list.size()];
for (int i = 0; i < list.size(); i++) {
arr[i] = ((Number) list.get(i)).intValue();
}
return type.cast(arr);
}
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
public <T> T getTrainingProperty(String key, Class<T> type) {
Object value = find(key, this.trainingConfig);
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
public <T> T getDatasetProperty(String key, Class<T> type) {
Object value = find(key, this.datasetConfig);
if (!type.isInstance(value)) {
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
}
return type.cast(value);
}
private Object find(String key, Map<String, Object> config){
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
return config.get(key);
}
}

View File

@@ -0,0 +1,22 @@
package com.naaturel.ANN.infrastructure.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
public class ConfigLoader {
public static ConfigDto load(String path) throws Exception {
try {
ObjectMapper mapper = new ObjectMapper();
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
return config;
} catch (Exception e){
throw new Exception("Unable to load config : " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,89 @@
package com.naaturel.ANN.infrastructure.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
import java.util.stream.Stream;
public class DataSet implements Iterable<DataSetEntry>{
private final Map<DataSetEntry, Labels> data;
private final int nbrInputs;
private final int nbrLabels;
public DataSet() {
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
}
public DataSet(Map<DataSetEntry, Labels> data){
this.data = data;
this.nbrInputs = this.calculateNbrInput();
this.nbrLabels = this.calculateNbrLabel();
}
private int calculateNbrInput(){
//assumes every entry are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
}
private int calculateNbrLabel(){
//assumes every label are the same length
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
}
public int size() {
return data.size();
}
public int getNbrInputs() {
return this.nbrInputs;
}
public int getNbrLabels(){
return this.nbrLabels;
}
public List<DataSetEntry> getData(){
return new ArrayList<>(this.data.keySet());
}
public List<Float> getLabelsAsFloat(DataSetEntry entry){
return this.data.get(entry).getValues();
}
public DataSet toNormalized() {
List<DataSetEntry> entries = this.getData();
float maxAbs = entries.stream()
.flatMap(e -> e.getData().stream())
.map(Input::getValue)
.map(Math::abs)
.max(Float::compare)
.orElse(1.0F);
Map<DataSetEntry, Labels> normalized = new HashMap<>();
for (DataSetEntry entry : entries) {
List<Input> normalizedData = new ArrayList<>();
for (Input input : entry.getData()) {
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
normalizedData.add(normalizedInput);
}
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
}
return new DataSet(normalized);
}
@Override
public Iterator<DataSetEntry> iterator() {
return this.data.keySet().iterator();
}
}

View File

@@ -1,16 +1,18 @@
package com.naaturel.ANN.domain.model.dataset;
package com.naaturel.ANN.infrastructure.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.util.*;
public class DataSetEntry implements Iterable<Float> {
public class DataSetEntry implements Iterable<Input> {
private List<Float> data;
private List<Input> data;
public DataSetEntry(List<Float> data){
public DataSetEntry(List<Input> data){
this.data = data;
}
public List<Float> getData() {
public List<Input> getData() {
return new ArrayList<>(data);
}
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
}
@Override
public Iterator<Float> iterator() {
public Iterator<Input> iterator() {
return this.data.iterator();
}

View File

@@ -0,0 +1,43 @@
package com.naaturel.ANN.infrastructure.dataset;
import com.naaturel.ANN.domain.model.neuron.Input;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;
public class DatasetExtractor {
public DataSet extract(String path, int nbrLabels) {
Map<DataSetEntry, Labels> data = new LinkedHashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split(",");
String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels);
String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length);
List<Input> inputs = new ArrayList<>();
List<Float> labels = new ArrayList<>();
for (String entry : rawInputs) {
inputs.add(new Input(Float.parseFloat(entry.trim())));
}
for (String entry : rawLabels) {
labels.add(Float.parseFloat(entry.trim()));
}
data.put(new DataSetEntry(inputs), new Labels(labels));
}
} catch (IOException e) {
throw new RuntimeException("Failed to read dataset from: " + path, e);
}
return new DataSet(data);
}
}

View File

@@ -0,0 +1,16 @@
package com.naaturel.ANN.infrastructure.dataset;
import java.util.List;
public class Labels {
private final List<Float> values;
public Labels(List<Float> value){
this.values = value;
}
public List<Float> getValues() {
return values.stream().toList();
}
}

View File

@@ -0,0 +1,11 @@
package com.naaturel.ANN.infrastructure.persistence;
import java.util.List;
public class ModelDto {
private List<NeuronDto> neurons;
}

View File

@@ -0,0 +1,81 @@
package com.naaturel.ANN.infrastructure.persistence;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import java.io.File;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class ModelSnapshot {
private Model model;
private final ObjectMapper mapper;
public ModelSnapshot(){
this(null);
}
public ModelSnapshot(Model model){
this.model = model;
mapper = new ObjectMapper();
}
public Model getModel() {
return model;
}
public void saveToFile(String path) throws Exception {
ArrayNode root = mapper.createArrayNode();
model.forEachNeuron(n -> {
ObjectNode neuronNode = mapper.createObjectNode();
neuronNode.put("id", n.getId());
neuronNode.put("layerIndex", model.layerIndexOf(n));
ArrayNode weights = mapper.createArrayNode();
for (int i = 0; i < n.synCount(); i++) {
float weight = n.getWeight(i);
weights.add(weight);
}
neuronNode.set("weights", weights);
root.add(neuronNode);
});
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
}
public void loadFromFile(String path) throws Exception {
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
root.forEach(neuronNode -> {
int id = neuronNode.get("id").asInt();
int layerIndex = neuronNode.get("layerIndex").asInt();
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
Synapse[] synapses = new Synapse[weightsNode.size() - 1];
for (int i = 0; i < synapses.length; i++) {
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
}
Neuron n = new Neuron(id, synapses, bias, new TanH());
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
});
Layer[] layers = neuronsByLayer.values().stream()
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
.toArray(Layer[]::new);
this.model = new FullyConnectedNetwork(layers);
}
}

View File

@@ -0,0 +1,4 @@
package com.naaturel.ANN.infrastructure.persistence;
public class NeuronDto {
}

View File

@@ -0,0 +1,59 @@
package com.naaturel.ANN.infrastructure.visualization;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import javax.swing.*;
public class GraphVisualizer {
XYSeriesCollection dataset;
public GraphVisualizer(){
this.dataset = new XYSeriesCollection();
}
public void addPoint(String title, float x, float y) {
if (this.dataset.getSeriesIndex(title) == -1)
this.dataset.addSeries(new XYSeries(title));
this.dataset.getSeries(title).add(x, y);
}
public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) {
for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) {
float x2 = (-y1 * x1 - k) / y2;
addPoint(title, x1, x2);
}
}
public void buildLineGraph(){
JFreeChart chart = ChartFactory.createXYLineChart(
"Model learning", "X", "Y", dataset
);
JFrame frame = new JFrame("Training Loss");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.add(new ChartPanel(chart));
frame.pack();
frame.setVisible(true);
}
public void buildScatterGraph(int lower, int upper){
JFreeChart chart = ChartFactory.createScatterPlot(
"Predictions", "X", "Y", dataset
);
XYPlot plot = chart.getXYPlot();
plot.getDomainAxis().setRange(lower, upper);
plot.getRangeAxis().setRange(lower, upper);
JFrame frame = new JFrame("Predictions");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.add(new ChartPanel(chart));
frame.pack();
frame.setVisible(true);
}
}

View File

@@ -0,0 +1,114 @@
package com.naaturel.ANN.infrastructure.visualization;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import javax.swing.*;
import java.awt.*;
import java.util.*;
import java.util.List;
public class ModelVisualizer {
private final JFrame frame;
private final Model model;
private boolean withWeights;
public ModelVisualizer(Model model){
this.frame = initFrame();
this.model = model;
this.withWeights = false;
}
public ModelVisualizer withWeights(boolean value){
this.withWeights = value;
return this;
}
public void display() {
JPanel panel = buildPanel();
frame.add(panel);
frame.setVisible(true);
}
private JFrame initFrame(){
JFrame frame = new JFrame("Model Visualizer");
frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
frame.setSize(800, 600);
return frame;
}
private Map<Integer, Point> computeNeuronPositions(int width, int height) {
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
model.forEachNeuron(n -> {
int layerIndex = model.layerIndexOf(n);
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
});
Map<Integer, Point> neuronPositions = new HashMap<>();
int layerCount = neuronsByLayer.size();
int layerX = width / (layerCount + 1);
for (Map.Entry<Integer, List<Neuron>> entry : neuronsByLayer.entrySet()) {
int layerIndex = entry.getKey();
List<Neuron> neurons = entry.getValue();
int x = layerX * (layerIndex + 1);
int neuronCount = neurons.size();
for (int i = 0; i < neuronCount; i++) {
int y = height / (neuronCount + 1) * (i + 1);
neuronPositions.put(neurons.get(i).getId(), new Point(x, y));
}
}
return neuronPositions;
}
private void drawConnections(Graphics2D g2, Map<Integer, Point> neuronPositions) {
model.forEachNeuron(n -> {
Point from = neuronPositions.get(n.getId());
int neuronIndex = model.indexInLayerOf(n);
model.forEachNeuronConnectedTo(n, connected -> {
Point to = neuronPositions.get(connected.getId());
g2.setColor(Color.LIGHT_GRAY);
g2.drawLine(from.x, from.y, to.x, to.y);
if(!this.withWeights) return;
int mx = (from.x + to.x) / 2;
int my = (from.y + to.y) / 2;
float weight = connected.getWeight(neuronIndex + 1);
g2.setColor(Color.DARK_GRAY);
g2.drawString(String.format("%.2f", weight), mx, my);
});
});
}
private void drawNeurons(Graphics2D g2, Map<Integer, Point> neuronPositions, int neuronRadius) {
model.forEachNeuron(n -> {
Point p = neuronPositions.get(n.getId());
g2.setColor(Color.WHITE);
g2.fillOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
g2.setColor(Color.BLACK);
g2.drawOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
g2.drawString(String.valueOf(n.getId()), p.x - 5, p.y + 5);
});
}
private JPanel buildPanel() {
int neuronRadius = 20;
return new JPanel() {
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
Graphics2D g2 = (Graphics2D) g;
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
Map<Integer, Point> neuronPositions = computeNeuronPositions(getWidth(), getHeight());
drawConnections(g2, neuronPositions);
drawNeurons(g2, neuronPositions, neuronRadius);
}
};
}
}

View File

@@ -0,0 +1,4 @@
0,0,-1
0,1,-1
1,0,-1
1,1,1
1 0 0 -1
2 0 1 -1
3 1 0 -1
4 1 1 1

View File

@@ -18,4 +18,4 @@
4,6,-1
4,7,-1
4,9,1
4,10,1
4,10,1
1 1 6 1
18 4 6 -1
19 4 7 -1
20 4 9 1
21 4 10 1

View File

@@ -0,0 +1,4 @@
0,0,0
0,1,1
1,0,1
1,1,0
1 0 0 0
2 0 1 1
3 1 0 1
4 1 1 0

View File

@@ -0,0 +1,37 @@
[ {
"id" : 0,
"layerIndex" : 0,
"weights" : [ 0.73659766, -0.77503574, -0.23966503, 0.5794077, -0.1634599, 0.6023351, 0.770439, -0.53199804, -0.75792193, 0.36339867, 0.017356396, -0.3005041, -0.5709046, 0.8757278, 0.22738981, -0.22997773, 0.81583726, -0.6008339, 0.8148706, -0.27952576, 0.11130214, 0.21506107, -0.96409214, 0.8534266, 0.9998076, 0.8971077, -0.55812895, 0.28677964, -0.4315225, -0.12088442, 0.41834033, -0.83330417, 0.013990879, -0.021193504, 0.95400894, 0.24115086, 0.122039676, 0.7069808, 0.74929047, 0.2558391, 0.1307528, 0.45781684, -0.19833839 ]
}, {
"id" : 1,
"layerIndex" : 0,
"weights" : [ -0.25173664, -0.15121317, 0.3947369, 0.39584184, 0.8823843, -0.43822396, 0.38901758, -0.83357096, 0.4349022, -0.57956433, -0.78882015, -0.7952547, 0.87495327, 0.20618606, -0.802194, -0.108419776, 0.47257674, -0.7384255, 0.41351187, -0.6334251, -0.61948025, 0.40434432, -0.18576205, -0.47782838, 0.57454073, -0.36851442, -0.7262291, 0.3893906, 0.83334255, -0.6979947, 0.43623865, 0.5753021, 0.0041633844, 0.6598717, 0.21344411, -0.40663266, 0.73282886, -0.00848031, -0.7269838, -0.36129093, -0.7280016, -0.0039184093, -0.71608007 ]
}, {
"id" : 2,
"layerIndex" : 1,
"weights" : [ 0.0760473, -0.54335964, -0.789695 ]
}, {
"id" : 3,
"layerIndex" : 1,
"weights" : [ -0.16066408, 0.97240365, 0.72418904 ]
}, {
"id" : 4,
"layerIndex" : 1,
"weights" : [ 0.58718896, -0.59475064, 0.81476605 ]
}, {
"id" : 5,
"layerIndex" : 1,
"weights" : [ -0.09139645, 0.71847045, -0.19625723 ]
}, {
"id" : 6,
"layerIndex" : 2,
"weights" : [ 0.60579014, 0.7003196, 0.82173157, 0.7627305, 0.83753014 ]
}, {
"id" : 7,
"layerIndex" : 2,
"weights" : [ -0.7915581, -0.5355048, -1.0218064, 0.35036424, -0.6445867 ]
}, {
"id" : 8,
"layerIndex" : 3,
"weights" : [ -0.24528235, 0.08879423, -0.46046266 ]
} ]

View File

@@ -0,0 +1,92 @@
package adaline;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class AdalineTest {
private DataSet dataset;
private AdalineTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new AdalineTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
);
pipeline = new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000)
.beforeEpoch(ctx -> {
ctx.globalLoss = 0.0F;
});
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
0.501522F,
0.498601F
);
context.learningRate = 0.03F;
pipeline.afterEpoch(ctx -> {
ctx.globalLoss /= context.dataset.size();
int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return;
//assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
});
pipeline.run(context);
assertEquals(214, context.epoch);
}
}

View File

@@ -0,0 +1,98 @@
package gradientDescent;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class GradientDescentTest {
private DataSet dataset;
private GradientDescentTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new GradientDescentTrainingContext();
context.dataset = dataset;
context.model = network;
context.correctorTerms = new ArrayList<>();
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SquareLossStep(context)),
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
);
pipeline = new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;
gdCtx.correctorTerms.clear();
for (int i = 0; i < ctx.model.synCount(); i++){
gdCtx.correctorTerms.add(0F);
}
});
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
0.5F,
0.38F,
0.3176F,
0.272096F,
0.237469F
);
context.learningRate = 0.2F;
pipeline.afterEpoch(ctx -> {
context.globalLoss /= context.dataset.size();
new GradientDescentCorrectionStrategy(context).run();
int index = ctx.epoch-1;
if(index >= expectedGlobalLosses.size()) return;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
});
pipeline
.withVerbose(true)
.run(context);
assertEquals(67, context.epoch);
}
}

View File

@@ -0,0 +1,84 @@
package perceptron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class SimplePerceptronTest {
private DataSet dataset;
private SimpleTrainingContext context;
private List<Synapse> synapses;
private Bias bias;
private FullyConnectedNetwork network;
private TrainingPipeline pipeline;
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
bias = new Bias(new Weight(0));
Neuron neuron = new Neuron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
network = new FullyConnectedNetwork(List.of(layer));
context = new SimpleTrainingContext();
context.dataset = dataset;
context.model = network;
List<TrainingStep> steps = List.of(
new PredictionStep(new SimplePredictionStep(context)),
new DeltaStep(new SimpleDeltaStep(context)),
new LossStep(new SimpleLossStrategy(context)),
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
new WeightCorrectionStep(new SimpleCorrectionStep(context))
);
pipeline = new TrainingPipeline(steps);
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
}
@Test
public void test_the_whole_algorithm(){
List<Float> expectedGlobalLosses = List.of(
2.0F,
3.0F,
3.0F,
2.0F,
1.0F,
0.0F
);
context.learningRate = 1F;
pipeline.afterEpoch(ctx -> {
int index = ctx.epoch-1;
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
});
pipeline.run(context);
assertEquals(6, context.epoch);
}
}