Skip to content

Commit 984af82

Browse files
author
Max Strange
committed
Tweak CNN
1 parent 654c1d5 commit 984af82

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

pyboat/src/main/scala/models/MoveOrHoldCNN.scala

+12-20
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ case class MoveOrHoldCNN() extends ModelArch {
4444
val builder = new NeuralNetConfiguration.Builder
4545
builder.seed(123)
4646
builder.iterations(1)
47-
builder.learningRate(0.02)
47+
builder.learningRate(0.08)
4848
builder.weightInit(WeightInit.XAVIER)
4949
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
5050
builder.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
@@ -63,7 +63,7 @@ case class MoveOrHoldCNN() extends ModelArch {
6363
cnnBuilder0.biasInit(0.01)
6464
cnnBuilder0.biasLearningRate(0.02)
6565
cnnBuilder0.convolutionMode(ConvolutionMode.Same)
66-
cnnBuilder0.nOut(256) //number of filters in this layer
66+
cnnBuilder0.nOut(128) //number of filters in this layer
6767
cnnBuilder0.activation(Activation.LEAKYRELU)
6868
listBuilder.layer(lindex, cnnBuilder0.build)
6969
lindex += 1
@@ -75,7 +75,7 @@ case class MoveOrHoldCNN() extends ModelArch {
7575
listBuilder.layer(lindex, subsamp0.build)
7676
lindex += 1
7777

78-
val cnnBuilder0_5 = new ConvolutionLayer.Builder(7, 7)
78+
val cnnBuilder0_5 = new ConvolutionLayer.Builder(5, 5)
7979
cnnBuilder0_5.stride(1, 1)
8080
cnnBuilder0_5.padding(1, 1)
8181
cnnBuilder0_5.biasInit(0.01)
@@ -104,7 +104,7 @@ case class MoveOrHoldCNN() extends ModelArch {
104104
listBuilder.layer(lindex, cnnBuilder1.build)
105105
lindex += 1
106106

107-
val cnnBuilder1_5 = new ConvolutionLayer.Builder(3, 3)
107+
val cnnBuilder1_5 = new ConvolutionLayer.Builder(5, 5)
108108
cnnBuilder1_5.stride(1, 1)
109109
cnnBuilder1_5.padding(1, 1)
110110
cnnBuilder1_5.biasInit(0.01)
@@ -122,7 +122,7 @@ case class MoveOrHoldCNN() extends ModelArch {
122122
listBuilder.layer(lindex, subsamp1.build)
123123
lindex += 1
124124

125-
val cnnBuilder2 = new ConvolutionLayer.Builder(3, 3)
125+
val cnnBuilder2 = new ConvolutionLayer.Builder(5, 5)
126126
cnnBuilder2.stride(1, 1)
127127
cnnBuilder2.padding(1, 1)
128128
cnnBuilder2.biasInit(0.01)
@@ -133,7 +133,7 @@ case class MoveOrHoldCNN() extends ModelArch {
133133
listBuilder.layer(lindex, cnnBuilder2.build)
134134
lindex += 1
135135

136-
val cnnBuilder3 = new ConvolutionLayer.Builder(3, 3)
136+
val cnnBuilder3 = new ConvolutionLayer.Builder(5, 5)
137137
cnnBuilder3.stride(1, 1)
138138
cnnBuilder3.padding(1, 1)
139139
cnnBuilder3.biasInit(0.01)
@@ -151,9 +151,9 @@ case class MoveOrHoldCNN() extends ModelArch {
151151
listBuilder.layer(lindex, subsamp2.build)
152152
lindex += 1
153153

154-
val cnnBuilder4 = new ConvolutionLayer.Builder(3, 3)
154+
val cnnBuilder4 = new ConvolutionLayer.Builder(5, 5)
155155
cnnBuilder4.stride(1, 1)
156-
cnnBuilder4.padding(1, 1)
156+
cnnBuilder4.padding(2, 2)
157157
cnnBuilder4.biasInit(0.01)
158158
cnnBuilder4.biasLearningRate(0.02)
159159
cnnBuilder4.convolutionMode(ConvolutionMode.Same)
@@ -162,9 +162,9 @@ case class MoveOrHoldCNN() extends ModelArch {
162162
listBuilder.layer(lindex, cnnBuilder4.build)
163163
lindex += 1
164164

165-
val cnnBuilder5 = new ConvolutionLayer.Builder(3, 3)
165+
val cnnBuilder5 = new ConvolutionLayer.Builder(5, 5)
166166
cnnBuilder5.stride(1, 1)
167-
cnnBuilder5.padding(1, 1)
167+
cnnBuilder5.padding(2, 2)
168168
cnnBuilder5.biasInit(0.01)
169169
cnnBuilder5.biasLearningRate(0.02)
170170
cnnBuilder5.convolutionMode(ConvolutionMode.Same)
@@ -219,11 +219,7 @@ case class MoveOrHoldCNN() extends ModelArch {
219219
}
220220

221221
class CNNDataFetcher() extends BaseDataFetcher {
222-
/// TODO: FIXME: Take only the first 5 to learn on /////
223-
val numToTake = 25
224-
var shuffledGameIds = util.Random.shuffle(Database.getAllGameIds()).take(numToTake)
225-
////////////////////////////////////////////////////////
226-
//val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
222+
val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
227223
var curGame = new Game(shuffledGameIds(0))
228224
val nChannels = curGame.getNumChannelsHoldOrMove()
229225
numOutcomes = curGame.getNumOutcomesHoldOrMove()
@@ -290,11 +286,7 @@ class CNNDataFetcher() extends BaseDataFetcher {
290286
def fetchNextGame() = {
291287
val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1)
292288
if (ls.length == 0) {
293-
//TODO FIXME go on forever ////
294-
shuffledGameIds = util.Random.shuffle(Database.getAllGameIds()).take(numToTake)
295-
curGame = new Game(shuffledGameIds(0))
296-
///////////////////////////////
297-
//curGame = null
289+
curGame = null
298290
} else {
299291
curGame = new Game(ls(0))
300292
}

0 commit comments

Comments
 (0)