Compare commits
36 Commits
edf1276a20
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| b253fb74ee | |||
| 8beb6aa870 | |||
| 40ebca469e | |||
| 42e6d3dde8 | |||
| 87536f5a55 | |||
| 5a73337687 | |||
| 4c1eaff238 | |||
| 5ddf6dc580 | |||
| 4441b149f9 | |||
| 1e8b02089c | |||
| daba4f8420 | |||
| 5aca7b87e3 | |||
| 165a2bc977 | |||
| 881088df28 | |||
| fd97d0853c | |||
| ada01d350b | |||
| aed78fe9d2 | |||
| b36a900f87 | |||
| 0fe309cd4e | |||
| 83526b72d4 | |||
| 17cff89b44 | |||
| 6d88651385 | |||
| 7fb4a7c057 | |||
| 572e5c7484 | |||
| 64bc830f18 | |||
| 3dd4404f51 | |||
| 0d3ab0de8d | |||
| c389646794 | |||
| 76465ab6ee | |||
| 65d3a0e3e4 | |||
| 0217607e9b | |||
| 5ace4952fb | |||
| a84c3d999d | |||
| b25aaba088 | |||
| 76bc791889 | |||
| 56f88bded3 |
10
.idea/.gitignore
generated
vendored
10
.idea/.gitignore
generated
vendored
@@ -1,10 +0,0 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Ignored default folder with query files
|
||||
/queries/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
1
.idea/.name
generated
1
.idea/.name
generated
@@ -1 +0,0 @@
|
||||
ANN
|
||||
17
.idea/gradle.xml
generated
17
.idea/gradle.xml
generated
@@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleHome" value="" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
10
.idea/misc.xml
generated
10
.idea/misc.xml
generated
@@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="FrameworkDetectionExcludesConfiguration">
|
||||
<file type="web" url="file://$PROJECT_DIR$" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_21" default="true" project-jdk-name="21" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/out" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -10,6 +10,9 @@ repositories {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("org.jfree:jfreechart:1.5.4")
|
||||
implementation("com.fasterxml.jackson.core:jackson-databind:2.21.2")
|
||||
|
||||
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
|
||||
14
config.json
Normal file
14
config.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"model": {
|
||||
"new": true,
|
||||
"parameters": [2, 4, 2, 1],
|
||||
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
|
||||
},
|
||||
"training" : {
|
||||
"learning_rate" : 0.0003,
|
||||
"max_epoch" : 5000
|
||||
},
|
||||
"dataset" : {
|
||||
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
|
||||
}
|
||||
}
|
||||
@@ -1,77 +1,110 @@
|
||||
package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.implementation.activationFunction.Heaviside;
|
||||
import com.naaturel.ANN.implementation.activationFunction.Linear;
|
||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
||||
DataSet orDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
|
||||
|
||||
DataSet andDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
boolean newModel = config.getModelProperty("new", Boolean.class);
|
||||
int[] modelParameters = config.getModelProperty("parameters", int[].class);
|
||||
String modelPath = config.getModelProperty("path", String.class);
|
||||
int maxEpoch = config.getTrainingProperty("max_epoch", Integer.class);
|
||||
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
|
||||
String datasetPath = config.getDatasetProperty("path", String.class);
|
||||
|
||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
|
||||
));
|
||||
int nbrClass = 5;
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
ModelSnapshot snapshot;
|
||||
|
||||
Bias bias = new Bias(new Weight(0));
|
||||
Model network;
|
||||
if(newModel){
|
||||
network = createNetwork(modelParameters, nbrInput);
|
||||
snapshot = new ModelSnapshot(network);
|
||||
System.out.println("Parameters: " + network.synCount());
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(learningRate, maxEpoch, network, dataset);
|
||||
snapshot.saveToFile(modelPath);
|
||||
} else {
|
||||
snapshot = new ModelSnapshot();
|
||||
snapshot.loadFromFile(modelPath);
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
//plotGraph(dataset, network);
|
||||
|
||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||
AdalineTraining st = new AdalineTraining();
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
st.train(n, 0.03F, andDataSet);
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||
new ModelVisualizer(network)
|
||||
.withWeights(true)
|
||||
.display();
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
int neuronId = 0;
|
||||
List<Layer> layers = new ArrayList<>();
|
||||
for (int i = 0; i < neuronPerLayer.length; i++){
|
||||
|
||||
List<Neuron> neurons = new ArrayList<>();
|
||||
for (int j = 0; j < neuronPerLayer[i]; j++){
|
||||
|
||||
int nbrSyn = i == 0 ? nbrInput: neuronPerLayer[i-1];
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
for (int k=0; k < nbrSyn; k++){
|
||||
syns.add(new Synapse(new Input(0), new Weight()));
|
||||
}
|
||||
|
||||
Bias bias = new Bias(new Weight());
|
||||
|
||||
Neuron n = new Neuron(neuronId, syns.toArray(new Synapse[0]), bias, new TanH());
|
||||
neurons.add(n);
|
||||
neuronId++;
|
||||
}
|
||||
Layer layer = new Layer(neurons.toArray(new Neuron[0]));
|
||||
layers.add(layer);
|
||||
}
|
||||
|
||||
return new FullyConnectedNetwork(layers.toArray(new Layer[0]));
|
||||
}
|
||||
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
/*for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
label.forEach(l -> {
|
||||
visualizer.addPoint("Label " + l,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
});
|
||||
}*/
|
||||
|
||||
float min = -50F;
|
||||
float max = 50F;
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
float[] predictions = network.predict(new float[]{x, y});
|
||||
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
|
||||
}
|
||||
}
|
||||
|
||||
visualizer.buildScatterGraph((int)min-1, (int)max+1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public interface ActivationFunction {
|
||||
|
||||
float accept(Neuron n);
|
||||
float derivative(float value);
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface AlgorithmStep {
|
||||
|
||||
void run();
|
||||
|
||||
}
|
||||
20
src/main/java/com/naaturel/ANN/domain/abstraction/Model.java
Normal file
20
src/main/java/com/naaturel/ANN/domain/abstraction/Model.java
Normal file
@@ -0,0 +1,20 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Model {
|
||||
int synCount();
|
||||
int neuronCount();
|
||||
int layerIndexOf(Neuron n);
|
||||
int indexInLayerOf(Neuron n);
|
||||
void forEachNeuron(Consumer<Neuron> consumer);
|
||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
||||
float[] predict(float[] inputs);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface Network {
|
||||
|
||||
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class Neuron {
|
||||
|
||||
protected List<Synapse> synapses;
|
||||
protected Bias bias;
|
||||
protected ActivationFunction activationFunction;
|
||||
|
||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||
this.synapses = synapses;
|
||||
this.bias = bias;
|
||||
this.activationFunction = func;
|
||||
}
|
||||
|
||||
public abstract float predict();
|
||||
public abstract float calculateWeightedSum();
|
||||
|
||||
public int getSynCount(){
|
||||
return this.synapses.size();
|
||||
}
|
||||
|
||||
public void setInput(int index, Input input){
|
||||
Synapse syn = this.synapses.get(index);
|
||||
syn.setInput(input.getValue());
|
||||
}
|
||||
|
||||
public Bias getBias(){
|
||||
return this.bias;
|
||||
}
|
||||
|
||||
public void updateBias(Weight weight) {
|
||||
this.bias.setWeight(weight.getValue());
|
||||
}
|
||||
|
||||
public Synapse getSynapse(int index){
|
||||
return this.synapses.get(index);
|
||||
}
|
||||
|
||||
public List<Synapse> getSynapses() {
|
||||
return new ArrayList<>(this.synapses);
|
||||
}
|
||||
|
||||
public void setWeight(int index, Weight weight){
|
||||
Synapse syn = this.synapses.get(index);
|
||||
syn.setWeight(weight.getValue());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
public abstract class NeuronTrainer {
|
||||
|
||||
private Trainable trainable;
|
||||
|
||||
public NeuronTrainer(Trainable trainable){
|
||||
this.trainable = trainable;
|
||||
}
|
||||
|
||||
public abstract void train();
|
||||
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
public interface Trainable {
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public interface Trainer {
|
||||
void train(float learningRate, int epoch, Model model, DataSet dataset);
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.naaturel.ANN.domain.abstraction;
|
||||
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public abstract class TrainingContext {
|
||||
public Model model;
|
||||
public DataSet dataset;
|
||||
public DataSetEntry currentEntry;
|
||||
|
||||
public List<Float> expectations;
|
||||
public float[] predictions;
|
||||
public float[] deltas;
|
||||
|
||||
public float globalLoss;
|
||||
public float localLoss;
|
||||
|
||||
public float learningRate;
|
||||
public int epoch;
|
||||
|
||||
public TrainingContext(Model model, DataSet dataset) {
|
||||
this.model = model;
|
||||
this.dataset = dataset;
|
||||
this.deltas = new float[dataset.getNbrLabels()];
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSet implements Iterable<DataSetEntry>{
|
||||
|
||||
private Map<DataSetEntry, Label> data;
|
||||
|
||||
public DataSet(){
|
||||
this(new HashMap<>());
|
||||
}
|
||||
|
||||
public DataSet(Map<DataSetEntry, Label> data){
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
public List<DataSetEntry> getData(){
|
||||
return new ArrayList<>(this.data.keySet());
|
||||
}
|
||||
|
||||
public Label getLabel(DataSetEntry entry){
|
||||
return this.data.get(entry);
|
||||
}
|
||||
|
||||
public DataSet toNormalized() {
|
||||
List<DataSetEntry> entries = this.getData();
|
||||
|
||||
float maxAbs = entries.stream()
|
||||
.flatMap(e -> e.getData().stream())
|
||||
.map(Math::abs)
|
||||
.max(Float::compare)
|
||||
.orElse(1.0F);
|
||||
|
||||
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
||||
for (DataSetEntry entry : entries) {
|
||||
List<Float> normalizedData = new ArrayList<>();
|
||||
for (float value : entry.getData()) {
|
||||
normalizedData.add(Math.round((value / maxAbs) * 100.0F) / 100.0F);
|
||||
}
|
||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||
}
|
||||
|
||||
return new DataSet(normalized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<DataSetEntry> iterator() {
|
||||
return this.data.keySet().iterator();
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
|
||||
public class Label {
|
||||
|
||||
private float value;
|
||||
|
||||
public Label(float value){
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
|
||||
public float getValue() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* Represents a fully connected neural network
|
||||
*/
|
||||
public class FullyConnectedNetwork implements Model {
|
||||
|
||||
private final Layer[] layers;
|
||||
private final Map<Neuron, List<Neuron>> connectionMap;
|
||||
private final Map<Neuron, Integer> layerIndexByNeuron;
|
||||
public FullyConnectedNetwork(Layer[] layers) {
|
||||
this.layers = layers;
|
||||
this.connectionMap = this.createConnectionMap();
|
||||
this.layerIndexByNeuron = this.createNeuronIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] predict(float[] inputs) {
|
||||
float[] previousLayerOutputs = inputs;
|
||||
for (Layer layer : layers) {
|
||||
previousLayerOutputs = layer.predict(previousLayerOutputs);
|
||||
}
|
||||
return previousLayerOutputs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
int res = 0;
|
||||
for(Layer layer : this.layers){
|
||||
res += layer.neuronCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
for(Layer l : this.layers){
|
||||
l.forEachNeuron(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
int lastIndex = this.layers.length-1;
|
||||
this.layers[lastIndex].forEachNeuron(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
if(!this.connectionMap.containsKey(n)) return;
|
||||
this.connectionMap.get(n).forEach(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int layerIndexOf(Neuron n) {
|
||||
return this.layerIndexByNeuron.get(n);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
int layerIndex = this.layerIndexByNeuron.get(n);
|
||||
return this.layers[layerIndex].indexInLayerOf(n);
|
||||
}
|
||||
|
||||
private Map<Neuron, List<Neuron>> createConnectionMap() {
|
||||
Map<Neuron, List<Neuron>> res = new HashMap<>();
|
||||
|
||||
for (int i = 0; i < this.layers.length - 1; i++) {
|
||||
List<Neuron> nextLayerNeurons = new ArrayList<>();
|
||||
this.layers[i + 1].forEachNeuron(nextLayerNeurons::add);
|
||||
this.layers[i].forEachNeuron(n -> res.put(n, nextLayerNeurons));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
for(Layer l : this.layers){
|
||||
l.forEachNeuron(n -> res.put(n, index.get()));
|
||||
index.incrementAndGet();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Layer implements Model {
|
||||
|
||||
private final Neuron[] neurons;
|
||||
private final Map<Neuron, Integer> neuronIndex;
|
||||
|
||||
public Layer(Neuron[] neurons) {
|
||||
this.neurons = neurons;
|
||||
this.neuronIndex = createNeuronIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] predict(float[] inputs) {
|
||||
float[] result = new float[neurons.length];
|
||||
for (int i = 0; i < neurons.length; i++) {
|
||||
result[i] = neurons[i].predict(inputs)[0];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
int res = 0;
|
||||
for (Neuron neuron : this.neurons) {
|
||||
res += neuron.synCount();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
return this.neurons.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int layerIndexOf(Neuron n) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
return this.neuronIndex.get(n);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
for (Neuron n : this.neurons){
|
||||
consumer.accept(n);
|
||||
}
|
||||
}
|
||||
|
||||
/*@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
for (Neuron n : this.neurons){
|
||||
n.forEachSynapse(consumer);
|
||||
}
|
||||
}*/
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
this.forEachNeuron(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection within the same layer");
|
||||
}
|
||||
|
||||
private Map<Neuron, Integer> createNeuronIndex() {
|
||||
Map<Neuron, Integer> res = new HashMap<>();
|
||||
int[] index = {0};
|
||||
this.forEachNeuron(n -> {
|
||||
res.put(n, index[0]++);
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
111
src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java
Normal file
111
src/main/java/com/naaturel/ANN/domain/model/neuron/Neuron.java
Normal file
@@ -0,0 +1,111 @@
|
||||
package com.naaturel.ANN.domain.model.neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public class Neuron implements Model {
|
||||
|
||||
private final int id;
|
||||
private float output;
|
||||
private final float[] weights;
|
||||
private final float[] inputs;
|
||||
private final ActivationFunction activationFunction;
|
||||
|
||||
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
|
||||
this.id = id;
|
||||
this.activationFunction = func;
|
||||
|
||||
output = 0;
|
||||
weights = new float[synapses.length+1]; //takes the bias into account
|
||||
inputs = new float[synapses.length+1]; //takes the bias into account
|
||||
|
||||
weights[0] = bias.getWeight();
|
||||
inputs[0] = bias.getInput();
|
||||
for (int i = 0; i < synapses.length; i++){
|
||||
weights[i+1] = synapses[i].getWeight();
|
||||
inputs[i+1] = synapses[i].getInput();
|
||||
}
|
||||
}
|
||||
|
||||
public void setWeight(int index, float value) {
|
||||
this.weights[index] = value;
|
||||
}
|
||||
|
||||
public float getWeight(int index) {
|
||||
return this.weights[index];
|
||||
}
|
||||
|
||||
public float getInput(int index) {
|
||||
return this.inputs[index];
|
||||
}
|
||||
|
||||
public ActivationFunction getActivationFunction(){
|
||||
return this.activationFunction;
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
int count = weights.length;
|
||||
float weightedSum = 0F;
|
||||
for (int i = 0; i < count; i++){
|
||||
weightedSum += weights[i] * inputs[i];
|
||||
}
|
||||
return weightedSum;
|
||||
}
|
||||
|
||||
public int getId(){
|
||||
return this.id;
|
||||
}
|
||||
|
||||
public float getOutput() {
|
||||
return this.output;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int neuronCount() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int layerIndexOf(Neuron n) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexInLayerOf(Neuron n) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] predict(float[] inputs) {
|
||||
this.setInputs(inputs);
|
||||
output = activationFunction.accept(this);
|
||||
return new float[] {output};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
consumer.accept(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
consumer.accept(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
||||
}
|
||||
|
||||
private void setInputs(float[] values){
|
||||
System.arraycopy(values, 0, inputs, 1, values.length);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -14,8 +14,8 @@ public class Synapse {
|
||||
return this.input.getValue();
|
||||
}
|
||||
|
||||
public void setInput(float value){
|
||||
this.input.setValue(value);
|
||||
public void setInput(Input input){
|
||||
this.input.setValue(input.getValue());
|
||||
}
|
||||
|
||||
public float getWeight() {
|
||||
@@ -25,6 +25,4 @@ public class Synapse {
|
||||
public void setWeight(float value){
|
||||
this.weight.setValue(value);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package com.naaturel.ANN.domain.model.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class TrainingPipeline {
|
||||
|
||||
private final List<AlgorithmStep> steps;
|
||||
private Consumer<TrainingContext> beforeEpoch;
|
||||
private Consumer<TrainingContext> afterEpoch;
|
||||
private Predicate<TrainingContext> stopCondition;
|
||||
|
||||
private boolean verbose;
|
||||
private boolean visualization;
|
||||
private boolean timeMeasurement;
|
||||
|
||||
private GraphVisualizer visualizer;
|
||||
private int verboseDelay;
|
||||
|
||||
public TrainingPipeline(List<AlgorithmStep> steps) {
|
||||
this.steps = new ArrayList<>(steps);
|
||||
this.stopCondition = (ctx) -> false;
|
||||
this.beforeEpoch = (context -> {});
|
||||
this.afterEpoch = (context -> {});
|
||||
}
|
||||
|
||||
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
|
||||
this.stopCondition = predicate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline beforeEpoch(Consumer<TrainingContext> consumer) {
|
||||
this.beforeEpoch = consumer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
|
||||
this.afterEpoch = consumer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withVerbose(boolean enabled, int epochDelay) {
|
||||
if(epochDelay <= 0) throw new IllegalArgumentException("Epoch delay cannot lower or equal to 0");
|
||||
this.verbose = enabled;
|
||||
this.verboseDelay = epochDelay;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withVisualization(boolean enabled, GraphVisualizer visualizer) {
|
||||
this.visualization = enabled;
|
||||
this.visualizer = visualizer;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainingPipeline withTimeMeasurement(boolean enabled) {
|
||||
this.timeMeasurement = enabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public void run(TrainingContext ctx) {
|
||||
|
||||
long start = this.timeMeasurement ? System.currentTimeMillis() : 0;
|
||||
|
||||
do {
|
||||
this.beforeEpoch.accept(ctx);
|
||||
this.executeSteps(ctx);
|
||||
this.afterEpoch.accept(ctx);
|
||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||
System.out.printf("[Global error] : %f\n", ctx.globalLoss);
|
||||
}
|
||||
ctx.epoch += 1;
|
||||
} while (!this.stopCondition.test(ctx));
|
||||
|
||||
if(this.timeMeasurement) {
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
|
||||
}
|
||||
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
|
||||
//if(this.visualization) this.visualize(ctx);
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
for (DataSetEntry entry : ctx.dataset) {
|
||||
|
||||
ctx.currentEntry = entry;
|
||||
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
|
||||
|
||||
for (AlgorithmStep step : steps) {
|
||||
step.run();
|
||||
}
|
||||
|
||||
if(this.verbose && ctx.epoch % this.verboseDelay == 0) {
|
||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions));
|
||||
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas));
|
||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*private void visualize(TrainingContext ctx){
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
|
||||
|
||||
float b = weights.get(0);
|
||||
float w1 = weights.get(1);
|
||||
float w2 = weights.get(2);
|
||||
|
||||
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
|
||||
});
|
||||
int i = 0;
|
||||
for(DataSetEntry entry : ctx.dataset){
|
||||
List<Input> inputs = entry.getData();
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
|
||||
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
|
||||
i++;
|
||||
}
|
||||
this.visualizer.buildLineGraph();
|
||||
}*/
|
||||
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.activationFunction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
|
||||
public class Heaviside implements ActivationFunction {
|
||||
|
||||
public Heaviside(){
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
return weightedSum <= 0 ? 0:1;
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.activationFunction;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
|
||||
public class Linear implements ActivationFunction {
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
return n.calculateWeightedSum();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.adaline;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public class AdalineTrainingContext extends TrainingContext {
|
||||
public AdalineTrainingContext(Model model, DataSet dataset) {
|
||||
super(model, dataset);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentCorrectionStrategy implements AlgorithmStep {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentCorrectionStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
int[] globalSynIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
for(int i = 0; i < n.synCount(); i++){
|
||||
float corrector = context.correctorTerms.get(globalSynIndex[0]);
|
||||
float c = n.getWeight(i) + corrector;
|
||||
n.setWeight(i, c);
|
||||
globalSynIndex[0]++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class GradientDescentErrorStrategy implements AlgorithmStep {
|
||||
|
||||
private final GradientDescentTrainingContext context;
|
||||
|
||||
public GradientDescentErrorStrategy(GradientDescentTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
AtomicInteger synIndex = new AtomicInteger(0);
|
||||
|
||||
context.model.forEachNeuron(neuron -> {
|
||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||
|
||||
for(int i = 0; i < neuron.synCount(); i++){
|
||||
float corrector = context.correctorTerms.get(synIndex.get());
|
||||
corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
|
||||
context.correctorTerms.set(synIndex.get(), corrector);
|
||||
synIndex.incrementAndGet();
|
||||
}
|
||||
neuronIndex.incrementAndGet();
|
||||
});
|
||||
|
||||
context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTrainingContext extends TrainingContext {
|
||||
|
||||
public List<Float> correctorTerms;
|
||||
|
||||
public GradientDescentTrainingContext(Model model, DataSet dataset) {
|
||||
super(model, dataset);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class Linear implements ActivationFunction {
|
||||
|
||||
private final float slope;
|
||||
private final float intercept;
|
||||
|
||||
public Linear(float slope, float intercept) {
|
||||
this.slope = slope;
|
||||
this.intercept = intercept;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
return slope * n.calculateWeightedSum() + intercept;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return this.slope;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.naaturel.ANN.implementation.gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class SquareLossStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SquareLossStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
float loss = 0f;
|
||||
for (float d : this.context.deltas) {
|
||||
loss += d * d;
|
||||
}
|
||||
this.context.localLoss = loss / 2f;
|
||||
this.context.globalLoss += this.context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
|
||||
private final GradientBackpropagationContext context;
|
||||
private final int synCount;
|
||||
private final float[] inputs;
|
||||
private final float[] signals;
|
||||
|
||||
public BackpropagationCorrectionStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
this.synCount = context.correctionBuffer.length;
|
||||
this.inputs = new float[synCount];
|
||||
this.signals = new float[synCount];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
int[] synIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
float signal = context.errorSignals[n.getId()];
|
||||
for (int i = 0; i < n.synCount(); i++){
|
||||
inputs[synIndex[0]] = n.getInput(i);
|
||||
signals[synIndex[0]] = signal;
|
||||
synIndex[0]++;
|
||||
}
|
||||
});
|
||||
|
||||
float lr = context.learningRate;
|
||||
boolean applyUpdate = context.currentSample >= context.batchSize;
|
||||
|
||||
for (int i = 0; i < synCount; i++) {
|
||||
context.correctionBuffer[i] += lr * signals[i] * inputs[i];
|
||||
}
|
||||
|
||||
if (applyUpdate) {
|
||||
syncWeights();
|
||||
context.currentSample = 0;
|
||||
}
|
||||
|
||||
context.currentSample++;
|
||||
}
|
||||
|
||||
private void syncWeights() {
|
||||
int[] synIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
for (int i = 0; i < n.synCount(); i++) {
|
||||
n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
|
||||
context.correctionBuffer[synIndex[0]] = 0f;
|
||||
synIndex[0]++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class BatchAccumulatorStep implements AlgorithmStep {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class ErrorSignalStep implements AlgorithmStep {
|
||||
|
||||
private final GradientBackpropagationContext context;
|
||||
|
||||
public ErrorSignalStep(GradientBackpropagationContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
|
||||
context.model.forEachNeuron(n -> {
|
||||
if (context.errorSignalsComputed[n.getId()]) return;
|
||||
|
||||
int neuronIndex = context.model.indexInLayerOf(n);
|
||||
float[] signalSum = {0f};
|
||||
context.model.forEachNeuronConnectedTo(n, connected -> {
|
||||
signalSum[0] += context.errorSignals[connected.getId()] * connected.getWeight(neuronIndex);
|
||||
});
|
||||
|
||||
context.errorSignals[n.getId()] = n.getActivationFunction().derivative(n.getOutput()) * signalSum[0];
|
||||
context.errorSignalsComputed[n.getId()] = true;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class GradientBackpropagationContext extends TrainingContext {
|
||||
|
||||
public final float[] errorSignals;
|
||||
public final float[] correctionBuffer;
|
||||
public final boolean[] errorSignalsComputed;
|
||||
|
||||
public int currentSample;
|
||||
public int batchSize;
|
||||
|
||||
public GradientBackpropagationContext(Model model, DataSet dataSet, float learningRate, int batchSize){
|
||||
super(model, dataSet);
|
||||
this.learningRate = learningRate;
|
||||
this.batchSize = batchSize;
|
||||
|
||||
this.errorSignals = new float[model.neuronCount()];
|
||||
this.correctionBuffer = new float[model.synCount()];
|
||||
this.errorSignalsComputed = new boolean[model.neuronCount()];
|
||||
this.currentSample = 1;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class OutputLayerErrorStep implements AlgorithmStep {
|
||||
|
||||
private final GradientBackpropagationContext context;
|
||||
private final float[] expectations;
|
||||
|
||||
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
this.expectations = new float[context.dataset.getNbrLabels()];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
Arrays.fill(context.errorSignals, 0f);
|
||||
Arrays.fill(context.errorSignalsComputed, false);
|
||||
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
List<Float> labels = context.dataset.getLabelsAsFloat(entry);
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
expectations[i] = labels.get(i);
|
||||
}
|
||||
|
||||
int[] index = {0};
|
||||
context.model.forEachOutputNeurons(n -> {
|
||||
float expected = expectations[index[0]];
|
||||
float predicted = n.getOutput();
|
||||
float delta = expected - predicted;
|
||||
|
||||
context.deltas[index[0]] = delta;
|
||||
context.errorSignals[n.getId()] = delta * n.getActivationFunction().derivative(predicted);
|
||||
context.errorSignalsComputed[n.getId()] = true;
|
||||
index[0]++;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class Sigmoid implements ActivationFunction {
|
||||
|
||||
private float steepness;
|
||||
|
||||
public Sigmoid(float steepness) {
|
||||
this.steepness = steepness;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
return (float) (1.0/(1.0 + Math.exp(-steepness * n.calculateWeightedSum())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return steepness * value * (1 - value);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.multiLayers;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
public class TanH implements ActivationFunction {
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
//For educational purpose. Math.tanh() could have been used here
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
double exp = Math.exp(weightedSum);
|
||||
double res = (exp-(1/exp))/(exp+(1/exp));
|
||||
return (float)(res);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
return 1 - value * value;
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package com.naaturel.ANN.implementation.neuron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainable;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePerceptron extends Neuron implements Trainable {
|
||||
|
||||
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
||||
super(synapses, b, func);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float predict() {
|
||||
return activationFunction.accept(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float calculateWeightedSum() {
|
||||
float res = 0;
|
||||
for(Synapse syn : super.synapses){
|
||||
res += syn.getWeight() * syn.getInput();
|
||||
}
|
||||
res += this.bias.getWeight() * this.bias.getInput();
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import javax.naming.OperationNotSupportedException;
|
||||
|
||||
public class Heaviside implements ActivationFunction {
|
||||
|
||||
public Heaviside(){
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public float accept(Neuron n) {
|
||||
float weightedSum = n.calculateWeightedSum();
|
||||
return weightedSum < 0 ? 0:1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float derivative(float value) {
|
||||
throw new UnsupportedOperationException("Heaviside is not differentiable");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
|
||||
public class SimpleCorrectionStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleCorrectionStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
if(context.expectations.equals(context.predictions)) return;
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
|
||||
context.model.forEachNeuron(neuron -> {
|
||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||
|
||||
for(int i = 0; i < neuron.synCount(); i++){
|
||||
float currentW = neuron.getWeight(i);
|
||||
float currentInput = neuron.getInput(i);
|
||||
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
|
||||
neuron.setWeight(i, newValue);
|
||||
}
|
||||
neuronIndex.incrementAndGet();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class SimpleDeltaStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleDeltaStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
DataSet dataSet = context.dataset;
|
||||
DataSetEntry entry = context.currentEntry;
|
||||
float[] predicted = context.predictions;
|
||||
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
||||
|
||||
for (int i = 0; i < predicted.length; i++) {
|
||||
context.deltas[i] = expected.get(i) - predicted[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
|
||||
public class SimpleErrorRegistrationStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimpleErrorRegistrationStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
context.globalLoss += context.localLoss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
|
||||
public class SimpleLossStrategy implements AlgorithmStep {
|
||||
|
||||
private final SimpleTrainingContext context;
|
||||
|
||||
public SimpleLossStrategy(SimpleTrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
float loss = 0f;
|
||||
for (float d : context.deltas) {
|
||||
loss += d;
|
||||
}
|
||||
context.localLoss = loss;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePredictionStep implements AlgorithmStep {
|
||||
|
||||
private final TrainingContext context;
|
||||
|
||||
public SimplePredictionStep(TrainingContext context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
List<Input> data = context.currentEntry.getData();
|
||||
float[] flatData = new float[data.size()];
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
flatData[i] = data.get(i).getValue();
|
||||
}
|
||||
context.predictions = context.model.predict(flatData);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
|
||||
public class SimpleTrainingContext extends TrainingContext {
|
||||
public SimpleTrainingContext(Model model, DataSet dataset) {
|
||||
super(model, dataset);
|
||||
}
|
||||
}
|
||||
@@ -1,33 +1,60 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class AdalineTraining {
|
||||
|
||||
public class AdalineTraining implements Trainer {
|
||||
|
||||
public AdalineTraining(){
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
AdalineTrainingContext context = new AdalineTrainingContext(model, dataset);
|
||||
context.learningRate = learningRate;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||
.withTimeMeasurement(true)
|
||||
.withVerbose(true, 1)
|
||||
.withVisualization(true, new GraphVisualizer())
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 1000;
|
||||
int maxEpoch = 202;
|
||||
float errorThreshold = 0.0F;
|
||||
float mse;
|
||||
|
||||
do {
|
||||
if(epoch > maxEpoch) break;
|
||||
mse = 0;
|
||||
|
||||
for(DataSetEntry entry : dataSet) {
|
||||
for(DataSetEntry entry : dataSet) {
|
||||
this.updateInputs(n, entry);
|
||||
float prediction = n.predict();
|
||||
float expectation = dataSet.getLabel(entry).getValue();
|
||||
@@ -49,23 +76,22 @@ public class AdalineTraining {
|
||||
System.out.printf("predicted : %.2f, ", prediction);
|
||||
System.out.printf("expected : %.2f, ", expectation);
|
||||
System.out.printf("delta : %.2f, ", delta);
|
||||
System.out.printf("loss : %.2f\n", loss);
|
||||
System.out.printf("loss : %.5f\n", loss);
|
||||
}
|
||||
mse /= dataSet.size();
|
||||
System.out.printf("[Total error : %f]\n", mse);
|
||||
|
||||
System.out.println("[Final weights]");
|
||||
System.out.printf("Bias: %f\n", n.getBias().getWeight());
|
||||
int i = 1;
|
||||
for(Synapse syn : n.getSynapses()){
|
||||
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
|
||||
i++;
|
||||
}
|
||||
epoch++;
|
||||
} while(mse > errorThreshold);
|
||||
|
||||
}
|
||||
|
||||
private List<Float> initCorrectorTerms(int number){
|
||||
List<Float> res = new ArrayList<>();
|
||||
for(int i = 0; i < number; i++){
|
||||
res.add(0F);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private void updateInputs(Neuron n, DataSetEntry entry){
|
||||
int index = 0;
|
||||
for(float value : entry){
|
||||
@@ -80,10 +106,6 @@ public class AdalineTraining {
|
||||
|
||||
private float calculateLoss(float delta){
|
||||
return (float) Math.pow(delta, 2)/2;
|
||||
}
|
||||
|
||||
private float calculateWeightCorrection(float value, float delta){
|
||||
return value * delta;
|
||||
}
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.BackpropagationCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.GradientBackpropagationContext;
|
||||
import com.naaturel.ANN.implementation.multiLayers.ErrorSignalStep;
|
||||
import com.naaturel.ANN.implementation.multiLayers.OutputLayerErrorStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
GradientBackpropagationContext context =
|
||||
new GradientBackpropagationContext(model, dataset, learningRate, 10);
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new OutputLayerErrorStep(context),
|
||||
new ErrorSignalStep(context),
|
||||
new BackpropagationCorrectionStep(context),
|
||||
new SquareLossStep(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.00F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
})
|
||||
.afterEpoch(ctx -> {
|
||||
ctx.globalLoss /= dataset.size();
|
||||
})
|
||||
.withVerbose(true,epoch/10)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,64 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentErrorStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentTrainingContext;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.GradientDescentCorrectionStrategy;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.SquareLossStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class GradientDescentTraining {
|
||||
public class GradientDescentTraining implements Trainer {
|
||||
|
||||
public GradientDescentTraining(){
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
GradientDescentTrainingContext context = new GradientDescentTrainingContext(model, dataset);
|
||||
context.learningRate = learningRate;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SquareLossStep(context),
|
||||
new GradientDescentErrorStrategy(context)
|
||||
);
|
||||
|
||||
new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> {
|
||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||
gdCtx.globalLoss = 0.0F;
|
||||
gdCtx.correctorTerms.clear();
|
||||
for(int i = 0; i < gdCtx.model.synCount(); i++){
|
||||
gdCtx.correctorTerms.add(0F);
|
||||
}
|
||||
})
|
||||
.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).run();
|
||||
})
|
||||
//.withVerbose(true)
|
||||
.withTimeMeasurement(true)
|
||||
.withVisualization(true, new GraphVisualizer())
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int maxEpoch = 1000;
|
||||
float errorThreshold = 0.0F;
|
||||
int maxEpoch = 402;
|
||||
float errorThreshold = 0F;
|
||||
float mse;
|
||||
|
||||
do {
|
||||
@@ -54,6 +92,7 @@ public class GradientDescentTraining {
|
||||
System.out.printf("delta : %.2f, ", delta);
|
||||
System.out.printf("loss : %.2f\n", loss);
|
||||
}
|
||||
mse /= dataSet.size();
|
||||
System.out.printf("[Total error : %f]\n", mse);
|
||||
|
||||
float currentBias = n.getBias().getWeight();
|
||||
@@ -69,6 +108,13 @@ public class GradientDescentTraining {
|
||||
epoch++;
|
||||
} while(mse > errorThreshold);
|
||||
|
||||
System.out.println("[Final weights]");
|
||||
System.out.printf("Bias: %f\n", n.getBias().getWeight());
|
||||
int i = 1;
|
||||
for(Synapse syn : n.getSynapses()){
|
||||
System.out.printf("Syn %d: %f\n", i, syn.getWeight());
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
private List<Float> initCorrectorTerms(int number){
|
||||
@@ -95,8 +141,21 @@ public class GradientDescentTraining {
|
||||
return (float) Math.pow(delta, 2)/2;
|
||||
}
|
||||
|
||||
private float calculateWeightCorrection(float value, float delta){
|
||||
return value * delta;
|
||||
}
|
||||
public float computeThreshold(DataSet dataSet) {
|
||||
float sum = 0;
|
||||
for (DataSetEntry entry : dataSet) {
|
||||
sum += dataSet.getLabel(entry).getValue();
|
||||
}
|
||||
float mean = sum / dataSet.size();
|
||||
|
||||
float variance = 0;
|
||||
for (DataSetEntry entry : dataSet) {
|
||||
float diff = dataSet.getLabel(entry).getValue() - mean;
|
||||
variance += diff * diff;
|
||||
}
|
||||
variance /= dataSet.size();
|
||||
|
||||
return variance;
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -1,19 +1,44 @@
|
||||
package com.naaturel.ANN.implementation.training;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStep;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
|
||||
public class SimpleTraining {
|
||||
import java.util.List;
|
||||
|
||||
public class SimpleTraining implements Trainer {
|
||||
|
||||
public SimpleTraining() {
|
||||
|
||||
}
|
||||
|
||||
public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
SimpleTrainingContext context = new SimpleTrainingContext(model, dataset);
|
||||
context.dataset = dataset;
|
||||
context.model = model;
|
||||
context.learningRate = learningRate;
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
new SimpleDeltaStep(context),
|
||||
new SimpleLossStrategy(context),
|
||||
new SimpleErrorRegistrationStep(context),
|
||||
new SimpleCorrectionStep(context)
|
||||
);
|
||||
|
||||
TrainingPipeline pipeline = new TrainingPipeline(steps);
|
||||
pipeline
|
||||
.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > epoch)
|
||||
.beforeEpoch(ctx -> ctx.globalLoss = 0)
|
||||
.withVerbose(true, 1)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
|
||||
int epoch = 1;
|
||||
int errorCount;
|
||||
|
||||
@@ -64,5 +89,5 @@ public class SimpleTraining {
|
||||
private float calculateLoss(float delta){
|
||||
return Math.abs(delta);
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.naaturel.ANN.infrastructure.config;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ConfigDto {
|
||||
|
||||
@JsonProperty("model")
|
||||
private Map<String, Object> modelConfig;
|
||||
|
||||
@JsonProperty("training")
|
||||
private Map<String, Object> trainingConfig;
|
||||
|
||||
@JsonProperty("dataset")
|
||||
private Map<String, Object> datasetConfig;
|
||||
|
||||
public <T> T getModelProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.modelConfig);
|
||||
if (value instanceof List<?> list && type.isArray()) {
|
||||
int[] arr = new int[list.size()];
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
arr[i] = ((Number) list.get(i)).intValue();
|
||||
}
|
||||
return type.cast(arr);
|
||||
}
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
public <T> T getTrainingProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.trainingConfig);
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
public <T> T getDatasetProperty(String key, Class<T> type) {
|
||||
Object value = find(key, this.datasetConfig);
|
||||
if (!type.isInstance(value)) {
|
||||
throw new RuntimeException("Property '" + key + "' is not of type " + type.getSimpleName());
|
||||
}
|
||||
return type.cast(value);
|
||||
}
|
||||
|
||||
private Object find(String key, Map<String, Object> config){
|
||||
if(!config.containsKey(key)) throw new RuntimeException("Unable to find property for key '" + key + "'");
|
||||
return config.get(key);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.naaturel.ANN.infrastructure.config;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class ConfigLoader {
|
||||
|
||||
|
||||
public static ConfigDto load(String path) throws Exception {
|
||||
try {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ConfigDto config = mapper.readValue(new File("config.json"), ConfigDto.class);
|
||||
|
||||
return config;
|
||||
} catch (Exception e){
|
||||
throw new Exception("Unable to load config : " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package com.naaturel.ANN.infrastructure.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class DataSet implements Iterable<DataSetEntry>{
|
||||
|
||||
private final Map<DataSetEntry, Labels> data;
|
||||
|
||||
private final int nbrInputs;
|
||||
private final int nbrLabels;
|
||||
|
||||
public DataSet() {
|
||||
this(new LinkedHashMap<>()); //ensure iteration order is the same as insertion order
|
||||
}
|
||||
|
||||
public DataSet(Map<DataSetEntry, Labels> data){
|
||||
this.data = data;
|
||||
this.nbrInputs = this.calculateNbrInput();
|
||||
this.nbrLabels = this.calculateNbrLabel();
|
||||
}
|
||||
|
||||
private int calculateNbrInput(){
|
||||
//assumes every entry are the same length
|
||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||
return firstEntry.map(inputs -> inputs.getData().size()).orElse(0);
|
||||
}
|
||||
|
||||
private int calculateNbrLabel(){
|
||||
//assumes every label are the same length
|
||||
Stream<DataSetEntry> keyStream = this.data.keySet().stream();
|
||||
Optional<DataSetEntry> firstEntry = keyStream.findFirst();
|
||||
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
|
||||
}
|
||||
|
||||
|
||||
public int size() {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
public int getNbrInputs() {
|
||||
return this.nbrInputs;
|
||||
}
|
||||
|
||||
public int getNbrLabels(){
|
||||
return this.nbrLabels;
|
||||
}
|
||||
|
||||
public List<DataSetEntry> getData(){
|
||||
return new ArrayList<>(this.data.keySet());
|
||||
}
|
||||
|
||||
public List<Float> getLabelsAsFloat(DataSetEntry entry){
|
||||
return this.data.get(entry).getValues();
|
||||
}
|
||||
|
||||
public DataSet toNormalized() {
|
||||
List<DataSetEntry> entries = this.getData();
|
||||
|
||||
float maxAbs = entries.stream()
|
||||
.flatMap(e -> e.getData().stream())
|
||||
.map(Input::getValue)
|
||||
.map(Math::abs)
|
||||
.max(Float::compare)
|
||||
.orElse(1.0F);
|
||||
|
||||
Map<DataSetEntry, Labels> normalized = new HashMap<>();
|
||||
for (DataSetEntry entry : entries) {
|
||||
List<Input> normalizedData = new ArrayList<>();
|
||||
|
||||
for (Input input : entry.getData()) {
|
||||
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
||||
normalizedData.add(normalizedInput);
|
||||
}
|
||||
|
||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||
}
|
||||
|
||||
return new DataSet(normalized);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<DataSetEntry> iterator() {
|
||||
return this.data.keySet().iterator();
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,18 @@
|
||||
package com.naaturel.ANN.domain.model.dataset;
|
||||
package com.naaturel.ANN.infrastructure.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class DataSetEntry implements Iterable<Float> {
|
||||
public class DataSetEntry implements Iterable<Input> {
|
||||
|
||||
private List<Float> data;
|
||||
private List<Input> data;
|
||||
|
||||
public DataSetEntry(List<Float> data){
|
||||
public DataSetEntry(List<Input> data){
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public List<Float> getData() {
|
||||
public List<Input> getData() {
|
||||
return new ArrayList<>(data);
|
||||
}
|
||||
|
||||
@@ -28,7 +30,7 @@ public class DataSetEntry implements Iterable<Float> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Float> iterator() {
|
||||
public Iterator<Input> iterator() {
|
||||
return this.data.iterator();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.naaturel.ANN.infrastructure.dataset;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
public class DatasetExtractor {
|
||||
|
||||
public DataSet extract(String path, int nbrLabels) {
|
||||
Map<DataSetEntry, Labels> data = new LinkedHashMap<>();
|
||||
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] parts = line.split(",");
|
||||
|
||||
String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels);
|
||||
String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length);
|
||||
|
||||
List<Input> inputs = new ArrayList<>();
|
||||
List<Float> labels = new ArrayList<>();
|
||||
|
||||
for (String entry : rawInputs) {
|
||||
inputs.add(new Input(Float.parseFloat(entry.trim())));
|
||||
}
|
||||
|
||||
for (String entry : rawLabels) {
|
||||
labels.add(Float.parseFloat(entry.trim()));
|
||||
}
|
||||
|
||||
data.put(new DataSetEntry(inputs), new Labels(labels));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
||||
}
|
||||
|
||||
return new DataSet(data);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.naaturel.ANN.infrastructure.dataset;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Labels {
|
||||
|
||||
private final List<Float> values;
|
||||
|
||||
public Labels(List<Float> value){
|
||||
this.values = value;
|
||||
}
|
||||
|
||||
public List<Float> getValues() {
|
||||
return values.stream().toList();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.naaturel.ANN.infrastructure.persistence;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class ModelDto {
|
||||
|
||||
private List<NeuronDto> neurons;
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package com.naaturel.ANN.infrastructure.persistence;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ModelSnapshot {
|
||||
|
||||
private Model model;
|
||||
private final ObjectMapper mapper;
|
||||
|
||||
public ModelSnapshot(){
|
||||
this(null);
|
||||
}
|
||||
|
||||
public ModelSnapshot(Model model){
|
||||
this.model = model;
|
||||
mapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
public Model getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public void saveToFile(String path) throws Exception {
|
||||
|
||||
ArrayNode root = mapper.createArrayNode();
|
||||
model.forEachNeuron(n -> {
|
||||
|
||||
ObjectNode neuronNode = mapper.createObjectNode();
|
||||
neuronNode.put("id", n.getId());
|
||||
neuronNode.put("layerIndex", model.layerIndexOf(n));
|
||||
|
||||
ArrayNode weights = mapper.createArrayNode();
|
||||
for (int i = 0; i < n.synCount(); i++) {
|
||||
float weight = n.getWeight(i);
|
||||
weights.add(weight);
|
||||
}
|
||||
neuronNode.set("weights", weights);
|
||||
root.add(neuronNode);
|
||||
});
|
||||
|
||||
mapper.writerWithDefaultPrettyPrinter().writeValue(new File(path), root);
|
||||
}
|
||||
|
||||
public void loadFromFile(String path) throws Exception {
|
||||
ArrayNode root = (ArrayNode) mapper.readTree(new File(path));
|
||||
|
||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
||||
|
||||
root.forEach(neuronNode -> {
|
||||
int id = neuronNode.get("id").asInt();
|
||||
int layerIndex = neuronNode.get("layerIndex").asInt();
|
||||
ArrayNode weightsNode = (ArrayNode) neuronNode.get("weights");
|
||||
|
||||
Bias bias = new Bias(new Weight(weightsNode.get(0).floatValue()));
|
||||
Synapse[] synapses = new Synapse[weightsNode.size() - 1];
|
||||
for (int i = 0; i < synapses.length; i++) {
|
||||
synapses[i] = new Synapse(new Input(0), new Weight(weightsNode.get(i + 1).floatValue()));
|
||||
}
|
||||
|
||||
Neuron n = new Neuron(id, synapses, bias, new TanH());
|
||||
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
||||
});
|
||||
|
||||
Layer[] layers = neuronsByLayer.values().stream()
|
||||
.map(neurons -> new Layer(neurons.toArray(new Neuron[0])))
|
||||
.toArray(Layer[]::new);
|
||||
|
||||
this.model = new FullyConnectedNetwork(layers);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
package com.naaturel.ANN.infrastructure.persistence;
|
||||
|
||||
public class NeuronDto {
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package com.naaturel.ANN.infrastructure.visualization;
|
||||
|
||||
import org.jfree.chart.ChartFactory;
|
||||
import org.jfree.chart.ChartPanel;
|
||||
import org.jfree.chart.JFreeChart;
|
||||
import org.jfree.chart.plot.XYPlot;
|
||||
import org.jfree.data.xy.XYSeries;
|
||||
import org.jfree.data.xy.XYSeriesCollection;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
public class GraphVisualizer {
|
||||
|
||||
XYSeriesCollection dataset;
|
||||
|
||||
public GraphVisualizer(){
|
||||
this.dataset = new XYSeriesCollection();
|
||||
}
|
||||
|
||||
public void addPoint(String title, float x, float y) {
|
||||
if (this.dataset.getSeriesIndex(title) == -1)
|
||||
this.dataset.addSeries(new XYSeries(title));
|
||||
this.dataset.getSeries(title).add(x, y);
|
||||
}
|
||||
|
||||
public void addEquation(String title, float y1, float y2, float k, float xMin, float xMax) {
|
||||
for (float x1 = xMin; x1 <= xMax; x1 += 0.01f) {
|
||||
float x2 = (-y1 * x1 - k) / y2;
|
||||
addPoint(title, x1, x2);
|
||||
}
|
||||
}
|
||||
|
||||
public void buildLineGraph(){
|
||||
JFreeChart chart = ChartFactory.createXYLineChart(
|
||||
"Model learning", "X", "Y", dataset
|
||||
);
|
||||
JFrame frame = new JFrame("Training Loss");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.add(new ChartPanel(chart));
|
||||
frame.pack();
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
public void buildScatterGraph(int lower, int upper){
|
||||
JFreeChart chart = ChartFactory.createScatterPlot(
|
||||
"Predictions", "X", "Y", dataset
|
||||
);
|
||||
XYPlot plot = chart.getXYPlot();
|
||||
plot.getDomainAxis().setRange(lower, upper);
|
||||
plot.getRangeAxis().setRange(lower, upper);
|
||||
|
||||
JFrame frame = new JFrame("Predictions");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.add(new ChartPanel(chart));
|
||||
frame.pack();
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package com.naaturel.ANN.infrastructure.visualization;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
public class ModelVisualizer {
|
||||
|
||||
private final JFrame frame;
|
||||
private final Model model;
|
||||
private boolean withWeights;
|
||||
|
||||
public ModelVisualizer(Model model){
|
||||
this.frame = initFrame();
|
||||
this.model = model;
|
||||
this.withWeights = false;
|
||||
}
|
||||
|
||||
public ModelVisualizer withWeights(boolean value){
|
||||
this.withWeights = value;
|
||||
return this;
|
||||
}
|
||||
|
||||
public void display() {
|
||||
JPanel panel = buildPanel();
|
||||
frame.add(panel);
|
||||
frame.setVisible(true);
|
||||
}
|
||||
|
||||
private JFrame initFrame(){
|
||||
JFrame frame = new JFrame("Model Visualizer");
|
||||
frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
|
||||
frame.setSize(800, 600);
|
||||
return frame;
|
||||
}
|
||||
|
||||
private Map<Integer, Point> computeNeuronPositions(int width, int height) {
|
||||
|
||||
Map<Integer, List<Neuron>> neuronsByLayer = new LinkedHashMap<>();
|
||||
model.forEachNeuron(n -> {
|
||||
int layerIndex = model.layerIndexOf(n);
|
||||
neuronsByLayer.computeIfAbsent(layerIndex, k -> new ArrayList<>()).add(n);
|
||||
});
|
||||
|
||||
Map<Integer, Point> neuronPositions = new HashMap<>();
|
||||
int layerCount = neuronsByLayer.size();
|
||||
int layerX = width / (layerCount + 1);
|
||||
|
||||
for (Map.Entry<Integer, List<Neuron>> entry : neuronsByLayer.entrySet()) {
|
||||
int layerIndex = entry.getKey();
|
||||
List<Neuron> neurons = entry.getValue();
|
||||
int x = layerX * (layerIndex + 1);
|
||||
int neuronCount = neurons.size();
|
||||
|
||||
for (int i = 0; i < neuronCount; i++) {
|
||||
int y = height / (neuronCount + 1) * (i + 1);
|
||||
neuronPositions.put(neurons.get(i).getId(), new Point(x, y));
|
||||
}
|
||||
}
|
||||
return neuronPositions;
|
||||
}
|
||||
|
||||
private void drawConnections(Graphics2D g2, Map<Integer, Point> neuronPositions) {
|
||||
model.forEachNeuron(n -> {
|
||||
Point from = neuronPositions.get(n.getId());
|
||||
int neuronIndex = model.indexInLayerOf(n);
|
||||
model.forEachNeuronConnectedTo(n, connected -> {
|
||||
Point to = neuronPositions.get(connected.getId());
|
||||
g2.setColor(Color.LIGHT_GRAY);
|
||||
g2.drawLine(from.x, from.y, to.x, to.y);
|
||||
|
||||
if(!this.withWeights) return;
|
||||
int mx = (from.x + to.x) / 2;
|
||||
int my = (from.y + to.y) / 2;
|
||||
float weight = connected.getWeight(neuronIndex + 1);
|
||||
g2.setColor(Color.DARK_GRAY);
|
||||
g2.drawString(String.format("%.2f", weight), mx, my);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private void drawNeurons(Graphics2D g2, Map<Integer, Point> neuronPositions, int neuronRadius) {
|
||||
model.forEachNeuron(n -> {
|
||||
Point p = neuronPositions.get(n.getId());
|
||||
g2.setColor(Color.WHITE);
|
||||
g2.fillOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
|
||||
g2.setColor(Color.BLACK);
|
||||
g2.drawOval(p.x - neuronRadius, p.y - neuronRadius, neuronRadius * 2, neuronRadius * 2);
|
||||
g2.drawString(String.valueOf(n.getId()), p.x - 5, p.y + 5);
|
||||
});
|
||||
}
|
||||
|
||||
private JPanel buildPanel() {
|
||||
int neuronRadius = 20;
|
||||
return new JPanel() {
|
||||
@Override
|
||||
protected void paintComponent(Graphics g) {
|
||||
super.paintComponent(g);
|
||||
Graphics2D g2 = (Graphics2D) g;
|
||||
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
|
||||
|
||||
Map<Integer, Point> neuronPositions = computeNeuronPositions(getWidth(), getHeight());
|
||||
drawConnections(g2, neuronPositions);
|
||||
drawNeurons(g2, neuronPositions, neuronRadius);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
4
src/main/resources/assets/and-gradient.csv
Normal file
4
src/main/resources/assets/and-gradient.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
0,0,-1
|
||||
0,1,-1
|
||||
1,0,-1
|
||||
1,1,1
|
||||
|
@@ -18,4 +18,4 @@
|
||||
4,6,-1
|
||||
4,7,-1
|
||||
4,9,1
|
||||
4,10,1
|
||||
4,10,1
|
||||
|
4
src/main/resources/assets/xor.csv
Normal file
4
src/main/resources/assets/xor.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
0,0,0
|
||||
0,1,1
|
||||
1,0,1
|
||||
1,1,0
|
||||
|
37
src/main/resources/snapshots/best-snapshot.json
Normal file
37
src/main/resources/snapshots/best-snapshot.json
Normal file
@@ -0,0 +1,37 @@
|
||||
[ {
|
||||
"id" : 0,
|
||||
"layerIndex" : 0,
|
||||
"weights" : [ 0.73659766, -0.77503574, -0.23966503, 0.5794077, -0.1634599, 0.6023351, 0.770439, -0.53199804, -0.75792193, 0.36339867, 0.017356396, -0.3005041, -0.5709046, 0.8757278, 0.22738981, -0.22997773, 0.81583726, -0.6008339, 0.8148706, -0.27952576, 0.11130214, 0.21506107, -0.96409214, 0.8534266, 0.9998076, 0.8971077, -0.55812895, 0.28677964, -0.4315225, -0.12088442, 0.41834033, -0.83330417, 0.013990879, -0.021193504, 0.95400894, 0.24115086, 0.122039676, 0.7069808, 0.74929047, 0.2558391, 0.1307528, 0.45781684, -0.19833839 ]
|
||||
}, {
|
||||
"id" : 1,
|
||||
"layerIndex" : 0,
|
||||
"weights" : [ -0.25173664, -0.15121317, 0.3947369, 0.39584184, 0.8823843, -0.43822396, 0.38901758, -0.83357096, 0.4349022, -0.57956433, -0.78882015, -0.7952547, 0.87495327, 0.20618606, -0.802194, -0.108419776, 0.47257674, -0.7384255, 0.41351187, -0.6334251, -0.61948025, 0.40434432, -0.18576205, -0.47782838, 0.57454073, -0.36851442, -0.7262291, 0.3893906, 0.83334255, -0.6979947, 0.43623865, 0.5753021, 0.0041633844, 0.6598717, 0.21344411, -0.40663266, 0.73282886, -0.00848031, -0.7269838, -0.36129093, -0.7280016, -0.0039184093, -0.71608007 ]
|
||||
}, {
|
||||
"id" : 2,
|
||||
"layerIndex" : 1,
|
||||
"weights" : [ 0.0760473, -0.54335964, -0.789695 ]
|
||||
}, {
|
||||
"id" : 3,
|
||||
"layerIndex" : 1,
|
||||
"weights" : [ -0.16066408, 0.97240365, 0.72418904 ]
|
||||
}, {
|
||||
"id" : 4,
|
||||
"layerIndex" : 1,
|
||||
"weights" : [ 0.58718896, -0.59475064, 0.81476605 ]
|
||||
}, {
|
||||
"id" : 5,
|
||||
"layerIndex" : 1,
|
||||
"weights" : [ -0.09139645, 0.71847045, -0.19625723 ]
|
||||
}, {
|
||||
"id" : 6,
|
||||
"layerIndex" : 2,
|
||||
"weights" : [ 0.60579014, 0.7003196, 0.82173157, 0.7627305, 0.83753014 ]
|
||||
}, {
|
||||
"id" : 7,
|
||||
"layerIndex" : 2,
|
||||
"weights" : [ -0.7915581, -0.5355048, -1.0218064, 0.35036424, -0.6445867 ]
|
||||
}, {
|
||||
"id" : 8,
|
||||
"layerIndex" : 3,
|
||||
"weights" : [ -0.24528235, 0.08879423, -0.46046266 ]
|
||||
} ]
|
||||
92
src/test/java/adaline/AdalineTest.java
Normal file
92
src/test/java/adaline/AdalineTest.java
Normal file
@@ -0,0 +1,92 @@
|
||||
package adaline;
|
||||
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStep;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.SimplePredictionStep;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class AdalineTest {
|
||||
|
||||
private DataSet dataset;
|
||||
private AdalineTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@BeforeEach
|
||||
public void init(){
|
||||
dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new AdalineTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.1329F || ctx.epoch > 10000)
|
||||
.beforeEpoch(ctx -> {
|
||||
ctx.globalLoss = 0.0F;
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_the_whole_algorithm(){
|
||||
|
||||
List<Float> expectedGlobalLosses = List.of(
|
||||
0.501522F,
|
||||
0.498601F
|
||||
);
|
||||
|
||||
context.learningRate = 0.03F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
ctx.globalLoss /= context.dataset.size();
|
||||
|
||||
int index = ctx.epoch-1;
|
||||
if(index >= expectedGlobalLosses.size()) return;
|
||||
|
||||
//assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
||||
});
|
||||
|
||||
pipeline.run(context);
|
||||
assertEquals(214, context.epoch);
|
||||
}
|
||||
}
|
||||
|
||||
98
src/test/java/gradientDescent/GradientDescentTest.java
Normal file
98
src/test/java/gradientDescent/GradientDescentTest.java
Normal file
@@ -0,0 +1,98 @@
|
||||
package gradientDescent;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
||||
public class GradientDescentTest {
|
||||
|
||||
private DataSet dataset;
|
||||
private GradientDescentTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@BeforeEach
|
||||
public void init(){
|
||||
dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Linear(1, 0));
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new GradientDescentTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
context.correctorTerms = new ArrayList<>();
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SquareLossStep(context)),
|
||||
new ErrorRegistrationStep(new GradientDescentErrorStrategy(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps)
|
||||
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 100)
|
||||
.beforeEpoch(ctx -> {
|
||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||
gdCtx.globalLoss = 0.0F;
|
||||
gdCtx.correctorTerms.clear();
|
||||
for (int i = 0; i < ctx.model.synCount(); i++){
|
||||
gdCtx.correctorTerms.add(0F);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_the_whole_algorithm(){
|
||||
|
||||
List<Float> expectedGlobalLosses = List.of(
|
||||
0.5F,
|
||||
0.38F,
|
||||
0.3176F,
|
||||
0.272096F,
|
||||
0.237469F
|
||||
);
|
||||
|
||||
context.learningRate = 0.2F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
new GradientDescentCorrectionStrategy(context).run();
|
||||
|
||||
int index = ctx.epoch-1;
|
||||
if(index >= expectedGlobalLosses.size()) return;
|
||||
|
||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss, 0.00001f);
|
||||
});
|
||||
|
||||
pipeline
|
||||
.withVerbose(true)
|
||||
.run(context);
|
||||
assertEquals(67, context.epoch);
|
||||
}
|
||||
}
|
||||
84
src/test/java/perceptron/SimplePerceptronTest.java
Normal file
84
src/test/java/perceptron/SimplePerceptronTest.java
Normal file
@@ -0,0 +1,84 @@
|
||||
package perceptron;
|
||||
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||
import com.naaturel.ANN.implementation.training.steps.*;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
|
||||
public class SimplePerceptronTest {
|
||||
|
||||
private DataSet dataset;
|
||||
private SimpleTrainingContext context;
|
||||
|
||||
private List<Synapse> synapses;
|
||||
private Bias bias;
|
||||
private FullyConnectedNetwork network;
|
||||
|
||||
private TrainingPipeline pipeline;
|
||||
|
||||
@BeforeEach
|
||||
public void init(){
|
||||
dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1);
|
||||
|
||||
List<Synapse> syns = new ArrayList<>();
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||
|
||||
bias = new Bias(new Weight(0));
|
||||
|
||||
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
||||
Layer layer = new Layer(List.of(neuron));
|
||||
network = new FullyConnectedNetwork(List.of(layer));
|
||||
|
||||
context = new SimpleTrainingContext();
|
||||
context.dataset = dataset;
|
||||
context.model = network;
|
||||
|
||||
List<TrainingStep> steps = List.of(
|
||||
new PredictionStep(new SimplePredictionStep(context)),
|
||||
new DeltaStep(new SimpleDeltaStep(context)),
|
||||
new LossStep(new SimpleLossStrategy(context)),
|
||||
new ErrorRegistrationStep(new SimpleErrorRegistrationStep(context)),
|
||||
new WeightCorrectionStep(new SimpleCorrectionStep(context))
|
||||
);
|
||||
|
||||
pipeline = new TrainingPipeline(steps);
|
||||
pipeline.stopCondition(ctx -> ctx.globalLoss == 0.0F || ctx.epoch > 100);
|
||||
pipeline.beforeEpoch(ctx -> ctx.globalLoss = 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_the_whole_algorithm(){
|
||||
|
||||
List<Float> expectedGlobalLosses = List.of(
|
||||
2.0F,
|
||||
3.0F,
|
||||
3.0F,
|
||||
2.0F,
|
||||
1.0F,
|
||||
0.0F
|
||||
);
|
||||
|
||||
context.learningRate = 1F;
|
||||
pipeline.afterEpoch(ctx -> {
|
||||
int index = ctx.epoch-1;
|
||||
assertEquals(expectedGlobalLosses.get(index), context.globalLoss);
|
||||
});
|
||||
|
||||
pipeline.run(context);
|
||||
assertEquals(6, context.epoch);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user