NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.16k stars 1.35k forks source link

Up to date patch for Windows compilation with Visual Studio 2022, CUDA 12.1 and PyTorch 2.2.2 #1792

Open doctorpangloss opened 3 months ago

doctorpangloss commented 3 months ago

Tested on Python 3.11

For the sake of your sanity, use Busybox for Windows so that you have a normal, native shell environment instead of PowerShell or cmd.exe. You can save Busybox as C:/Windows/sh.exe, then execute it from a command prompt using sh -ilX.

The patch: windows_support.patch

The wheel for Python 3.11: apex-0.1-cp311-cp311-win_amd64.whl.zip

Clone this repository, apply this patch and build:

# installs visual studio build tools if you do not already have it
# requires chocolatey
choco install -y visualstudio2022buildtools
choco install -y visualstudio2022-workload-vctools

# activates vcvars aka puts the compilation tools on the path
cmd /c 'C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\AUXILI~1\Build\VCVARS~1.BAT amd64 & busybox sh -ilX'

git clone https://github.com/NVIDIA/apex.git

cd apex
git checkout 810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
curl -L "https://github.com/NVIDIA/apex/files/14844602/windows_support.patch" | git apply
python -m venv venv
source venv/scripts/activate
pip install packaging wheel
export DISTUTILS_USE_SDK=1
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --config-settings "--build-option=--deprecated_fused_adam" ./

Change install to wheel if you need an installable wheel to put into your repo. Then, add this to your requirements.txt if your wheel is located at the root of your repo:

For example, if you built for Python 3.11 like I did:

apex @ {root:uri}/apex-0.1-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'

pip understands {root:uri}.

Take care to enable the options you need.

Patch:

diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
index 9209df4..8e11883 100644
--- a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+++ b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
@@ -486,7 +486,7 @@ __global__ void strided_check_finite_cuda_kernel(

     for (int j = i; j < tsize; j+=totThreads) {
         GRAD_T pi = p_copy[j];
-        if (!isfinite(pi)) {
+        if (!isfinite(static_cast<float>(pi))) {
             *noop_gmem = 1;
         }
     }
@@ -516,7 +516,7 @@ __global__ void strided_check_finite_cuda_kernel(
     for (int j = i; j < tsize; j+=totThreads) {
         at::Half pi;
         convert(p_copy[j], pi);
-        if (!isfinite(pi)) {
+        if (!isfinite(static_cast<float>(pi))) {
             *noop_gmem = 1;
         }
     }
diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp
index 830d606..c128fd6 100644
--- a/csrc/mlp.cpp
+++ b/csrc/mlp.cpp
@@ -58,7 +58,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
     output_features.push_back(inputs[i + 1].size(0));
   }

-  auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
+  auto reserved_size = (int64_t)get_mlp_reserved_space(batch_size, num_layers, output_features.data());

   // create output/workspace tensor
   auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
@@ -132,7 +132,7 @@ std::vector<at::Tensor> mlp_backward(
     }

     auto work_size =
-        get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
+            (int64_t)get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());

     // auto work_space = at::empty({work_size*4}, at::kByte);
     auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());
diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu
index f93f1df..598c8a2 100644
--- a/csrc/mlp_cuda.cu
+++ b/csrc/mlp_cuda.cu
@@ -434,7 +434,7 @@ CLEANUP:
 // Bias ADD. Assume input X is [features x batch size], column major.
 // Bias is one 'features' long vector, with implicit broadcast.
 template <typename T>
-__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
+__global__ void biasAdd_fprop(T *X, T *b, unsigned int batch_size, unsigned int features) {
   T r_x[ILP];
   T r_b[ILP];
   if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
@@ -481,7 +481,7 @@ __global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
 // Bias ADD + ReLU. Assume input X is [features x batch size], column major.
 // Activation support fuesed ReLU. Safe to call in-place.
 template <typename T>
-__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
+__global__ void biasAddRelu_fprop(T *X, T *b, unsigned int batch_size, unsigned int features) {
   T r_x[ILP];
   T r_b[ILP];
   if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
@@ -528,7 +528,7 @@ __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
 // ReLU. Assume input X is [features x batch size], column major.
 // Safe to call in-place.
 template <typename T>
-__global__ void Relu_fprop(T *X, uint batch_size, uint features) {
+__global__ void Relu_fprop(T *X, unsigned int batch_size, unsigned int features) {
   T r_x[ILP];
   if(is_aligned(X) && features % ILP ==0) {
     int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -568,7 +568,7 @@ __global__ void Relu_fprop(T *X, uint batch_size, uint features) {
 // Sigmoid. Assume input X is [features x batch size], column major.
 // Safe to call in-place.
 template <typename T>
-__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
+__global__ void Sigmoid_fprop(T *X, unsigned int batch_size, unsigned int features) {
   T r_x[ILP];
   if(is_aligned(X) && features % ILP ==0) {
     int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -608,7 +608,7 @@ __global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
 // ReLU. Assume input X is [features x batch size], column major.
 // Safe to call in-place.
 template <typename T>
-__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
+__global__ void Relu_bprop(T *dY, T *Y, unsigned int batch_size, unsigned int features, T *dX) {
   T r_dy[ILP];
   T r_y[ILP];
   if(is_aligned(dY) &&
@@ -656,7 +656,7 @@ __global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
 // Sigmoid. Assume input X is [features x batch size], column major.
 // Safe to call in-place.
 template <typename T>
-__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
+__global__ void Sigmoid_bprop(T *dY, T *Y, unsigned int batch_size, unsigned int features, T *dX) {
   T r_dy[ILP];
   T r_y[ILP];
   if(is_aligned(dY) &&
@@ -1324,7 +1324,7 @@ int mlp_fp(
         return 1;
       }

-      const uint &input_size = ofeat;
+      const unsigned int &input_size = ofeat;
       int num_blocks = 0;
       int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
       // Call biasReLU
diff --git a/csrc/multi_tensor_adagrad.cu b/csrc/multi_tensor_adagrad.cu
index 699681b..291d4fc 100644
--- a/csrc/multi_tensor_adagrad.cu
+++ b/csrc/multi_tensor_adagrad.cu
@@ -9,7 +9,7 @@

 #include "multi_tensor_apply.cuh"
 #include "type_shim.h"
-
+#define _ENABLE_EXTENDED_ALIGNED_STORAGE
 #define BLOCK_SIZE 1024
 #define ILP 4

diff --git a/csrc/multi_tensor_axpby_kernel.cu b/csrc/multi_tensor_axpby_kernel.cu
index 021df27..43bd628 100644
--- a/csrc/multi_tensor_axpby_kernel.cu
+++ b/csrc/multi_tensor_axpby_kernel.cu
@@ -72,11 +72,11 @@ struct AxpbyFunctor
         {
           r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
           if(arg_to_check == -1)
-            finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
+            finite = finite && (isfinite(static_cast<float>(r_x[ii])) && isfinite(static_cast<float>(r_y[ii])));
           if(arg_to_check == 0)
-            finite = finite && isfinite(r_x[ii]);
+            finite = finite && isfinite(static_cast<float>(r_x[ii]));
           if(arg_to_check == 1)
-            finite = finite && isfinite(r_y[ii]);
+            finite = finite && isfinite(static_cast<float>(r_y[ii]));
         }
         // store
         load_store(out, r_out, i_start , 0);
@@ -104,11 +104,11 @@ struct AxpbyFunctor
         {
           r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
           if(arg_to_check == -1)
-            finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
+            finite = finite && (isfinite(static_cast<float>(r_x[ii])) && isfinite(static_cast<float>(r_y[ii])));
           if(arg_to_check == 0)
-            finite = finite && isfinite(r_x[ii]);
+            finite = finite && isfinite(static_cast<float>(r_x[ii]));
           if(arg_to_check == 1)
-            finite = finite && isfinite(r_y[ii]);
+            finite = finite && isfinite(static_cast<float>(r_y[ii]));
         }
         // see note in multi_tensor_scale_kernel.cu
 #pragma unroll
diff --git a/csrc/multi_tensor_scale_kernel.cu b/csrc/multi_tensor_scale_kernel.cu
index 4fd848b..649c2c2 100644
--- a/csrc/multi_tensor_scale_kernel.cu
+++ b/csrc/multi_tensor_scale_kernel.cu
@@ -66,7 +66,7 @@ struct ScaleFunctor
         for(int ii = 0; ii < ILP; ii++)
         {
           r_out[ii] = static_cast<float>(r_in[ii]) * scale;
-          finite = finite && isfinite(r_in[ii]);
+          finite = finite && isfinite(static_cast<float>(r_in[ii]));
         }
         // store
         load_store(out, r_out, i_start, 0);
@@ -94,7 +94,7 @@ struct ScaleFunctor
         for(int ii = 0; ii < ILP; ii++)
         {
           r_out[ii] = static_cast<float>(r_in[ii]) * scale;
-          finite = finite && isfinite(r_in[ii]);
+          finite = finite && isfinite(static_cast<float>(r_in[ii]));
         }
 #pragma unroll
         for(int ii = 0; ii < ILP; ii++)
diff --git a/setup.py b/setup.py
index a0c7fb6..b2066f9 100644
--- a/setup.py
+++ b/setup.py
@@ -16,6 +16,7 @@ from torch.utils.cpp_extension import (
     load,
 )

+#'-gencode arch=compute_61,code=sm_61 -gencode arch=compute_61,code=compute_61',
 # ninja build does not work unless include_dirs are abs path
 this_dir = os.path.dirname(os.path.abspath(__file__))

@@ -151,6 +152,7 @@ if "--distributed_adam" in sys.argv:
                 "cxx": ["-O3"] + version_dependent_macros,
                 "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -169,6 +171,7 @@ if "--distributed_lamb" in sys.argv:
                 "cxx": ["-O3"] + version_dependent_macros,
                 "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -198,14 +201,15 @@ if "--cuda_ext" in sys.argv:
                 "csrc/update_scale_hysteresis.cu",
             ],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-lineinfo",
                     "-O3",
                     # '--resource-usage',
                     "--use_fast_math",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )
     ext_modules.append(
@@ -213,9 +217,10 @@ if "--cuda_ext" in sys.argv:
             name="syncbn",
             sources=["csrc/syncbn.cpp", "csrc/welford.cu"],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -224,9 +229,10 @@ if "--cuda_ext" in sys.argv:
             name="fused_layer_norm_cuda",
             sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-maxrregcount=50", "-O3", "--use_fast_math"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-maxrregcount=50", "-O3", "--use_fast_math"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -235,9 +241,10 @@ if "--cuda_ext" in sys.argv:
             name="mlp_cuda",
             sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )
     ext_modules.append(
@@ -245,9 +252,10 @@ if "--cuda_ext" in sys.argv:
             name="fused_dense_cuda",
             sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -260,15 +268,16 @@ if "--cuda_ext" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
                     "-U__CUDA_NO_HALF_CONVERSIONS__",
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -281,15 +290,16 @@ if "--cuda_ext" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
                     "-U__CUDA_NO_HALF_CONVERSIONS__",
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -299,15 +309,16 @@ if "--cuda_ext" in sys.argv:
             sources=["csrc/megatron/scaled_masked_softmax.cpp", "csrc/megatron/scaled_masked_softmax_cuda.cu"],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
                     "-U__CUDA_NO_HALF_CONVERSIONS__",
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -317,15 +328,16 @@ if "--cuda_ext" in sys.argv:
             sources=["csrc/megatron/scaled_softmax.cpp", "csrc/megatron/scaled_softmax_cuda.cu"],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
                     "-U__CUDA_NO_HALF_CONVERSIONS__",
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -338,15 +350,16 @@ if "--cuda_ext" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
                     "-U__CUDA_NO_HALF_CONVERSIONS__",
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -374,7 +387,7 @@ if "--cuda_ext" in sys.argv:
                     "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
                 ],
                 extra_compile_args={
-                    "cxx": ["-O3"] + version_dependent_macros,
+                    "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                     "nvcc": [
                         "-O3",
                         "-U__CUDA_NO_HALF_OPERATORS__",
@@ -382,8 +395,9 @@ if "--cuda_ext" in sys.argv:
                         "--expt-relaxed-constexpr",
                         "--expt-extended-lambda",
                         "--use_fast_math",
-                    ] + version_dependent_macros + cc_flag,
+                    ] + version_dependent_macros + cc_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 },
+                extra_link_args=['cublas.lib', 'cublasLt.lib'],
             )
         )

@@ -398,8 +412,9 @@ if "--permutation_search" in sys.argv:
             CUDAExtension(name='permutation_search_cuda',
                           sources=['apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu'],
                           include_dirs=[os.path.join(this_dir, 'apex', 'contrib', 'sparsity', 'permutation_search_kernels', 'CUDA_kernels')],
-                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
-                                              'nvcc':['-O3'] + version_dependent_macros + cc_flag}))
+                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                                              'nvcc':['-O3'] + version_dependent_macros + cc_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']}),
+                          extra_link_args=['cublas.lib', 'cublasLt.lib'],)

 if "--bnp" in sys.argv:
     sys.argv.remove("--bnp")
@@ -415,14 +430,15 @@ if "--bnp" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": [] + version_dependent_macros,
+                "cxx": [] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-DCUDA_HAS_FP16=1",
                     "-D__CUDA_NO_HALF_OPERATORS__",
                     "-D__CUDA_NO_HALF_CONVERSIONS__",
                     "-D__CUDA_NO_HALF2_OPERATORS__",
-                ] + version_dependent_macros,
+                ] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -438,9 +454,10 @@ if "--xentropy" in sys.argv:
             sources=["apex/contrib/csrc/xentropy/interface.cpp", "apex/contrib/csrc/xentropy/xentropy_kernel.cu"],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros + [f'-DXENTROPY_VER="{xentropy_ver}"'],
-                "nvcc": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'] + [f'-DXENTROPY_VER="{xentropy_ver}"'],
+                "nvcc": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -456,9 +473,10 @@ if "--focal_loss" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, 'csrc')],
             extra_compile_args={
-                'cxx': ['-O3'] + version_dependent_macros,
-                'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,
+                'cxx': ['-O3'] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -482,11 +500,12 @@ if "--group_norm" in sys.argv:
             ] + glob.glob("apex/contrib/csrc/group_norm/*.cu"),
             include_dirs=[os.path.join(this_dir, 'csrc')],
             extra_compile_args={
-                "cxx": ["-O3", "-std=c++17"] + version_dependent_macros,
+                "cxx": ["-O3", "-std=c++17"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3", "-std=c++17", "--use_fast_math", "--ftz=false",
-                ] + arch_flags + version_dependent_macros,
+                ] + arch_flags + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -502,9 +521,10 @@ if "--index_mul_2d" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, 'csrc')],
             extra_compile_args={
-                'cxx': ['-O3'] + version_dependent_macros,
-                'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,
+                'cxx': ['-O3'] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -520,9 +540,10 @@ if "--deprecated_fused_adam" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -539,9 +560,10 @@ if "--deprecated_fused_lamb" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3", "--use_fast_math"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -576,7 +598,7 @@ if "--fast_layer_norm" in sys.argv:
                 "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
             ],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros + generator_flag,
+                "cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
@@ -589,9 +611,10 @@ if "--fast_layer_norm" in sys.argv:
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
                     "--use_fast_math",
-                ] + version_dependent_macros + generator_flag + cc_flag,
+                ] + version_dependent_macros + generator_flag + cc_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
             include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")],
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -626,7 +649,7 @@ if "--fmha" in sys.argv:
                 "apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu",
             ],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros + generator_flag,
+                "cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
@@ -634,8 +657,9 @@ if "--fmha" in sys.argv:
                     "--expt-relaxed-constexpr",
                     "--expt-extended-lambda",
                     "--use_fast_math",
-                ] + version_dependent_macros + generator_flag + cc_flag,
+                ] + version_dependent_macros + generator_flag + cc_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
             include_dirs=[
                 os.path.join(this_dir, "apex/contrib/csrc"),
                 os.path.join(this_dir, "apex/contrib/csrc/fmha/src"),
@@ -678,7 +702,7 @@ if "--fast_multihead_attn" in sys.argv:
                 "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
             ],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros + generator_flag,
+                "cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
                 "nvcc": [
                     "-O3",
                     "-U__CUDA_NO_HALF_OPERATORS__",
@@ -689,8 +713,10 @@ if "--fast_multihead_attn" in sys.argv:
                 ]
                 + version_dependent_macros
                 + generator_flag
-                + cc_flag,
+                + cc_flag
+                + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
             include_dirs=[
                 os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/include/"),
                 os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass/tools/util/include")
@@ -709,9 +735,10 @@ if "--transducer" in sys.argv:
                 "apex/contrib/csrc/transducer/transducer_joint_kernel.cu",
             ],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros + generator_flag,
-                "nvcc": ["-O3"] + version_dependent_macros + generator_flag,
+                "cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
             include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")],
         )
     )
@@ -724,9 +751,10 @@ if "--transducer" in sys.argv:
             ],
             include_dirs=[os.path.join(this_dir, "csrc")],
             extra_compile_args={
-                "cxx": ["-O3"] + version_dependent_macros,
-                "nvcc": ["-O3"] + version_dependent_macros,
+                "cxx": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
+                "nvcc": ["-O3"] + version_dependent_macros + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE'],
             },
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -743,7 +771,8 @@ if "--cudnn_gbn" in sys.argv:
                     "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp",
                 ],
                 include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
-                extra_compile_args={"cxx": ["-O3", "-g"] + version_dependent_macros + generator_flag},
+                extra_compile_args={"cxx": ["-O3", "-g"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+                extra_link_args=['cublas.lib', 'cublasLt.lib'],
             )
         )

@@ -757,7 +786,8 @@ if "--peer_memory" in sys.argv:
                 "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu",
                 "apex/contrib/csrc/peer_memory/peer_memory.cpp",
             ],
-            extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
+            extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )

@@ -780,7 +810,8 @@ if "--nccl_p2p" in sys.argv:
                     "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu",
                     "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp",
                 ],
-                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
+                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+                extra_link_args=['cublas.lib', 'cublasLt.lib'],
             )
         )
     else:
@@ -799,7 +830,8 @@ if "--fast_bottleneck" in sys.argv:
                 name="fast_bottleneck",
                 sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"],
                 include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
-                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
+                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+                extra_link_args=['cublas.lib', 'cublasLt.lib'],
             )
         )

@@ -814,7 +846,8 @@ if "--fused_conv_bias_relu" in sys.argv:
                 name="fused_conv_bias_relu",
                 sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"],
                 include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
-                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
+                extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+                extra_link_args=['cublas.lib', 'cublasLt.lib'],
             )
         )

@@ -828,7 +861,8 @@ if "--gpu_direct_storage" in sys.argv:
             sources=["apex/contrib/csrc/gpu_direct_storage/gds.cpp", "apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp"],
             include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/gpu_direct_storage")],
             libraries=["cufile"],
-            extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
+            extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag + ['-D_DISABLE_EXTENDED_ALIGNED_STORAGE']},
+            extra_link_args=['cublas.lib', 'cublasLt.lib'],
         )
     )
rkononovs commented 2 months ago

I was struggling with Apex installation and this is the only thing that helped me. Thank you, man <3