@@ -44,7 +44,7 @@ case class MoveOrHoldCNN() extends ModelArch {
44
44
val builder = new NeuralNetConfiguration .Builder
45
45
builder.seed(123 )
46
46
builder.iterations(1 )
47
- builder.learningRate(0.02 )
47
+ builder.learningRate(0.08 )
48
48
builder.weightInit(WeightInit .XAVIER )
49
49
builder.optimizationAlgo(OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT )
50
50
builder.gradientNormalization(GradientNormalization .RenormalizeL2PerLayer )
@@ -63,7 +63,7 @@ case class MoveOrHoldCNN() extends ModelArch {
63
63
cnnBuilder0.biasInit(0.01 )
64
64
cnnBuilder0.biasLearningRate(0.02 )
65
65
cnnBuilder0.convolutionMode(ConvolutionMode .Same )
66
- cnnBuilder0.nOut(256 ) // number of filters in this layer
66
+ cnnBuilder0.nOut(128 ) // number of filters in this layer
67
67
cnnBuilder0.activation(Activation .LEAKYRELU )
68
68
listBuilder.layer(lindex, cnnBuilder0.build)
69
69
lindex += 1
@@ -75,7 +75,7 @@ case class MoveOrHoldCNN() extends ModelArch {
75
75
listBuilder.layer(lindex, subsamp0.build)
76
76
lindex += 1
77
77
78
- val cnnBuilder0_5 = new ConvolutionLayer .Builder (7 , 7 )
78
+ val cnnBuilder0_5 = new ConvolutionLayer .Builder (5 , 5 )
79
79
cnnBuilder0_5.stride(1 , 1 )
80
80
cnnBuilder0_5.padding(1 , 1 )
81
81
cnnBuilder0_5.biasInit(0.01 )
@@ -104,7 +104,7 @@ case class MoveOrHoldCNN() extends ModelArch {
104
104
listBuilder.layer(lindex, cnnBuilder1.build)
105
105
lindex += 1
106
106
107
- val cnnBuilder1_5 = new ConvolutionLayer .Builder (3 , 3 )
107
+ val cnnBuilder1_5 = new ConvolutionLayer .Builder (5 , 5 )
108
108
cnnBuilder1_5.stride(1 , 1 )
109
109
cnnBuilder1_5.padding(1 , 1 )
110
110
cnnBuilder1_5.biasInit(0.01 )
@@ -122,7 +122,7 @@ case class MoveOrHoldCNN() extends ModelArch {
122
122
listBuilder.layer(lindex, subsamp1.build)
123
123
lindex += 1
124
124
125
- val cnnBuilder2 = new ConvolutionLayer .Builder (3 , 3 )
125
+ val cnnBuilder2 = new ConvolutionLayer .Builder (5 , 5 )
126
126
cnnBuilder2.stride(1 , 1 )
127
127
cnnBuilder2.padding(1 , 1 )
128
128
cnnBuilder2.biasInit(0.01 )
@@ -133,7 +133,7 @@ case class MoveOrHoldCNN() extends ModelArch {
133
133
listBuilder.layer(lindex, cnnBuilder2.build)
134
134
lindex += 1
135
135
136
- val cnnBuilder3 = new ConvolutionLayer .Builder (3 , 3 )
136
+ val cnnBuilder3 = new ConvolutionLayer .Builder (5 , 5 )
137
137
cnnBuilder3.stride(1 , 1 )
138
138
cnnBuilder3.padding(1 , 1 )
139
139
cnnBuilder3.biasInit(0.01 )
@@ -151,9 +151,9 @@ case class MoveOrHoldCNN() extends ModelArch {
151
151
listBuilder.layer(lindex, subsamp2.build)
152
152
lindex += 1
153
153
154
- val cnnBuilder4 = new ConvolutionLayer .Builder (3 , 3 )
154
+ val cnnBuilder4 = new ConvolutionLayer .Builder (5 , 5 )
155
155
cnnBuilder4.stride(1 , 1 )
156
- cnnBuilder4.padding(1 , 1 )
156
+ cnnBuilder4.padding(2 , 2 )
157
157
cnnBuilder4.biasInit(0.01 )
158
158
cnnBuilder4.biasLearningRate(0.02 )
159
159
cnnBuilder4.convolutionMode(ConvolutionMode .Same )
@@ -162,9 +162,9 @@ case class MoveOrHoldCNN() extends ModelArch {
162
162
listBuilder.layer(lindex, cnnBuilder4.build)
163
163
lindex += 1
164
164
165
- val cnnBuilder5 = new ConvolutionLayer .Builder (3 , 3 )
165
+ val cnnBuilder5 = new ConvolutionLayer .Builder (5 , 5 )
166
166
cnnBuilder5.stride(1 , 1 )
167
- cnnBuilder5.padding(1 , 1 )
167
+ cnnBuilder5.padding(2 , 2 )
168
168
cnnBuilder5.biasInit(0.01 )
169
169
cnnBuilder5.biasLearningRate(0.02 )
170
170
cnnBuilder5.convolutionMode(ConvolutionMode .Same )
@@ -219,11 +219,7 @@ case class MoveOrHoldCNN() extends ModelArch {
219
219
}
220
220
221
221
class CNNDataFetcher () extends BaseDataFetcher {
222
- // / TODO: FIXME: Take only the first 5 to learn on /////
223
- val numToTake = 25
224
- var shuffledGameIds = util.Random .shuffle(Database .getAllGameIds()).take(numToTake)
225
- // //////////////////////////////////////////////////////
226
- // val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
222
+ val shuffledGameIds = util.Random .shuffle(Database .getAllGameIds())
227
223
var curGame = new Game (shuffledGameIds(0 ))
228
224
val nChannels = curGame.getNumChannelsHoldOrMove()
229
225
numOutcomes = curGame.getNumOutcomesHoldOrMove()
@@ -290,11 +286,7 @@ class CNNDataFetcher() extends BaseDataFetcher {
290
286
def fetchNextGame () = {
291
287
val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1 )
292
288
if (ls.length == 0 ) {
293
- // TODO FIXME go on forever ////
294
- shuffledGameIds = util.Random .shuffle(Database .getAllGameIds()).take(numToTake)
295
- curGame = new Game (shuffledGameIds(0 ))
296
- // /////////////////////////////
297
- // curGame = null
289
+ curGame = null
298
290
} else {
299
291
curGame = new Game (ls(0 ))
300
292
}
0 commit comments