Minor changes

This commit is contained in:
2026-05-09 16:44:12 +02:00
parent b253fb74ee
commit dd664efa92
4 changed files with 68 additions and 35 deletions

View File

@@ -1,14 +1,14 @@
{
"model": {
"new": true,
"parameters": [2, 4, 2, 1],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/best-snapshot.json"
"parameters": [5, 2, 3],
"path": "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/snapshots/snapshot-test.json"
},
"training" : {
"learning_rate" : 0.0003,
"max_epoch" : 5000
},
"dataset" : {
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/LangageDesSignes/data_formatted.csv"
"path" : "C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_3_1.csv"
}
}

View File

@@ -8,6 +8,7 @@ import com.naaturel.ANN.implementation.training.GradientBackpropagationTraining;
import com.naaturel.ANN.infrastructure.config.ConfigDto;
import com.naaturel.ANN.infrastructure.config.ConfigLoader;
import com.naaturel.ANN.infrastructure.dataset.DataSet;
import com.naaturel.ANN.infrastructure.dataset.DataSetEntry;
import com.naaturel.ANN.infrastructure.dataset.DatasetExtractor;
import com.naaturel.ANN.domain.model.neuron.*;
import com.naaturel.ANN.infrastructure.visualization.GraphVisualizer;
@@ -20,6 +21,8 @@ public class Main {
public static void main(String[] args) throws Exception {
Scanner sc = new Scanner(System.in);
ConfigDto config = ConfigLoader.load("C:/Users/Laurent/Desktop/ANN-framework/config.json");
boolean newModel = config.getModelProperty("new", Boolean.class);
@@ -29,7 +32,12 @@ public class Main {
float learningRate = config.getTrainingProperty("learning_rate", Double.class).floatValue();
String datasetPath = config.getDatasetProperty("path", String.class);
int nbrClass = 5;
System.out.printf("How many classes ? [%d] : ", modelParameters[modelParameters.length-1]);
String input = sc.nextLine().trim();
int nbrClass = input.isEmpty() ? modelParameters[modelParameters.length-1] : Integer.parseInt(input);
System.out.println();
DataSet dataset = new DatasetExtractor().extract(datasetPath, nbrClass);
int nbrInput = dataset.getNbrInputs();
@@ -48,7 +56,8 @@ public class Main {
snapshot.loadFromFile(modelPath);
network = snapshot.getModel();
}
//plotGraph(dataset, network);
plotGraph(dataset, network);
new ModelVisualizer(network)
.withWeights(true)
@@ -84,23 +93,35 @@ public class Main {
}
private static void plotGraph(DataSet dataset, Model network){
if(dataset.getNbrInputs() != 2) return;
GraphVisualizer visualizer = new GraphVisualizer();
float min = -10F;
float max = 10F;
float step = Math.abs(max - min)/300;
/*for (DataSetEntry entry : dataset) {
//plot labels
for (DataSetEntry entry : dataset) {
List<Float> label = dataset.getLabelsAsFloat(entry);
label.forEach(l -> {
visualizer.addPoint("Label " + l,
int expectedClass = 0;
for (int i = 1; i < label.size(); i++) {
if (label.get(i) > label.get(expectedClass)) expectedClass = i;
}
visualizer.addPoint("Label " + expectedClass,
entry.getData().get(0).getValue(), entry.getData().get(1).getValue());
});
}*/
}
float min = -50F;
float max = 50F;
float step = 0.03F;
//plot predictions
for (float x = min; x < max; x+=step){
for (float y = min; y < max; y+=step){
float[] predictions = network.predict(new float[]{x, y});
visualizer.addPoint(Float.toString(Math.round(predictions[0])), x, y);
List<Float> predictions = new ArrayList<>();
for (float p : network.predict(new float[]{x, y})) predictions.add(p);
float highest = Collections.max(predictions);
int predictedClass = predictions.indexOf(highest);
visualizer.addPoint("Predict " + predictedClass, x, y);
}
}

View File

@@ -84,7 +84,7 @@ public class TrainingPipeline {
System.out.printf("[Training finished in %.3fs]\n", (end-start)/1000.0);
}
System.out.printf("[Final global error] : %f\n", ctx.globalLoss);
//if(this.visualization) this.visualize(ctx);
System.out.println("Accuracy: " + accuracyCheck(ctx) + "%");
}
private void executeSteps(TrainingContext ctx){
@@ -107,26 +107,28 @@ public class TrainingPipeline {
}
}
/*private void visualize(TrainingContext ctx){
AtomicInteger neuronIndex = new AtomicInteger(0);
ctx.model.forEachNeuron(n -> {
List<Float> weights = new ArrayList<>();
n.forEachSynapse(syn -> weights.add(syn.getWeight()));
private float accuracyCheck(TrainingContext ctx){
int correct = 0;
for (DataSetEntry entry : ctx.dataset) {
float[] predictions = ctx.model.predict(entry.getDataAsFloat());
List<Float> labels = ctx.dataset.getLabelsAsFloat(entry);
float b = weights.get(0);
float w1 = weights.get(1);
float w2 = weights.get(2);
this.visualizer.addEquation("boundary_" + neuronIndex.getAndIncrement(), w1, w2, b, -3, 3);
});
int i = 0;
for(DataSetEntry entry : ctx.dataset){
List<Input> inputs = entry.getData();
this.visualizer.addPoint("p"+i, inputs.get(0).getValue(), inputs.get(1).getValue());
this.visualizer.addPoint("p"+i, inputs.get(0).getValue()+0.01F, inputs.get(1).getValue()+0.01F);
i++;
// argmax of predictions
int predictedClass = 0;
for (int i = 1; i < predictions.length; i++) {
if (predictions[i] > predictions[predictedClass]) predictedClass = i;
}
// argmax of labels
int expectedClass = 0;
for (int i = 1; i < labels.size(); i++) {
if (labels.get(i) > labels.get(expectedClass)) expectedClass = i;
}
if (predictedClass == expectedClass) correct++;
}
return 100f * correct / ctx.dataset.size();
}
this.visualizer.buildLineGraph();
}*/
}

View File

@@ -17,6 +17,16 @@ public class DataSetEntry implements Iterable<Input> {
}
public float[] getDataAsFloat() {
float[] res = new float[data.size()];
for(int i = 0; i < data.size(); i++){
res[i] = data.get(i).getValue();
}
return res;
}
@Override
public int hashCode() {
return Objects.hash(this.data);