Closed splendourbell closed 4 months ago
self.out_features_ls should be [8, 16, 32, 64, 128] based on the current hyperparameters.
Not too sure what is the weights_logits that you are referring to.
out_feat_size is a tensor representing the patch size for each token. torch.eq(...) behaves as a mask, and only adds the current feat_size to out. So, out should be the prediction of each token based on the appropriate patch size with zero padding.
in DistrParamProj.init function:
print(args_dim) {'weights_logits': 4, 'components': [{'df': 1, 'loc': 1, 'scale': 1}, {'loc': 1}, {'total_count': 1, 'logits': 1}, {'loc': 1, 'scale': 1}]}
the code
else proj_layer(
in_features, tuple(dim * of for of in out_features), **kwargs
)
class DistrParamProj(nn.Module): def init( self, in_features: int, out_features: int | tuple[int, ...] | list[int], args_dim: PyTree[int, "T"], domain_map: PyTree[Callable[[torch.Tensor], torch.Tensor], "T"], proj_layer: Callable[..., nn.Module] = MultiOutSizeLinear, kwargs: Any, ): super().init() self.in_features = in_features self.out_features = out_features self.args_dim = args_dim self.domain_map = domain_map self.proj = convert_to_module( tree_map( lambda dim: ( proj_layer(in_features, dim * out_features, *kwargs) if isinstance(out_features, int) else proj_layer( in_features, tuple(dim of for of in out_features), kwargs ) ), args_dim, ) ) self.out_size = ( out_features if isinstance(out_features, int) else max(out_features) )
I see.. I think I get what you mean, will look into it, thanks!
Seems like this is a pretty major bug, fixing it would make predictions with patch size 8, 16 (with the current configuration) have better outputs, and improve performance for low frequency data. Thanks for catching this!
MultiOutSizeLinear.forward self.out_features_ls is [32, 64, 128, 256, 512]
because weights_logits:4 * [8, 16, 32, 64, 256]
when out_feat_size is 8,
the "torch.eq(out_feat_size, feat_size).unsqueeze(-1) " is always is False then the out is alwayls zero. is it right?