pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Torchscript error in JitTraceEnum_ELBO Torch Version 2.2.1, CUDA Version: 12.3 #3338

Open mtvector opened 7 months ago

mtvector commented 7 months ago

Hi there, Noticed a bug in JitTraceEnum_ELBO. My code runs fine with a previous version of pytorch or with JitTrace_ELBO (I can use RelaxedOneHotCategorical instead of OneHotCategorical for what I was enumerating). I don't personally need this bug fixed at this time, and this bug is out of my depth to understand but figured I'd report it in case someone else notices the same problem:

The error seems to come from a torchscript issue in calculating the Enumerate ELBO in pyro.infer.SVI:

    315 def step(self, *args, **kwargs):
    316     # Compute loss and gradients
    317     with poutine.trace(param_only=True) as param_capture:
--> 318         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    320     loss_val = torch_item(loss)
    321     self.losses.append(loss_val)

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:564, in JitTraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    563 def loss_and_grads(self, model, guide, *args, **kwargs):
--> 564     differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
    565     differentiable_loss.backward()  # this line triggers jit compilation
    566     loss = differentiable_loss.item()

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:561, in JitTraceEnum_ELBO.differentiable_loss(self, model, guide, *args, **kwargs)
    557         return elbo * (-1.0 / self.num_particles)
    559     self._differentiable_loss = differentiable_loss
--> 561 return self._differentiable_loss(*args, **kwargs)

File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/ops/jit.py:120, in CompiledFunction.__call__(self, *args, **kwargs)
    118 with poutine.block(hide=self._param_names):
    119     with poutine.trace(param_only=True) as param_capture:
--> 120         ret = self.compiled[key](*params_and_args)
    122 for name in param_capture.trace.nodes.keys():
    123     if name not in self._param_names:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: default_program(23): error: extra text after expected end of number
      aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
                                                                                                                        ^

default_program(23): error: extra text after expected end of number
      aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
                                                                                                                                                   ^

2 errors detected in the compilation of "default_program".

nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)

template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}

extern "C" __global__
void fused_clamp_sub_exp(float* tt_3, float* tshift_1, float* aten_exp) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<45150ll ? 1 : 0) {
    float tshift_1_1 = __ldg(tshift_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    float v = __ldg(tt_3 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
    aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
  }}
}

My environment is as follows:

``` absl-py==2.1.0 aiohttp==3.9.1 aiosignal==1.3.1 anndata==0.10.4 annotated-types==0.6.0 anyio==4.2.0 array_api_compat==1.4.1 arrow==1.3.0 asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work attrs==23.2.0 backoff==2.2.1 beautifulsoup4==4.12.3 blessed==1.20.0 boto3==1.34.28 botocore==1.34.28 certifi==2023.11.17 charset-normalizer==3.3.2 chex==0.1.7 click==8.1.7 comm @ file:///work/ci_py311/comm_1677709131612/work contextlib2==21.6.0 contourpy==1.2.0 croniter==1.4.1 cycler==0.12.1 dateutils==0.6.12 debugpy @ file:///croot/debugpy_1690905042057/work decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work deepdiff==6.7.1 dm-tree==0.1.8 docrep==0.3.2 editor==1.6.6 etils==1.6.0 executing @ file:///opt/conda/conda-bld/executing_1646925071911/work fastapi==0.109.0 filelock @ file:///croot/filelock_1700591183607/work flax==0.8.0 fonttools==4.47.2 frozenlist==1.4.1 fsspec==2023.12.2 gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work h11==0.14.0 h5py==3.10.0 idna==3.6 igraph==0.11.3 importlib-resources==6.1.1 inquirer==3.2.1 ipykernel @ file:///croot/ipykernel_1705933831282/work ipython @ file:///croot/ipython_1704833016303/work itsdangerous==2.1.2 jax==0.4.23 jaxlib==0.4.23 jedi @ file:///work/ci_py311_2/jedi_1679336495545/work Jinja2 @ file:///work/ci_py311/jinja2_1676823587943/work jmespath==1.0.1 joblib==1.3.2 jupyter_client @ file:///croot/jupyter_client_1699455897726/work jupyter_core @ file:///croot/jupyter_core_1698937308754/work kiwisolver==1.4.5 leidenalg==0.10.2 lightning==2.0.9.post0 lightning-cloud==0.5.61 lightning-utilities==0.10.1 llvmlite==0.41.1 markdown-it-py==3.0.0 MarkupSafe @ file:///croot/markupsafe_1704205993651/work matplotlib==3.8.2 matplotlib-inline @ file:///work/ci_py311/matplotlib-inline_1676823841154/work mdurl==0.1.2 mkl-fft @ file:///croot/mkl_fft_1695058164594/work mkl-random @ file:///croot/mkl_random_1695059800811/work mkl-service==2.4.0 ml-collections==0.1.1 ml-dtypes @ file:///croot/ml_dtypes_1702691022032/work mpmath @ file:///croot/mpmath_1690848262763/work msgpack==1.0.7 mudata==0.2.3 multidict==6.0.4 multipledispatch==1.0.0 natsort==8.4.0 nest-asyncio @ file:///work/ci_py311/nest-asyncio_1676823382924/work networkx==3.2.1 numba==0.58.1 numpy==1.26.1 numpyro==0.13.2 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.4.99 nvidia-nvtx-cu12==12.1.105 opt-einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1696448916724/work optax==0.1.8 orbax-checkpoint==0.5.1 ordered-set==4.1.0 packaging @ file:///croot/packaging_1693575174725/work pandas==2.2.0 parso @ file:///opt/conda/conda-bld/parso_1641458642106/work patsy==0.5.6 pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pillow==10.2.0 platformdirs @ file:///croot/platformdirs_1692205439124/work prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work protobuf==4.25.2 psutil @ file:///work/ci_py311_2/psutil_1679337388738/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work pydantic==2.1.1 pydantic_core==2.4.0 Pygments @ file:///croot/pygments_1684279966437/work PyJWT==2.8.0 pymde==0.1.18 pynndescent==0.5.11 pyparsing==3.1.1 pyro-api==0.1.2 pyro-ppl==1.8.6 python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work python-multipart==0.0.6 pytorch-lightning==2.1.3 pytz==2023.3.post1 PyYAML @ file:///croot/pyyaml_1698096049011/work pyzmq @ file:///croot/pyzmq_1705605076900/work readchar==4.0.5 requests==2.31.0 rich==13.7.0 runs==1.2.2 s3transfer==0.10.0 scanpy==1.9.6 scikit-learn==1.3.2 scipy==1.11.4 scvi-tools==1.0.4 seaborn==0.13.1 session-info==1.0.0 six @ file:///tmp/build/80754af9/six_1644875935023/work sniffio==1.3.0 soupsieve==2.5 sparse==0.15.1 stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work starlette==0.35.1 starsessions==1.3.0 statsmodels==0.14.1 stdlib-list==0.10.0 sympy @ file:///croot/sympy_1701397643339/work tensorstore==0.1.52 texttable==1.7.0 threadpoolctl==3.2.0 toolz==0.12.1 torch==2.2.1 torchmetrics==1.3.0.post0 torchvision==0.17.1 tornado @ file:///croot/tornado_1696936946304/work tqdm==4.66.1 traitlets @ file:///work/ci_py311/traitlets_1676823305040/work triton==2.2.0 types-python-dateutil==2.8.19.20240106 typing_extensions @ file:///croot/typing_extensions_1705599297034/work tzdata==2023.4 umap-learn==0.5.5 urllib3==2.0.7 uvicorn==0.27.0 wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work websocket-client==1.7.0 websockets==12.0 xarray==2024.1.1 xgboost==2.0.1 xmod==1.8.1 yarl==1.9.4 zipp==3.17.0 ```

Thanks for all the development work, pyro rules!

fritzo commented 7 months ago

Thanks for the bug report. My guess is that this is an upstream bug in pytorch code generation where they are writing two decimal points in a floating point constant. I'm not sure what we can do but wait for an upstream fix.