Add plotting support for both one-hot vector and single integer classes

This commit is contained in:
2026-05-11 00:48:18 +02:00
parent dd664efa92
commit 45cdab0373
2 changed files with 20 additions and 11 deletions

View File

@@ -1,14 +1,14 @@
{
"model": {
"new": true,
"parameters": [5, 2, 3],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
"parameters": [5, 5, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-4-12.json"
},
"training" : {
"learning_rate" : 0.0003,
"learning_rate" : 0.03,
"max_epoch" : 5000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/xor.csv"
}
}

View File

@@ -37,7 +37,6 @@ public class Main {
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
@@ -101,13 +100,19 @@ public class Main {
float max = 10F;
float step = Math.abs(max - min)/300;
boolean isOneHot = dataset.getNbrLabels() > 1;
//plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
int expectedClass = 0;
if (isOneHot) {
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
}
} else {
expectedClass = Math.round(label.getFirst());
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}
@@ -116,12 +121,16 @@ public class Main {
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y);
int predictedClass;
if(isOneHot){
predictedClass = predictions.indexOf(Collections.max(predictions));
} else {
predictedClass = Math.round(predictions.getFirst());
}
visualizer.addPoint("Class " + predictedClass, x, y);
}
}