This release brings some improvements to the memory_efficient_attention
Pip wheels now target pytorch 2.0.0 - conda builds are available for PT 2.0.0, 1.13.1 and 1.12.1
Fixed
- fMHA: Fixed BW pass on Sm86/Sm89 GPUs when
K > 64
(RTX 3090, RTX 4090, A6000, ..) [#631]