gpleiss / efficient_densenet_pytorch

A memory-efficient implementation of DenseNets
MIT License
1.51k stars 329 forks source link

How can I apply this to my own model? #70

Open CXMANDTXW opened 3 years ago

CXMANDTXW commented 3 years ago

Thank you for your nice work. My model use the densenet connections like:

tensorFeat = torch.cat([self.moduleOne(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleTwo(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleThr(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleFou(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleFiv(tensorFeat), tensorFeat], 1)

What do I need to do to implement efficient technology to save this part of memory consumption.Densenet connections is just a part of my full model.

gpleiss commented 3 years ago

It really depends on the other aspects of your model. This implementation uses torch's gradient checkpointing feature: https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py#L38 - which trades off time for memory efficiency.

See these docs for more information.