Add multi-layer support

This commit is contained in:
2026-03-26 21:21:31 +01:00
parent 3dd4404f51
commit 64bc830f18
30 changed files with 228 additions and 172 deletions

View File

@@ -1,16 +1,13 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.simplePerceptron.Heaviside;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import java.util.*;
@@ -18,20 +15,27 @@ public class Main {
public static void main(String[] args){
int nbrInput = 3;
int nbrClass = 3;
DataSet dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv");
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv", nbrClass);
DataSet andDataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
List<Neuron> neurons = new ArrayList<>();
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
for (int i=0; i < nbrClass; i++){
List<Synapse> syns = new ArrayList<>();
for (int j=0; j < nbrInput; j++){
syns.add(new Synapse(new Input(0), new Weight(0)));
}
Bias bias = new Bias(new Weight(0));
Bias bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
Neuron n = new Neuron(syns, bias, new Linear());
neurons.add(n);
}
Layer layer = new Layer(neurons);
Network network = new Network(List.of(layer));
Trainer trainer = new AdalineTraining();

View File

@@ -1,5 +1,7 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public interface ActivationFunction {
float accept(Neuron n);

View File

@@ -1,13 +1,16 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
public interface Model {
int synCount();
void forEachNeuron(Consumer<Neuron> consumer);
void forEachSynapse(Consumer<Synapse> consumer);
List<Float> predict(List<Input> inputs);

View File

@@ -1,42 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.List;
public abstract class Neuron implements Model {
protected List<Synapse> synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
}
public abstract float calculateWeightedSum();
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
}
public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue());
}
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
Synapse syn = this.synapses.get(i);
syn.setInput(inputs.get(i));
}
}
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias in account
}
}

View File

@@ -2,16 +2,17 @@ package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import java.util.List;
public abstract class TrainingContext {
public Model model;
public DataSet dataset;
public DataSetEntry currentEntry;
public Label currentLabel;
public float prediction;
public float delta;
public List<Float> expectations;
public List<Float> predictions;
public List<Float> deltas;
public float globalLoss;
public float localLoss;

View File

@@ -6,13 +6,13 @@ import java.util.*;
public class DataSet implements Iterable<DataSetEntry>{
private Map<DataSetEntry, Label> data;
private final Map<DataSetEntry, Labels> data;
public DataSet() {
this(new LinkedHashMap<>());
}
public DataSet(Map<DataSetEntry, Label> data){
public DataSet(Map<DataSetEntry, Labels> data){
this.data = data;
}
@@ -24,8 +24,8 @@ public class DataSet implements Iterable<DataSetEntry>{
return new ArrayList<>(this.data.keySet());
}
public Label getLabel(DataSetEntry entry){
return this.data.get(entry);
public List<Float> getLabelsAsFloat(DataSetEntry entry){
return this.data.get(entry).getValues();
}
public DataSet toNormalized() {
@@ -38,13 +38,15 @@ public class DataSet implements Iterable<DataSetEntry>{
.max(Float::compare)
.orElse(1.0F);
Map<DataSetEntry, Label> normalized = new HashMap<>();
Map<DataSetEntry, Labels> normalized = new HashMap<>();
for (DataSetEntry entry : entries) {
List<Input> normalizedData = new ArrayList<>();
for (Input input : entry.getData()) {
Input normalizedInput = new Input(Math.round((input.getValue() / maxAbs) * 100.0F) / 100.0F);
normalizedData.add(normalizedInput);
}
normalized.put(new DataSetEntry(normalizedData), this.data.get(entry));
}

View File

@@ -9,19 +9,29 @@ import java.util.*;
public class DatasetExtractor {
public DataSet extract(String path) {
Map<DataSetEntry, Label> data = new LinkedHashMap<>();
public DataSet extract(String path, int nbrLabels) {
Map<DataSetEntry, Labels> data = new LinkedHashMap<>();
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split(",");
String[] rawInputs = Arrays.copyOfRange(parts, 0, parts.length-nbrLabels);
String[] rawLabels = Arrays.copyOfRange(parts, parts.length-nbrLabels, parts.length);
List<Input> inputs = new ArrayList<>();
for (int i = 0; i < parts.length - 1; i++) {
inputs.add(new Input(Float.parseFloat(parts[i].trim())));
List<Float> labels = new ArrayList<>();
for (String entry : rawInputs) {
inputs.add(new Input(Float.parseFloat(entry.trim())));
}
float label = Float.parseFloat(parts[parts.length - 1].trim());
data.put(new DataSetEntry(inputs), new Label(label));
for (String entry : rawLabels) {
labels.add(Float.parseFloat(entry.trim()));
}
data.put(new DataSetEntry(inputs), new Labels(labels));
}
} catch (IOException e) {
throw new RuntimeException("Failed to read dataset from: " + path, e);

View File

@@ -1,15 +0,0 @@
package com.naaturel.ANN.domain.model.dataset;
public class Label {
private float value;
public Label(float value){
this.value = value;
}
public float getValue() {
return value;
}
}

View File

@@ -0,0 +1,16 @@
package com.naaturel.ANN.domain.model.dataset;
import java.util.List;
public class Labels {
private final List<Float> values;
public Labels(List<Float> value){
this.values = value;
}
public List<Float> getValues() {
return values.stream().toList();
}
}

View File

@@ -1,6 +1,5 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.ArrayList;
@@ -34,6 +33,11 @@ public class Layer implements Model {
return res;
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.neurons.forEach(consumer);
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));

View File

@@ -33,6 +33,11 @@ public class Network implements Model {
return res;
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
this.layers.forEach(layer -> layer.forEachNeuron(consumer));
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));

View File

@@ -0,0 +1,67 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Model;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
public class Neuron implements Model {
protected List<Synapse> synapses;
protected Bias bias;
protected ActivationFunction activationFunction;
public Neuron(List<Synapse> synapses, Bias bias, ActivationFunction func){
this.synapses = synapses;
this.bias = bias;
this.activationFunction = func;
}
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
}
public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue());
}
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
Synapse syn = this.synapses.get(i);
syn.setInput(inputs.get(i));
}
}
@Override
public int synCount() {
return this.synapses.size()+1; //take the bias in account
}
@Override
public List<Float> predict(List<Input> inputs) {
this.setInputs(inputs);
return List.of(activationFunction.accept(this));
}
@Override
public void forEachNeuron(Consumer<Neuron> consumer) {
consumer.accept(this);
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
this.synapses.forEach(consumer);
}
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : this.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
}
}

View File

@@ -4,8 +4,8 @@ import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import java.sql.Time;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
@@ -74,19 +74,23 @@ public class TrainingPipeline {
private void executeSteps(TrainingContext ctx){
for (DataSetEntry entry : ctx.dataset) {
ctx.currentEntry = entry;
ctx.currentLabel = ctx.dataset.getLabel(entry);
ctx.expectations = ctx.dataset.getLabelsAsFloat(entry);
for (TrainingStep step : steps) {
step.run();
}
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.currentLabel.getValue());
System.out.printf("delta : %.2f, ", ctx.delta);
System.out.printf("predicted : %s, ", Arrays.toString(ctx.predictions.toArray()));
System.out.printf("expected : %s, ", Arrays.toString(ctx.expectations.toArray()));
System.out.printf("delta : %s, ", Arrays.toString(ctx.deltas.toArray()));
System.out.printf("loss : %.5f\n", ctx.localLoss);
}
}
ctx.epoch += 1;
}
}

View File

@@ -15,13 +15,23 @@ public class GradientDescentErrorStrategy implements AlgorithmStrategy {
@Override
public void apply() {
AtomicInteger i = new AtomicInteger(0);
context.model.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(i.get());
corrector += context.learningRate * context.delta * syn.getInput();
context.correctorTerms.set(i.get(), corrector);
i.incrementAndGet();
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas.get(neuronIndex.get());
neuron.forEachSynapse(syn -> {
float corrector = context.correctorTerms.get(synIndex.get());
corrector += context.learningRate * correspondingDelta * syn.getInput();
context.correctorTerms.set(synIndex.get(), corrector);
synIndex.incrementAndGet();
});
neuronIndex.incrementAndGet();
});
context.globalLoss += context.localLoss;
}
}

View File

@@ -1,7 +1,7 @@
package com.naaturel.ANN.implementation.gradientDescent;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Linear implements ActivationFunction {

View File

@@ -4,6 +4,8 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleTrainingContext;
import java.util.stream.Stream;
public class SquareLossStrategy implements AlgorithmStrategy {
private final TrainingContext context;
@@ -14,6 +16,8 @@ public class SquareLossStrategy implements AlgorithmStrategy {
@Override
public void apply() {
this.context.localLoss = (float)Math.pow(this.context.delta, 2)/2;
Stream<Float> deltaStream = this.context.deltas.stream();
this.context.localLoss = deltaStream.reduce(0.0F, (acc, d) -> (float) (acc + Math.pow(d, 2)));
this.context.localLoss /= 2;
}
}

View File

@@ -1,40 +0,0 @@
package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public class SimplePerceptron extends Neuron {
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
super(synapses, b, func);
}
@Override
public List<Float> predict(List<Input> inputs) {
super.setInputs(inputs);
return List.of(activationFunction.accept(this));
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
consumer.accept(this.bias);
this.synapses.forEach(consumer);
}
@Override
public float calculateWeightedSum() {
float res = 0;
res += this.bias.getWeight() * this.bias.getInput();
for(Synapse syn : super.synapses){
res += syn.getWeight() * syn.getInput();
}
return res;
}
}

View File

@@ -1,7 +1,7 @@
package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
public class Heaviside implements ActivationFunction {

View File

@@ -3,6 +3,8 @@ package com.naaturel.ANN.implementation.simplePerceptron;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import java.util.concurrent.atomic.AtomicInteger;
public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@@ -14,12 +16,20 @@ public class SimpleCorrectionStrategy implements AlgorithmStrategy {
@Override
public void apply() {
if(context.currentLabel.getValue() == context.prediction) return ;
context.model.forEachSynapse(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * context.delta * currentInput);
syn.setWeight(newValue);
if(context.expectations.equals(context.predictions)) return;
AtomicInteger neuronIndex = new AtomicInteger(0);
AtomicInteger synIndex = new AtomicInteger(0);
context.model.forEachNeuron(neuron -> {
float correspondingDelta = context.deltas.get(neuronIndex.get());
neuron.forEachSynapse(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
syn.setWeight(newValue);
synIndex.incrementAndGet();
});
neuronIndex.incrementAndGet();
});
}
}

View File

@@ -4,7 +4,11 @@ import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class SimpleDeltaStrategy implements AlgorithmStrategy {
@@ -18,9 +22,14 @@ public class SimpleDeltaStrategy implements AlgorithmStrategy {
public void apply() {
DataSet dataSet = context.dataset;
DataSetEntry entry = context.currentEntry;
Label label = dataSet.getLabel(entry);
List<Float> predicted = context.predictions;
List<Float> expected = dataSet.getLabelsAsFloat(entry);
context.delta = label.getValue() - context.prediction;
//context.delta = label.getValue() - context.predictions;
context.deltas = IntStream.range(0, predicted.size())
.mapToObj(i -> expected.get(i) - predicted.get(i))
.collect(Collectors.toList());
System.out.printf("");
}
}

View File

@@ -12,6 +12,6 @@ public class SimpleLossStrategy implements AlgorithmStrategy {
@Override
public void apply() {
this.context.localLoss = Math.abs(this.context.delta);
this.context.localLoss = this.context.deltas.stream().reduce(0.0F, Float::sum);
}
}

View File

@@ -15,7 +15,6 @@ public class SimplePredictionStrategy implements AlgorithmStrategy {
@Override
public void apply() {
List<Float> predictions = context.model.predict(context.currentEntry.getData());
context.prediction = predictions.getFirst();
context.predictions = context.model.predict(context.currentEntry.getData());
}
}

View File

@@ -38,11 +38,11 @@ public class AdalineTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.125F || ctx.epoch > 10000)
.stopCondition(ctx -> ctx.globalLoss <= 0.04F || ctx.epoch > 1000)
.beforeEpoch(ctx -> ctx.globalLoss = 0.0F)
.afterEpoch(ctx -> ctx.globalLoss /= context.dataset.size())
.withVerbose(true)
.withTimeMeasurement(true)
.withVerbose(true)
.run(context);
}

View File

@@ -38,7 +38,7 @@ public class GradientDescentTraining implements Trainer {
);
new TrainingPipeline(steps)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 5000)
.stopCondition(ctx -> ctx.globalLoss <= 0.08F || ctx.epoch > 500)
.beforeEpoch(ctx -> {
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
gdCtx.globalLoss = 0.0F;

View File

@@ -1,11 +1,7 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.AlgorithmStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingContext;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
public class DeltaStep implements TrainingStep {

View File

@@ -0,0 +1,9 @@
package com.naaturel.ANN.infrastructure.graph;
public class GraphVisualizer {
public GraphVisualizer(){
}
}

View File

@@ -1,7 +1,7 @@
package adaline;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
@@ -9,7 +9,6 @@ import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.adaline.AdalineTrainingContext;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleDeltaStrategy;
import com.naaturel.ANN.implementation.simplePerceptron.SimpleErrorRegistrationStrategy;
@@ -37,7 +36,7 @@ public class AdalineTest {
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -45,7 +44,7 @@ public class AdalineTest {
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Neuron neuron = new Neuron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));

View File

@@ -1,13 +1,12 @@
package gradientDescent;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.gradientDescent.*;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
@@ -15,7 +14,6 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.*;
@@ -34,7 +32,7 @@ public class GradientDescentTest {
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv");
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and-gradient.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -42,7 +40,7 @@ public class GradientDescentTest {
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Linear());
Neuron neuron = new Neuron(syns, bias, new Linear());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));

View File

@@ -1,12 +1,11 @@
package perceptron;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.simplePerceptron.*;
import com.naaturel.ANN.implementation.training.steps.*;
import org.junit.jupiter.api.BeforeEach;
@@ -32,7 +31,7 @@ public class SimplePerceptronTest {
@BeforeEach
public void init(){
dataset = new DatasetExtractor()
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv");
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv", 1);
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -40,7 +39,7 @@ public class SimplePerceptronTest {
bias = new Bias(new Weight(0));
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Neuron neuron = new Neuron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
network = new Network(List.of(layer));