facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

Predicted ESM2 logits depend on other elements within a batch #273

Closed FedericoV closed 1 year ago

FedericoV commented 2 years ago

Thank you so much for all the work on ESM and ESM2. I ran into some surprising behaviour:

Bug description ESM2 predicts slightly different logits even when in eval mode depending on other elements within a batch.

Reproduction steps

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model_path = "facebookresearch/esm:main"
model_name = f"esm2_t36_3B_UR50D"
esm_model, alphabet = torch.hub.load(model_path, model_name)

esm_model = esm_model.eval().cuda()
batch_converter = alphabet.get_batch_converter()

# Those are arbitrary sequences, doesn't matter which ones are used
sequences = [
    'A' * 255,
    'Y' * 310
]

model_input = batch_converter([(None, seq) for seq in sequences[:2]])[2]
model_input = model_input.to(device)

# Here is the surprising part:
logits1 = esm_model(model_input[[0]])['logits']
logits2 = esm_model(model_input)['logits']

torch.linalg.norm(logits1 - logits2[0])

tensor(0.3426, device='cuda:0')

This gives roughly 0.3426 - with many values significantly different than zero. I was expecting this to be due to some kind of batch norm like functionality, but, the model is in eval mode.

nikitos9000 commented 2 years ago

Hi @FedericoV, thanks for reporting this, but unfortunately I can't reproduce it with any of the models – the difference is really close to zero for me. Could you please provide more details about your environment (conda list / pip freeze), and maybe try to clear the cache with rm -rf ~/.cache/torch/hub/facebookresearch_esm_main/ and re-do the script? Also, please post the logit outputs here as well, that might be useful

FedericoV commented 2 years ago

Hi @nikitos9000 thanks for looking into it. So, first of all, I did a bit more digging (this is with esm2_t36_3B_UR50D)

len_2 = 4

diffs = []
for len_2 in range(5, 100):

    sequences = [
        'A' * 4,
        'Y' * len_2
    ]

    model_input = batch_converter([(None, seq) for seq in sequences])[2]
    model_input = model_input.to(device)

    # Here is the surprising part:
    logits1 = esm_model(model_input[[0]])['logits']
    logits2 = esm_model(model_input)['logits'][0]

    norm = float(torch.linalg.norm(logits1 - logits2).detach().cpu())

    diffs.append(norm)

and this is what it ends up looking like.

image

For my environment:

channels:
  - gpytorch
  - bioconda
  - anaconda
  - conda-forge/label/cloudml_hypertune_dev
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=1_llvm
  - aiohttp=3.8.1=py39hb9d737c_1
  - aiosignal=1.2.0=pyhd8ed1ab_0
  - alsa-lib=1.2.3=h516909a_0
  - aom=3.3.0=h27087fc_1
  - argcomplete=2.0.0=pyhd8ed1ab_0
  - argh=0.26.2=pyh9f0ad1d_1002
  - asttokens=2.0.5=pyhd8ed1ab_0
  - async-timeout=4.0.2=pyhd8ed1ab_0
  - attrs=21.4.0=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=py_2
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - biopython=1.79=py39h3811e60_1
  - blas=1.0=mkl
  - bleach=5.0.0=pyhd8ed1ab_0
  - blinker=1.4=py_1
  - bokeh=2.4.2=py39hf3d152e_1
  - botorch=0.6.4=0
  - brotli=1.0.9=h166bdaf_7
  - brotli-bin=1.0.9=h166bdaf_7
  - brotlipy=0.7.0=py39hb9d737c_1004
  - bzip2=1.0.8=h7f98852_4
  - c-ares=1.18.1=h7f98852_0
  - ca-certificates=2022.6.15=ha878542_0
  - cachetools=5.0.0=pyhd8ed1ab_0
  - certifi=2022.6.15=pyhd8ed1ab_1
  - cffi=1.15.0=py39h4bc2ebd_0
  - charset-normalizer=2.0.12=pyhd8ed1ab_0
  - click=8.1.2=py39hf3d152e_0
  - cloudml-hypertune=0.1.0.dev6=pyh01e6276_0
  - cloudpickle=2.0.0=pyhd8ed1ab_0
  - colorama=0.4.4=pyh9f0ad1d_0
  - colorcet=3.0.0=pyhd8ed1ab_0
  - cryptography=36.0.0=py39h9ce1e76_0
  - cudatoolkit=11.3.1=h2bc3f7f_2
  - cycler=0.11.0=pyhd8ed1ab_0
  - cytoolz=0.11.2=py39hb9d737c_2
  - dask=2022.4.0=pyhd8ed1ab_0
  - dask-core=2022.4.0=pyhd8ed1ab_0
  - datashader=0.13.0=pyh6c4a22f_0
  - datashape=0.5.4=py_1
  - dbus=1.13.6=h5008d03_3
  - debugpy=1.6.0=py39h5a03fae_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - distributed=2022.4.0=pyhd8ed1ab_0
  - entrypoints=0.4=pyhd8ed1ab_0
  - et_xmlfile=1.1.0=py39h06a4308_0
  - executing=0.8.3=pyhd8ed1ab_0
  - expat=2.4.8=h27087fc_0
  - fairscale=0.4.3=py39hf3d152e_0
  - ffmpeg=5.0.1=h594f047_0
  - fire=0.4.0=pyh44b312d_0
  - flask=2.1.0=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=hab24e00_0
  - fontconfig=2.14.0=h8e229c2_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.32.0=py39hb9d737c_0
  - freetype=2.10.4=h0708190_1
  - frozenlist=1.3.0=py39hb9d737c_1
  - fsspec=2022.2.0=pyhd8ed1ab_0
  - future=0.18.2=py39hf3d152e_5
  - gcsfs=2022.2.0=pyhd8ed1ab_0
  - gettext=0.19.8.1=h73d1719_1008
  - gffutils=0.11.0=pyh5e36f6f_0
  - giflib=5.2.1=h36c2ea0_2
  - gmp=6.2.1=h58526e2_0
  - gnutls=3.6.13=h85f3911_1
  - google-api-core=2.5.0=pyhd8ed1ab_0
  - google-auth=2.6.4=pyh6c4a22f_0
  - google-auth-oauthlib=0.5.1=pyhd8ed1ab_0
  - google-cloud-core=2.2.2=pyh6c4a22f_0
  - google-cloud-storage=2.1.0=pyh6c4a22f_0
  - google-crc32c=1.1.2=py39hb81f231_2
  - google-resumable-media=2.1.0=pyh6c4a22f_0
  - googleapis-common-protos=1.56.0=py39hf3d152e_0
  - gpytorch=1.6.0=0
  - grpcio=1.45.0=py39h7fbbf82_0
  - gst-plugins-base=1.20.1=hcf0ee16_1
  - gstreamer=1.20.1=hd4edc92_1
  - hdbscan=0.8.28=py39hce5d2b2_1
  - heapdict=1.0.1=py_0
  - holoviews=1.14.8=pyhd8ed1ab_0
  - icu=69.1=h9c3ff4c_0
  - idna=3.3=pyhd8ed1ab_0
  - ignite=0.4.9=py_0
  - imagecodecs-lite=2019.12.3=py39hd257fcd_5
  - imageio=2.16.2=pyhcf75d05_0
  - importlib-metadata=4.11.3=py39hf3d152e_1
  - importlib_metadata=4.11.3=hd8ed1ab_1
  - ipykernel=6.10.0=py39hef51801_0
  - ipython=8.2.0=py39hf3d152e_0
  - ipython_genutils=0.2.0=py_1
  - itsdangerous=2.1.2=pyhd8ed1ab_0
  - jbig=2.1=h7f98852_2003
  - jedi=0.18.1=py39hf3d152e_1
  - jinja2=3.1.1=pyhd8ed1ab_0
  - joblib=1.1.0=pyhd8ed1ab_0
  - jpeg=9e=h7f98852_0
  - jupyter_client=7.2.2=pyhd8ed1ab_1
  - jupyter_core=4.9.2=py39hf3d152e_0
  - keyutils=1.6.1=h166bdaf_0
  - kiwisolver=1.4.2=py39hf939315_1
  - krb5=1.19.3=h08a2579_0
  - lame=3.100=h7f98852_1001
  - lcms2=2.12=hddcbb42_0
  - ld_impl_linux-64=2.36.1=hea4e1c9_2
  - lerc=3.0=h9c3ff4c_0
  - libbrotlicommon=1.0.9=h166bdaf_7
  - libbrotlidec=1.0.9=h166bdaf_7
  - libbrotlienc=1.0.9=h166bdaf_7
  - libclang=13.0.1=default_hc23dcda_0
  - libcrc32c=1.1.2=h9c3ff4c_0
  - libdeflate=1.10=h7f98852_0
  - libdrm=2.4.109=h7f98852_0
  - libedit=3.1.20191231=he28a2e2_2
  - libevent=2.1.10=h28343ad_4
  - libffi=3.4.2=h7f98852_5
  - libgcc-ng=11.2.0=h1d223b6_15
  - libgfortran-ng=7.5.0=h14aa051_20
  - libgfortran4=7.5.0=h14aa051_20
  - libglib=2.70.2=h174f98d_4
  - libiconv=1.16=h516909a_0
  - libllvm11=11.1.0=hf817b99_3
  - libllvm13=13.0.1=hf817b99_2
  - libnsl=2.0.0=h7f98852_0
  - libogg=1.3.4=h7f98852_1
  - libopus=1.3.1=h7f98852_1
  - libpciaccess=0.16=h516909a_0
  - libpng=1.6.37=h21135ba_2
  - libpq=14.2=h676c864_0
  - libprotobuf=3.20.0=h6239696_0
  - libsodium=1.0.18=h36c2ea0_1
  - libstdcxx-ng=11.2.0=he4da1e4_15
  - libtiff=4.3.0=h542a066_3
  - libuuid=2.32.1=h7f98852_1000
  - libuv=1.43.0=h7f98852_0
  - libva=2.14.0=h7f98852_0
  - libvorbis=1.3.7=h9c3ff4c_0
  - libvpx=1.11.0=h9c3ff4c_3
  - libwebp=1.2.2=h3452ae3_0
  - libwebp-base=1.2.2=h7f98852_1
  - libxcb=1.13=h7f98852_1004
  - libxkbcommon=1.0.3=he3ba5ed_0
  - libxml2=2.9.12=h885dcf4_1
  - libzlib=1.2.11=h166bdaf_1014
  - line_profiler=3.4.0=py39hf939315_1
  - llvm-openmp=13.0.1=he0ac6c6_1
  - llvmlite=0.38.0=py39h7d9a04d_1
  - locket=0.2.0=py_2
  - lz4=4.0.0=py39h029007f_1
  - lz4-c=1.9.3=h9c3ff4c_1
  - markdown=3.3.6=pyhd8ed1ab_0
  - markupsafe=2.1.1=py39hb9d737c_1
  - matplotlib=3.5.1=py39hf3d152e_0
  - matplotlib-base=3.5.1=py39h2fa2bec_0
  - matplotlib-inline=0.1.3=pyhd8ed1ab_0
  - mkl=2021.4.0=h8d4b97c_729
  - mkl-service=2.4.0=py39h7e14d7c_0
  - mkl_fft=1.3.1=py39h0c7bc48_1
  - mkl_random=1.2.2=py39hde0f152_0
  - msgpack-python=1.0.3=py39hf939315_1
  - multidict=6.0.2=py39hb9d737c_1
  - multipledispatch=0.6.0=py_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - mysql-common=8.0.28=h26416b9_3
  - mysql-libs=8.0.28=hbc51c84_3
  - ncurses=6.3=h27087fc_1
  - nest-asyncio=1.5.5=pyhd8ed1ab_0
  - nettle=3.6=he412f7d_0
  - networkx=2.7.1=pyhd8ed1ab_0
  - nspr=4.32=h9c3ff4c_1
  - nss=3.77=h2350873_0
  - numba=0.55.1=py39h56b8d98_0
  - numpy=1.21.5=py39he7a7128_1
  - numpy-base=1.21.5=py39hf524024_1
  - oauthlib=3.2.0=pyhd8ed1ab_0
  - openh264=2.1.1=h780b84a_0
  - openjpeg=2.4.0=hb52868f_1
  - openpyxl=3.0.9=pyhd3eb1b0_0
  - openssl=3.0.3=h166bdaf_0
  - opt_einsum=3.3.0=pyhd8ed1ab_1
  - packaging=21.3=pyhd8ed1ab_0
  - pandas=1.4.2=py39h1832856_1
  - panel=0.12.7=pyhd8ed1ab_0
  - param=1.12.1=pyh6c4a22f_0
  - parso=0.8.3=pyhd8ed1ab_0
  - partd=1.2.0=pyhd8ed1ab_0
  - patsy=0.5.2=pyhd8ed1ab_0
  - pcre=8.45=h9c3ff4c_0
  - pexpect=4.8.0=pyh9f0ad1d_2
  - pickleshare=0.7.5=py_1003
  - pillow=9.1.0=py39hae2aec6_2
  - pip=22.0.4=pyhd8ed1ab_0
  - prompt-toolkit=3.0.29=pyha770c72_0
  - protobuf=3.20.0=py39h5a03fae_4
  - psutil=5.9.0=py39hb9d737c_1
  - pthread-stubs=0.4=h36c2ea0_1001
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pyasn1=0.4.8=py_0
  - pyasn1-modules=0.2.7=py_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pyct=0.4.6=py_0
  - pyct-core=0.4.6=py_0
  - pyfaidx=0.7.0=pyh5e36f6f_0
  - pygments=2.11.2=pyhd8ed1ab_0
  - pyjwt=2.3.0=pyhd8ed1ab_1
  - pynndescent=0.5.6=pyh6c4a22f_0
  - pyopenssl=22.0.0=pyhd8ed1ab_0
  - pyparsing=3.0.8=pyhd8ed1ab_0
  - pyqt=5.12.3=py39hf3d152e_8
  - pyqt-impl=5.12.3=py39hde8b62d_8
  - pyqt5-sip=4.19.18=py39he80948d_8
  - pyqtchart=5.12=py39h0fcd23e_8
  - pyqtwebengine=5.12.1=py39h0fcd23e_8
  - pyro-api=0.1.2=pyhd8ed1ab_0
  - pyro-ppl=1.8.0=pyhd8ed1ab_0
  - pysocks=1.7.1=py39hf3d152e_5
  - python=3.9.12=h2660328_1_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.9=2_cp39
  - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
  - pytorch-mutex=1.0=cuda
  - pytz=2022.1=pyhd8ed1ab_0
  - pyu2f=0.1.5=pyhd8ed1ab_0
  - pyviz_comms=2.2.0=pyhd8ed1ab_0
  - pywavelets=1.3.0=py39hd257fcd_1
  - pyyaml=6.0=py39hb9d737c_4
  - pyzmq=22.3.0=py39headdf64_2
  - qt=5.12.9=h1304e3e_6
  - readline=8.1=h46c0cb4_0
  - requests=2.27.1=pyhd8ed1ab_0
  - requests-oauthlib=1.3.1=pyhd8ed1ab_0
  - rsa=4.8=pyhd8ed1ab_0
  - scikit-image=0.19.2=py39hde0f152_0
  - scikit-learn=1.0.2=py39h51133e4_1
  - scipy=1.7.3=py39hc147768_0
  - seaborn=0.11.0=py_0
  - setuptools=59.8.0=py39hf3d152e_1
  - simplejson=3.17.6=py39hb9d737c_1
  - six=1.16.0=pyh6c4a22f_0
  - sortedcontainers=2.4.0=pyhd8ed1ab_0
  - sqlite=3.38.2=h4ff8645_0
  - stack_data=0.2.0=pyhd8ed1ab_0
  - statsmodels=0.13.2=py39hce5d2b2_0
  - svt-av1=0.9.1=h27087fc_0
  - tbb=2021.5.0=h924138e_1
  - tblib=1.7.0=pyhd8ed1ab_0
  - termcolor=1.1.0=pyhd8ed1ab_3
  - threadpoolctl=3.1.0=pyh8a188c0_0
  - tifffile=2020.6.3=py_0
  - tk=8.6.12=h27826a3_0
  - toolz=0.11.2=pyhd8ed1ab_0
  - torchaudio=0.11.0=py39_cu113
  - torchvision=0.12.0=py39_cu113
  - tornado=6.1=py39hb9d737c_3
  - tqdm=4.64.0=pyhd8ed1ab_0
  - traitlets=5.1.1=pyhd8ed1ab_0
  - typing-extensions=4.1.1=hd8ed1ab_0
  - typing_extensions=4.1.1=pyha770c72_0
  - tzdata=2022a=h191b570_0
  - umap-learn=0.5.2=py39hf3d152e_1
  - uncertainties=3.1.7=pyhd8ed1ab_0
  - unicodedata2=14.0.0=py39hb9d737c_1
  - urllib3=1.26.9=pyhd8ed1ab_0
  - wcwidth=0.2.5=pyh9f0ad1d_2
  - webencodings=0.5.1=py_1
  - werkzeug=2.1.1=pyhd8ed1ab_0
  - wheel=0.37.1=pyhd8ed1ab_0
  - x264=1!161.3030=h7f98852_1
  - x265=3.5=h924138e_3
  - xarray=2022.3.0=pyhd8ed1ab_0
  - xorg-fixesproto=5.0=h7f98852_1002
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libx11=1.7.2=h7f98852_0
  - xorg-libxau=1.0.9=h7f98852_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xorg-libxext=1.3.4=h7f98852_1
  - xorg-libxfixes=5.0.3=h7f98852_1004
  - xorg-xextproto=7.3.0=h7f98852_1002
  - xorg-xproto=7.0.31=h7f98852_1007
  - xz=5.2.5=h516909a_1
  - yaml=0.2.5=h7f98852_2
  - yarl=1.7.2=py39hb9d737c_2
  - zeromq=4.3.4=h9c3ff4c_1
  - zict=2.1.0=pyhd8ed1ab_0
  - zipp=3.8.0=pyhd8ed1ab_0
  - zlib=1.2.11=h166bdaf_1014
  - zstd=1.5.2=ha95c52a_0
  - pip:
    - fair-esm==1.0.1
    - filelock==3.6.0
    - google-api-python-client==2.45.0
    - google-auth-httplib2==0.1.0
    - httplib2==0.20.4
    - huggingface-hub==0.5.1
    - regex==2022.3.15
    - sacremoses==0.0.49
    - sentencepiece==0.1.96
    - tokenizers==0.12.1
    - transformers==4.18.0
    - uritemplate==4.1.1

I will post the actual embeddings in another reply to avoid cluttering.

FedericoV commented 2 years ago
len_2 = 30

sequences = [
    'A' * 4,
    'Y' * len_2
]

model_input = batch_converter([(None, seq) for seq in sequences])[2]
model_input = model_input.to(device)

# Here is the surprising part:
with torch.no_grad():
    logits1 = esm_model(model_input[[0]])['logits']
    logits2 = esm_model(model_input)['logits'][0]

    norm = float(torch.linalg.norm(logits1 - logits2).detach().cpu())

logits1:

tensor([[[ 8.5170e+00, -1.2241e+01, -3.9846e+00,  ..., -1.5335e+01,
          -1.3460e+01, -1.2178e+01],
         [-1.0977e+01, -1.7280e+01, -1.2593e+01,  ..., -1.5288e+01,
          -1.5942e+01, -1.7287e+01],
         [-1.2390e+01, -1.8101e+01, -1.2171e+01,  ..., -1.5465e+01,
          -1.5759e+01, -1.8096e+01],
         ...,
         [-1.2104e+01, -1.8944e+01, -1.1913e+01,  ..., -1.6254e+01,
          -1.6069e+01, -1.8937e+01],
         [ 1.3983e-02, -8.0200e+00,  2.4321e+01,  ..., -1.3348e+01,
          -1.3025e+01, -7.9783e+00],
         [-4.7551e-01, -9.2902e+00,  1.8570e+01,  ..., -1.3072e+01,
          -1.2925e+01, -9.2340e+00]]], device='cuda:0')

logits2:

tensor([[ 8.5033e+00, -1.2248e+01, -3.9885e+00,  ..., -1.5335e+01,
         -1.3462e+01, -1.2184e+01],
        [-1.0979e+01, -1.7284e+01, -1.2595e+01,  ..., -1.5289e+01,
         -1.5943e+01, -1.7291e+01],
        [-1.2391e+01, -1.8103e+01, -1.2171e+01,  ..., -1.5465e+01,
         -1.5758e+01, -1.8097e+01],
        ...,
        [-1.2103e+01, -1.8943e+01, -1.1913e+01,  ..., -1.6254e+01,
         -1.6069e+01, -1.8936e+01],
        [ 8.6889e-03, -8.0334e+00,  2.4305e+01,  ..., -1.3349e+01,
         -1.3026e+01, -7.9916e+00],
        [-4.8808e-01, -9.3043e+00,  1.8519e+01,  ..., -1.3073e+01,
         -1.2927e+01, -9.2482e+00]], device='cuda:0')
FedericoV commented 2 years ago

I see a similar pattern with esm1b_t33_650M_UR50S:

image

FedericoV commented 2 years ago

I just cleared the cache and downloaded the models again from torchhub, and I got the exact same pattern showing up.

FedericoV commented 2 years ago

Created a wholly new VM and reinstalled everything from scratch, and I keep seeing the same thing.

FedericoV commented 2 years ago

Converted all weights of the model (esm2) to float64, and repeated the process:

image

This is almost certainly a numerical issue - the numbers are very small once the weights are made float64.

nikitos9000 commented 2 years ago

@FedericoV Thanks for detailed info. Yes it seems like a numerical issues for sure, but I can reproduce it only when going to float16 precision, not in float32. In float32, everything works well for all the possible input lengths. Could you please provide maybe your GPU info as well, and also please try to compare the logits without pad tokens, i.e. logits1[:, :seq_len + 2] – as they are masked during the inference, no meaningful output is expected and sequences in batch are padded to max size (+ 2 service tokens).

FedericoV commented 2 years ago

@nikitos9000 happy to look into it more.

Re: GPU: I'm using an A100 on Google Cloud.

image

Re: Cropping. Most of the examples for esm embedding cropping show the embeddings being cropped like so:

with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

For esm2, do we need to crop from 0 to len(seq)+2 instead?

BernhoferM commented 2 years ago

This might be caused by non-deterministic behavior of PyTorch and cuBLAS on GPU. There are a few things you can configure to get as much determinism as possible, but even then you might get different results on different hardware.

https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility https://pytorch.org/docs/stable/notes/randomness.html#reproducibility

Below is a minimal example (with a single linear layer) to test your system. You can uncomment the lines regarding cuDNN and deterministic algorithms to see how the results change. In my personal experience, setting CUBLAS_WORKSPACE_CONFIG to :16:8 gave the most stable results.

import torch
import numpy
import random

seed = 101

random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

'''
Set EnvVar 'CUBLAS_WORKSPACE_CONFIG' to either ':16:8' or ':4096:8'
'''
# torch.use_deterministic_algorithms(True)

# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

B, N, F = 64, 200, 512

f = torch.nn.Linear(F, 64).cuda()
x = torch.randn(B, N, F).float().cuda()

print('### TEST B DIMENSION ###')
for n in range(1, B+1):
    y = f(x)[:n]
    z = f(x[:n])

    print(n, torch.equal(y, z), torch.abs(y - z).max().item())

print('### TEST N DIMENSION ###')
for n in range(1, N+1):
    y = f(x)[:, :n, :]
    z = f(x[:, :n, :])

    print(n, torch.equal(y, z), torch.abs(y - z).max().item())
nikitos9000 commented 2 years ago

@FedericoV Have you got any new results after dropping out pad logits and setting deterministic=True?