From 257114e60b257186c493da982f49cfa3580e58ef Mon Sep 17 00:00:00 2001 From: echo Date: Tue, 7 Jan 2025 17:10:51 +0100 Subject: [PATCH] Fixed Adam and AdamW --- .../training/optimizers/impl/Adam.java | 23 +++--- .../training/optimizers/impl/AdamW.java | 73 ++----------------- .../net/echo/brain4j/utils/GenericUtils.java | 12 +-- src/test/java/XorTest.java | 4 +- 4 files changed, 27 insertions(+), 85 deletions(-) diff --git a/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java b/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java index 98cac68..018f11c 100644 --- a/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java +++ b/src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java @@ -10,16 +10,16 @@ public class Adam extends Optimizer { // Momentum vectors - private ThreadLocal firstMomentum; - private ThreadLocal secondMomentum; + protected ThreadLocal firstMomentum; + protected ThreadLocal secondMomentum; - private double beta1Timestep; - private double beta2Timestep; + protected double beta1Timestep; + protected double beta2Timestep; - private double beta1; - private double beta2; - private double epsilon; - private int timestep = 0; + protected double beta1; + protected double beta2; + protected double epsilon; + protected int timestep = 0; public Adam(double learningRate) { this(learningRate, 0.9, 0.999, 1e-8); @@ -41,7 +41,7 @@ public void postInitialize() { @Override public double update(Synapse synapse, Object... params) { double[] firstMomentum = (double[]) params[0]; - double[] secondMomentum = (double[]) params[0]; + double[] secondMomentum = (double[]) params[1]; double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue(); @@ -69,9 +69,12 @@ public void postIteration(Updater updater, List layers) { this.beta1Timestep = Math.pow(beta1, timestep); this.beta2Timestep = Math.pow(beta2, timestep); + double[] firstMomentum = this.firstMomentum.get(); + double[] secondMomentum = this.secondMomentum.get(); + for (Layer layer : layers) { for (Synapse synapse : layer.getSynapses()) { - double change = update(synapse); + double change = update(synapse, firstMomentum, secondMomentum); updater.acknowledgeChange(synapse, change); } } diff --git a/src/main/java/net/echo/brain4j/training/optimizers/impl/AdamW.java b/src/main/java/net/echo/brain4j/training/optimizers/impl/AdamW.java index 6c73e4b..dd2b938 100644 --- a/src/main/java/net/echo/brain4j/training/optimizers/impl/AdamW.java +++ b/src/main/java/net/echo/brain4j/training/optimizers/impl/AdamW.java @@ -7,23 +7,12 @@ import java.util.List; -public class AdamW extends Optimizer { +public class AdamW extends Adam { - // Momentum vectors - private ThreadLocal firstMomentum; - private ThreadLocal secondMomentum; - - private double beta1Timestep; - private double beta2Timestep; - - private double beta1; - private double beta2; - private double epsilon; private double weightDecay; - private int timestep = 0; public AdamW(double learningRate) { - this(learningRate, 0.01, 0.9, 0.999, 1e-8); + this(learningRate, 0.001); } public AdamW(double learningRate, double weightDecay) { @@ -31,42 +20,16 @@ public AdamW(double learningRate, double weightDecay) { } public AdamW(double learningRate, double weightDecay, double beta1, double beta2, double epsilon) { - super(learningRate); - this.beta1 = beta1; - this.beta2 = beta2; - this.epsilon = epsilon; + super(learningRate, beta1, beta2, epsilon); this.weightDecay = weightDecay; } - @Override - public void postInitialize() { - this.firstMomentum = ThreadLocal.withInitial(() -> new double[Synapse.ID_COUNTER]); - this.secondMomentum = ThreadLocal.withInitial(() -> new double[Synapse.ID_COUNTER]); - } - @Override public double update(Synapse synapse, Object... params) { - double[] firstMomentum = (double[]) params[0]; - double[] secondMomentum = (double[]) params[0]; - - double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue(); - - int synapseId = synapse.getSynapseId(); - - double currentFirstMomentum = firstMomentum[synapseId]; - double currentSecondMomentum = secondMomentum[synapseId]; - - double m = beta1 * currentFirstMomentum + (1 - beta1) * gradient; - double v = beta2 * currentSecondMomentum + (1 - beta2) * gradient * gradient; - - firstMomentum[synapseId] = m; - secondMomentum[synapseId] = v; - - double mHat = m / (1 - beta1Timestep); - double vHat = v / (1 - beta2Timestep); - + double adamValue = super.update(synapse, params); double weightDecayTerm = weightDecay * synapse.getWeight(); - return (learningRate * mHat) / (Math.sqrt(vHat) + epsilon) + weightDecayTerm; + + return adamValue + weightDecayTerm; } @Override @@ -87,30 +50,6 @@ public void postIteration(Updater updater, List layers) { } } - public double getBeta1() { - return beta1; - } - - public void setBeta1(double beta1) { - this.beta1 = beta1; - } - - public double getBeta2() { - return beta2; - } - - public void setBeta2(double beta2) { - this.beta2 = beta2; - } - - public double getEpsilon() { - return epsilon; - } - - public void setEpsilon(double epsilon) { - this.epsilon = epsilon; - } - public double getWeightDecay() { return weightDecay; } diff --git a/src/main/java/net/echo/brain4j/utils/GenericUtils.java b/src/main/java/net/echo/brain4j/utils/GenericUtils.java index 8f33201..1704191 100644 --- a/src/main/java/net/echo/brain4j/utils/GenericUtils.java +++ b/src/main/java/net/echo/brain4j/utils/GenericUtils.java @@ -15,7 +15,7 @@ public class GenericUtils { * @param the type of the enum * @return the best matching enum constant */ - public static > T findBestMatch(double[] outputs, Class clazz) { + public static > T findBestMatch(Vector outputs, Class clazz) { return clazz.getEnumConstants()[indexOfMaxValue(outputs)]; } @@ -25,13 +25,13 @@ public static > T findBestMatch(double[] outputs, Class cla * @param inputs array of input values * @return index of the maximum value */ - public static int indexOfMaxValue(double[] inputs) { + public static int indexOfMaxValue(Vector inputs) { int index = 0; - double max = inputs[0]; + double max = inputs.get(0); - for (int i = 1; i < inputs.length; i++) { - if (inputs[i] > max) { - max = inputs[i]; + for (int i = 1; i < inputs.size(); i++) { + if (inputs.get(i) > max) { + max = inputs.get(i); index = i; } } diff --git a/src/test/java/XorTest.java b/src/test/java/XorTest.java index 5f7f5b3..07daf9a 100644 --- a/src/test/java/XorTest.java +++ b/src/test/java/XorTest.java @@ -25,7 +25,7 @@ public static void main(String[] args) { WeightInit.HE, LossFunctions.BINARY_CROSS_ENTROPY, new AdamW(0.1), - new StochasticUpdater() + new NormalUpdater() ); System.out.println(model.getStats()); @@ -38,7 +38,7 @@ public static void main(String[] args) { DataSet training = new DataSet(first, second, third, fourth); training.partition(1); - trainForBenchmark(model, training); + trainTillError(model, training); } private static void trainForBenchmark(Model model, DataSet data) {