TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
324 stars 86 forks source link

`unpack_tuple()` is no longer correct with timm v1.0.3 #34

Open yukw777 opened 1 month ago

yukw777 commented 1 month ago

timm v1.0.3 was just released 2 hours ago (https://github.com/huggingface/pytorch-image-models/releases/tag/v1.0.3) and it seems like they've reworked the API for forward_intermediates() and it returns a list instead of a tuple. As a result, when I run scripts.generate.py with all the default settings and a simple question Is the coffee cup empty?, I get the following error:

Traceback (most recent call last):
  File "/home/peter/repos/prismatic-vlms/scripts/generate.py", line 133, in <module>
    generate()
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/draccus/argparsing.py", line 203, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/scripts/generate.py", line 116, in generate
    generated_text = vlm.generate(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 553, in generate
    generated_ids = super().generate(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1576, in generate
    result = self._greedy_search(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2494, in _greedy_search
    outputs = self(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 311, in forward
    patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values})
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/backbones/vision/dinosiglip_vit.py", line 147, in forward
    return torch.cat([dino_patches, siglip_patches], dim=2)
TypeError: expected Tensor as element 0 in argument 0, but got list

The following diff fixes the issue:

diff --git a/prismatic/models/backbones/vision/base_vision.py b/prismatic/models/backbones/vision/base_vision.py
index e9ccade..cf67351 100644
--- a/prismatic/models/backbones/vision/base_vision.py
+++ b/prismatic/models/backbones/vision/base_vision.py
@@ -11,7 +11,7 @@ Transformer model for feature extraction.
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from functools import partial
-from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union, Sequence

 import timm
 import torch
@@ -27,7 +27,7 @@ from torchvision.transforms import Compose, Resize
 def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         result = fn(*args, **kwargs)
-        return result[0] if isinstance(result, tuple) else result
+        return result[0] if isinstance(result, Sequence) else result

     return wrapper

I'm happy to submit a PR for this, but seeing that this is related to monkey patching for FSDP support, I wanted to discuss how to properly fix it before moving forward.

siddk commented 1 month ago

I just pushed a commit to pin timm==0.9.10 for the time being to make sure this doesn't break things for other folks.

I'd love it if you could push a PR, maybe add a test that verifies that results with different versions return the same output? Based on your PR, I can then test FSDP functionality and make sure everything checks out!

yukw777 commented 1 month ago

Great! A few questions for you:

  1. Should we now drop support for timm < 1.0.0 now that timm reached 1.0.0? It'll significantly lessen the ongoing maintenance effort by depending on a (supposedly) stable API.
  2. If we do decide to keep supporting timm < 1.0.0, it'd be a good idea to write regression tests, but how do you guys usually write tests? I haven't been able to find an example in the repo.
  3. Does my quick fix look good to you? I may also want to rename the function to unpack_seq(), and it'd support both pre-1.0 timm and post-1.0 timm. If we do decide to drop support for timm < 1.0.0, I may just check for list (and rename the function to unpack_list()) and bubble up the error instead of eating it up.