Skip to content

Commit 654c1d5

Browse files
author
Max Strange
committed
CNN memorizes small (5 games) dataset
1 parent a4b1014 commit 654c1d5

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

pyboat/src/main/scala/Database.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ case class LocalHost() extends DriverConfig {
1515
val password = ""
1616
}
1717
case class Synapse() extends DriverConfig {
18-
val url = "jdbc:mysql://10.75.6.229/diplomacy?autoReconnect=true&useSSL=false"
18+
val url = "jdbc:mysql://10.1.58.68/diplomacy?autoReconnect=true&useSSL=false"
1919
val username = "maxst"
2020
val password = Database.readPasswordFromFile()
2121
}
2222

2323
object Database {
2424
val driver = "com.mysql.jdbc.Driver"
25-
val driverConfig = LocalHost()
25+
val driverConfig = Synapse()
2626
val url = driverConfig.url
2727
val username = driverConfig.username
2828
val password = driverConfig.password

pyboat/src/main/scala/Main.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ object PyBoat {
2727
util.Random.setSeed(12345)
2828

2929
// !! Change this value to change what model is being trained !! //
30-
val architecture: ModelArch = MoveOrHoldMLP()
30+
val architecture: ModelArch = MoveOrHoldCNN()
3131
println("ARCHITECTURE: " + architecture)
3232

3333
val networkConf: MultiLayerConfiguration = architecture match {
@@ -87,11 +87,11 @@ object PyBoat {
8787
println("Value of [0, 25]: " + out.getDouble(0, 25))
8888
if (i % 100 == 0) {
8989
println("=================== Mask ===================")
90-
println(ds.getLabelsMaskArray().getRow(10))
90+
println(ds.getLabelsMaskArray().getRow(0))
9191
println("=================== Label ===================")
92-
println(ds.getLabels().getRow(10))
92+
println(ds.getLabels().getRow(0))
9393
println("=================== Output ===================")
94-
println(out.getRow(10))
94+
println(out.getRow(0))
9595
}
9696
if (i % 1000 == 0) {
9797
println("Saving model...")

pyboat/src/main/scala/models/MoveOrHoldCNN.scala

+16-10
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ case class MoveOrHoldCNN() extends ModelArch {
7575
listBuilder.layer(lindex, subsamp0.build)
7676
lindex += 1
7777

78-
val cnnBuilder0_5 = new ConvolutionLayer.Builder(5, 5)
78+
val cnnBuilder0_5 = new ConvolutionLayer.Builder(7, 7)
7979
cnnBuilder0_5.stride(1, 1)
8080
cnnBuilder0_5.padding(1, 1)
8181
cnnBuilder0_5.biasInit(0.01)
@@ -93,7 +93,7 @@ case class MoveOrHoldCNN() extends ModelArch {
9393
listBuilder.layer(lindex, subsamp0_5.build)
9494
lindex += 1
9595

96-
val cnnBuilder1 = new ConvolutionLayer.Builder(3, 3)
96+
val cnnBuilder1 = new ConvolutionLayer.Builder(5, 5)
9797
cnnBuilder1.stride(1, 1)
9898
cnnBuilder1.padding(1, 1)
9999
cnnBuilder1.biasInit(0.01)
@@ -219,7 +219,11 @@ case class MoveOrHoldCNN() extends ModelArch {
219219
}
220220

221221
class CNNDataFetcher() extends BaseDataFetcher {
222-
val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
222+
/// TODO: FIXME: Take only the first 5 to learn on /////
223+
val numToTake = 25
224+
var shuffledGameIds = util.Random.shuffle(Database.getAllGameIds()).take(numToTake)
225+
////////////////////////////////////////////////////////
226+
//val shuffledGameIds = util.Random.shuffle(Database.getAllGameIds())
223227
var curGame = new Game(shuffledGameIds(0))
224228
val nChannels = curGame.getNumChannelsHoldOrMove()
225229
numOutcomes = curGame.getNumOutcomesHoldOrMove()
@@ -284,14 +288,16 @@ class CNNDataFetcher() extends BaseDataFetcher {
284288
* is done, this will set it to null.
285289
*/
286290
def fetchNextGame() = {
287-
////// TODO: FIXME: DEBUG: Use only the current game again and again /////
288-
val ls = shuffledGameIds.dropWhile(i => i != curGame.id)
289-
///////////////////////////////
290-
//val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1)
291-
if (ls.length == 0)
292-
curGame = null
293-
else
291+
val ls = shuffledGameIds.dropWhile(i => i != curGame.id).drop(1)
292+
if (ls.length == 0) {
293+
//TODO FIXME go on forever ////
294+
shuffledGameIds = util.Random.shuffle(Database.getAllGameIds()).take(numToTake)
295+
curGame = new Game(shuffledGameIds(0))
296+
///////////////////////////////
297+
//curGame = null
298+
} else {
294299
curGame = new Game(ls(0))
300+
}
295301
println("GAME ID: " + curGame.id)
296302
}
297303
}

0 commit comments

Comments
 (0)