Add plotting support for both one-hot vector and single integer classes

This commit is contained in:
2026-05-11 00:48:18 +02:00
parent dd664efa92
commit 45cdab0373
2 changed files with 20 additions and 11 deletions

View File

@@ -1,14 +1,14 @@
{ {
"model": { "model": {
"new": true, "new": true,
"parameters": [5, 2, 3], "parameters": [5, 5, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json" "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
}, },
"training" : { "training" : {
"learning_rate" : 0.0003, "learning_rate" : 0.03,
"max_epoch" : 5000 "max_epoch" : 5000
}, },
"dataset" : { "dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv" "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
} }
} }

View File

@@ -37,7 +37,6 @@ public class Main {
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input); int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println(); System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
@@ -101,13 +100,19 @@ public class Main {
float max = 10F; float max = 10F;
float step = Math.abs(max - min)/300; float step = Math.abs(max - min)/300;
boolean isOneHot = dataset.getNbrLabels() > 1;
//plot labels //plot labels
for (DataSetEntry entry : dataset) { for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry); List<Float> label = dataset.getLabelsAsFloat(entry);
int expectedClass = 0; int expectedClass = 0;
if (isOneHot) {
for (int i = 1; i < label.size(); i++) { for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i; if (label.get(i) > label.get(expectedClass)) expectedClass = i;
} }
} else {
expectedClass = Math.round(label.getFirst());
}
visualizer.addPoint("Label " + expectedClass, visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
} }
@@ -116,12 +121,16 @@ public class Main {
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
List<Float> predictions = new ArrayList<>(); List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p); for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y); int predictedClass;
if(isOneHot){
predictedClass = predictions.indexOf(Collections.max(predictions));
} else {
predictedClass = Math.round(predictions.getFirst());
}
visualizer.addPoint("Class " + predictedClass, x, y);
} }
} }