Closed aliencaocao closed 2 days ago
Added them, they pass locally
@aliencaocao Great! Could you push an empty commit with the message: [run_slow] swin2sr, swinv2
. I trust the tests are passing locally, but because of differences that can creep in because of hardware and env set-up, the logits can still be slightly different. So let's make sure the numbers match what's going to be running on the CI : )
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@amyeroberts need your approval for slow tests
ugh it seems the specific gpu used indeed has numerical difference... I ran the tests and got logit using RTX 3080Ti, torch2.3.1+cu121, nvidia 555.95 driver, Windows 10
@aliencaocao As the change is simple, and look OK, updating the tests to use the results from the CI runs should be OK
How do I get the CI outputs? Do I have to print in CI?
@aliencaocao Good question! Indeed, they're not part of the console output. Let me see if I can ssh in
@aliencaocao Running on the runners, I get the following logits
swin2sr
tensor([[0.5454, 0.5542, 0.5640],
[0.5518, 0.5562, 0.5649],
[0.5391, 0.5425, 0.5620]], device='cuda:0', dtype=torch.float16)
swinv2
tensor([-0.3938, -0.4290, 0.0020], device='cuda:0', dtype=torch.float16)
Thanks, triggered again
@aliencaocao Thanks! All looks good - we can merge 🤗
What does this PR do?
The current implementation uses
.float()
in https://github.com/huggingface/transformers/blob/0f67ba1d741d65b07d549daf4ee157609ce4f9c1/src/transformers/models/swin2sr/modeling_swin2sr.py#L286-L287 which causes subsequentrelative_coords_table
to be always intorch.float32
, not respecting whatever precision the other weights might be, e.g.torch.float16
.This PR adds a cast to the same
dtype
as thecontinuous_position_bias_mlp
layer sincerelative_coords_table
is being passed directly into the layer at https://github.com/huggingface/transformers/blob/0f67ba1d741d65b07d549daf4ee157609ce4f9c1/src/transformers/models/swin2sr/modeling_swin2sr.py#L349Same issue & fix for swinv2
Prerequisite for https://github.com/huggingface/transformers/pull/31342 image to image pipeline FP16 test to pass.
Before submitting
Who can review?
@amyeroberts