Open newsbreak-tianxiang opened 1 year ago
When using the Adam optimizer, this inconsistency results occurs in earlier training iterations than sgd:
0
1
2
3
Traceback (most recent call last):
File "debug_torchrec_optim.py", line 166, in <module>
test_dmp_optim()
File "debug_torchrec_optim.py", line 161, in test_dmp_optim
torch.testing.assert_close(y, torchrec_y)
File "/home/tianxiangpan/venv/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 188 / 8192 (2.3%)
Greatest absolute difference: 1.5497207641601562e-05 at index (658, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.005986993705095808 at index (1408, 1) (up to 1.3e-06 allowed)
@YLGH @henrylhtsang @joshuadeng
@newsbreak-tianxiang Finally I might have an answer for you.
We can actually reduce the code a bit:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torchrec
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.distributed.model_parallel import DistributedModelParallel
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torch.distributed.elastic.utils.distributed import get_free_port
torch.set_printoptions(precision=10)
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class Model(nn.Module):
def __init__(self, use_torchrec=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.use_torchrec = use_torchrec
if self.use_torchrec:
self.f_ebc = EmbeddingBagCollection(
tables=[
torchrec.EmbeddingBagConfig(
name="f1_table",
num_embeddings=20,
embedding_dim=8,
feature_names=["f1"],
pooling=torchrec.PoolingType.SUM,
),
],
device=torch.device("meta"),
)
else:
self.f1_embeddingBag = nn.EmbeddingBag(20, 8, mode="sum")
self.linear = nn.Linear(8, 1).cuda()
def forward(self, x):
# x: (batch_size, 1, 10)
batch_size = x.shape[0]
pool_size = x.shape[2]
f1 = x[:, 0, :]
if self.use_torchrec:
ebc_values = torch.cat([f1.flatten()]).cuda()
ebc_length = torch.tensor([pool_size] * batch_size).cuda()
kjt_input = torchrec.KeyedJaggedTensor(
keys=["f1"],
values=ebc_values,
lengths=ebc_length,
)
ebc_result = self.f_ebc(kjt_input)
x_embedding = torch.cat([ebc_result["f1"]], dim=1)
else:
f1 = self.f1_embeddingBag(f1)
x_embedding = torch.cat([f1], dim=1)
y = self.linear(x_embedding)
# y = torch.sum(x_embedding, dim=1)
return y
def init_dist():
# init the distributed environment
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
# init the process group
rank = int(os.environ.get("LOCAL_RANK", "0"))
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
if not torch.distributed.is_initialized():
dist.init_process_group(backend=backend)
init_dist()
device = torch.device("cuda:0")
torchrec_model = Model(use_torchrec=True)
fused_params = {
"optimizer": EmbOptimType.EXACT_SGD,
"learning_rate": 0.01,
"momentum": 0.0,
}
torchrec_model = DistributedModelParallel(
torchrec_model,
init_data_parallel=False,
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], # pyre-ignore
device=device,
)
torchrec_model.init_data_parallel()
model = Model(use_torchrec=False).to(device)
# init the torchrec model with the same weights as the model
torchrec_model.module.f_ebc.embedding_bags.f1_table.load_state_dict(
{"weight": model.state_dict()["f1_embeddingBag.weight"]}
)
torchrec_model.module.linear.load_state_dict(model.linear.state_dict())
criterion = torch.nn.NLLLoss()
torchrec_criterion = torch.nn.NLLLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.0)
torchrec_optim = torch.optim.SGD(
[v for k, v in dict(torchrec_model.named_parameters()).items() if "linear" in k],
lr=0.01,
momentum=0.0,
)
assert torch.equal(torchrec_model.module.linear.weight, model.linear.weight), (
torchrec_model.module.linear.weight,
model.linear.weight,
)
assert torch.equal(torchrec_model.module.linear.bias, model.linear.bias), (
torchrec_model.module.linear.bias,
model.linear.bias,
)
assert torch.equal(
torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
model.f1_embeddingBag.weight.data,
), (
torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
model.f1_embeddingBag.weight.data,
)
batch_size = 20
for i in range(100):
print(i)
# generate random data
x = torch.randint(0, 20, (batch_size, 1, 10)).cuda()
label = torch.randint(0, 1, (batch_size,)).cuda()
num_duplicates = len(list(x.flatten())) - len(set(x.flatten()))
if num_duplicates:
print(f"number of duplicates = {num_duplicates}")
# duplicate values
# x[:1024, 0, :] = 5
# x[:1024, 1, :] = 10
# model forward
optim.zero_grad()
y = model(x)
loss = criterion(y, label)
loss.backward()
optim.step()
# torchrec model forward
torchrec_optim.zero_grad()
torchrec_y = torchrec_model(x)
torchrec_loss = torchrec_criterion(torchrec_y, label)
torchrec_loss.backward()
torchrec_optim.step()
assert torch.equal(y, torchrec_y), "prediction don't match"
assert torch.equal(
torchrec_model.module.linear.weight, model.linear.weight
), "linear weight doesn't match"
assert torch.equal(
torchrec_model.module.linear.bias, model.linear.bias
), "linear bias doesn't match"
assert torch.equal(
torchrec_model.module.linear.weight.grad, model.linear.weight.grad
), "linear weight grad doesn't match"
assert torch.equal(
torchrec_model.module.linear.bias.grad, model.linear.bias.grad
), "linear bias doesn't match"
assert torch.equal(
torchrec_model.module.f_ebc.embedding_bags.f1_table.weight.data,
model.f1_embeddingBag.weight.data,
), "embedding table weights don't match"
The main difference is reducing the batch size. Then what I saw is:
0
1
2
3
4
embedding table weights don't match
In other words, even without the duplicate inputs, the result are identical for 4 iterations, until they diverge.
After consulting with FBGEMM team, it turns out this is somewhat expected, since they involve accumulation. "The order of arithmetic operations can be different between the two implementations. (The order of arithmetics affects the floating point rounding results)"
I am attempting to convert an original PyTorch model into the torchrec version. During the experiment, I noticed that there is a clear inconsistency between results of the torchrec optimizer and the original torch.nn.EmbeddingBag&pytorch version when there are many duplicate inputs. I would like to ask if such results are reasonable and how to fix & check whether the results are reasonable. The code below can reproduce inconsistent results in PyTorch version 2.0.1 and torchrec version 0.4.0.