Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common: postops: add post-ops support for binary select op #2900

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion doc/primitives/binary.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ argument index as specified by the following table.
| \f$\src_1\f$ | DNNL_ARG_SRC_1 |
| \f$\src_2\f$ | DNNL_ARG_SRC_2 |
| \dst | DNNL_ARG_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |
| \f$binary scale0\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_SRC_0 |
| \f$binary scale1\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_SRC_1 |

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ argument index as specified by the following table.
| \diffbias | DNNL_ARG_DIFF_BIAS |
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$depthwise\f$ | DNNL_ARG_ATTR_POST_OP_DW |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1, |
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |
| \f$\text{prelu post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(prelu_post_op_position) \| DNNL_ARG_WEIGHTS |

## Implementation Details
Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/eltwise.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ argument index as specified by the following table.
| \dst | DNNL_ARG_DST |
| \diffsrc | DNNL_ARG_DIFF_SRC |
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/group_normalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ argument index as specified by the following table.
| \diffsrc | DNNL_ARG_DIFF_SRC |
| \f$\diffgamma\f$ | DNNL_ARG_DIFF_SCALE |
| \f$\diffbeta\f$ | DNNL_ARG_DIFF_SHIFT |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/inner_product.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ argument index as specified by the following table.
| \diffweights | DNNL_ARG_DIFF_WEIGHTS |
| \diffbias | DNNL_ARG_DIFF_BIAS |
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1, |
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |
| \f$\text{prelu post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(prelu_post_op_position) \| DNNL_ARG_WEIGHTS |

## Implementation Details
Expand Down
4 changes: 2 additions & 2 deletions doc/primitives/layer_normalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ argument index as specified by the following table.
| \diffbeta | DNNL_ARG_DIFF_SHIFT |
| \f$src scale\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_SRC |
| \f$dst scale\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |

| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ argument index as specified by the following table.
| \f$\text{dropout output mask}\f$ | DNNL_ARG_ATTR_DROPOUT_MASK |
| \f$\text{dropout probability}\f$ | DNNL_ARG_ATTR_DROPOUT_PROBABILITY |
| \f$\text{dropout rng seed}\f$ | DNNL_ARG_ATTR_DROPOUT_SEED |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1, |
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |
| \f$\text{prelu post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(prelu_post_op_position) \| DNNL_ARG_WEIGHTS |

## Implementation Details
Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/pooling.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ argument index as specified by the following table.
| workspace | DNNL_ARG_WORKSPACE |
| \diffsrc | DNNL_ARG_DIFF_SRC |
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/reduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ argument index as specified by the following table.
|-----------------------------|---------------------------------------------------------------------------|
| \src | DNNL_ARG_SRC |
| \dst | DNNL_ARG_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/resampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ argument index as specified by the following table.
| \dst | DNNL_ARG_DST |
| \diffsrc | DNNL_ARG_DIFF_SRC |
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
3 changes: 2 additions & 1 deletion doc/primitives/softmax.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ argument index as specified by the following table.
| \diffdst | DNNL_ARG_DIFF_DST |
| \f$src scale\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_SRC |
| \f$dst scale\f$ | DNNL_ARG_ATTR_SCALES \| DNNL_ARG_DST |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1,|
| | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_2 |

## Implementation Details

Expand Down
23 changes: 23 additions & 0 deletions doc/programming_model/attributes_post_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,29 @@ Currently the following scenarios are optimized:
post-operation which format may be queried from attributes using
`dnnl::post_ops::get_params_binary(...)` function call.

For the binary select operation, an additional conditional tensor is required
to execute the operation which is implemented using:

~~~cpp
void dnnl::post_ops::append_binary(
algorithm alg, // binary algorithm to apply
const memory::desc &src1 // memory descriptor for a second memory operand
const memory::desc &src2 // memory descriptor for a third memory operand
);
~~~

The `alg`, `src1` and `src2` parameters are the same as in
@ref dev_guide_binary.

The binary post-op thus becomes:

\f[
\dst[:] = \operatorname{binary}(\operatorname{Op}(...), Source\_1[:], Source\_2[:])
\f]

There is no broadcasting support for the conditional tensor. The select op is only
supported for CPU implementations.

@anchor dev_guide_attributes_post_ops_prelu
### Prelu Post-op

Expand Down
46 changes: 45 additions & 1 deletion include/oneapi/dnnl/dnnl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -736,6 +736,31 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);

/// Appends a binary post-op with ternary operators.
///
/// The kind of this post operation is #dnnl_binary.
///
/// In the simplest case when the binary is the only post operation, the
/// computations would be:
///
/// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
///
/// where binary_op is configured with the given parameters. binary_op supports
/// broadcast semantics for a second operand.
///
/// @param post_ops Post-ops.
/// @param alg_kind Binary algorithm for the post-op.
/// @param src1_desc Memory descriptor of a second operand.
/// @param src2_desc Memory descriptor of a third operand. If the specificed
/// algorithm is not one that requires a ternary input, src2_desc will be
/// ignored.

/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary_v2(dnnl_post_ops_t post_ops,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc,
const_dnnl_memory_desc_t src2_desc);

/// Returns the parameters of a binary post-op.
///
/// @param post_ops Post-ops.
Expand All @@ -750,6 +775,25 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
const_dnnl_memory_desc_t *src1_desc);

/// Returns the parameters of a binary post-op with ternary operators.
///
/// @param post_ops Post-ops.
/// @param index Index of the binary post-op.
/// @param alg_kind Output binary algorithm kind.
/// @param src1_desc Output memory descriptor of a second operand.
/// @param src2_desc Output memory descriptor of a third operand. If the
/// specificed algorithm is not one that requires a ternary input, src2_desc
/// will be ignored.

/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
/// post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary_v2(
const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
const_dnnl_memory_desc_t *src1_desc,
const_dnnl_memory_desc_t *src2_desc);

/// Appends a prelu forward post-op.
///
/// The kind of this post-op is #dnnl::primitive::kind::prelu.
Expand Down
53 changes: 53 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3867,6 +3867,32 @@ struct post_ops : public handle<dnnl_post_ops_t> {
"could not append a binary post-op");
}

/// Appends a binary post-op with ternary operators.
///
/// The kind of this post operation is #dnnl_binary.
///
/// In the simplest case when this is the only post operation, the
/// computations would be:
///
/// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
///
/// where binary_op is configured with the given parameters. binary_op
/// supports broadcast semantics only for the second operand and not for the
/// third operand.
///
/// @param aalgorithm Binary algorithm for the post-op.
/// @param src1_desc Memory descriptor of the second operand.
/// @param src2_desc Memory descriptor of the third operand. If the specificed
/// algorithm is not one that requires a ternary input, src2_desc will be
/// ignored.
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc,
const memory::desc &src2_desc) {
error::wrap_c_api(
dnnl_post_ops_append_binary_v2(get(), convert_to_c(aalgorithm),
src1_desc.get(), src2_desc.get()),
"could not append a binary post-op with ternary operators");
}

/// Returns the parameters of a binary post-op.
///
/// @param index Index of the binary post-op.
Expand All @@ -3886,6 +3912,33 @@ struct post_ops : public handle<dnnl_post_ops_t> {
src1_desc = memory::desc(cloned_md);
}

/// Returns the parameters of a binary post-op with ternary operators.
///
/// @param index Index of the binary post-op.
/// @param aalgorithm Output binary algorithm kind.
/// @param src1_desc Output memory descriptor of the second operand.
/// @param src2_desc Output memory descriptor of the third operand.
void get_params_binary(int index, algorithm &aalgorithm,
memory::desc &src1_desc, memory::desc &src2_desc) const {
dnnl_alg_kind_t c_alg;
const_dnnl_memory_desc_t cdesc1, cdesc2;
error::wrap_c_api(dnnl_post_ops_get_params_binary_v2(
get(), index, &c_alg, &cdesc1, &cdesc2),
"could not get parameters of a binary post-op with ternary "
"operators");
aalgorithm = static_cast<dnnl::algorithm>(c_alg);
dnnl_memory_desc_t cloned_md1 = nullptr;
dnnl_memory_desc_t cloned_md2 = nullptr;

error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md1, cdesc1),
"could not clone a memory descriptor");
src1_desc = memory::desc(cloned_md1);

error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md2, cdesc2),
"could not clone a memory descriptor");
src2_desc = memory::desc(cloned_md2);
}

/// Appends a prelu forward post-op.
///
/// The kind of this post-op is #dnnl::primitive::kind::prelu.
Expand Down
Loading
Loading