Skip to content

Commit 72d94f7

Browse files
committed
database stream patches
1 parent d03c4bc commit 72d94f7

File tree

4 files changed

+102
-63
lines changed

4 files changed

+102
-63
lines changed

src/clustering/lookup.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ impl Upload for Lookup {
164164
while reader.read_exact(buffer).is_ok() {
165165
match u16::from_be_bytes(buffer.clone()) {
166166
2 => {
167-
reader.read_u32::<BE>().expect("observation length");
167+
assert!(8 == reader.read_u32::<BE>().expect("observation length"));
168168
let iso = reader.read_i64::<BE>().expect("read observation");
169-
reader.read_u32::<BE>().expect("abstraction length");
169+
assert!(8 == reader.read_u32::<BE>().expect("abstraction length"));
170170
let abs = reader.read_i64::<BE>().expect("read abstraction");
171171
let observation = Isomorphism::from(iso);
172172
let abstraction = Abstraction::from(abs);

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ pub fn init() {
100100
pub async fn db() -> std::sync::Arc<tokio_postgres::Client> {
101101
log::info!("connecting to database");
102102
let tls = tokio_postgres::tls::NoTls;
103-
let ref url = std::env::var("DB_URL").expect("DB_URL not set");
103+
let ref url =
104+
std::env::var("DB_URL").unwrap_or("postgresql://localhost:5432/postgres".to_string());
104105
let (client, connection) = tokio_postgres::connect(url, tls)
105106
.await
106107
.expect("database connection failed");

src/save/upload.rs

-17
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
use crate::cards::street::Street;
2-
use byteorder::ReadBytesExt;
3-
use byteorder::BE;
4-
use std::fs::File;
5-
use std::io::BufReader;
6-
use tokio_postgres::types::ToSql;
72
use tokio_postgres::types::Type;
83

94
// blueprint ~ 154M, (grows with number of CFR iterations)
@@ -35,18 +30,6 @@ pub trait Upload {
3530
fn load(street: Street) -> Self;
3631
/// write to disk
3732
fn save(&self);
38-
/// given a BufReader<File>, read a row according to Self::columns()
39-
fn read(reader: &mut BufReader<File>) -> Vec<Box<dyn ToSql + Sync>> {
40-
Self::columns()
41-
.iter()
42-
.cloned()
43-
.map(|ty| match ty {
44-
Type::FLOAT4 => Box::new(reader.read_f32::<BE>().unwrap()) as Box<dyn ToSql + Sync>,
45-
Type::INT8 => Box::new(reader.read_i64::<BE>().unwrap()) as Box<dyn ToSql + Sync>,
46-
_ => panic!("unsupported type: {}", ty),
47-
})
48-
.collect::<Vec<Box<dyn ToSql + Sync>>>()
49-
}
5033

5134
/// query to nuke table in Postgres
5235
fn nuke() -> String {

src/save/writer.rs

+98-43
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ use crate::clustering::metric::Metric;
66
use crate::clustering::transitions::Decomp;
77
use crate::mccfr::encoder::Encoder;
88
use crate::mccfr::profile::Profile;
9+
use byteorder::ReadBytesExt;
10+
use byteorder::BE;
911
use std::fs::File;
1012
use std::io::BufReader;
1113
use std::io::Read;
1214
use std::io::Seek;
1315
use std::sync::Arc;
1416
use tokio_postgres::binary_copy::BinaryCopyInWriter;
17+
use tokio_postgres::types::ToSql;
18+
use tokio_postgres::types::Type;
1519
use tokio_postgres::Client;
1620
use tokio_postgres::Error as E;
1721

@@ -28,59 +32,22 @@ impl Writer {
2832
let postgres = Self(crate::db().await);
2933
postgres.upload::<Metric>().await?;
3034
postgres.upload::<Decomp>().await?;
31-
postgres.upload::<Encoder>().await?;
35+
postgres.upload::<Encoder>().await?; // Lookup ?
3236
postgres.upload::<Profile>().await?;
3337
postgres.derive::<Abstraction>().await?;
3438
postgres.derive::<Street>().await?;
3539
postgres.vacuum().await?;
3640
Ok(())
3741
}
3842

39-
async fn upload<U>(&self) -> Result<(), E>
40-
where
41-
U: Upload,
42-
{
43-
let ref name = U::name();
44-
if self.has_rows(name).await? {
45-
log::info!("tables data already uploaded ({})", name);
46-
Ok(())
47-
} else {
48-
log::info!("copying {}", name);
49-
self.0.batch_execute(&U::prepare()).await?;
50-
self.0.batch_execute(&U::nuke()).await?;
51-
let sink = self.0.copy_in(&U::copy()).await?;
52-
let writer = BinaryCopyInWriter::new(sink, U::columns());
53-
futures::pin_mut!(writer);
54-
let ref mut count = [0u8; 2];
55-
for ref mut reader in U::sources()
56-
.iter()
57-
.map(|s| File::open(s).expect("file not found"))
58-
.map(|f| BufReader::new(f))
59-
{
60-
reader.seek(std::io::SeekFrom::Start(19)).unwrap();
61-
while let Ok(_) = reader.read_exact(count) {
62-
match u16::from_be_bytes(count.clone()) {
63-
0xFFFF => break,
64-
length => {
65-
assert!(length == U::columns().len() as u16);
66-
let row = U::read(reader);
67-
let row = row.iter().map(|b| &**b).collect::<Vec<_>>();
68-
writer.as_mut().write(&row).await?;
69-
}
70-
}
71-
}
72-
}
73-
writer.finish().await?;
74-
self.0.batch_execute(&U::indices()).await?;
75-
Ok(())
76-
}
77-
}
78-
7943
async fn derive<D>(&self) -> Result<(), E>
8044
where
8145
D: Derive,
8246
{
8347
let ref name = D::name();
48+
// if !does_exist(name).await? {
49+
// create
50+
// }
8451
if self.has_rows(name).await? {
8552
log::info!("tables data already uploaded ({})", name);
8653
Ok(())
@@ -103,6 +70,73 @@ impl Writer {
10370
}
10471
}
10572

73+
async fn upload<U>(&self) -> Result<(), E>
74+
where
75+
U: Upload,
76+
{
77+
let ref name = U::name();
78+
// if !does_exist(name).await? {
79+
// create
80+
// }
81+
if self.has_rows(name).await? {
82+
log::info!("tables data already uploaded ({})", name);
83+
Ok(())
84+
} else {
85+
log::info!("copying {}", name);
86+
self.prepare::<U>().await?;
87+
self.stream::<U>().await?;
88+
self.index::<U>().await?;
89+
Ok(())
90+
}
91+
}
92+
93+
async fn prepare<T>(&self) -> Result<(), E>
94+
where
95+
T: Upload,
96+
{
97+
self.0.batch_execute(&T::prepare()).await?;
98+
self.0.batch_execute(&T::nuke()).await?;
99+
Ok(())
100+
}
101+
102+
async fn index<T>(&self) -> Result<(), E>
103+
where
104+
T: Upload,
105+
{
106+
self.0.batch_execute(&T::indices()).await?;
107+
Ok(())
108+
}
109+
110+
async fn stream<T>(&self) -> Result<(), E>
111+
where
112+
T: Upload,
113+
{
114+
let sink = self.0.copy_in(&T::copy()).await?;
115+
let writer = BinaryCopyInWriter::new(sink, T::columns());
116+
futures::pin_mut!(writer);
117+
let ref mut count = [0u8; 2];
118+
for ref mut reader in T::sources()
119+
.iter()
120+
.map(|s| File::open(s).expect("file not found"))
121+
.map(|f| BufReader::new(f))
122+
{
123+
reader.seek(std::io::SeekFrom::Start(19)).unwrap();
124+
while let Ok(_) = reader.read_exact(count) {
125+
match u16::from_be_bytes(count.clone()) {
126+
0xFFFF => break,
127+
length => {
128+
assert!(length == T::columns().len() as u16);
129+
let row = Self::read::<T>(reader);
130+
let row = row.iter().map(|b| &**b).collect::<Vec<_>>();
131+
writer.as_mut().write(&row).await?;
132+
}
133+
}
134+
}
135+
}
136+
writer.finish().await?;
137+
Ok(())
138+
}
139+
106140
async fn vacuum(&self) -> Result<(), E> {
107141
self.0.batch_execute("VACUUM ANALYZE;").await
108142
}
@@ -116,7 +150,7 @@ impl Writer {
116150
",
117151
table
118152
);
119-
Ok(0 != self.0.query_one(sql, &[]).await?.get::<_, i64>(0))
153+
Ok(!self.0.query(sql, &[]).await?.is_empty())
120154
} else {
121155
Ok(false)
122156
}
@@ -130,6 +164,27 @@ impl Writer {
130164
",
131165
table
132166
);
133-
Ok(1 == self.0.query_one(sql, &[]).await?.get::<_, i64>(0))
167+
Ok(!self.0.query(sql, &[]).await?.is_empty())
168+
}
169+
170+
fn read<T>(reader: &mut BufReader<File>) -> Vec<Box<dyn ToSql + Sync>>
171+
where
172+
T: Upload,
173+
{
174+
T::columns()
175+
.iter()
176+
.cloned()
177+
.map(|ty| match ty {
178+
Type::FLOAT4 => {
179+
assert!(reader.read_u32::<BE>().expect("length") == 4);
180+
Box::new(reader.read_f32::<BE>().expect("data")) as Box<dyn ToSql + Sync>
181+
}
182+
Type::INT8 => {
183+
assert!(reader.read_u32::<BE>().expect("length") == 8);
184+
Box::new(reader.read_i64::<BE>().expect("data")) as Box<dyn ToSql + Sync>
185+
}
186+
_ => panic!("unsupported type: {}", ty),
187+
})
188+
.collect::<Vec<Box<dyn ToSql + Sync>>>()
134189
}
135190
}

0 commit comments

Comments
 (0)