isl-org / ZoeDepth

Metric depth estimation from a single image
MIT License
2.25k stars 207 forks source link

Error(s) in loading state_dict for ZoeDepth #72

Open andysingal opened 1 year ago

andysingal commented 1 year ago

Hi, I ran command:

!git clone https://github.com/isl-org/ZoeDepth.git
%cd ZoeDepth
import torch
import matplotlib
import matplotlib.cm
import numpy as np

from zoedepth.utils.misc import get_image_from_url, colorize
from PIL import Image
import matplotlib.pyplot as plt

torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)
zoe = zoe.to("cuda")

def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()

    value = value.squeeze()
    if invalid_mask is None:
        invalid_mask = value == invalid_val
    mask = np.logical_not(invalid_mask)

    # normalize
    vmin = np.percentile(value[mask],2) if vmin is None else vmin
    vmax = np.percentile(value[mask],85) if vmax is None else vmax
    if vmin != vmax:
        value = (value - vmin) / (vmax - vmin)  # vmin..vmax
    else:
        # Avoid 0-division
        value = value * 0.

    # squeeze last dim if it exists
    # grey out the invalid values

    value[invalid_mask] = np.nan
    cmapper = matplotlib.cm.get_cmap(cmap)
    if value_transform:
        value = value_transform(value)
        # value = value / value.max()
    value = cmapper(value, bytes=True)  # (nxmx4)

    # img = value[:, :, :]
    img = value[...]
    img[invalid_mask] = background_color

    # gamma correction
    img = img / 255
    img = np.power(img, 2.2)
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    return img

def get_zoe_depth_map(image):
    with torch.autocast("cuda", enabled=True):
        depth = model_zoe_n.infer_pil(image)
    depth = colorize(depth, cmap="gray_r")
    return depth

ERROR:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 11
      7 from PIL import Image
      8 import matplotlib.pyplot as plt
---> 11 zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)
     13 torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo
     14 zoe = torch.hub.load(".", "ZoeD_N", source="local", pretrained=True)

File /usr/local/lib/python3.10/dist-packages/torch/hub.py:542, in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
    538 if source == 'github':
    539     repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
    540                                        verbose=verbose, skip_validation=skip_validation)
--> 542 model = _load_local(repo_or_dir, model, *args, **kwargs)
    543 return model

File /usr/local/lib/python3.10/dist-packages/torch/hub.py:572, in _load_local(hubconf_dir, model, *args, **kwargs)
    569 hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
    571 entry = _load_entry_from_hubconf(hub_module, model)
--> 572 model = entry(*args, **kwargs)
    574 sys.path.remove(hubconf_dir)
    576 return model

File /workspace/ZoeDepth/./hubconf.py:69, in ZoeD_N(pretrained, midas_model_type, config_mode, **kwargs)
     66     pretrained_resource = "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt"
     68 config = get_config("zoedepth", config_mode, pretrained_resource=pretrained_resource, **kwargs)
---> 69 model = build_model(config)
     70 return model

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/builder.py:51, in build_model(config)
     48 except AttributeError as e:
     49     raise ValueError(
     50         f"Model {config.model} has no get_version function.") from e
---> 51 return get_version(config.version_name).build_from_config(config)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py:250, in ZoeDepth.build_from_config(config)
    248 @staticmethod
    249 def build_from_config(config):
--> 250     return ZoeDepth.build(**config)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py:245, in ZoeDepth.build(midas_model_type, pretrained_resource, use_pretrained_midas, train_midas, freeze_midas_bn, **kwargs)
    243 if pretrained_resource:
    244     assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
--> 245     model = load_state_from_resource(model, pretrained_resource)
    246 return model

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:84, in load_state_from_resource(model, resource)
     82 if resource.startswith('url::'):
     83     url = resource.split('url::')[1]
---> 84     return load_state_dict_from_url(model, url, progress=True)
     86 elif resource.startswith('local::'):
     87     path = resource.split('local::')[1]

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:61, in load_state_dict_from_url(model, url, **kwargs)
     59 def load_state_dict_from_url(model, url, **kwargs):
     60     state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
---> 61     return load_state_dict(model, state_dict)

File ~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py:49, in load_state_dict(model, state_dict)
     45         k = 'module.' + k
     47     state[k] = v
---> 49 model.load_state_dict(state)
     50 print("Loaded successfully")
     51 return model

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ZoeDepth:
    Unexpected key(s) in state_dict: "core.core.pretrained.model.blocks.0.attn.relative_position_index", "core.core.pretrained.model.blocks.1.attn.relative_position_index", "core.core.pretrained.model.blocks.2.attn.relative_position_index", "core.core.pretrained.model.blocks.3.attn.relative_position_index", "core.core.pretrained.model.blocks.4.attn.relative_position_index", "core.core.pretrained.model.blocks.5.attn.relative_position_index", "core.core.pretrained.model.blocks.6.attn.relative_position_index", "core.core.pretrained.model.blocks.7.attn.relative_position_index", "core.core.pretrained.model.blocks.8.attn.relative_position_index", "core.core.pretrained.model.blocks.9.attn.relative_position_index", "core.core.pretrained.model.blocks.10.attn.relative_position_index", "core.core.pretrained.model.blocks.11.attn.relative_position_index", "core.core.pretrained.model.blocks.12.attn.relative_position_index", "core.core.pretrained.model.blocks.13.attn.relative_position_index", "core.core.pretrained.model.blocks.14.attn.relative_position_index", "core.core.pretrained.model.blocks.15.attn.relative_position_index", "core.core.pretrained.model.blocks.16.attn.relative_position_index", "core.core.pretrained.model.blocks.17.attn.relative_position_index", "core.core.pretrained.model.blocks.18.attn.relative_position_index", "core.core.pretrained.model.blocks.19.attn.relative_position_index", "core.core.pretrained.model.blocks.20.attn.relative_position_index", "core.core.pretrained.model.blocks.21.attn.relative_position_index", "core.core.pretrained.model.blocks.22.attn.relative_position_index", "core.core.pretrained.model.blocks.23.attn.relative_position_index". 
jay-sign commented 1 year ago

run this -->

!pip install timm==0.6.7

and restart the runtime

fenneishi commented 11 months ago

same problem,fixed,thanks

NoSuchObjectException commented 6 months ago

fixed the problem!

ashwon13 commented 6 months ago

It solved the same problem for me as well, Thanks so much!

talkativewarrior commented 6 months ago

jay-sign thanks alot!!!!

ddyaoshang commented 5 months ago

运行这个-->

!pip install timm==0.6.7

并重新启动运行时

Thank you very much

cookie-yu commented 4 months ago

thank you so much ! ! !

cohnt commented 4 months ago

This fixed it for me. Thank you!