CUDA: FA support for Deepseek (Ampere or newer) (#13306)
* CUDA: FA support for Deepseek (Ampere or newer) * do loop unrolling via C++ template
This commit is contained in:
parent
27ebfcacba
commit
0cf6725e9f
33 changed files with 852 additions and 547 deletions
|
@ -2,6 +2,17 @@
|
|||
|
||||
#include "common.cuh"
|
||||
|
||||
|
||||
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
return __cvta_generic_to_shared(generic_ptr);
|
||||
#else
|
||||
GGML_UNUSED(generic_ptr);
|
||||
NO_DEVICE_CODE;
|
||||
return 0;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
|
||||
// Copies data from global to shared memory, cg == cache global.
|
||||
// Both the src and dst pointers must be aligned to 16 bit.
|
||||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue