Added dataset anf fixed gradient descent algorithm
This commit is contained in:
@@ -19,6 +19,13 @@ public class Main {
|
||||
|
||||
public static void main(String[] args){
|
||||
|
||||
DataSet orDataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 0.0F)), new Label(0.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(0.0F, 1.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 0.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 1.0F)), new Label(1.0F))
|
||||
));
|
||||
|
||||
DataSet dataSet = new DataSet(Map.ofEntries(
|
||||
Map.entry(new DataSetEntry(List.of(1.0F, 6.0F)), new Label(1.0F)),
|
||||
Map.entry(new DataSetEntry(List.of(7.0F, 9.0F)), new Label(-1.0F)),
|
||||
@@ -51,7 +58,12 @@ public class Main {
|
||||
|
||||
Neuron n = new SimplePerceptron(syns, bias, new Linear());
|
||||
GradientDescentTraining st = new GradientDescentTraining();
|
||||
st.train(n, 0.0003F, dataSet);
|
||||
}
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
st.train(n, 0.03F, orDataSet.toNormalized());
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
System.out.printf("Training completed in %.2f s%n", (end - start) / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user