broadinstitute / CellBender

CellBender is a software package for eliminating technical artifacts from high-throughput single-cell RNA sequencing (scRNA-seq) data.
https://cellbender.rtfd.io
BSD 3-Clause "New" or "Revised" License
295 stars 54 forks source link

Can't pickle `weakref` objects when saving checkpoints #212

Open tzeitim opened 1 year ago

tzeitim commented 1 year ago

I hope this is not too bleeding-edge but I have no other versioning options due to the combination of the GPU-nodes I have access to and software dependencies.

To do a quick recap

I am pulling cellbender 0.3.0 from the branch sf_dev_0.3.0_postreg_posterior_format_h5 since I had the same two issues raised in PR #193. Following @sjfleming 's suggestion to pull the latest pytorch 2.0.0 from their dev branch for pytorch.

git clone -b sf_dev_0.3.0_postreg_posterior_format_h5 https://github.com/broadinstitute/CellBender.git

Unfortunately the issue about learning rates schedulers persisted even after pulling cellbender's 736d6.

After some digging, I managed to solve the pytorch-pyro scheduler issue by pulling pyro from this commit instead.

pip install git+https://github.com/ilia-kats/pyro/@c9ed43a1f90d2f9a92278c68319eb68962b29013

The main issue

cellbender was able to finish training but it raised a new error when trying to write the final checkpoint (and only checkpoint attempted in this data set, that I am aware).

The code in cellbender's checkpoint.py in it's current form just shows that it failed in an attempt to write the checkpoint when it exits.

'Could not save checkpoint'

I had to remove the try block in order to reveal the real issue.

*** TypeError: cannot pickle 'weakref' object

I dissected the individual lines that would trigger the error on their own.

torch.save(model_obj, filebase + '_model.torch')
torch.save(scheduler, filebase + '_optim.torch')
scheduler.save(filebase + '_optim.pyro')  # use PyroOptim method
pyro.get_param_store().save(filebase + '_params.pyro')

Interestingly the model object can be saved by invoking its .state_dict() method.

torch.save(model_obj.state_dict(), filebase + '_model.torch') 

No .state_dict() exists for the scheduler object, though.

To understand the problem a bit better, I omitted the method objects within the scheduler and then torch.savecould run! This strategy indicated that theweakref in anneal_func was to blame.

I did a little bit of googling with this information and I think that the weakref issue is very similar (maybe identical to) this pytorch issue #42376 .

I have decided to open this issue and documented it here as it has gone beyond my ability to resolve for now.


As a footnote- and just for the record -I wrote this non-fancy routine to eliminate the methods mentioned above

def remove_weakrefs(aa):
    remove_keys = []
    for i in aa.keys():
        for s in aa[i].keys():
            for k in aa[i][s].keys():
                print(f'{isinstance(aa[i], weakref.ReferenceType)} {isinstance(aa[i][s], weakref.ReferenceType)} {isinstance(aa[i][s][k], weakref.ReferenceType)} {i} {s} {k} ')
                print(f'{aa[i].__class__.__name__} {aa[i][s].__class__.__name__} {aa[i][s][k].__class__.__name__}  ')
                if aa[i][s][k].__class__.__name__ == "method":
                    remove_keys.append((i,s,k))

    for i,s,k in remove_keys:
        aa[i][s].pop(k)
    return(aa)

aa = remove_weakrefs(scheduler.get_state())

torch.save(aa, filebase + '_optim.torch')  # this works
sjfleming commented 1 year ago

@tzeitim thank you very much for writing in with this! This is the same problem I was running into when trying to move to pytorch 2.0.0, and I have not yet been able to figure it out either. You got farther than I did! So thank you, I appreciate it.

Thanks for pointing out the ilia-kats fix for the pyro issue, I had not seen that yet, and that seems promising.

I hadn't run into that optimizer saving issue yet, but I did run into another weakref pickling issue here: https://github.com/pyro-ppl/pyro/issues/3201

I wonder what the deal is with this weakref stuff in pytorch 2.0.0. They must have refactored some things in a way I don't understand. I wonder why I'm seeing it now with v2.0.0, but never saw it before? But that pytorch issue you linked seems like the right thing.

It seems like this agrees with your fix, and I think I might try this out https://github.com/numenta/nupic.research/pull/328/files

state_dict = self.lr_scheduler.state_dict()
if "anneal_func" in state_dict:
    del state_dict["anneal_func"]

In my own development work, I am currently still using python 3.7 with pytorch < 2.0.0, for the reasons you pointed out. I will be working on these kinds of fixes on the sf_dev_0.3.0_postreg_python3.8 branch to enable python 3.8 and pytorch 2.0.0 compatibility in the future

sjfleming commented 1 year ago

Here's some tracking for this stuff: https://github.com/broadinstitute/CellBender/pull/203

tzeitim commented 1 year ago

Hi @sjfleming - Thanks for your answer and the references. To be honest I was just lucky to find that solution ilia-kats, it was just a couple days old when I found it.

Regarding the source of this issue ... I don't know ... I've spent a lot of time trying to understand the chain of events that lead to a weakref without any success. I'll report of any progress when possible.

I am glad I could help you save some time or identify potential solutions for the future.

Keep up the great work!