Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: api, doc, interface, utils, backend: support bottom-right implicit causal mask #2967

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

ElaineBao
Copy link
Contributor

Description

Implementation of RFC: #2885.
Support bottom-right implicit causal mask (large partition kernel) in Graph API. SDPA primitive kernel will be supported after GPU primitive support.

Currently this PR depends on :

@ElaineBao ElaineBao added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Mar 27, 2025
@ElaineBao ElaineBao self-assigned this Mar 27, 2025
@ElaineBao ElaineBao requested review from a team as code owners March 27, 2025 10:02
@github-actions github-actions bot added documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel component:api Codeowner: @oneapi-src/onednn-arch component:tests Codeowner: @oneapi-src/onednn-arch component:examples labels Mar 27, 2025
@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from 31a23c0 to c29fc8f Compare March 27, 2025 15:38
@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from c29fc8f to a653332 Compare March 28, 2025 06:00
@github-actions github-actions bot removed the platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel label Mar 28, 2025
@ElaineBao ElaineBao removed the request for review from a team March 28, 2025 06:01
@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from a653332 to 91f287f Compare March 31, 2025 13:43
std::fill(val.dims, val.dims + DNNL_MAX_NDIMS, DNNL_GRAPH_UNKNOWN_DIM);
std::fill(val.layout.strides, val.layout.strides + DNNL_MAX_NDIMS,
DNNL_GRAPH_UNKNOWN_DIM);
if (ndims == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To align the API behavior with other logical_tensor_init_... APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure that the uninitialized fields will not affect the hash results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ndims = 0, the hash result only combines ndims, without touch to the dims and strides, so this change doesn't affect the hash result.

@@ -306,6 +306,7 @@ class logical_tensor {
/// the library. For example, constant weight tensors in inference
/// scenarios.
constant = dnnl_graph_tensor_property_constant,
host_scalar = dnnl_graph_tensor_property_host_scalar,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is host_scalar constant? how to set const property for host_scalar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

host_scalar is not a constant, we don't have a way to set it as a const currently. There was some discussion here, but no decision yet.

@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from 91f287f to ef08d4c Compare April 1, 2025 02:32
@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from ef08d4c to a4aa4c6 Compare April 1, 2025 05:44
@ElaineBao
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

std::fill(val.dims, val.dims + DNNL_MAX_NDIMS, DNNL_GRAPH_UNKNOWN_DIM);
std::fill(val.layout.strides, val.layout.strides + DNNL_MAX_NDIMS,
DNNL_GRAPH_UNKNOWN_DIM);
if (ndims == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure that the uninitialized fields will not affect the hash results.

@TaoLv
Copy link
Contributor

TaoLv commented Apr 2, 2025

@ElaineBao I want to confirm that the fusion and implementation still work when s_kv and s_q are device-side scalars.

@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from 2ce1c5b to c69976f Compare April 2, 2025 05:41
@ElaineBao
Copy link
Contributor Author

@ElaineBao I want to confirm that the fusion and implementation still work when s_kv and s_q are device-side scalars.

Yes, I just tried to modify the shape to 1D and also reset the property_type, it can pass.

onednn_verbose,v1,graph,exec,gpu,100002,sdp,bmm1;scale_mul_op;gen_index_row_op;mask_add_op;mask_sub_op;gen_index_col_op;mask_ge_op;mask_select_op;softmax;bmm2,,in0_f16:0:strided:undef:32x16x384x64:393216s24576s64s1 in1_f16:1:strided:undef:32x16x384x64:393216s24576s64s1 in2_f16:4:strided:undef:1:1 in3_s32:10:strided:undef:1:1 in4_s32:13:strided:undef:1:1 in5_f32:20:strided:undef:1:1 in6_f16:24:strided:undef:32x16x384x64:393216s24576s64s1 out0_f16:25:strided:undef:32x16x384x64:393216s24576s64s1,fpm:strict,larger_partition_kernel_t,dnnl_backend,118.229

@ElaineBao
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@ElaineBao ElaineBao force-pushed the yixin/bottom-right-causal-ref branch from c69976f to 1987805 Compare April 2, 2025 09:06
@ElaineBao
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@@ -76,6 +76,7 @@ typedef enum {
/// optimizations for constant tensors or cache constant tensors inside the
/// library. For example, constant weight tensors in inference scenarios.
dnnl_graph_tensor_property_constant = 2,
dnnl_graph_tensor_property_host_scalar = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation?

@@ -16,6 +16,7 @@ in comparison to fp32.
| f16 | [IEEE half precision floating-point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16) |
| s8/u8 | signed/unsigned 8-bit integer |
| s4/u4 | signed/unsigned 4-bit integer |
| s32 | signed/unsigned 32-bit integer |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably could use a wiki page link. @vpirogov, could you share your thoughts, please?

// check unimplemented bottom-right causal mask
auto post_op = get_post_op(cur_op);
if (post_op && post_op->get_kind() == graph::op_kind::Add)
return status::unimplemented;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verbose message instead?

auto mem = make_dnnl_memory(md, p_engine, nullptr);
exec_args_set_.add_value_mem_map({in.get(), mem});
classify_mem(mem, in.get());
logical_tensor_t in_lt = in->get_logical_tensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const?

// property is host_scalar
int32_t s32_value = 0;
// TODO: add more dtype support
} scalar;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} scalar;
} scalar_;

@@ -407,11 +407,6 @@ void skip_unimplemented_ops(const dnnl::graph::partition &partition,
const deserialized_graph_t &dg, res_t *res) {
// A list of ops that don't have DNNL backend support so far.
static const std::vector<std::string> unimplemented_ops {"Pow"};
// A list of ops that don't have DNNL backend support so far on GPU.
static const std::vector<std::string> unimplemented_ops_gpu {
"GreaterEqual"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just remove the op from unimplemented_ops_gpu and kept rest infrastructure intact for future unsupported ops.

@@ -112,9 +114,14 @@ dnnl::graph::tensor dnn_graph_mem_t::make_graph_tensor(
dnnl_memory_get_data_handle(mem_.m_, &data_handle);
dnnl::graph::logical_tensor graph_lt(lt.id_, lt.get_data_type(), lt.shape_,
str2layout(lt.layout_type_), lt.get_property_type());
dnnl::graph::tensor ret(graph_lt, get_graph_engine(), data_handle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use

    const auto &g_eng = is_host_scalar
            ? get_graph_host_engine().operator const dnnl::engine &()
            : get_graph_engine().operator const dnnl::engine &();

instead?

@@ -309,7 +309,7 @@ int partition_data_displacer_t::displace_input_data(
} else if (filling_type == filling_type_t::causal_mask) {
SAFE(gen_causal_mask_filling(mem_replace, mem.md_, res), WARN);
} else if (filling_type == filling_type_t::minus_infinity) {
static const std::vector<float> user_set {-INFINITY};
static const std::vector<float> user_set {-1e4};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

}

inline const cpp_engine_t &get_graph_host_engine() {
const dnnl::engine &g_eng
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const dnnl::engine &g_eng
// Return `get_graph_engine` for `is_cpu` to avoid different engines.
const dnnl::engine &g_eng

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:api Codeowner: @oneapi-src/onednn-arch component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants