OpenGVLab / InternVL

[CVPR 2024 Oral] InternVL Family: A Pioneering Open-Source Alternative to GPT-4o. 接近GPT-4o表现的开源多模态对话模型
https://internvl.readthedocs.io/en/latest/
MIT License
5.96k stars 462 forks source link

为什么我把internVL转存到cuda上时,python会崩溃报段错误 #271

Closed qyr0403 closed 2 months ago

qyr0403 commented 4 months ago

bash video_retrieval.sh

Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]/opt/conda/envs/umt/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: Type
dStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storag$
s directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.get(instance, owner)()
Loading checkpoint shards: 100%|██████████████| 3/3 [00:00<00:00, 4.59it/s]
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that t$
e legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you u$derstand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Computing text features
0%| | 0/1000 [00:00<?, ?it/s]Fatal Python error: Segmentation fault

Thread 0x00007f6076956700 (most recent call first):
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 324 in wait
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 607 in wait
File "/opt/conda/envs/umt/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f60b6a7a700 (most recent call first):
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 324 in wait
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 607 in wait
File "/opt/conda/envs/umt/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/opt/conda/envs/umt/lib/python3.10/threading.py", line 973 in _bootstrap

Current thread 0x00007f621261e100 (most recent call first):
File "/opt/conda/envs/umt/lib/python3.10/site-packages/torch/cuda/init.py", line 298 in _lazy_init
File "/data2/dy/code/InternVideo/InternVideo2/multi_modality/InternVL_video_retrieval.py", line 65 in validate_msrvtt
File "/data2/dy/code/InternVideo/InternVideo2/multi_modality/InternVL_video_retrieval.py", line 166 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, nu$
py.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc6$
, numpy.random._generator, yaml._yaml, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, sentencepie$
e._sentencepiece, PIL._imaging, zope.interface._zope_interface_coptimizations, sqlalchemy.cyextension.collections, sqlalchemy.cyextension.immutabledict, sqlalchem$
.cyextension.processors, sqlalchemy.cyextension.resultproxy, sqlalchemy.cyextension.util, greenlet._greenlet, google._upb._message, psutil.psutil_linux, psutil.$
sutil_posix, PIL._imagingft, av._core, av.logging, av.bytesource, av.buffer, av.audio.format, av.enum, av.error, av.utils, av.option, av.descriptor, av.container.$yio, av.dictionary, av.format, av.stream, av.container.streams, av.sidedata.motionvectors, av.sidedata.sidedata, av.packet, av.container.input, av.container.outpu$
, av.container.core, av.codec.context, av.video.format, av.video.reformatter, av.plane, av.video.plane, av.video.frame, av.video.stream, av.codec.codec, av.frame,
av.audio.layout, av.audio.plane, av.audio.frame, av.audio.stream, av.audio.fifo, av.filter.pad, av.filter.link, av.filter.context, av.filter.graph, av.filter.filt$r, av.audio.resampler (total: 75)
video_retrieval.sh: line 6: 1129879 Segmentation fault (core dumped) python -X faulthandler InternVL_video_retrieval.py --video-root /data1/DATASET/MSRVTT/vid
eos --metadata /data2/dy/code/unmasked_teacher/umt_data/anno_downstream/msrvtt_ret_test1k(umt)

代码文件如下:

import argparse
import io
import json
import math
import os
os.environ['http_proxy'] = "http://127.0.0.1:1081"
os.environ['https_proxy'] = "http://127.0.0.1:1081"
import decord
import mmengine
import numpy as np
import torch
import tqdm
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor

def recall_at_k(scores, positive_pairs, k):
    """
    Compute the recall at k for each sample
    :param scores: compability score between  text and image embeddings (nb texts, nb images)
    :param k: number of images to consider per text, for retrieval
    :param positive_pairs: boolean matrix of positive pairs (nb texts, nb images)
    :return: recall at k averaged over all texts
    """
    nb_texts, nb_images = scores.shape
    # for each text, sort according to image scores in decreasing order
    topk_indices = torch.topk(scores, k, dim=1)[1]
    # compute number of positives for each text
    nb_positive = positive_pairs.sum(dim=1)
    # nb_texts, k, nb_images
    topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
    # compute number of true positives
    positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
    # a true positive means a positive among the topk
    nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
    # compute recall at k
    recall_at_k = (nb_true_positive / nb_positive)
    return recall_at_k

def batchify(func, X, Y, batch_size, device, *args, **kwargs):
    results = []
    for start in range(0, len(X), batch_size):
        end = start + batch_size
        x = X[start:end].to(device)
        y = Y[start:end].to(device)
        result = func(x, y, *args, **kwargs).cpu()
        results.append(result)
    return torch.cat(results)

def validate_msrvtt(model, tokenizer, image_processor, root, metadata,
                    num_frames=1, prefix='summarize:', mode='InternVL-G', recall_k_list=[1, 5, 10],
                    use_dsl=True, eval_batch_size=32):
    metadata = json.load(open(metadata))

    video_features = []
    text_features = []

    # compute text features
    print('Computing text features', flush=True)
    for data in tqdm.tqdm(metadata):
        caption = prefix + data['caption']
        input_ids = tokenizer(caption, return_tensors='pt', max_length=80,
                        truncation=True, padding='max_length').input_ids.cuda()
        with torch.no_grad():
            feat = model.encode_text(input_ids)
        text_features.append(feat.cpu())
    text_features = torch.cat(text_features)

    # compute video features
    print('Computing video features', flush=True)
    for data in tqdm.tqdm(metadata):
        video_id = data['video']
        video_path = os.path.join(root, video_id)
        video_data = mmengine.get(video_path)
        video_data = io.BytesIO(video_data)
        video_reader = decord.VideoReader(video_data)

        # uniformly sample frames
        interval = math.ceil(len(video_reader) / num_frames)
        frames_id = np.arange(0, len(video_reader), interval) + interval // 2
        assert len(frames_id) == num_frames and frames_id[-1] < len(video_reader)

        frames = video_reader.get_batch(frames_id).asnumpy()

        pixel_values = image_processor(images=frames, return_tensors='pt').pixel_values
        with torch.no_grad():
            pixel_values = pixel_values.to(torch.bfloat16).cuda()
            feat = model.encode_image(pixel_values, mode=mode)
            feat = feat.mean(dim=0, keepdim=True)
        video_features.append(feat.cpu())
    video_features = torch.cat(video_features)

    print('Computing metrics', flush=True)
    texts_emb = text_features / text_features.norm(dim=-1, keepdim=True)
    images_emb = video_features / video_features.norm(dim=-1, keepdim=True)

    # get the score for each text and image pair
    scores = texts_emb @ images_emb.t()

    # construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
    positive_pairs = torch.zeros_like(scores, dtype=bool)
    positive_pairs[torch.arange(len(scores)), torch.arange(len(scores))] = True

    scores_T = scores.T
    positive_pairs_T = positive_pairs.T

    if use_dsl:
        scores = scores * scores.softmax(dim=0)
        scores_T = scores_T * scores_T.softmax(dim=0)

    metrics = {}
    for recall_k in recall_k_list:
        # Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
        # of true positives, e.g. for text retrieval, is, for each image,  the number of retrieved texts matching that image among the top-k.
        # Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
        # for each image, that number will be greater than 1 for text retrieval.
        # However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
        # recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
        # so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
        # which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
        # it over the dataset.
        metrics[f't2v_retrieval_recall@{recall_k}'] = (
                    batchify(recall_at_k, scores, positive_pairs, eval_batch_size, scores.device,
                             k=recall_k) > 0).float().mean().item()
        metrics[f'v2t_retrieval_recall@{recall_k}'] = (
                    batchify(recall_at_k, scores_T, positive_pairs_T, eval_batch_size, scores.device,
                             k=recall_k) > 0).float().mean().item()

    print(metrics)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='validate MSR-VTT', add_help=False)
    parser.add_argument('--video-root', type=str)
    parser.add_argument('--metadata', type=str)
    parser.add_argument('--mode', type=str, default='InternVL-C',choices=['InternVL-C', 'InternVL-G'])
    parser.add_argument('--num-frames', type=int, default=1)
    args = parser.parse_args()

    # try:
    #     model = AutoModel.from_pretrained(
    #         'OpenGVLab/InternVL-14B-224px',
    #         cache_dir="/datassd2/pretrained_models/InternVL",
    #         torch_dtype=torch.bfloat16,
    #         low_cpu_mem_usage=True,
    #         trust_remote_code=True).to("cuda").eval()
    #     import pdb; pdb.set_trace()
    # except Exception as e:
    #     import pdb; pdb.set_trace()
    #     print(e)
    model = AutoModel.from_pretrained(
    'OpenGVLab/InternVL-14B-224px',
    cache_dir="/datassd2/pretrained_models/InternVL",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).to("cuda").eval()

    image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px', cache_dir="/datassd2/pretrained_models/InternVL")

    tokenizer = AutoTokenizer.from_pretrained(
        'OpenGVLab/InternVL-14B-224px', cache_dir="/datassd2/pretrained_models/InternVL", use_fast=False, add_eos_token=True)
    tokenizer.pad_token_id = 0  # set pad_token_id to 0

    metrics = validate_msrvtt(model, tokenizer, image_processor,
                              root=args.video_root,
                              metadata=args.metadata,
                              mode=args.mode,
                              num_frames=args.num_frames,)
czczup commented 3 months ago

请问这个问题解决了吗?可能是因为爆显存了

zmyzxb commented 2 months ago

This issue has been inactive for over two weeks. If the problem is still unresolved, please feel free to open a new issue to ask your question. Thank you.