From 609997e1b903065a110834d68e2ef09bd4944b45 Mon Sep 17 00:00:00 2001
From: Raju Konda <kraju@nvidia.com>
Date: Wed, 12 Mar 2025 21:55:14 -0700
Subject: [PATCH] Add missing copy of input data from file to memory for
 YUV444P

Also add support frame count calculations for YUV444P and YUV422P

Signed-off-by: Raju Konda <kraju@nvidia.com>
---
 common/libs/VkCodecUtils/YCbCrConvUtilsCpu.h  | 45 +++++++++++++++++++
 .../libs/VkVideoEncoder/VkEncoderConfig.h     | 17 ++++++-
 .../libs/VkVideoEncoder/VkVideoEncoder.cpp    | 43 ++++++++++++++++--
 3 files changed, 99 insertions(+), 6 deletions(-)

diff --git a/common/libs/VkCodecUtils/YCbCrConvUtilsCpu.h b/common/libs/VkCodecUtils/YCbCrConvUtilsCpu.h
index 4b9cdd6..e70d1a4 100644
--- a/common/libs/VkCodecUtils/YCbCrConvUtilsCpu.h
+++ b/common/libs/VkCodecUtils/YCbCrConvUtilsCpu.h
@@ -240,6 +240,51 @@ class YCbCrConvUtilsCpu
 
         return 0;
     }
+
+    static int I444ToP444(const planeType* src_y,
+                         int src_stride_y,
+                         const planeType* src_u,
+                         int src_stride_u,
+                         const planeType* src_v,
+                         int src_stride_v,
+                         planeType* dst_y,
+                         int dst_stride_y,
+                         planeType* dst_uv,
+                         int dst_stride_uv,
+                         int width,
+                         int height,
+                         int shiftBits = 0) {
+        if (!src_y || !src_u || !src_v || !dst_y || !dst_uv || width <= 0 || height == 0) {
+            return -1;
+        }
+
+        // Convert strides from bytes to elements
+        src_stride_y /= (int)sizeof(planeType);
+        dst_stride_y /= (int)sizeof(planeType);
+        src_stride_u /= (int)sizeof(planeType);
+        src_stride_v /= (int)sizeof(planeType);
+        dst_stride_uv /= (int)sizeof(planeType);
+
+        // Handle negative height (image inversion)
+        if (height < 0) {
+            height = -height;
+            src_y = src_y + (height - 1) * src_stride_y;
+            src_u = src_u + (height - 1) * src_stride_u;
+            src_v = src_v + (height - 1) * src_stride_v;
+            src_stride_y = -src_stride_y;
+            src_stride_u = -src_stride_u;
+            src_stride_v = -src_stride_v;
+        }
+
+        // Copy Y plane at full resolution
+        CopyPlane(src_y, src_stride_y, dst_y, dst_stride_y, width, height, shiftBits);
+
+        // Merge U and V planes at full resolution
+        MergeUVPlane(src_u, src_stride_u, src_v, src_stride_v, dst_uv, dst_stride_uv,
+                     width, height, shiftBits);
+
+        return 0;
+    }
 };
 
 #endif /* _VKCODECUTILS_YCBCRCONVUTILSCPU_H_ */
diff --git a/vk_video_encoder/libs/VkVideoEncoder/VkEncoderConfig.h b/vk_video_encoder/libs/VkVideoEncoder/VkEncoderConfig.h
index 7a42afb..f0ccfcc 100644
--- a/vk_video_encoder/libs/VkVideoEncoder/VkEncoderConfig.h
+++ b/vk_video_encoder/libs/VkVideoEncoder/VkEncoderConfig.h
@@ -393,13 +393,23 @@ class EncoderInputFileHandler
 
     uint32_t GetFrameCount(uint32_t width, uint32_t height, uint8_t bpp, VkVideoChromaSubsamplingFlagBitsKHR chromaSubsampling) {
         uint8_t nBytes = (uint8_t)(bpp + 7) / 8;
-        double samplingFactor = 1.5;
+        double samplingFactor = 1.5; // Default for 420
         switch (chromaSubsampling)
         {
+        case VK_VIDEO_CHROMA_SUBSAMPLING_MONOCHROME_BIT_KHR:
+            samplingFactor = 1.0; // Only Y component
+            break;
         case VK_VIDEO_CHROMA_SUBSAMPLING_420_BIT_KHR:
-            samplingFactor = 1.5;
+            samplingFactor = 1.5; // Y + 1/4 U + 1/4 V = 1.5
+            break;
+        case VK_VIDEO_CHROMA_SUBSAMPLING_422_BIT_KHR:
+            samplingFactor = 2.0; // Y + 1/2 U + 1/2 V = 2.0
+            break;
+        case VK_VIDEO_CHROMA_SUBSAMPLING_444_BIT_KHR:
+            samplingFactor = 3.0; // Full Y + full U + full V = 3.0
             break;
         default:
+            assert(!"Unknown chroma subsampling");
             break;
         }
         uint32_t frameSize = (uint32_t)(width * height * nBytes * samplingFactor);
@@ -929,6 +939,9 @@ struct EncoderConfig : public VkVideoRefCountBase {
             return VK_ERROR_INVALID_VIDEO_STD_PARAMETERS_KHR;
         }
 
+        // Copy chroma subsampling from input to encoder config
+        encodeChromaSubsampling = input.chromaSubsampling;
+
         if ((encodeWidth == 0) || (encodeWidth > input.width)) {
             encodeWidth = input.width;
         }
diff --git a/vk_video_encoder/libs/VkVideoEncoder/VkVideoEncoder.cpp b/vk_video_encoder/libs/VkVideoEncoder/VkVideoEncoder.cpp
index 370ede0..ec4045b 100644
--- a/vk_video_encoder/libs/VkVideoEncoder/VkVideoEncoder.cpp
+++ b/vk_video_encoder/libs/VkVideoEncoder/VkVideoEncoder.cpp
@@ -170,8 +170,24 @@ VkResult VkVideoEncoder::LoadNextFrame(VkSharedBaseObj<VkVideoEncodeFrameInfo>&
     int yCbCrConvResult = 0;
     if (m_encoderConfig->input.bpp == 8) {
 
-        // Load current 8-bit frame from file and convert to NV12
-        yCbCrConvResult = YCbCrConvUtilsCpu<uint8_t>::I420ToNV12(
+        if (m_encoderConfig->encodeChromaSubsampling == VK_VIDEO_CHROMA_SUBSAMPLING_444_BIT_KHR) {
+            // Load current 8-bit frame from file and convert to 2-plane YUV444
+            yCbCrConvResult = YCbCrConvUtilsCpu<uint8_t>::I444ToP444(
+                    pInputFrameData + m_encoderConfig->input.planeLayouts[0].offset,         // src_y
+                    (int)m_encoderConfig->input.planeLayouts[0].rowPitch,                    // src_stride_y
+                    pInputFrameData + m_encoderConfig->input.planeLayouts[1].offset,         // src_u
+                    (int)m_encoderConfig->input.planeLayouts[1].rowPitch,                    // src_stride_u
+                    pInputFrameData + m_encoderConfig->input.planeLayouts[2].offset,         // src_v
+                    (int)m_encoderConfig->input.planeLayouts[2].rowPitch,                    // src_stride_v
+                    writeImagePtr + dstSubresourceLayout[0].offset,                          // dst_y
+                    (int)dstSubresourceLayout[0].rowPitch,                                   // dst_stride_y
+                    writeImagePtr + dstSubresourceLayout[1].offset,                          // dst_uv
+                    (int)dstSubresourceLayout[1].rowPitch,                                   // dst_stride_uv
+                    std::min(m_encoderConfig->encodeWidth,  m_encoderConfig->input.width),   // width
+                    std::min(m_encoderConfig->encodeHeight, m_encoderConfig->input.height)); // height
+        } else {
+            // Load current 8-bit frame from file and convert to NV12
+            yCbCrConvResult = YCbCrConvUtilsCpu<uint8_t>::I420ToNV12(
                     pInputFrameData + m_encoderConfig->input.planeLayouts[0].offset,         // src_y,
                     (int)m_encoderConfig->input.planeLayouts[0].rowPitch,                    // src_stride_y,
                     pInputFrameData + m_encoderConfig->input.planeLayouts[1].offset,         // src_u,
@@ -184,6 +200,7 @@ VkResult VkVideoEncoder::LoadNextFrame(VkSharedBaseObj<VkVideoEncodeFrameInfo>&
                     (int)dstSubresourceLayout[1].rowPitch,                                   // dst_stride_uv,
                     std::min(m_encoderConfig->encodeWidth,  m_encoderConfig->input.width),   // width
                     std::min(m_encoderConfig->encodeHeight, m_encoderConfig->input.height)); // height
+        }
 
     } else if (m_encoderConfig->input.bpp == 10) { // 10-bit - actually 16-bit only for now.
 
@@ -194,8 +211,25 @@ VkResult VkVideoEncoder::LoadNextFrame(VkSharedBaseObj<VkVideoEncodeFrameInfo>&
             shiftBits = 16 - m_encoderConfig->input.bpp;
         }
 
-        // Load current 10-bit frame from file and convert to P010/P016
-        yCbCrConvResult = YCbCrConvUtilsCpu<uint16_t>::I420ToNV12(
+        if (m_encoderConfig->encodeChromaSubsampling == VK_VIDEO_CHROMA_SUBSAMPLING_444_BIT_KHR) {
+            // Load current 10-bit frame from file and convert to 2-plane YUV444
+            yCbCrConvResult = YCbCrConvUtilsCpu<uint16_t>::I444ToP444(
+                    (const uint16_t*)(pInputFrameData + m_encoderConfig->input.planeLayouts[0].offset), // src_y
+                    (int)m_encoderConfig->input.planeLayouts[0].rowPitch,                               // src_stride_y
+                    (const uint16_t*)(pInputFrameData + m_encoderConfig->input.planeLayouts[1].offset), // src_u
+                    (int)m_encoderConfig->input.planeLayouts[1].rowPitch,                               // src_stride_u
+                    (const uint16_t*)(pInputFrameData + m_encoderConfig->input.planeLayouts[2].offset), // src_v
+                    (int)m_encoderConfig->input.planeLayouts[2].rowPitch,                               // src_stride_v
+                    (uint16_t*)(writeImagePtr + dstSubresourceLayout[0].offset),                        // dst_y
+                    (int)dstSubresourceLayout[0].rowPitch,                                              // dst_stride_y
+                    (uint16_t*)(writeImagePtr + dstSubresourceLayout[1].offset),                        // dst_uv
+                    (int)dstSubresourceLayout[1].rowPitch,                                              // dst_stride_uv
+                    std::min(m_encoderConfig->encodeWidth,  m_encoderConfig->input.width),              // width
+                    std::min(m_encoderConfig->encodeHeight, m_encoderConfig->input.height),             // height
+                    shiftBits);
+        } else {
+            // Load current 10-bit frame from file and convert to P010/P016
+            yCbCrConvResult = YCbCrConvUtilsCpu<uint16_t>::I420ToNV12(
                     (const uint16_t*)(pInputFrameData + m_encoderConfig->input.planeLayouts[0].offset), // src_y,
                     (int)m_encoderConfig->input.planeLayouts[0].rowPitch,                               // src_stride_y,
                     (const uint16_t*)(pInputFrameData + m_encoderConfig->input.planeLayouts[1].offset), // src_u,
@@ -209,6 +243,7 @@ VkResult VkVideoEncoder::LoadNextFrame(VkSharedBaseObj<VkVideoEncodeFrameInfo>&
                     std::min(m_encoderConfig->encodeWidth,  m_encoderConfig->input.width),              // width
                     std::min(m_encoderConfig->encodeHeight, m_encoderConfig->input.height),             // height
                     shiftBits);
+        }
 
     } else {
         assert(!"Requested bit-depth is not supported!");