microsoft / satclip

PyTorch implementation of SatCLIP
MIT License
190 stars 19 forks source link

Unable to Load Locally Stored SATClip Model #14

Closed veeralakrishna closed 1 month ago

veeralakrishna commented 2 months ago

Issue Description: I have downloaded the SATClip model named satclip-vit16-l40.ckpt from Hugging Face and stored it at /pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt. I attempted to load the locally stored model using the provided code in my working environment without internet access. However, I encountered the following error. Seeking guidance on successfully loading and utilising a locally downloaded and stored SATClip model without accessing the internet.

import sys
sys.path.append("./satclip")
from load import get_satclip

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# c = torch.randn(32, 2)  # Represents a batch of 32 locations (lon/lat)
# Represents a batch of 4 locations (lon/lat)
a = [[44.963320,-93.244523],
     [33.872022,-84.464836],
    [30.237592,-95.177780],
    [34.738666,-86.646624],
 ]

torch.as_tensor(a).float()

model = get_satclip(
    '/pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt',
    device=device,
)  # Only loads location encoder by default
model.eval()
with torch.no_grad():
    emb = model(torch.as_tensor(a).float().to(device)).detach().cpu()
using pretrained moco vit16
Downloading: "https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth
---------------------------------------------------------------------------
HTTPError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 model = get_satclip(
      2     '/pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt',
      3     device=device,
      4 )  # Only loads location encoder by default
      5 model.eval()
      6 with torch.no_grad():

File /mnt/satclip/./satclip/load.py:8, in get_satclip(ckpt_path, device, return_all)
      6 ckpt['hyper_parameters'].pop('air_temp_data_path')
      7 ckpt['hyper_parameters'].pop('election_data_path')
----> 8 lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device)
     10 lightning_model.load_state_dict(ckpt['state_dict'])
     11 lightning_model.eval()

File /mnt/satclip/./satclip/main.py:39, in SatCLIPLightningModule.__init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, in_channels, le_type, pe_type, frequency_num, max_radius, min_radius, legendre_polys, harmonics_calculation, sh_embedding_dims, learning_rate, weight_decay, num_hidden_layers, capacity)
     16 def __init__(
     17     self,
     18     embed_dim=512,
   (...)
     35     capacity=256,        
     36 ) -> None:
     37     super().__init__()
---> 39     self.model = SatCLIP(
     40         embed_dim=embed_dim,
     41         image_resolution=image_resolution,
     42         vision_layers=vision_layers,
     43         vision_width=vision_width,
     44         vision_patch_size=vision_patch_size,
     45         in_channels=in_channels,
     46         le_type=le_type,
     47         pe_type=pe_type,
     48         frequency_num=frequency_num,
     49         max_radius=max_radius,
     50         min_radius=min_radius,
     51         legendre_polys=legendre_polys,
     52         harmonics_calculation=harmonics_calculation,
     53         sh_embedding_dims=sh_embedding_dims,
     54         num_hidden_layers=num_hidden_layers,
     55         capacity=capacity,
     56     )
     58     self.loss_fun = SatCLIPLoss()
     59     self.learning_rate = learning_rate

File /mnt/satclip/./satclip/model.py:309, in SatCLIP.__init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, in_channels, le_type, pe_type, frequency_num, max_radius, min_radius, harmonics_calculation, legendre_polys, sh_embedding_dims, ffn, num_hidden_layers, capacity, *args, **kwargs)
    307 in_chans = weights.meta["in_chans"]
    308 self.visual = timm.create_model("vit_small_patch16_224", in_chans=in_chans, num_classes=embed_dim)
--> 309 self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
    310 self.visual.requires_grad_(False)
    311 self.visual.head.requires_grad_(True)

File /opt/conda/lib/python3.9/site-packages/torchvision/models/_api.py:90, in WeightsEnum.get_state_dict(self, *args, **kwargs)
     89 def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
---> 90     return load_state_dict_from_url(self.url, *args, **kwargs)

File /opt/conda/lib/python3.9/site-packages/torch/hub.py:760, in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name, weights_only)
    758         r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    759         hash_prefix = r.group(1) if r else None
--> 760     download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    762 if _is_legacy_zip_format(cached_file):
    763     return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)

File /opt/conda/lib/python3.9/site-packages/torch/hub.py:622, in download_url_to_file(url, dst, hash_prefix, progress)
    620 file_size = None
    621 req = Request(url, headers={"User-Agent": "torch.hub"})
--> 622 u = urlopen(req)
    623 meta = u.info()
    624 if hasattr(meta, 'getheaders'):

File /opt/conda/lib/python3.9/urllib/request.py:214, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    212 else:
    213     opener = _opener
--> 214 return opener.open(url, data, timeout)

File /opt/conda/lib/python3.9/urllib/request.py:523, in OpenerDirector.open(self, fullurl, data, timeout)
    521 for processor in self.process_response.get(protocol, []):
    522     meth = getattr(processor, meth_name)
--> 523     response = meth(req, response)
    525 return response

File /opt/conda/lib/python3.9/urllib/request.py:632, in HTTPErrorProcessor.http_response(self, request, response)
    629 # According to RFC 2616, "2xx" code indicates that the client's
    630 # request was successfully received, understood, and accepted.
    631 if not (200 <= code < 300):
--> 632     response = self.parent.error(
    633         'http', request, response, code, msg, hdrs)
    635 return response

File /opt/conda/lib/python3.9/urllib/request.py:561, in OpenerDirector.error(self, proto, *args)
    559 if http_err:
    560     args = (dict, 'default', 'http_error_default') + orig_args
--> 561     return self._call_chain(*args)

File /opt/conda/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    492 for handler in handlers:
    493     func = getattr(handler, meth_name)
--> 494     result = func(*args)
    495     if result is not None:
    496         return result

File /opt/conda/lib/python3.9/urllib/request.py:641, in HTTPDefaultErrorHandler.http_error_default(self, req, fp, code, msg, hdrs)
    640 def http_error_default(self, req, fp, code, msg, hdrs):
--> 641     raise HTTPError(req.full_url, code, msg, hdrs, fp)

HTTPError: HTTP Error 503: Service Unavailable

Looking forward to your assistance in resolving this issue promptly. Thank you.

konstantinklemmer commented 2 months ago

The problem arises from the code trying to load the pretrained vision encoder (e.g. a ResNet 50) too, which it does from HuggingFace too.

To just load the location encoder you'll have to write a custom function that creates a LocationEncoder() class (from here: https://github.com/microsoft/satclip/blob/main/satclip/location_encoder.py) and then loads the part of the state_dict from satclip-vit16-l40.ckpt that is just the location encoder.

Hope this helps!

konstantinklemmer commented 1 month ago

Lightweight loading that supports offline use is now available here: https://github.com/microsoft/satclip/blob/main/satclip/load_lightweight.py. Thanks to @crastoru.