Fix weighted sum back
This commit is contained in:
@@ -24,9 +24,9 @@ public class Main {
|
|||||||
int nbrClass = 1;
|
int nbrClass = 1;
|
||||||
|
|
||||||
DataSet dataset = new DatasetExtractor()
|
DataSet dataset = new DatasetExtractor()
|
||||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_2_9.csv", nbrClass);
|
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||||
|
|
||||||
int[] neuronPerLayer = new int[]{1800, 2, 1800, dataset.getNbrLabels()};
|
int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()};
|
||||||
int nbrInput = dataset.getNbrInputs();
|
int nbrInput = dataset.getNbrInputs();
|
||||||
|
|
||||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||||
@@ -36,7 +36,7 @@ public class Main {
|
|||||||
Trainer trainer = new GradientBackpropagationTraining();
|
Trainer trainer = new GradientBackpropagationTraining();
|
||||||
trainer.train(0.01F, 2000, network, dataset);
|
trainer.train(0.01F, 2000, network, dataset);
|
||||||
|
|
||||||
//plotGraph(dataset, network);
|
plotGraph(dataset, network);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
private static FullyConnectedNetwork createNetwork(int[] neuronPerLayer, int nbrInput){
|
||||||
@@ -78,8 +78,8 @@ public class Main {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
float min = -5F;
|
float min = -3F;
|
||||||
float max = 5F;
|
float max = 3F;
|
||||||
float step = 0.03F;
|
float step = 0.03F;
|
||||||
for (float x = min; x < max; x+=step){
|
for (float x = min; x < max; x+=step){
|
||||||
for (float y = min; y < max; y+=step){
|
for (float y = min; y < max; y+=step){
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ public abstract class TrainingContext {
|
|||||||
public TrainingContext(Model model, DataSet dataset) {
|
public TrainingContext(Model model, DataSet dataset) {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
this.dataset = dataset;
|
this.dataset = dataset;
|
||||||
this.deltas = new float[model.neuronCount()];
|
this.deltas = new float[dataset.getNbrLabels()];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,14 +60,12 @@ public class Neuron implements Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public float calculateWeightedSum() {
|
public float calculateWeightedSum() {
|
||||||
float sum = bias.getWeight() * bias.getInput();
|
this.weightedSum = 0F;
|
||||||
|
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
|
||||||
for (int i = 0; i < weights.length; i++) {
|
for(Synapse syn : this.synapses){
|
||||||
sum += weights[i] * inputs[i];
|
this.weightedSum += syn.getWeight() * syn.getInput();
|
||||||
}
|
}
|
||||||
|
return this.weightedSum;
|
||||||
this.weightedSum = sum;
|
|
||||||
return sum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getId(){
|
public int getId(){
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
|||||||
.afterEpoch(ctx -> {
|
.afterEpoch(ctx -> {
|
||||||
ctx.globalLoss /= dataset.size();
|
ctx.globalLoss /= dataset.size();
|
||||||
})
|
})
|
||||||
.withVerbose(false,epoch/10)
|
.withVerbose(true,epoch/10)
|
||||||
.withTimeMeasurement(true)
|
.withTimeMeasurement(true)
|
||||||
.run(context);
|
.run(context);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user