diff --git a/src/main/java/net/echo/brain4j/activation/Activations.java b/src/main/java/net/echo/brain4j/activation/Activations.java index 10022d4..05b2c13 100644 --- a/src/main/java/net/echo/brain4j/activation/Activations.java +++ b/src/main/java/net/echo/brain4j/activation/Activations.java @@ -6,6 +6,7 @@ public enum Activations { LINEAR(new LinearActivation()), RELU(new ReLUActivation()), + GELU(new GELUActivation()), LEAKY_RELU(new LeakyReLUActivation()), SIGMOID(new SigmoidActivation()), SOFTMAX(new SoftmaxActivation()), diff --git a/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java b/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java index 19d072f..fe5cffa 100644 --- a/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java +++ b/src/main/java/net/echo/brain4j/activation/impl/GELUActivation.java @@ -26,7 +26,7 @@ public double getDerivative(double input) { @Override public void apply(List neurons) { for (Neuron neuron : neurons) { - neuron.setValue(activate(neuron.getValue())); // Applicare GELU al valore di ciascun neurone + neuron.setValue(activate(neuron.getValue())); } } } diff --git a/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java b/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java index f7071cd..580d9e3 100644 --- a/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java +++ b/src/main/java/net/echo/brain4j/layer/impl/LayerNorm.java @@ -35,16 +35,20 @@ public void applyFunction(Layer previous) { } } - public void normalize(Vector input) { + public Vector normalize(Vector input) { double mean = input.mean(); double variance = input.variance(mean); + double denominator = Math.sqrt(variance + epsilon); + for (int i = 0; i < input.size(); i++) { double value = input.get(i); - double normalized = (value - mean) / Math.sqrt(variance + epsilon); + double normalized = (value - mean) / denominator; input.set(i, normalized); } + + return input; } private double calculateMean(List inputs) { diff --git a/src/main/java/net/echo/brain4j/model/Model.java b/src/main/java/net/echo/brain4j/model/Model.java index 685def2..4690e6f 100644 --- a/src/main/java/net/echo/brain4j/model/Model.java +++ b/src/main/java/net/echo/brain4j/model/Model.java @@ -166,6 +166,7 @@ public Vector predict(Vector input) { } Layer outputLayer = layers.get(layers.size() - 1); + double[] output = new double[outputLayer.getNeurons().size()]; for (int i = 0; i < output.length; i++) { diff --git a/src/main/java/net/echo/brain4j/nlp/model/Transformer.java b/src/main/java/net/echo/brain4j/nlp/model/Transformer.java index c3b3553..4af0c6e 100644 --- a/src/main/java/net/echo/brain4j/nlp/model/Transformer.java +++ b/src/main/java/net/echo/brain4j/nlp/model/Transformer.java @@ -38,6 +38,11 @@ public List transform(List embeddings) { } } + for (Vector vector : resulting) { + System.out.println("Resulting"); + System.out.println(vector); + } + List concatEmbeddings = new ArrayList<>(resulting); for (Vector embedding : resulting) { diff --git a/src/main/java/net/echo/brain4j/nlp/model/layers/TransformerEncoder.java b/src/main/java/net/echo/brain4j/nlp/model/layers/TransformerEncoder.java index 72fcdb9..e149fdb 100644 --- a/src/main/java/net/echo/brain4j/nlp/model/layers/TransformerEncoder.java +++ b/src/main/java/net/echo/brain4j/nlp/model/layers/TransformerEncoder.java @@ -23,7 +23,7 @@ public TransformerEncoder(int numHeads, int contextSize, int dimension, double t this.attention = new MultiHeadAttention(numHeads, contextSize, dimension, temperature); this.feedForward = new Model( new DenseLayer(dimension, Activations.LINEAR), - new DenseLayer(4 * dimension, Activations.RELU), + new DenseLayer(4 * dimension, Activations.GELU), new DenseLayer(dimension, Activations.LINEAR) ); } @@ -47,14 +47,18 @@ public List transform(List embeddings) { for (Vector vector : embeddings) { Vector embedding = Vector.of(vector.toArray()); - normalizer.normalize(embedding); + embedding = normalizer.normalize(embedding); Vector attended = attention.attend(embedding.toArray()); - normalizer.normalize(attended); + attended = normalizer.normalize(attended); + System.out.println("Attended"); + System.out.println(attended); Vector result = feedForward.predict(attended); - normalizer.normalize(result); + result = normalizer.normalize(result); + System.out.println("Result"); + System.out.println(result); resulting.add(result); } diff --git a/src/main/java/net/echo/brain4j/utils/Vector.java b/src/main/java/net/echo/brain4j/utils/Vector.java index 10bcaf9..98b2c0a 100644 --- a/src/main/java/net/echo/brain4j/utils/Vector.java +++ b/src/main/java/net/echo/brain4j/utils/Vector.java @@ -16,7 +16,7 @@ private Vector(double... data) { } public static Vector of(double... data) { - return new Vector(data); + return new Vector(Arrays.copyOf(data, data.length)); } public static Vector random(int size) { diff --git a/src/test/java/antiswear/ToxicCommentClassification.java b/src/test/java/antiswear/ToxicCommentClassification.java index 44c2ed8..9925507 100644 --- a/src/test/java/antiswear/ToxicCommentClassification.java +++ b/src/test/java/antiswear/ToxicCommentClassification.java @@ -41,14 +41,9 @@ public static void main(String[] args) { String phrase = "the pen is on the table"; var embeddings = getEmbeddings(vectors, phrase); - for (var embed : embeddings) { - System.out.println(embed); - } - List output = transformer.transform(embeddings); for (Vector vector : output) { - System.out.println("------------------------------"); System.out.println(vector); } }