Reworked synapses data structure
This commit is contained in:
@@ -26,7 +26,7 @@ public class Main {
|
||||
DataSet dataset = new DatasetExtractor()
|
||||
.extract("C:/Users/Laurent/Desktop/ANN-framework/src/main/resources/assets/table_4_12.csv", nbrClass);
|
||||
|
||||
int[] neuronPerLayer = new int[]{10, 5, 10, dataset.getNbrLabels()};
|
||||
int[] neuronPerLayer = new int[]{100, 100, 50, dataset.getNbrLabels()};
|
||||
int nbrInput = dataset.getNbrInputs();
|
||||
|
||||
FullyConnectedNetwork network = createNetwork(neuronPerLayer, nbrInput);
|
||||
@@ -34,7 +34,7 @@ public class Main {
|
||||
System.out.println(network.synCount());
|
||||
|
||||
Trainer trainer = new GradientBackpropagationTraining();
|
||||
trainer.train(0.01F, 2000, network, dataset);
|
||||
trainer.train(0.001F, 2000, network, dataset);
|
||||
|
||||
plotGraph(dataset, network);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ public interface Model {
|
||||
int neuronCount();
|
||||
int indexInLayerOf(Neuron n);
|
||||
void forEachNeuron(Consumer<Neuron> consumer);
|
||||
void forEachSynapse(Consumer<Synapse> consumer);
|
||||
//void forEachSynapse(Consumer<Synapse> consumer);
|
||||
void forEachOutputNeurons(Consumer<Neuron> consumer);
|
||||
void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer);
|
||||
List<Float> predict(List<Input> inputs);
|
||||
|
||||
@@ -51,13 +51,6 @@ public class FullyConnectedNetwork implements Model {
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
for(Layer l : this.layers){
|
||||
l.forEachSynapse(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachNeuron(Consumer<Neuron> consumer) {
|
||||
for(Layer l : this.layers){
|
||||
|
||||
@@ -54,12 +54,12 @@ public class Layer implements Model {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
/*@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
for (Neuron n : this.neurons){
|
||||
n.forEachSynapse(consumer);
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
|
||||
@@ -7,40 +7,37 @@ import java.util.function.Consumer;
|
||||
|
||||
public class Neuron implements Model {
|
||||
|
||||
protected int id;
|
||||
protected final Synapse[] synapses;
|
||||
protected Bias bias;
|
||||
protected ActivationFunction activationFunction;
|
||||
protected Float output;
|
||||
protected Float weightedSum;
|
||||
protected final float[] weights;
|
||||
protected final float[] inputs;
|
||||
private final int id;
|
||||
private float output;
|
||||
private final float[] weights;
|
||||
private final float[] inputs;
|
||||
private final ActivationFunction activationFunction;
|
||||
|
||||
public Neuron(int id, Synapse[] synapses, Bias bias, ActivationFunction func){
|
||||
this.id = id;
|
||||
this.synapses = synapses;
|
||||
this.bias = bias;
|
||||
this.activationFunction = func;
|
||||
this.output = null;
|
||||
this.weightedSum = null;
|
||||
|
||||
weights = new float[synapses.length];
|
||||
inputs = new float[synapses.length];
|
||||
weights = new float[synapses.length+1]; //takes the bias into account
|
||||
inputs = new float[synapses.length+1]; //takes the bias into account
|
||||
|
||||
weights[0] = bias.getWeight();
|
||||
inputs[0] = bias.getInput();
|
||||
for (int i = 0; i < synapses.length; i++){
|
||||
weights[i+1] = synapses[i].getWeight();
|
||||
inputs[i+1] = synapses[i].getInput();
|
||||
}
|
||||
}
|
||||
|
||||
public void updateBias(Weight weight) {
|
||||
this.bias.setWeight(weight.getValue());
|
||||
public void setWeight(int index, float value) {
|
||||
this.weights[index] = value;
|
||||
}
|
||||
|
||||
public void updateWeight(int index, Weight weight) {
|
||||
this.synapses[index].setWeight(weight.getValue());
|
||||
public float getWeight(int index) {
|
||||
return this.weights[index];
|
||||
}
|
||||
|
||||
protected void setInputs(List<Input> inputs){
|
||||
for(int i = 0; i < inputs.size() && i < synapses.length; i++){
|
||||
Synapse syn = this.synapses[i];
|
||||
syn.setInput(inputs.get(i));
|
||||
}
|
||||
public float getInput(int index) {
|
||||
return this.inputs[index];
|
||||
}
|
||||
|
||||
public ActivationFunction getActivationFunction(){
|
||||
@@ -51,21 +48,13 @@ public class Neuron implements Model {
|
||||
return this.output;
|
||||
}
|
||||
|
||||
public float getWeight(int index){
|
||||
return this.synapses[index].getWeight();
|
||||
}
|
||||
|
||||
public float getWeightedSum(){
|
||||
return this.weightedSum;
|
||||
}
|
||||
|
||||
public float calculateWeightedSum() {
|
||||
this.weightedSum = 0F;
|
||||
this.weightedSum += this.bias.getWeight() * this.bias.getInput();
|
||||
for(Synapse syn : this.synapses){
|
||||
this.weightedSum += syn.getWeight() * syn.getInput();
|
||||
int count = synCount();
|
||||
float weightedSum = 0F;
|
||||
for (int i = 0; i < count; i++){
|
||||
weightedSum += weights[i] * inputs[i];
|
||||
}
|
||||
return this.weightedSum;
|
||||
return weightedSum;
|
||||
}
|
||||
|
||||
public int getId(){
|
||||
@@ -74,7 +63,7 @@ public class Neuron implements Model {
|
||||
|
||||
@Override
|
||||
public int synCount() {
|
||||
return this.synapses.length+1; //take the bias into account
|
||||
return this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -99,14 +88,6 @@ public class Neuron implements Model {
|
||||
consumer.accept(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachSynapse(Consumer<Synapse> consumer) {
|
||||
consumer.accept(this.bias);
|
||||
for (Synapse syn : this.synapses){
|
||||
consumer.accept(syn);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void forEachOutputNeurons(Consumer<Neuron> consumer) {
|
||||
consumer.accept(this);
|
||||
@@ -116,4 +97,11 @@ public class Neuron implements Model {
|
||||
public void forEachNeuronConnectedTo(Neuron n, Consumer<Neuron> consumer) {
|
||||
throw new UnsupportedOperationException("Neurons have no connection with themselves");
|
||||
}
|
||||
|
||||
private void setInputs(List<Input> values){
|
||||
for(int i = 0; i < values.size(); i++){
|
||||
inputs[i+1] = values.get(i).getValue();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ public class TrainingPipeline {
|
||||
System.out.printf("[Training finished in %.3fs]", (end-start)/1000.0);
|
||||
}
|
||||
|
||||
if(this.visualization) this.visualize(ctx);
|
||||
//if(this.visualization) this.visualize(ctx);
|
||||
}
|
||||
|
||||
private void executeSteps(TrainingContext ctx){
|
||||
@@ -109,7 +109,7 @@ public class TrainingPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
private void visualize(TrainingContext ctx){
|
||||
/*private void visualize(TrainingContext ctx){
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
ctx.model.forEachNeuron(n -> {
|
||||
List<Float> weights = new ArrayList<>();
|
||||
@@ -129,6 +129,6 @@ public class TrainingPipeline {
|
||||
i++;
|
||||
}
|
||||
this.visualizer.buildLineGraph();
|
||||
}
|
||||
}*/
|
||||
|
||||
}
|
||||
|
||||
@@ -14,12 +14,14 @@ public class GradientDescentCorrectionStrategy implements AlgorithmStep {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
AtomicInteger i = new AtomicInteger(0);
|
||||
context.model.forEachSynapse(syn -> {
|
||||
float corrector = context.correctorTerms.get(i.get());
|
||||
float c = syn.getWeight() + corrector;
|
||||
syn.setWeight(c);
|
||||
i.incrementAndGet();
|
||||
int[] globalSynIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
for(int i = 0; i < n.synCount(); i++){
|
||||
float corrector = context.correctorTerms.get(globalSynIndex[0]);
|
||||
float c = n.getWeight(i) + corrector;
|
||||
n.setWeight(i, c);
|
||||
globalSynIndex[0]++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,13 +22,12 @@ public class GradientDescentErrorStrategy implements AlgorithmStep {
|
||||
context.model.forEachNeuron(neuron -> {
|
||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||
|
||||
neuron.forEachSynapse(syn -> {
|
||||
for(int i = 0; i < neuron.synCount(); i++){
|
||||
float corrector = context.correctorTerms.get(synIndex.get());
|
||||
corrector += context.learningRate * correspondingDelta * syn.getInput();
|
||||
corrector += context.learningRate * correspondingDelta * neuron.getInput(i);
|
||||
context.correctorTerms.set(synIndex.get(), corrector);
|
||||
synIndex.incrementAndGet();
|
||||
});
|
||||
|
||||
}
|
||||
neuronIndex.incrementAndGet();
|
||||
});
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
int[] synIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
float signal = context.errorSignals[n.getId()];
|
||||
n.forEachSynapse(syn -> {
|
||||
inputs[synIndex[0]] = syn.getInput();
|
||||
for (int i = 0; i < n.synCount(); i++){
|
||||
inputs[synIndex[0]] = n.getInput(i);
|
||||
signals[synIndex[0]] = signal;
|
||||
synIndex[0]++;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
float lr = context.learningRate;
|
||||
@@ -44,13 +44,13 @@ public class BackpropagationCorrectionStep implements AlgorithmStep {
|
||||
}
|
||||
|
||||
private void syncWeights() {
|
||||
int[] i = {0};
|
||||
int[] synIndex = {0};
|
||||
context.model.forEachNeuron(n -> {
|
||||
n.forEachSynapse(syn -> {
|
||||
syn.setWeight(syn.getWeight() + context.correctionBuffer[i[0]]);
|
||||
context.correctionBuffer[i[0]] = 0f;
|
||||
i[0]++;
|
||||
});
|
||||
for (int i = 0; i < n.synCount(); i++) {
|
||||
n.setWeight(i, n.getWeight(i) + context.correctionBuffer[synIndex[0]]);
|
||||
context.correctionBuffer[synIndex[0]] = 0f;
|
||||
synIndex[0]++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ public class OutputLayerErrorStep implements AlgorithmStep {
|
||||
|
||||
public OutputLayerErrorStep(GradientBackpropagationContext context){
|
||||
this.context = context;
|
||||
this.expectations = new float[context.model.neuronCount()];
|
||||
this.expectations = new float[context.dataset.getNbrLabels()];
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -18,17 +18,16 @@ public class SimpleCorrectionStep implements AlgorithmStep {
|
||||
public void run() {
|
||||
if(context.expectations.equals(context.predictions)) return;
|
||||
AtomicInteger neuronIndex = new AtomicInteger(0);
|
||||
AtomicInteger synIndex = new AtomicInteger(0);
|
||||
|
||||
context.model.forEachNeuron(neuron -> {
|
||||
float correspondingDelta = context.deltas[neuronIndex.get()];
|
||||
neuron.forEachSynapse(syn -> {
|
||||
float currentW = syn.getWeight();
|
||||
float currentInput = syn.getInput();
|
||||
|
||||
for(int i = 0; i < neuron.synCount(); i++){
|
||||
float currentW = neuron.getWeight(i);
|
||||
float currentInput = neuron.getInput(i);
|
||||
float newValue = currentW + (context.learningRate * correspondingDelta * currentInput);
|
||||
syn.setWeight(newValue);
|
||||
synIndex.incrementAndGet();
|
||||
});
|
||||
neuron.setWeight(i, newValue);
|
||||
}
|
||||
neuronIndex.incrementAndGet();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
@Override
|
||||
public void train(float learningRate, int epoch, Model model, DataSet dataset) {
|
||||
GradientBackpropagationContext context =
|
||||
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size()/3);
|
||||
new GradientBackpropagationContext(model, dataset, learningRate, dataset.size());
|
||||
|
||||
List<AlgorithmStep> steps = List.of(
|
||||
new SimplePredictionStep(context),
|
||||
@@ -35,7 +35,7 @@ public class GradientBackpropagationTraining implements Trainer {
|
||||
.afterEpoch(ctx -> {
|
||||
ctx.globalLoss /= dataset.size();
|
||||
})
|
||||
.withVerbose(true,epoch/10)
|
||||
.withVerbose(false,epoch/10)
|
||||
.withTimeMeasurement(true)
|
||||
.run(context);
|
||||
}
|
||||
|
||||
@@ -41,7 +41,9 @@ public class GradientDescentTraining implements Trainer {
|
||||
GradientDescentTrainingContext gdCtx = (GradientDescentTrainingContext) ctx;
|
||||
gdCtx.globalLoss = 0.0F;
|
||||
gdCtx.correctorTerms.clear();
|
||||
gdCtx.model.forEachSynapse(syn -> gdCtx.correctorTerms.add(0F));
|
||||
for(int i = 0; i < gdCtx.model.synCount(); i++){
|
||||
gdCtx.correctorTerms.add(0F);
|
||||
}
|
||||
})
|
||||
.afterEpoch(ctx -> {
|
||||
context.globalLoss /= context.dataset.size();
|
||||
|
||||
Reference in New Issue
Block a user