diff --git a/src/main/java/net/echo/brain4j/training/updater/Updater.java b/src/main/java/net/echo/brain4j/training/updater/Updater.java new file mode 100644 index 0000000..ed66881 --- /dev/null +++ b/src/main/java/net/echo/brain4j/training/updater/Updater.java @@ -0,0 +1,20 @@ +package net.echo.brain4j.training.updater; + +import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.structure.Synapse; + +import java.util.List; + +public abstract class Updater { + + public void postInitialize() { + } + + public void postIteration(List layers, double learningRate) { + } + + public void postFit(List layers, double learningRate) { + } + + public abstract void acknowledgeChange(Synapse synapse, double change, double learningRate); +} diff --git a/src/main/java/net/echo/brain4j/training/updater/impl/NormalUpdater.java b/src/main/java/net/echo/brain4j/training/updater/impl/NormalUpdater.java new file mode 100644 index 0000000..d83c19a --- /dev/null +++ b/src/main/java/net/echo/brain4j/training/updater/impl/NormalUpdater.java @@ -0,0 +1,50 @@ +package net.echo.brain4j.training.updater.impl; + +import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.structure.Neuron; +import net.echo.brain4j.structure.Synapse; +import net.echo.brain4j.training.updater.Updater; + +import java.util.Arrays; +import java.util.List; + +public class NormalUpdater extends Updater { + + private Synapse[] synapses; + private Double[] gradients; + private int minSynapseIndex; + + @Override + public void postInitialize() { + this.synapses = new Synapse[Synapse.ID_COUNTER]; + this.gradients = new Double[Synapse.ID_COUNTER]; + Arrays.fill(gradients, 0.0); + } + + @Override + public void postFit(List layers, double learningRate) { + for (int i = 0; i < gradients.length; i++) { + Synapse synapse = synapses[i]; + double gradient = gradients[i]; + + synapse.setWeight(synapse.getWeight() + learningRate * gradient); + } + + for (Layer layer : layers) { + for (Neuron neuron : layer.getNeurons()) { + double deltaBias = learningRate * neuron.getDelta(); + neuron.setBias(neuron.getBias() + deltaBias); + } + } + + Arrays.fill(gradients, 0.0); + } + + @Override + public void acknowledgeChange(Synapse synapse, double change, double learningRate) { + int id = synapse.getSynapseId(); + + synapses[id] = synapse; + gradients[id] += change; + } +} diff --git a/src/main/java/net/echo/brain4j/training/updater/impl/StochasticUpdater.java b/src/main/java/net/echo/brain4j/training/updater/impl/StochasticUpdater.java new file mode 100644 index 0000000..12916c1 --- /dev/null +++ b/src/main/java/net/echo/brain4j/training/updater/impl/StochasticUpdater.java @@ -0,0 +1,26 @@ +package net.echo.brain4j.training.updater.impl; + +import net.echo.brain4j.layer.Layer; +import net.echo.brain4j.structure.Neuron; +import net.echo.brain4j.structure.Synapse; +import net.echo.brain4j.training.updater.Updater; + +import java.util.List; + +public class StochasticUpdater extends Updater { + + @Override + public void postIteration(List layers, double learningRate) { + for (Layer layer : layers) { + for (Neuron neuron : layer.getNeurons()) { + double deltaBias = learningRate * neuron.getDelta(); + neuron.setBias(neuron.getBias() + deltaBias); + } + } + } + + @Override + public void acknowledgeChange(Synapse synapse, double change, double learningRate) { + synapse.setWeight(synapse.getWeight() + change); + } +}