-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from Adversing/feature/nlp-adversing
Feat: Added WIP NLP actual implementation (to-be-fixed)
- Loading branch information
Showing
25 changed files
with
762 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
package net.echo.brain4j.activation; | ||
|
||
import net.echo.brain4j.structure.Neuron; | ||
|
||
import java.util.List; | ||
|
||
public interface Activation { | ||
|
||
double activate(double input); | ||
double[] activate(double[] input); | ||
|
||
double getDerivative(double input); | ||
void apply(List<Neuron> neurons); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
src/main/java/net/echo/brain4j/activation/impl/SoftmaxActivation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package net.echo.brain4j.activation.impl; | ||
|
||
import net.echo.brain4j.activation.Activation; | ||
import net.echo.brain4j.structure.Neuron; | ||
|
||
import java.util.List; | ||
|
||
public class SoftmaxActivation implements Activation { | ||
@Override | ||
public double activate(double input) { | ||
throw new UnsupportedOperationException("Softmax activation function is not supported for single value"); | ||
} | ||
|
||
@Override | ||
public double[] activate(double[] inputs) { | ||
double maxInput = Double.NEGATIVE_INFINITY; | ||
for (double input : inputs) { | ||
if (input > maxInput) { | ||
maxInput = input; | ||
} | ||
} | ||
|
||
double[] expValues = new double[inputs.length]; | ||
double sum = 0.0; | ||
for (int i = 0; i < inputs.length; i++) { | ||
expValues[i] = Math.exp(inputs[i] - maxInput); | ||
sum += expValues[i]; | ||
} | ||
|
||
for (int i = 0; i < expValues.length; i++) { | ||
expValues[i] /= sum; | ||
} | ||
|
||
return expValues; | ||
} | ||
|
||
@Override | ||
public double getDerivative(double input) { | ||
return input * (1.0 - input); | ||
} | ||
|
||
@Override | ||
public void apply(List<Neuron> neurons) { | ||
double[] values = new double[neurons.size()]; | ||
for (int i = 0; i < neurons.size(); i++) { | ||
values[i] = neurons.get(i).getValue() + neurons.get(i).getBias(); | ||
} | ||
|
||
double[] activatedValues = activate(values); | ||
|
||
for (int i = 0; i < neurons.size(); i++) { | ||
neurons.get(i).setValue(activatedValues[i]); | ||
} | ||
} | ||
} |
36 changes: 36 additions & 0 deletions
36
src/main/java/net/echo/brain4j/activation/impl/TanhActivation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package net.echo.brain4j.activation.impl; | ||
|
||
import net.echo.brain4j.activation.Activation; | ||
import net.echo.brain4j.structure.Neuron; | ||
|
||
import java.util.List; | ||
|
||
public class TanhActivation implements Activation { | ||
|
||
@Override | ||
public double activate(double input) { | ||
return Math.tanh(input); | ||
} | ||
|
||
@Override | ||
public double[] activate(double[] inputs) { | ||
double[] result = new double[inputs.length]; | ||
for (int i = 0; i < inputs.length; i++) { | ||
result[i] = activate(inputs[i]); | ||
} | ||
return result; | ||
} | ||
|
||
@Override | ||
public double getDerivative(double input) { | ||
return 1.0 - Math.pow(Math.tanh(input), 2); | ||
} | ||
|
||
@Override | ||
public void apply(List<Neuron> neurons) { | ||
for (Neuron neuron : neurons) { | ||
double value = neuron.getValue() + neuron.getBias(); | ||
neuron.setValue(activate(value)); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
src/main/java/net/echo/brain4j/loss/impl/CategoricalCrossEntropy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package net.echo.brain4j.loss.impl; | ||
|
||
import net.echo.brain4j.loss.LossFunction; | ||
|
||
public class CategoricalCrossEntropy implements LossFunction { | ||
@Override | ||
public double calculate(double[] expected, double[] actual) { | ||
double sum = 0.0; | ||
for (int i = 0; i < expected.length; i++) { | ||
sum += -expected[i] * Math.log(actual[i] + 1e-15); | ||
} | ||
return sum; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package net.echo.brain4j.nlp.agents; | ||
|
||
import net.echo.brain4j.training.data.DataSet; | ||
|
||
public interface Agent { | ||
String process(String input); | ||
void train(DataSet conversationData); | ||
double evaluate(DataSet testData); | ||
void save(String path); | ||
void load(String path); | ||
} |
6 changes: 6 additions & 0 deletions
6
src/main/java/net/echo/brain4j/nlp/agents/attention/AttentionMechanism.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
package net.echo.brain4j.nlp.agents.attention; | ||
|
||
public interface AttentionMechanism { | ||
double[] attend(double[] input, String contextKey); | ||
} | ||
|
Oops, something went wrong.