Skip to content

Commit

Permalink
Fix the bug when sender was never dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Dec 14, 2024
1 parent a4af360 commit bc69244
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions ipa-core/src/query/runner/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ use std::{
convert::{Infallible, Into},
marker::PhantomData,
ops::Add,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use futures::{stream::iter, StreamExt, TryStreamExt};
use futures_util::TryFutureExt;
use futures::{stream::iter, Stream, StreamExt, TryStreamExt};
use futures_util::{stream, TryFutureExt};
use generic_array::ArrayLength;
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc::Receiver;

use super::QueryResult;
use crate::{
Expand Down Expand Up @@ -73,6 +75,23 @@ impl<C, HV, R: PrivateKeyRegistry> Query<C, HV, R> {
}
}

struct KnownSizeReceiverStream<T> {
rx: Receiver<T>,
sz: usize,
}

impl<T> Stream for KnownSizeReceiverStream<T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.sz))
}
}

impl<C, HV, R> Query<C, HV, R>
where
C: UpgradableContext
Expand Down Expand Up @@ -153,11 +172,18 @@ where

let (tx, rx) = tokio::sync::mpsc::channel(ctx.active_work().get());
let (_, (decrypted_reports, resharded_tags)) = assert_send(futures::future::try_join(
seq_join(ctx.active_work(), stream)
.try_for_each(|(report, tag)| tx.send((report, tag)).map_err(|_| Error::Internal)),
{
let f = seq_join(ctx.active_work(), stream)
.zip(stream::repeat(tx))
.map(|(r, tx)| r.map(|v| (v, tx)))
.try_for_each(|((report, tag), tx)| async move {
tx.send((report, tag)).map_err(|_| Error::Internal).await
});
f
},
reshard_aad(
ctx.narrow(&HybridStep::ReshardByTag),
ReceiverStream::new(rx).map(Ok),
KnownSizeReceiverStream { rx, sz }.map(Ok),
|ctx, _, tag| tag.shard_picker(ctx.shard_count()),
),
))
Expand Down Expand Up @@ -229,6 +255,7 @@ mod tests {
use rand_core::SeedableRng;

use crate::{
executor::IpaRuntime,
ff::{
boolean_array::{BA3, BA32, BA8},
U128Conversions,
Expand Down Expand Up @@ -335,6 +362,7 @@ mod tests {
HybridQuery::<_, BA32, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
IpaRuntime::current(),
)
.execute(ctx, query_size, input)
})
Expand Down Expand Up @@ -418,6 +446,7 @@ mod tests {
HybridQuery::<_, BA32, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
IpaRuntime::current(),
)
.execute(ctx, query_size, input)
})
Expand Down Expand Up @@ -464,6 +493,7 @@ mod tests {
HybridQuery::<_, BA32, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
IpaRuntime::current(),
)
.execute(ctx, query_size, input)
})
Expand Down

0 comments on commit bc69244

Please sign in to comment.