Skip to content

Commit f0f3130

Browse files
committed
MCCFR training working on small shortdeck!!
1 parent 9cc24b0 commit f0f3130

File tree

9 files changed

+93
-89
lines changed

9 files changed

+93
-89
lines changed

src/clustering/encoding.rs

+16-30
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ impl Encoder {
104104
/// wrap the (Game, Bucket) in a Data
105105
pub fn encode(&self, game: Game, action: Action, past: &Vec<&Edge>) -> (Data, Edge) {
106106
let edge = Edge::from(action);
107-
let chance = self.chance_abstraction(&game);
108107
let choice = self.action_abstraction(&past, &edge);
108+
let chance = self.chance_abstraction(&game);
109109
let bucket = Bucket::from((choice, chance));
110-
let choice = Data::from((game, bucket));
111-
(choice, edge)
110+
let data = Data::from((game, bucket));
111+
(data, edge)
112112
}
113113

114114
/// i like to think of this as "positional encoding"
@@ -118,25 +118,18 @@ impl Encoder {
118118
/// the cards we see at a Node are memoryless, but the
119119
/// Path represents "how we got here"
120120
///
121-
/// for 2-players, depth works okay but there are definitely tradeoffs:
122-
/// - the same Card info at the same depth doesn't necessarily
123-
/// allow for the same available actions. which is actually a breaking problem
124-
/// since we assume all Nodes in the same Infoset have the same avaialble actions...
125-
///
126121
/// we need to assert that: any Nodes in the same Infoset have the
127-
/// same available actions. in addition to depth, we should consider
128-
/// whether we can Check, Raise, Fold, Call
122+
/// same available actions. in addition to depth, we consider
123+
/// whether or not we are in a Checkable or Foldable state.
129124
fn action_abstraction(&self, past: &Vec<&Edge>, edge: &Edge) -> Path {
130-
match edge {
131-
Edge::Random => Path::from(0),
132-
Edge::Choice(_) => Path::from(
133-
past.iter()
134-
.rev()
135-
.take_while(|edge| matches!(edge, Edge::Choice(_)))
136-
.count() as u64
137-
+ 1,
138-
),
139-
}
125+
let mut round = past
126+
.iter()
127+
.chain(std::iter::once(&edge))
128+
.rev()
129+
.take_while(|e| e.is_choice());
130+
let depth = round.clone().count();
131+
let raise = round.any(|e| e.is_raise());
132+
Path::from((depth, raise))
140133
}
141134

142135
/// the compressed card information for an observation
@@ -191,16 +184,9 @@ impl From<Street> for Encoder {
191184
impl Encoder {
192185
/// indicates whether the abstraction table is already on disk
193186
pub fn done() -> bool {
194-
[
195-
"flop.abstraction.pgcopy",
196-
"turn.abstraction.pgcopy",
197-
"preflop.metric.pgcopy",
198-
"flop.metric.pgcopy",
199-
"turn.metric.pgcopy",
200-
"river.metric.pgcopy",
201-
]
202-
.iter()
203-
.any(|file| std::path::Path::new(file).exists())
187+
["flop.abstraction.pgcopy", "turn.abstraction.pgcopy"]
188+
.iter()
189+
.any(|file| std::path::Path::new(file).exists())
204190
}
205191

206192
/// pulls the entire pre-computed abstraction table

src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ fn main() {
66
// The k-means earth mover's distance hand-clustering algorithm.
77
clustering::encoding::Encoder::learn();
88
// Monet Carlo counter-factual regret minimization. External sampling, alternating regret updates, linear weighting schedules.
9-
mccfr::training::Blueprint::load().train();
9+
mccfr::minimizer::Blueprint::load().train();
1010
// After 100s of CPU-days of training in the arena, the CPU is ready to see you.
1111
play::game::Game::play();
1212
}

src/mccfr/edge.rs

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ pub enum Edge {
77
Random,
88
}
99

10+
impl Edge {
11+
pub fn is_raise(&self) -> bool {
12+
if let Edge::Choice(action) = self {
13+
matches!(action, Action::Raise(_) | Action::Shove(_))
14+
} else {
15+
false
16+
}
17+
}
18+
pub fn is_choice(&self) -> bool {
19+
matches!(self, Edge::Choice(_))
20+
}
21+
}
22+
1023
impl From<Action> for Edge {
1124
fn from(action: Action) -> Self {
1225
match action {
@@ -33,6 +46,7 @@ impl From<Edge> for u32 {
3346
}
3447
}
3548
}
49+
3650
impl std::fmt::Display for Edge {
3751
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3852
match self {

src/mccfr/training.rs src/mccfr/minimizer.rs

+23-32
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ impl Blueprint {
8888
/// Build the Tree iteratively starting from the root node.
8989
/// This function uses a stack to simulate recursion and builds the tree in a depth-first manner.
9090
fn sample(&mut self) -> Sample {
91-
log::info!("sampling tree");
9291
let mut tree = Tree::empty();
9392
let mut partition = Partition::new();
9493
let ref mut queue = Vec::new();
@@ -104,47 +103,39 @@ impl Blueprint {
104103
let head = tree.at(tail);
105104
self.visit(&head, queue, infos);
106105
}
107-
println!("\n{}\n", self.profile);
108-
println!("\n{}\n", tree);
109106
Sample(tree, partition)
110107
}
111108

112109
/// Process a node: witness it for profile and partition if necessary,
113110
/// and add its children to the exploration queue.
111+
/// under external sampling rules:
112+
/// - explore ALL my options
113+
/// - explore 1 of Chance
114+
/// - explore 1 of Villain
114115
fn visit(&mut self, head: &Node, queue: &mut Vec<Branch>, infosets: &mut Partition) {
115-
let explored = self.explore(head);
116-
if head.player() == self.profile.walker() {
117-
infosets.witness(head);
118-
}
119-
if head.player() != Player::chance() {
120-
self.profile.witness(head, &explored);
121-
}
122-
for (tail, from) in explored {
123-
queue.push((tail, from, head.index()));
124-
}
125-
}
126-
127-
/// generate children for a given node
128-
/// under external sampling rules.
129-
/// explore all MY options
130-
/// but only 1 of Chance, 1 of Villain
131-
fn explore(&self, node: &Node) -> Vec<(Data, Edge)> {
132-
let children = self.children(node);
133-
let walker = self.profile.walker();
134116
let chance = Player::chance();
135-
let player = node.player();
136-
if children.is_empty() {
117+
let player = head.player();
118+
let walker = self.profile.walker();
119+
let children = self.children(head);
120+
let sample = if children.is_empty() {
137121
vec![]
138122
} else if player == chance {
139-
self.take_any(children, node)
140-
} else if player == walker {
141-
self.take_all(children, node)
123+
self.sample_any(children, head)
142124
} else if player != walker {
143-
self.take_one(children, node)
125+
self.profile.witness(head, &children);
126+
self.sample_one(children, head)
127+
} else if player == walker {
128+
infosets.witness(head);
129+
self.profile.witness(head, &children);
130+
self.sample_all(children, head)
144131
} else {
145132
panic!("at the disco")
133+
};
134+
for (tail, from) in sample {
135+
queue.push((tail, from, head.index()));
146136
}
147137
}
138+
148139
fn children(&self, node: &Node) -> Vec<(Data, Edge)> {
149140
const MAX_N_RAISE: usize = 2;
150141
let ref past = node.history();
@@ -173,14 +164,14 @@ impl Blueprint {
173164
// external sampling
174165

175166
/// full exploration of my decision space Edges
176-
fn take_all(&self, choices: Vec<(Data, Edge)>, _: &Node) -> Vec<(Data, Edge)> {
167+
fn sample_all(&self, choices: Vec<(Data, Edge)>, _: &Node) -> Vec<(Data, Edge)> {
177168
assert!(choices
178169
.iter()
179170
.all(|(_, edge)| matches!(edge, Edge::Choice(_))));
180171
choices
181172
}
182173
/// uniform sampling of chance Edge
183-
fn take_any(&self, mut choices: Vec<(Data, Edge)>, head: &Node) -> Vec<(Data, Edge)> {
174+
fn sample_any(&self, mut choices: Vec<(Data, Edge)>, head: &Node) -> Vec<(Data, Edge)> {
184175
let ref mut rng = self.profile.rng(head);
185176
let n = choices.len();
186177
let choice = rng.gen_range(0..n);
@@ -189,7 +180,7 @@ impl Blueprint {
189180
vec![chosen]
190181
}
191182
/// Profile-weighted sampling of opponent Edge
192-
fn take_one(&self, mut choices: Vec<(Data, Edge)>, head: &Node) -> Vec<(Data, Edge)> {
183+
fn sample_one(&self, mut choices: Vec<(Data, Edge)>, head: &Node) -> Vec<(Data, Edge)> {
193184
let ref mut rng = self.profile.rng(head);
194185
let policy = choices
195186
.iter()
@@ -207,7 +198,7 @@ impl Blueprint {
207198
#[cfg(test)]
208199
mod tests {
209200
use super::*;
210-
use crate::mccfr::training::Blueprint;
201+
use crate::mccfr::minimizer::Blueprint;
211202
use petgraph::graph::NodeIndex;
212203

213204
#[test]

src/mccfr/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ pub mod bucket;
22
pub mod data;
33
pub mod edge;
44
pub mod info;
5+
pub mod minimizer;
56
pub mod node;
67
pub mod partition;
78
pub mod path;
89
pub mod player;
910
pub mod profile;
1011
pub mod strategy;
11-
pub mod training;
1212
pub mod tree;

src/mccfr/node.rs

-6
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,6 @@ impl<'tree> Node<'tree> {
110110
.collect()
111111
}
112112
}
113-
/// SAFETY:
114-
/// we have logical assurance that lifetimes work out effectively:
115-
/// 'info: 'node: 'tree
116-
/// Info is created from a Node
117-
/// Node is created from a Tree
118-
/// Tree owns its Graph
119113
pub fn graph(&self) -> &'tree DiGraph<Data, Edge> {
120114
self.graph
121115
}

src/mccfr/path.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Ord, PartialOrd)]
22
pub struct Path(u64);
33

4+
impl From<(usize, bool)> for Path {
5+
fn from((depth, raise): (usize, bool)) -> Self {
6+
Path((depth as u64) << 1 | raise as u64)
7+
}
8+
}
9+
410
impl From<u64> for Path {
511
fn from(value: u64) -> Self {
612
Path(value)
@@ -15,6 +21,6 @@ impl From<Path> for u64 {
1521

1622
impl std::fmt::Display for Path {
1723
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18-
write!(f, "d{:02}", self.0)
24+
write!(f, "H{:02}", self.0)
1925
}
2026
}

src/mccfr/profile.rs

+30-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rand::rngs::SmallRng;
1313
use rand::SeedableRng;
1414
use std::collections::hash_map::DefaultHasher;
1515
use std::collections::BTreeMap;
16+
use std::collections::HashSet;
1617
use std::hash::Hash;
1718
use std::hash::Hasher;
1819

@@ -42,30 +43,42 @@ impl Profile {
4243
/// increment Epoch counter
4344
/// and return current count
4445
pub fn next(&mut self) -> usize {
46+
log::info!("{:>10}", self.iterations);
4547
self.iterations += 1;
4648
self.iterations
4749
}
4850
/// idempotent initialization of Profile
4951
/// at a given Node.
5052
///
51-
/// if we've already visited this Infoset,
52-
/// then we can skip over it.
53+
/// if we've already visited this Bucket,
54+
/// then we just want to make sure that
55+
/// the available outgoing Edges are consistent.
5356
///
5457
/// otherwise, we initialize the strategy
5558
/// at this Node with uniform distribution
56-
/// over its spawned support:
57-
/// Data -> Vec<(Data, Edge)>.
59+
/// over its outgoing Edges .
60+
///
61+
/// @assertion
5862
pub fn witness(&mut self, node: &Node, children: &Vec<(Data, Edge)>) {
59-
let n = children.len();
60-
let uniform = 1. / n as Probability;
6163
let bucket = node.bucket();
62-
for (_, edge) in children {
63-
self.strategies
64-
.entry(bucket.clone())
65-
.or_insert_with(BTreeMap::default)
66-
.entry(edge.clone())
67-
.or_insert_with(Strategy::default)
68-
.policy = uniform;
64+
match self.strategies.get(bucket) {
65+
Some(strategy) => {
66+
let expected = children.iter().map(|(_, e)| e).collect::<HashSet<_>>();
67+
let observed = strategy.keys().collect::<HashSet<_>>();
68+
assert!(observed == expected);
69+
}
70+
None => {
71+
let n = children.len();
72+
let uniform = 1. / n as Probability;
73+
for (_, edge) in children {
74+
self.strategies
75+
.entry(bucket.clone())
76+
.or_insert_with(BTreeMap::default)
77+
.entry(edge.clone())
78+
.or_insert_with(Strategy::default)
79+
.policy = uniform;
80+
}
81+
}
6982
}
7083
}
7184

@@ -89,10 +102,10 @@ impl Profile {
89102
let epochs = self.epochs();
90103
for (action, policy) in vector {
91104
let strategy = self.strategy(bucket, action);
92-
strategy.policy = *policy;
93105
strategy.advice *= epochs as Probability;
94106
strategy.advice += policy;
95107
strategy.advice /= epochs as Probability + 1.;
108+
strategy.policy = *policy;
96109
}
97110
}
98111

@@ -142,7 +155,7 @@ impl Profile {
142155
/// division by 2 is used to allow each player
143156
/// one iteration to walk the Tree in a single Epoch
144157
pub fn epochs(&self) -> usize {
145-
self.iterations
158+
self.iterations / 2
146159
}
147160
/// which player is traversing the Tree on this Epoch?
148161
/// used extensively in assertions and utility calculations
@@ -158,7 +171,7 @@ impl Profile {
158171
/// emulate the "opponent" strategy. the opponent is just whoever is not
159172
/// the traverser
160173
pub fn policy(&self, node: &Node, edge: &Edge) -> Probability {
161-
assert!(node.player() != Player::chance().to_owned());
174+
assert!(node.player() != Player::chance());
162175
assert!(node.player() != self.walker());
163176
self.strategies
164177
.get(node.bucket())
@@ -413,7 +426,7 @@ impl Profile {
413426
}
414427
}
415428
impl std::fmt::Display for Profile {
416-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417430
write!(
418431
f,
419432
"{}",

src/mccfr/tree.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ impl Tree {
4949
.0
5050
.edge_weight(self.0.find_edge(index, *child).unwrap())
5151
.unwrap();
52-
writeln!(f, "{}{}──{} -> {}", prefix, stem, edge, head)?;
52+
writeln!(f, "{}{}──{} {}", prefix, stem, edge, head)?;
5353
self.draw(f, *child, &format!("{}{}", prefix, gaps))?;
5454
}
5555
Ok(())

0 commit comments

Comments
 (0)