-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph: api, doc, interface, utils, backend: support bottom-right implicit causal mask #2967
base: main
Are you sure you want to change the base?
Conversation
31a23c0
to
c29fc8f
Compare
c29fc8f
to
a653332
Compare
a653332
to
91f287f
Compare
std::fill(val.dims, val.dims + DNNL_MAX_NDIMS, DNNL_GRAPH_UNKNOWN_DIM); | ||
std::fill(val.layout.strides, val.layout.strides + DNNL_MAX_NDIMS, | ||
DNNL_GRAPH_UNKNOWN_DIM); | ||
if (ndims == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To align the API behavior with other logical_tensor_init_... APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make sure that the uninitialized fields will not affect the hash results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ndims = 0, the hash result only combines ndims, without touch to the dims and strides, so this change doesn't affect the hash result.
@@ -306,6 +306,7 @@ class logical_tensor { | |||
/// the library. For example, constant weight tensors in inference | |||
/// scenarios. | |||
constant = dnnl_graph_tensor_property_constant, | |||
host_scalar = dnnl_graph_tensor_property_host_scalar, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is host_scalar constant? how to set const property for host_scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
host_scalar is not a constant, we don't have a way to set it as a const currently. There was some discussion here, but no decision yet.
91f287f
to
ef08d4c
Compare
ef08d4c
to
a4aa4c6
Compare
make test |
std::fill(val.dims, val.dims + DNNL_MAX_NDIMS, DNNL_GRAPH_UNKNOWN_DIM); | ||
std::fill(val.layout.strides, val.layout.strides + DNNL_MAX_NDIMS, | ||
DNNL_GRAPH_UNKNOWN_DIM); | ||
if (ndims == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make sure that the uninitialized fields will not affect the hash results.
@ElaineBao I want to confirm that the fusion and implementation still work when |
2ce1c5b
to
c69976f
Compare
Yes, I just tried to modify the shape to 1D and also reset the property_type, it can pass. onednn_verbose,v1,graph,exec,gpu,100002,sdp,bmm1;scale_mul_op;gen_index_row_op;mask_add_op;mask_sub_op;gen_index_col_op;mask_ge_op;mask_select_op;softmax;bmm2,,in0_f16:0:strided:undef:32x16x384x64:393216s24576s64s1 in1_f16:1:strided:undef:32x16x384x64:393216s24576s64s1 in2_f16:4:strided:undef:1:1 in3_s32:10:strided:undef:1:1 in4_s32:13:strided:undef:1:1 in5_f32:20:strided:undef:1:1 in6_f16:24:strided:undef:32x16x384x64:393216s24576s64s1 out0_f16:25:strided:undef:32x16x384x64:393216s24576s64s1,fpm:strict,larger_partition_kernel_t,dnnl_backend,118.229 |
make test |
in sdp_primitive_kernel
which is used in SDPA bottom-right causal mask
c69976f
to
1987805
Compare
make test |
@@ -76,6 +76,7 @@ typedef enum { | |||
/// optimizations for constant tensors or cache constant tensors inside the | |||
/// library. For example, constant weight tensors in inference scenarios. | |||
dnnl_graph_tensor_property_constant = 2, | |||
dnnl_graph_tensor_property_host_scalar = 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation?
@@ -16,6 +16,7 @@ in comparison to fp32. | |||
| f16 | [IEEE half precision floating-point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16) | | |||
| s8/u8 | signed/unsigned 8-bit integer | | |||
| s4/u4 | signed/unsigned 4-bit integer | | |||
| s32 | signed/unsigned 32-bit integer | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably could use a wiki page link. @vpirogov, could you share your thoughts, please?
// check unimplemented bottom-right causal mask | ||
auto post_op = get_post_op(cur_op); | ||
if (post_op && post_op->get_kind() == graph::op_kind::Add) | ||
return status::unimplemented; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verbose message instead?
auto mem = make_dnnl_memory(md, p_engine, nullptr); | ||
exec_args_set_.add_value_mem_map({in.get(), mem}); | ||
classify_mem(mem, in.get()); | ||
logical_tensor_t in_lt = in->get_logical_tensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const?
// property is host_scalar | ||
int32_t s32_value = 0; | ||
// TODO: add more dtype support | ||
} scalar; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} scalar; | |
} scalar_; |
@@ -407,11 +407,6 @@ void skip_unimplemented_ops(const dnnl::graph::partition &partition, | |||
const deserialized_graph_t &dg, res_t *res) { | |||
// A list of ops that don't have DNNL backend support so far. | |||
static const std::vector<std::string> unimplemented_ops {"Pow"}; | |||
// A list of ops that don't have DNNL backend support so far on GPU. | |||
static const std::vector<std::string> unimplemented_ops_gpu { | |||
"GreaterEqual"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just remove the op from unimplemented_ops_gpu
and kept rest infrastructure intact for future unsupported ops.
@@ -112,9 +114,14 @@ dnnl::graph::tensor dnn_graph_mem_t::make_graph_tensor( | |||
dnnl_memory_get_data_handle(mem_.m_, &data_handle); | |||
dnnl::graph::logical_tensor graph_lt(lt.id_, lt.get_data_type(), lt.shape_, | |||
str2layout(lt.layout_type_), lt.get_property_type()); | |||
dnnl::graph::tensor ret(graph_lt, get_graph_engine(), data_handle); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use
const auto &g_eng = is_host_scalar
? get_graph_host_engine().operator const dnnl::engine &()
: get_graph_engine().operator const dnnl::engine &();
instead?
@@ -309,7 +309,7 @@ int partition_data_displacer_t::displace_input_data( | |||
} else if (filling_type == filling_type_t::causal_mask) { | |||
SAFE(gen_causal_mask_filling(mem_replace, mem.md_, res), WARN); | |||
} else if (filling_type == filling_type_t::minus_infinity) { | |||
static const std::vector<float> user_set {-INFINITY}; | |||
static const std::vector<float> user_set {-1e4}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
} | ||
|
||
inline const cpp_engine_t &get_graph_host_engine() { | ||
const dnnl::engine &g_eng |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const dnnl::engine &g_eng | |
// Return `get_graph_engine` for `is_cpu` to avoid different engines. | |
const dnnl::engine &g_eng |
Description
Implementation of RFC: #2885.
Support bottom-right implicit causal mask (large partition kernel) in Graph API. SDPA primitive kernel will be supported after GPU primitive support.
Currently this PR depends on :