This line would lead to the following issue:
"UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."
The simple solution should be to store the Parameters directly on the Module.
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index = 1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f'param_{i}', parameter)
setattr(self, f'param_num', i+1)
def forward(self, x):
for i in range(self.param_num):
x = x + getattr(self, f'param_{i}')
return x
https://github.com/lucidrains/axial-attention/blob/a1a483c0f4a3922eef8f9a857dc1a802523bd437/axial_attention/axial_attention.py#L100
This line would lead to the following issue: "UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."
It is a known issue here
The simple solution should be to store the Parameters directly on the Module.