Fix plotting

This commit is contained in:
2026-05-11 20:23:25 +02:00
parent 9f4a76d9e5
commit 2f37d72efd
3 changed files with 30 additions and 17 deletions

View File

@@ -1,14 +1,14 @@
{
"model": {
"new": true,
"parameters": [1],
"parameters": [10, 5, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
},
"training" : {
"learning_rate" : 1.0,
"learning_rate" : 0.002,
"max_epoch" : 5000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/and.csv"
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_10.csv"
}
}

View File

@@ -21,7 +21,7 @@ public class Main {
public static void main(String[] args) throws Exception {
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient retro-propagation"};
String[] types = {"Simple perceptron", "Gradient descent", "Adaline", "Gradient backpropagation"};
Scanner sc = new Scanner(System.in);
for (int i = 0; i < types.length; i++) {
@@ -77,22 +77,22 @@ public class Main {
if(dataset.getNbrInputs() != 2) return;
GraphVisualizer visualizer = new GraphVisualizer();
float min = -10F;
float max = 10F;
float min = 0F;
float max = 15F;
float step = Math.abs(max - min)/300;
boolean isOneHot = dataset.getNbrLabels() > 1;
//plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
List<Float> labels = dataset.getLabelsAsFloat(entry);
int expectedClass = 0;
if (isOneHot) {
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
for (int i = 1; i < labels.size(); i++) {
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
}
} else {
expectedClass = Math.round(label.getFirst());
expectedClass = Math.round(labels.getFirst());
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
@@ -105,12 +105,7 @@ public class Main {
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
int predictedClass;
if(isOneHot){
predictedClass = predictions.indexOf(Collections.max(predictions));
} else {
predictedClass = Math.round(predictions.getFirst());
}
int predictedClass = dataset.getPredictedClass(predictions);
visualizer.addPoint("Class " + predictedClass, x, y);
}

View File

@@ -36,7 +36,6 @@ public class DataSet implements Iterable<DataSetEntry>{
return firstEntry.map(inputs -> this.data.get(inputs).getValues().size()).orElse(0);
}
public int size() {
return data.size();
}
@@ -57,6 +56,25 @@ public class DataSet implements Iterable<DataSetEntry>{
return this.data.get(entry).getValues();
}
public int getPredictedClass(List<Float> values) {
if (nbrLabels == 1) {
// single value — round to nearest known class
float value = values.getFirst();
return data.values().stream()
.map(l -> l.getValues().getFirst())
.distinct()
.min(Comparator.comparingDouble(l -> Math.abs(l - value)))
.map(Math::round)
.orElseThrow();
} else {
// one-hot — return index of max value
int max = 0;
for (int i = 1; i < values.size(); i++)
if (values.get(i) > values.get(max)) max = i;
return max;
}
}
public DataSet toNormalized() {
List<DataSetEntry> entries = this.getData();