Skip to content

Commit fb03ef5

Browse files
committed
training run with parameters as in lib.rs
1 parent 73ec8d2 commit fb03ef5

File tree

3 files changed

+11
-17
lines changed

3 files changed

+11
-17
lines changed

src/clustering/lookup.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,16 @@ impl Save for Lookup {
5151
std::fs::metadata(format!("{}{}", street, Self::name())).is_ok()
5252
}
5353
fn make(street: Street) -> Self {
54-
let n = street.n_isomorphisms();
55-
let progress = crate::progress(n);
54+
// abstractions for River are calculated once via obs.equity
55+
// abstractions for Preflop are cequivalent to just enumerating isomorphisms
5656
match street {
5757
Street::Rive => IsomorphismIterator::from(Street::Rive)
5858
.map(|iso| (iso, Abstraction::from(iso.0.equity())))
59-
.inspect(|_| progress.inc(1))
6059
.collect::<BTreeMap<_, _>>()
6160
.into(),
6261
Street::Pref => IsomorphismIterator::from(Street::Pref)
6362
.enumerate()
6463
.map(|(k, iso)| (iso, Abstraction::from((Street::Pref, k))))
65-
.inspect(|_| progress.inc(1))
6664
.collect::<BTreeMap<_, _>>()
6765
.into(),
6866
_ => panic!("lookup must be learned via layer for {street}"),

src/lib.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ const N_RAISE: usize = 3;
2626

2727
/// sinkhorn optimal transport parameters
2828
const SINKHORN_TEMPERATURE: Entropy = 0.125;
29-
const SINKHORN_ITERATIONS: usize = 16;
30-
const SINKHORN_TOLERANCE: Energy = 0.001;
29+
const SINKHORN_ITERATIONS: usize = 32;
30+
const SINKHORN_TOLERANCE: Energy = 0.005;
3131

3232
// kmeans clustering parameters
33-
const KMEANS_FLOP_TRAINING_ITERATIONS: usize = 32;
34-
const KMEANS_TURN_TRAINING_ITERATIONS: usize = 32;
35-
const KMEANS_FLOP_CLUSTER_COUNT: usize = 24;
36-
const KMEANS_TURN_CLUSTER_COUNT: usize = 16;
37-
const KMEANS_EQTY_CLUSTER_COUNT: usize = 64;
33+
const KMEANS_FLOP_TRAINING_ITERATIONS: usize = KMEANS_TURN_TRAINING_ITERATIONS;
34+
const KMEANS_TURN_TRAINING_ITERATIONS: usize = KMEANS_TURN_CLUSTER_COUNT;
35+
const KMEANS_FLOP_CLUSTER_COUNT: usize = 128;
36+
const KMEANS_TURN_CLUSTER_COUNT: usize = 144;
37+
const KMEANS_EQTY_CLUSTER_COUNT: usize = 101;
3838

3939
// mccfr parameters
40-
const CFR_BATCH_SIZE: usize = 16;
41-
const CFR_TREE_COUNT: usize = 1024; // WARNING THIS WILL NOT SOLVE ANYTHING
40+
const CFR_BATCH_SIZE: usize = 256;
41+
const CFR_TREE_COUNT: usize = 1_048_576;
4242
const CFR_ITERATIONS: usize = CFR_TREE_COUNT / CFR_BATCH_SIZE;
4343
const CFR_PRUNNING_PHASE: usize = 100_000_000 / CFR_BATCH_SIZE;
4444
const CFR_DISCOUNT_PHASE: usize = 100_000 / CFR_BATCH_SIZE;

src/mccfr/profile.rs

-4
Original file line numberDiff line numberDiff line change
@@ -583,12 +583,8 @@ mod tests {
583583
/// arguments to the save function to write to a temporary name
584584
/// and delete the file
585585
fn persistence() {
586-
let name = "test";
587-
let file = format!("{}.profile.pgcopy", name);
588586
let save = Profile::random();
589-
save.save();
590587
let load = Profile::load(Street::random());
591-
std::fs::remove_file(file).unwrap();
592588
assert!(std::iter::empty()
593589
.chain(save.strategies.iter().zip(load.strategies.iter()))
594590
.chain(load.strategies.iter().zip(save.strategies.iter()))

0 commit comments

Comments
 (0)