Open shaibagon opened 3 years ago
Thanks for your great suggestions. Will update them accordingly after testing.
torch.ones_like(input[:1, :1, ...])
Does input[:1, :1, ...]
taking a slice of the tensor without making a copy?
@ivanstepanovftw I don't think it makes a copy. However, since it is only used as an argument for torch.ones_like
the values of the tensor are not being used at all, only the shape
, dtype
and device
.
Also I think this operation is useless: https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L75 because input will be multiplied by the mask anyway here: https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L70
Hi, I really like your work - thank you for sharing the code.
I have some remarks on the code of PartialConv2d:
1. Using buffers instead of simple "member tensors":
The class attribute
self.weight_maskUpdater
is defined as a "plain" tensor: https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L32-L35 As a result, when the model is transferred to GPU, or data type is changing you need to explicitly check for it and change it: https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L49-L50A more elegant way is to use non-persistent buffers:
This way
self.weight_maskUpdater
will be affected by any.to(...)
/.cuda()
/.cpu()
invoked on the model hosting this layer, making the condition on line 49 redundent.2. Use of
torch.ones_like
Instead of
torch.ones(...).to(...)
https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L54-L57 You can usetorch.ones_like
which is simpler and easier to read:3. No need to import
Variable
You import
Variable
, but never use it. https://github.com/NVIDIA/partialconv/blob/610d373f35257887d45adae84c86d0ce7ad808ec/models/partialconv2d.py#L12BTW All these comments are applicable to the code of
PartialConv3d
.