samuela / torch2jax

Run PyTorch in JAX. 🤝
168 stars 5 forks source link

XlaRuntimeError: UNIMPLEMENTED: from_dlpack got array with non-default layout with minor-to-major dimensions (2,0,1), expected (2,1,0) #6

Closed truncs closed 1 week ago

truncs commented 1 week ago
In [32]: vit = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg').cuda().eval()
Using cache found in /home/truncs/.cache/torch/hub/facebookresearch_dinov2_main

In [33]: image = torch.randn((2, 3, 70, 70)).cuda()

In [34]: jax_vit = t2j(vit)

In [35]: jax_image = t2j(image)

In [36]: params = {k: t2j(v) for k, v in vit.named_parameters()}
samuela commented 1 week ago

this looks like a jax.dlpack.from_dlpack issue, we don't do anything interesting here besides https://github.com/samuela/torch2jax/blob/bd7bd9c95253c89ffb7a25cc0ff2ccb296f6cfbf/torch2jax/__init__.py#L18-L19