patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

`copy.deepcopy` cannot copy `torchtyping.utils.frozendict` objects #32

Closed anivegesana closed 2 years ago

anivegesana commented 2 years ago

copy.deepcopy is used in some training loops to save a copy of the model parameters that has the best validation loss. Currently, the following code gives an error:

import copy
from torchtyping import TensorType
copy.deepcopy(TensorType["batch", "embedding"])
TypeError: __setitem__() takes 2 positional arguments but 3 were given

Instead, copying the dictionary should just return the dictionary unaltered. Since it cannot be changed, the copy does nothing.

patrick-kidger commented 2 years ago

LGTM, thanks!