THUDM / VisualGLM-6B

Chinese and English multimodal conversational language model | 多模态中英双语对话语言模型
Apache License 2.0
4.07k stars 415 forks source link

读取图片太慢了 #81

Open freelancerllm opened 1 year ago

freelancerllm commented 1 year ago

image 一秒一张图片

freelancerllm commented 1 year ago

[2023-05-30 17:16:28,151] [INFO] [RANK 0] > Set tokenizer as a /mnt/benteng.bt/chatglm-6b tokenizer! Now you can get_tokenizer() everywhere. Image processing: 0%|▌ | 82/33319 [01:09<7:50:37, 1.18it/s]

freelancerllm commented 1 year ago

如何设置使用gpu处理图片呀 image

CaicaiJason commented 1 year ago

image 我也碰到一样的问题了,请问解决了吗?

sssssshf commented 1 year ago

image 我也碰到一样的问题了,请问解决了吗?

这是啥数据集?

CaicaiJason commented 1 year ago

image 我也碰到一样的问题了,请问解决了吗?

这是啥数据集?

自己生成的一些数据

CiciCR7 commented 1 year ago

同问,数据加载太慢了,

CiciCR7 commented 1 year ago

image 我也碰到一样的问题了,请问解决了吗?

请问你怎么可以使用这么多数据,我用了10000条左右就不行了,就不动了

Sleepychord commented 1 year ago

大量的数据一般用lmdb或者webdataset的格式读取,也有很多其他的方法可以在网上查阅,大家可以自行更改代码的创建数据集函数。因为其他的格式上手有一定难度,因此本项目使用了直接的图像读取。

af-74413592 commented 11 months ago

主要是processor(Image.open(item['img']).convert('RGB'))预处理图片这一步,先得把图片Blip embedding加载出来,可以在for item in data这一行加一个tqdm,用啥数据格式读取倒不是关键。

af-74413592 commented 11 months ago

主要是processor(Image.open(item['img']).convert('RGB'))预处理图片这一步,先得把图片Blip embedding加载出来,可以在for item in data这一行加一个tqdm,用啥数据格式读取倒不是关键。

这个dataset就是简单粗暴的把所有图片矩阵加载到内存中,数据量大了确实会爆掉。

af-74413592 commented 11 months ago

尝试改了一下,load_dataset(streaming=True)的形式,改是改出来了,但是这个sat.training_main还不支持流式dataset。。。

CaicaiJason commented 11 months ago

改一下dataset的方法,先把索引读进去,然后每个batch再读图片 原来给的方法一口气把数据集全读到内存里了,直接oom

class FewShotDataset(Dataset):
    def __init__(self, path, processor, tokenizer, args):
        self.max_seq_length = args.max_source_length + args.max_target_length
        with open(path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        self.processor = processor
        self.tokenizer = tokenizer
        self.args = args

        self.images = [i['img'] for i in self.data]
        self.prompt = [i['prompt'] for i in self.data]
        self.label = [i['label'] for i in self.data]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = self.data[idx]
        try:
            image = self.processor(Image.open(item['img']).convert('RGB'))
        except:
            image = torch.zeros((3, 224, 224))
            print('read img error', item['img'])
        input0 = self.tokenizer.encode("<img>", add_special_tokens=False)
        input1 = [self.tokenizer.pad_token_id] * self.args.image_length
        input2 = self.tokenizer.encode("</img>问:" + item['prompt'] + "\n答:", add_special_tokens=False)
        a_ids = sum([input0, input1, input2], [])
        b_ids = self.tokenizer.encode(text=item['label'], add_special_tokens=False)
        if len(a_ids) > self.args.max_source_length - 1:
            a_ids = a_ids[: self.args.max_source_length - 1]
        if len(b_ids) > self.args.max_target_length - 2:
            b_ids = b_ids[: self.args.max_target_length - 2]
        pre_image = len(input0)
        input_ids = self.tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

        context_length = input_ids.index(self.tokenizer.bos_token_id)
        mask_position = context_length - 1
        labels = [-100] * context_length + input_ids[mask_position + 1:]

        pad_len = self.max_seq_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        labels = labels + [self.tokenizer.pad_token_id] * pad_len
        if self.args.ignore_pad_token_for_loss:
            labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]

        return {
            "image": image,
            "input_ids": input_ids,
            "labels": labels,
            "pre_image": pre_image
        }
af-74413592 commented 11 months ago

改一下dataset的方法,先把索引读进去,然后每个batch再读图片 原来给的方法一口气把数据集全读到内存里了,直接oom

class FewShotDataset(Dataset):
    def __init__(self, path, processor, tokenizer, args):
        self.max_seq_length = args.max_source_length + args.max_target_length
        with open(path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        self.processor = processor
        self.tokenizer = tokenizer
        self.args = args

        self.images = [i['img'] for i in self.data]
        self.prompt = [i['prompt'] for i in self.data]
        self.label = [i['label'] for i in self.data]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        item = self.data[idx]
        try:
            image = self.processor(Image.open(item['img']).convert('RGB'))
        except:
            image = torch.zeros((3, 224, 224))
            print('read img error', item['img'])
        input0 = self.tokenizer.encode("<img>", add_special_tokens=False)
        input1 = [self.tokenizer.pad_token_id] * self.args.image_length
        input2 = self.tokenizer.encode("</img>问:" + item['prompt'] + "\n答:", add_special_tokens=False)
        a_ids = sum([input0, input1, input2], [])
        b_ids = self.tokenizer.encode(text=item['label'], add_special_tokens=False)
        if len(a_ids) > self.args.max_source_length - 1:
            a_ids = a_ids[: self.args.max_source_length - 1]
        if len(b_ids) > self.args.max_target_length - 2:
            b_ids = b_ids[: self.args.max_target_length - 2]
        pre_image = len(input0)
        input_ids = self.tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

        context_length = input_ids.index(self.tokenizer.bos_token_id)
        mask_position = context_length - 1
        labels = [-100] * context_length + input_ids[mask_position + 1:]

        pad_len = self.max_seq_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        labels = labels + [self.tokenizer.pad_token_id] * pad_len
        if self.args.ignore_pad_token_for_loss:
            labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]

        return {
            "image": image,
            "input_ids": input_ids,
            "labels": labels,
            "pre_image": pre_image
        }

嗯,你的方法也可以,我也改出来了,增加--iterable-dataset参数,就可以支持流式dataset了。现在正常训练了。