Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: kernels: sdp verbose enhancement #2273

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/graph/backend/dnnl/kernels/sdp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

#include "graph/backend/dnnl/dnnl_partition_impl.hpp"

#define VDISPATCH_GRAPH_SDP(msg, ...) \
VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__)

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -76,6 +79,11 @@ struct sdp_base_t : public kernel_base_t {
kernel = std::make_shared<larger_partition_kernel_t>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
}
if (ret == status::success)
VDISPATCH_GRAPH_SDP(
"sdpa is dispatched to (%s)", kernel->str().c_str());
else
VDISPATCH_GRAPH_SDP("sdpa is failed to dispatch");
return ret;
}

Expand Down
56 changes: 34 additions & 22 deletions src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"

#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_decomp_kernel_t, (cond), status, msg, \
##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand All @@ -25,10 +29,10 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// The order of input logical tensors in inputs is not certain, we need
// to record the input offset in a certain order of ops.
auto op_status = record_input_offset(sg, inputs);
if (op_status != status::success) return false;
CHECK_BOOL(record_input_offset(sg, inputs));
dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims();
if (src1_user_dims.size() != 4) return false;
VCHECK_SDP_DECOMP(src1_user_dims.size() == 4, false,
"Input dims should be 4, but got %zu", src1_user_dims.size());

// Initialize SDP input dimension according to the src of mm1
batch_size = src1_user_dims[0];
Expand All @@ -41,14 +45,17 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,

// Check batch size compatibility.
dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims();
if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) {
return false;
}
VCHECK_SDP_DECOMP(
batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0],
false,
"Batch size mismatch, batch_size: %lld, wei1: %lld, wei2: %lld",
batch_size, wei1_user_dims[0], wei2_user_dims[0]);

// Check scale size
if (graph_inport[2] != -1) {
auto scale_sz = ltw(inputs[graph_inport[2]]).nelems();
if (scale_sz != 1) return false;
VCHECK_SDP_DECOMP(scale_sz == 1, false,
"Only supports single scale value, but got %lld", scale_sz);
}

#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
Expand All @@ -65,10 +72,13 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
#define RATIO 2
// Initialize nthr with current threads num
nthr = dnnl_get_current_num_threads();
return batch_size * num_head_q > RATIO * nthr;
#else
return true;
VCHECK_SDP_DECOMP(batch_size * num_head_q > RATIO * nthr, false,
"Doesn't meet condition for decompose: Batch size * num_head_q "
"should be larger than ratio * nthr, but got batch_size %lld, "
"num_head_q %lld, ration %d , nthr %d",
batch_size, num_head_q, RATIO, nthr);
#endif
return true;
}

template <bool quantized, memory::data_type dt>
Expand All @@ -78,7 +88,7 @@ impl::status_t sdp_decomp_config_t::construct_params(
const std::vector<logical_tensor_t> &inputs) {

// Record the ops inside of SDP pattern for later usage
record_sdp_ops(sg, quantized);
CHECK(record_sdp_ops(sg, quantized));

// Update SDPA input params. Sequence length for query and key/value are
// NOT always same.
Expand Down Expand Up @@ -435,11 +445,12 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
graph::op_kind::SoftMax};
for (const auto &cur_op : sg->get_ops()) {
const auto &op_kind = cur_op->get_kind();
if (op_kind == graph::op_kind::DynamicDequantize
&& cur_op->get_attr<std::string>(op_attr::qtype)
== "per_group") {
return status::unimplemented;
}
VCHECK_SDP_DECOMP(
!(op_kind == graph::op_kind::DynamicDequantize
&& cur_op->get_attr<std::string>(op_attr::qtype)
== "per_group"),
status::unimplemented,
"Not support per_group DynamicDequantize");
// both mm1 and mm2 are found.
if (mm1 && mm2) break;
if (op_kind != graph::op_kind::MatMul) continue;
Expand All @@ -451,9 +462,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between matmul1 and scale");
// find scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -478,8 +489,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
mm2 = cur_op;
}
}
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;

VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph,
"Failed to find matmul1 or matmul2");
int src1_id = find_graph_inport(mm1->get_input_value(0));
graph_inport.emplace_back(src1_id);
int wei1_id = find_graph_inport(mm1->get_input_value(1));
Expand Down Expand Up @@ -534,7 +545,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
auto post_op = get_post_op(cur_op);
if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue;
auto ppost_op = get_post_op(post_op);
if (!ppost_op) return status::invalid_graph;
VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph,
"Failed to find post post op for matmul");

op_ptr reorder1;
op_ptr reorder2;
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
&& res->find_value_mem_map(
cfg_.v_zero_points_.get(), mem_storage[9]);

if (!ok) return status::runtime_error;
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
status::runtime_error,
"sdp_primitive_kernel get_prim_exec_args failed");

memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};
Expand Down
60 changes: 33 additions & 27 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

#include "common/compiler_workarounds.hpp"

#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \
VCONDCHECK(graph, create, check, sdp_primitive_kernel_t, (cond), status, \
msg, ##__VA_ARGS__);

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
// Locate mm1 and all post ops(scale and mask) here.
// 1. locate mm1
if (mm1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented,
"Multiple mm1 found");
mm1 = cur_op;
// At least one of scale and mask exists
if (post_op->get_kind() == op_kind::dnnl_binary) {
Expand All @@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
}
}
} else {
if (mm2) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
"Multiple mm2 found");
mm2 = cur_op;
}
}

// Locate input/outputs: Q, K, V, dst, scale, mask
mm1_ = mm1;
mm2_ = mm2;
if (!mm1 || !mm2 || !final_op) return status::unimplemented;
VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented,
"Not all ops are found");

q_ = mm1->get_input_value(0);
k_ = mm1->get_input_value(1);
v_ = mm2->get_input_value(1);
Expand Down Expand Up @@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
// At least 3 inputs: Q, K, V
if (inputs.size() < 3) return status::invalid_arguments;
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");

// step1(pattern check): Not support sdpa variants with select as mask
// We already have a pattern matcher to ensure that the sdpa patterns
Expand Down Expand Up @@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
mm1 = cur_op;
// Not support select between mm1 and scale(optional)
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select between mm1 and scale(optional)");
// scale
if (post_op->get_kind() == graph::op_kind::Divide
|| post_op->get_kind() == graph::op_kind::Multiply) {
Expand All @@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(

// Not support select after scale(optional) and mask(optional)
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
if (post_op->get_kind() == graph::op_kind::Select) {
return status::unimplemented;
}
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
status::unimplemented,
"Not support select after scale(optional) and "
"mask(optional)");
} else {
mm2 = cur_op;
}
Expand All @@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
return -1;
};

if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
VCHECK_SDP_PRIMITIVE(
mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found");

// step3(dims check): only support 4-dims now.
int q_id = find_graph_inport(mm1->get_input_value(0));
int k_id = find_graph_inport(mm1->get_input_value(1));
int v_id = find_graph_inport(mm2->get_input_value(1));

bool ok = true;
ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1);
if (!ok) return status::unimplemented;
ok = ok && ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4;
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
status::unimplemented, "Q, K, V are not found");
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
&& ltw(inputs[k_id]).vdims().size() == 4
&& ltw(inputs[v_id]).vdims().size() == 4,
status::unimplemented, "Q, K, V should be 4-dims");

// sdp_primitive only supports single scale value.
if (scale) {
const auto &s = scale->get_input_value(1)->get_logical_tensor();
if (ltw(s).nelems() != 1) return status::unimplemented;
VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented,
"Scale should be single value");
}

return ok ? status::success : status::unimplemented;
return status::success;
}

status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
Expand Down Expand Up @@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,

auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());

if (status != status::success) {
if (get_verbose(verbose_t::create_dispatch, component_t::graph)) {
verbose_printf(
"graph,create:dispatch,sdpa,could not create primitive, "
"falling back\n");
}
}

VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status,
"could not create sdp primitive, falling back\n");
return status;
}

Expand Down
Loading