From 45cdab0373c7b60e339b5f9d1795fc624100d9d0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 11 May 2026 00:48:18 +0200 Subject: [PATCH] Add plotting support for both one-hot vector and single integer classes --- config.json | 8 ++++---- src/main/java/com/naaturel/ANN/Main.java | 23 ++++++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/config.json b/config.json index 479e407..69ff49f 100644 --- a/config.json +++ b/config.json @@ -1,14 +1,14 @@ { "model": { "new": true, - "parameters": [5, 2, 3], - "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" + "parameters": [5, 5, 1], + "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json" }, "training" : { - "learning_rate" : 0.0003, + "learning_rate" : 0.03, "max_epoch" : 5000 }, "dataset" : { - "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv" + "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv" } } \ No newline at end of file diff --git a/src/main/java/com/naaturel/ANN/Main.java b/src/main/java/com/naaturel/ANN/Main.java index 0c5d71e..ec84328 100644 --- a/src/main/java/com/naaturel/ANN/Main.java +++ b/src/main/java/com/naaturel/ANN/Main.java @@ -37,7 +37,6 @@ public class Main { int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input); System.out.println(); - DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); int nbrInput = dataset.getNbrInputs(); @@ -101,12 +100,18 @@ public class Main { float max = 10F; float step = Math.abs(max - min)/300; + boolean isOneHot = dataset.getNbrLabels() > 1; + //plot labels for (DataSetEntry entry : dataset) { List label = dataset.getLabelsAsFloat(entry); int expectedClass = 0; - for (int i = 1; i < label.size(); i++) { - if (label.get(i) > label.get(expectedClass)) expectedClass = i; + if (isOneHot) { + for (int i = 1; i < label.size(); i++) { + if (label.get(i) > label.get(expectedClass)) expectedClass = i; + } + } else { + expectedClass = Math.round(label.getFirst()); } visualizer.addPoint("Label " + expectedClass, entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); @@ -116,12 +121,16 @@ public class Main { for (float x = min; x < max; x+=step){ for (float y = min; y < max; y+=step){ List predictions = new ArrayList<>(); - for (float p : network.predict(new float[]{x, y})) predictions.add(p); - float highest = Collections.max(predictions); - int predictedClass = predictions.indexOf(highest); - visualizer.addPoint("Predict " + predictedClass, x, y); + int predictedClass; + if(isOneHot){ + predictedClass = predictions.indexOf(Collections.max(predictions)); + } else { + predictedClass = Math.round(predictions.getFirst()); + } + + visualizer.addPoint("Class " + predictedClass, x, y); } }