Details
cuda: fuse snake activation (mul, sin, sqr, mul, add) (#22667)
- cuda: fuse snake activation (mul, sin, sqr, mul, add)
Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The
matcher recognizes the naive 5 op decomposition emitted by audio
decoders (BigVGAN, Vocos) for snake activation
y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise
kernel.
Add test_snake_fuse comparing CPU naive vs CUDA fused across
F32 / F16 / BF16.
- cuda: address review feedback from @am17an
Use ggml_cuda_cast for F32/F16/BF16 conversions and rename
kernel_snake to snake_kernel to match upstream conventions.
-
cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an
-
Update tests/test-backend-ops.cpp
Co-authored-by: Aman Gupta amangupta052@gmail.com
- cuda: snake fusion check add->type matches x->type
Address review feedback from @am17an
- cuda: snake fusion check add->type matches x->type
Moved for readability (equivalent)
Address review feedback from @am17an
Co-authored-by: Aman Gupta amangupta052@gmail.com
macOS/iOS:
- macOS Apple Silicon (arm64)
- macOS Apple Silicon (arm64, KleidiAI enabled)
- macOS Intel (x64)
- iOS XCFramework
Linux:
- Ubuntu x64 (CPU)
- Ubuntu arm64 (CPU)
- Ubuntu s390x (CPU)
- Ubuntu x64 (Vulkan)
- Ubuntu arm64 (Vulkan)
- Ubuntu x64 (ROCm 7.2)
- Ubuntu x64 (OpenVINO)
- Ubuntu x64 (SYCL FP32)
- Ubuntu x64 (SYCL FP16)
Android:
Windows:
- Windows x64 (CPU)
- Windows arm64 (CPU)
- Windows x64 (CUDA 12) - CUDA 12.4 DLLs
- Windows x64 (CUDA 13) - CUDA 13.1 DLLs
- Windows x64 (Vulkan)
- Windows x64 (SYCL)
- Windows x64 (HIP)
openEuler: