I'm not sure if this is resolvable, but the UNI weights result in nan values when doing training or inference on float16. The following is on a H&E stain image with imagenet normalization:
path = hf_hub_download("MahmoodLab/UNI", filename="pytorch_model.bin")
model = timm.create_model("vit_large_patch16_224", init_values=1e-5, num_classes=0).to(device)
missing_k, unexpected_k = model.load_state_dict(torch.load(path), strict=False)
print(f'Missing keys: {missing_k}')
print(f'Unexpected keys: {unexpected_k}')
with torch.autocast(device_type='cuda', dtype=torch.float32):
print(f'float32 output: {model(batch_img)}')
with torch.autocast(device_type='cuda', dtype=torch.float16):
print(f'float16 output: {model(batch_img)}')
model_imagenet = timm.create_model("vit_large_patch16_224", init_values=1e-5, num_classes=0, pretrained=True).to(device)
with torch.autocast(device_type='cuda', dtype=torch.float32):
print(f'float32 output: {model_imagenet(batch_img)}')
with torch.autocast(device_type='cuda', dtype=torch.float16):
print(f'float16 output: {model_imagenet(batch_img)}')
Output: Half-precision works for default ImageNet-pretrain weights but not UNI.
I'm not sure if this is resolvable, but the UNI weights result in nan values when doing training or inference on float16. The following is on a H&E stain image with imagenet normalization:
Output: Half-precision works for default ImageNet-pretrain weights but not UNI.