Skip to content

Commit 2feda7a

Browse files
committed
optimal transport formulation. wip
1 parent 32e0b6e commit 2feda7a

File tree

5 files changed

+260
-141
lines changed

5 files changed

+260
-141
lines changed

src/clustering/abstractor.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ impl Abstractor {
3838
.inner() // cluster turn
3939
.save()
4040
.inner() // cluster flop
41+
.save()
42+
.inner() // cluster preflop (but really just save flop.metric)
4143
.save();
4244
}
4345
}
@@ -236,9 +238,10 @@ impl Abstractor {
236238
/// 3. Write the extension (4 bytes)
237239
/// 4. Write the observation and abstraction pairs
238240
/// 5. Write the trailer (2 bytes)
239-
pub fn save(&self, name: String) {
240-
log::info!("saving abstraction lookup {}", name);
241-
let ref mut file = File::create(format!("{}.abstraction.pgcopy", name)).expect("new file");
241+
pub fn save(&self, street: Street) {
242+
log::info!("{:<32}{:<32}", "saving abstraction lookup", street);
243+
let ref mut file =
244+
File::create(format!("{}.abstraction.pgcopy", street)).expect("new file");
242245
file.write_all(b"PGCOPY\n\xff\r\n\0").expect("header");
243246
file.write_u32::<BigEndian>(0).expect("flags");
244247
file.write_u32::<BigEndian>(0).expect("extension");
@@ -272,7 +275,7 @@ mod tests {
272275
.map(|o| (o, Abstraction::random()))
273276
.collect(),
274277
);
275-
save.save(street.to_string());
278+
save.save(street);
276279
// Load from disk
277280
let load = Abstractor::load_street(street);
278281
std::iter::empty()

src/clustering/histogram.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ impl Histogram {
2929
/// all witnessed Abstractions.
3030
/// treat this like an unordered array
3131
/// even though we use BTreeMap for struct.
32-
pub fn support(&self) -> Vec<&Abstraction> {
33-
self.contribution.keys().collect()
32+
pub fn support(&self) -> impl Iterator<Item = &Abstraction> {
33+
self.contribution.keys()
34+
}
35+
pub fn normalized(&self) -> BTreeMap<Abstraction, f32> {
36+
self.contribution
37+
.iter()
38+
.map(|(&a, &count)| (a, count as f32 / self.mass as f32))
39+
.collect()
3440
}
3541

3642
/// useful only for k-means edge case of centroid drift
@@ -39,6 +45,11 @@ impl Histogram {
3945
self.contribution.is_empty()
4046
}
4147

48+
/// size of the support
49+
pub fn size(&self) -> usize {
50+
self.contribution.len()
51+
}
52+
4253
/// insert the Abstraction into our support,
4354
/// incrementing its local weight,
4455
/// incrementing our global norm.

src/clustering/layer.rs

+52-45
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,6 @@ use rayon::iter::IntoParallelRefIterator;
1717
use rayon::iter::ParallelIterator;
1818
use std::collections::BTreeMap;
1919

20-
/// number of kmeans centroids.
21-
/// this determines the granularity of the abstraction space.
22-
///
23-
/// - CPU: O(N^2) for kmeans initialization
24-
/// - CPU: O(N) for kmeans clustering
25-
/// - RAM: O(N^2) for learned metric
26-
/// - RAM: O(N) for learned centroids
27-
const N_KMEANS_CENTROIDS: usize = 256;
28-
29-
/// number of kmeans iterations.
30-
/// this controls the precision of the abstraction space.
31-
///
32-
/// - CPU: O(N) for kmeans clustering
33-
const N_KMEANS_ITERATION: usize = 64;
34-
3520
/// Hierarchical K Means Learner.
3621
/// this is decomposed into the necessary data structures
3722
/// for kmeans clustering to occur for a given `Street`.
@@ -61,6 +46,35 @@ pub struct Layer {
6146
}
6247

6348
impl Layer {
49+
/// number of kmeans centroids.
50+
/// this determines the granularity of the abstraction space.
51+
///
52+
/// - CPU: O(N^2) for kmeans initialization
53+
/// - CPU: O(N) for kmeans clustering
54+
/// - RAM: O(N^2) for learned metric
55+
/// - RAM: O(N) for learned centroids
56+
const fn k(street: Street) -> usize {
57+
match street {
58+
Street::Pref => 169,
59+
Street::Flop => 8,
60+
Street::Turn => 8,
61+
Street::Rive => unreachable!(),
62+
}
63+
}
64+
65+
/// number of kmeans iterations.
66+
/// this controls the precision of the abstraction space.
67+
///
68+
/// - CPU: O(N) for kmeans clustering
69+
const fn t(street: Street) -> usize {
70+
match street {
71+
Street::Pref => 0,
72+
Street::Flop => 128,
73+
Street::Turn => 32,
74+
Street::Rive => unreachable!(),
75+
}
76+
}
77+
6478
/// start with the River layer. everything is empty because we
6579
/// can generate `Abstractor` and `SmallSpace` from "scratch".
6680
/// - `lookup`: lazy equity calculation of river observations
@@ -95,8 +109,8 @@ impl Layer {
95109
}
96110
/// save the current layer's `Metric` and `Abstractor` to disk
97111
pub fn save(self) -> Self {
98-
self.metric.save(format!("{}", self.street.next())); // outer layer generates this purely (metric over projections)
99-
self.lookup.save(format!("{}", self.street)); // while inner layer generates this (clusters)
112+
self.metric.save(self.street.next()); // outer layer generates this purely (metric over projections)
113+
self.lookup.save(self.street); // while inner layer generates this (clusters)
100114
self
101115
}
102116

@@ -115,7 +129,7 @@ impl Layer {
115129
///
116130
/// we symmetrize the distance by averaging the EMDs in both directions.
117131
/// the distnace isn't symmetric in the first place only because our heuristic algo is not fully accurate
118-
pub fn inner_metric(&self) -> Metric {
132+
fn inner_metric(&self) -> Metric {
119133
log::info!(
120134
"{:<32}{:<32}",
121135
"computing metric",
@@ -170,13 +184,13 @@ impl Layer {
170184
log::info!(
171185
"{:<32}{:<32}",
172186
"declaring abstractions",
173-
format!("{} {} clusters", self.street, N_KMEANS_CENTROIDS)
187+
format!("{} {} clusters", self.street, Self::k(self.street))
174188
);
175189
let ref mut rng = rand::thread_rng();
176-
let progress = Self::progress(N_KMEANS_CENTROIDS);
190+
let progress = Self::progress(Self::k(self.street));
177191
self.kmeans.expand(self.sample_uniform(rng));
178192
progress.inc(1);
179-
while self.kmeans.0.len() < N_KMEANS_CENTROIDS {
193+
while self.kmeans.0.len() < Self::k(self.street) {
180194
self.kmeans.expand(self.sample_outlier(rng));
181195
progress.inc(1);
182196
}
@@ -189,17 +203,16 @@ impl Layer {
189203
log::info!(
190204
"{:<32}{:<32}",
191205
"clustering observations",
192-
format!("{} {} iterations", self.street, N_KMEANS_ITERATION)
206+
format!("{} {} iterations", self.street, Self::t(self.street))
193207
);
194-
let progress = Self::progress(N_KMEANS_ITERATION);
195-
for _ in 0..N_KMEANS_ITERATION {
208+
let progress = Self::progress(Self::t(self.street));
209+
for _ in 0..Self::t(self.street) {
196210
let neighbors = self
197211
.points
198212
.0
199213
.par_iter()
200214
.map(|(_, h)| self.nearest_neighbor(h))
201215
.collect::<Vec<(Abstraction, f32)>>();
202-
self.kmeans.clear();
203216
self.assign_nearest_neighbor(neighbors);
204217
self.assign_orphans_randomly();
205218
progress.inc(1);
@@ -211,36 +224,33 @@ impl Layer {
211224
/// by computing the EMD distance between the `Observation`'s `Histogram` and each `Centroid`'s `Histogram`
212225
/// and returning the `Abstraction` of the nearest `Centroid`
213226
fn assign_nearest_neighbor(&mut self, neighbors: Vec<(Abstraction, f32)>) {
227+
self.kmeans.clear();
214228
let mut loss = 0.;
215-
for ((observation, histogram), (abstraction, distance)) in
216-
std::iter::zip(self.points.0.iter_mut(), neighbors.iter())
217-
{
218-
loss += distance * distance;
219-
self.lookup.assign(abstraction, observation);
220-
self.kmeans.absorb(abstraction, histogram);
229+
for ((obs, hist), (abs, dist)) in self.points.0.iter_mut().zip(neighbors.iter()) {
230+
loss += dist * dist;
231+
self.lookup.assign(abs, obs);
232+
self.kmeans.absorb(abs, hist);
221233
}
222-
log::debug!("LOSS {:>12.8}", loss / self.points.0.len() as f32);
234+
let loss = loss / self.points.0.len() as f32;
235+
log::trace!("LOSS {:>12.8}", loss);
223236
}
224237
/// centroid drift may make it such that some centroids are empty
225238
/// so we reinitialize empty centroids with random Observations if necessary
226239
fn assign_orphans_randomly(&mut self) {
227240
for ref a in self.kmeans.orphans() {
228-
log::warn!(
229-
"{:<32}{:<32}",
230-
"reassigning empty centroid",
231-
format!("0x{}", a)
232-
);
233241
let ref mut rng = rand::thread_rng();
234242
let ref sample = self.sample_uniform(rng);
235243
self.kmeans.absorb(a, sample);
244+
log::debug!(
245+
"{:<32}{:<32}",
246+
"reassigned empty centroid",
247+
format!("0x{}", a)
248+
);
236249
}
237250
}
238251

239252
/// the first Centroid is uniformly random across all `Observation` `Histogram`s
240-
fn sample_uniform<R>(&self, rng: &mut R) -> Histogram
241-
where
242-
R: Rng,
243-
{
253+
fn sample_uniform<R: Rng>(&self, rng: &mut R) -> Histogram {
244254
self.points
245255
.0
246256
.values()
@@ -251,10 +261,7 @@ impl Layer {
251261
/// each next Centroid is selected with probability proportional to
252262
/// the squared distance to the nearest neighboring Centroid.
253263
/// faster convergence, i guess. on the shoulders of giants
254-
fn sample_outlier<R>(&self, rng: &mut R) -> Histogram
255-
where
256-
R: Rng,
257-
{
264+
fn sample_outlier<R: Rng>(&self, rng: &mut R) -> Histogram {
258265
let weights = self
259266
.points
260267
.0

0 commit comments

Comments
 (0)