lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

when `transformer.generate(prompt=None)`,empty code is passed to the decoder. Error!! #55

Closed fighting-Zhang closed 6 months ago

fighting-Zhang commented 7 months ago

Thank you very much for your work!!!

When I try to generate with the trained model, if I don't add the prompt, the codes are generated one by one starting from empty. codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device)) When empty is input into the model decoder, an error will be reported. Only when face_codes.size(1) is not 0, no error will be reported. I would like to ask you how to solve it.

I tried entering prompt and it was generated successfully.

transformer.generate(prompt=None) face_codes.size()= [1,0,512]

Error: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. File "/ssd1/meshgpt-pytorch/meshgpt_pytorch/meshgpt_pytorch.py", line 1436, in forward_on_codes attended_face_codes, coarse_cache = self.decoder( File "/ssd1/meshgpt-pytorch/meshgpt_pytorch/meshgpt_pytorch.py", line 1189, in generate output = self.forward_on_codes( File "/ssd1/meshgpt-pytorch/generate_samples_v1.py", line 109, in face_coords, face_mask = transformer.generate(temperature=r, texts=texts_list) RuntimeError: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

lucidrains commented 7 months ago

can you show the full condensed script for which it errors?

fighting-Zhang commented 7 months ago

generate_samples_v1.py:

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

autoencoder = MeshAutoencoder.init_and_load('./exps/mesh-autoencoder.ckpt.90.pt')

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 12000, #8192,
    flash_attn = True,
    gateloop_use_heinsen = False, 
    condition_on_text = False
).cuda()
transformer.load('./checkpoints/mesh-transformer.ckpt.5.pt')

face_coords, face_mask = transformer.generate(temperature=0.5)

meshgpt_pytorch.py is version 0.6.7

MeshTransformer.generate() :

 @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        prompt: Optional[Tensor] = None,
        batch_size: Optional[int] = None,
        filter_logits_fn: Callable = top_k,
        filter_kwargs: dict = dict(),
        temperature = 1.,
        return_codes = False,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_scale = 1.,
        cache_kv = True,
        max_seq_len = None,
        face_coords_to_file: Optional[Callable[[Tensor], Any]] = None
    ):
        max_seq_len = default(max_seq_len, self.max_seq_len)

        if exists(prompt):
            assert not exists(batch_size)

            prompt = rearrange(prompt, 'b ... -> b (...)')
            assert prompt.shape[-1] <= self.max_seq_len

            batch_size = prompt.shape[0]

        if self.condition_on_text:
            assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'
            if exists(texts):
                text_embeds = self.embed_texts(texts)

            batch_size = default(batch_size, text_embeds.shape[0])

        batch_size = default(batch_size, 1)

        codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))

        curr_length = codes.shape[-1]

        cache = (None, None)

        for i in tqdm(range(curr_length, max_seq_len)):
            # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)

            can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3)  # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes

            output = self.forward_on_codes(
                codes,
                text_embeds = text_embeds,
                return_loss = False,
                return_cache = cache_kv,
                append_eos = False,
                cond_scale = cond_scale,
                cfg_routed_kwargs = dict(
                    cache = cache
                )
            )

            if cache_kv:
                logits, cache = output

                if cond_scale == 1.:
                    cache = (cache, None)
            else:
                logits = output

            logits = logits[:, -1]

            if not can_eos:
                logits[:, -1] = -torch.finfo(logits.dtype).max

            filtered_logits = filter_logits_fn(logits, **filter_kwargs)

            if temperature == 0.:
                sample = filtered_logits.argmax(dim = -1)
            else:
                probs = F.softmax(filtered_logits / temperature, dim = -1)
                sample = torch.multinomial(probs, 1)

            codes, _ = pack([codes, sample], 'b *')

            # check for all rows to have [eos] to terminate

            is_eos_codes = (codes == self.eos_token_id)

            if is_eos_codes.any(dim = -1).all():
                break

        # mask out to padding anything after the first eos

        mask = is_eos_codes.float().cumsum(dim = -1) >= 1
        codes = codes.masked_fill(mask, self.pad_id)

        # remove a potential extra token from eos, if breaked early

        code_len = codes.shape[-1]
        round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
        codes = codes[:, :round_down_code_len]

        # early return of raw residual quantizer codes

        if return_codes:
            codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
            return codes

        self.autoencoder.eval()
        face_coords, face_mask = self.autoencoder.decode_from_codes_to_faces(codes)

        if not exists(face_coords_to_file):
            return face_coords, face_mask

        files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)]
        return files

MeshTransformer.forward_on_codes() :

@classifier_free_guidance
    def forward_on_codes(
        self,
        codes = None,
        return_loss = True,
        return_cache = False,
        append_eos = True,
        cache = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = 0.
    ):
        # handle text conditions

        attn_context_kwargs = dict()

        if self.condition_on_text:
            assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'

            if exists(texts):
                text_embeds = self.conditioner.embed_texts(texts)

            if exists(codes):
                assert text_embeds.shape[0] == codes.shape[0], 'batch size of texts or text embeddings is not equal to the batch size of the mesh codes'

            _, maybe_dropped_text_embeds = self.conditioner(
                text_embeds = text_embeds,
                cond_drop_prob = cond_drop_prob
            )

            attn_context_kwargs = dict(
                context = maybe_dropped_text_embeds.embed,
                context_mask = maybe_dropped_text_embeds.mask
            )

        # take care of codes that may be flattened

        if codes.ndim > 2:
            codes = rearrange(codes, 'b ... -> b (...)')

        # get some variable

        batch, seq_len, device = *codes.shape, codes.device

        assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'

        # auto append eos token

        if append_eos:
            assert exists(codes)

            code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1)

            codes = F.pad(codes, (0, 1), value = 0)

            batch_arange = torch.arange(batch, device = device)

            batch_arange = rearrange(batch_arange, '... -> ... 1')
            code_lens = rearrange(code_lens, '... -> ... 1')

            codes[batch_arange, code_lens] = self.eos_token_id

        # if returning loss, save the labels for cross entropy

        if return_loss:
            assert seq_len > 0
            codes, labels = codes[:, :-1], codes

        # token embed (each residual VQ id)

        codes = codes.masked_fill(codes == self.pad_id, 0)
        codes = self.token_embed(codes)

        # codebook embed + absolute positions

        seq_arange = torch.arange(codes.shape[-2], device = device)

        codes = codes + self.abs_pos_emb(seq_arange)

        # embedding for quantizer level

        code_len = codes.shape[1]

        level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(code_len / self.num_quantizers))
        codes = codes + level_embed[:code_len]

        # embedding for each vertex

        vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (3 * self.num_quantizers)), q = self.num_quantizers)
        codes = codes + vertex_embed[:code_len]

        # create a token per face, by summarizing the 3 vertices
        # this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941

        num_tokens_per_face = self.num_quantizers * 3

        curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage

        code_len_is_multiple_of_face = divisible_by(code_len, num_tokens_per_face)

        next_multiple_code_len = ceil(code_len / num_tokens_per_face) * num_tokens_per_face

        codes = pad_to_length(codes, next_multiple_code_len, dim = -2)

        # grouped codes will be used for the second stage

        grouped_codes = rearrange(codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)

        # create the coarse tokens for the first attention network

        face_codes = grouped_codes if code_len_is_multiple_of_face else grouped_codes[:, :-1]

        face_codes = rearrange(face_codes, 'b nf n d -> b nf (n d)')
        face_codes = self.to_face_tokens(face_codes)

        face_codes_len = face_codes.shape[-2]

        # cache logic

        (
            cached_attended_face_codes,
            coarse_cache,
            fine_cache,
            coarse_gateloop_cache,
            fine_gateloop_cache
        ) = cache if exists(cache) else ((None,) * 5)

        if exists(cache):
            cached_face_codes_len = cached_attended_face_codes.shape[-2]
            need_call_first_transformer = face_codes_len > cached_face_codes_len
        else:
            need_call_first_transformer = True

        should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face)

        # attention on face codes (coarse)

        if need_call_first_transformer:
            if exists(self.coarse_gateloop_block):
                face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)

            attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

            attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
        else:
            attended_face_codes = cached_attended_face_codes

        # maybe project from coarse to fine dimension for hierarchical transformers

        attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes)

        # auto prepend sos token

        sos = repeat(self.sos_token, 'd -> b d', b = batch)

        attended_face_codes_with_sos, _ = pack([sos, attended_face_codes], 'b * d')

        grouped_codes = pad_to_length(grouped_codes, attended_face_codes_with_sos.shape[-2], dim = 1)
        fine_vertex_codes, _ = pack([attended_face_codes_with_sos, grouped_codes], 'b n * d')

        fine_vertex_codes = fine_vertex_codes[..., :-1, :]

        # gateloop layers

        if exists(self.fine_gateloop_block):
            fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> b (nf n) d')
            orig_length = fine_vertex_codes.shape[-2]
            fine_vertex_codes = fine_vertex_codes[:, :(code_len + 1)]

            fine_vertex_codes, fine_gateloop_cache = self.fine_gateloop_block(fine_vertex_codes, cache = fine_gateloop_cache)

            fine_vertex_codes = pad_to_length(fine_vertex_codes, orig_length, dim = -2)
            fine_vertex_codes = rearrange(fine_vertex_codes, 'b (nf n) d -> b nf n d', n = num_tokens_per_face)

        # fine attention - 2nd stage

        if exists(cache):
            fine_vertex_codes = fine_vertex_codes[:, -1:]

            if exists(fine_cache):
                for attn_intermediate in fine_cache.attn_intermediates:
                    ck, cv = attn_intermediate.cached_kv
                    ck, cv = map(lambda t: rearrange(t, '(b nf) ... -> b nf ...', b = batch), (ck, cv))
                    ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv))
                    attn_intermediate.cached_kv = (ck, cv)

        one_face = fine_vertex_codes.shape[1] == 1

        fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> (b nf) n d')

        if one_face:
            fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]

        attended_vertex_codes, fine_cache = self.fine_decoder(
            fine_vertex_codes,
            cache = fine_cache,
            return_hiddens = True
        )

        if not should_cache_fine:
            fine_cache = None

        if not one_face:
            # reconstitute original sequence

            embed = rearrange(attended_vertex_codes, '(b nf) n d -> b (nf n) d', b = batch)
            embed = embed[:, :(code_len + 1)]
        else:
            embed = attended_vertex_codes

        # logits

        logits = self.to_logits(embed)

        if not return_loss:
            if not return_cache:
                return logits

            next_cache = (
                attended_face_codes,
                coarse_cache,
                fine_cache,
                coarse_gateloop_cache,
                fine_gateloop_cache
            )

            return logits, next_cache

        # loss

        ce_loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index = self.pad_id
        )

        return ce_loss
lucidrains commented 7 months ago

@fighting-Zhang it works fine for me

can you update to 1.0 and retry?

fighting-Zhang commented 7 months ago

My problem mainly occurs when entering empty into self.decoder.
In the code below, face_codes.size()= [1,0,512]. What are your dimensions?

if need_call_first_transformer:
            if exists(self.coarse_gateloop_block):
                face_codes, coarse_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_gateloop_cache)

            attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

            attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
        else:
            attended_face_codes = cached_attended_face_codes
lucidrains commented 7 months ago

@fighting-Zhang what version of x-transformers are you using?

fighting-Zhang commented 7 months ago

1.27.3

lucidrains commented 7 months ago

@fighting-Zhang does the very first example in the readme run for you?

fighting-Zhang commented 7 months ago

wow, amazing! The very first example works fine.

fighting-Zhang commented 7 months ago

But when I put the data and model on CUDA, I got the above error.

lucidrains commented 7 months ago

wow, amazing! The very first example works fine.

well, the very first example is also promptless. so i don't think that's the issue

fighting-Zhang commented 7 months ago

Thank you for your patient answer. I will continue to look for ways to solve the cuda error.

MarcusLoppe commented 7 months ago

@fighting-Zhang Is the Autoencoder also on the GPU?

fighting-Zhang commented 7 months ago

@MarcusLoppe yes generate_sample_v0.py :

import torch

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# autoencoder

autoencoder = MeshAutoencoder(
    num_discrete_coors = 128
).cuda()

# mock inputs

vertices = torch.randn((2, 121, 3)).cuda()            # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3)).cuda()      # (batch, num faces, vertices (3))

# make sure faces are padded with `-1` for variable lengthed meshes

# forward in the faces

loss = autoencoder(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768
).cuda()

loss = transformer(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates, face_mask = transformer.generate()

Error message :

Traceback (most recent call last): File "/code/mesh-auto/generate_sample_v0.py", line 48, in faces_coordinates, face_mask = transformer.generate() File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/autoregressive_wrapper.py", line 27, in inner out = fn(self, *args, kwargs) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "<@beartype(meshgpt_pytorch.meshgpt_pytorch.MeshTransformer.generate) at 0x7f05a2f83760>", line 170, in generate File "/code/mesh-auto/meshgpt_pytorch/meshgpt_pytorch.py", line 1186, in generate output = self.forward_on_codes( File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 153, in inner outputs = fn_maybe_with_text(self, args, fn_kwargs, kwargs_without_cond_dropout) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py", line 131, in fn_maybe_with_text return fn(self, *args, *kwargs) File "/code/mesh-auto/meshgpt_pytorch/meshgpt_pytorch.py", line 1413, in forward_on_codes attended_face_codes, coarse_cache = self.decoder( File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/x_transformers.py", line 1336, in forward out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/x_transformers.py", line 944, in forward out, intermediates = self.attend( File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/attend.py", line 274, in forward return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) File "/opt/miniconda/envs/meshtransformer/lib/python3.10/site-packages/x_transformers/attend.py", line 214, in flash_attn out = F.scaled_dot_product_attention( RuntimeError: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

MarcusLoppe commented 7 months ago

My problem mainly occurs when entering empty into self.decoder. In the code below, face_codes.size()= [1,0,512]. What are your dimensions?

@fighting-Zhang I've checked and I also get the same shape but I've run the example you provided and it works for me.

I'm guessing you'll need to reinstall meshgpt with all the dependencies or the GPU your using isn't compatible. So give it a go with the reinstall otherwise can you say what GPU you are using along with the pytorch & CUDA version?