Start to reimplement gradient descent

This commit is contained in:
2026-03-23 23:12:52 +01:00
parent 5ace4952fb
commit 0217607e9b
16 changed files with 157 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.domain.abstraction.TrainingStep;
import com.naaturel.ANN.domain.model.dataset.DataSet;
import com.naaturel.ANN.domain.model.dataset.DataSetEntry;
import com.naaturel.ANN.domain.model.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.dataset.Label;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.domain.model.training.TrainingContext;
@@ -15,49 +16,14 @@ import com.naaturel.ANN.implementation.neuron.SimplePerceptron;
import com.naaturel.ANN.implementation.training.SimpleTraining;
import com.naaturel.ANN.implementation.training.steps.*;
import javax.xml.crypto.Data;
import java.util.*;
public class Main {
public static void main(String[] args){
DataSet orDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 10.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 7.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(2.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(6.0F, 9.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 5.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 8.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(3.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 8.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 10.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(5.0F, 11.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 6.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 7.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 9.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(4.0F, 10.0F)), new Label(1.0F))
));
DataSet dataset = new DatasetExtractor().extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/or.csv");
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight(0)));
@@ -70,7 +36,7 @@ public class Main {
Network network = new Network(List.of(layer));
Trainer trainer = new SimpleTraining();
trainer.train(network, orDataSet);
trainer.train(network, dataset);
}
}