diff --git a/CMakeLists.txt b/CMakeLists.txt index 24026c6..3841b79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,11 @@ if(GPU_RUNTIME STREQUAL "CUDA") endforeach() message("Set CUDA flags: " ${CMAKE_CUDA_FLAGS}) endif() + # Set torch cuda architecture list + set(TORCH_CUDA_ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES}) + list(TRANSFORM TORCH_CUDA_ARCH_LIST REPLACE "([0-9])([0-9])" "\\1.\\2") + string(REPLACE ";" " " TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST}") + message(STATUS "** Updated TORCH_CUDA_ARCH_LIST to ${TORCH_CUDA_ARCH_LIST} **") endif() elseif(GPU_RUNTIME STREQUAL "HIP") set(USE_HIP ON CACHE BOOL "Use HIP for GPU acceleration")