Minor changes

This commit is contained in:
2026-05-09 16:44:12 +02:00
parent b253fb74ee
commit dd664efa92
4 changed files with 68 additions and 35 deletions

View File

@@ -1,14 +1,14 @@
{ {
"model": { "model": {
"new": true, "new": true,
"parameters": [2, 4, 2, 1], "parameters": [5, 2, 3],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json" "path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
}, },
"training" : { "training" : {
"learning_rate" : 0.0003, "learning_rate" : 0.0003,
"max_epoch" : 5000 "max_epoch" : 5000
}, },
"dataset" : { "dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv" "path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
} }
} }

View File

@@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto; import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader; import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet; import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor; import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*; import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer; import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
@@ -20,6 +21,8 @@ public class Main {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
Scanner sc = new Scanner(System.in);
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json"); ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
boolean newModel = config.getModelProperty("new", Boolean.class); boolean newModel = config.getModelProperty("new", Boolean.class);
@@ -29,7 +32,12 @@ public class Main {
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue(); float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class); String datasetPath = config.getDatasetProperty("path", String.class);
int nbrClass = 5; System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
String input = sc.nextLine().trim();
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass); DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs(); int nbrInput = dataset.getNbrInputs();
@@ -48,7 +56,8 @@ public class Main {
snapshot.loadFromFile(modelPath); snapshot.loadFromFile(modelPath);
network = snapshot.getModel(); network = snapshot.getModel();
} }
//plotGraph(dataset, network);
plotGraph(dataset, network);
new ModelVisualizer(network) new ModelVisualizer(network)
.withWeights(true) .withWeights(true)
@@ -84,23 +93,35 @@ public class Main {
} }
private static void plotGraph(DataSet dataset, Model network){ private static void plotGraph(DataSet dataset, Model network){
if(dataset.getNbrInputs() != 2) return;
GraphVisualizer visualizer = new GraphVisualizer(); GraphVisualizer visualizer = new GraphVisualizer();
float min = -10F;
float max = 10F;
float step = Math.abs(max - min)/300;
/*for (DataSetEntry entry : dataset) { //plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry); List<Float> label = dataset.getLabelsAsFloat(entry);
label.forEach(l -> { int expectedClass = 0;
visualizer.addPoint("Label " + l, for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue()); entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
}); }
}*/
float min = -50F; //plot predictions
float max = 50F;
float step = 0.03F;
for (float x = min; x < max; x+=step){ for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){ for (float y = min; y < max; y+=step){
float[] predictions = network.predict(new float[]{x, y}); List<Float> predictions = new ArrayList<>();
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y);
} }
} }

View File

@@ -84,7 +84,7 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0); System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
} }
System.out.printf("[Final global error] : %f\n", ctx.globalLoss); System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
//if(this.visualization) this.visualize(ctx); System.out.println("Accuracy: " + accuracyCheck(ctx) + "%");
} }
private void executeSteps(TrainingContext ctx){ private void executeSteps(TrainingContext ctx){
@@ -107,26 +107,28 @@ public class TrainingPipeline {
} }
} }
/*private void visualize(TrainingContext ctx){ private float accuracyCheck(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0); int correct = 0;
ctx.model.forEachNeuron(n -> { for (DataSetEntry entry : ctx.dataset) {
List<Float> weights = new ArrayList<>(); float[] predictions = ctx.model.predict(entry.getDataAsFloat());
n.forEachSynapse(syn -> weights.add(syn.getWeight())); List<Float> labels = ctx.dataset.getLabelsAsFloat(entry);
float b = weights.get(0); // argmax of predictions
float w1 = weights.get(1); int predictedClass = 0;
float w2 = weights.get(2); for (int i = 1; i < predictions.length; i++) {
if (predictions[i] > predictions[predictedClass]) predictedClass = i;
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3); }
});
int i = 0; // argmax of labels
for(DataSetEntry entry : ctx.dataset){ int expectedClass = 0;
List<Input> inputs = entry.getData(); for (int i = 1; i < labels.size(); i++) {
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue()); if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F); }
i++;
if (predictedClass == expectedClass) correct++;
}
return 100f * correct / ctx.dataset.size();
} }
this.visualizer.buildLineGraph();
}*/
} }

View File

@@ -17,6 +17,16 @@ public class DataSetEntry implements Iterable<Input> {
} }
public float[] getDataAsFloat() {
float[] res = new float[data.size()];
for(int i = 0; i < data.size(); i++){
res[i] = data.get(i).getValue();
}
return res;
}
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(this.data); return Objects.hash(this.data);