@@ -46,7 +46,9 @@ def convolution_backward(
46
46
47
47
return grad_input , grad_weight , grad_bias
48
48
49
+
49
50
if len (get_decompositions ([aten ._scaled_dot_product_flash_attention .default ])) == 0 :
51
+
50
52
@register_decomposition (aten ._scaled_dot_product_flash_attention .default )
51
53
def scaled_dot_product_flash_attention (
52
54
query ,
@@ -101,16 +103,197 @@ def scaled_dot_product_flash_attention(
101
103
102
104
103
105
def get_aot_decomposition_list ():
104
- return ([torch .ops .aten ._scaled_dot_product_flash_attention .default ,
105
- torch .ops .aten ._softmax .default ,
106
- torch .ops .aten ._softmax_backward_data .default ,
107
- torch .ops .aten .convolution_backward .default ,
108
- torch .ops .aten .gelu_backward .default ,
109
- torch .ops .aten .native_group_norm .default ,
110
- torch .ops .aten .native_group_norm_backward .default ,
111
- torch .ops .aten .native_layer_norm .default ,
112
- torch .ops .aten .native_layer_norm_backward .default ,
113
- torch .ops .aten .slice_backward .default ])
106
+ return [
107
+ torch .ops .aten ._scaled_dot_product_flash_attention .default ,
108
+ torch .ops .aten ._softmax .default ,
109
+ torch .ops .aten ._softmax_backward_data .default ,
110
+ torch .ops .aten .convolution_backward .default ,
111
+ torch .ops .aten .gelu_backward .default ,
112
+ torch .ops .aten .native_group_norm .default ,
113
+ torch .ops .aten .native_group_norm_backward .default ,
114
+ torch .ops .aten .native_layer_norm .default ,
115
+ torch .ops .aten .native_layer_norm_backward .default ,
116
+ torch .ops .aten .slice_backward .default ,
117
+ ]
118
+
114
119
115
120
def get_inf_decomposition_list ():
116
- return ([torch .ops .aten .nll_loss_forward .default ])
121
+ return [torch .ops .aten .nll_loss_forward .default ]
122
+
123
+
124
+ def get_export_decomposition_list ():
125
+ # List of decompositions from torch._decomp.core_aten_decompositions
126
+ # removed _backward ops and ops supported without decomposition
127
+ decomp = [
128
+ torch .ops .aten .addcdiv ,
129
+ torch .ops .aten .addcdiv_ ,
130
+ torch .ops .aten .addcmul ,
131
+ torch .ops .aten .addcmul_ ,
132
+ torch .ops .aten .addr ,
133
+ torch .ops .aten .affine_grid_generator ,
134
+ torch .ops .aten .all ,
135
+ torch .ops .aten .aminmax ,
136
+ torch .ops .aten .arange .default ,
137
+ torch .ops .aten .arange .start ,
138
+ torch .ops .aten .baddbmm ,
139
+ torch .ops .aten .binary_cross_entropy ,
140
+ torch .ops .aten .binary_cross_entropy_with_logits ,
141
+ torch .ops .aten .block_diag ,
142
+ torch .ops .aten .celu ,
143
+ torch .ops .aten .celu_ ,
144
+ torch .ops .aten .clamp_max ,
145
+ torch .ops .aten .clamp_min ,
146
+ torch .ops .aten .count_nonzero ,
147
+ torch .ops .aten .linalg_cross ,
148
+ torch .ops .aten .cudnn_batch_norm ,
149
+ torch .ops .aten .deg2rad ,
150
+ torch .ops .aten .deg2rad_ ,
151
+ torch .ops .aten .detach ,
152
+ torch .ops .aten .diag_embed ,
153
+ torch .ops .aten .dot ,
154
+ torch .ops .aten .vdot ,
155
+ torch .ops .aten .elu ,
156
+ torch .ops .aten .elu_ ,
157
+ torch .ops .aten ._embedding_bag ,
158
+ torch .ops .aten .empty_like ,
159
+ torch .ops .aten ._euclidean_dist .default ,
160
+ torch .ops .aten .expand_as ,
161
+ torch .ops .aten .eye ,
162
+ torch .ops .aten .fill ,
163
+ torch .ops .aten .fill_ ,
164
+ torch .ops .aten .floor_divide ,
165
+ torch .ops .aten .frac ,
166
+ torch .ops .aten .frac_ ,
167
+ torch .ops .aten ._fused_moving_avg_obs_fq_helper ,
168
+ torch .ops .aten .gelu_ ,
169
+ torch .ops .aten .glu ,
170
+ torch .ops .aten .hardshrink ,
171
+ torch .ops .aten .hardsigmoid ,
172
+ torch .ops .aten .hardsigmoid_ ,
173
+ torch .ops .aten .hardswish ,
174
+ torch .ops .aten .hardswish_ ,
175
+ torch .ops .aten .hardtanh_ ,
176
+ torch .ops .aten .heaviside ,
177
+ torch .ops .aten .heaviside_ ,
178
+ torch .ops .aten .huber_loss ,
179
+ torch .ops .aten .im2col ,
180
+ torch .ops .aten .index_add ,
181
+ torch .ops .aten .index_add_ ,
182
+ torch .ops .aten .index_copy ,
183
+ torch .ops .aten .index_copy_ ,
184
+ torch .ops .aten .index_fill ,
185
+ torch .ops .aten .index_fill_ ,
186
+ torch .ops .aten .isin ,
187
+ torch .ops .aten .isneginf ,
188
+ torch .ops .aten .isposinf ,
189
+ torch .ops .aten .l1_loss ,
190
+ torch .ops .aten .leaky_relu_ ,
191
+ torch .ops .aten .lerp ,
192
+ torch .ops .aten .lerp_ ,
193
+ torch .ops .aten .linspace ,
194
+ torch .ops .aten .logaddexp ,
195
+ torch .ops .aten .logaddexp2 ,
196
+ torch .ops .aten .logit ,
197
+ torch .ops .aten .logit_ ,
198
+ torch .ops .aten .log_sigmoid_forward ,
199
+ torch .ops .aten .logspace ,
200
+ torch .ops .aten .logsumexp .default ,
201
+ torch .ops .aten .masked_fill ,
202
+ torch .ops .aten .masked_fill_ ,
203
+ torch .ops .aten .mish ,
204
+ torch .ops .aten .mish_ ,
205
+ torch .ops .aten .mse_loss ,
206
+ torch .ops .aten .multi_margin_loss ,
207
+ torch .ops .aten .multilabel_margin_loss_forward ,
208
+ torch .ops .aten .mv ,
209
+ torch .ops .aten .mvlgamma ,
210
+ torch .ops .aten .mvlgamma_ ,
211
+ torch .ops .aten .nansum ,
212
+ torch .ops .aten .nan_to_num ,
213
+ torch .ops .aten .nan_to_num_ ,
214
+ torch .ops .aten .narrow ,
215
+ torch .ops .aten .new_empty ,
216
+ torch .ops .aten .new_full ,
217
+ torch .ops .aten .new_ones ,
218
+ torch .ops .aten .new_zeros ,
219
+ torch .ops .aten .nll_loss_forward ,
220
+ torch .ops .aten .norm ,
221
+ torch .ops .aten .ones ,
222
+ torch .ops .aten .ones_like ,
223
+ torch .ops .aten ._prelu_kernel ,
224
+ torch .ops .aten ._reshape_alias ,
225
+ torch .ops .aten .rad2deg ,
226
+ torch .ops .aten .rad2deg_ ,
227
+ torch .ops .aten .reflection_pad1d ,
228
+ torch .ops .aten .reflection_pad2d ,
229
+ torch .ops .aten .reflection_pad3d ,
230
+ torch .ops .aten .replication_pad1d ,
231
+ torch .ops .aten .replication_pad2d ,
232
+ torch .ops .aten .replication_pad3d ,
233
+ torch .ops .aten .renorm ,
234
+ torch .ops .aten .renorm_ ,
235
+ torch .ops .aten .resize_as ,
236
+ torch .ops .aten .roll ,
237
+ torch .ops .aten .rot90 ,
238
+ torch .ops .aten .rrelu_with_noise ,
239
+ torch .ops .aten .rrelu_with_noise_ ,
240
+ torch .ops .aten .rsub ,
241
+ torch .ops .aten .select_scatter ,
242
+ torch .ops .aten .sgn ,
243
+ torch .ops .aten .sgn_ ,
244
+ torch .ops .aten .silu ,
245
+ torch .ops .aten .silu_ ,
246
+ torch .ops .aten .sinc ,
247
+ torch .ops .aten .sinc_ ,
248
+ torch .ops .aten .smooth_l1_loss ,
249
+ torch .ops .aten .soft_margin_loss ,
250
+ torch .ops .aten .softplus ,
251
+ torch .ops .aten .softshrink ,
252
+ torch .ops .aten .special_entr ,
253
+ torch .ops .aten .special_log_ndtr ,
254
+ torch .ops .aten .special_xlog1py ,
255
+ torch .ops .aten .split .Tensor ,
256
+ torch .ops .aten .split_with_sizes_copy ,
257
+ torch .ops .aten .squeeze .default ,
258
+ torch .ops .aten .squeeze .dim ,
259
+ torch .ops .aten .std ,
260
+ torch .ops .aten .std_mean ,
261
+ torch .ops .aten .stack ,
262
+ torch .ops .aten .sum .default ,
263
+ torch .ops .aten .sum .out ,
264
+ torch .ops .aten .t ,
265
+ torch .ops .aten .take ,
266
+ torch .ops .aten .threshold ,
267
+ torch .ops .aten .threshold_ ,
268
+ torch .ops .aten .trace ,
269
+ torch .ops .aten .transpose .int ,
270
+ torch .ops .aten .tril ,
271
+ torch .ops .aten .tril_ ,
272
+ torch .ops .aten .triu ,
273
+ torch .ops .aten .triu_ ,
274
+ torch .ops .aten .unbind ,
275
+ torch .ops .aten .unfold_copy ,
276
+ torch .ops .aten ._unsafe_index ,
277
+ torch .ops .aten .unsafe_split .Tensor ,
278
+ torch .ops .aten .unsafe_split_with_sizes ,
279
+ torch .ops .aten ._unsafe_view ,
280
+ torch .ops .aten .view_as_complex ,
281
+ torch .ops .aten .xlogy ,
282
+ torch .ops .aten .xlogy_ ,
283
+ torch .ops .aten .zero ,
284
+ torch .ops .aten .zero_ ,
285
+ torch .ops .aten .zeros ,
286
+ torch .ops .aten .zeros_like ,
287
+ torch .ops .aten ._weight_norm_interface ,
288
+ ]
289
+ try :
290
+ from packaging import version
291
+ if version .parse (torch .__version__ ) >= version .parse ("2.3" ):
292
+ decomp += [
293
+ torch .ops .aten ._lazy_clone ,
294
+ torch .ops .aten ._test_parallel_materialize ,
295
+ torch .ops .aten ._chunk_cat ,
296
+ ]
297
+ except ImportError :
298
+ pass
299
+ return decomp
0 commit comments