-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
common: postops: add post-ops support for binary select op #2900
base: main
Are you sure you want to change the base?
Conversation
tests/benchdnn/dnn_types.cpp
Outdated
@@ -590,6 +594,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks( | |||
: policy2mask( | |||
DNNL_ARG_SRC_1, e.binary.policy, prim_kind, ndims); | |||
arg = DNNL_ARG_SRC_1; | |||
|
|||
if (e.is_binary_kind_with_ternary_op()) { | |||
mask2 = e.binary.mask; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see it working if user specified the policy instead of a mask number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, there should be no mask requirement for the src2 tensor since broadcasting is disabled for the tensor and it has the same dimensions as src0. This has been updated in the code - thanks for the catch!
@@ -206,6 +206,9 @@ attr_t::post_ops_t parse_attr_post_ops_func(const std::string &s) { | |||
} else if (e.is_binary_kind()) { | |||
const auto dt_str = get_substr(subs, subs_pos, ':'); | |||
e.binary.src1_dt = str2dt(dt_str.c_str()); | |||
|
|||
if (e.is_binary_kind_with_ternary_op()) e.binary.src2_dt = dnnl_s8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No new inputs for benchdnn? Since we target sdpa fusion, I'd check that select post-op can be fused into matmul successfully.
After some thinking, it seems that the only change for post-op input is an algorithm, mask and format will be re-used and the data type is fixed. What about a broadcast case the Graph team asked for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since broadcasting support is not enabled for select op
, the broadcast example from the Graph team may not be of use for this case. Other than matmul
, are there any additional cases of interest - perhaps, fusion with binary
or conv
can also be helpful?
TIME_FILL(SAFE( | ||
fill_random_real(mem, ref_mem, res, binary_fill_cfg), | ||
WARN)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I'm following how it's working... I'd expect an } else if (exec_arg & DNNL_ARG_SRC_2) {
to appear instead of mixing stuff together... and a separate memory object to be filled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the explanation for the working in the comments. The filling for both the binary src1/2
tensors is the same except for the minimum fill value. For the src2
tensor, the condition (is_binary_src1_arg || attr.post_ops.entry[bin_po_idx].is_binary_kind_with_ternary_op())
ensures that the values are filled only when the binary algorithm needs a third input and ignored otherwise.
20dff17
to
7edca7d
Compare
Description
The PR adds updates post-ops implementations so that they support binary select operation as a post-op:
The updates account for the third conditional input that the select op requires for computation.
Implements MFDNN-12883.
Testcase: