huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.29k stars 25.45k forks source link

Fix dtype casting in swinv2 and swinv2sr to allow non-FP32 inference #31589

Closed aliencaocao closed 2 days ago

aliencaocao commented 3 days ago

What does this PR do?

The current implementation uses .float() in https://github.com/huggingface/transformers/blob/0f67ba1d741d65b07d549daf4ee157609ce4f9c1/src/transformers/models/swin2sr/modeling_swin2sr.py#L286-L287 which causes subsequent relative_coords_table to be always in torch.float32, not respecting whatever precision the other weights might be, e.g. torch.float16.

This PR adds a cast to the same dtype as the continuous_position_bias_mlp layer since relative_coords_table is being passed directly into the layer at https://github.com/huggingface/transformers/blob/0f67ba1d741d65b07d549daf4ee157609ce4f9c1/src/transformers/models/swin2sr/modeling_swin2sr.py#L349

Same issue & fix for swinv2

Prerequisite for https://github.com/huggingface/transformers/pull/31342 image to image pipeline FP16 test to pass.

Before submitting

Who can review?

@amyeroberts

aliencaocao commented 3 days ago

Added them, they pass locally

amyeroberts commented 2 days ago

@aliencaocao Great! Could you push an empty commit with the message: [run_slow] swin2sr, swinv2. I trust the tests are passing locally, but because of differences that can creep in because of hardware and env set-up, the logits can still be slightly different. So let's make sure the numbers match what's going to be running on the CI : )

HuggingFaceDocBuilderDev commented 2 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

aliencaocao commented 2 days ago

@amyeroberts need your approval for slow tests

aliencaocao commented 2 days ago

ugh it seems the specific gpu used indeed has numerical difference... I ran the tests and got logit using RTX 3080Ti, torch2.3.1+cu121, nvidia 555.95 driver, Windows 10

amyeroberts commented 2 days ago

@aliencaocao As the change is simple, and look OK, updating the tests to use the results from the CI runs should be OK

aliencaocao commented 2 days ago

How do I get the CI outputs? Do I have to print in CI?

amyeroberts commented 2 days ago

@aliencaocao Good question! Indeed, they're not part of the console output. Let me see if I can ssh in

amyeroberts commented 2 days ago

@aliencaocao Running on the runners, I get the following logits

swin2sr

tensor([[0.5454, 0.5542, 0.5640],
        [0.5518, 0.5562, 0.5649],
        [0.5391, 0.5425, 0.5620]], device='cuda:0', dtype=torch.float16)

swinv2

tensor([-0.3938, -0.4290,  0.0020], device='cuda:0', dtype=torch.float16)
aliencaocao commented 2 days ago

Thanks, triggered again

amyeroberts commented 2 days ago

@aliencaocao Thanks! All looks good - we can merge 🤗