@@ -9,7 +9,7 @@ use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
9
9
use vortex_mask:: { AllOr , Mask } ;
10
10
11
11
use super :: { ComputeFnVTable , InvocationArgs , Output , cast} ;
12
- use crate :: builders:: builder_with_capacity;
12
+ use crate :: builders:: { ArrayBuilder , builder_with_capacity} ;
13
13
use crate :: compute:: { ComputeFn , Kernel } ;
14
14
use crate :: vtable:: VTable ;
15
15
use crate :: { Array , ArrayRef } ;
@@ -18,10 +18,10 @@ use crate::{Array, ArrayRef};
18
18
///
19
19
/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
20
20
/// otherwise `result[i] = if_false[i]`.
21
- pub fn zip ( mask : & Mask , if_true : & dyn Array , if_false : & dyn Array ) -> VortexResult < ArrayRef > {
21
+ pub fn zip ( if_true : & dyn Array , if_false : & dyn Array , mask : & Mask ) -> VortexResult < ArrayRef > {
22
22
ZIP_FN
23
23
. invoke ( & InvocationArgs {
24
- inputs : & [ mask . into ( ) , if_true . into ( ) , if_false . into ( ) ] ,
24
+ inputs : & [ if_true . into ( ) , if_false . into ( ) , mask . into ( ) ] ,
25
25
options : & ( ) ,
26
26
} ) ?
27
27
. unwrap_array ( )
@@ -44,9 +44,9 @@ impl ComputeFnVTable for Zip {
44
44
kernels : & [ ArcRef < dyn Kernel > ] ,
45
45
) -> VortexResult < Output > {
46
46
let ZipArgs {
47
- mask,
48
47
if_true,
49
48
if_false,
49
+ mask,
50
50
} = ZipArgs :: try_from ( args) ?;
51
51
52
52
if mask. all_true ( ) {
@@ -57,11 +57,6 @@ impl ComputeFnVTable for Zip {
57
57
return Ok ( cast ( if_false, & zip_return_dtype ( if_true, if_false) ) ?. into ( ) ) ;
58
58
}
59
59
60
- if if_true. is_canonical ( ) && if_false. is_canonical ( ) {
61
- // skip kernel lookup if both arrays are canonical
62
- return Ok ( zip_impl ( mask, if_true, if_false) ?. into ( ) ) ;
63
- }
64
-
65
60
// check if if_true supports zip directly
66
61
for kernel in kernels {
67
62
if let Some ( output) = kernel. invoke ( args) ? {
@@ -77,9 +72,9 @@ impl ComputeFnVTable for Zip {
77
72
// kernel.invoke(Args(if_false, if_true, mask, invert_mask = true))
78
73
79
74
Ok ( zip_impl (
80
- mask,
81
75
if_true. to_canonical ( ) ?. as_ref ( ) ,
82
76
if_false. to_canonical ( ) ?. as_ref ( ) ,
77
+ mask,
83
78
) ?
84
79
. into ( ) )
85
80
}
@@ -96,8 +91,11 @@ impl ComputeFnVTable for Zip {
96
91
}
97
92
98
93
fn return_len ( & self , args : & InvocationArgs ) -> VortexResult < usize > {
99
- let ZipArgs { if_true, .. } = ZipArgs :: try_from ( args) ?;
94
+ let ZipArgs { if_true, mask , .. } = ZipArgs :: try_from ( args) ?;
100
95
// ComputeFn::invoke asserts if_true.len() == if_false.len(), because zip is elementwise
96
+ if if_true. len ( ) != mask. len ( ) {
97
+ vortex_bail ! ( "input arrays must have the same length as the mask" ) ;
98
+ }
101
99
Ok ( if_true. len ( ) )
102
100
}
103
101
@@ -107,9 +105,9 @@ impl ComputeFnVTable for Zip {
107
105
}
108
106
109
107
struct ZipArgs < ' a > {
110
- mask : & ' a Mask ,
111
108
if_true : & ' a dyn Array ,
112
109
if_false : & ' a dyn Array ,
110
+ mask : & ' a Mask ,
113
111
}
114
112
115
113
impl < ' a > TryFrom < & InvocationArgs < ' a > > for ZipArgs < ' a > {
@@ -119,32 +117,33 @@ impl<'a> TryFrom<&InvocationArgs<'a>> for ZipArgs<'a> {
119
117
if value. inputs . len ( ) != 3 {
120
118
vortex_bail ! ( "Expected 3 inputs for zip, found {}" , value. inputs. len( ) ) ;
121
119
}
122
- let mask = value. inputs [ 0 ]
123
- . mask ( )
124
- . ok_or_else ( || vortex_err ! ( "Expected input 0 to be a mask " ) ) ?;
120
+ let if_true = value. inputs [ 0 ]
121
+ . array ( )
122
+ . ok_or_else ( || vortex_err ! ( "Expected input 0 to be an array " ) ) ?;
125
123
126
- let if_true = value. inputs [ 1 ]
124
+ let if_false = value. inputs [ 1 ]
127
125
. array ( )
128
126
. ok_or_else ( || vortex_err ! ( "Expected input 1 to be an array" ) ) ?;
129
127
130
- let if_false = value. inputs [ 2 ]
131
- . array ( )
132
- . ok_or_else ( || vortex_err ! ( "Expected input 2 to be an array" ) ) ?;
128
+ let mask = value. inputs [ 2 ]
129
+ . mask ( )
130
+ . ok_or_else ( || vortex_err ! ( "Expected input 2 to be a mask" ) ) ?;
131
+
133
132
Ok ( Self {
134
- mask,
135
133
if_true,
136
134
if_false,
135
+ mask,
137
136
} )
138
137
}
139
138
}
140
139
141
140
pub trait ZipKernel : VTable {
142
141
fn zip (
143
142
& self ,
144
- mask : & Mask ,
145
143
if_true : & Self :: Array ,
146
144
if_false : & dyn Array ,
147
- ) -> VortexResult < ArrayRef > ;
145
+ mask : & Mask ,
146
+ ) -> VortexResult < Option < ArrayRef > > ;
148
147
}
149
148
150
149
pub struct ZipKernelRef ( pub ArcRef < dyn Kernel > ) ;
@@ -162,27 +161,35 @@ impl<V: VTable + ZipKernel> ZipKernelAdapter<V> {
162
161
impl < V : VTable + ZipKernel > Kernel for ZipKernelAdapter < V > {
163
162
fn invoke ( & self , args : & InvocationArgs ) -> VortexResult < Option < Output > > {
164
163
let ZipArgs {
165
- mask,
166
164
if_true,
167
165
if_false,
166
+ mask,
168
167
} = ZipArgs :: try_from ( args) ?;
169
168
let Some ( if_true) = if_true. as_opt :: < V > ( ) else {
170
169
return Ok ( None ) ;
171
170
} ;
172
- Ok ( Some ( V :: zip ( & self . 0 , mask , if_true, if_false) ?. into ( ) ) )
171
+ Ok ( V :: zip ( & self . 0 , if_true, if_false, mask ) ?. map ( Into :: into ) )
173
172
}
174
173
}
175
174
176
- fn zip_return_dtype ( if_true : & dyn Array , if_false : & dyn Array ) -> DType {
175
+ pub ( crate ) fn zip_return_dtype ( if_true : & dyn Array , if_false : & dyn Array ) -> DType {
177
176
if_true
178
177
. dtype ( )
179
178
. union_nullability ( if_false. dtype ( ) . nullability ( ) )
180
179
}
181
180
182
- fn zip_impl ( mask : & Mask , if_true : & dyn Array , if_false : & dyn Array ) -> VortexResult < ArrayRef > {
181
+ fn zip_impl ( if_true : & dyn Array , if_false : & dyn Array , mask : & Mask ) -> VortexResult < ArrayRef > {
183
182
// if_true.len() == if_false.len() from ComputeFn::invoke
184
- let mut builder = builder_with_capacity ( & zip_return_dtype ( if_true, if_false) , if_true. len ( ) ) ;
183
+ let builder = builder_with_capacity ( & zip_return_dtype ( if_true, if_false) , if_true. len ( ) ) ;
184
+ zip_impl_with_builder ( if_true, if_false, mask, builder)
185
+ }
185
186
187
+ pub ( crate ) fn zip_impl_with_builder (
188
+ if_true : & dyn Array ,
189
+ if_false : & dyn Array ,
190
+ mask : & Mask ,
191
+ mut builder : Box < dyn ArrayBuilder > ,
192
+ ) -> VortexResult < ArrayRef > {
186
193
match mask. slices ( ) {
187
194
AllOr :: All => Ok ( if_true. to_array ( ) ) ,
188
195
AllOr :: None => Ok ( if_false. to_array ( ) ) ,
@@ -213,7 +220,7 @@ mod tests {
213
220
let if_true = PrimitiveArray :: from_iter ( [ 10 , 20 , 30 , 40 , 50 ] ) . into_array ( ) ;
214
221
let if_false = PrimitiveArray :: from_iter ( [ 1 , 2 , 3 , 4 , 5 ] ) . into_array ( ) ;
215
222
216
- let result = zip ( & mask , & if_true , & if_false ) . unwrap ( ) ;
223
+ let result = zip ( & if_true , & if_false , & mask ) . unwrap ( ) ;
217
224
let expected = PrimitiveArray :: from_iter ( [ 10 , 2 , 3 , 40 , 5 ] ) ;
218
225
219
226
assert_eq ! (
@@ -229,7 +236,7 @@ mod tests {
229
236
let if_false =
230
237
PrimitiveArray :: from_option_iter ( [ Some ( 1 ) , Some ( 2 ) , Some ( 3 ) , None ] ) . into_array ( ) ;
231
238
232
- let result = zip ( & mask , & if_true , & if_false ) . unwrap ( ) ;
239
+ let result = zip ( & if_true , & if_false , & mask ) . unwrap ( ) ;
233
240
234
241
assert_eq ! (
235
242
result. to_primitive( ) . unwrap( ) . as_slice:: <i32 >( ) ,
@@ -247,6 +254,6 @@ mod tests {
247
254
let if_true = PrimitiveArray :: from_iter ( [ 10 , 20 , 30 ] ) . into_array ( ) ;
248
255
let if_false = PrimitiveArray :: from_iter ( [ 1 , 2 , 3 , 4 ] ) . into_array ( ) ;
249
256
250
- zip ( & mask , & if_true , & if_false ) . unwrap ( ) ;
257
+ zip ( & if_true , & if_false , & mask ) . unwrap ( ) ;
251
258
}
252
259
}
0 commit comments