In [32]: vit = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg').cuda().eval()
Using cache found in /home/truncs/.cache/torch/hub/facebookresearch_dinov2_main
In [33]: image = torch.randn((2, 3, 70, 70)).cuda()
In [34]: jax_vit = t2j(vit)
In [35]: jax_image = t2j(image)
In [36]: params = {k: t2j(v) for k, v in vit.named_parameters()}