Changelog
Bugs
- The flash attention module was missing in the original codebase. This resulted in a module not found error during the execution.
- Incorrect implementation of flash attention integration with the main attention module. The
forward
method in theAttend
class wasn't correctly handling flash attention. - The
flash_attn
function within theAttend
class had incorrect assumptions about the dimensions of thek
andv
tensors. This led to dimension mismatch errors during the tensor operations. - The original
flash_attn
method was not handling the scale correctly whenqk_norm
was set toTrue
.
Improvements
- Integrated the flash attention module into the main codebase and ensured the dimensions and operations are correct.
- Modified the
forward
method in theAttend
class to handle flash attention correctly. It checks whether flash attention is enabled and accordingly calls the correct attention method. - Adjusted the
flash_attn
method to account for possible missing dimensions inq
,k
, andv
tensors, and to correct for possible dimension mismatches. - Included a check to determine if the tensor is on a CUDA device and if so, to leverage the appropriate CUDA configuration for efficient attention.
- Correctly handled the scale in the
flash_attn
method whenqk_norm
wasTrue
. - Added assertions and informative error messages for incompatible options such as 'talking heads' and 'flash attention'.
- Ensured compatibility with PyTorch version 2.0 and above for using flash attention.