Add model visualization
This commit is contained in:
@@ -1,23 +1,19 @@
|
||||
package com.naaturel.ANN;
|
||||
|
||||
import com.naaturel.ANN.domain.abstraction.Model;
|
||||
import com.naaturel.ANN.domain.abstraction.Network;
|
||||
import com.naaturel.ANN.domain.model.neuron.Neuron;
|
||||
import com.naaturel.ANN.domain.abstraction.Trainer;
|
||||
import com.naaturel.ANN.implementation.gradientDescent.Linear;
|
||||
import com.naaturel.ANN.implementation.multiLayers.Sigmoid;
|
||||
import com.naaturel.ANN.implementation.multiLayers.TanH;
|
||||
import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigDto;
|
||||
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSet;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
|
||||
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
|
||||
import com.naaturel.ANN.domain.model.neuron.*;
|
||||
import com.naaturel.ANN.infrastructure.graph.GraphVisualizer;
|
||||
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
|
||||
import com.naaturel.ANN.infrastructure.persistence.ModelSnapshot;
|
||||
import com.naaturel.ANN.infrastructure.visualization.ModelVisualizer;
|
||||
|
||||
import java.io.Console;
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
@@ -53,6 +49,10 @@ public class Main {
|
||||
network = snapshot.getModel();
|
||||
}
|
||||
//plotGraph(dataset, network);
|
||||
|
||||
new ModelVisualizer(network)
|
||||
.withWeights(true)
|
||||
.display();
|
||||
}
|
||||
|
||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||
@@ -86,16 +86,16 @@ public class Main {
|
||||
private static void plotGraph(DataSet dataset, Model network){
|
||||
GraphVisualizer visualizer = new GraphVisualizer();
|
||||
|
||||
for (DataSetEntry entry : dataset) {
|
||||
/*for (DataSetEntry entry : dataset) {
|
||||
List<Float> label = dataset.getLabelsAsFloat(entry);
|
||||
label.forEach(l -> {
|
||||
visualizer.addPoint("Label " + l,
|
||||
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
|
||||
});
|
||||
}
|
||||
}*/
|
||||
|
||||
float min = -5F;
|
||||
float max = 5F;
|
||||
float min = -50F;
|
||||
float max = 50F;
|
||||
float step = 0.03F;
|
||||
for (float x = min; x < max; x+=step){
|
||||
for (float y = min; y < max; y+=step){
|
||||
|
||||
Reference in New Issue
Block a user