Add model visualization

This commit is contained in:
2026-04-04 17:16:18 +02:00
parent 8beb6aa870
commit b253fb74ee
9 changed files with 146 additions and 1240 deletions

View File

@@ -1,23 +1,19 @@
package com.naaturel.ANN;
import com.naaturel.ANN.domain.abstraction.Model;
import com.naaturel.ANN.domain.abstraction.Network;
import com.naaturel.ANN.domain.model.neuron.Neuron;
import com.naaturel.ANN.domain.abstraction.Trainer;
import com.naaturel.ANN.implementation.gradientDescent.Linear;
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
import com.naaturel.ANN.implementation.multiLayers.TanH;
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
import java.io.Console;
import java.util.*;
public class Main {
@@ -53,6 +49,10 @@ public class Main {
network = snapshot.getModel();
}
//plotGraph(dataset, network);
new ModelVisualizer(network)
.withWeights(true)
.display();
}
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
@@ -86,16 +86,16 @@ public class Main {
private static void plotGraph(DataSet dataset, Model network){
GraphVisualizer visualizer = new GraphVisualizer();
for (DataSetEntry entry : dataset) {
/*for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
label.forEach(l -> {
visualizer.addPoint("Label " + l,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
});
}
}*/
float min = -5F;
float max = 5F;
float min = -50F;
float max = 50F;
float step = 0.03F;
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){