Skip to content

Commit

Permalink
Fixed Adam and AdamW
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Jan 7, 2025
1 parent cbf8b24 commit 257114e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 85 deletions.
23 changes: 13 additions & 10 deletions src/main/java/net/echo/brain4j/training/optimizers/impl/Adam.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
public class Adam extends Optimizer {

// Momentum vectors
private ThreadLocal<double[]> firstMomentum;
private ThreadLocal<double[]> secondMomentum;
protected ThreadLocal<double[]> firstMomentum;
protected ThreadLocal<double[]> secondMomentum;

private double beta1Timestep;
private double beta2Timestep;
protected double beta1Timestep;
protected double beta2Timestep;

private double beta1;
private double beta2;
private double epsilon;
private int timestep = 0;
protected double beta1;
protected double beta2;
protected double epsilon;
protected int timestep = 0;

public Adam(double learningRate) {
this(learningRate, 0.9, 0.999, 1e-8);
Expand All @@ -41,7 +41,7 @@ public void postInitialize() {
@Override
public double update(Synapse synapse, Object... params) {
double[] firstMomentum = (double[]) params[0];
double[] secondMomentum = (double[]) params[0];
double[] secondMomentum = (double[]) params[1];

double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue();

Expand Down Expand Up @@ -69,9 +69,12 @@ public void postIteration(Updater updater, List<Layer> layers) {
this.beta1Timestep = Math.pow(beta1, timestep);
this.beta2Timestep = Math.pow(beta2, timestep);

double[] firstMomentum = this.firstMomentum.get();
double[] secondMomentum = this.secondMomentum.get();

for (Layer layer : layers) {
for (Synapse synapse : layer.getSynapses()) {
double change = update(synapse);
double change = update(synapse, firstMomentum, secondMomentum);
updater.acknowledgeChange(synapse, change);
}
}
Expand Down
73 changes: 6 additions & 67 deletions src/main/java/net/echo/brain4j/training/optimizers/impl/AdamW.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,66 +7,29 @@

import java.util.List;

public class AdamW extends Optimizer {
public class AdamW extends Adam {

// Momentum vectors
private ThreadLocal<double[]> firstMomentum;
private ThreadLocal<double[]> secondMomentum;

private double beta1Timestep;
private double beta2Timestep;

private double beta1;
private double beta2;
private double epsilon;
private double weightDecay;
private int timestep = 0;

public AdamW(double learningRate) {
this(learningRate, 0.01, 0.9, 0.999, 1e-8);
this(learningRate, 0.001);
}

public AdamW(double learningRate, double weightDecay) {
this(learningRate, weightDecay, 0.9, 0.999, 1e-8);
}

public AdamW(double learningRate, double weightDecay, double beta1, double beta2, double epsilon) {
super(learningRate);
this.beta1 = beta1;
this.beta2 = beta2;
this.epsilon = epsilon;
super(learningRate, beta1, beta2, epsilon);
this.weightDecay = weightDecay;
}

@Override
public void postInitialize() {
this.firstMomentum = ThreadLocal.withInitial(() -> new double[Synapse.ID_COUNTER]);
this.secondMomentum = ThreadLocal.withInitial(() -> new double[Synapse.ID_COUNTER]);
}

@Override
public double update(Synapse synapse, Object... params) {
double[] firstMomentum = (double[]) params[0];
double[] secondMomentum = (double[]) params[0];

double gradient = synapse.getOutputNeuron().getDelta() * synapse.getInputNeuron().getValue();

int synapseId = synapse.getSynapseId();

double currentFirstMomentum = firstMomentum[synapseId];
double currentSecondMomentum = secondMomentum[synapseId];

double m = beta1 * currentFirstMomentum + (1 - beta1) * gradient;
double v = beta2 * currentSecondMomentum + (1 - beta2) * gradient * gradient;

firstMomentum[synapseId] = m;
secondMomentum[synapseId] = v;

double mHat = m / (1 - beta1Timestep);
double vHat = v / (1 - beta2Timestep);

double adamValue = super.update(synapse, params);
double weightDecayTerm = weightDecay * synapse.getWeight();
return (learningRate * mHat) / (Math.sqrt(vHat) + epsilon) + weightDecayTerm;

return adamValue + weightDecayTerm;
}

@Override
Expand All @@ -87,30 +50,6 @@ public void postIteration(Updater updater, List<Layer> layers) {
}
}

public double getBeta1() {
return beta1;
}

public void setBeta1(double beta1) {
this.beta1 = beta1;
}

public double getBeta2() {
return beta2;
}

public void setBeta2(double beta2) {
this.beta2 = beta2;
}

public double getEpsilon() {
return epsilon;
}

public void setEpsilon(double epsilon) {
this.epsilon = epsilon;
}

public double getWeightDecay() {
return weightDecay;
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/net/echo/brain4j/utils/GenericUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class GenericUtils {
* @param <T> the type of the enum
* @return the best matching enum constant
*/
public static <T extends Enum<T>> T findBestMatch(double[] outputs, Class<T> clazz) {
public static <T extends Enum<T>> T findBestMatch(Vector outputs, Class<T> clazz) {
return clazz.getEnumConstants()[indexOfMaxValue(outputs)];
}

Expand All @@ -25,13 +25,13 @@ public static <T extends Enum<T>> T findBestMatch(double[] outputs, Class<T> cla
* @param inputs array of input values
* @return index of the maximum value
*/
public static int indexOfMaxValue(double[] inputs) {
public static int indexOfMaxValue(Vector inputs) {
int index = 0;
double max = inputs[0];
double max = inputs.get(0);

for (int i = 1; i < inputs.length; i++) {
if (inputs[i] > max) {
max = inputs[i];
for (int i = 1; i < inputs.size(); i++) {
if (inputs.get(i) > max) {
max = inputs.get(i);
index = i;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/XorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public static void main(String[] args) {
WeightInit.HE,
LossFunctions.BINARY_CROSS_ENTROPY,
new AdamW(0.1),
new StochasticUpdater()
new NormalUpdater()
);

System.out.println(model.getStats());
Expand All @@ -38,7 +38,7 @@ public static void main(String[] args) {
DataSet training = new DataSet(first, second, third, fourth);
training.partition(1);

trainForBenchmark(model, training);
trainTillError(model, training);
}

private static void trainForBenchmark(Model model, DataSet data) {
Expand Down

0 comments on commit 257114e

Please sign in to comment.