mahmoodlab / UNI

Towards a general-purpose foundation model for computational pathology - Nature Medicine
Other
351 stars 48 forks source link

Model weights result in nan-values for half precision #18

Closed LostGeorge closed 8 months ago

LostGeorge commented 8 months ago

I'm not sure if this is resolvable, but the UNI weights result in nan values when doing training or inference on float16. The following is on a H&E stain image with imagenet normalization:

path = hf_hub_download("MahmoodLab/UNI", filename="pytorch_model.bin")
model = timm.create_model("vit_large_patch16_224", init_values=1e-5, num_classes=0).to(device)

missing_k, unexpected_k = model.load_state_dict(torch.load(path), strict=False)
print(f'Missing keys: {missing_k}')
print(f'Unexpected keys: {unexpected_k}')

with torch.autocast(device_type='cuda', dtype=torch.float32):
    print(f'float32 output: {model(batch_img)}')
with torch.autocast(device_type='cuda', dtype=torch.float16):
    print(f'float16 output: {model(batch_img)}')

model_imagenet = timm.create_model("vit_large_patch16_224", init_values=1e-5, num_classes=0, pretrained=True).to(device)

with torch.autocast(device_type='cuda', dtype=torch.float32):
    print(f'float32 output: {model_imagenet(batch_img)}')
with torch.autocast(device_type='cuda', dtype=torch.float16):
    print(f'float16 output: {model_imagenet(batch_img)}')

Output: Half-precision works for default ImageNet-pretrain weights but not UNI.

Missing keys: []
Unexpected keys: []
float32 output: tensor([[-0.9344, -0.0447,  2.0671,  ...,  0.1991,  1.0729, -0.1812]],
       device='cuda:0', grad_fn=<SelectBackward0>)
float16 output: tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<SelectBackward0>)
float32 output: tensor([[ 1.3607,  0.1251, -0.2508,  ...,  0.2557, -0.1732,  0.6628]],
       device='cuda:0', grad_fn=<SelectBackward0>)
float16 output: tensor([[ 1.3607,  0.1251, -0.2508,  ...,  0.2557, -0.1732,  0.6628]],
       device='cuda:0', grad_fn=<SelectBackward0>)
LostGeorge commented 8 months ago

Resolved by Updating timm (0.6.12 -> 0.9.2).