bucky527 / scME

2 stars 2 forks source link

Problem when I want to get the latent space representation after training #1

Open 4ivage32 opened 7 months ago

4ivage32 commented 7 months ago

I want to integrate my paired RNA data and Protein data as the tutorial described, but encountered an error when I'm going to get the latent space representation after training. The message is

RuntimeError                              Traceback (most recent call last)
Cell In[15], line 7
      5 rnatorch,proteintorch=rnatorch.to(model.device),proteintorch.to(model.device)
      6 model.eval()
----> 7 zm=model.inference(rnatorch, proteintorch)

File ~/Workspace/Multiomics/benchmark/script/./scME/pyroMethod.py:653, in ScMESVI_2.inference(self, rna, protein)
    650 self.eval() 
    651 # zr_loc,zr_scale,l_loc,l_scale=self.zr_encoder((rna,yr))
    652 # zp_loc,zp_scale,c_loc,c_scale,pi=self.zp_encoder((protein,yp))
--> 653 zr_loc,zr_scale,l_loc,l_scale=self.zr_encoder(rna)
    654 zp_loc,zp_scale,c_loc,c_scale,pi=self.zp_encoder(protein)
    655 zm_loc,zm_scale=self.zm_encoder((zr_loc,zp_loc))

RuntimeError: mat1 and mat2 must have the same dtype

The rnatorch.shape is torch.Size([16204, 2000]) and I get model.zr_encoder like

MLP(
  (sequential_mlp): Sequential(
    (0): ConcatModule()
    (1): DataParallel(
      (module): Linear(in_features=2000, out_features=1000, bias=True)
    )
    (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): DataParallel(
      (module): Linear(in_features=1000, out_features=256, bias=True)
    )
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): DataParallel(
      (module): Linear(in_features=256, out_features=64, bias=True)
    )
    (8): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): ListOutModule(
      (0): Sequential(
        (0): Linear(in_features=64, out_features=24, bias=True)
      )
      (1): Sequential(
        (0): Linear(in_features=64, out_features=24, bias=True)
        (1): Softplus(beta=1, threshold=20)
      )
      (2): Sequential(
        (0): Linear(in_features=64, out_features=1, bias=True)
      )
      (3): Sequential(
        (0): Linear(in_features=64, out_features=1, bias=True)
        (1): Softplus(beta=1, threshold=20)
      )
    )
  )
)

I consider the shapes of rnatorch and the MLP are matched, so I don't know where the problem is. What could I do to fix it? I'd appreciate it if you could consider my issue!

4ivage32 commented 7 months ago

I get it working well with another dataset! But the problem is that when I work with the dataset above, I met an error that

#train the scme model
model=scme.train_model(model,max_epochs=10)---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
    229 try:
--> 230     log_p = site["fn"].log_prob(
    231         site["value"], *site["args"], **site["kwargs"]
    232     )
    233 except ValueError as e:

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/independent.py:99, in Independent.log_prob(self, value)
     98 def log_prob(self, value):
---> 99     log_prob = self.base_dist.log_prob(value)
    100     return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/distributions/zero_inflated.py:71, in ZeroInflatedDistribution.log_prob(self, value)
     70 if self._validate_args:
---> 71     self._validate_sample(value)
     73 if "gate" in self.__dict__:

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value)
    299 if not valid.all():
--> 300     raise ValueError(
    301         "Expected value argument "
    302         f"({type(value).__name__} of shape {tuple(value.shape)}) "
    303         f"to be within the support ({repr(support)}) "
    304         f"of the distribution {repr(self)}, "
    305         f"but found invalid values:\n{value}"
    306     )

ValueError: Expected value argument (Tensor of shape (256, 2000)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([256, 2000])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.1207, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 1.0324, 1.0324, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 1.1781, 1.1781, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[74], line 2
      1 #train the scme model
----> 2 model=scme.train_model(model,max_epochs=10)

File ~/Workspace/Multiomics/benchmark/script/./scME/scme.py:136, in train_model(model, max_epochs, batchsize, lr, lr_cla, milestones, save_model, save_dir)
    134 yr=yr.to(device)
    135 yp=yp.to(device)
--> 136 loss1=svi.step(rna,protein,yr,yp)
    137 loss2=svi2.step(rna,protein,yr,yp)
    138 losses_ae.append(loss1)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
    235 else:
    236     for i in range(self.num_particles):
--> 237         yield self._get_trace(model, guide, args, kwargs)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     72 guide_trace = prune_subsample_sites(guide_trace)
     73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
     76 guide_trace.compute_score_parts()
     77 if is_validation_enabled():

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:236, in Trace.compute_log_prob(self, site_filter)
    234     _, exc_value, traceback = sys.exc_info()
    235     shapes = self.format_shapes(last_site=site["name"])
--> 236     raise ValueError(
    237         "Error while computing log_prob at site '{}':\n{}\n{}".format(
    238             name, exc_value, shapes
    239         )
    240     ).with_traceback(traceback) from e
    241 site["unscaled_log_prob"] = log_p
    242 log_p = scale_and_mask(log_p, site["scale"], site["mask"])

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
    228 if "log_prob" not in site:
    229     try:
--> 230         log_p = site["fn"].log_prob(
    231             site["value"], *site["args"], **site["kwargs"]
    232         )
    233     except ValueError as e:
    234         _, exc_value, traceback = sys.exc_info()

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/independent.py:99, in Independent.log_prob(self, value)
     98 def log_prob(self, value):
---> 99     log_prob = self.base_dist.log_prob(value)
    100     return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/distributions/zero_inflated.py:71, in ZeroInflatedDistribution.log_prob(self, value)
     69 def log_prob(self, value):
     70     if self._validate_args:
---> 71         self._validate_sample(value)
     73     if "gate" in self.__dict__:
     74         gate, value = broadcast_all(self.gate, value)

File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value)
    298 valid = support.check(value)
    299 if not valid.all():
--> 300     raise ValueError(
    301         "Expected value argument "
    302         f"({type(value).__name__} of shape {tuple(value.shape)}) "
    303         f"to be within the support ({repr(support)}) "
    304         f"of the distribution {repr(self)}, "
    305         f"but found invalid values:\n{value}"
    306     )

ValueError: Error while computing log_prob at site 'rna_count':
Expected value argument (Tensor of shape (256, 2000)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([256, 2000])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.1207, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 1.0324, 1.0324, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 1.1781, 1.1781, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

It was fixed after turning the data into integer by

#create training dataset
rna.X=rna.layers["counts"].astype(int)
protein.X=protein.layers["counts"].astype(int)
# rna.X=rna.layers["counts"]
# protein.X=protein.layers["counts"]
traindataset=scme.AnnDataset(rna,protein,to_onehot=True)

which caused new error as I described above.