Added
- fMHA:
PagedBlockDiagonalGappyKeysMask
- fMHA: heterogeneous queries in
triton_splitk
- fMHA: support for paged attention in flash
- fMHA: Added backwards pass for
merge_attentions
- fMHA: Added
torch.compile
support for 3 biases (LowerTriangularMask
,LowerTriangularMaskWithTensorBias
andBlockDiagonalMask
) - some might require PyTorch 2.4 - fMHA: Added
torch.compile
support inmemory_efficient_attention
when passing the flash operator explicitely (egmemory_efficient_attention(..., op=(flash.FwOp, flash.BwOp))
) - fMHA:
memory_efficient_attention
now expects itsattn_bias
argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device. - fMHA:
AttentionBias
subclasses are now constructed by default on thecuda
device if available - they used to be created on the CPU device - 2:4 sparsity: Added
xformers.ops.sp24.sparsify24_ste
for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values
Improved
- fMHA: Fixed out-of-bounds reading for Split-K triton implementation
- Profiler: fix bug with modules that take a single tuple as argument
- Profiler: Added manual trigger for a profiling step, by creating a
trigger
file in the profiling directory
Removed
- Removed support for PyTorch version older than 2.2.0