blob: b74ce4c8dfe042dcbea82bbceeea449d2e051164 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(USE_CUDA AND USE_CUTLASS)
set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
set(CUTLASS_RUNTIME_OBJS "")
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
src/relax/backend/contrib/cutlass/*.cc
)
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})
set(FPA_INTB_GEMM_TVM_BINDING ON)
set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})
### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
target_include_directories(fpA_intB_gemm PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
target_link_libraries(fpA_intB_gemm_tvm PRIVATE tvm_ffi_header)
set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
target_link_libraries(fpA_intB_cutlass_objs PRIVATE tvm_ffi_header)
target_compile_definitions(fpA_intB_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
target_include_directories(fpA_intB_cutlass_objs PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")
### Build cutlass runtime objects for flash attention
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
target_include_directories(flash_attn PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
)
### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
set(TVM_CUTLASS_RUNTIME_SRCS "")
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu)
endif()
if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-lineinfo --expt-relaxed-constexpr>)
target_include_directories(tvm_cutlass_objs PRIVATE
${CUTLASS_DIR}/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
)
target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header)
target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
# Note: enable this to get more detailed logs for cutlass kernels
# target_compile_definitions(tvm_cutlass_objs PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=2)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
endif()
### Add cutlass objects to list of TVM runtime extension objs
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")
message(STATUS "Build with CUTLASS")
endif()