pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.86k stars 402 forks source link

[Question] Is there FP8 embeddings support for training? #2264

Open ShijieZZZZ opened 1 month ago

ShijieZZZZ commented 1 month ago

Hello, it looks like EmbeddingBagCollection forces data type to be float32 or float16 during initialization. https://github.com/pytorch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L179

Is there any support to make embedding be float8? Note, this is for training.

屏幕截图 2024-07-31 160639

PaulZhang12 commented 1 month ago

Doesn't look like it no, what is your use case? Feel free to put up a pull request

ShijieZZZZ commented 1 month ago

Hello @PaulZhang12, thanks for your reply. The use case is a normal deep learning recommendation model training with all the embeddings in FP8 format. The reason I do not use FP32 or FP16 embeddings is because I want to save memory. A simple example as below:

import torch
import torchrec
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

class myModel(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int):
        super(myModel, self).__init__()

        self.L= torch.nn.Linear(input_size, output_size)
        self.ebc = torchrec.EmbeddingBagCollection(
            device="cpu",
            tables=[
                torchrec.EmbeddingBagConfig(
                    name="t1",
                    embedding_dim=8,
                    num_embeddings=32,
                    feature_names=["f1"],
                    pooling=torchrec.PoolingType.SUM,
                    data_type=torchrec.modules.embedding_configs.DataType.FP8,
                ),
                torchrec.EmbeddingBagConfig(
                    name="t2",
                    embedding_dim=8,
                    num_embeddings=32,
                    feature_names=["f2"],
                    pooling=torchrec.PoolingType.SUM,
                    data_type=torchrec.modules.embedding_configs.DataType.FP8,
                ),
            ],
        ) 

    def forward(self, kjt):

        embeddings = self.ebc(kjt)
        input = [embeddings ["f1"], embeddings ["f2"]]

        cat = torch.cat(input, dim=1)
        output = self.L(cat)
        return output

#Training

model = myModel(input_size=16, output_size=1)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for _ in range(1000):
    optimizer.zero_grad()
    kjt = KeyedJaggedTensor(
        keys=["f1", "f2"],
        values=torch.randint(0, 31, (8,)),
        lengths=torch.tensor([2, 2, 1, 3]),
    )

    prediction = model(kjt)
    target = torch.randint(0, 1, (2, 1))
    loss = loss_fn(prediction, target.float())
    loss.backward()
    optimizer.step()