lucidrains / axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
MIT License
346 stars 30 forks source link

Problem of ParameterList with nn.DataParallel #11

Closed resuly closed 3 years ago

resuly commented 3 years ago

https://github.com/lucidrains/axial-attention/blob/a1a483c0f4a3922eef8f9a857dc1a802523bd437/axial_attention/axial_attention.py#L100

This line would lead to the following issue: "UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."

It is a known issue here

The simple solution should be to store the Parameters directly on the Module.

class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, shape, emb_dim_index = 1):
        super().__init__()
        parameters = []
        total_dimensions = len(shape) + 2
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]

        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
            shape = [1] * total_dimensions
            shape[emb_dim_index] = dim
            shape[axial_dim_index] = axial_dim
            parameter = nn.Parameter(torch.randn(*shape))
            setattr(self, f'param_{i}', parameter)
            setattr(self, f'param_num', i+1)

    def forward(self, x):
        for i in range(self.param_num):
            x = x + getattr(self, f'param_{i}')
        return x
lucidrains commented 3 years ago

@resuly thank you! i decided to go with your solution in 0.6.1 :)