BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
168 stars 11 forks source link

Deep structural causal model counterfactuals example causes an `nan` results #461

Open Flawless1202 opened 9 months ago

Flawless1202 commented 9 months ago

When I run the tutorial Example: Deep structural causal model counterfactuals, the following cell:

adam_params = {"lr": 1e-3}
batch_size = 128
num_epochs = 100

class LightningSVI(pl.LightningModule):
    def __init__(
        self,
        model: ConditionedDeepSCM,
        guide: pyro.nn.PyroModule,
        elbo: pyro.infer.ELBO,
        optim_params: dict,
    ):
        super().__init__()
        self.model = model
        self.guide = guide
        self.elbo = elbo(self.model, self.guide)
        self._optim_params = optim_params

    def training_step(self, batch, batch_idx):
        t_obs, i_obs, x_obs = batch
        x_obs = x_obs + torch.rand_like(x_obs)
        loss = self.elbo(t_obs, i_obs, x_obs)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.elbo.parameters(), **self._optim_params)
        return optimizer

guide = pyro.infer.autoguide.AutoDelta(conditioned_model)
elbo = pyro.infer.Trace_ELBO()
lightning_svi = LightningSVI(conditioned_model, guide, elbo, adam_params)

dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(thickness, intensity, images),
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
)

trainer = pl.Trainer(
    max_epochs=num_epochs,
    gradient_clip_val=1.0,
    accelerator="gpu",
    default_root_dir=os.path.join("./lightning_logs/deepscm_ckpt", "deepscm_joint"),
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_weights_only=True, mode="min", monitor="train_loss"
        ),
    ],
)

try:
    state_path = os.path.join(
        "./lightning_logs/deepscm_ckpt",
        "deepscm_joint",
        "lightning_logs",
        "version_114",
        "checkpoints",
        "epoch=47-step=2064.ckpt",
    )
    lightning_svi.load_state_dict(torch.load(state_path)["state_dict"])
    lightning_svi.eval()
except FileNotFoundError:
    trainer.fit(model=lightning_svi, train_dataloaders=dataloader)

causes an error as follow:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:64                                                                                   │
│                                                                                                  │
│   61 │   │   "checkpoints",                                                                      │
│   62 │   │   "epoch=47-step=2064.ckpt",                                                          │
│   63 │   )                                                                                       │
│ ❱ 64 │   lightning_svi.load_state_dict(torch.load(state_path)["state_dict"])                     │
│   65 │   lightning_svi.eval()                                                                    │
│   66 except FileNotFoundError:                                                                   │
│   67 │   trainer.fit(model=lightning_svi, train_dataloaders=dataloader)                          │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/serialization.py:791 │
│ in load                                                                                          │
│                                                                                                  │
│    788 │   if 'encoding' not in pickle_load_args.keys():                                         │
│    789 │   │   pickle_load_args['encoding'] = 'utf-8'                                            │
│    790 │                                                                                         │
│ ❱  791 │   with _open_file_like(f, 'rb') as opened_file:                                         │
│    792 │   │   if _is_zipfile(opened_file):                                                      │
│    793 │   │   │   # The zipfile reader is going to advance the current file position.           │
│    794 │   │   │   # If we want to actually tail call to torch.jit.load, we need to              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/serialization.py:271 │
│ in _open_file_like                                                                               │
│                                                                                                  │
│    268                                                                                           │
│    269 def _open_file_like(name_or_buffer, mode):                                                │
│    270 │   if _is_path(name_or_buffer):                                                          │
│ ❱  271 │   │   return _open_file(name_or_buffer, mode)                                           │
│    272 │   else:                                                                                 │
│    273 │   │   if 'w' in mode:                                                                   │
│    274 │   │   │   return _open_buffer_writer(name_or_buffer)                                    │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/serialization.py:252 │
│ in __init__                                                                                      │
│                                                                                                  │
│    249                                                                                           │
│    250 class _open_file(_opener):                                                                │
│    251 │   def __init__(self, name, mode):                                                       │
│ ❱  252 │   │   super().__init__(open(name, mode))                                                │
│    253 │                                                                                         │
│    254 │   def __exit__(self, *args):                                                            │
│    255 │   │   self.file_like.close()                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
FileNotFoundError: [Errno 2] No such file or directory: 
'./lightning_logs/deepscm_ckpt/deepscm_joint/lightning_logs/version_114/checkpoints/epoch=47-step=2064.ckpt'

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/poutine/trace_struct. │
│ py:230 in compute_log_prob                                                                       │
│                                                                                                  │
│   227 │   │   │   if site["type"] == "sample" and site_filter(name, site):                       │
│   228 │   │   │   │   if "log_prob" not in site:                                                 │
│   229 │   │   │   │   │   try:                                                                   │
│ ❱ 230 │   │   │   │   │   │   log_p = site["fn"].log_prob(                                       │
│   231 │   │   │   │   │   │   │   site["value"], *site["args"], **site["kwargs"]                 │
│   232 │   │   │   │   │   │   )                                                                  │
│   233 │   │   │   │   │   except ValueError as e:                                                │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/transf │
│ ormed_distribution.py:153 in log_prob                                                            │
│                                                                                                  │
│   150 │   │   │   │   │   │   │   │   │   │   │   │    event_dim - transform.domain.event_dim)   │
│   151 │   │   │   y = x                                                                          │
│   152 │   │                                                                                      │
│ ❱ 153 │   │   log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),                   │
│   154 │   │   │   │   │   │   │   │   │   │   │    event_dim - len(self.base_dist.event_shape)   │
│   155 │   │   return log_prob                                                                    │
│   156                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/indepe │
│ ndent.py:99 in log_prob                                                                          │
│                                                                                                  │
│    96 │   │   return self.base_dist.rsample(sample_shape)                                        │
│    97 │                                                                                          │
│    98 │   def log_prob(self, value):                                                             │
│ ❱  99 │   │   log_prob = self.base_dist.log_prob(value)                                          │
│   100 │   │   return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)                    │
│   101 │                                                                                          │
│   102 │   def entropy(self):                                                                     │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/normal │
│ .py:79 in log_prob                                                                               │
│                                                                                                  │
│    76 │                                                                                          │
│    77 │   def log_prob(self, value):                                                             │
│    78 │   │   if self._validate_args:                                                            │
│ ❱  79 │   │   │   self._validate_sample(value)                                                   │
│    80 │   │   # compute the variance                                                             │
│    81 │   │   var = (self.scale ** 2)                                                            │
│    82 │   │   log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale   │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/distri │
│ bution.py:300 in _validate_sample                                                                │
│                                                                                                  │
│   297 │   │   assert support is not None                                                         │
│   298 │   │   valid = support.check(value)                                                       │
│   299 │   │   if not valid.all():                                                                │
│ ❱ 300 │   │   │   raise ValueError(                                                              │
│   301 │   │   │   │   "Expected value argument "                                                 │
│   302 │   │   │   │   f"({type(value).__name__} of shape {tuple(value.shape)}) "                 │
│   303 │   │   │   │   f"to be within the support ({repr(support)}) "                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Expected value argument (Tensor of shape (128, 1)) to be within the support (Real()) of the 
distribution Normal(loc: torch.Size([128, 1]), scale: torch.Size([128, 1])), but found invalid values:
tensor([[   nan],
        [   nan],
        [   nan],
        [1.3494],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0990],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0442],
        [1.1113],
        [1.0319],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.4944],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0511],
        [   nan],
        [1.1086],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.1693],
        [1.2343],
        [   nan],
        [   nan],
        [   nan],
        [1.2089],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.1464],
        [   nan],
        [   nan],
        [1.0146],
        [   nan],
        [1.0616],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0513],
        [   nan],
        [   nan],
        [1.0878],
        [   nan],
        [   nan],
        [1.0663],
        [1.0729],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.2027],
        [   nan],
        [1.0763],
        [1.0316],
        [1.0748],
        [1.3685],
        [   nan],
        [   nan],
        [   nan],
        [1.3050],
        [   nan],
        [   nan],
        [   nan],
        [1.2104],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0835],
        [   nan],
        [   nan]], device='cuda:0', grad_fn=<IndexPutBackward0>)

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:67                                                                                   │
│                                                                                                  │
│   64 │   lightning_svi.load_state_dict(torch.load(state_path)["state_dict"])                     │
│   65 │   lightning_svi.eval()                                                                    │
│   66 except FileNotFoundError:                                                                   │
│ ❱ 67 │   trainer.fit(model=lightning_svi, train_dataloaders=dataloader)                          │
│   68                                                                                             │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ trainer.py:544 in fit                                                                            │
│                                                                                                  │
│    541 │   │   self.state.fn = TrainerFn.FITTING                                                 │
│    542 │   │   self.state.status = TrainerStatus.RUNNING                                         │
│    543 │   │   self.training = True                                                              │
│ ❱  544 │   │   call._call_and_handle_interrupt(                                                  │
│    545 │   │   │   self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule,  │
│    546 │   │   )                                                                                 │
│    547                                                                                           │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ call.py:44 in _call_and_handle_interrupt                                                         │
│                                                                                                  │
│    41 │   try:                                                                                   │
│    42 │   │   if trainer.strategy.launcher is not None:                                          │
│    43 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer,    │
│ ❱  44 │   │   return trainer_fn(*args, **kwargs)                                                 │
│    45 │                                                                                          │
│    46 │   except _TunerExitException:                                                            │
│    47 │   │   _call_teardown_hook(trainer)                                                       │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ trainer.py:580 in _fit_impl                                                                      │
│                                                                                                  │
│    577 │   │   │   model_provided=True,                                                          │
│    578 │   │   │   model_connected=self.lightning_module is not None,                            │
│    579 │   │   )                                                                                 │
│ ❱  580 │   │   self._run(model, ckpt_path=ckpt_path)                                             │
│    581 │   │                                                                                     │
│    582 │   │   assert self.state.stopped                                                         │
│    583 │   │   self.training = False                                                             │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ trainer.py:989 in _run                                                                           │
│                                                                                                  │
│    986 │   │   # ----------------------------                                                    │
│    987 │   │   # RUN THE TRAINER                                                                 │
│    988 │   │   # ----------------------------                                                    │
│ ❱  989 │   │   results = self._run_stage()                                                       │
│    990 │   │                                                                                     │
│    991 │   │   # ----------------------------                                                    │
│    992 │   │   # POST-Training CLEAN UP                                                          │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ trainer.py:1035 in _run_stage                                                                    │
│                                                                                                  │
│   1032 │   │   │   with isolate_rng():                                                           │
│   1033 │   │   │   │   self._run_sanity_check()                                                  │
│   1034 │   │   │   with torch.autograd.set_detect_anomaly(self._detect_anomaly):                 │
│ ❱ 1035 │   │   │   │   self.fit_loop.run()                                                       │
│   1036 │   │   │   return None                                                                   │
│   1037 │   │   raise RuntimeError(f"Unexpected state {self.state}")                              │
│   1038                                                                                           │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/fi │
│ t_loop.py:202 in run                                                                             │
│                                                                                                  │
│   199 │   │   while not self.done:                                                               │
│   200 │   │   │   try:                                                                           │
│   201 │   │   │   │   self.on_advance_start()                                                    │
│ ❱ 202 │   │   │   │   self.advance()                                                             │
│   203 │   │   │   │   self.on_advance_end()                                                      │
│   204 │   │   │   │   self._restarting = False                                                   │
│   205 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/fi │
│ t_loop.py:359 in advance                                                                         │
│                                                                                                  │
│   356 │   │   │   )                                                                              │
│   357 │   │   with self.trainer.profiler.profile("run_training_epoch"):                          │
│   358 │   │   │   assert self._data_fetcher is not None                                          │
│ ❱ 359 │   │   │   self.epoch_loop.run(self._data_fetcher)                                        │
│   360 │                                                                                          │
│   361 │   def on_advance_end(self) -> None:                                                      │
│   362 │   │   trainer = self.trainer                                                             │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/tr │
│ aining_epoch_loop.py:136 in run                                                                  │
│                                                                                                  │
│   133 │   │   self.on_run_start(data_fetcher)                                                    │
│   134 │   │   while not self.done:                                                               │
│   135 │   │   │   try:                                                                           │
│ ❱ 136 │   │   │   │   self.advance(data_fetcher)                                                 │
│   137 │   │   │   │   self.on_advance_end(data_fetcher)                                          │
│   138 │   │   │   │   self._restarting = False                                                   │
│   139 │   │   │   except StopIteration:                                                          │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/tr │
│ aining_epoch_loop.py:240 in advance                                                              │
│                                                                                                  │
│   237 │   │   │   with trainer.profiler.profile("run_training_batch"):                           │
│   238 │   │   │   │   if trainer.lightning_module.automatic_optimization:                        │
│   239 │   │   │   │   │   # in automatic optimization, there can only be one optimizer           │
│ ❱ 240 │   │   │   │   │   batch_output = self.automatic_optimization.run(trainer.optimizers[0]   │
│   241 │   │   │   │   else:                                                                      │
│   242 │   │   │   │   │   batch_output = self.manual_optimization.run(kwargs)                    │
│   243                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/op │
│ timization/automatic.py:187 in run                                                               │
│                                                                                                  │
│   184 │   │   # ------------------------------                                                   │
│   185 │   │   # gradient update with accumulated gradients                                       │
│   186 │   │   else:                                                                              │
│ ❱ 187 │   │   │   self._optimizer_step(batch_idx, closure)                                       │
│   188 │   │                                                                                      │
│   189 │   │   result = closure.consume_result()                                                  │
│   190 │   │   if result.loss is None:                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/op │
│ timization/automatic.py:265 in _optimizer_step                                                   │
│                                                                                                  │
│   262 │   │   │   self.optim_progress.optimizer.step.increment_ready()                           │
│   263 │   │                                                                                      │
│   264 │   │   # model hook                                                                       │
│ ❱ 265 │   │   call._call_lightning_module_hook(                                                  │
│   266 │   │   │   trainer,                                                                       │
│   267 │   │   │   "optimizer_step",                                                              │
│   268 │   │   │   trainer.current_epoch,                                                         │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ call.py:157 in _call_lightning_module_hook                                                       │
│                                                                                                  │
│   154 │   pl_module._current_fx_name = hook_name                                                 │
│   155 │                                                                                          │
│   156 │   with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hoo   │
│ ❱ 157 │   │   output = fn(*args, **kwargs)                                                       │
│   158 │                                                                                          │
│   159 │   # restore current_fx when nested context                                               │
│   160 │   pl_module._current_fx_name = prev_fx_name                                              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/core/mod │
│ ule.py:1282 in optimizer_step                                                                    │
│                                                                                                  │
│   1279 │   │   │   │   │   │   pg["lr"] = lr_scale * self.learning_rate                          │
│   1280 │   │                                                                                     │
│   1281 │   │   """                                                                               │
│ ❱ 1282 │   │   optimizer.step(closure=optimizer_closure)                                         │
│   1283 │                                                                                         │
│   1284 │   def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> N  │
│   1285 │   │   """Override this method to change the default behaviour of ``optimizer.zero_grad  │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/core/opt │
│ imizer.py:151 in step                                                                            │
│                                                                                                  │
│   148 │   │   │   raise MisconfigurationException("When `optimizer.step(closure)` is called, t   │
│   149 │   │                                                                                      │
│   150 │   │   assert self._strategy is not None                                                  │
│ ❱ 151 │   │   step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)    │
│   152 │   │                                                                                      │
│   153 │   │   self._on_after_step()                                                              │
│   154                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/strategi │
│ es/strategy.py:230 in optimizer_step                                                             │
│                                                                                                  │
│   227 │   │   model = model or self.lightning_module                                             │
│   228 │   │   # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed    │
│   229 │   │   assert isinstance(model, pl.LightningModule)                                       │
│ ❱ 230 │   │   return self.precision_plugin.optimizer_step(optimizer, model=model, closure=clos   │
│   231 │                                                                                          │
│   232 │   def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) ->   │
│   233 │   │   """Setup a model and multiple optimizers together.                                 │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/plugins/ │
│ precision/precision.py:117 in optimizer_step                                                     │
│                                                                                                  │
│   114 │   ) -> Any:                                                                              │
│   115 │   │   """Hook to run the optimizer step."""                                              │
│   116 │   │   closure = partial(self._wrap_closure, model, optimizer, closure)                   │
│ ❱ 117 │   │   return optimizer.step(closure=closure, **kwargs)                                   │
│   118 │                                                                                          │
│   119 │   def _clip_gradients(                                                                   │
│   120 │   │   self,                                                                              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/optim/optimizer.py:2 │
│ 80 in wrapper                                                                                    │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/optim/optimizer.py:3 │
│ 3 in _use_grad                                                                                   │
│                                                                                                  │
│    30 │   │   prev_grad = torch.is_grad_enabled()                                                │
│    31 │   │   try:                                                                               │
│    32 │   │   │   torch.set_grad_enabled(self.defaults['differentiable'])                        │
│ ❱  33 │   │   │   ret = func(self, *args, **kwargs)                                              │
│    34 │   │   finally:                                                                           │
│    35 │   │   │   torch.set_grad_enabled(prev_grad)                                              │
│    36 │   │   return ret                                                                         │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/optim/adam.py:121 in │
│ step                                                                                             │
│                                                                                                  │
│   118 │   │   loss = None                                                                        │
│   119 │   │   if closure is not None:                                                            │
│   120 │   │   │   with torch.enable_grad():                                                      │
│ ❱ 121 │   │   │   │   loss = closure()                                                           │
│   122 │   │                                                                                      │
│   123 │   │   for group in self.param_groups:                                                    │
│   124 │   │   │   params_with_grad = []                                                          │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/plugins/ │
│ precision/precision.py:104 in _wrap_closure                                                      │
│                                                                                                  │
│   101 │   │   consistent with the ``Precision`` subclasses that cannot pass ``optimizer.step(c   │
│   102 │   │                                                                                      │
│   103 │   │   """                                                                                │
│ ❱ 104 │   │   closure_result = closure()                                                         │
│   105 │   │   self._after_closure(model, optimizer)                                              │
│   106 │   │   return closure_result                                                              │
│   107                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/op │
│ timization/automatic.py:140 in __call__                                                          │
│                                                                                                  │
│   137 │   │   return step_output                                                                 │
│   138 │                                                                                          │
│   139 │   def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:                     │
│ ❱ 140 │   │   self._result = self.closure(*args, **kwargs)                                       │
│   141 │   │   return self._result.loss                                                           │
│   142                                                                                            │
│   143                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/utils/_contextlib.py │
│ :115 in decorate_context                                                                         │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/op │
│ timization/automatic.py:126 in closure                                                           │
│                                                                                                  │
│   123 │                                                                                          │
│   124 │   @torch.enable_grad()                                                                   │
│   125 │   def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:                         │
│ ❱ 126 │   │   step_output = self._step_fn()                                                      │
│   127 │   │                                                                                      │
│   128 │   │   if step_output.closure_loss is None:                                               │
│   129 │   │   │   self.warning_cache.warn("`training_step` returned `None`. If this was on pur   │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/loops/op │
│ timization/automatic.py:315 in _training_step                                                    │
│                                                                                                  │
│   312 │   │   trainer = self.trainer                                                             │
│   313 │   │                                                                                      │
│   314 │   │   # manually capture logged metrics                                                  │
│ ❱ 315 │   │   training_step_output = call._call_strategy_hook(trainer, "training_step", *kwarg   │
│   316 │   │   self.trainer.strategy.post_training_step()  # unused hook - call anyway for back   │
│   317 │   │                                                                                      │
│   318 │   │   return self.output_result_cls.from_training_step_output(training_step_output, tr   │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/trainer/ │
│ call.py:309 in _call_strategy_hook                                                               │
│                                                                                                  │
│   306 │   │   return None                                                                        │
│   307 │                                                                                          │
│   308 │   with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hoo   │
│ ❱ 309 │   │   output = fn(*args, **kwargs)                                                       │
│   310 │                                                                                          │
│   311 │   # restore current_fx when nested context                                               │
│   312 │   pl_module._current_fx_name = prev_fx_name                                              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pytorch_lightning/strategi │
│ es/strategy.py:382 in training_step                                                              │
│                                                                                                  │
│   379 │   │   with self.precision_plugin.train_step_context():                                   │
│   380 │   │   │   if self.model != self.lightning_module:                                        │
│   381 │   │   │   │   return self._forward_redirection(self.model, self.lightning_module, "tra   │
│ ❱ 382 │   │   │   return self.lightning_module.training_step(*args, **kwargs)                    │
│   383 │                                                                                          │
│   384 │   def post_training_step(self) -> None:                                                  │
│   385 │   │   """This hook is deprecated.                                                        │
│                                                                                                  │
│ in training_step:23                                                                              │
│                                                                                                  │
│   20 │   def training_step(self, batch, batch_idx):                                              │
│   21 │   │   t_obs, i_obs, x_obs = batch                                                         │
│   22 │   │   x_obs = x_obs + torch.rand_like(x_obs)                                              │
│ ❱ 23 │   │   loss = self.elbo(t_obs, i_obs, x_obs)                                               │
│   24 │   │   self.log("train_loss", loss)                                                        │
│   25 │   │   return loss                                                                         │
│   26                                                                                             │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1501 in _call_impl                                                                              │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/infer/elbo.py:25 in   │
│ forward                                                                                          │
│                                                                                                  │
│    22 │   │   self.elbo = elbo                                                                   │
│    23 │                                                                                          │
│    24 │   def forward(self, *args, **kwargs):                                                    │
│ ❱  25 │   │   return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs)      │
│    26                                                                                            │
│    27                                                                                            │
│    28 class ELBO(object, metaclass=ABCMeta):                                                     │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:1 │
│ 21 in differentiable_loss                                                                        │
│                                                                                                  │
│   118 │   │   """                                                                                │
│   119 │   │   loss = 0.0                                                                         │
│   120 │   │   surrogate_loss = 0.0                                                               │
│ ❱ 121 │   │   for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):      │
│   122 │   │   │   loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(   │
│   123 │   │   │   │   model_trace, guide_trace                                                   │
│   124 │   │   │   )                                                                              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/infer/elbo.py:237 in  │
│ _get_traces                                                                                      │
│                                                                                                  │
│   234 │   │   │   yield self._get_vectorized_trace(model, guide, args, kwargs)                   │
│   235 │   │   else:                                                                              │
│   236 │   │   │   for i in range(self.num_particles):                                            │
│ ❱ 237 │   │   │   │   yield self._get_trace(model, guide, args, kwargs)                          │
│   238                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:5 │
│ 7 in _get_trace                                                                                  │
│                                                                                                  │
│    54 │   │   Returns a single trace from the guide, and the model that is run                   │
│    55 │   │   against it.                                                                        │
│    56 │   │   """                                                                                │
│ ❱  57 │   │   model_trace, guide_trace = get_importance_trace(                                   │
│    58 │   │   │   "flat", self.max_plate_nesting, model, guide, args, kwargs                     │
│    59 │   │   )                                                                                  │
│    60 │   │   if is_validation_enabled():                                                        │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/infer/enum.py:75 in   │
│ get_importance_trace                                                                             │
│                                                                                                  │
│    72 │   guide_trace = prune_subsample_sites(guide_trace)                                       │
│    73 │   model_trace = prune_subsample_sites(model_trace)                                       │
│    74 │                                                                                          │
│ ❱  75 │   model_trace.compute_log_prob()                                                         │
│    76 │   guide_trace.compute_score_parts()                                                      │
│    77 │   if is_validation_enabled():                                                            │
│    78 │   │   for site in model_trace.nodes.values():                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/poutine/trace_struct. │
│ py:236 in compute_log_prob                                                                       │
│                                                                                                  │
│   233 │   │   │   │   │   except ValueError as e:                                                │
│   234 │   │   │   │   │   │   _, exc_value, traceback = sys.exc_info()                           │
│   235 │   │   │   │   │   │   shapes = self.format_shapes(last_site=site["name"])                │
│ ❱ 236 │   │   │   │   │   │   raise ValueError(                                                  │
│   237 │   │   │   │   │   │   │   "Error while computing log_prob at site '{}':\n{}\n{}".forma   │
│   238 │   │   │   │   │   │   │   │   name, exc_value, shapes                                    │
│   239 │   │   │   │   │   │   │   )                                                              │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/pyro/poutine/trace_struct. │
│ py:230 in compute_log_prob                                                                       │
│                                                                                                  │
│   227 │   │   │   if site["type"] == "sample" and site_filter(name, site):                       │
│   228 │   │   │   │   if "log_prob" not in site:                                                 │
│   229 │   │   │   │   │   try:                                                                   │
│ ❱ 230 │   │   │   │   │   │   log_p = site["fn"].log_prob(                                       │
│   231 │   │   │   │   │   │   │   site["value"], *site["args"], **site["kwargs"]                 │
│   232 │   │   │   │   │   │   )                                                                  │
│   233 │   │   │   │   │   except ValueError as e:                                                │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/transf │
│ ormed_distribution.py:153 in log_prob                                                            │
│                                                                                                  │
│   150 │   │   │   │   │   │   │   │   │   │   │   │    event_dim - transform.domain.event_dim)   │
│   151 │   │   │   y = x                                                                          │
│   152 │   │                                                                                      │
│ ❱ 153 │   │   log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),                   │
│   154 │   │   │   │   │   │   │   │   │   │   │    event_dim - len(self.base_dist.event_shape)   │
│   155 │   │   return log_prob                                                                    │
│   156                                                                                            │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/indepe │
│ ndent.py:99 in log_prob                                                                          │
│                                                                                                  │
│    96 │   │   return self.base_dist.rsample(sample_shape)                                        │
│    97 │                                                                                          │
│    98 │   def log_prob(self, value):                                                             │
│ ❱  99 │   │   log_prob = self.base_dist.log_prob(value)                                          │
│   100 │   │   return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)                    │
│   101 │                                                                                          │
│   102 │   def entropy(self):                                                                     │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/normal │
│ .py:79 in log_prob                                                                               │
│                                                                                                  │
│    76 │                                                                                          │
│    77 │   def log_prob(self, value):                                                             │
│    78 │   │   if self._validate_args:                                                            │
│ ❱  79 │   │   │   self._validate_sample(value)                                                   │
│    80 │   │   # compute the variance                                                             │
│    81 │   │   var = (self.scale ** 2)                                                            │
│    82 │   │   log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale   │
│                                                                                                  │
│ /home/chenkai/.virtualenvs/py310torch200/lib/python3.10/site-packages/torch/distributions/distri │
│ bution.py:300 in _validate_sample                                                                │
│                                                                                                  │
│   297 │   │   assert support is not None                                                         │
│   298 │   │   valid = support.check(value)                                                       │
│   299 │   │   if not valid.all():                                                                │
│ ❱ 300 │   │   │   raise ValueError(                                                              │
│   301 │   │   │   │   "Expected value argument "                                                 │
│   302 │   │   │   │   f"({type(value).__name__} of shape {tuple(value.shape)}) "                 │
│   303 │   │   │   │   f"to be within the support ({repr(support)}) "                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Error while computing log_prob at site 'T':
Expected value argument (Tensor of shape (128, 1)) to be within the support (Real()) of the distribution 
Normal(loc: torch.Size([128, 1]), scale: torch.Size([128, 1])), but found invalid values:
tensor([[   nan],
        [   nan],
        [   nan],
        [1.3494],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0990],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0442],
        [1.1113],
        [1.0319],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.4944],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0511],
        [   nan],
        [1.1086],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.1693],
        [1.2343],
        [   nan],
        [   nan],
        [   nan],
        [1.2089],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.1464],
        [   nan],
        [   nan],
        [1.0146],
        [   nan],
        [1.0616],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0513],
        [   nan],
        [   nan],
        [1.0878],
        [   nan],
        [   nan],
        [1.0663],
        [1.0729],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.2027],
        [   nan],
        [1.0763],
        [1.0316],
        [1.0748],
        [1.3685],
        [   nan],
        [   nan],
        [   nan],
        [1.3050],
        [   nan],
        [   nan],
        [   nan],
        [1.2104],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [   nan],
        [1.0835],
        [   nan],
        [   nan]], device='cuda:0', grad_fn=<IndexPutBackward0>)
                                         Trace Shapes:        
                                          Param Sites:        
     model.thickness_transform$$$0.unnormalized_widths   1 8  
    model.thickness_transform$$$0.unnormalized_heights   1 8  
model.thickness_transform$$$0.unnormalized_derivatives   1 7  
    model.thickness_transform$$$0.unnormalized_lambdas   1 8  
                                         Sample Sites:        
                                                T dist 128 | 1
                                                 value 128 | 1

How could I solve the problem and get the correct outputs?

eb8680 commented 9 months ago

@Flawless1202 thanks for the clear and informative report. We've made some changes to this example but haven't re-run it from scratch in a while because it takes quite a long time to train to convergence. It's also somewhat unstable, so I'm not surprised to see some NaNs pop up.

We'll look into it ourselves to see if there are any bugs in the model, but in the meantime you might try reducing the learning rate manually or via a scheduler, stopping early based on a validation loss, or restarting from several different initial points.