Add multi-layer support
This commit is contained in:
@@ -10,6 +10,8 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
implementation("org.jfree:jfreechart:1.5.4")
|
||||||
|
|
||||||
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
testImplementation(platform("org.junit:junit-bom:5.10.0"))
|
||||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||||
|
|||||||
@@ -1,16 +1,13 @@
|
|||||||
package com.naaturel.ANN;
|
package com.naaturel.ANN;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
|
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
|
||||||
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
import com.naaturel.ANN.implementation.training.AdalineTraining;
|
||||||
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
|
||||||
import com.naaturel.ANN.implementation.training.SimpleTraining;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@@ -18,20 +15,27 @@ public class Main {
|
|||||||
|
|
||||||
public static void main(String[] args){
|
public static void main(String[] args){
|
||||||
|
|
||||||
|
int nbrInput = 3;
|
||||||
|
int nbrClass = 3;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
|
||||||
|
|
||||||
DataSet andDataset = new DatasetExtractor()
|
List<Neuron> neurons = new ArrayList<>();
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
|
||||||
|
|
||||||
|
for (int i=0; i < nbrClass; i++){
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
|
for (int j=0; j < nbrInput; j++){
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
}
|
||||||
|
|
||||||
Bias bias = new Bias(new Weight(0));
|
Bias bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
Neuron n = new Neuron(syns, bias, new Linear());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
neurons.add(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Layer layer = new Layer(neurons);
|
||||||
Network network = new Network(List.of(layer));
|
Network network = new Network(List.of(layer));
|
||||||
|
|
||||||
Trainer trainer = new AdalineTraining();
|
Trainer trainer = new AdalineTraining();
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
|
||||||
public interface ActivationFunction {
|
public interface ActivationFunction {
|
||||||
|
|
||||||
float accept(Neuron n);
|
float accept(Neuron n);
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
package com.naaturel.ANN.domain.abstraction;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
import com.naaturel.ANN.domain.model.neuron.Input;
|
||||||
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.BiConsumer;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
public interface Model {
|
public interface Model {
|
||||||
int synCount();
|
int synCount();
|
||||||
|
void forEachNeuron(Consumer<Neuron> consumer);
|
||||||
void forEachSynapse(Consumer<Synapse> consumer);
|
void forEachSynapse(Consumer<Synapse> consumer);
|
||||||
List<Float> predict(List<Input> inputs);
|
List<Float> predict(List<Input> inputs);
|
||||||
|
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.abstraction;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Weight;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class Neuron implements Model {
|
|
||||||
|
|
||||||
protected List<Synapse> synapses;
|
|
||||||
protected Bias bias;
|
|
||||||
protected ActivationFunction activationFunction;
|
|
||||||
|
|
||||||
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
|
||||||
this.synapses = synapses;
|
|
||||||
this.bias = bias;
|
|
||||||
this.activationFunction = func;
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract float calculateWeightedSum();
|
|
||||||
|
|
||||||
public void updateBias(Weight weight) {
|
|
||||||
this.bias.setWeight(weight.getValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
public void updateWeight(int index, Weight weight) {
|
|
||||||
this.synapses.get(index).setWeight(weight.getValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void setInputs(List<Input> inputs){
|
|
||||||
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
|
||||||
Synapse syn = this.synapses.get(i);
|
|
||||||
syn.setInput(inputs.get(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int synCount() {
|
|
||||||
return this.synapses.size()+1; //take the bias in account
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,16 +2,17 @@ package com.naaturel.ANN.domain.abstraction;
|
|||||||
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public abstract class TrainingContext {
|
public abstract class TrainingContext {
|
||||||
public Model model;
|
public Model model;
|
||||||
public DataSet dataset;
|
public DataSet dataset;
|
||||||
public DataSetEntry currentEntry;
|
public DataSetEntry currentEntry;
|
||||||
|
|
||||||
public Label currentLabel;
|
public List<Float> expectations;
|
||||||
public float prediction;
|
public List<Float> predictions;
|
||||||
public float delta;
|
public List<Float> deltas;
|
||||||
|
|
||||||
public float globalLoss;
|
public float globalLoss;
|
||||||
public float localLoss;
|
public float localLoss;
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import java.util.*;
|
|||||||
|
|
||||||
public class DataSet implements Iterable<DataSetEntry>{
|
public class DataSet implements Iterable<DataSetEntry>{
|
||||||
|
|
||||||
private Map<DataSetEntry, Label> data;
|
private final Map<DataSetEntry, Labels> data;
|
||||||
|
|
||||||
public DataSet() {
|
public DataSet() {
|
||||||
this(new LinkedHashMap<>());
|
this(new LinkedHashMap<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataSet(Map<DataSetEntry, Label> data){
|
public DataSet(Map<DataSetEntry, Labels> data){
|
||||||
this.data = data;
|
this.data = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,8 +24,8 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
return new ArrayList<>(this.data.keySet());
|
return new ArrayList<>(this.data.keySet());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Label getLabel(DataSetEntry entry){
|
public List<Float> getLabelsAsFloat(DataSetEntry entry){
|
||||||
return this.data.get(entry);
|
return this.data.get(entry).getValues();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DataSet toNormalized() {
|
public DataSet toNormalized() {
|
||||||
@@ -38,13 +38,15 @@ public class DataSet implements Iterable<DataSetEntry>{
|
|||||||
.max(Float::compare)
|
.max(Float::compare)
|
||||||
.orElse(1.0F);
|
.orElse(1.0F);
|
||||||
|
|
||||||
Map<DataSetEntry, Label> normalized = new HashMap<>();
|
Map<DataSetEntry, Labels> normalized = new HashMap<>();
|
||||||
for (DataSetEntry entry : entries) {
|
for (DataSetEntry entry : entries) {
|
||||||
List<Input> normalizedData = new ArrayList<>();
|
List<Input> normalizedData = new ArrayList<>();
|
||||||
|
|
||||||
for (Input input : entry.getData()) {
|
for (Input input : entry.getData()) {
|
||||||
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
|
||||||
normalizedData.add(normalizedInput);
|
normalizedData.add(normalizedInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,19 +9,29 @@ import java.util.*;
|
|||||||
|
|
||||||
public class DatasetExtractor {
|
public class DatasetExtractor {
|
||||||
|
|
||||||
public DataSet extract(String path) {
|
public DataSet extract(String path, int nbrLabels) {
|
||||||
Map<DataSetEntry, Label> data = new LinkedHashMap<>();
|
Map<DataSetEntry, Labels> data = new LinkedHashMap<>();
|
||||||
|
|
||||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||||
String line;
|
String line;
|
||||||
while ((line = reader.readLine()) != null) {
|
while ((line = reader.readLine()) != null) {
|
||||||
String[] parts = line.split(",");
|
String[] parts = line.split(",");
|
||||||
|
|
||||||
|
String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels);
|
||||||
|
String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length);
|
||||||
|
|
||||||
List<Input> inputs = new ArrayList<>();
|
List<Input> inputs = new ArrayList<>();
|
||||||
for (int i = 0; i < parts.length - 1; i++) {
|
List<Float> labels = new ArrayList<>();
|
||||||
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
|
|
||||||
|
for (String entry : rawInputs) {
|
||||||
|
inputs.add(new Input(Float.parseFloat(entry.trim())));
|
||||||
}
|
}
|
||||||
float label = Float.parseFloat(parts[parts.length - 1].trim());
|
|
||||||
data.put(new DataSetEntry(inputs), new Label(label));
|
for (String entry : rawLabels) {
|
||||||
|
labels.add(Float.parseFloat(entry.trim()));
|
||||||
|
}
|
||||||
|
|
||||||
|
data.put(new DataSetEntry(inputs), new Labels(labels));
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
throw new RuntimeException("Failed to read dataset from: " + path, e);
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
package com.naaturel.ANN.domain.model.dataset;
|
|
||||||
|
|
||||||
public class Label {
|
|
||||||
|
|
||||||
private float value;
|
|
||||||
|
|
||||||
public Label(float value){
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public float getValue() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.dataset;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Labels {
|
||||||
|
|
||||||
|
private final List<Float> values;
|
||||||
|
|
||||||
|
public Labels(List<Float> value){
|
||||||
|
this.values = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Float> getValues() {
|
||||||
|
return values.stream().toList();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.naaturel.ANN.domain.model.neuron;
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Model;
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -34,6 +33,11 @@ public class Layer implements Model {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
|
this.neurons.forEach(consumer);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
|
||||||
|
|||||||
@@ -33,6 +33,11 @@ public class Network implements Model {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
|
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package com.naaturel.ANN.domain.model.neuron;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
|
import com.naaturel.ANN.domain.abstraction.Model;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.BiConsumer;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
public class Neuron implements Model {
|
||||||
|
|
||||||
|
protected List<Synapse> synapses;
|
||||||
|
protected Bias bias;
|
||||||
|
protected ActivationFunction activationFunction;
|
||||||
|
|
||||||
|
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
|
||||||
|
this.synapses = synapses;
|
||||||
|
this.bias = bias;
|
||||||
|
this.activationFunction = func;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateBias(Weight weight) {
|
||||||
|
this.bias.setWeight(weight.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateWeight(int index, Weight weight) {
|
||||||
|
this.synapses.get(index).setWeight(weight.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void setInputs(List<Input> inputs){
|
||||||
|
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
|
||||||
|
Synapse syn = this.synapses.get(i);
|
||||||
|
syn.setInput(inputs.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int synCount() {
|
||||||
|
return this.synapses.size()+1; //take the bias in account
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Float> predict(List<Input> inputs) {
|
||||||
|
this.setInputs(inputs);
|
||||||
|
return List.of(activationFunction.accept(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||||
|
consumer.accept(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||||
|
consumer.accept(this.bias);
|
||||||
|
this.synapses.forEach(consumer);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float calculateWeightedSum() {
|
||||||
|
float res = 0;
|
||||||
|
res += this.bias.getWeight() * this.bias.getInput();
|
||||||
|
for(Synapse syn : this.synapses){
|
||||||
|
res += syn.getWeight() * syn.getInput();
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -4,8 +4,8 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
|
|
||||||
import java.sql.Time;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
@@ -74,19 +74,23 @@ public class TrainingPipeline {
|
|||||||
|
|
||||||
private void executeSteps(TrainingContext ctx){
|
private void executeSteps(TrainingContext ctx){
|
||||||
for (DataSetEntry entry : ctx.dataset) {
|
for (DataSetEntry entry : ctx.dataset) {
|
||||||
|
|
||||||
ctx.currentEntry = entry;
|
ctx.currentEntry = entry;
|
||||||
ctx.currentLabel = ctx.dataset.getLabel(entry);
|
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
|
||||||
|
|
||||||
for (TrainingStep step : steps) {
|
for (TrainingStep step : steps) {
|
||||||
step.run();
|
step.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(this.verbose) {
|
if(this.verbose) {
|
||||||
System.out.printf("Epoch : %d, ", ctx.epoch);
|
System.out.printf("Epoch : %d, ", ctx.epoch);
|
||||||
System.out.printf("predicted : %.2f, ", ctx.prediction);
|
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
|
||||||
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
|
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
|
||||||
System.out.printf("delta : %.2f, ", ctx.delta);
|
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray()));
|
||||||
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
System.out.printf("loss : %.5f\n", ctx.localLoss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.epoch += 1;
|
ctx.epoch += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,13 +15,23 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
AtomicInteger i = new AtomicInteger(0);
|
|
||||||
context.model.forEachSynapse(syn -> {
|
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||||
float corrector = context.correctorTerms.get(i.get());
|
AtomicInteger synIndex = new AtomicInteger(0);
|
||||||
corrector += context.learningRate * context.delta * syn.getInput();
|
|
||||||
context.correctorTerms.set(i.get(), corrector);
|
context.model.forEachNeuron(neuron -> {
|
||||||
i.incrementAndGet();
|
float correspondingDelta = context.deltas.get(neuronIndex.get());
|
||||||
|
|
||||||
|
neuron.forEachSynapse(syn -> {
|
||||||
|
float corrector = context.correctorTerms.get(synIndex.get());
|
||||||
|
corrector += context.learningRate * correspondingDelta * syn.getInput();
|
||||||
|
context.correctorTerms.set(synIndex.get(), corrector);
|
||||||
|
synIndex.incrementAndGet();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
neuronIndex.incrementAndGet();
|
||||||
|
});
|
||||||
|
|
||||||
context.globalLoss += context.localLoss;
|
context.globalLoss += context.localLoss;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.naaturel.ANN.implementation.gradientDescent;
|
package com.naaturel.ANN.implementation.gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
|
||||||
public class Linear implements ActivationFunction {
|
public class Linear implements ActivationFunction {
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
|||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
|
||||||
|
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class SquareLossStrategy implements AlgorithmStrategy {
|
public class SquareLossStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
private final TrainingContext context;
|
private final TrainingContext context;
|
||||||
@@ -14,6 +16,8 @@ public class SquareLossStrategy implements AlgorithmStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
|
Stream<Float> deltaStream = this.context.deltas.stream();
|
||||||
|
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
|
||||||
|
this.context.localLoss /= 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
package com.naaturel.ANN.implementation.neuron;
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Bias;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Input;
|
|
||||||
import com.naaturel.ANN.domain.model.neuron.Synapse;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.function.Consumer;
|
|
||||||
|
|
||||||
public class SimplePerceptron extends Neuron {
|
|
||||||
|
|
||||||
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
|
|
||||||
super(synapses, b, func);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Float> predict(List<Input> inputs) {
|
|
||||||
super.setInputs(inputs);
|
|
||||||
return List.of(activationFunction.accept(this));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
|
||||||
consumer.accept(this.bias);
|
|
||||||
this.synapses.forEach(consumer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float calculateWeightedSum() {
|
|
||||||
float res = 0;
|
|
||||||
res += this.bias.getWeight() * this.bias.getInput();
|
|
||||||
for(Synapse syn : super.synapses){
|
|
||||||
res += syn.getWeight() * syn.getInput();
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.naaturel.ANN.implementation.simplePerceptron;
|
package com.naaturel.ANN.implementation.simplePerceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
|
|
||||||
public class Heaviside implements ActivationFunction {
|
public class Heaviside implements ActivationFunction {
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron;
|
|||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
|
||||||
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
@@ -14,12 +16,20 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
if(context.currentLabel.getValue() == context.prediction) return ;
|
if(context.expectations.equals(context.predictions)) return;
|
||||||
context.model.forEachSynapse(syn -> {
|
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||||
|
AtomicInteger synIndex = new AtomicInteger(0);
|
||||||
|
|
||||||
|
context.model.forEachNeuron(neuron -> {
|
||||||
|
float correspondingDelta = context.deltas.get(neuronIndex.get());
|
||||||
|
neuron.forEachSynapse(syn -> {
|
||||||
float currentW = syn.getWeight();
|
float currentW = syn.getWeight();
|
||||||
float currentInput = syn.getInput();
|
float currentInput = syn.getInput();
|
||||||
float newValue = currentW + (context.learningRate * context.delta * currentInput);
|
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
|
||||||
syn.setWeight(newValue);
|
syn.setWeight(newValue);
|
||||||
|
synIndex.incrementAndGet();
|
||||||
|
});
|
||||||
|
neuronIndex.incrementAndGet();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,11 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
|||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
||||||
|
|
||||||
@@ -18,9 +22,14 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy {
|
|||||||
public void apply() {
|
public void apply() {
|
||||||
DataSet dataSet = context.dataset;
|
DataSet dataSet = context.dataset;
|
||||||
DataSetEntry entry = context.currentEntry;
|
DataSetEntry entry = context.currentEntry;
|
||||||
Label label = dataSet.getLabel(entry);
|
List<Float> predicted = context.predictions;
|
||||||
|
List<Float> expected = dataSet.getLabelsAsFloat(entry);
|
||||||
|
|
||||||
context.delta = label.getValue() - context.prediction;
|
//context.delta = label.getValue() - context.predictions;
|
||||||
|
context.deltas = IntStream.range(0, predicted.size())
|
||||||
|
.mapToObj(i -> expected.get(i) - predicted.get(i))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
System.out.printf("");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ public class SimpleLossStrategy implements AlgorithmStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
this.context.localLoss = Math.abs(this.context.delta);
|
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ public class SimplePredictionStrategy implements AlgorithmStrategy {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void apply() {
|
public void apply() {
|
||||||
List<Float> predictions = context.model.predict(context.currentEntry.getData());
|
context.predictions = context.model.predict(context.currentEntry.getData());
|
||||||
context.prediction = predictions.getFirst();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000)
|
||||||
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
|
||||||
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
|
||||||
.withVerbose(true)
|
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
|
.withVerbose(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
new TrainingPipeline(steps)
|
new TrainingPipeline(steps)
|
||||||
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000)
|
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500)
|
||||||
.beforeEpoch(ctx -> {
|
.beforeEpoch(ctx -> {
|
||||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||||
gdCtx.globalLoss = 0.0F;
|
gdCtx.globalLoss = 0.0F;
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
package com.naaturel.ANN.implementation.training.steps;
|
package com.naaturel.ANN.implementation.training.steps;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingContext;
|
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
|
|
||||||
import com.naaturel.ANN.domain.model.dataset.Label;
|
|
||||||
|
|
||||||
public class DeltaStep implements TrainingStep {
|
public class DeltaStep implements TrainingStep {
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.naaturel.ANN.infrastructure.graph;
|
||||||
|
|
||||||
|
public class GraphVisualizer {
|
||||||
|
|
||||||
|
public GraphVisualizer(){
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package adaline;
|
package adaline;
|
||||||
|
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
@@ -9,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.*;
|
|||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
|
||||||
@@ -37,7 +36,7 @@ public class AdalineTest {
|
|||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void init(){
|
public void init(){
|
||||||
dataset = new DatasetExtractor()
|
dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -45,7 +44,7 @@ public class AdalineTest {
|
|||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
Neuron neuron = new Neuron(syns, bias, new Linear());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
network = new Network(List.of(layer));
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package gradientDescent;
|
package gradientDescent;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.gradientDescent.*;
|
import com.naaturel.ANN.implementation.gradientDescent.*;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
@@ -15,7 +14,6 @@ import org.junit.jupiter.api.Test;
|
|||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@@ -34,7 +32,7 @@ public class GradientDescentTest {
|
|||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void init(){
|
public void init(){
|
||||||
dataset = new DatasetExtractor()
|
dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -42,7 +40,7 @@ public class GradientDescentTest {
|
|||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
|
Neuron neuron = new Neuron(syns, bias, new Linear());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
network = new Network(List.of(layer));
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
package perceptron;
|
package perceptron;
|
||||||
|
|
||||||
import com.naaturel.ANN.domain.abstraction.Neuron;
|
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||||
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
import com.naaturel.ANN.domain.abstraction.TrainingStep;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
import com.naaturel.ANN.domain.model.dataset.DataSet;
|
||||||
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
|
||||||
import com.naaturel.ANN.domain.model.neuron.*;
|
import com.naaturel.ANN.domain.model.neuron.*;
|
||||||
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
|
||||||
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
|
|
||||||
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
import com.naaturel.ANN.implementation.simplePerceptron.*;
|
||||||
import com.naaturel.ANN.implementation.training.steps.*;
|
import com.naaturel.ANN.implementation.training.steps.*;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
@@ -32,7 +31,7 @@ public class SimplePerceptronTest {
|
|||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void init(){
|
public void init(){
|
||||||
dataset = new DatasetExtractor()
|
dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1);
|
||||||
|
|
||||||
List<Synapse> syns = new ArrayList<>();
|
List<Synapse> syns = new ArrayList<>();
|
||||||
syns.add(new Synapse(new Input(0), new Weight(0)));
|
syns.add(new Synapse(new Input(0), new Weight(0)));
|
||||||
@@ -40,7 +39,7 @@ public class SimplePerceptronTest {
|
|||||||
|
|
||||||
bias = new Bias(new Weight(0));
|
bias = new Bias(new Weight(0));
|
||||||
|
|
||||||
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
|
Neuron neuron = new Neuron(syns, bias, new Heaviside());
|
||||||
Layer layer = new Layer(List.of(neuron));
|
Layer layer = new Layer(List.of(neuron));
|
||||||
network = new Network(List.of(layer));
|
network = new Network(List.of(layer));
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user