NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.43k stars 1.4k forks source link

why a kernel like CUDAFunctor_add appears when testing MixedFusedRMSNorm? #1753

Open HangJie720 opened 1 year ago

HangJie720 commented 1 year ago

When I execute the following python code, I don’t know why a kernel like CUDAFunctor_add appears.

`import torch from apex.normalization import MixedFusedRMSNorm

datatype = torch.bfloat16 input = torch.randn([1024, 8192], dtype=datatype).cuda() input1 = torch.rand_like(input) input1.requires_grad=True

input1_nofuse = input1.detach().requiresgrad(True)

grad = torch.rand_like(input)

for i in range(10): norm = MixedFusedRMSNorm(input.size()[1:]).to(device='cuda', dtype=datatype) input3 = input1nofuse output = norm(input3) output = output.to('cuda:0') output_.backward(gradient=grad) ` run with nvprof tool to show kernel:

==14326== NVPROF is profiling process 14326, command: python3 test_rmsnorm.py ==14326== Profiling application: python3 test_rmsnorm.py ==14326== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 76.46% 12.271ms 11 1.1155ms 3.3270us 12.237ms [CUDA memcpy HtoD] 9.53% 1.5299ms 10 152.99us 151.65us 154.56us void cuComputeGradInput<c10::BFloat16, float, c10::BFloat16>(c10::BFloat16 const *, c10::BFloat16 const *, int, int, float const *, float const *, float, c10::BFloat16 const *, c10::BFloat16*, bool) 6.61% 1.0608ms 10 106.08us 104.38us 108.22us void cuApplyRMSNorm<c10::BFloat16, float, c10::BFloat16>(c10::BFloat16*, float*, c10::BFloat16 const *, int, int, float, c10::BFloat16 const *) 3.40% 545.94us 10 54.594us 53.439us 56.767us void cuComputePartGradGammaBeta<c10::BFloat16, float, c10::BFloat16>(c10::BFloat16 const *, c10::BFloat16 const *, int, int, float const *, float const *, float, float*, float*, bool) 3.03% 486.97us 9 54.107us 53.248us 55.103us void at::native::vectorized_elementwise_kernel<int=4, at::native::CUDAFunctor_add<c10::BFloat16>, at::detail::Array<char*, int=3>>(int, c10::BFloat16, at::native::CUDAFunctor_add<c10::BFloat16>) 0.59% 95.392us 2 47.696us 47.040us 48.352us _ZN2at6native55_GLOBAL__N__722798bb_22_DistributionUniform_cu_f2fea07d43distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda21uniform_and_transformIN3c108BFloat16EfLm4EPNS_17CUDAGeneratorImplEZZZNS4_14uniform_kernelIS9_EEvRNS_18TensorIteratorBaseEddT_ENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvSC_T2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SL_SG_EEvSC_SH_RKSI_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SH_ 0.37% 59.137us 10 5.9130us 5.6000us 6.5600us void cuComputeGradGammaBeta<float, c10::BFloat16>(float const *, float const *, int, int, int, c10::BFloat16*, c10::BFloat16*, bool) API calls: 97.73% 1.24793s 1 1.24793s 1.24793s 1.24793s cudaDeviceGetStreamPriorityRange There is a CUDAFunctor_add kernel,I want to ask why?