New Features
Inference Table Batched Embedding (TBE) Enhancements (#951, #984)
The table batched embedding (TBE) operator is an important base operation for embedding lookup for recommendation system inference on GPU. We added the following enhancements for performance and flexibility:
- Alignment restriction removed: Embedding dimension * data type size had to be multiple of 4B before and now, it is 1B. ()
- UVM caching kernels now scale linearly with # of tables using UVM caching. Previously, it was having similar overhead as all tables using UVM caching
- UVM caching kernel overhead is much smaller than before
Inference FP8 Table Batched Embedding (TBE) (#1091)
The table batched embedding (TBE) previously supported FP32, FP16, INT8, INT4, and INT2 embedding weight types. While these weight types work well in many models, we integrate FP8 weight types (in both GPU and CPU operations) to allow for numerical and performance evaluations of FP8 in our models. Compared to INT8, FP8 does not require the additional bias and scale storage and calculations. Additionally, the next generation of H100 GPUs has the FP8 support on Tensor Core (mainly matmul ops).
Jagged Tensor Kernels (#1006, #1008)
We added optimized kernels to speed up TorchRec Jagged Tensor. The purpose of JaggedTensor is to handle the case where one dimension of the input data is “jagged”, meaning that each consecutive row in a given dimension may be a different length, which is often the case with sparse feature inputs in recommendation systems.
Optimized permute102-baddbmm-permute102 (#1048)
It is difficult to fuse various matrix multiplications where the batch size is not the batch size of the model, switching the batch dimension is a quick solution. We created the permute102_baddbmm_permute102 operation that switches the first and the second dimension, performs the batched matrix multiplication and then switches back. Currently we only support forward pass with FP16 data type and will support FP32 type and backward pass in the future.
Optimized index_select for dim 0 index selection (#1113)
index_select is normally used as part of a sparse operation. While PyTorch supports a generic index_select for an arbitrary-dimension index selection, its performance for a special case like the dim 0 index selection is suboptimal. For this reason, we implement a specialized index_select for dim 0. In some cases, we have observed 1.4x performance gain from FBGEMM’s index_select compared to the one from PyTorch (using uniform index distribution).
Full Changelog: https://github.com/pytorch/FBGEMM/commits/v0.2.0