Open profitgrowinginnovator opened 1 week ago
Thanks, I've minted #2628 to support __dp4a
on older architectures which is similar to what you suggest (and is based on what is actually done in llama.cpp, see here).
Re multiprocess, could you provide more details about which part actually doesn't work on your gpu?
Many thanks! I am trying to run the llama_multiprocess and get Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading copy2d_bf16 Just like __dp4a, copy2d_bf16 is not supported in older architectures is my guess. However I was not able to solve it as easily as __dp4a.
This is likely because your hardware doesn't support bfloat16, could you try a f16 model instead, e.g. use --which v2-7b
to get llama 2 (which was trained in f16) rather than llama3?
Alternatively you can try something like --dtype f32
or --dtype f16
though if the model has been trained with bf16 or f32, using f16 is likely to result in some nans.
candle-transformers requires SM80 for the llama-multiprocess and SM61 for quantized. Having support for SM60 would allow Nvidia Tesla and other SM60 cards to be used which costs some hundreds of dollars instead of needing cards which even on Ebay cost thousands of dollars or new tens of thousands.
more candle-kernels/src/custom_dp4a.cuh
ifndef CUSTOM_DP4A_CUH
define CUSTOM_DP4A_CUH
// Check if we're compiling for a CUDA architecture less than 6.1
if defined(CUDA_ARCH) && (CUDA_ARCH < 610)
// Custom implementation of __dp4a for sm_60 device inline int custom_dp4a(int a, int b, int c) { // Extract four 8-bit segments from each integer int a0 = (a) & 0xFF; int a1 = (a >> 8) & 0xFF; int a2 = (a >> 16) & 0xFF; int a3 = (a >> 24) & 0xFF;
}
// Redefine __dp4a to use custom_dp4a when compiling for sm_60
define __dp4a(a, b, c) custom_dp4a(a, b, c)
endif // __CUDA_ARCH__ < 610
endif // CUSTOM_DP4A_CUH
and including // Make work with CUDA 60
include "custom_dp4a.cuh"
// end
in candle-kernels/src/quantized.cu at least makes the __dp4a error go away.
llama_multiprocess is harder to get working. Any pointers really appreciated.