Skip to content

Commit

Permalink
Better testing
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Oct 23, 2024
1 parent 9b96626 commit 588e7ef
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions src/test/java/XorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public static void main(String[] args) {
Model network = new FeedForwardModel();

network.add(new DenseLayer(2, Activations.LINEAR));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(128, Activations.RELU));
network.add(new DenseLayer(1, Activations.SIGMOID));

network.compile(InitializationType.XAVIER, LossFunctions.BINARY_CROSS_ENTROPY, new Adam(0.001));
Expand All @@ -32,27 +32,16 @@ public static void main(String[] args) {

DataSet training = new DataSet(first, second, third, fourth);

double error;

long start = System.nanoTime();
int epoches = 0;

do {
while (epoches < 500) {
epoches++;

network.fit(training);
}

double evalStart = System.nanoTime();
error = network.evaluate(training);
double evalTook = System.nanoTime() - evalStart;

if (epoches % 100 == 0) {

System.out.println("Epoch #" + epoches + " has error " + error);
System.out.println("Eval took " + (evalTook / 1e6) + "ms");
}
} while (error > 1.0E-4);

double error = network.evaluate(training);
double took = (System.nanoTime() - start) / 1e6;

System.out.println("Took " + took + " ms with an average of " + (took / epoches) + " ms per epoch and error " + error);
Expand Down

0 comments on commit 588e7ef

Please sign in to comment.