pinellolab / dictys

Context specific and dynamic gene regulatory network reconstruction and analysis
GNU Affero General Public License v3.0
108 stars 14 forks source link

NaN's arise in dictys.network.reconstruct #66

Closed ekernf01 closed 2 months ago

ekernf01 commented 2 months ago

Dear Dr. Wang, I am encountering some NaNs when running a portion of the Dictys workflow on simple inputs through the Python interface. I am not sure if this type of analysis will be supported, but here are details that I hope will enable you to reproduce the issue. Thanks very much.

Checks before submitting the issue

Describe the error

This error happens with Dictys cpu version 1.0.0 build hg9576260_0 installed via conda, or when running Dictys using the provided Docker image lfwa/dictys-cpu.

Full error message:

/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:246: UserWarning: Encountered NaN: log_prob_sum at site 'G_0'
  warn_if_nan(
/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:246: UserWarning: Encountered NaN: log_prob_sum at site 'G_obs'
  warn_if_nan(
/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:285: UserWarning: Encountered NaN: log_prob_sum at site 'G_0'
  warn_if_nan(
/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py:92: UserWarning: Encountered NaN: loss[G_0]
  warn_if_nan(surrogate_loss[xi], f"loss[{xi}]")
/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py:92: UserWarning: Encountered NaN: loss[G_obs]
  warn_if_nan(surrogate_loss[xi], f"loss[{xi}]")
Traceback (most recent call last):
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 585, in guide
    t1=dist.Normal(G_0_mean[ind_c],G_0_std[ind_c]).to_event(1)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/torch/distributions/distribution.py", line 62, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (1385, 1385)) of distribution Normal(loc: torch.Size([1385, 1385]), scale: torch.Size([1385, 1385])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<IndexBackward0>)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 790, in reconstruct
    model.train_svi(nstep,nstep_report=nstep_report)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 378, in train_svi
    self.stat_loss.append(self.svi.step())
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 117, in _loss_and_grads
    loss_val = tuple(losses(*args, **kwargs))
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 84, in differentiable_loss
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/infer/elbo.py", line 237, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 57, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/infer/enum.py", line 60, in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/dictys/network.py", line 585, in guide
    t1=dist.Normal(G_0_mean[ind_c],G_0_std[ind_c]).to_event(1)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/torch/distributions/distribution.py", line 62, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (1385, 1385)) of distribution Normal(loc: torch.Size([1385, 1385]), scale: torch.Size([1385, 1385])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<IndexBackward0>)
Trace Shapes:          
 Param Sites:          
     G_0_mean 1385 1385
      G_0_std 1385 1385
Sample Sites:          
   cell1 dist         |
        value 1385    |

Where: The error happens when I run this code with the attached input files in the working directory.

dictys.network.reconstruct(
    fi_exp="exp.txt",
    fi_mask="mask.txt",
    fo_weight="weight.tsv",
    fo_meanvar="meanvar.tsv",
    fo_covfactor="covfactor.tsv",
    fo_loss="loss.tsv",
    fo_stats="stats.tsv", 
)

The expression data are real 10X 3' scRNA data from a human cell line. I have selected ~1500 variable genes and subsetted down to a single stage of differentiation. The network is a subset of CellOracle's default base network from motif analysis of human promoters.

exp.txt mask.txt

lingfeiwang commented 2 months ago

Hi ekernf01,

You need to perform qc first to e.g. remove genes with no or low read counts. The recommended way to use custom mask file (binlinking.tsv.gz) is to first run the original pipeline, replace the mask file with your own (make sure the modification time is touched), remove the output h5 file, and then run it again (dictys_helper network_inference.sh ...). If you want, you can look into the makefiles to understand file dependencies at https://github.com/pinellolab/dictys/blob/master/src/dictys/scripts/makefiles/common.mk.

If that doesn't work, please let us know with the error message and updated input files.

Lingfei

ekernf01 commented 2 months ago

Thank you. I will read more from the documentation and try to do it this way, if my data enable it.