SalesforceAIResearch / uni2ts

[ICML2024] Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
796 stars 80 forks source link

weights_logits:4 the MultiOutSizeLinear.forward out is alwayls zero #44

Closed splendourbell closed 4 months ago

splendourbell commented 4 months ago

MultiOutSizeLinear.forward self.out_features_ls is [32, 64, 128, 256, 512]
because weights_logits:4 * [8, 16, 32, 64, 256]

when out_feat_size is 8,
the "torch.eq(out_feat_size, feat_size).unsqueeze(-1) " is always is False then the out is alwayls zero. is it right?

def forward(
    self,
    x: Float[torch.Tensor, "*batch in_feat"],
    out_feat_size: Int[torch.Tensor, "*batch"],
) -> Float[torch.Tensor, "*batch max_feat"]:
    out = 0
    for idx, feat_size in enumerate(self.out_features_ls):
        weight = self.weight[idx] * self.mask[idx]
        bias = self.bias[idx] if self.bias is not None else 0
        out = out + (
            torch.eq(out_feat_size, feat_size).unsqueeze(-1)
            * (einsum(weight, x, "out inp, ... inp -> ... out") + bias)
        )
    return out
gorold commented 4 months ago

self.out_features_ls should be [8, 16, 32, 64, 128] based on the current hyperparameters.

Not too sure what is the weights_logits that you are referring to.

out_feat_size is a tensor representing the patch size for each token. torch.eq(...) behaves as a mask, and only adds the current feat_size to out. So, out should be the prediction of each token based on the appropriate patch size with zero padding.

splendourbell commented 4 months ago

in DistrParamProj.init function:

print(args_dim) {'weights_logits': 4, 'components': [{'df': 1, 'loc': 1, 'scale': 1}, {'loc': 1}, {'total_count': 1, 'logits': 1}, {'loc': 1, 'scale': 1}]}

the code

else proj_layer(
                        in_features, tuple(dim * of for of in out_features), **kwargs
                    ) 

when dim is 4(weights_logits param), the tuple will be [32, 64, 128, 256, 512]. then the proj_layer(MultiOutSizeLinear).forward out always zero

class DistrParamProj(nn.Module): def init( self, in_features: int, out_features: int | tuple[int, ...] | list[int], args_dim: PyTree[int, "T"], domain_map: PyTree[Callable[[torch.Tensor], torch.Tensor], "T"], proj_layer: Callable[..., nn.Module] = MultiOutSizeLinear, kwargs: Any, ): super().init() self.in_features = in_features self.out_features = out_features self.args_dim = args_dim self.domain_map = domain_map self.proj = convert_to_module( tree_map( lambda dim: ( proj_layer(in_features, dim * out_features, *kwargs) if isinstance(out_features, int) else proj_layer( in_features, tuple(dim of for of in out_features), kwargs ) ), args_dim, ) ) self.out_size = ( out_features if isinstance(out_features, int) else max(out_features) )

gorold commented 4 months ago

I see.. I think I get what you mean, will look into it, thanks!

gorold commented 4 months ago

Seems like this is a pretty major bug, fixing it would make predictions with patch size 8, 16 (with the current configuration) have better outputs, and improve performance for low frequency data. Thanks for catching this!