NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.72k stars 2.45k forks source link

when "write_predictions_to_file" is true,generate will fail。 #9170

Closed gaojingwei closed 2 months ago

gaojingwei commented 4 months ago

I use images:nvcr.io/nvidia/nemo:24.03.framework,and successfully run /NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py but when I run /NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py, if set model.data.test_ds.write_predictions_to_file=False,it will run success; if set model.data.test_ds.write_predictions_to_file=True,it will run failed。 this is the error log:


Traceback (most recent call last):
  File "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py", line 170, in <module>
    main()
  File "/opt/NeMo/nemo/core/config/hydra_runner.py", line 129, in wrapper
    _run_hydra(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 389, in _run_hydra
    _run_app(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 452, in _run_app
    run_and_report(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 216, in run_and_report
    raise ex
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 213, in run_and_report
    return func()
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 453, in <lambda>
    lambda: hydra.run(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py", line 164, in main
    trainer.test(model)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 754, in test
    return call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 794, in _test_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
    return self._evaluation_loop.run()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 425, in test_step
    return self.lightning_module.test_step(*args, **kwargs)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 431, in test_step
    return self.inference_step(dataloader_iter, 'test')
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 438, in inference_step
    outputs = self.inference_step_validation_call(batch, batch_idx, data_cfg, dataloader_idx)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 468, in inference_step_validation_call
    output = self.predict_step(batch, batch_idx, dataloader_idx)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 680, in predict_step
    response = generate(self, **inference_config)
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 649, in generate
    output = synced_generate(
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 512, in synced_generate
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 801, in sample_sequence_batch
    output = inference_strategy.forward_step(batch, tensor_shape)
  File "/opt/NeMo/nemo/collections/nlp/modules/common/text_generation_strategy.py", line 67, in forward_step
    output_tensor = fwd_bwd_function(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 1211, in forward_backward_pipelining_without_interleaving
    output_tensor = forward_step(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 192, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
  File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1160, in fwd_output_only_func
    output_tensor = model(tokens, position_ids, attention_mask, **extra_arg)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/models/gpt/gpt_model.py", line 174, in forward
    hidden_states = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 370, in forward
    hidden_states, context = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_layer.py", line 176, in forward
    attention_output_with_bias = self.self_attention(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/attention.py", line 254, in forward
    query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
  File "/opt/megatron-lm/megatron/core/transformer/attention.py", line 370, in get_query_key_value_tensors
    mixed_qkv, _ = self.linear_qkv(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/custom_layers/transformer_engine.py", line 245, in forward
    out = super().forward(x, is_first_microbatch=_is_first_microbatch)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 1105, in forward
    out = fwd_fn(*args)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 127, in forward
    ln_out, mu, rsigma = _apply_normalization(inputmat,
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/_common.py", line 90, in _apply_normalization
    return normalization_func(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/cpp_extensions/normalization.py", line 179, in rmsnorm_fwd_inf
    return torch.ops.tex_ts.rmsnorm_fwd_inf_ts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 825, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: /opt/TransformerEngine/transformer_engine/common/transformer_engine.cpp:39 in function CheckInputTensor: Assertion failed: t.data.dptr != nullptr. Input x is not allocated!
Error executing job with overrides: ['model.restore_from_path=/workspace/results/checkpoints/megatron_gpt_peft_none_tuning.nemo']

this is my conf:

name: megatron_gpt_peft_${model.peft.peft_scheme}_tuning

trainer:
  devices: 8
  accelerator: gpu
  num_nodes: 2
  precision: 16
  logger: False # logger provided by exp_manager
  enable_checkpointing: False
  use_distributed_sampler: False
  max_epochs: 9999
  max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
  log_every_n_steps: 10 # frequency with which training steps are logged 
  val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
  gradient_clip_val: 1.0

exp_manager:
  explicit_log_dir: null
  exp_dir: null
  name: ${name}
  create_wandb_logger: False
  wandb_logger_kwargs:
    project: null
    name: null
  resume_if_exists: True
  resume_ignore_no_checkpoint: True
  create_checkpoint_callback: True
  checkpoint_callback_params:
    monitor: validation_${model.data.test_ds.metric.name}
    save_top_k: 1
    mode: max
    save_nemo_on_train_end: True
    filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
    model_parallel_size: ${model.tensor_model_parallel_size}
    always_save_nemo: True
model:
  seed: 1234
  tensor_model_parallel_size: 8 # intra-layer model parallelism
  pipeline_model_parallel_size: 2 # inter-layer model parallelism

  global_batch_size: 16
  micro_batch_size: 1
  restore_from_path: /workspace/results/checkpoints/megatron_gpt_peft_none_tuning.nemo # Path to an existing .nemo model you wish to add new tasks to or run inference with
  resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
  save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
  sync_batch_comm: False
  megatron_amp_O2: False

  ## Sequence Parallelism
  # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
  # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
  sequence_parallel: False

  ## Activation Checkpoint
  activations_checkpoint_granularity: null # 'selective' or 'full'
  activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
  # 'uniform' divides the total number of transformer layers and checkpoints the input activation
  # of each chunk at the specified granularity
  # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
  activations_checkpoint_num_layers: null # not used with 'selective'
  activations_checkpoint_layers_per_pipeline: null
  answer_only_loss: True
  gradient_as_bucket_view: False

  hidden_dropout: 0.0
  attention_dropout: 0.0
  ffn_dropout: 0.0

  peft:
    peft_scheme: "none"  # can be either adapter,ia3, or ptuning
    restore_from_path: null
    restore_from_ckpt:
      checkpoint_dir: null
      checkpoint_name: null

    # Used for adapter peft training
    adapter_tuning:
      type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
      adapter_dim: 32
      adapter_dropout: 0.0
      norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used.
      column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used,  options are ['layernorm', 'mixedfusedlayernorm']
      layer_selection: null  # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers
      weight_tying: False
      position_embedding_strategy: null # used only when weight_tying is True

    lora_tuning:
      target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
      adapter_dim: 32
      adapter_dropout: 0.0
      column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
      layer_selection:  null  # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
      weight_tying: False
      position_embedding_strategy: null # used only when weight_tying is True

    # Used for p-tuning peft training
    p_tuning:
      virtual_tokens: 10  # The number of virtual tokens the prompt encoder should add at the start of the sequence
      bottleneck_dim: 1024  # the size of the prompt encoder mlp bottleneck
      embedding_dim: 1024  # the size of the prompt encoder embeddings
      init_std: 0.023

    ia3_tuning:
      layer_selection:  null  # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

  data:
    test_ds:
      file_names: ['/workspace/databricks-dolly-15k/test.jsonl'] # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
      names: ['dolly-15k_test'] # Names of the corresponding datasets used to log metrics.
      global_batch_size: 16
      micro_batch_size: 1
      shuffle: False
      num_workers: 0
      pin_memory: True
      max_seq_length: 2048
      min_seq_length: 1
      drop_last: False
      context_key: 'input'
      label_key: ${data.train_ds.label_key}
      add_eos: ${data.train_ds.add_eos}
      add_sep: ${data.train_ds.add_sep}
      add_bos: ${data.train_ds.add_bos}
      write_predictions_to_file: True
      output_file_path_prefix: /workspace/results/sft_results # Prefix of the file to write predictions to.
      truncation_field: ${data.train_ds.truncation_field} # Options: keys in prompt_template
      index_mapping_dir: null # Path to a directory to write index mapping files.
      prompt_template: ${data.train_ds.prompt_template}
      tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
      truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']

      metric:
        name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
        average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
        num_classes: null

inference:
  greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
  top_k: 0  # The number of highest probability vocabulary tokens to keep for top-k-filtering.
  top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
  temperature: 1.0 # sampling temperature
  all_probs: False  # whether return the log prob for all the tokens in vocab
  repetition_penalty: 1.2  # The parameter for repetition penalty. 1.0 means no penalty.
  min_tokens_to_generate: 0  # The minimum length of the sequence to be generated.
  compute_logprob: False  # a flag used to compute logprob of all the input text, a very special case of running inference, default False
  outfile_path: /workspace/results/output.txt
  compute_attention_mask: True

# server-related configs
server: False  # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: True  # whether create a public URL
username: test # user name for web client
password: test2  # password for web client
web_port: 9889 # the port number of the web server 1058
chat: False # use the chat interface
chatbot_config:
  value: False   # whether to inject the value attributes
  attributes:
    - name: Quality
      min: 0
      max: 4
      key: quality
      type: int
      default: 4
    - name: Toxicity
      min: 0
      max: 4
      key: toxcity
      type: int
      default: 0
    - name: Humor
      min: 0
      max: 4
      key: humor
      type: int
      default: 0
    - name: Creativity
      min: 0
      max: 4
      key: creativity
      type: int
      default: 0
    - name: Violence
      min: 0
      max: 4
      key: violence
      type: int
      default: 0
    - name: Helpfulness
      min: 0
      max: 4
      key: helpfulness
      type: int
      default: 4
    - name: Not_Appropriate
      min: 0
      max: 4
      key: not_appropriate
      type: int
      default: 0
    - name: Language
      choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh']
      key: lang
      type: list
      default: en

  user: User
  assistant: Assistant
  system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
aditya-malte commented 4 months ago

Hi, do you also face this issue in megatron_gpt_finetuning.py if write_predictions_to_file is set to True? I'm also facing the same error that way

ArthurJiang commented 4 months ago

image I also meet this issue, w/ a NumPy array is not writable, when the write_predictions_to_file is True.

eagle705 commented 3 months ago

I encountered the same issue within 24.05 image

hengruizhang98 commented 3 months ago

Meet the same issue.

A-baoYang commented 3 months ago

Meet the same issue within 24.01 image

dimapihtar commented 2 months ago

I'm not seeing this issue with our latest dev container so the fix should be alredy in main. It will be included in upcoming 24.07 release container.