@@ -75,7 +75,7 @@ case class MoveOrHoldCNN() extends ModelArch {
75
75
listBuilder.layer(lindex, subsamp0.build)
76
76
lindex += 1
77
77
78
- val cnnBuilder0_5 = new ConvolutionLayer .Builder (5 , 5 )
78
+ val cnnBuilder0_5 = new ConvolutionLayer .Builder (7 , 7 )
79
79
cnnBuilder0_5.stride(1 , 1 )
80
80
cnnBuilder0_5.padding(1 , 1 )
81
81
cnnBuilder0_5.biasInit(0.01 )
@@ -93,7 +93,7 @@ case class MoveOrHoldCNN() extends ModelArch {
93
93
listBuilder.layer(lindex, subsamp0_5.build)
94
94
lindex += 1
95
95
96
- val cnnBuilder1 = new ConvolutionLayer .Builder (3 , 3 )
96
+ val cnnBuilder1 = new ConvolutionLayer .Builder (5 , 5 )
97
97
cnnBuilder1.stride(1 , 1 )
98
98
cnnBuilder1.padding(1 , 1 )
99
99
cnnBuilder1.biasInit(0.01 )
@@ -219,7 +219,11 @@ case class MoveOrHoldCNN() extends ModelArch {
219
219
}
220
220
221
221
class CNNDataFetcher () extends BaseDataFetcher {
222
- val shuffledGameIds = util.Random .shuffle(Database .getAllGameIds())
222
+ // / TODO: FIXME: Take only the first 5 to learn on /////
223
+ val numToTake = 25
224
+ var shuffledGameIds = util.Random .shuffle(Database .getAllGameIds()).take(numToTake)
225
+ // //////////////////////////////////////////////////////
226
+ // val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
223
227
var curGame = new Game (shuffledGameIds(0 ))
224
228
val nChannels = curGame.getNumChannelsHoldOrMove()
225
229
numOutcomes = curGame.getNumOutcomesHoldOrMove()
@@ -284,14 +288,16 @@ class CNNDataFetcher() extends BaseDataFetcher {
284
288
* is done, this will set it to null.
285
289
*/
286
290
def fetchNextGame () = {
287
- // //// TODO: FIXME: DEBUG: Use only the current game again and again /////
288
- val ls = shuffledGameIds.dropWhile(i => i != curGame.id)
289
- // /////////////////////////////
290
- // val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1)
291
- if (ls.length == 0 )
292
- curGame = null
293
- else
291
+ val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1 )
292
+ if (ls.length == 0 ) {
293
+ // TODO FIXME go on forever ////
294
+ shuffledGameIds = util.Random .shuffle(Database .getAllGameIds()).take(numToTake)
295
+ curGame = new Game (shuffledGameIds(0 ))
296
+ // /////////////////////////////
297
+ // curGame = null
298
+ } else {
294
299
curGame = new Game (ls(0 ))
300
+ }
295
301
println(" GAME ID: " + curGame.id)
296
302
}
297
303
}
0 commit comments