Open vitkl opened 1 year ago
Hi @vitkl, Could you provide more details about the memory leak? Does memory use happen during training/inference or during prediction? Where are the observed data being saved? Are you poutine.trace
ing the guide or somehow saving sample statements beyond the invocation of a single .guide()
? Could you point to some places in the code where the problem occurs?
EDIT to clarify this bug, is it an issue with prediction in .median()
and .quantiles()
, or is it an issue with the usual training method .__call__()
?
The goal of this function is to generate samples from the posterior, however, the same problem exists with .median()
and .quantiles()
.
This is not really a memory leak (GPU memory use doesn't increase with time). It's simply that
if isinstance(self.module.guide, poutine.messenger.Messenger):
# This already includes trace-replay behavior.
sample = self.module.guide(*args, **kwargs)
returns all sites including observed sites. For gene expression data, that means returning the densified count matrix n_samples number of times. That's a huge matrix. During training, the model obviously needs to compute the likelihood of this data and that's not a problem because of minibatch training and because during training only one sample is generated. However, there is no need to return observed sites when generating posterior samples and when computing the median and quantiles.
An alternative is poutine.trace
approach https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L186-L208 for sampling which doesn't have this problem because observed sites can be excluded https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L200. It doesn't address the issue with .median()
and .quantiles()
though.
To add my "solution" to this which uses infer.Predictive and checks once that the variable is on a per cell basis and not a global parameter. It records all samples. I am not sure what problem there is with .median()
adata = self._validate_anndata(adata)
train_dl = self._make_data_loader(
adata=adata, indices=indices, shuffle=False, batch_size=batch_size
)
self.to_device(device)
model = model if model else self.module.model
# sample local parameters
for tensor_dict in track(
train_dl,
style="tqdm",
description="Sampling local variables, batch: ",
):
args, kwargs = self.module._get_fn_args_from_batch(tensor_dict)
args = [a.to(device) for a in args]
kwargs = {k: v.to(device) for k, v in kwargs.items()}
if library_size is not None:
kwargs["library"] = torch.full_like(
kwargs["library"], torch.log(torch.tensor(library_size))
)
samples_ = infer.Predictive(
model,
num_samples=num_samples,
guide=self.module.guide,
return_sites=return_sites,
)(*args, **kwargs)
if not samples:
model_trace = poutine.trace(model).get_trace(*args, **kwargs)
return_sites = [
i
for i in model_trace.nodes.keys()
if model_trace.nodes[i].pop("cond_indep_stack", None)
]
samples = {k: [v.cpu()] for k, v in samples_.items()}
else:
# Record minibatches if variable is minibatch dependent
samples = {
k: v + [samples_[k].cpu()] if k in return_sites else v
for k, v in samples.items()
}
samples = {
k: torch.cat(v, axis=1).numpy() for k, v in samples.items()
} # for each variable
It is critical for me to copy to cpu during execution (GPU memory to small). I have to run gc.collect() afterwards to get the GPU emptied so I think there is an additional leak.
The code in scvi-tools already checks which variables are in minibatch plate and aggregates minibatched variables along the correct dimensions in https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L357-L409. This code also copies each sample to CPU to avoid keeping 1000 samples on GPU: https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L210 .
You code assumes that presence of a plate means that this variable is minibatched and assumes -1 dimension. The code above implements a more general solution that checks plate name and fetches dimension.
I don't see how your code solves the problem of not saving observed variables. Does infer.Predictive
return only unobserved variables?
When I was implementing this @martinjankowiak recommended to not use infer.Predictive
because we have a simpler problem.
There is an additional bug: _get_obs_plate_sites
correctly excludes observed variables from its list of obs_plate_sites - which means that observed sites are treated as global variables. This might explain why memory use increased from 3.8GB to 40GB rather than >> 40GB. Created a PR to fix both issues: https://github.com/scverse/scvi-tools/pull/1805
Hi
Posterior sampling with Messenger Pyro guides does not remove observed variables leading to huge memory use. https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_pyromixin.py#L184 I don't know whether when sampling from Messenger guides it is possible to easily detect and exclude observed variables. @fritzo any recommendations?
This can be addressed
Related to https://github.com/BayraktarLab/cell2location/pull/144/