Closed jmif closed 1 year ago
Hey thank you for pointing this out! I think this went through a slightly unexpected code path. In particular, our repo llm-foundry is how we typically interact with HF models. For example:
Here we have code that makes it easier to interact with HF models + FSDP:
We are open to PRs to help smooth out the wrapping process.
Ah ok I see the code, thanks for the reference. I'll close for now but will keep this in mind as I continue to understand the project, would be happy to contribute once I have a better understanding of things. Thanks!
I've got a setup that roughly looks like this:
When I go to train this, I get an error that says
FullyShardedDataParallel has no method len()
. Upon digging into the model code, I found that the underlying GPT2 model has a nn.ModuleList as one of it's members.The ModuleList gets wrapped and this ends up breaking the training run because the training run references
len(self.h)
. I was able to fix this by setting_fsdp_wrap
to false on the ModuleList and_fsdp_wrap
to true on all of the modules in the ModuleList. It occurred to me that it may not make sense to wrap ModuleList by default so opening an issue in case there is an opportunity to improve the default wrap implementation or fix a bug there. I'm fairly new at this, so this may be uninformed :).