Closed SamDuffield closed 2 months ago
It appears this is a more general problem of composing functools.partial
and torch.vmap
def f(a, x):
return a + (x**2).sum()
a = 3.0
x_all = torch.randn(3, 10)
f(a, x_all)
torch.vmap(f, in_dims=(None, 0))(a, x_all) # Works fine
partial(torch.vmap(f, in_dims=(None, 0)), x=x_all)(a) # Throws error
# ValueError: vmap(f, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
# *]) but inputs has structure TreeSpec(tuple, None, [*]).
This issue that in_dims
does not support kwargs
https://github.com/google/jax/issues/20914
This can be simply fixed with a wrapper ensuring no kwargs
in per_samplify
Fixed in #79
Works fine:
Throws ValueError: