Skip to content

Commit

Permalink
feat: leader participation (#50)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi authored Nov 7, 2024
1 parent 784edb8 commit f8f6453
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::Arc;

pub trait HeapRelation {
fn traverse<F>(&self, callback: F)
fn traverse<F>(&self, progress: bool, callback: F)
where
F: FnMut((Pointer, Vec<f32>));
fn opfamily(&self) -> Opfamily;
Expand Down Expand Up @@ -53,7 +53,7 @@ pub fn build<T: HeapRelation, R: Reporter>(
let max_number_of_samples = internal_build.nlist.saturating_mul(256);
let mut samples = Vec::new();
let mut number_of_samples = 0_u32;
heap_relation.traverse(|(_, vector)| {
heap_relation.traverse(false, |(_, vector)| {
assert_eq!(dims as usize, vector.len(), "invalid vector dimensions");
if number_of_samples < max_number_of_samples {
samples.extend(vector);
Expand Down
69 changes: 52 additions & 17 deletions src/index/am.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub unsafe extern "C" fn ambuild(
opfamily: Opfamily,
}
impl HeapRelation for Heap {
fn traverse<F>(&self, callback: F)
fn traverse<F>(&self, progress: bool, callback: F)
where
F: FnMut((Pointer, Vec<f32>)),
{
Expand Down Expand Up @@ -191,7 +191,7 @@ pub unsafe extern "C" fn ambuild(
self.index_info,
true,
false,
true,
progress,
0,
pgrx::pg_sys::InvalidBlockNumber,
Some(call::<F>),
Expand Down Expand Up @@ -246,6 +246,15 @@ pub unsafe extern "C" fn ambuild(
unsafe { RabbitholeLeader::enter(heap, index, (*index_info).ii_Concurrent) }
{
unsafe {
parallel_build(
index,
heap,
index_info,
leader.tablescandesc,
leader.rabbitholeshared,
true,
);
leader.wait();
let nparticipants = leader.nparticipants;
loop {
pgrx::pg_sys::SpinLockAcquire(&raw mut (*leader.rabbitholeshared).mutex);
Expand All @@ -264,7 +273,7 @@ pub unsafe extern "C" fn ambuild(
} else {
let mut tuples_done = 0;
reporter.tuples_done(tuples_done);
heap_relation.traverse(|(payload, vector)| {
heap_relation.traverse(true, |(payload, vector)| {
algorithm::insert::insert(
index_relation.clone(),
payload,
Expand Down Expand Up @@ -306,6 +315,7 @@ struct RabbitholeLeader {
pcxt: *mut pgrx::pg_sys::ParallelContext,
nparticipants: i32,
rabbitholeshared: *mut RabbitholeShared,
tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData,
snapshot: pgrx::pg_sys::Snapshot,
}

Expand Down Expand Up @@ -417,10 +427,10 @@ impl RabbitholeLeader {
pgrx::pg_sys::LaunchParallelWorkers(pcxt);
}

let nparticipants = unsafe { (*pcxt).nworkers_launched };
let nworkers_launched = unsafe { (*pcxt).nworkers_launched };

unsafe {
if nparticipants == 0 {
if nworkers_launched == 0 {
pgrx::pg_sys::WaitForParallelWorkersToFinish(pcxt);
if is_mvcc_snapshot(snapshot) {
pgrx::pg_sys::UnregisterSnapshot(snapshot);
Expand All @@ -430,16 +440,21 @@ impl RabbitholeLeader {
return None;
}
}
unsafe {
pgrx::pg_sys::WaitForParallelWorkersToAttach(pcxt);
}

Some(Self {
pcxt,
nparticipants,
nparticipants: nworkers_launched + 1,
rabbitholeshared,
tablescandesc,
snapshot,
})
}

pub fn wait(&self) {
unsafe {
pgrx::pg_sys::WaitForParallelWorkersToAttach(self.pcxt);
}
}
}

impl Drop for RabbitholeLeader {
Expand Down Expand Up @@ -486,6 +501,31 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
(*index_info).ii_Concurrent = (*rabbitholeshared).isconcurrent;
}

unsafe {
parallel_build(
index,
heap,
index_info,
tablescandesc,
rabbitholeshared,
false,
);
}

unsafe {
pgrx::pg_sys::index_close(index, index_lockmode);
pgrx::pg_sys::table_close(heap, heap_lockmode);
}
}

unsafe fn parallel_build(
index: *mut pgrx::pg_sys::RelationData,
heap: pgrx::pg_sys::Relation,
index_info: *mut pgrx::pg_sys::IndexInfo,
tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData,
rabbitholeshared: *mut RabbitholeShared,
progress: bool,
) {
#[derive(Debug, Clone)]
pub struct Heap {
heap: pgrx::pg_sys::Relation,
Expand All @@ -495,7 +535,7 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
scan: *mut pgrx::pg_sys::TableScanDescData,
}
impl HeapRelation for Heap {
fn traverse<F>(&self, callback: F)
fn traverse<F>(&self, progress: bool, callback: F)
where
F: FnMut((Pointer, Vec<f32>)),
{
Expand Down Expand Up @@ -541,7 +581,7 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
self.index_info,
true,
false,
true,
progress,
0,
pgrx::pg_sys::InvalidBlockNumber,
Some(call::<F>),
Expand All @@ -566,7 +606,7 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
opfamily,
scan,
};
heap_relation.traverse(|(payload, vector)| {
heap_relation.traverse(progress, |(payload, vector)| {
algorithm::insert::insert(
index_relation.clone(),
payload,
Expand All @@ -581,11 +621,6 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main(
pgrx::pg_sys::SpinLockRelease(&raw mut (*rabbitholeshared).mutex);
pgrx::pg_sys::ConditionVariableSignal(&raw mut (*rabbitholeshared).workersdonecv);
}

unsafe {
pgrx::pg_sys::index_close(index, index_lockmode);
pgrx::pg_sys::table_close(heap, heap_lockmode);
}
}

#[pgrx::pg_guard]
Expand Down

0 comments on commit f8f6453

Please sign in to comment.