Details
ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (… (#22286)
- ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32)
Adds MMA-f16 and tile kernel configs, dispatch logic, template instances,
and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to
ncols2=32 to support GQA ratio 32 only.
-
Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32
-
Apply suggestions from code review
Address review comments
Co-authored-by: Johannes Gäßler johannesg@5d6.de
-
Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256
-
Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)
-
Apply suggestions from code review
Co-authored-by: Johannes Gäßler johannesg@5d6.de
- Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Co-authored-by: Johannes Gäßler johannesg@5d6.de
Co-authored-by: Johannes Gäßler johannesg@5d6.de
macOS/iOS:
- macOS Apple Silicon (arm64)
- macOS Apple Silicon (arm64, KleidiAI enabled)
- macOS Intel (x64)
- iOS XCFramework
Linux:
- Ubuntu x64 (CPU)
- Ubuntu arm64 (CPU)
- Ubuntu s390x (CPU)
- Ubuntu x64 (Vulkan)
- Ubuntu arm64 (Vulkan)
- Ubuntu x64 (ROCm 7.2)
- Ubuntu x64 (OpenVINO)
- Ubuntu x64 (SYCL FP32)
- Ubuntu x64 (SYCL FP16)
Android:
Windows:
- Windows x64 (CPU)
- Windows arm64 (CPU)
- Windows x64 (CUDA 12) - CUDA 12.4 DLLs
- Windows x64 (CUDA 13) - CUDA 13.1 DLLs
- Windows x64 (Vulkan)
- Windows x64 (SYCL)
- Windows x64 (HIP)
openEuler: