Skip to content

Commit

Permalink
Removed some debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Dec 6, 2024
1 parent 94ed2a2 commit 0607fa7
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/main/java/net/echo/brain4j/activation/Activations.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public enum Activations {

LINEAR(new LinearActivation()),
RELU(new ReLUActivation()),
GELU(new GELUActivation()),
LEAKY_RELU(new LeakyReLUActivation()),
SIGMOID(new SigmoidActivation()),
SOFTMAX(new SoftmaxActivation()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public double getDerivative(double input) {
@Override
public void apply(List<Neuron> neurons) {
for (Neuron neuron : neurons) {
neuron.setValue(activate(neuron.getValue())); // Applicare GELU al valore di ciascun neurone
neuron.setValue(activate(neuron.getValue()));
}
}
}
8 changes: 6 additions & 2 deletions src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@ public void applyFunction(Layer previous) {
}
}

public void normalize(Vector input) {
public Vector normalize(Vector input) {
double mean = input.mean();
double variance = input.variance(mean);

double denominator = Math.sqrt(variance + epsilon);

for (int i = 0; i < input.size(); i++) {
double value = input.get(i);
double normalized = (value - mean) / Math.sqrt(variance + epsilon);
double normalized = (value - mean) / denominator;

input.set(i, normalized);
}

return input;
}

private double calculateMean(List<Neuron> inputs) {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/net/echo/brain4j/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public Vector predict(Vector input) {
}

Layer outputLayer = layers.get(layers.size() - 1);

double[] output = new double[outputLayer.getNeurons().size()];

for (int i = 0; i < output.length; i++) {
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/net/echo/brain4j/nlp/model/Transformer.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public List<Vector> transform(List<Vector> embeddings) {
}
}

for (Vector vector : resulting) {
System.out.println("Resulting");
System.out.println(vector);
}

List<Vector> concatEmbeddings = new ArrayList<>(resulting);

for (Vector embedding : resulting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public TransformerEncoder(int numHeads, int contextSize, int dimension, double t
this.attention = new MultiHeadAttention(numHeads, contextSize, dimension, temperature);
this.feedForward = new Model(
new DenseLayer(dimension, Activations.LINEAR),
new DenseLayer(4 * dimension, Activations.RELU),
new DenseLayer(4 * dimension, Activations.GELU),
new DenseLayer(dimension, Activations.LINEAR)
);
}
Expand All @@ -47,14 +47,18 @@ public List<Vector> transform(List<Vector> embeddings) {

for (Vector vector : embeddings) {
Vector embedding = Vector.of(vector.toArray());
normalizer.normalize(embedding);
embedding = normalizer.normalize(embedding);

Vector attended = attention.attend(embedding.toArray());
normalizer.normalize(attended);
attended = normalizer.normalize(attended);

System.out.println("Attended");
System.out.println(attended);
Vector result = feedForward.predict(attended);
normalizer.normalize(result);
result = normalizer.normalize(result);

System.out.println("Result");
System.out.println(result);
resulting.add(result);
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/echo/brain4j/utils/Vector.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ private Vector(double... data) {
}

public static Vector of(double... data) {
return new Vector(data);
return new Vector(Arrays.copyOf(data, data.length));
}

public static Vector random(int size) {
Expand Down
5 changes: 0 additions & 5 deletions src/test/java/antiswear/ToxicCommentClassification.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,9 @@ public static void main(String[] args) {
String phrase = "the pen is on the table";
var embeddings = getEmbeddings(vectors, phrase);

for (var embed : embeddings) {
System.out.println(embed);
}

List<Vector> output = transformer.transform(embeddings);

for (Vector vector : output) {
System.out.println("------------------------------");
System.out.println(vector);
}
}
Expand Down

0 comments on commit 0607fa7

Please sign in to comment.