substratusai / images

Official Substratus Container Images
1 stars 0 forks source link

issue with only pytorch model only filtering #24

Closed samos123 closed 1 year ago

samos123 commented 1 year ago

The model will try safetensors first when available:

ta, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, is_quantized, keep_in_fp32_modules)     3288 if shard_file in disk_only_shard_files:                                                                                
   3289     continue                                                                                                           
-> 3290 state_dict = load_state_dict(shard_file)                                                                                  3292 # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not                   3293 # matching the weights in the model.                                                                                      3294 mismatched_keys += _find_mismatched_keys(                                                                                 3295     state_dict,                                                                                                           3296     model_state_dict,                                                                                                  
   (...)                                                                                                                       
   3300     ignore_mismatched_sizes,                                                                                              3301 )                                                                                                                                                                                                                                                     
File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:447, in load_state_dict(checkpoint_file)           
    442 """                                                                                                                        443 Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.                                    
    444 """                                                                                                                    
    445 if checkpoint_file.endswith(".safetensors") and is_safetensors_available():                                                446     # Check format of the archive                                                                                      
--> 447     with safe_open(checkpoint_file, framework="pt") as f:                                                              
    448         metadata = f.metadata()                                                                                        
    449     if metadata.get("format") not in ["pt", "tf", "flax"]:                                                             

FileNotFoundError: No such file or directory: "/content/saved-model/model-00001-of-00002.safetensors" 

Possible fixes, make sure to filter out the model.safetensors.index.json file which prevents this issue OR ensure when safetensors are there that other model files don't get used

I think the right fix is to keep only safetensors when available since that's what gets default if pytorch and safetensors are both available.

samos123 commented 1 year ago

Impact: Currently any model that has safetensors and pytorch model won't work with the trainer