setup.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. import sys
  2. import warnings
  3. import os
  4. import glob
  5. from packaging.version import parse, Version
  6. from setuptools import setup, find_packages
  7. import subprocess
  8. import torch
  9. from torch.utils.cpp_extension import (
  10. BuildExtension,
  11. CppExtension,
  12. CUDAExtension,
  13. CUDA_HOME,
  14. load,
  15. )
  16. # ninja build does not work unless include_dirs are abs path
  17. this_dir = os.path.dirname(os.path.abspath(__file__))
  18. def get_cuda_bare_metal_version(cuda_dir):
  19. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  20. output = raw_output.split()
  21. release_idx = output.index("release") + 1
  22. bare_metal_version = parse(output[release_idx].split(",")[0])
  23. return raw_output, bare_metal_version
  24. def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
  25. raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
  26. torch_binary_version = parse(torch.version.cuda)
  27. print("\nCompiling cuda extensions with")
  28. print(raw_output + "from " + cuda_dir + "/bin\n")
  29. if (bare_metal_version != torch_binary_version):
  30. raise RuntimeError(
  31. "Cuda extensions are being compiled with a version of Cuda that does "
  32. "not match the version used to compile Pytorch binaries. "
  33. "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
  34. + "In some cases, a minor-version mismatch will not cause later errors: "
  35. "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
  36. "You can try commenting out this check (at your own risk)."
  37. )
  38. def raise_if_cuda_home_none(global_option: str) -> None:
  39. if CUDA_HOME is not None:
  40. return
  41. raise RuntimeError(
  42. f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
  43. "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
  44. "only images whose names contain 'devel' will provide nvcc."
  45. )
  46. def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
  47. cudnn_available = torch.backends.cudnn.is_available()
  48. cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
  49. if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
  50. warnings.warn(
  51. f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, "
  52. f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
  53. )
  54. return False
  55. return True
  56. if not torch.cuda.is_available():
  57. # https://github.com/NVIDIA/apex/issues/486
  58. # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
  59. # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
  60. print(
  61. "\nWarning: Torch did not find available GPUs on this system.\n",
  62. "If your intention is to cross-compile, this is not an error.\n"
  63. "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
  64. "Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
  65. "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
  66. "If you wish to cross-compile for a single specific architecture,\n"
  67. 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
  68. )
  69. if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
  70. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  71. if bare_metal_version >= Version("11.8"):
  72. os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
  73. elif bare_metal_version >= Version("11.1"):
  74. os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
  75. elif bare_metal_version == Version("11.0"):
  76. os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
  77. else:
  78. os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
  79. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  80. TORCH_MAJOR = int(torch.__version__.split(".")[0])
  81. TORCH_MINOR = int(torch.__version__.split(".")[1])
  82. if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
  83. raise RuntimeError(
  84. "Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/"
  85. )
  86. cmdclass = {}
  87. ext_modules = []
  88. extras = {}
  89. if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
  90. if TORCH_MAJOR == 0:
  91. raise RuntimeError(
  92. "--cpp_ext requires Pytorch 1.0 or later, " "found torch.__version__ = {}".format(torch.__version__)
  93. )
  94. if "--cpp_ext" in sys.argv:
  95. sys.argv.remove("--cpp_ext")
  96. ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"]))
  97. # Set up macros for forward/backward compatibility hack around
  98. # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
  99. # and
  100. # https://github.com/NVIDIA/apex/issues/456
  101. # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
  102. version_ge_1_1 = []
  103. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  104. version_ge_1_1 = ["-DVERSION_GE_1_1"]
  105. version_ge_1_3 = []
  106. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  107. version_ge_1_3 = ["-DVERSION_GE_1_3"]
  108. version_ge_1_5 = []
  109. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  110. version_ge_1_5 = ["-DVERSION_GE_1_5"]
  111. version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  112. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  113. if "--distributed_adam" in sys.argv:
  114. sys.argv.remove("--distributed_adam")
  115. raise_if_cuda_home_none("--distributed_adam")
  116. ext_modules.append(
  117. CUDAExtension(
  118. name="distributed_adam_cuda",
  119. sources=[
  120. "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp",
  121. "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu",
  122. ],
  123. include_dirs=[os.path.join(this_dir, "csrc")],
  124. extra_compile_args={
  125. "cxx": ["-O3"] + version_dependent_macros,
  126. "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
  127. },
  128. )
  129. )
  130. if "--distributed_lamb" in sys.argv:
  131. sys.argv.remove("--distributed_lamb")
  132. raise_if_cuda_home_none("--distributed_lamb")
  133. ext_modules.append(
  134. CUDAExtension(
  135. name="distributed_lamb_cuda",
  136. sources=[
  137. "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp",
  138. "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu",
  139. ],
  140. include_dirs=[os.path.join(this_dir, "csrc")],
  141. extra_compile_args={
  142. "cxx": ["-O3"] + version_dependent_macros,
  143. "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
  144. },
  145. )
  146. )
  147. if "--cuda_ext" in sys.argv:
  148. sys.argv.remove("--cuda_ext")
  149. raise_if_cuda_home_none("--cuda_ext")
  150. check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
  151. ext_modules.append(
  152. CUDAExtension(
  153. name="amp_C",
  154. sources=[
  155. "csrc/amp_C_frontend.cpp",
  156. "csrc/multi_tensor_sgd_kernel.cu",
  157. "csrc/multi_tensor_scale_kernel.cu",
  158. "csrc/multi_tensor_axpby_kernel.cu",
  159. "csrc/multi_tensor_l2norm_kernel.cu",
  160. "csrc/multi_tensor_l2norm_kernel_mp.cu",
  161. "csrc/multi_tensor_l2norm_scale_kernel.cu",
  162. "csrc/multi_tensor_lamb_stage_1.cu",
  163. "csrc/multi_tensor_lamb_stage_2.cu",
  164. "csrc/multi_tensor_adam.cu",
  165. "csrc/multi_tensor_adagrad.cu",
  166. "csrc/multi_tensor_novograd.cu",
  167. "csrc/multi_tensor_lamb.cu",
  168. "csrc/multi_tensor_lamb_mp.cu",
  169. "csrc/update_scale_hysteresis.cu",
  170. ],
  171. extra_compile_args={
  172. "cxx": ["-O3"] + version_dependent_macros,
  173. "nvcc": [
  174. "-lineinfo",
  175. "-O3",
  176. # '--resource-usage',
  177. "--use_fast_math",
  178. ] + version_dependent_macros,
  179. },
  180. )
  181. )
  182. ext_modules.append(
  183. CUDAExtension(
  184. name="syncbn",
  185. sources=["csrc/syncbn.cpp", "csrc/welford.cu"],
  186. extra_compile_args={
  187. "cxx": ["-O3"] + version_dependent_macros,
  188. "nvcc": ["-O3"] + version_dependent_macros,
  189. },
  190. )
  191. )
  192. ext_modules.append(
  193. CUDAExtension(
  194. name="fused_layer_norm_cuda",
  195. sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"],
  196. extra_compile_args={
  197. "cxx": ["-O3"] + version_dependent_macros,
  198. "nvcc": ["-maxrregcount=50", "-O3", "--use_fast_math"] + version_dependent_macros,
  199. },
  200. )
  201. )
  202. ext_modules.append(
  203. CUDAExtension(
  204. name="mlp_cuda",
  205. sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"],
  206. extra_compile_args={
  207. "cxx": ["-O3"] + version_dependent_macros,
  208. "nvcc": ["-O3"] + version_dependent_macros,
  209. },
  210. )
  211. )
  212. ext_modules.append(
  213. CUDAExtension(
  214. name="fused_dense_cuda",
  215. sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"],
  216. extra_compile_args={
  217. "cxx": ["-O3"] + version_dependent_macros,
  218. "nvcc": ["-O3"] + version_dependent_macros,
  219. },
  220. )
  221. )
  222. ext_modules.append(
  223. CUDAExtension(
  224. name="scaled_upper_triang_masked_softmax_cuda",
  225. sources=[
  226. "csrc/megatron/scaled_upper_triang_masked_softmax.cpp",
  227. "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu",
  228. ],
  229. include_dirs=[os.path.join(this_dir, "csrc")],
  230. extra_compile_args={
  231. "cxx": ["-O3"] + version_dependent_macros,
  232. "nvcc": [
  233. "-O3",
  234. "-U__CUDA_NO_HALF_OPERATORS__",
  235. "-U__CUDA_NO_HALF_CONVERSIONS__",
  236. "--expt-relaxed-constexpr",
  237. "--expt-extended-lambda",
  238. ] + version_dependent_macros,
  239. },
  240. )
  241. )
  242. ext_modules.append(
  243. CUDAExtension(
  244. name="generic_scaled_masked_softmax_cuda",
  245. sources=[
  246. "csrc/megatron/generic_scaled_masked_softmax.cpp",
  247. "csrc/megatron/generic_scaled_masked_softmax_cuda.cu",
  248. ],
  249. include_dirs=[os.path.join(this_dir, "csrc")],
  250. extra_compile_args={
  251. "cxx": ["-O3"] + version_dependent_macros,
  252. "nvcc": [
  253. "-O3",
  254. "-U__CUDA_NO_HALF_OPERATORS__",
  255. "-U__CUDA_NO_HALF_CONVERSIONS__",
  256. "--expt-relaxed-constexpr",
  257. "--expt-extended-lambda",
  258. ] + version_dependent_macros,
  259. },
  260. )
  261. )
  262. ext_modules.append(
  263. CUDAExtension(
  264. name="scaled_masked_softmax_cuda",
  265. sources=["csrc/megatron/scaled_masked_softmax.cpp", "csrc/megatron/scaled_masked_softmax_cuda.cu"],
  266. include_dirs=[os.path.join(this_dir, "csrc")],
  267. extra_compile_args={
  268. "cxx": ["-O3"] + version_dependent_macros,
  269. "nvcc": [
  270. "-O3",
  271. "-U__CUDA_NO_HALF_OPERATORS__",
  272. "-U__CUDA_NO_HALF_CONVERSIONS__",
  273. "--expt-relaxed-constexpr",
  274. "--expt-extended-lambda",
  275. ] + version_dependent_macros,
  276. },
  277. )
  278. )
  279. ext_modules.append(
  280. CUDAExtension(
  281. name="scaled_softmax_cuda",
  282. sources=["csrc/megatron/scaled_softmax.cpp", "csrc/megatron/scaled_softmax_cuda.cu"],
  283. include_dirs=[os.path.join(this_dir, "csrc")],
  284. extra_compile_args={
  285. "cxx": ["-O3"] + version_dependent_macros,
  286. "nvcc": [
  287. "-O3",
  288. "-U__CUDA_NO_HALF_OPERATORS__",
  289. "-U__CUDA_NO_HALF_CONVERSIONS__",
  290. "--expt-relaxed-constexpr",
  291. "--expt-extended-lambda",
  292. ] + version_dependent_macros,
  293. },
  294. )
  295. )
  296. ext_modules.append(
  297. CUDAExtension(
  298. name="fused_rotary_positional_embedding",
  299. sources=[
  300. "csrc/megatron/fused_rotary_positional_embedding.cpp",
  301. "csrc/megatron/fused_rotary_positional_embedding_cuda.cu",
  302. ],
  303. include_dirs=[os.path.join(this_dir, "csrc")],
  304. extra_compile_args={
  305. "cxx": ["-O3"] + version_dependent_macros,
  306. "nvcc": [
  307. "-O3",
  308. "-U__CUDA_NO_HALF_OPERATORS__",
  309. "-U__CUDA_NO_HALF_CONVERSIONS__",
  310. "--expt-relaxed-constexpr",
  311. "--expt-extended-lambda",
  312. ] + version_dependent_macros,
  313. },
  314. )
  315. )
  316. if bare_metal_version >= Version("11.0"):
  317. cc_flag = []
  318. cc_flag.append("-gencode")
  319. cc_flag.append("arch=compute_70,code=sm_70")
  320. cc_flag.append("-gencode")
  321. cc_flag.append("arch=compute_80,code=sm_80")
  322. if bare_metal_version >= Version("11.1"):
  323. cc_flag.append("-gencode")
  324. cc_flag.append("arch=compute_86,code=sm_86")
  325. if bare_metal_version >= Version("11.8"):
  326. cc_flag.append("-gencode")
  327. cc_flag.append("arch=compute_90,code=sm_90")
  328. ext_modules.append(
  329. CUDAExtension(
  330. name="fused_weight_gradient_mlp_cuda",
  331. include_dirs=[os.path.join(this_dir, "csrc")],
  332. sources=[
  333. "csrc/megatron/fused_weight_gradient_dense.cpp",
  334. "csrc/megatron/fused_weight_gradient_dense_cuda.cu",
  335. "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
  336. ],
  337. extra_compile_args={
  338. "cxx": ["-O3"] + version_dependent_macros,
  339. "nvcc": [
  340. "-O3",
  341. "-U__CUDA_NO_HALF_OPERATORS__",
  342. "-U__CUDA_NO_HALF_CONVERSIONS__",
  343. "--expt-relaxed-constexpr",
  344. "--expt-extended-lambda",
  345. "--use_fast_math",
  346. ] + version_dependent_macros + cc_flag,
  347. },
  348. )
  349. )
  350. if "--permutation_search" in sys.argv:
  351. sys.argv.remove("--permutation_search")
  352. if CUDA_HOME is None:
  353. raise RuntimeError("--permutation_search was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
  354. else:
  355. cc_flag = ['-Xcompiler', '-fPIC', '-shared']
  356. ext_modules.append(
  357. CUDAExtension(name='permutation_search_cuda',
  358. sources=['apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu'],
  359. include_dirs=[os.path.join(this_dir, 'apex', 'contrib', 'sparsity', 'permutation_search_kernels', 'CUDA_kernels')],
  360. extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
  361. 'nvcc':['-O3'] + version_dependent_macros + cc_flag}))
  362. if "--bnp" in sys.argv:
  363. sys.argv.remove("--bnp")
  364. raise_if_cuda_home_none("--bnp")
  365. ext_modules.append(
  366. CUDAExtension(
  367. name="bnp",
  368. sources=[
  369. "apex/contrib/csrc/groupbn/batch_norm.cu",
  370. "apex/contrib/csrc/groupbn/ipc.cu",
  371. "apex/contrib/csrc/groupbn/interface.cpp",
  372. "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu",
  373. ],
  374. include_dirs=[os.path.join(this_dir, "csrc")],
  375. extra_compile_args={
  376. "cxx": [] + version_dependent_macros,
  377. "nvcc": [
  378. "-DCUDA_HAS_FP16=1",
  379. "-D__CUDA_NO_HALF_OPERATORS__",
  380. "-D__CUDA_NO_HALF_CONVERSIONS__",
  381. "-D__CUDA_NO_HALF2_OPERATORS__",
  382. ] + version_dependent_macros,
  383. },
  384. )
  385. )
  386. if "--xentropy" in sys.argv:
  387. from datetime import datetime
  388. sys.argv.remove("--xentropy")
  389. raise_if_cuda_home_none("--xentropy")
  390. xentropy_ver = datetime.today().strftime("%y.%m.%d")
  391. print(f"`--xentropy` setting version of {xentropy_ver}")
  392. ext_modules.append(
  393. CUDAExtension(
  394. name="xentropy_cuda",
  395. sources=["apex/contrib/csrc/xentropy/interface.cpp", "apex/contrib/csrc/xentropy/xentropy_kernel.cu"],
  396. include_dirs=[os.path.join(this_dir, "csrc")],
  397. extra_compile_args={
  398. "cxx": ["-O3"] + version_dependent_macros + [f'-DXENTROPY_VER="{xentropy_ver}"'],
  399. "nvcc": ["-O3"] + version_dependent_macros,
  400. },
  401. )
  402. )
  403. if "--focal_loss" in sys.argv:
  404. sys.argv.remove("--focal_loss")
  405. raise_if_cuda_home_none("--focal_loss")
  406. ext_modules.append(
  407. CUDAExtension(
  408. name='focal_loss_cuda',
  409. sources=[
  410. 'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp',
  411. 'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu',
  412. ],
  413. include_dirs=[os.path.join(this_dir, 'csrc')],
  414. extra_compile_args={
  415. 'cxx': ['-O3'] + version_dependent_macros,
  416. 'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,
  417. },
  418. )
  419. )
  420. if "--group_norm" in sys.argv:
  421. sys.argv.remove("--group_norm")
  422. raise_if_cuda_home_none("--group_norm")
  423. # CUDA group norm supports from SM70
  424. arch_flags = []
  425. for arch in [70, 75, 80, 86, 90]:
  426. arch_flag = f"-gencode=arch=compute_{arch},code=sm_{arch}"
  427. arch_flags.append(arch_flag)
  428. arch_flag = f"-gencode=arch=compute_90,code=compute_90"
  429. arch_flags.append(arch_flag)
  430. ext_modules.append(
  431. CUDAExtension(
  432. name="group_norm_cuda",
  433. sources=[
  434. "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp",
  435. ] + glob.glob("apex/contrib/csrc/group_norm/*.cu"),
  436. include_dirs=[os.path.join(this_dir, 'csrc')],
  437. extra_compile_args={
  438. "cxx": ["-O3", "-std=c++17"] + version_dependent_macros,
  439. "nvcc": [
  440. "-O3", "-std=c++17", "--use_fast_math", "--ftz=false",
  441. ] + arch_flags + version_dependent_macros,
  442. },
  443. )
  444. )
  445. if "--index_mul_2d" in sys.argv:
  446. sys.argv.remove("--index_mul_2d")
  447. raise_if_cuda_home_none("--index_mul_2d")
  448. ext_modules.append(
  449. CUDAExtension(
  450. name='fused_index_mul_2d',
  451. sources=[
  452. 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp',
  453. 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu',
  454. ],
  455. include_dirs=[os.path.join(this_dir, 'csrc')],
  456. extra_compile_args={
  457. 'cxx': ['-O3'] + version_dependent_macros,
  458. 'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,
  459. },
  460. )
  461. )
  462. if "--deprecated_fused_adam" in sys.argv:
  463. sys.argv.remove("--deprecated_fused_adam")
  464. raise_if_cuda_home_none("--deprecated_fused_adam")
  465. ext_modules.append(
  466. CUDAExtension(
  467. name="fused_adam_cuda",
  468. sources=[
  469. "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp",
  470. "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu",
  471. ],
  472. include_dirs=[os.path.join(this_dir, "csrc")],
  473. extra_compile_args={
  474. "cxx": ["-O3"] + version_dependent_macros,
  475. "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
  476. },
  477. )
  478. )
  479. if "--deprecated_fused_lamb" in sys.argv:
  480. sys.argv.remove("--deprecated_fused_lamb")
  481. raise_if_cuda_home_none("--deprecated_fused_lamb")
  482. ext_modules.append(
  483. CUDAExtension(
  484. name="fused_lamb_cuda",
  485. sources=[
  486. "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp",
  487. "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu",
  488. "csrc/multi_tensor_l2norm_kernel.cu",
  489. ],
  490. include_dirs=[os.path.join(this_dir, "csrc")],
  491. extra_compile_args={
  492. "cxx": ["-O3"] + version_dependent_macros,
  493. "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
  494. },
  495. )
  496. )
  497. # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
  498. # See https://github.com/pytorch/pytorch/pull/70650
  499. generator_flag = []
  500. torch_dir = torch.__path__[0]
  501. if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
  502. generator_flag = ["-DOLD_GENERATOR_PATH"]
  503. if "--fast_layer_norm" in sys.argv:
  504. sys.argv.remove("--fast_layer_norm")
  505. raise_if_cuda_home_none("--fast_layer_norm")
  506. cc_flag = []
  507. cc_flag.append("-gencode")
  508. cc_flag.append("arch=compute_70,code=sm_70")
  509. if bare_metal_version >= Version("11.0"):
  510. cc_flag.append("-gencode")
  511. cc_flag.append("arch=compute_80,code=sm_80")
  512. if bare_metal_version >= Version("11.8"):
  513. cc_flag.append("-gencode")
  514. cc_flag.append("arch=compute_90,code=sm_90")
  515. ext_modules.append(
  516. CUDAExtension(
  517. name="fast_layer_norm",
  518. sources=[
  519. "apex/contrib/csrc/layer_norm/ln_api.cpp",
  520. "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
  521. "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
  522. ],
  523. extra_compile_args={
  524. "cxx": ["-O3"] + version_dependent_macros + generator_flag,
  525. "nvcc": [
  526. "-O3",
  527. "-U__CUDA_NO_HALF_OPERATORS__",
  528. "-U__CUDA_NO_HALF_CONVERSIONS__",
  529. "-U__CUDA_NO_BFLOAT16_OPERATORS__",
  530. "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
  531. "-U__CUDA_NO_BFLOAT162_OPERATORS__",
  532. "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
  533. "-I./apex/contrib/csrc/layer_norm/",
  534. "--expt-relaxed-constexpr",
  535. "--expt-extended-lambda",
  536. "--use_fast_math",
  537. ] + version_dependent_macros + generator_flag + cc_flag,
  538. },
  539. include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")],
  540. )
  541. )
  542. if "--fmha" in sys.argv:
  543. sys.argv.remove("--fmha")
  544. raise_if_cuda_home_none("--fmha")
  545. if bare_metal_version < Version("11.0"):
  546. raise RuntimeError("--fmha only supported on sm_80 and sm_90 GPUs")
  547. cc_flag = []
  548. cc_flag.append("-gencode")
  549. cc_flag.append("arch=compute_80,code=sm_80")
  550. if bare_metal_version >= Version("11.8"):
  551. cc_flag.append("-gencode")
  552. cc_flag.append("arch=compute_90,code=sm_90")
  553. ext_modules.append(
  554. CUDAExtension(
  555. name="fmhalib",
  556. sources=[
  557. "apex/contrib/csrc/fmha/fmha_api.cpp",
  558. "apex/contrib/csrc/fmha/src/fmha_fill.cu",
  559. "apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
  560. "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
  561. "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",
  562. "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu",
  563. "apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu",
  564. "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu",
  565. "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu",
  566. "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu",
  567. "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
  568. ],
  569. extra_compile_args={
  570. "cxx": ["-O3"] + version_dependent_macros + generator_flag,
  571. "nvcc": [
  572. "-O3",
  573. "-U__CUDA_NO_HALF_OPERATORS__",
  574. "-U__CUDA_NO_HALF_CONVERSIONS__",
  575. "--expt-relaxed-constexpr",
  576. "--expt-extended-lambda",
  577. "--use_fast_math",
  578. ] + version_dependent_macros + generator_flag + cc_flag,
  579. },
  580. include_dirs=[
  581. os.path.join(this_dir, "apex/contrib/csrc"),
  582. os.path.join(this_dir, "apex/contrib/csrc/fmha/src"),
  583. ],
  584. )
  585. )
  586. if "--fast_multihead_attn" in sys.argv:
  587. sys.argv.remove("--fast_multihead_attn")
  588. raise_if_cuda_home_none("--fast_multihead_attn")
  589. cc_flag = []
  590. cc_flag.append("-gencode")
  591. cc_flag.append("arch=compute_70,code=sm_70")
  592. if bare_metal_version >= Version("11.0"):
  593. cc_flag.append("-gencode")
  594. cc_flag.append("arch=compute_80,code=sm_80")
  595. if bare_metal_version >= Version("11.1"):
  596. cc_flag.append("-gencode")
  597. cc_flag.append("arch=compute_86,code=sm_86")
  598. if bare_metal_version >= Version("11.8"):
  599. cc_flag.append("-gencode")
  600. cc_flag.append("arch=compute_90,code=sm_90")
  601. subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
  602. ext_modules.append(
  603. CUDAExtension(
  604. name="fast_multihead_attn",
  605. sources=[
  606. "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp",
  607. "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu",
  608. "apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
  609. "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
  610. "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
  611. "apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
  612. "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
  613. "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
  614. "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
  615. ],
  616. extra_compile_args={
  617. "cxx": ["-O3"] + version_dependent_macros + generator_flag,
  618. "nvcc": [
  619. "-O3",
  620. "-U__CUDA_NO_HALF_OPERATORS__",
  621. "-U__CUDA_NO_HALF_CONVERSIONS__",
  622. "--expt-relaxed-constexpr",
  623. "--expt-extended-lambda",
  624. "--use_fast_math",
  625. ]
  626. + version_dependent_macros
  627. + generator_flag
  628. + cc_flag,
  629. },
  630. include_dirs=[
  631. os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/include/"),
  632. os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/tools/util/include")
  633. ],
  634. )
  635. )
  636. if "--transducer" in sys.argv:
  637. sys.argv.remove("--transducer")
  638. raise_if_cuda_home_none("--transducer")
  639. ext_modules.append(
  640. CUDAExtension(
  641. name="transducer_joint_cuda",
  642. sources=[
  643. "apex/contrib/csrc/transducer/transducer_joint.cpp",
  644. "apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
  645. ],
  646. extra_compile_args={
  647. "cxx": ["-O3"] + version_dependent_macros + generator_flag,
  648. "nvcc": ["-O3"] + version_dependent_macros + generator_flag,
  649. },
  650. include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")],
  651. )
  652. )
  653. ext_modules.append(
  654. CUDAExtension(
  655. name="transducer_loss_cuda",
  656. sources=[
  657. "apex/contrib/csrc/transducer/transducer_loss.cpp",
  658. "apex/contrib/csrc/transducer/transducer_loss_kernel.cu",
  659. ],
  660. include_dirs=[os.path.join(this_dir, "csrc")],
  661. extra_compile_args={
  662. "cxx": ["-O3"] + version_dependent_macros,
  663. "nvcc": ["-O3"] + version_dependent_macros,
  664. },
  665. )
  666. )
  667. if "--cudnn_gbn" in sys.argv:
  668. sys.argv.remove("--cudnn_gbn")
  669. raise_if_cuda_home_none("--cudnn_gbn")
  670. if check_cudnn_version_and_warn("--cudnn_gbn", 8500):
  671. subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
  672. ext_modules.append(
  673. CUDAExtension(
  674. name="cudnn_gbn_lib",
  675. sources=[
  676. "apex/contrib/csrc/cudnn_gbn/norm_sample.cpp",
  677. "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp",
  678. ],
  679. include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
  680. extra_compile_args={"cxx": ["-O3", "-g"] + version_dependent_macros + generator_flag},
  681. )
  682. )
  683. if "--peer_memory" in sys.argv:
  684. sys.argv.remove("--peer_memory")
  685. raise_if_cuda_home_none("--peer_memory")
  686. ext_modules.append(
  687. CUDAExtension(
  688. name="peer_memory_cuda",
  689. sources=[
  690. "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
  691. "apex/contrib/csrc/peer_memory/peer_memory.cpp",
  692. ],
  693. extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
  694. )
  695. )
  696. # NOTE: Requires NCCL >= 2.10.3
  697. if "--nccl_p2p" in sys.argv:
  698. sys.argv.remove("--nccl_p2p")
  699. raise_if_cuda_home_none("--nccl_p2p")
  700. # Check NCCL version.
  701. _nccl_version_getter = load(
  702. name="_nccl_version_getter",
  703. sources=["apex/contrib/csrc/nccl_p2p/nccl_version.cpp", "apex/contrib/csrc/nccl_p2p/nccl_version_check.cu"],
  704. )
  705. _available_nccl_version = _nccl_version_getter.get_nccl_version()
  706. if _available_nccl_version >= (2, 10):
  707. ext_modules.append(
  708. CUDAExtension(
  709. name="nccl_p2p_cuda",
  710. sources=[
  711. "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
  712. "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
  713. ],
  714. extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
  715. )
  716. )
  717. else:
  718. warnings.warn(
  719. f"Skip `--nccl_p2p` as it requires NCCL 2.10.3 or later, but {_available_nccl_version[0]}.{_available_nccl_version[1]}"
  720. )
  721. # note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`.
  722. if "--fast_bottleneck" in sys.argv:
  723. sys.argv.remove("--fast_bottleneck")
  724. raise_if_cuda_home_none("--fast_bottleneck")
  725. if check_cudnn_version_and_warn("--fast_bottleneck", 8400):
  726. subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
  727. ext_modules.append(
  728. CUDAExtension(
  729. name="fast_bottleneck",
  730. sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"],
  731. include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
  732. extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
  733. )
  734. )
  735. if "--fused_conv_bias_relu" in sys.argv:
  736. sys.argv.remove("--fused_conv_bias_relu")
  737. raise_if_cuda_home_none("--fused_conv_bias_relu")
  738. if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400):
  739. subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
  740. ext_modules.append(
  741. CUDAExtension(
  742. name="fused_conv_bias_relu",
  743. sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"],
  744. include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
  745. extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
  746. )
  747. )
  748. if "--gpu_direct_storage" in sys.argv:
  749. sys.argv.remove("--gpu_direct_storage")
  750. raise_if_cuda_home_none("--gpu_direct_storage")
  751. ext_modules.append(
  752. CUDAExtension(
  753. name="_apex_gpu_direct_storage",
  754. sources=["apex/contrib/csrc/gpu_direct_storage/gds.cpp", "apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp"],
  755. include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/gpu_direct_storage")],
  756. libraries=["cufile"],
  757. extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
  758. )
  759. )
  760. setup(
  761. name="apex",
  762. version="0.1",
  763. packages=find_packages(
  764. exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",)
  765. ),
  766. install_requires=["packaging>20.6"],
  767. description="PyTorch Extensions written by NVIDIA",
  768. ext_modules=ext_modules,
  769. cmdclass={"build_ext": BuildExtension} if ext_modules else {},
  770. extras_require=extras,
  771. )