pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.92k stars 424 forks source link

Inconsistent training result when getting duplicate inputs #1298

Open newsbreak-tianxiang opened 1 year ago

newsbreak-tianxiang commented 1 year ago

I am attempting to convert an original PyTorch model into the torchrec version. During the experiment, I noticed that there is a clear inconsistency between results of the torchrec optimizer and the original torch.nn.EmbeddingBag&pytorch version when there are many duplicate inputs. I would like to ask if such results are reasonable and how to fix & check whether the results are reasonable. The code below can reproduce inconsistent results in PyTorch version 2.0.1 and torchrec version 0.4.0.

import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torchrec
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.optimizers import in_backward_optimizer_filter

class Model(nn.Module):
    def __init__(self, use_torchrec=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.use_torchrec = use_torchrec

        if self.use_torchrec:
            self.f_ebc = EmbeddingBagCollection(
                tables = [
                    torchrec.EmbeddingBagConfig(name='f1_table', num_embeddings=100, embedding_dim=8, feature_names=["f1"], pooling=torchrec.PoolingType.SUM),
                    torchrec.EmbeddingBagConfig(name='f2_table', num_embeddings=100, embedding_dim=8, feature_names=["f2"], pooling=torchrec.PoolingType.SUM)
                ],
                device=torch.device("meta"),
            )
        else:
            self.f1_embeddingBag = nn.EmbeddingBag(100, 8, mode='sum')
            self.f2_embeddingBag = nn.EmbeddingBag(100, 8, mode='sum')

        self.linear = nn.Linear(16, 2).cuda()

    def forward(self, x):
        # x: (batch_size, 2, 10)
        batch_size = x.shape[0]
        pool_size = x.shape[2]

        f1 = x[:, 0, :]
        f2 = x[:, 1, :]

        if self.use_torchrec:
            ebc_values = torch.cat([f1.flatten(), f2.flatten()]).long().cuda()
            ebc_length = torch.tensor([pool_size]*batch_size*2).long().cuda()
            kjt_input = torchrec.KeyedJaggedTensor(
                keys = ['f1', 'f2'],
                values=ebc_values,
                lengths=ebc_length,
            )
            ebc_result = self.f_ebc(kjt_input)
            x_embedding = torch.cat([ebc_result['f1'], ebc_result['f2']], dim=1)
        else:
            f1 = self.f1_embeddingBag(f1)
            f2 = self.f2_embeddingBag(f2)
            x_embedding = torch.cat([f1, f2], dim=1)

        y = self.linear(x_embedding)
        return y

def init_dist():

    # init the distributed environment
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "30001"
    os.environ["NODE_RANK"] = "0"
    os.environ["LOCAL_RANK"] = "0"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["LOCAL_WORLD_SIZE"] = "1"

    # init the process group
    rank = int(os.environ.get("LOCAL_RANK", "0"))
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)

    if not torch.distributed.is_initialized():
        dist.init_process_group(backend=backend)

def test_dmp_optim():
    init_dist()
    device = torch.device("cuda:0")

    model = Model(use_torchrec=False).to(device)
    torchrec_model = Model(use_torchrec=True)

    apply_optimizer_in_backward(
        optimizer_class=torch.optim.SGD,
        params=list(torchrec_model.f_ebc.parameters()),
        optimizer_kwargs={
            "lr": 0.01,
            "momentum": 0.0,
        },
    )
    torchrec_model = DistributedModelParallel(torchrec_model, device=device)

    # init the torchrec model with the same weights as the model
    torchrec_model.module.f_ebc.embedding_bags.f1_table.load_state_dict({'weight': model.state_dict()['f1_embeddingBag.weight']})
    torchrec_model.module.f_ebc.embedding_bags.f2_table.load_state_dict({'weight': model.state_dict()['f2_embeddingBag.weight']})
    torchrec_model.module.linear.load_state_dict(model.linear.state_dict())

    criterion = torch.nn.NLLLoss()
    optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.0)
    torchrec_optim = KeyedOptimizerWrapper(
        {k: v for k, v in dict(in_backward_optimizer_filter(
            torchrec_model.named_parameters())).items()},
        lambda params: torch.optim.SGD(params, lr=0.01, momentum=0.0),
    )

    batch_size = 4096
    for i in range(100):
        print(i)

        # generate random data
        x = torch.randint(0, 100, (batch_size, 2, 10)).cuda()
        label = torch.randint(0, 2, (batch_size,)).cuda()

        # duplicate values
        x[:1024, 0, :] = 5
        x[:1024, 1, :] = 10

        # model forward
        optim.zero_grad()
        y = model(x)
        loss = criterion(y, label)
        loss.backward()
        optim.step()

        # torchrec model forward
        torchrec_optim.zero_grad()
        torchrec_y = torchrec_model(x)
        torchrec_loss = criterion(torchrec_y, label)
        torchrec_loss.backward()
        torchrec_optim.step()

        torch.testing.assert_close(y, torchrec_y)

if __name__ == "__main__":
    test_dmp_optim()
newsbreak-tianxiang commented 1 year ago

When using the Adam optimizer, this inconsistency results occurs in earlier training iterations than sgd:

0
1
2
3
Traceback (most recent call last):
  File "debug_torchrec_optim.py", line 166, in <module>
    test_dmp_optim()
  File "debug_torchrec_optim.py", line 161, in test_dmp_optim
    torch.testing.assert_close(y, torchrec_y)
  File "/home/tianxiangpan/venv/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 188 / 8192 (2.3%)
Greatest absolute difference: 1.5497207641601562e-05 at index (658, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.005986993705095808 at index (1408, 1) (up to 1.3e-06 allowed)
newsbreak-tianxiang commented 1 year ago

@YLGH @henrylhtsang @joshuadeng

henrylhtsang commented 8 months ago

@newsbreak-tianxiang Finally I might have an answer for you.

We can actually reduce the code a bit:

import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torchrec
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.distributed.model_parallel import DistributedModelParallel
from fbgemm_gpu.split_embedding_configs import EmbOptimType

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder

from torch.distributed.elastic.utils.distributed import get_free_port

torch.set_printoptions(precision=10)

seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class Model(nn.Module):
    def __init__(self, use_torchrec=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.use_torchrec = use_torchrec

        if self.use_torchrec:
            self.f_ebc = EmbeddingBagCollection(
                tables=[
                    torchrec.EmbeddingBagConfig(
                        name="f1_table",
                        num_embeddings=20,
                        embedding_dim=8,
                        feature_names=["f1"],
                        pooling=torchrec.PoolingType.SUM,
                    ),
                ],
                device=torch.device("meta"),
            )
        else:
            self.f1_embeddingBag = nn.EmbeddingBag(20, 8, mode="sum")

        self.linear = nn.Linear(8, 1).cuda()

    def forward(self, x):
        # x: (batch_size, 1, 10)
        batch_size = x.shape[0]
        pool_size = x.shape[2]

        f1 = x[:, 0, :]

        if self.use_torchrec:
            ebc_values = torch.cat([f1.flatten()]).cuda()
            ebc_length = torch.tensor([pool_size] * batch_size).cuda()
            kjt_input = torchrec.KeyedJaggedTensor(
                keys=["f1"],
                values=ebc_values,
                lengths=ebc_length,
            )
            ebc_result = self.f_ebc(kjt_input)
            x_embedding = torch.cat([ebc_result["f1"]], dim=1)
        else:
            f1 = self.f1_embeddingBag(f1)
            x_embedding = torch.cat([f1], dim=1)

        y = self.linear(x_embedding)
        # y = torch.sum(x_embedding, dim=1)
        return y

def init_dist():
    # init the distributed environment
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["LOCAL_RANK"] = "0"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["LOCAL_WORLD_SIZE"] = "1"

    # init the process group
    rank = int(os.environ.get("LOCAL_RANK", "0"))
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)

    if not torch.distributed.is_initialized():
        dist.init_process_group(backend=backend)

init_dist()
device = torch.device("cuda:0")

torchrec_model = Model(use_torchrec=True)

fused_params = {
    "optimizer": EmbOptimType.EXACT_SGD,
    "learning_rate": 0.01,
    "momentum": 0.0,
}
torchrec_model = DistributedModelParallel(
    torchrec_model,
    init_data_parallel=False,
    sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],  # pyre-ignore
    device=device,
)
torchrec_model.init_data_parallel()

model = Model(use_torchrec=False).to(device)

# init the torchrec model with the same weights as the model
torchrec_model.module.f_ebc.embedding_bags.f1_table.load_state_dict(
    {"weight": model.state_dict()["f1_embeddingBag.weight"]}
)
torchrec_model.module.linear.load_state_dict(model.linear.state_dict())

criterion = torch.nn.NLLLoss()
torchrec_criterion = torch.nn.NLLLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.0)
torchrec_optim = torch.optim.SGD(
    [v for k, v in dict(torchrec_model.named_parameters()).items() if "linear" in k],
    lr=0.01,
    momentum=0.0,
)

assert torch.equal(torchrec_model.module.linear.weight, model.linear.weight), (
    torchrec_model.module.linear.weight,
    model.linear.weight,
)
assert torch.equal(torchrec_model.module.linear.bias, model.linear.bias), (
    torchrec_model.module.linear.bias,
    model.linear.bias,
)
assert torch.equal(
    torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
    model.f1_embeddingBag.weight.data,
), (
    torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
    model.f1_embeddingBag.weight.data,
)

batch_size = 20
for i in range(100):
    print(i)

    # generate random data
    x = torch.randint(0, 20, (batch_size, 1, 10)).cuda()
    label = torch.randint(0, 1, (batch_size,)).cuda()

    num_duplicates = len(list(x.flatten())) - len(set(x.flatten()))
    if num_duplicates:
        print(f"number of duplicates = {num_duplicates}")

    # duplicate values
    # x[:1024, 0, :] = 5
    # x[:1024, 1, :] = 10

    # model forward
    optim.zero_grad()
    y = model(x)
    loss = criterion(y, label)
    loss.backward()
    optim.step()

    # torchrec model forward
    torchrec_optim.zero_grad()
    torchrec_y = torchrec_model(x)
    torchrec_loss = torchrec_criterion(torchrec_y, label)
    torchrec_loss.backward()
    torchrec_optim.step()

    assert torch.equal(y, torchrec_y), "prediction don't match"

    assert torch.equal(
        torchrec_model.module.linear.weight, model.linear.weight
    ), "linear weight doesn't match"
    assert torch.equal(
        torchrec_model.module.linear.bias, model.linear.bias
    ), "linear bias doesn't match"
    assert torch.equal(
        torchrec_model.module.linear.weight.grad, model.linear.weight.grad
    ), "linear weight grad doesn't match"
    assert torch.equal(
        torchrec_model.module.linear.bias.grad, model.linear.bias.grad
    ), "linear bias doesn't match"
    assert torch.equal(
        torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
        model.f1_embeddingBag.weight.data,
    ), "embedding table weights don't match"

The main difference is reducing the batch size. Then what I saw is:

0
1
2
3
4
embedding table weights don't match

In other words, even without the duplicate inputs, the result are identical for 4 iterations, until they diverge.

After consulting with FBGEMM team, it turns out this is somewhat expected, since they involve accumulation. "The order of arithmetic operations can be different between the two implementations. (The order of arithmetics affects the floating point rounding results)"