From 9bf40d33cab4a0f6a75a97388671c5ca897719e3 Mon Sep 17 00:00:00 2001 From: "Zhang, Rong A" Date: Mon, 16 Dec 2024 00:40:06 -0800 Subject: [PATCH] graph: backend: kernels: sdp verbose enhancement --- src/graph/backend/dnnl/kernels/sdp.hpp | 8 +++ .../dnnl/kernels/sdp_decomp_config.cpp | 56 ++++++++++------- .../backend/dnnl/kernels/sdp_primitive.cpp | 4 +- .../dnnl/kernels/sdp_primitive_config.cpp | 60 ++++++++++--------- 4 files changed, 78 insertions(+), 50 deletions(-) diff --git a/src/graph/backend/dnnl/kernels/sdp.hpp b/src/graph/backend/dnnl/kernels/sdp.hpp index 6391d9dbab3..9df511dfdaa 100644 --- a/src/graph/backend/dnnl/kernels/sdp.hpp +++ b/src/graph/backend/dnnl/kernels/sdp.hpp @@ -30,6 +30,9 @@ #include "graph/backend/dnnl/dnnl_partition_impl.hpp" +#define VDISPATCH_GRAPH_SDP(msg, ...) \ + VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__) + namespace dnnl { namespace impl { namespace graph { @@ -76,6 +79,11 @@ struct sdp_base_t : public kernel_base_t { kernel = std::make_shared(); ret = kernel->compile_impl(part, g_engine, inputs, outputs); } + if (ret == status::success) + VDISPATCH_GRAPH_SDP( + "sdpa is dispatched to (%s)", kernel->str().c_str()); + else + VDISPATCH_GRAPH_SDP("sdpa is failed to dispatch"); return ret; } diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp index f9418f7f9e1..8e49149d2c5 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp @@ -16,6 +16,10 @@ #include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp" +#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \ + VCONDCHECK(graph, create, check, sdp_decomp_kernel_t, (cond), status, msg, \ + ##__VA_ARGS__); + namespace dnnl { namespace impl { namespace graph { @@ -25,10 +29,10 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, const std::vector &inputs) { // The order of input logical tensors in inputs is not certain, we need // to record the input offset in a certain order of ops. - auto op_status = record_input_offset(sg, inputs); - if (op_status != status::success) return false; + CHECK_BOOL(record_input_offset(sg, inputs)); dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims(); - if (src1_user_dims.size() != 4) return false; + VCHECK_SDP_DECOMP(src1_user_dims.size() == 4, false, + "Input dims should be 4, but got %zu", src1_user_dims.size()); // Initialize SDP input dimension according to the src of mm1 batch_size = src1_user_dims[0]; @@ -41,14 +45,17 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, // Check batch size compatibility. dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims(); - if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) { - return false; - } + VCHECK_SDP_DECOMP( + batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0], + false, + "Batch size mismatch, batch_size: %lld, wei1: %lld, wei2: %lld", + batch_size, wei1_user_dims[0], wei2_user_dims[0]); // Check scale size if (graph_inport[2] != -1) { auto scale_sz = ltw(inputs[graph_inport[2]]).nelems(); - if (scale_sz != 1) return false; + VCHECK_SDP_DECOMP(scale_sz == 1, false, + "Only supports single scale value, but got %lld", scale_sz); } #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP @@ -65,10 +72,13 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr &sg, #define RATIO 2 // Initialize nthr with current threads num nthr = dnnl_get_current_num_threads(); - return batch_size * num_head_q > RATIO * nthr; -#else - return true; + VCHECK_SDP_DECOMP(batch_size * num_head_q > RATIO * nthr, false, + "Doesn't meet condition for decompose: Batch size * num_head_q " + "should be larger than ratio * nthr, but got batch_size %lld, " + "num_head_q %lld, ration %d , nthr %d", + batch_size, num_head_q, RATIO, nthr); #endif + return true; } template @@ -78,7 +88,7 @@ impl::status_t sdp_decomp_config_t::construct_params( const std::vector &inputs) { // Record the ops inside of SDP pattern for later usage - record_sdp_ops(sg, quantized); + CHECK(record_sdp_ops(sg, quantized)); // Update SDPA input params. Sequence length for query and key/value are // NOT always same. @@ -435,11 +445,12 @@ impl::status_t sdp_decomp_config_t::record_input_offset( graph::op_kind::SoftMax}; for (const auto &cur_op : sg->get_ops()) { const auto &op_kind = cur_op->get_kind(); - if (op_kind == graph::op_kind::DynamicDequantize - && cur_op->get_attr(op_attr::qtype) - == "per_group") { - return status::unimplemented; - } + VCHECK_SDP_DECOMP( + !(op_kind == graph::op_kind::DynamicDequantize + && cur_op->get_attr(op_attr::qtype) + == "per_group"), + status::unimplemented, + "Not support per_group DynamicDequantize"); // both mm1 and mm2 are found. if (mm1 && mm2) break; if (op_kind != graph::op_kind::MatMul) continue; @@ -451,9 +462,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset( // TODO(xxx): Currently, p2 is not supported by decomp kernel. // p1: [matmul] --> [scale] --> [select] --> [mask] --> ... // p2: [matmul] --> [select] --> [scale] --> [mask] --> ... - if (post_op->get_kind() == graph::op_kind::Select) { - return status::unimplemented; - } + VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select, + status::unimplemented, + "Not support select between matmul1 and scale"); // find scale if (post_op->get_kind() == graph::op_kind::Divide || post_op->get_kind() == graph::op_kind::Multiply) { @@ -478,8 +489,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset( mm2 = cur_op; } } - if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph; - + VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph, + "Failed to find matmul1 or matmul2"); int src1_id = find_graph_inport(mm1->get_input_value(0)); graph_inport.emplace_back(src1_id); int wei1_id = find_graph_inport(mm1->get_input_value(1)); @@ -534,7 +545,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops( auto post_op = get_post_op(cur_op); if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue; auto ppost_op = get_post_op(post_op); - if (!ppost_op) return status::invalid_graph; + VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph, + "Failed to find post post op for matmul"); op_ptr reorder1; op_ptr reorder2; diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp index f98d9bbedc9..1ae50306d26 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp @@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t::get_prim_exec_args( && res->find_value_mem_map( cfg_.v_zero_points_.get(), mem_storage[9]); - if (!ok) return status::runtime_error; + VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok, + status::runtime_error, + "sdp_primitive_kernel get_prim_exec_args failed"); memory_arg_t mem_arg_q = {mem_storage[0].get(), true}; memory_arg_t mem_arg_k = {mem_storage[1].get(), true}; diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index 2897f45bacc..6ebf5d76498 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -19,6 +19,10 @@ #include "common/compiler_workarounds.hpp" +#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \ + VCONDCHECK(graph, create, check, sdp_primitive_kernel_t, (cond), status, \ + msg, ##__VA_ARGS__); + namespace dnnl { namespace impl { namespace graph { @@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, if (post_op && mm1_post_op_kind.count(post_op->get_kind())) { // Locate mm1 and all post ops(scale and mask) here. // 1. locate mm1 - if (mm1) return status::unimplemented; + VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented, + "Multiple mm1 found"); mm1 = cur_op; // At least one of scale and mask exists if (post_op->get_kind() == op_kind::dnnl_binary) { @@ -84,7 +89,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, } } } else { - if (mm2) return status::unimplemented; + VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented, + "Multiple mm2 found"); mm2 = cur_op; } } @@ -92,7 +98,9 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, // Locate input/outputs: Q, K, V, dst, scale, mask mm1_ = mm1; mm2_ = mm2; - if (!mm1 || !mm2 || !final_op) return status::unimplemented; + VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented, + "Not all ops are found"); + q_ = mm1->get_input_value(0); k_ = mm1->get_input_value(1); v_ = mm2->get_input_value(1); @@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check( const std::shared_ptr &sg, const std::vector &inputs) { // At least 3 inputs: Q, K, V - if (inputs.size() < 3) return status::invalid_arguments; + VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments, + "At least 3 inputs are required"); // step1(pattern check): Not support sdpa variants with select as mask // We already have a pattern matcher to ensure that the sdpa patterns @@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check( mm1 = cur_op; // Not support select between mm1 and scale(optional) // GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ... - if (post_op->get_kind() == graph::op_kind::Select) { - return status::unimplemented; - } + VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select, + status::unimplemented, + "Not support select between mm1 and scale(optional)"); // scale if (post_op->get_kind() == graph::op_kind::Divide || post_op->get_kind() == graph::op_kind::Multiply) { @@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check( // Not support select after scale(optional) and mask(optional) // Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ... - if (post_op->get_kind() == graph::op_kind::Select) { - return status::unimplemented; - } + VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select, + status::unimplemented, + "Not support select after scale(optional) and " + "mask(optional)"); } else { mm2 = cur_op; } @@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check( return -1; }; - if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph; + VCHECK_SDP_PRIMITIVE( + mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found"); // step3(dims check): only support 4-dims now. int q_id = find_graph_inport(mm1->get_input_value(0)); int k_id = find_graph_inport(mm1->get_input_value(1)); int v_id = find_graph_inport(mm2->get_input_value(1)); - bool ok = true; - ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1); - if (!ok) return status::unimplemented; - ok = ok && ltw(inputs[q_id]).vdims().size() == 4 - && ltw(inputs[k_id]).vdims().size() == 4 - && ltw(inputs[v_id]).vdims().size() == 4; + VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1, + status::unimplemented, "Q, K, V are not found"); + VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4 + && ltw(inputs[k_id]).vdims().size() == 4 + && ltw(inputs[v_id]).vdims().size() == 4, + status::unimplemented, "Q, K, V should be 4-dims"); // sdp_primitive only supports single scale value. if (scale) { const auto &s = scale->get_input_value(1)->get_logical_tensor(); - if (ltw(s).nelems() != 1) return status::unimplemented; + VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented, + "Scale should be single value"); } - return ok ? status::success : status::unimplemented; + return status::success; } status_t sdp_primitive_config_t::init(std::shared_ptr &sg, @@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr &sg, auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get()); - if (status != status::success) { - if (get_verbose(verbose_t::create_dispatch, component_t::graph)) { - verbose_printf( - "graph,create:dispatch,sdpa,could not create primitive, " - "falling back\n"); - } - } - + VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status, + "could not create sdp primitive, falling back\n"); return status; }