OFA-Sys / Chinese-CLIP

Chinese version of CLIP which achieves Chinese cross-modal retrieval and representation generation.
MIT License
4.01k stars 418 forks source link

load_from_name 加入 flash-attn 支持 #312

Open ZechengLi19 opened 2 months ago

ZechengLi19 commented 2 months ago

感谢你如此好的代码实现,他对我的帮助很大,但是我在使用load_from_name 函数时,我发现并不支持flash-attn ,因此我自己实现了这一块的代码,但是我不确定实现是否正确,尽管它可以正常运行。


###### ------- ps: add use_flash_attention keyword ------- ######
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
                   download_root: str = None, vision_model_name: str = None, text_model_name: str = None, 
                   input_resolution: int = None, use_flash_attention: bool = False):
    if name in _MODELS:
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
        model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
    elif os.path.isfile(name):
        assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
        model_path = name
        model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    with open(model_path, 'rb') as opened_file:
        # loading saved checkpoint
        checkpoint = torch.load(opened_file, map_location="cpu")

    model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
    if str(device) == "cpu":
    return model, image_transform(model_input_resolution)
###### ------- ps: convert flash_attention weight ------- ######
def create_model(model_name, checkpoint=None, use_flash_attention=False):
    vision_model, text_model = model_name.split('@')
    # Initialize the model.
    vision_model_config_file = Path(
        __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
    print('Loading vision model config from', vision_model_config_file)
    assert os.path.exists(vision_model_config_file)

    text_model_config_file = Path(
        __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
    print('Loading text model config from', text_model_config_file)
    assert os.path.exists(text_model_config_file)

    with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
        model_info = json.load(fv)
        for k, v in json.load(ft).items():
            model_info[k] = v
    if isinstance(model_info['vision_layers'], str):
        model_info['vision_layers'] = eval(model_info['vision_layers'])
    print('Model info', model_info)
    if use_flash_attention:
        model_info['use_flash_attention'] = use_flash_attention
    model = CLIP(**model_info)

    if checkpoint:
        if use_flash_attention:
            sd = checkpoint["state_dict"]
            sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
            # Resize the positional embedding by interpolation, if needed
            resize_pos_embed(sd, model, prefix="module.")
            # Adapt flash attention
            sd = convert_state_dict(sd)
            # Load the state dict
            sd = checkpoint["state_dict"]
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
    return model




DtYXs commented 1 month ago

您好,目前在启动flash-attn训练时,保存的ckpt格式与不启动是完全一致的。因此用flash-attn训练得到的ckpt应该是直接可以load进来的,您可以先尝试一下。 https://github.com/OFA-Sys/Chinese-CLIP/blob/8a323f2702ebeca5be864e4cb66a72d9a770f25c/cn_clip/training/train.py#L309

ZechengLi19 commented 1 month ago

@DtYXs 感谢您的回复,但是您好像误解了我的意思。


DtYXs commented 1 month ago

@ZechengLi19 我明白你的意思~我理解目前代码中定义的flash-attn格式只适用Chinese-CLIP这一个项目,而Chinese-CLIP训练得到的模型会自动将flash-attn模型转化为正常模式,所以我想知道目前是在什么情况下需要load一个flash-attn格式的模型呢。

ZechengLi19 commented 1 month ago

@DtYXs 比如说,我想把你训练好的chinese-clip用到其他下游任务中。

