|
| 1 | +// Copyright (C) 2018-2023 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "causal_mask_preprocess.h" |
| 6 | + |
| 7 | +#include "common/bfloat16.hpp" |
| 8 | +#include "common/cpu_memcpy.h" |
| 9 | +#include "cpu/x64/cpu_isa_traits.hpp" |
| 10 | +#include "shape_inference/shape_inference_internal_dyn.hpp" |
| 11 | +#include "utils/plain_tensor.hpp" |
| 12 | + |
| 13 | +#include <chrono> |
| 14 | +#include <string> |
| 15 | +#include <vector> |
| 16 | + |
| 17 | +namespace ov { |
| 18 | +namespace intel_cpu { |
| 19 | +namespace node { |
| 20 | + |
| 21 | +/* |
| 22 | +CausalMaskPreprocess: |
| 23 | + inputs: |
| 24 | + 0: attention_mask : i64[N, kv_len] |
| 25 | + 0 means mask-out, 1 means attends to |
| 26 | + 1: batch_size (size_Gather) : i32[1] |
| 27 | + 2: cache_positions i32[q_len]; |
| 28 | + 3: kvLen i32[1]; |
| 29 | + outputs |
| 30 | + 0: causal mask for SDPA : f32[batch_size, 1, q_len, kvLen] |
| 31 | +
|
| 32 | +The functionality is equivalent to following python code: |
| 33 | +
|
| 34 | + ##### preprocess |
| 35 | + min_dtype = torch.finfo(dtype).min |
| 36 | + causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype |
| 37 | + causal_mask = causal_mask.to(dtype=dtype, device=device) |
| 38 | +
|
| 39 | + mask_length = attention_mask.shape[-1] |
| 40 | + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) |
| 41 | + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) |
| 42 | +
|
| 43 | + ##### when being used will be further sliced |
| 44 | + causal_mask = attention_mask |
| 45 | + if attention_mask is not None and cache_position is not None: |
| 46 | + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] |
| 47 | +*/ |
| 48 | +template <typename T> |
| 49 | +struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPreprocess::Executor { |
| 50 | + void execute(dnnl::stream strm, |
| 51 | + intel_cpu::Node * pnode, |
| 52 | + const intel_cpu::CausalMaskPreprocessNode::Config& config) override { |
| 53 | + ov::intel_cpu::PlainTensor t_attention_mask(pnode->getSrcMemoryAtPort(0)); |
| 54 | + ov::intel_cpu::PlainTensor t_batch_size(pnode->getSrcMemoryAtPort(1)); |
| 55 | + ov::intel_cpu::PlainTensor t_cache_positions(pnode->getSrcMemoryAtPort(2)); |
| 56 | + ov::intel_cpu::PlainTensor t_kvLen(pnode->getSrcMemoryAtPort(3)); |
| 57 | + |
| 58 | + auto mask_length = t_attention_mask.size(-1); |
| 59 | + auto batch_size = static_cast<size_t>(*t_batch_size.ptr<int32_t>(0)); |
| 60 | + auto kvLen = static_cast<size_t>(*t_kvLen.ptr<int32_t>(0)); |
| 61 | + auto qLen = t_cache_positions.size(0); |
| 62 | + |
| 63 | + VectorDims newDims{batch_size, 1, qLen, kvLen}; |
| 64 | + pnode->redefineOutputMemory({newDims}); |
| 65 | + ov::intel_cpu::PlainTensor t_dst(pnode->getDstMemoryAtPort(0)); |
| 66 | + |
| 67 | + DEBUG_LOG("CausalMaskPreprocess::execute", config.type, " batch_size=", batch_size, " qLen=", qLen, " kvLen=", kvLen); |
| 68 | + DEBUG_LOG("CausalMaskPreprocess::execute attention_mask=", t_attention_mask); |
| 69 | + DEBUG_LOG("CausalMaskPreprocess::execute cache_positions=", t_cache_positions); |
| 70 | + |
| 71 | + // raw_causal_mask is already ensured to be triu by transformation |
| 72 | + auto* prow = t_cache_positions.ptr<int32_t>(0); |
| 73 | + T min_dtype = std::numeric_limits<T>::lowest(); |
| 74 | + |
| 75 | + parallel_for2d(batch_size, qLen, [&](size_t n, size_t i) { |
| 76 | + auto* pamask = t_attention_mask.ptr<int32_t>(n, 0); |
| 77 | + auto* pdst = t_dst.ptr<T>(n, 0, i); |
| 78 | + auto row = static_cast<size_t>(prow[i]); |
| 79 | + size_t j = 0; |
| 80 | + for (; j < mask_length; j++) { |
| 81 | + bool cmask_eq0 = (j <= row); |
| 82 | + bool amask_eq0 = (pamask[j] == 0); |
| 83 | + bool padding_mask = (cmask_eq0 && amask_eq0); |
| 84 | + pdst[j] = (padding_mask | (!cmask_eq0))? min_dtype : T(0); |
| 85 | + } |
| 86 | + for (; j < kvLen; j++) { |
| 87 | + bool cmask_eq0 = (j <= row); |
| 88 | + pdst[j] = cmask_eq0 ? T(0) : min_dtype; |
| 89 | + } |
| 90 | + }); |
| 91 | + DEBUG_LOG("CausalMaskPreprocess::execute dst=", t_dst); |
| 92 | + } |
| 93 | +}; |
| 94 | + |
| 95 | +CausalMaskPreprocess::CausalMaskPreprocess(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context) |
| 96 | + : Node(op, context, InternalDynShapeInferFactory()) { |
| 97 | + std::string errorMessage; |
| 98 | + if (!isSupportedOperation(op, errorMessage)) { |
| 99 | + OPENVINO_THROW("CPU: " + errorMessage); |
| 100 | + } |
| 101 | + |
| 102 | + const auto node = std::dynamic_pointer_cast<const intel_cpu::CausalMaskPreprocessNode>(op); |
| 103 | + m_config = node->get_config(); |
| 104 | +} |
| 105 | + |
| 106 | +bool CausalMaskPreprocess::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept { |
| 107 | + try { |
| 108 | + const auto node = std::dynamic_pointer_cast<const intel_cpu::CausalMaskPreprocessNode>(op); |
| 109 | + if (!node) { |
| 110 | + errorMessage = "Only CausalMaskPreprocessNode operation is supported"; |
| 111 | + return false; |
| 112 | + } |
| 113 | + } catch (...) { |
| 114 | + return false; |
| 115 | + } |
| 116 | + return true; |
| 117 | +} |
| 118 | + |
| 119 | +void CausalMaskPreprocess::initSupportedPrimitiveDescriptors() { |
| 120 | + if (!supportedPrimitiveDescriptors.empty()) |
| 121 | + return; |
| 122 | + |
| 123 | + std::vector<ov::element::Type> iprecs = getOriginalInputPrecisions(); |
| 124 | + std::vector<ov::element::Type> oprecs = getOriginalOutputPrecisions(); |
| 125 | + |
| 126 | + // precision preferences |
| 127 | + if (m_config.type == "CausalMaskPreprocess") { |
| 128 | + if (oprecs[0] == ov::element::bf16) { |
| 129 | + m_executor = std::make_shared<ExecutorCausalMaskPreprocess<ov::bfloat16>>(); |
| 130 | + } else { |
| 131 | + // fallback to default precision |
| 132 | + m_executor = std::make_shared<ExecutorCausalMaskPreprocess<float>>(); |
| 133 | + oprecs[0] = ov::element::f32; |
| 134 | + } |
| 135 | + // all input precisions must be int32 |
| 136 | + for (auto& prec : iprecs) prec = ov::element::i32; |
| 137 | + } else { |
| 138 | + OPENVINO_THROW("CPU: CausalMaskPreprocess type not supported : " + m_config.type); |
| 139 | + } |
| 140 | + |
| 141 | + std::vector<PortConfigurator> inPortConfigs; |
| 142 | + for (size_t i = 0; i < getOriginalInputsNumber(); i++) |
| 143 | + inPortConfigs.emplace_back(LayoutType::ncsp, iprecs[i], getInputShapeAtPort(i), false, -1); |
| 144 | + |
| 145 | + std::vector<PortConfigurator> outPortConfigs; |
| 146 | + for (size_t i = 0; i < getOriginalOutputsNumber(); i++) |
| 147 | + outPortConfigs.emplace_back(LayoutType::ncsp, oprecs[i], getOutputShapeAtPort(i), false, -1); |
| 148 | + |
| 149 | + addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); |
| 150 | +} |
| 151 | + |
| 152 | +void CausalMaskPreprocess::execute(dnnl::stream strm) { |
| 153 | + m_executor->execute(strm, this, m_config); |
| 154 | +} |
| 155 | + |
| 156 | +} // namespace node |
| 157 | +} // namespace intel_cpu |
| 158 | +} // namespace ov |
0 commit comments