Skip to content

Commit

Permalink
* fixed he and xavier initializations
Browse files Browse the repository at this point in the history
* added normal and uniform xavier init
  • Loading branch information
xEcho1337 committed Jan 6, 2025
1 parent 90b575e commit f88a8ae
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package net.echo.brain4j.model.initialization;

import net.echo.brain4j.model.initialization.impl.HeInit;
import net.echo.brain4j.model.initialization.impl.LeCunInit;
import net.echo.brain4j.model.initialization.impl.NormalInit;
import net.echo.brain4j.model.initialization.impl.XavierInit;
import net.echo.brain4j.model.initialization.impl.*;

/**
* Enum that defines the different types of weight initialization strategies used for neural networks.
Expand All @@ -24,10 +21,18 @@ public enum WeightInit {
HE(new HeInit()),

/**
* Xavier initialization (also known as Glorot initialization) is designed for layers with sigmoid or tanh activations.
* It initializes weights using a uniform distribution with a variance based on the number of input and output neurons.
* Uniform Xavier initialization is specifically designed for layers with sigmoid or tanh activations.
* It initializes weights using a uniform distribution,
* scaled by the square root of 6 divided by the sum of the number of input and output neurons.
*/
XAVIER(new XavierInit()),
UNIFORM_XAVIER(new UniformXavierInit()),

/**
* Normal Xavier initialization is specifically designed for layers with sigmoid or tanh activations.
* It initializes weights using a normal distribution,
* scaled by the square root of 2 divided by the number of input and output neurons.
*/
NORMAL_XAVIER(new NormalXavierInit()),

/**
* LeCun initialization is specifically designed for layers with the sigmoid or tanh activation functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ public class HeInit implements WeightInitializer {

@Override
public double getBound(int nIn, int nOut) {
return SQRT_OF_6 / Math.sqrt(nIn);
return Math.sqrt(2.0 / nIn);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import net.echo.brain4j.model.initialization.WeightInitializer;

public class XavierInit implements WeightInitializer {
public class NormalXavierInit implements WeightInitializer {

@Override
public double getBound(int nIn, int nOut) {
return SQRT_OF_6 / Math.sqrt(nIn + nOut);
return Math.sqrt(2.0 / (nIn + nOut));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package net.echo.brain4j.model.initialization.impl;

import net.echo.brain4j.model.initialization.WeightInitializer;

public class UniformXavierInit implements WeightInitializer {

@Override
public double getBound(int nIn, int nOut) {
return Math.sqrt(6.0 / (nIn + nOut));
}
}

0 comments on commit f88a8ae

Please sign in to comment.