PiotrNawrot / nanoT5

Fast & Simple repository for pre-training and fine-tuning T5-style models
Apache License 2.0
970 stars 74 forks source link

self-defined loss function failed to work (torch._dynamo.exc.InternalTorchDynamoError: ln_encoder) #24

Closed QinengWang-Aiden closed 1 year ago

QinengWang-Aiden commented 1 year ago

I try to add my own loss function using the encoder's hidden states, and I add a new linear layer similar to your layer self.lm_head to obtain the corresponding logits. However, the training process fails every time and it seems like I did not use the linear layer correctly, but I do not know why... Here is my modified part of MyT5 module:

class MyT5(nn.Module):
    def __init__(self, config: T5Config):
        ...
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.ln_encode = nn.Linear(config.d_model, 3, bias=False)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        encoder_outputs = None,
        myencode_attention_mask: Optional[torch.BoolTensor] = None,
        seq_order: Optional[torch.LongTensor] = None,
    ) -> Seq2SeqLMOutput:
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

        hidden_states = encoder_outputs.hidden_states
        myencode_attention_mask_extended = myencode_attention_mask.unsqueeze(-1).expand_as(encoder_outputs[0])
        myencode_hidden_states = encoder_outputs[0][myencode_attention_mask_extended].view(-1, self.model_dim)
        myencode_loss_ft = CrossEntropyLoss(ignore_index=-100)
        ln_encode_logits = self.ln_encoder(myencode_hidden_states)
        seq_loss = seq_loss_ft(ln_encode_logits.view(-1, ln_encode_logits.size(-1)), seq_order[seq_order != -100].view(-1))
        ...
        loss += seq_loss
        return ...
    def _init_weights(self, module):
        factor = self.config.initializer_factor  # Used for testing weights initialization
        ...
        elif isinstance(module, (MyT5)):
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
            if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
                module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
                print("lm initialized")
            if hasattr(module, "ln_encode"):
                module.ln_encode.weight.data.normal_(mean=0.0, std=factor * 1.0)
            ...

And here is the error message:

Error executing job with overrides: []
Traceback (most recent call last):
  File "/data2/usr/projects/nanoT5/nanoT5/main.py", line 85, in main
    train(model, train_dataloader, test_dataloader, accelerator,
  File "/data2/usr/projects/nanoT5/nanoT5/utils/train_utils.py", line 237, in train
    loss, stats = forward(model, batch)
  File "/data2/usr/projects/nanoT5/nanoT5/utils/train_utils.py", line 90, in forward
    outputs = model(**batch)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1521, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1357, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/data2/usr/projects/nanoT5/nanoT5/utils/t5_model.py", line 477, in forward
    myencode_hidden_states = encoder_outputs[0][myencode_attention_mask_extended].view(-1, self.model_dim)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return hijacked_callback(frame, cache_size, hooks, frame_state)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 637, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 371, in _convert_frame_assert
    return _compile(
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 584, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 567, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 181, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 466, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 433, in transform
    tracer.run()
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2071, in run
    super().run()
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1191, in LOAD_ATTR
    result = BuiltinVariable(getattr).call_function(
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 608, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1074, in call_getattr
    return obj.var_getattr(tx, name).add_options(options)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 192, in var_getattr
    subobj = inspect.getattr_static(base, name)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/inspect.py", line 1769, in getattr_static
    raise AttributeError(attr)
torch._dynamo.exc.InternalTorchDynamoError: ln_encoder

from user code:
   File "/data2/usr/projects/nanoT5/nanoT5/utils/t5_model.py", line 479, in <resume in forward>
    ln_encode_logits = self.ln_encoder(myencode_hidden_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Looking forward to your assistance :)

PiotrNawrot commented 1 year ago

Does it work without Pytorch Compile?

QinengWang-Aiden commented 1 year ago

Firstly, I'd like to apologize for providing a misleading error message earlier. The error in the code I provided was due to my mistakenly writing self.ln_encode instead of self.ln_encoder. So, the actual error message was not related to that. Instead, it was the following:

ERROR RUNNING GUARDS __init__ <string>:2
lambda L, **___kwargs_ignored:
  ___guarded_code.valid and
  hasattr(L['loss'], '_dynamo_dynamic_indices') == False and
  ___check_type_id(L['self'], 142654656) and
  hasattr(L['logits'], '_dynamo_dynamic_indices') == False and
  ___check_obj_id(L['self'].loss, 7628576) and
  ___check_obj_id(L['self'].logits, 7628576) and
  ___check_obj_id(L['self'].encoder_outputs, 7628576) and
  hasattr(L['encoder_outputs'].hidden_states, '_dynamo_dynamic_indices') == False and
  hasattr(L['encoder_outputs'].attention_mask, '_dynamo_dynamic_indices') == False and
  ___is_grad_enabled() and
  not ___are_deterministic_algorithms_enabled() and
  ___is_torch_function_enabled() and
  utils_device.CURRENT_DEVICE == None and
  ___check_tensors(L['loss'], L['logits'], L['encoder_outputs'].hidden_states, L['encoder_outputs'].attention_mask)
Error executing job with overrides: []
Traceback (most recent call last):
  File "/data2/usr/projects/nanoT5/nanoT5/main.py", line 85, in main
    train(model, train_dataloader, test_dataloader, accelerator,
  File "/data2/usr/projects/nanoT5/nanoT5/utils/train_utils.py", line 237, in train
    loss, stats = forward(model, batch)
  File "/data2/usr/projects/nanoT5/nanoT5/utils/train_utils.py", line 90, in forward
    outputs = model(**batch)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in <resume in __call__>
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 548, in convert_to_fp32
    return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
  File "/home/usr/miniconda3/envs/nanoT5/lib/python3.10/site-packages/accelerate/utils/operations.py", line 119, in recursively_apply
    return type(data)(
  File "<string>", line 21, in guard
AttributeError: 'NoneType' object has no attribute 'hidden_states'

Secondly, I just attempted to run the code with torch.compile=False and found that it runs successfully using both the python -m nanoT5.main and accelerate launch -m nanoT5.main commands (at least for the first 100 steps).

Since I haven't delved deeply into the mechanics of torch.compile, I would like to inquire about the possible reasons for the previous error message. Thank you in advance :)

PiotrNawrot commented 1 year ago

That's good that it works without torch.compile. Unfortunately, I don't know the underlying dynamic of torch.compile well enough to tell you what's the issue. I guess that you may try to reproduce this error with some sketch model and then raise an issue on the official PyTorch repo.

QinengWang-Aiden commented 1 year ago

Okay, thanks for your prompt responses!