Skip to content

Commit 44f9105

Browse files
committed
Abstractor -> Encoder
1 parent 5a1cbb6 commit 44f9105

File tree

7 files changed

+66
-28
lines changed

7 files changed

+66
-28
lines changed

src/clustering/abstractor.rs src/clustering/encoding.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ use std::collections::BTreeMap;
2020
/// full game tree, learned by kmeans
2121
/// rooted in showdown equity at the River.
2222
#[derive(Default)]
23-
pub struct Abstractor(BTreeMap<Isomorphism, Abstraction>);
23+
pub struct Encoder(BTreeMap<Isomorphism, Abstraction>);
2424

2525
/* learning methods
2626
*
2727
* during clustering, we're constantly inserting and updating
2828
* the abstraction mapping. needs to help project layers
2929
* hierarchically, while also
3030
*/
31-
impl Abstractor {
31+
impl Encoder {
3232
/// only run this once.
3333
pub fn learn() {
3434
if Self::done() {
@@ -39,6 +39,8 @@ impl Abstractor {
3939
.inner() // cluster turn
4040
.save()
4141
.inner() // cluster flop
42+
.save()
43+
.inner() // cluster preflop
4244
.save();
4345
}
4446
}
@@ -89,7 +91,7 @@ impl Abstractor {
8991
* by sampling according to a given Profile. here we provide
9092
* methods for unraveling the Tree
9193
*/
92-
impl Abstractor {
94+
impl Encoder {
9395
/// abstraction methods
9496
pub fn chance_abstraction(&self, game: &Game) -> Abstraction {
9597
self.abstraction(&Isomorphism::from(Observation::from(game)))
@@ -145,7 +147,7 @@ use std::io::Write;
145147
* straightforward to compute on the fly, for different reasons
146148
*/
147149

148-
impl From<Street> for Abstractor {
150+
impl From<Street> for Encoder {
149151
fn from(street: Street) -> Self {
150152
let file = File::open(format!("{}.abstraction.pgcopy", street)).expect("open file");
151153
let mut buffer = [0u8; 2];
@@ -170,7 +172,7 @@ impl From<Street> for Abstractor {
170172
}
171173
}
172174

173-
impl Abstractor {
175+
impl Encoder {
174176
/// indicates whether the abstraction table is already on disk
175177
pub fn done() -> bool {
176178
[
@@ -230,15 +232,15 @@ mod tests {
230232
fn persistence() {
231233
let street = Street::Rive;
232234
let file = format!("{}.abstraction.pgcopy", street);
233-
let save = Abstractor(
235+
let save = Encoder(
234236
(0..100)
235237
.map(|_| Observation::from(street))
236238
.map(|o| Isomorphism::from(o))
237239
.map(|o| (o, Abstraction::random()))
238240
.collect(),
239241
);
240242
save.save(street);
241-
let load = Abstractor::from(street);
243+
let load = Encoder::from(street);
242244
std::iter::empty()
243245
.chain(save.0.iter().zip(load.0.iter()))
244246
.chain(load.0.iter().zip(save.0.iter()))

src/clustering/layer.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::abstraction::Abstraction;
2-
use super::abstractor::Abstractor;
32
use super::datasets::AbstractionSpace;
43
use super::datasets::ObservationSpace;
4+
use super::encoding::Encoder;
55
use super::histogram::Histogram;
66
use super::metric::Metric;
77
use super::xor::Pair;
@@ -40,7 +40,7 @@ use std::collections::BTreeMap;
4040
pub struct Layer {
4141
street: Street,
4242
metric: Metric,
43-
lookup: Abstractor,
43+
lookup: Encoder,
4444
kmeans: AbstractionSpace,
4545
points: ObservationSpace,
4646
}
@@ -56,8 +56,8 @@ impl Layer {
5656
const fn k(street: Street) -> usize {
5757
match street {
5858
Street::Pref => 169,
59-
Street::Flop => 8,
60-
Street::Turn => 8,
59+
Street::Flop => 32,
60+
Street::Turn => 32,
6161
Street::Rive => unreachable!(),
6262
}
6363
}
@@ -69,8 +69,8 @@ impl Layer {
6969
const fn t(street: Street) -> usize {
7070
match street {
7171
Street::Pref => 0,
72-
Street::Flop => 16,
73-
Street::Turn => 16,
72+
Street::Flop => 32,
73+
Street::Turn => 32,
7474
Street::Rive => unreachable!(),
7575
}
7676
}
@@ -85,7 +85,7 @@ impl Layer {
8585
Self {
8686
street: Street::Rive,
8787
metric: Metric::default(),
88-
lookup: Abstractor::default(),
88+
lookup: Encoder::default(),
8989
kmeans: AbstractionSpace::default(),
9090
points: ObservationSpace::default(),
9191
}
@@ -97,7 +97,7 @@ impl Layer {
9797
/// 3. cluster kmeans centroids
9898
pub fn inner(&self) -> Self {
9999
let mut layer = Self {
100-
lookup: Abstractor::default(), // assigned during clustering
100+
lookup: Encoder::default(), // assigned during clustering
101101
kmeans: AbstractionSpace::default(), // assigned during clustering
102102
street: self.inner_street(), // uniquely determined by outer layer
103103
metric: self.inner_metric(), // uniquely determined by outer layer
@@ -143,7 +143,7 @@ impl Layer {
143143
let x = self.kmeans.0.get(a).expect("pre-computed").histogram();
144144
let y = self.kmeans.0.get(b).expect("pre-computed").histogram();
145145
let distance = self.metric.emd(x, y) + self.metric.emd(y, x);
146-
let distance = distance / 2.0;
146+
let distance = distance / 2.;
147147
metric.insert(index, distance);
148148
}
149149
}

src/clustering/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pub mod abstraction;
2-
pub mod abstractor;
32
pub mod centroid;
43
pub mod datasets;
4+
pub mod encoding;
55
pub mod equity;
66
pub mod histogram;
77
pub mod layer;

src/kmeans/mod.rs

+40-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,44 @@
1-
pub trait KMeans<P> {
2-
fn k(&self) -> usize;
3-
fn dataset(&self) -> &[P; N];
4-
fn cluster(&self) -> &[P; K];
1+
pub trait Point: Clone {}
2+
3+
pub trait KMeans<P>
4+
where
5+
P: Point,
6+
{
7+
fn loss(&self) -> f32;
58
fn measure(&self, a: &P, b: &P) -> f32;
6-
fn average(&self, cluster: &[P]) -> P;
7-
fn assignments(&self) -> &[usize; N]; // to what cluster is each point assigned
8-
fn frequencies(&self) -> &[usize; K]; // how many points are in each cluster
9+
fn average(&self, points: &[P]) -> P;
10+
fn dataset(&self) -> &[P; N];
11+
fn centers(&self) -> &[P; K];
12+
fn distances(&mut self) -> &mut [f32; N];
13+
fn neighbors(&mut self) -> &mut [usize; N]; // to what cluster is each point assigned
14+
fn densities(&mut self) -> &mut [usize; K]; // how many points are in each cluster
15+
}
16+
17+
impl<P> Iterator for dyn KMeans<P>
18+
where
19+
P: Point,
20+
{
21+
type Item = f32;
22+
fn next(&mut self) -> Option<Self::Item> {
23+
// do the inner of Layer::cluster_kmeans() loop
24+
// calculate neighbors &self.neighbors()
25+
// calculate densities &self.densities()
26+
// check against stopping rule(s)
27+
Some(self.loss())
28+
}
29+
}
30+
31+
pub enum Initialization {
32+
Random, // chooses random points from dataset
33+
Spaced, // chooses evenly spaced points from dataset
34+
FullPlusPlus, // weights every point inverse distance to the nearest centroid
35+
MiniPlusPlus(usize), // weights batch point inverse distance to the nearest centroid
36+
}
37+
38+
pub enum Termination {
39+
Iterations(usize),
40+
Convergent(usize, f32),
941
}
42+
1043
const N: usize = 1000;
1144
const K: usize = 10;

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod cards;
22
pub mod clustering;
3+
pub mod kmeans;
34
pub mod mccfr;
45
pub mod play;
56
pub mod players;

src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ fn main() {
44
// Boring stuff
55
logging();
66
// The k-means earth mover's distance hand-clustering algorithm.
7-
clustering::abstractor::Abstractor::learn();
7+
clustering::encoding::Encoder::learn();
88
// Monet Carlo counter-factual regret minimization. External sampling, alternating regret updates, linear weighting schedules.
99
mccfr::trainer::Blueprint::load().train();
1010
// After 100s of CPU-days of training in the arena, the CPU is ready to see you.

src/mccfr/trainer.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use super::node::Node;
66
use super::player::Player;
77
use super::profile::Profile;
88
use super::tree::Tree;
9-
use crate::clustering::abstractor::Abstractor;
9+
use crate::clustering::encoding::Encoder;
1010
use crate::play::game::Game;
1111
use crate::Probability;
1212
use crate::Utility;
@@ -25,14 +25,14 @@ struct Sample(Tree, Partition);
2525

2626
pub struct Blueprint {
2727
profile: Profile,
28-
encoder: Abstractor,
28+
encoder: Encoder,
2929
}
3030

3131
impl Blueprint {
3232
pub fn load() -> Self {
3333
Self {
3434
profile: Profile::load(),
35-
encoder: Abstractor::load(),
35+
encoder: Encoder::load(),
3636
}
3737
}
3838

@@ -81,6 +81,7 @@ impl Blueprint {
8181
let root = tree.insert(root);
8282
let root = tree.at(root);
8383
assert!(0 == root.index().index());
84+
self.profile.witness(root);
8485
if self.profile.walker() == root.player() {
8586
partition.witness(root);
8687
}
@@ -92,6 +93,7 @@ impl Blueprint {
9293
let from = tree.attach(from, tail, root);
9394
let root = tree.at(tail);
9495
assert!(1 == root.index().index() - from.index());
96+
self.profile.witness(root);
9597
if self.profile.walker() == root.player() {
9698
partition.witness(root);
9799
}

0 commit comments

Comments
 (0)