wesselb / lab

A generic interface for linear algebra backends
MIT License
70 stars 5 forks source link

mean with squeeze=False works incorrectly under torch #21

Open vabor112 opened 1 week ago

vabor112 commented 1 week ago

If axis is None (or not provided), mean(torch_tensor, squeeze=squeeze) will ignore the value of squeeze.

This is obvious from the implementation:

@dispatch
def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True):
    if axis is None:
        return torch.mean(a)
    else:
        return torch.mean(a, dim=axis, keepdim=not squeeze)

Why not just change the implementation to the following? It works as expected for me and handles axis=None okay. (Disclaimer: I tested it only under torch 2.2.0)

@dispatch
def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True):
        return torch.mean(a, dim=axis, keepdim=not squeeze)