Commit 390d2035 by Ting PAN

Add support to link CuDNN 8.3.x statically

Summary:
This commit adds cmake modules to handle linked libraries.
1 parent a79a3bba
# ---[ Protobuf
file(GLOB PROTO_FILES ${PROJECT_SOURCE_DIR}/proto/*.proto)
protobuf_generate_cpp(${PROTO_FILES})
# ---[ Runtime
if (PYTHON_EXECUTABLE AND BUILD_RUNTIME)
set(HAS_RUNTIME_CODEGEN ON)
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PROJECT_SOURCE_DIR}/../tools/codegen_runtime.py
${PROJECT_SOURCE_DIR} "REMOVE_GRADIENT")
else()
set(HAS_RUNTIME_CODEGEN OFF)
endif()
...@@ -13,16 +13,13 @@ if (BUILD_PYTHON) ...@@ -13,16 +13,13 @@ if (BUILD_PYTHON)
include(${PROJECT_SOURCE_DIR}/../cmake/FindNumPy.cmake) include(${PROJECT_SOURCE_DIR}/../cmake/FindNumPy.cmake)
endif() endif()
if (USE_CUDA) if (USE_CUDA)
find_package(CUDA REQUIRED) include(${PROJECT_SOURCE_DIR}/../cmake/FindCUDA.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/SelectCudaArch.cmake) endif()
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH ${CUDA_ARCH}) if (USE_CUDNN)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}") include(${PROJECT_SOURCE_DIR}/../cmake/FindCUDNN.cmake)
if (MSVC) endif()
# Suppress all warnings for msvc compiler if (USE_MPI)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -w") include(${PROJECT_SOURCE_DIR}/../cmake/FindMPI.cmake)
else()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++14")
endif()
endif() endif()
if (USE_TENSORRT) if (USE_TENSORRT)
if (NOT TENSORRT_SDK_ROOT_DIR) if (NOT TENSORRT_SDK_ROOT_DIR)
...@@ -30,10 +27,11 @@ if (USE_TENSORRT) ...@@ -30,10 +27,11 @@ if (USE_TENSORRT)
endif() endif()
endif() endif()
# ---[ Include directories # ---[ Directories
include_directories(${PROJECT_SOURCE_DIR}/../) include_directories(${PROJECT_SOURCE_DIR}/../)
include_directories(${THIRD_PARTY_DIR}/eigen) include_directories(${THIRD_PARTY_DIR}/eigen)
include_directories(${PROTOBUF_SDK_ROOT_DIR}/include) include_directories(${PROTOBUF_SDK_ROOT_DIR}/include)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib)
if(APPLE) if(APPLE)
include_directories(/usr/local/include) include_directories(/usr/local/include)
endif() endif()
...@@ -43,36 +41,18 @@ if (BUILD_PYTHON) ...@@ -43,36 +41,18 @@ if (BUILD_PYTHON)
include_directories(${THIRD_PARTY_DIR}/pybind11/include) include_directories(${THIRD_PARTY_DIR}/pybind11/include)
endif() endif()
if (USE_CUDA) if (USE_CUDA)
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIR})
include_directories(${THIRD_PARTY_DIR}/cub)
endif() endif()
if (USE_CUDNN) if (USE_CUDNN)
include_directories(${THIRD_PARTY_DIR}/cudnn/include) include_directories(${CUDNN_INCLUDE_DIR})
endif() endif()
if (USE_MPI) if (USE_MPI)
include_directories(${THIRD_PARTY_DIR}/mpi/include) include_directories(${MPI_INCLUDE_DIR})
endif() endif()
if (USE_TENSORRT) if (USE_TENSORRT)
include_directories(${TENSORRT_SDK_ROOT_DIR}/include) include_directories(${TENSORRT_SDK_ROOT_DIR}/include)
endif()
# ---[ Library directories
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${PROTOBUF_SDK_ROOT_DIR}/lib)
if (USE_MPI)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
if (USE_CUDNN)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/cudnn/lib/x64)
endif()
if (USE_TENSORRT)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${TENSORRT_SDK_ROOT_DIR}/lib) list(APPEND THIRD_PARTY_LIBRARY_DIRS ${TENSORRT_SDK_ROOT_DIR}/lib)
endif() endif()
if (USE_CUDA)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
# ---[ Defines # ---[ Defines
if (BUILD_PYTHON) if (BUILD_PYTHON)
......
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# - Find the CUDA libraries
#
# Following variables can be set and are optional:
#
# CUDA_VERSION - version of the CUDA
# CUDA_VERSION_MAJOR - the major version number of CUDA
# CUDA_VERSION_MINOR - the minor version number of CUDA
# CUDA_TOOLKIT_ROOT_DIR - path to the CUDA toolkit
# CUDA_INCLUDE_DIR - path to the CUDA headers
# CUDA_LIBRARIES_SHARED - path to the CUDA shared library
# CUDA_LIBRARIES_STATIC - path to the CUDA static library
#
find_package(CUDA REQUIRED)
include(${PROJECT_SOURCE_DIR}/../cmake/SelectCudaArch.cmake)
# Set NVCC flags.
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH ${CUDA_ARCH})
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_ARCH}")
if (MSVC)
# Suppress all warnings for msvc compiler.
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -w")
else()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++14")
endif()
# Set include directory.
set(CUDA_INCLUDE_DIR ${CUDA_INCLUDE_DIRS} ${THIRD_PARTY_DIR}/cub)
# Set libraries.
if (EXISTS "${CUDA_TOOLKIT_ROOT_DIR}/lib64")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
elseif (EXISTS "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
set(CUDA_LIBRARIES_SHARED cudart cublas curand)
set(CUDA_LIBRARIES_STATIC culibos cudart_static cublas_static curand_static)
if (CUDA_VERSION VERSION_GREATER "10.0")
set(CUDA_LIBRARIES_STATIC ${CUDA_LIBRARIES_STATIC} cublasLt_static)
endif()
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# - Find the CUDNN libraries
#
# Following variables can be set and are optional:
#
# CUDNN_VERSION - version of the CUDNN
# CUDNN_VERSION_MAJOR - the major version number of CUDNN
# CUDNN_VERSION_MINOR - the minor version number of CUDNN
# CUDNN_VERSION_PATCH - the patch version number of CUDNN
# CUDNN_INCLUDE_DIR - path to the CUDNN headers
# CUDNN_LIBRARIES_SHARED - path to the CUDNN shared library
# CUDNN_LIBRARIES_STATIC - path to the CUDNN static library
#
# Set include directory.
if (EXISTS "${CUDA_INCLUDE_DIRS}/cudnn.h")
set(CUDNN_INCLUDE_DIR ${CUDA_INCLUDE_DIRS})
elseif (EXISTS "${THIRD_PARTY_DIR}/cudnn/include/cudnn.h")
set(CUDNN_INCLUDE_DIR ${THIRD_PARTY_DIR}/cudnn/include)
endif()
# Set version.
if (CUDNN_INCLUDE_DIR)
set(_file ${CUDNN_INCLUDE_DIR}/cudnn.h)
if (EXISTS "${CUDNN_INCLUDE_DIR}/cudnn_version.h")
set(_file ${CUDNN_INCLUDE_DIR}/cudnn_version.h)
endif()
file(READ ${_file} tmp)
string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" _major "${tmp}")
string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" _major "${_major}")
string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" _minor "${tmp}")
string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" _minor "${_minor}")
string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" _patch "${tmp}")
string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" _patch "${_patch}")
set(CUDNN_VERSION ${_major}.${_minor}.${_patch})
set(CUDNN_VERSION_MAJOR ${_major})
set(CUDNN_VERSION_MINOR ${_minor})
set(CUDNN_VERSION_PATCH ${_patch})
endif()
# Check version.
if (CUDNN_VERSION VERSION_LESS "7.0.0")
message(FATAL_ERROR "CuDNN ${CUDNN_VERSION} is not supported. (Required: >= 7.0.0)")
else()
get_filename_component(_dir "${CUDNN_INCLUDE_DIR}" ABSOLUTE)
message(STATUS "Found CUDNN: ${_dir} (found version \"${CUDNN_VERSION}\")")
endif()
# Set libraries.
if (EXISTS "${CUDNN_INCLUDE_DIR}/../lib64")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDNN_INCLUDE_DIR}/../lib64)
elseif (EXISTS "${CUDNN_INCLUDE_DIR}/../lib/x64")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDNN_INCLUDE_DIR}/../lib/x64)
elseif (EXISTS "${CUDNN_INCLUDE_DIR}/../lib")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${CUDNN_INCLUDE_DIR}/../lib)
endif()
set(CUDNN_LIBRARIES_SHARED cudnn)
set(CUDNN_LIBRARIES_STATIC cudnn_static)
if (CUDNN_VERSION VERSION_GREATER "8.2.4")
set(CUDNN_LIBRARIES_STATIC cudnn_adv_infer_static cudnn_adv_train_static
cudnn_cnn_infer_static cudnn_cnn_train_static
cudnn_ops_infer_static cudnn_ops_train_static)
endif()
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
# Licensed under the BSD 2-Clause License.
# Following variables can be set and are optional:
#
# MPI_INCLUDE_DIR - path to the MPI headers
# MPI_LIBRARIES - path to the MPI library
# MPI_LIBRARIES_SHARED - path to the MPI shared library
# MPI_LIBRARIES_STATIC - path to the MPI static library
#
# Set include directory.
set(MPI_INCLUDE_DIR ${THIRD_PARTY_DIR}/mpi/include)
# Set libraries.
if (EXISTS "${THIRD_PARTY_DIR}/mpi/lib")
list(APPEND THIRD_PARTY_LIBRARY_DIRS ${THIRD_PARTY_DIR}/mpi/lib)
endif()
set(MPI_LIBRARIES z)
if (UNIX AND (NOT APPLE))
set(MPI_LIBRARIES ${MPI_LIBRARIES} udev)
endif()
set(MPI_LIBRARIES_SHARED mpi)
set(MPI_LIBRARIES_STATIC mpi open-rte open-pal)
# - Find the NumPy libraries # - Find the NumPy libraries
# This module finds if NumPy is installed, and sets the following variables # This module finds if NumPy is installed, and sets Following variables
# indicating where it is. # indicating where it is.
# #
# TODO: Update to provide the libraries and paths for linking npymath lib. # TODO: Update to provide the libraries and paths for linking npymath lib.
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
# - Find the protobuf libraries # - Find the protobuf libraries
# #
# The Following variables can be set and are optional: # Following variables can be set and are optional:
# #
# PROTOBUF_SDK_ROOT_DIR - The root dir of protobuf sdk # PROTOBUF_SDK_ROOT_DIR - The root dir of protobuf sdk
# PROTOBUF_PROTOC_EXECUTABLE - The protoc compiler # PROTOBUF_PROTOC_EXECUTABLE - The protoc compiler
# #
# The following function are defined: # Following function are defined:
# #
# protobuf_generate_cpp(<proto_file>...) - Process the proto to C++ sources # protobuf_generate_cpp(<proto_file>...) - Process the proto to C++ sources
# protobuf_generate_lite(<proto_file>...) - Process the proto to Lite sources # protobuf_generate_lite(<proto_file>...) - Process the proto to Lite sources
......
# - Find python libraries # - Find python libraries
# This module finds the libraries corresponding to the Python interpreter # This module finds the libraries corresponding to the Python interpreter
# FindPythonInterp provides. # FindPythonInterp provides.
# This code sets the following variables: # This code sets Following variables:
# #
# PYTHONLIBS_FOUND - have the Python libs been found # PYTHONLIBS_FOUND - have the Python libs been found
# PYTHON_PREFIX - path to the Python installation # PYTHON_PREFIX - path to the Python installation
...@@ -22,14 +22,14 @@ ...@@ -22,14 +22,14 @@
# All rights reserved. # All rights reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions # modification, are permitted provided that Following conditions
# are met: # are met:
# #
# * Redistributions of source code must retain the above copyright # * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer. # notice, this list of conditions and Following disclaimer.
# #
# * Redistributions in binary form must reproduce the above copyright # * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the # notice, this list of conditions and Following disclaimer in the
# documentation and/or other materials provided with the distribution. # documentation and/or other materials provided with the distribution.
# #
# * Neither the names of Kitware, Inc., the Insight Software Consortium, # * Neither the names of Kitware, Inc., the Insight Software Consortium,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# - Link the libraries according to the project suffix hint # - Link the libraries according to the project suffix hint
# #
# The following function are defined: # Following function are defined:
# #
# target_link_libraries_v2(<target> <item>...]) - Link the libraries to target # target_link_libraries_v2(<target> <item>...]) - Link the libraries to target
# target_get_libraries(<out_variable> <target> [<keyword>...]) - Query the libraries of target # target_get_libraries(<out_variable> <target> [<keyword>...]) - Query the libraries of target
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# - Strip debug information from a c/c++ target # - Strip debug information from a c/c++ target
# #
# The following function are defined: # Following function are defined:
# #
# strip_debug_symbol(<target>) - Strip the debug symbol in the target # strip_debug_symbol(<target>) - Strip the debug symbol in the target
# #
......
...@@ -58,6 +58,7 @@ include(${PROJECT_SOURCE_DIR}/../cmake/MiscCheck.cmake) ...@@ -58,6 +58,7 @@ include(${PROJECT_SOURCE_DIR}/../cmake/MiscCheck.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/LinkLibrary.cmake) include(${PROJECT_SOURCE_DIR}/../cmake/LinkLibrary.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/StripDebugInfo.cmake) include(${PROJECT_SOURCE_DIR}/../cmake/StripDebugInfo.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/Dependencies.cmake) include(${PROJECT_SOURCE_DIR}/../cmake/Dependencies.cmake)
include(${PROJECT_SOURCE_DIR}/../cmake/Codegen.cmake)
# ---[ CMake settings # ---[ CMake settings
set(CMAKE_BUILD_TYPE Release CACHE INTERNAL "" FORCE) set(CMAKE_BUILD_TYPE Release CACHE INTERNAL "" FORCE)
...@@ -70,22 +71,6 @@ if (NOT LIBRARY_INSTALL_PREFIX) ...@@ -70,22 +71,6 @@ if (NOT LIBRARY_INSTALL_PREFIX)
set(LIBRARY_INSTALL_PREFIX "") set(LIBRARY_INSTALL_PREFIX "")
endif() endif()
# ---[ Command - Protobuf
file(GLOB PROTO_FILES ${PROJECT_SOURCE_DIR}/proto/*.proto)
protobuf_generate_cpp(${PROTO_FILES})
# ---[ Command - CodeGen
if (PYTHON_EXECUTABLE AND BUILD_RUNTIME)
set(HAS_RUNTIME_CODEGEN ON)
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PROJECT_SOURCE_DIR}/../tools/codegen_runtime.py
${PROJECT_SOURCE_DIR} "REMOVE_GRADIENT")
else()
set(HAS_RUNTIME_CODEGEN OFF)
endif()
# ---[ Subdirectories # ---[ Subdirectories
if (BUILD_PYTHON) if (BUILD_PYTHON)
add_subdirectory(modules/python) add_subdirectory(modules/python)
......
...@@ -55,23 +55,18 @@ if (USE_OPENMP) ...@@ -55,23 +55,18 @@ if (USE_OPENMP)
endif() endif()
if (USE_CUDA) if (USE_CUDA)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragon cudart) target_link_libraries_v2(dragon ${CUDA_LIBRARIES_SHARED})
target_link_libraries_v2(dragon cublas)
target_link_libraries_v2(dragon curand)
else() else()
target_link_libraries_v2(dragon cudart_static) target_link_libraries_v2(dragon ${CUDA_LIBRARIES_STATIC})
target_link_libraries_v2(dragon cublas_static)
target_link_libraries_v2(dragon curand_static)
if (CUDA_VERSION VERSION_GREATER "10.0")
target_link_libraries_v2(dragon cublasLt_static)
endif()
endif() endif()
endif() endif()
if (USE_CUDNN) if (USE_CUDNN)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragon cudnn) target_link_libraries_v2(dragon ${CUDNN_LIBRARIES_SHARED})
else() else()
target_link_libraries_v2(dragon_python -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive) target_link_libraries_v2(
dragon_python -Wl,--whole-archive
${CUDNN_LIBRARIES_STATIC} -Wl,--no-whole-archive)
endif() endif()
endif() endif()
if (USE_NCCL) if (USE_NCCL)
...@@ -81,25 +76,15 @@ if (USE_NCCL) ...@@ -81,25 +76,15 @@ if (USE_NCCL)
target_link_libraries_v2(dragon nccl_static) target_link_libraries_v2(dragon nccl_static)
endif() endif()
endif() endif()
if (USE_CUDA AND (NOT USE_SHARED_LIBS))
target_link_libraries_v2(dragon culibos)
endif()
if (USE_MPI) if (USE_MPI)
target_link_libraries(dragon ${MPI_LIBRARIES})
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragon mpi) target_link_libraries_v2(dragon ${MPI_LIBRARIES_SHARED})
else()
target_link_libraries_v2(dragon mpi open-rte open-pal)
if (UNIX)
if (APPLE)
target_link_libraries(dragon z)
else() else()
target_link_libraries(dragon z udev) target_link_libraries_v2(dragon ${MPI_LIBRARIES_STATIC})
endif()
endif()
endif() endif()
endif() endif()
if(WIN32) if (WIN32)
target_link_libraries(dragon ${PYTHON_LIBRARIES})
target_link_libraries(dragon_python ${PYTHON_LIBRARIES}) target_link_libraries(dragon_python ${PYTHON_LIBRARIES})
endif() endif()
......
...@@ -63,28 +63,20 @@ if (USE_OPENMP) ...@@ -63,28 +63,20 @@ if (USE_OPENMP)
endif() endif()
if (USE_CUDA) if (USE_CUDA)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragonrt cudart) target_link_libraries_v2(dragonrt ${CUDA_LIBRARIES_SHARED})
target_link_libraries_v2(dragonrt cublas)
target_link_libraries_v2(dragonrt curand)
else() else()
target_link_libraries_v2(dragonrt cudart_static) target_link_libraries_v2(dragonrt ${CUDA_LIBRARIES_STATIC})
target_link_libraries_v2(dragonrt cublas_static)
target_link_libraries_v2(dragonrt curand_static)
if (CUDA_VERSION VERSION_GREATER "10.0")
target_link_libraries_v2(dragonrt cublasLt_static)
endif()
endif() endif()
endif() endif()
if (USE_CUDNN) if (USE_CUDNN)
if (USE_SHARED_LIBS) if (USE_SHARED_LIBS)
target_link_libraries_v2(dragonrt cudnn) target_link_libraries_v2(dragonrt ${CUDNN_LIBRARIES_SHARED})
else() else()
target_link_libraries_v2(dragonrt -Wl,--whole-archive cudnn_static -Wl,--no-whole-archive) target_link_libraries_v2(
dragonrt -Wl,--whole-archive
${CUDNN_LIBRARIES_STATIC} -Wl,--no-whole-archive)
endif() endif()
endif() endif()
if (USE_CUDA AND (NOT USE_SHARED_LIBS))
target_link_libraries_v2(dragonrt culibos)
endif()
# ---[ Command - Strip # ---[ Command - Strip
strip_debug_symbol(dragonrt) strip_debug_symbol(dragonrt)
......
...@@ -171,6 +171,7 @@ setuptools.setup( ...@@ -171,6 +171,7 @@ setuptools.setup(
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
......
...@@ -82,7 +82,6 @@ def path_remove_gradient(project_source_dir): ...@@ -82,7 +82,6 @@ def path_remove_gradient(project_source_dir):
FileWriter().apply_regex( FileWriter().apply_regex(
glob_recurse(operators_dir, '.cc'), [ glob_recurse(operators_dir, '.cc'), [
r'DEPLOY_.+[(].*Gradient[)][;]', r'DEPLOY_.+[(].*Gradient[)][;]',
r'OPERATOR_SCHEMA[(].+Gradient.*[)][\s\S]*?[;]',
r'REGISTER_GRADIENT[(].+[)][;]', r'REGISTER_GRADIENT[(].+[)][;]',
r'class GradientMaker[\s\S]*[;]', r'class GradientMaker[\s\S]*[;]',
] ]
...@@ -90,14 +89,10 @@ def path_remove_gradient(project_source_dir): ...@@ -90,14 +89,10 @@ def path_remove_gradient(project_source_dir):
if __name__ == '__main__': if __name__ == '__main__':
path_remove_gradient('/Users/neo/workspace/dragon/dragon')
while True:
pass
if len(sys.argv) != 3: if len(sys.argv) != 3:
raise ValueError('Usage: codegen.py ' raise ValueError('Usage: codegen.py '
'<PROJECT_SOURCE_DIR> <PATH_NAME>') '<PROJECT_SOURCE_DIR> <PATH_NAME>')
project_source_dir, path_name = sys.argv[1:] project_source_dir, path_name = sys.argv[1:]
if path_name == 'REMOVE_GRADIENT': if path_name == 'REMOVE_GRADIENT':
path_remove_gradient(project_source_dir) path_remove_gradient(project_source_dir)
else: else:
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!