pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.95k stars 441 forks source link

[Bug] state_dict returns wrong path when EC is internally DMP wrapped #2584

Open JacoCheung opened 7 hours ago

JacoCheung commented 7 hours ago

Since torchrec register a state_hook. The path of is always prepended a long prefix:

for example, if I have a module structure like

A.B.C.ec = EmbeddingCollection() # A is not DMP wrapped, instead, ec is DMP wrapped.

# then I'll get a state dict key
A.B.C.ec.A.B.C.ec.embeddings.weight

For some reason, I can use DMP for my whole model. Is there any thing I can do?