diff --git a/config.json b/config.json index 97f4010..802a933 100644 --- a/config.json +++ b/config.json @@ -1,14 +1,14 @@ { "model": { "new": true, - "parameters": [1], + "parameters": [10, 5, 1], "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" }, "training" : { - "learning_rate" : 1.0, + "learning_rate" : 0.002, "max_epoch" : 5000 }, "dataset" : { - "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv" + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv" } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 4b73161..5a66eff 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -21,7 +21,7 @@ public class Main { public static void main(String[] args) throws Exception { - String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"}; + String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient backpropagation"}; Scanner sc = new Scanner(System.in); for (int i = 0; i < types.length; i++) { @@ -77,22 +77,22 @@ public class Main { if(dataset.getNbrInputs() != 2) return; GraphVisualizer visualizer = new GraphVisualizer(); - float min = -10F; - float max = 10F; + float min = 0F; + float max = 15F; float step = Math.abs(max - min)/300; boolean isOneHot = dataset.getNbrLabels() > 1; //plot labels for (DataSetEntry entry : dataset) { - List label = dataset.getLabelsAsFloat(entry); + List labels = dataset.getLabelsAsFloat(entry); int expectedClass = 0; if (isOneHot) { - for (int i = 1; i < label.size(); i++) { - if (label.get(i) > label.get(expectedClass)) expectedClass = i; + for (int i = 1; i < labels.size(); i++) { + if (labels.get(i) > labels.get(expectedClass)) expectedClass = i; } } else { - expectedClass = Math.round(label.getFirst()); + expectedClass = Math.round(labels.getFirst()); } visualizer.addPoint("Label " + expectedClass, entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); @@ -105,12 +105,7 @@ public class Main { List predictions = new ArrayList<>(); for (float p : network.predict(new float[]{x, y})) predictions.add(p); - int predictedClass; - if(isOneHot){ - predictedClass = predictions.indexOf(Collections.max(predictions)); - } else { - predictedClass = Math.round(predictions.getFirst()); - } + int predictedClass = dataset.getPredictedClass(predictions); visualizer.addPoint("Class " + predictedClass, x, y); } diff --git a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java index 8573314..1b22a77 100644 --- a/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java +++ b/src/main/java/com/naaturel/ANN/infrastructure/dataset/DataSet.java @@ -36,7 +36,6 @@ public class DataSet implements Iterable{ return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0); } - public int size() { return data.size(); } @@ -57,6 +56,25 @@ public class DataSet implements Iterable{ return this.data.get(entry).getValues(); } + public int getPredictedClass(List values) { + if (nbrLabels == 1) { + // single value — round to nearest known class + float value = values.getFirst(); + return data.values().stream() + .map(l -> l.getValues().getFirst()) + .distinct() + .min(Comparator.comparingDouble(l -> Math.abs(l - value))) + .map(Math::round) + .orElseThrow(); + } else { + // one-hot — return index of max value + int max = 0; + for (int i = 1; i < values.size(); i++) + if (values.get(i) > values.get(max)) max = i; + return max; + } + } + public DataSet toNormalized() { List entries = this.getData();