facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

non-determinism when running inference with CPU offloading #297

Closed walid0925 closed 2 years ago

walid0925 commented 2 years ago

thanks so much for your work in maintaining ESM!

Bug description I recently started to play around with these models and have noticed non-determinism when using CPU offloading as shown in the examples/esm2_infer_fairscale_fsdp_cpu_offloading.py script; non-determinism in itself is perhaps expected but the magnitude that I'm seeing is not. Please let me know if I've missed anything!

Reproduction steps This code is almost directly copied from the examples script, with the modification of using only one example protein sequence and printing a mean representation. I've also included a smaller model here for speed, though I've noticed the same across other models such as esm2_t36_3B_UR50D.

import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.wrap import enable_wrap, wrap

import esm

# init the distributed world with world_size 1
url = "tcp://localhost:23456"
torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0)

# download model data from the hub
model_data, regression_data = esm.pretrained._download_model_and_regression_data(
    "esm2_t33_650M_UR50D",
)
if regression_data is not None:
    model_data["model"].update(regression_data["model"])

# initialize the model with FSDP wrapper
fsdp_params = dict(
    mixed_precision=True,
    flatten_parameters=True,
    state_dict_device=torch.device("cpu"),  # reduce GPU mem usage
    cpu_offload=True,  # enable cpu offloading
)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
    model, vocab, _ = esm.pretrained._load_model_and_alphabet_core_v2(model_data)
    batch_converter = vocab.get_batch_converter()
    model.eval()

    # Wrap each layer in FSDP separately
    for name, child in model.named_children():
        if name == "layers":
            for layer_name, layer in child.named_children():
                wrapped_layer = wrap(layer)
                setattr(child, layer_name, wrapped_layer)
    model = wrap(model)

data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), 
]

batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.cuda()
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    representations = {
        layer: t.to(device="cpu") for layer, t in results["representations"].items()
    }
    mean_representations = {
                        layer: t[0, 1 : len(batch_strs[0]) + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
print(mean_representations)

Repeated calls result in very different outputs. Included here are the outputs from three consecutive runs

# run 1
{33: tensor([-0.1011,  0.2788, -0.3147,  ..., -1.2031,  0.5732,  0.2917],
       dtype=torch.float16)}
# run 2
{33: tensor([-0.9829, -0.0180, -0.1920,  ...,  0.0896, -1.8984,  1.5332],
       dtype=torch.float16)}
# run 3
{33: tensor([ 1.9287,  0.3608, -0.5444,  ...,  0.7388,  1.6504, -0.4055],
       dtype=torch.float16)}

As you can see, the outputs are very different (different magnitudes, even different signs)

Expected behavior I would expect consecutive runs to have approximately similar outputs, even if not exactly the same

Additional context Device: AWS p3.2xlarge (Tesla V100)

walid0925 commented 2 years ago

i think i've figured this out - the private method esm.pretrained._load_model_and_alphabet_core_v2(model_data) doesn't update the model's state_dict, so this is doing inference through a randomly initialized model. I can make a quick PR to fix this