Implement main structure of framework

This commit is contained in:
Laurent
2026-03-23 16:39:12 +01:00
parent 76bc791889
commit 89d9abe329
24 changed files with 353 additions and 73 deletions

View File

@@ -1,18 +1,17 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.activation.Heaviside;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.*;
@@ -64,14 +63,28 @@ public class Main {
Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear());
Trainer trainer = new AdalineTraining();
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer));
long start = System.currentTimeMillis();
TrainingContext context = new TrainingContext();
context.dataset = dataSet;
context.model = network;
trainer.train(n, 0.03F, andDataSet);
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new SimpleLossStep(),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000)
.afterEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.run(context);
long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
}
}

View File

@@ -0,0 +1,9 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface CorrectionStrategy {
void apply(TrainingContext context);
}

View File

@@ -7,7 +7,7 @@ import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.ArrayList;
import java.util.List;
public abstract class Neuron {
public abstract class Neuron implements Trainable {
protected List<Synapse> synapses;
protected Bias bias;
@@ -19,37 +19,20 @@ public abstract class Neuron {
this.activationFunction = func;
}
public abstract float predict();
public abstract float calculateWeightedSum();
public int getSynCount(){
return this.synapses.size();
}
public void setInput(int index, Input input){
Synapse syn = this.synapses.get(index);
syn.setInput(input.getValue());
}
public Bias getBias(){
return this.bias;
}
public void updateBias(Weight weight) {
this.bias.setWeight(weight.getValue());
}
public Synapse getSynapse(int index){
return this.synapses.get(index);
public void updateWeight(int index, Weight weight) {
this.synapses.get(index).setWeight(weight.getValue());
}
public List<Synapse> getSynapses() {
return new ArrayList<>(this.synapses);
protected void setInputs(List<Input> inputs){
for(int i = 0; i < inputs.size() && i < synapses.size(); i++){
Synapse syn = this.synapses.get(i);
syn.setInput(inputs.get(i));
}
public void setWeight(int index, Weight weight){
Synapse syn = this.synapses.get(index);
syn.setWeight(weight.getValue());
}
}

View File

@@ -1,13 +0,0 @@
package com.naaturel.ANN.domain.abstraction;
public abstract class NeuronTrainer {
private Trainable trainable;
public NeuronTrainer(Trainable trainable){
this.trainable = trainable;
}
public abstract void train();
}

View File

@@ -1,7 +1,14 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import java.util.List;
import java.util.function.Consumer;
public interface Trainable {
List<Float> predict(List<Input> inputs);
void forEachSynapse(Consumer<Synapse> consumer);
}

View File

@@ -1,8 +1,10 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import java.util.List;
public interface Trainer {
void train(Neuron n, float learningRate, DataSet dataSet);
void train(TrainingContext context, List<TrainingStep> steps);
}

View File

@@ -0,0 +1,9 @@
package com.naaturel.ANN.domain.abstraction;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public interface TrainingStep {
void run(TrainingContext ctx);
}

View File

@@ -0,0 +1,33 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Layer implements Trainable {
private final List<Neuron> neurons;
public Layer(List<Neuron> neurons) {
this.neurons = neurons;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Float> result = new ArrayList<>();
for(Neuron neuron : this.neurons){
List<Float> res = neuron.predict(inputs);
result.addAll(res);
}
return result;
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.neurons.forEach(neuron -> neuron.forEachSynapse(consumer));
}
}

View File

@@ -0,0 +1,31 @@
package com.naaturel.ANN.domain.model.neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
public class Network implements Trainable {
private final List<Layer> layers;
public Network(List<Layer> layers) {
this.layers = layers;
}
@Override
public List<Float> predict(List<Input> inputs) {
List<Float> result = new ArrayList<>();
for(Layer layer : this.layers){
List<Float> res = layer.predict(inputs);
result.addAll(res);
}
return result;
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.layers.forEach(layer -> layer.forEachSynapse(consumer));
}
}

View File

@@ -14,8 +14,8 @@ public class Synapse {
return this.input.getValue();
}
public void setInput(float value){
this.input.setValue(value);
public void setInput(Input input){
this.input.setValue(input.getValue());
}
public float getWeight() {

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
public class TrainingContext {
public Trainable model;
public DataSet dataset;
public DataSetEntry currentEntry;
public float prediction;
public float delta;
public float localLoss;
public float globalLoss;
public float learningRate;
public int epoch;
}

View File

@@ -0,0 +1,68 @@
package com.naaturel.ANN.domain.model.training;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
public class TrainingPipeline {
private final List<TrainingStep> steps;
private Consumer<TrainingContext> afterAll;
private Predicate<TrainingContext> stopCondition;
private boolean verbose;
private boolean timeMeasurement;
public TrainingPipeline(List<TrainingStep> steps) {
this.steps = new ArrayList<>(steps);
}
public TrainingPipeline stopCondition(Predicate<TrainingContext> predicate) {
this.stopCondition = predicate;
return this;
}
public TrainingPipeline afterEpoch(Consumer<TrainingContext> consumer) {
this.afterAll = consumer;
return this;
}
public TrainingPipeline withVerbose(boolean enabled) {
this.verbose = enabled;
return this;
}
public TrainingPipeline withTimeMeasurement(boolean enabled) {
this.timeMeasurement = enabled;
return this;
}
public void run(TrainingContext ctx) {
do {
this.executeSteps(ctx);
if(this.afterAll != null) {
this.afterAll.accept(ctx);
}
} while (!this.stopCondition.test(ctx));
}
private void executeSteps(TrainingContext ctx){
for (DataSetEntry sample : ctx.dataset) {
ctx.currentEntry = sample;
for (TrainingStep step : steps) {
step.run(ctx);
if(this.verbose) {
System.out.printf("Epoch : %d, ", ctx.epoch);
System.out.printf("predicted : %.2f, ", ctx.prediction);
System.out.printf("expected : %.2f, ", ctx.dataset.getLabel(ctx.currentEntry).getValue());
System.out.printf("delta : %.2f\n", ctx.delta);
}
}
}
ctx.epoch += 1;
}
}

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activationFunction;
package com.naaturel.ANN.implementation.activation;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -1,4 +1,4 @@
package com.naaturel.ANN.implementation.activationFunction;
package com.naaturel.ANN.implementation.activation;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.Neuron;

View File

@@ -0,0 +1,17 @@
package com.naaturel.ANN.implementation.correction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleCorrectionStrategy implements CorrectionStrategy {
@Override
public void apply(TrainingContext context) {
context.model.forEachSynapse(syn -> {
float currentW = syn.getWeight();
float currentInput = syn.getInput();
float newValue = currentW + (context.learningRate * context.delta * currentInput);
syn.setWeight(newValue);
});
}
}

View File

@@ -1,22 +1,32 @@
package com.naaturel.ANN.implementation.neuron;
import com.naaturel.ANN.domain.abstraction.ActivationFunction;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainable;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import java.util.List;
import java.util.function.Consumer;
public class SimplePerceptron extends Neuron implements Trainable {
public class SimplePerceptron extends Neuron {
public SimplePerceptron(List<Synapse> synapses, Bias b, ActivationFunction func) {
super(synapses, b, func);
}
@Override
public float predict() {
return activationFunction.accept(this);
public List<Float> predict(List<Input> inputs) {
super.setInputs(inputs);
return List.of(activationFunction.accept(this));
}
@Override
public void forEachSynapse(Consumer<Synapse> consumer) {
this.synapses.forEach(consumer);
}
@Override

View File

@@ -9,7 +9,7 @@ import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
public class AdalineTraining implements Trainer {
/*public class AdalineTraining implements Trainer {
public AdalineTraining(){
@@ -78,4 +78,4 @@ public class AdalineTraining implements Trainer {
return (float) Math.pow(delta, 2)/2;
}
}
}*/

View File

@@ -13,7 +13,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GradientDescentTraining implements Trainer {
/*public class GradientDescentTraining implements Trainer {
public GradientDescentTraining(){
@@ -122,4 +122,4 @@ public class GradientDescentTraining implements Trainer {
return variance;
}
}
}*/

View File

@@ -1,12 +1,15 @@
package com.naaturel.ANN.implementation.training;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.domain.model.neuron.Network;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.List;
public class SimpleTraining implements Trainer {
@@ -14,7 +17,12 @@ public class SimpleTraining implements Trainer {
}
public void train(Neuron n, float learningRate, DataSet dataSet) {
@Override
public void train(TrainingContext context, List<TrainingStep> steps) {
}
/*public void train(Neuron n, float learningRate, DataSet dataSet) {
int epoch = 1;
int errorCount;
@@ -65,5 +73,5 @@ public class SimpleTraining implements Trainer {
private float calculateLoss(float delta){
return Math.abs(delta);
}
*/
}

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class DeltaStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
DataSet dataSet = ctx.dataset;
DataSetEntry entry = ctx.currentEntry;
Label label = dataSet.getLabel(entry);
ctx.delta = label.getValue() - ctx.prediction;
}
}

View File

@@ -0,0 +1,21 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import java.util.ArrayList;
import java.util.List;
public class PredictionStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
List<Input> inputs = new ArrayList<>();
for(Float f : ctx.currentEntry.getData()){
inputs.add(new Input(f));
}
List<Float> predictions = ctx.model.predict(inputs);
ctx.prediction = predictions.getFirst();
}
}

View File

@@ -0,0 +1,13 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleErrorDetectionStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
ctx.globalLoss += ctx.localLoss;
}
}

View File

@@ -0,0 +1,12 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class SimpleLossStep implements TrainingStep {
@Override
public void run(TrainingContext ctx) {
ctx.localLoss = Math.abs(ctx.delta);
}
}

View File

@@ -0,0 +1,19 @@
package com.naaturel.ANN.implementation.training.steps;
import com.naaturel.ANN.domain.abstraction.CorrectionStrategy;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.training.TrainingContext;
public class WeightCorrectionStep implements TrainingStep {
private final CorrectionStrategy correctionStrategy;
public WeightCorrectionStep(CorrectionStrategy strategy) {
this.correctionStrategy = strategy;
}
@Override
public void run(TrainingContext ctx) {
this.correctionStrategy.apply(ctx);
}
}