pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.96k stars 442 forks source link

support stbe length rebatching and remove stbe output padding for MTIA #2523

Open seanx92 opened 1 month ago

seanx92 commented 1 month ago

Summary:

  1. For rebatching stbe length without output, it must be 2d tensor in the shape of [F x B] and we can directly concat at dim1;
  2. For MTIA inference, if stbe is in remote, its output will be padded to max batch size, which will make split not work. In this case, we want to remove the padding and restore its original size.

Differential Revision: D64914077

facebook-github-bot commented 1 month ago

This pull request was exported from Phabricator. Differential Revision: D64914077