Implement main structure of framework

This commit is contained in:
2026-03-23 16:39:12 +01:00
parent 76bc791889
commit b25aaba088
24 changed files with 353 additions and 73 deletions

View File

@@ -1,18 +1,17 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.Bias;
import com.naaturel.ANN.domain.model.neuron.Input;
import com.naaturel.ANN.domain.model.neuron.Synapse;
import com.naaturel.ANN.domain.model.neuron.Weight;
import com.naaturel.ANN.implementation.activationFunction.Linear;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext;
import com.naaturel.ANN.domain.model.training.TrainingPipeline;
import com.naaturel.ANN.implementation.activation.Heaviside;
import com.naaturel.ANN.implementation.correction.SimpleCorrectionStrategy;
import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.AdalineTraining;
import com.naaturel.ANN.implementation.training.GradientDescentTraining;
import com.naaturel.ANN.implementation.training.steps.*;
import java.util.*;
@@ -64,14 +63,28 @@ public class Main {
Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear());
Trainer trainer = new AdalineTraining();
Neuron neuron = new SimplePerceptron(syns, bias, new Heaviside());
Layer layer = new Layer(List.of(neuron));
Network network = new Network(List.of(layer));
long start = System.currentTimeMillis();
TrainingContext context = new TrainingContext();
context.dataset = dataSet;
context.model = network;
trainer.train(n, 0.03F, andDataSet);
List<TrainingStep> steps = List.of(
new PredictionStep(),
new DeltaStep(),
new SimpleLossStep(),
new SimpleErrorDetectionStep(),
new WeightCorrectionStep(new SimpleCorrectionStrategy())
);
TrainingPipeline pipeline = new TrainingPipeline(steps);
pipeline
.stopCondition(ctx -> ctx.globalLoss == 0 && ctx.epoch >= 1000)
.afterEpoch(ctx -> ctx.globalLoss = 0)
.withVerbose(true)
.run(context);
long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
}
}