Add plotting support for both one-hot vector and single integer classes
This commit is contained in:
@@ -37,7 +37,6 @@ public class Main {
|
||||
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
|
||||
System.out.println();
|
||||
|
||||
|
||||
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
@@ -101,12 +100,18 @@ public class Main {
|
||||
float max = 10F;
|
||||
float step = Math.abs(max - min)/300;
|
||||
|
||||
boolean isOneHot = dataset.getNbrLabels() > 1;
|
||||
|
||||
//plot labels
|
||||
for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
int expectedClass = 0;
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
if (isOneHot) {
|
||||
for (int i = 1; i < label.size(); i++) {
|
||||
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
|
||||
}
|
||||
} else {
|
||||
expectedClass = Math.round(label.getFirst());
|
||||
}
|
||||
visualizer.addPoint("Label " + expectedClass,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
@@ -116,12 +121,16 @@ public class Main {
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
List<Float> predictions = new ArrayList<>();
|
||||
|
||||
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
|
||||
float highest = Collections.max(predictions);
|
||||
int predictedClass = predictions.indexOf(highest);
|
||||
|
||||
visualizer.addPoint("Predict " + predictedClass, x, y);
|
||||
int predictedClass;
|
||||
if(isOneHot){
|
||||
predictedClass = predictions.indexOf(Collections.max(predictions));
|
||||
} else {
|
||||
predictedClass = Math.round(predictions.getFirst());
|
||||
}
|
||||
|
||||
visualizer.addPoint("Class " + predictedClass, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user