Fully fixed gradient descent

This commit is contained in:
2026-03-18 13:15:01 +01:00
parent 7dc55a370f
commit a2a74566ba
3 changed files with 25 additions and 16 deletions

View File

@@ -26,6 +26,13 @@ public class Main {
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet andDataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(-1.0F)),
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
));
DataSet dataSet = new DataSet(Map.ofEntries(
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
@@ -51,17 +58,17 @@ public class Main {
));
List<Synapse> syns = new ArrayList<>();
syns.add(new Synapse(new Input(0), new Weight()));
syns.add(new Synapse(new Input(0), new Weight()));
syns.add(new Synapse(new Input(0), new Weight(0)));
syns.add(new Synapse(new Input(0), new Weight(0)));
Bias bias = new Bias(new Weight());
Bias bias = new Bias(new Weight(0));
Neuron n = new SimplePerceptron(syns, bias, new Linear());
GradientDescentTraining st = new GradientDescentTraining();
long start = System.currentTimeMillis();
st.train(n, 0.03F, orDataSet.toNormalized());
st.train(n, 0.2F, andDataSet);
long end = System.currentTimeMillis();
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);