Skip to content

Commit

Permalink
Fixed backpropagation, now converges better
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Oct 24, 2024
1 parent 588e7ef commit 06e0f95
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/main/java/net/echo/brain4j/layer/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void connectAll(Layer nextLayer, double bound) {
for (Neuron nextNeuron : nextLayer.neurons) {
Synapse synapse = new Synapse(neuron, nextNeuron, bound);

neuron.setSynapse(synapse);
neuron.addSynapse(synapse);

synapses.add(synapse);
}
Expand Down
13 changes: 8 additions & 5 deletions src/main/java/net/echo/brain4j/structure/Neuron.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@

import com.google.gson.annotations.Expose;

import java.util.ArrayList;
import java.util.List;

public class Neuron {

private Synapse synapse;
private final List<Synapse> synapses = new ArrayList<>();
private double delta;
private double value;
@Expose private double bias = 2 * Math.random() - 1;

public Synapse getSynapse() {
return synapse;
public List<Synapse> getSynapses() {
return synapses;
}

public void setSynapse(Synapse synapse) {
this.synapse = synapse;
public void addSynapse(Synapse synapse) {
this.synapses.add(synapse);
}

public double getDelta() {
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/net/echo/brain4j/training/BackPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ private void backpropagate(double[] targets, double[] outputs, double learningRa

for (Neuron neuron : layer.getNeurons()) {
double output = neuron.getValue();
double error = 0.0;

Synapse synapse = neuron.getSynapse();
for (Synapse synapse : neuron.getSynapses()) {
double error = synapse.getWeight() * synapse.getOutputNeuron().getDelta();

double delta = error * layer.getActivation().getFunction().getDerivative(output);
neuron.setDelta(delta);

synapse.setWeight(synapse.getWeight() + delta * synapse.getInputNeuron().getValue());
}
/*Synapse synapse = neuron.getSynapse();
error += synapse.getWeight() * synapse.getOutputNeuron().getDelta();
double delta = error * layer.getActivation().getFunction().getDerivative(output);
neuron.setDelta(delta);
neuron.setDelta(delta);*/
}
}

Expand Down
10 changes: 7 additions & 3 deletions src/test/java/XorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public static void main(String[] args) {
Model network = new FeedForwardModel();

network.add(new DenseLayer(2, Activations.LINEAR));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(256, Activations.RELU));
network.add(new DenseLayer(256, Activations.RELU));
network.add(new DenseLayer(256, Activations.RELU));
network.add(new DenseLayer(1, Activations.SIGMOID));

network.compile(InitializationType.XAVIER, LossFunctions.BINARY_CROSS_ENTROPY, new Adam(0.001));
Expand All @@ -38,6 +38,10 @@ public static void main(String[] args) {
while (epoches < 500) {
epoches++;

double error = network.evaluate(training);

System.out.println("Epoch " + epoches + " with error " + error);

network.fit(training);
}

Expand Down

0 comments on commit 06e0f95

Please sign in to comment.