Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: api, doc, interface, utils, backend: support bottom-right implicit causal mask #2967

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
@@ -68,11 +68,12 @@ optional.

2. Implicit library-generated mask: You can use the operations in the library
to generate a mask by constructing a subgraph. Currently, Graph API supports
generating an implicit causal mask (top-left aligned) using operations of
[GenIndex](@ref dev_guide_op_genindex), [GreaterEqual](@ref dev_guide_op_greaterequal)
generating an implicit causal mask (top-left or bottom-right aligned) using
operations of [GenIndex](@ref dev_guide_op_genindex), [Add](@ref dev_guide_op_add).
[Subtract](@ref dev_guide_op_subtract), [GreaterEqual](@ref dev_guide_op_greaterequal)
and [Select](@ref dev_guide_op_select).

![SDPA-mask-3](images/sdpa-mask-3.png)
![SDPA-mask-3](images/sdpa-mask-3.png) ![SDPA-mask-4](images/sdpa-mask-4.png)

4. The SoftMax operation takes the masked output and transforms it into
probabilities between 0 and 1. See [SoftMax](@ref dev_guide_op_softmax)
@@ -114,8 +115,7 @@ platforms follow the general description in @ref dev_guide_data_types.
softmax primitives. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
sequence length input on platforms with limited memory. For an implicit
causal mask, the reference implementation is only available on CPU.
sequence length input on platforms with limited memory.
2. The SDPA patterns functionally supports all input shapes meeting the shape
requirements of each operation in the graph. For example, Add, Multiply,
Divide, and Select operations require the input tensors to have the same
1 change: 1 addition & 0 deletions doc/graph/operations/Add.md
Original file line number Diff line number Diff line change
@@ -51,3 +51,4 @@ Add operation supports the following data type combinations.
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
| s32 | s32 | s32 |
1 change: 1 addition & 0 deletions doc/graph/operations/Subtract.md
Original file line number Diff line number Diff line change
@@ -51,3 +51,4 @@ Subtract operation supports the following data type combinations.
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
| s32 | s32 | s32 |
9 changes: 5 additions & 4 deletions doc/programming_model/data_types.md
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ in comparison to fp32.
| f16 | [IEEE half precision floating-point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16) |
| s8/u8 | signed/unsigned 8-bit integer |
| s4/u4 | signed/unsigned 4-bit integer |
| s32 | signed/unsigned 32-bit integer |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably could use a wiki page link. @vpirogov, could you share your thoughts, please?

| f64 | [IEEE double precision floating-point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format#IEEE_754_double-precision_binary_floating-point_format:_binary64) |
| boolean | bool (size is C++ implementation defined) |
| f8\_e5m2 | [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf) with 5 exponent and 2 mantissa bits |
@@ -29,10 +30,10 @@ in comparison to fp32.

oneDNN supports training and inference with the following data types:

| Usage mode | CPU | GPU |
|:-----------|:-----------------------------------------------------------------------------|:----------------------------------------------|
| Inference | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f4\_e2m1, f4\_e3m0, s8/u8, s4/u4, boolean | f32, bf16, f16, f8\_e5m2/f8\_e4m3, s8/u8, f64 |
| Training | f32, bf16, f16, f8\_e5m2/f8\_e4m3 | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f64 |
| Usage mode | CPU | GPU |
|:-----------|:----------------------------------------------------------------------------------|:------------------------------------------------------------|
| Inference | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f4\_e2m1, f4\_e3m0, s8/u8, s4/u4, s32, boolean | f32, bf16, f16, f8\_e5m2/f8\_e4m3, s8/u8, s32, f64, boolean |
| Training | f32, bf16, f16, f8\_e5m2/f8\_e4m3 | f32, bf16, f16, f8\_e5m2/f8\_e4m3, f64 |

@note
Using lower precision arithmetic may require changes in the deep learning
3 changes: 3 additions & 0 deletions include/oneapi/dnnl/dnnl_graph.hpp
Original file line number Diff line number Diff line change
@@ -306,6 +306,9 @@ class logical_tensor {
/// the library. For example, constant weight tensors in inference
/// scenarios.
constant = dnnl_graph_tensor_property_constant,
/// Host scalar means the tensor will be a 0-D scalar tensor on host.
/// It should be used with a CPU engine when creating the tensor.
host_scalar = dnnl_graph_tensor_property_host_scalar,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is host_scalar constant? how to set const property for host_scalar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

host_scalar is not a constant, we don't have a way to set it as a const currently. There was some discussion here, but no decision yet.

};

/// default constructor
3 changes: 3 additions & 0 deletions include/oneapi/dnnl/dnnl_graph_types.h
Original file line number Diff line number Diff line change
@@ -76,6 +76,9 @@ typedef enum {
/// optimizations for constant tensors or cache constant tensors inside the
/// library. For example, constant weight tensors in inference scenarios.
dnnl_graph_tensor_property_constant = 2,
/// Host scalar means the tensor will be a 0-D scalar tensor on host.
/// It should be used with a CPU engine when creating the tensor.
dnnl_graph_tensor_property_host_scalar = 3,
} dnnl_graph_tensor_property_t;

/// Logical tensor. It is based on an ID, a number of dimensions, dimensions
28 changes: 24 additions & 4 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
@@ -93,6 +93,20 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mul_scales, 1,
executable_creator<reorder_executable_t>)
.SET_ARG_INDICES_GETTER(reorder_executable_t))

DNNL_GRAPH_OP_SCHEMA(dnnl_host_scalar, 1,
op_schema_t()
.set_num_inputs(1)
.set_num_outputs(1)
.set_input(0, "scalar")
.set_output(0, "output")
.SET_ATTR_IS_CONSTANT // used for constant prop and cache
.set_shape_inference_function(
infer_dnnl_host_scalar_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_host_scalar)
.SET_EXECUTABLE_CREATOR(
executable_creator<host_scalar_executable_t>)
.SET_ARG_INDICES_GETTER(host_scalar_executable_t))

DNNL_GRAPH_OP_SCHEMA(dnnl_constant_scales, 1,
op_schema_t()
.set_num_inputs(0)
@@ -1119,14 +1133,20 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_groupnorm, 1,

DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1,
op_schema_t()
.set_num_inputs(2)
.set_inputs_option(op_schema_t::param_num_option::optional)
.set_num_inputs(std::set<size_t>({2, 4}))
.set_num_outputs(1)
.set_input(0, "input")
.set_input(1, "-inf")
.set_input(2, "s_kv")
.set_input(3, "s_q")
.set_output(0, "output")
// Attributes inherited from front gen_index ops
.set_attr(op_attr::axis_row, true, attribute_kind::i)
.set_attr(op_attr::axis_col, true, attribute_kind::i)
// mask_type attribute indicates existence of explicit mask,
// top-left implicit causal mask or bottm-right implicit causal mask
.set_attr(op_attr::mask_type, true, attribute_kind::i)
.SET_ATTR_IS_CONSTANT // used for constant prop and cache
// Analysis rules
.set_shape_inference_function(infer_identity_output_shape)
@@ -1152,9 +1172,9 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
.set_attr(op_attr::with_scale, true, attribute_kind::b)
.set_attr(op_attr::is_invert_scale, false, attribute_kind::b,
false)
.set_attr(op_attr::with_mask, true, attribute_kind::b)
// with_causal attribute support top-left mask type only
.set_attr(op_attr::with_causal, true, attribute_kind::b)
// mask_type attribute indicates existence of explicit mask,
// top-left implicit causal mask or bottm-right implicit causal mask
.set_attr(op_attr::mask_type, true, attribute_kind::i)
.set_shape_inference_function(infer_dnnl_sdpa_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa)
.SET_EXECUTABLE_CREATOR(executable_creator<sdpa_executable_t>)
2 changes: 2 additions & 0 deletions src/graph/backend/dnnl/dnnl_opset.hpp
Original file line number Diff line number Diff line change
@@ -73,6 +73,8 @@ class dnnl_opset_t {
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(
dnnl_eltwise_bwd, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_gen_index, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(
dnnl_host_scalar, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_mask, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_shuffle, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sum, 1)>());
10 changes: 10 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
@@ -602,6 +602,16 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n,
return status::success;
}

status_t infer_dnnl_host_scalar_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
// host scalar output is always strided
// with shape = {1}, strides = {1}
outputs[0]->layout_type = layout_type::strided;
set_shape_and_strides(*outputs[0], {1});
return status::success;
}

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
4 changes: 4 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
@@ -111,6 +111,10 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_dnnl_host_scalar_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
6 changes: 2 additions & 4 deletions src/graph/backend/dnnl/internal_attrs.hpp
Original file line number Diff line number Diff line change
@@ -47,8 +47,7 @@ const op_attr_t with_sum = 0x1000e;
const op_attr_t keep_dst_layout = 0x1000f;
const op_attr_t with_scale = 0x10010;
const op_attr_t is_invert_scale = 0x10011;
const op_attr_t with_causal = 0x10012;
const op_attr_t with_mask = 0x10013;
const op_attr_t mask_type = 0x10012;

// int64_t
const op_attr_t alg_kind = 0x10100;
@@ -92,8 +91,7 @@ static inline std::string internal_attr2str(op_attr_t attr) {
CASE(keep_dst_layout);
CASE(with_scale);
CASE(is_invert_scale);
CASE(with_causal);
CASE(with_mask);
CASE(mask_type);
CASE(alg_kind);
CASE(fusion_info_key);
CASE(axis_row);
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/internal_ops.hpp
Original file line number Diff line number Diff line change
@@ -80,7 +80,8 @@ namespace op_kind {
X(dnnl_groupnorm, Dnnl_groupnorm) \
X(dnnl_gen_index, Dnnl_gen_index) \
X(dnnl_mask, Dnnl_mask) \
X(dnnl_sdpa, Dnnl_sdpa)
X(dnnl_sdpa, Dnnl_sdpa) \
X(dnnl_host_scalar, Dnnl_host_scalar)

enum kind_t {
kDNNL_INTERNAL_OP_STARTER = 0x1234,
18 changes: 18 additions & 0 deletions src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
@@ -36,6 +36,8 @@ void larger_partition_kernel_t::setup_pipeline_stage1(
pass_pipeline_t &pipeline) {
// Directly lower down (1 to 1 mapping)
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
// handle the case that the input is a scalar tensor
BACKEND_DNNL_ADD_PASS(pipeline, insert_host_scalar);
// Decompose select to binary ops if necessary
BACKEND_DNNL_ADD_PASS(pipeline, decompose_select_to_binary_ops);

@@ -166,6 +168,19 @@ void larger_partition_kernel_t::setup_pipeline(pass_pipeline_t &pipeline,
setup_pipeline_stage2(pipeline, mem_planner, enable_constant_cache);
}

void larger_partition_kernel_t::prepare_host_scalar_args(
execution_args_set_t *res, const std::vector<tensor_t> &inputs) {
for (const auto &host_scalar_info : res->get_host_scalar_infos()) {
auto mem = make_dnnl_memory(host_scalar_info.md,
make_dnnl_engine(
*(inputs[host_scalar_info.input_idx].get_engine())),
inputs[host_scalar_info.input_idx].get_data_handle());
auto args = res->get_exec_args()[host_scalar_info.exec_idx];
args.insert({host_scalar_info.arg, mem});
res->reset_exec_args(host_scalar_info.exec_idx, args);
}
}

void larger_partition_kernel_t::prepare_args_set(
const execution_args_set_t *res, const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
@@ -251,6 +266,7 @@ status_t larger_partition_kernel_t::execute_impl(const stream_t *g_stream,
assertm(scratchpad.size()
>= memory_planner_.total_internal_temporary_size(),
"no enough scratchpad memory");
prepare_host_scalar_args(res, inputs);
prepare_args_set(res, inputs, outputs, scratchpad);

constant_cache_t::cached_t c_buffer;
@@ -321,6 +337,7 @@ status_t larger_partition_kernel_t::sycl_execute_impl(const stream_t *g_stream,
assertm(scratchpad.size()
>= memory_planner_.total_internal_temporary_size(),
"no enough scratchpad memory");
prepare_host_scalar_args(res, inputs);
prepare_args_set(res, inputs, outputs, scratchpad);

constant_cache_t::cached_t c_buffer;
@@ -396,6 +413,7 @@ status_t larger_partition_kernel_t::ocl_execute_impl(const stream_t *g_stream,
assertm(scratchpad.size()
>= memory_planner_.total_internal_temporary_size(),
"no enough scratchpad memory");
prepare_host_scalar_args(res, inputs);
prepare_args_set(res, inputs, outputs, scratchpad);

constant_cache_t::cached_t c_buffer;
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/large_partition.hpp
Original file line number Diff line number Diff line change
@@ -82,6 +82,9 @@ class larger_partition_kernel_t : public kernel_base_t {
return status::success;
}

void prepare_host_scalar_args(
execution_args_set_t *res, const std::vector<tensor_t> &inputs);

void prepare_args_set(const execution_args_set_t *res,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
12 changes: 7 additions & 5 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
@@ -86,12 +86,15 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
// 3. locate mask if have
if (post_op->get_kind() == op_kind::dnnl_binary) {
add = post_op;
mask_type_ = attn_mask_type::buffer;
} else if (post_op->get_kind() == op_kind::dnnl_mask) {
// implicit causal mask
causal_mask_ = true;
mask_type_ = static_cast<attn_mask_type_t>(
post_op->get_attr<int64_t>(op_attr::mask_type));
}
} else if (post_op->get_kind() == op_kind::dnnl_mask) {
causal_mask_ = true;
mask_type_ = static_cast<attn_mask_type_t>(
post_op->get_attr<int64_t>(op_attr::mask_type));
}
} else {
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
@@ -363,9 +366,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,

CHECK(create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(),
md_v.get(), md_dst.get(), md_mask.get(), scale_dt, invert_scale_,
kv_head_number_,
causal_mask_ ? attn_mask_type::top_left : attn_mask_type::buffer,
attr.get(), qk_attr.get(), vs_attr.get()));
kv_head_number_, mask_type_, attr.get(), qk_attr.get(),
vs_attr.get()));

auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());

2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ struct sdp_primitive_config_t {

bool invert_scale_ = false;
bool quantized_ = false;
bool causal_mask_ = false;
attn_mask_type_t mask_type_ = attn_mask_type::undef;
dim_t kv_head_number_;

// SDP pd and primitive.
13 changes: 13 additions & 0 deletions src/graph/backend/dnnl/layout_propagator.cpp
Original file line number Diff line number Diff line change
@@ -1757,6 +1757,19 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
return status;
}

status_t layout_propagator_for_host_scalar(std::shared_ptr<op_t> &op,
const dnnl::engine &p_engine, fusion_info_mgr_t &mgr,
pd_cache_t &pd_cache, subgraph_rewriter_t &rewriter) {
// no need to do layout propagation for host scalar
// as its output is always strided
UNUSED(op);
UNUSED(p_engine);
UNUSED(mgr);
UNUSED(pd_cache);
UNUSED(rewriter);
return status::success;
}

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/layout_propagator.hpp
Original file line number Diff line number Diff line change
@@ -94,6 +94,7 @@ DECLARE_LAYOUT_PROPAGATOR(groupnorm);
DECLARE_LAYOUT_PROPAGATOR(gen_index);
DECLARE_LAYOUT_PROPAGATOR(mask);
DECLARE_LAYOUT_PROPAGATOR(sdpa);
DECLARE_LAYOUT_PROPAGATOR(host_scalar);

#undef DECLARE_LAYOUT_PROPAGATOR

16 changes: 14 additions & 2 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
@@ -2390,6 +2390,17 @@ arg_indices_t reorder_executable_t::get_arg_indices(
return arg_indices;
}

arg_indices_t host_scalar_executable_t::get_arg_indices(
const op_t *op, fusion_info_mgr_t &mgr) {
UNUSED(op);
UNUSED(mgr);
arg_indices_t arg_indices;

arg_indices.insert({DNNL_ARG_FROM, indices_t {input, 0}});
arg_indices.insert({DNNL_ARG_TO, indices_t {output, 0}});
return arg_indices;
}

arg_indices_t softmax_bwd_executable_t::get_arg_indices(
const op_t *op, fusion_info_mgr_t &mgr) {
UNUSED(mgr);
@@ -2462,10 +2473,11 @@ arg_indices_t sdpa_executable_t::get_arg_indices(
arg_indices.insert({DNNL_ARG_QUERIES, indices_t {input, index++}});
arg_indices.insert({DNNL_ARG_KEYS, indices_t {input, index++}});
arg_indices.insert({DNNL_ARG_VALUES, indices_t {input, index++}});
if (op->get_attr<bool>(dnnl::impl::graph::dnnl_impl::op_attr::with_scale)) {
if (op->get_attr<bool>(op_attr::with_scale)) {
arg_indices.insert({DNNL_ARG_SCALE, indices_t {input, index++}});
}
if (op->get_attr<bool>(dnnl::impl::graph::dnnl_impl::op_attr::with_mask)) {
if (op->get_attr<int64_t>(op_attr::mask_type)
== static_cast<int64_t>(attn_mask_type::buffer)) {
arg_indices.insert({DNNL_ARG_ATTN_MASK, indices_t {input, index++}});
}

Loading
Loading