Skip to content

Commit 029af03

Browse files
authored
varbinview zip kernel (#4054)
Uses the buffer deduplicating builder to construct the zipped array. This guards against the pathological case where we are zipping two varbinview arrays with a mask that has lots of contiguous slices. Each `builder.extend_from_array(input.slice(..))` would duplicate the entire buffers of `input`, and each slice in the mask would add the same buffers to the result array over and over again. --------- Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent 195e80e commit 029af03

File tree

3 files changed

+129
-30
lines changed

3 files changed

+129
-30
lines changed

vortex-array/src/arrays/varbinview/compute/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod is_sorted;
88
mod mask;
99
mod min_max;
1010
mod take;
11+
mod zip;
1112

1213
#[cfg(test)]
1314
mod tests {
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::arrays::{VarBinViewArray, VarBinViewVTable};
7+
use crate::builders::VarBinViewBuilder;
8+
use crate::compute::{ZipKernel, ZipKernelAdapter, zip_impl_with_builder, zip_return_dtype};
9+
use crate::{Array, ArrayRef, register_kernel};
10+
11+
impl ZipKernel for VarBinViewVTable {
12+
fn zip(
13+
&self,
14+
if_true: &VarBinViewArray,
15+
if_false: &dyn Array,
16+
mask: &vortex_mask::Mask,
17+
) -> VortexResult<Option<ArrayRef>> {
18+
let Some(if_false) = if_false.as_opt::<VarBinViewVTable>() else {
19+
return Ok(None);
20+
};
21+
Ok(Some(zip_impl_with_builder(
22+
if_true.as_ref(),
23+
if_false.as_ref(),
24+
mask,
25+
Box::new(VarBinViewBuilder::with_buffer_deduplication(
26+
zip_return_dtype(if_true.as_ref(), if_false.as_ref()),
27+
if_true.len(),
28+
)),
29+
)?))
30+
}
31+
}
32+
33+
register_kernel!(ZipKernelAdapter(VarBinViewVTable).lift());
34+
35+
#[cfg(test)]
36+
mod tests {
37+
use arrow_array::cast::AsArray;
38+
use arrow_select::zip::zip as arrow_zip;
39+
use vortex_dtype::{DType, Nullability};
40+
use vortex_mask::Mask;
41+
42+
use crate::IntoArray;
43+
use crate::arrays::VarBinViewVTable;
44+
use crate::arrow::IntoArrowArray;
45+
use crate::builders::{ArrayBuilder as _, VarBinViewBuilder};
46+
use crate::compute::zip;
47+
48+
#[test]
49+
fn test_varbinview_zip() {
50+
let if_true = {
51+
let mut builder =
52+
VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 10);
53+
for _ in 0..100 {
54+
builder.append_value("Hello");
55+
builder.append_value("Hello this is a long string that won't be inlined.");
56+
}
57+
builder.finish()
58+
};
59+
60+
let if_false = {
61+
let mut builder =
62+
VarBinViewBuilder::with_capacity(DType::Utf8(Nullability::NonNullable), 10);
63+
for _ in 0..100 {
64+
builder.append_value("Hello2");
65+
builder.append_value("Hello2 this is a long string that won't be inlined.");
66+
}
67+
builder.finish()
68+
};
69+
70+
// [1,2,4,5,7,8,..]
71+
let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
72+
73+
let zipped = zip(&if_true, &if_false, &mask).unwrap();
74+
let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
75+
assert_eq!(zipped.nbuffers(), 2);
76+
77+
// assert the result is the same as arrow
78+
let expected = arrow_zip(
79+
mask.into_array()
80+
.into_arrow_preferred()
81+
.unwrap()
82+
.as_boolean(),
83+
&if_true.into_arrow_preferred().unwrap(),
84+
&if_false.into_arrow_preferred().unwrap(),
85+
)
86+
.unwrap();
87+
88+
let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
89+
assert_eq!(actual.as_ref(), expected.as_ref());
90+
}
91+
}

vortex-array/src/compute/zip.rs

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
99
use vortex_mask::{AllOr, Mask};
1010

1111
use super::{ComputeFnVTable, InvocationArgs, Output, cast};
12-
use crate::builders::builder_with_capacity;
12+
use crate::builders::{ArrayBuilder, builder_with_capacity};
1313
use crate::compute::{ComputeFn, Kernel};
1414
use crate::vtable::VTable;
1515
use crate::{Array, ArrayRef};
@@ -18,10 +18,10 @@ use crate::{Array, ArrayRef};
1818
///
1919
/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
2020
/// otherwise `result[i] = if_false[i]`.
21-
pub fn zip(mask: &Mask, if_true: &dyn Array, if_false: &dyn Array) -> VortexResult<ArrayRef> {
21+
pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
2222
ZIP_FN
2323
.invoke(&InvocationArgs {
24-
inputs: &[mask.into(), if_true.into(), if_false.into()],
24+
inputs: &[if_true.into(), if_false.into(), mask.into()],
2525
options: &(),
2626
})?
2727
.unwrap_array()
@@ -44,9 +44,9 @@ impl ComputeFnVTable for Zip {
4444
kernels: &[ArcRef<dyn Kernel>],
4545
) -> VortexResult<Output> {
4646
let ZipArgs {
47-
mask,
4847
if_true,
4948
if_false,
49+
mask,
5050
} = ZipArgs::try_from(args)?;
5151

5252
if mask.all_true() {
@@ -57,11 +57,6 @@ impl ComputeFnVTable for Zip {
5757
return Ok(cast(if_false, &zip_return_dtype(if_true, if_false))?.into());
5858
}
5959

60-
if if_true.is_canonical() && if_false.is_canonical() {
61-
// skip kernel lookup if both arrays are canonical
62-
return Ok(zip_impl(mask, if_true, if_false)?.into());
63-
}
64-
6560
// check if if_true supports zip directly
6661
for kernel in kernels {
6762
if let Some(output) = kernel.invoke(args)? {
@@ -77,9 +72,9 @@ impl ComputeFnVTable for Zip {
7772
// kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
7873

7974
Ok(zip_impl(
80-
mask,
8175
if_true.to_canonical()?.as_ref(),
8276
if_false.to_canonical()?.as_ref(),
77+
mask,
8378
)?
8479
.into())
8580
}
@@ -96,8 +91,11 @@ impl ComputeFnVTable for Zip {
9691
}
9792

9893
fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
99-
let ZipArgs { if_true, .. } = ZipArgs::try_from(args)?;
94+
let ZipArgs { if_true, mask, .. } = ZipArgs::try_from(args)?;
10095
// ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
96+
if if_true.len() != mask.len() {
97+
vortex_bail!("input arrays must have the same length as the mask");
98+
}
10199
Ok(if_true.len())
102100
}
103101

@@ -107,9 +105,9 @@ impl ComputeFnVTable for Zip {
107105
}
108106

109107
struct ZipArgs<'a> {
110-
mask: &'a Mask,
111108
if_true: &'a dyn Array,
112109
if_false: &'a dyn Array,
110+
mask: &'a Mask,
113111
}
114112

115113
impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
@@ -119,32 +117,33 @@ impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
119117
if value.inputs.len() != 3 {
120118
vortex_bail!("Expected 3 inputs for zip, found {}", value.inputs.len());
121119
}
122-
let mask = value.inputs[0]
123-
.mask()
124-
.ok_or_else(|| vortex_err!("Expected input 0 to be a mask"))?;
120+
let if_true = value.inputs[0]
121+
.array()
122+
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
125123

126-
let if_true = value.inputs[1]
124+
let if_false = value.inputs[1]
127125
.array()
128126
.ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
129127

130-
let if_false = value.inputs[2]
131-
.array()
132-
.ok_or_else(|| vortex_err!("Expected input 2 to be an array"))?;
128+
let mask = value.inputs[2]
129+
.mask()
130+
.ok_or_else(|| vortex_err!("Expected input 2 to be a mask"))?;
131+
133132
Ok(Self {
134-
mask,
135133
if_true,
136134
if_false,
135+
mask,
137136
})
138137
}
139138
}
140139

141140
pub trait ZipKernel: VTable {
142141
fn zip(
143142
&self,
144-
mask: &Mask,
145143
if_true: &Self::Array,
146144
if_false: &dyn Array,
147-
) -> VortexResult<ArrayRef>;
145+
mask: &Mask,
146+
) -> VortexResult<Option<ArrayRef>>;
148147
}
149148

150149
pub struct ZipKernelRef(pub ArcRef<dyn Kernel>);
@@ -162,27 +161,35 @@ impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
162161
impl<V: VTable + ZipKernel> Kernel for ZipKernelAdapter<V> {
163162
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
164163
let ZipArgs {
165-
mask,
166164
if_true,
167165
if_false,
166+
mask,
168167
} = ZipArgs::try_from(args)?;
169168
let Some(if_true) = if_true.as_opt::<V>() else {
170169
return Ok(None);
171170
};
172-
Ok(Some(V::zip(&self.0, mask, if_true, if_false)?.into()))
171+
Ok(V::zip(&self.0, if_true, if_false, mask)?.map(Into::into))
173172
}
174173
}
175174

176-
fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
175+
pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
177176
if_true
178177
.dtype()
179178
.union_nullability(if_false.dtype().nullability())
180179
}
181180

182-
fn zip_impl(mask: &Mask, if_true: &dyn Array, if_false: &dyn Array) -> VortexResult<ArrayRef> {
181+
fn zip_impl(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
183182
// if_true.len() == if_false.len() from ComputeFn::invoke
184-
let mut builder = builder_with_capacity(&zip_return_dtype(if_true, if_false), if_true.len());
183+
let builder = builder_with_capacity(&zip_return_dtype(if_true, if_false), if_true.len());
184+
zip_impl_with_builder(if_true, if_false, mask, builder)
185+
}
185186

187+
pub(crate) fn zip_impl_with_builder(
188+
if_true: &dyn Array,
189+
if_false: &dyn Array,
190+
mask: &Mask,
191+
mut builder: Box<dyn ArrayBuilder>,
192+
) -> VortexResult<ArrayRef> {
186193
match mask.slices() {
187194
AllOr::All => Ok(if_true.to_array()),
188195
AllOr::None => Ok(if_false.to_array()),
@@ -213,7 +220,7 @@ mod tests {
213220
let if_true = PrimitiveArray::from_iter([10, 20, 30, 40, 50]).into_array();
214221
let if_false = PrimitiveArray::from_iter([1, 2, 3, 4, 5]).into_array();
215222

216-
let result = zip(&mask, &if_true, &if_false).unwrap();
223+
let result = zip(&if_true, &if_false, &mask).unwrap();
217224
let expected = PrimitiveArray::from_iter([10, 2, 3, 40, 5]);
218225

219226
assert_eq!(
@@ -229,7 +236,7 @@ mod tests {
229236
let if_false =
230237
PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
231238

232-
let result = zip(&mask, &if_true, &if_false).unwrap();
239+
let result = zip(&if_true, &if_false, &mask).unwrap();
233240

234241
assert_eq!(
235242
result.to_primitive().unwrap().as_slice::<i32>(),
@@ -247,6 +254,6 @@ mod tests {
247254
let if_true = PrimitiveArray::from_iter([10, 20, 30]).into_array();
248255
let if_false = PrimitiveArray::from_iter([1, 2, 3, 4]).into_array();
249256

250-
zip(&mask, &if_true, &if_false).unwrap();
257+
zip(&if_true, &if_false, &mask).unwrap();
251258
}
252259
}

0 commit comments

Comments
 (0)