Add plotting support for both one-hot vector and single integer classes

This commit is contained in:
2026-05-11 00:48:18 +02:00
parent dd664efa92
commit 45cdab0373
2 changed files with 20 additions and 11 deletions

View File

@@ -37,7 +37,6 @@ public class Main {
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
@@ -101,12 +100,18 @@ public class Main {
float max = 10F;
float step = Math.abs(max - min)/300;
boolean isOneHot = dataset.getNbrLabels() > 1;
//plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
int expectedClass = 0;
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
if (isOneHot) {
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
}
} else {
expectedClass = Math.round(label.getFirst());
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
@@ -116,12 +121,16 @@ public class Main {
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y);
int predictedClass;
if(isOneHot){
predictedClass = predictions.indexOf(Collections.max(predictions));
} else {
predictedClass = Math.round(predictions.getFirst());
}
visualizer.addPoint("Class " + predictedClass, x, y);
}
}