Closed howard-yen closed 1 year ago
cc @pacman100
Hello @howard-yen, I don't see any issues.
Using the main branch of Accelerate and Transformers PR https://github.com/huggingface/transformers/pull/25820 which fixes a bug with efficient loading of model while using FSDP.
Command:
cd transformers
export CUDA_VISIBLE_DEVICES="0,1"
torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/language-modeling/run_clm.py --model_name_or_path facebook/opt-350m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --output_dir output --fsdp "shard_grad_op auto_wrap"
In run_clm.py
, edit it to view the FSDP config:
...
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
)
+ print(f"{trainer.accelerator.state.fsdp_plugin}")
...
Output:
[2023-08-29 12:10:31,536] torch.distributed.run: [WARNING]
[2023-08-29 12:10:31,536] torch.distributed.run: [WARNING] *****************************************
[2023-08-29 12:10:31,536] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2023-08-29 12:10:31,536] torch.distributed.run: [WARNING] *****************************************
[2023-08-29 12:10:34,727] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-08-29 12:10:34,742] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
08/29/2023 12:10:36 - WARNING - __main__ - Process rank: 1, device: cuda:1, n_gpu: 1distributed training: True, 16-bits training: False
08/29/2023 12:10:36 - WARNING - __main__ - Process rank: 0, device: cuda:0, n_gpu: 1distributed training: True, 16-bits training: False
08/29/2023 12:10:36 - INFO - __main__ - Training/evaluation parameters TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[<FSDPOption.SHARD_GRAD_OP: 'shard_grad_op'>, <FSDPOption.AUTO_WRAP: 'auto_wrap'>],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=5e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=output/runs/Aug29_12-10-36_hf-dgx-01,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=500,
logging_strategy=steps,
lr_scheduler_type=linear,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
mp_parameters=,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_torch,
optim_args=None,
output_dir=output,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=8,
per_device_train_batch_size=8,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=True,
report_to=[],
resume_from_checkpoint=None,
run_name=output,
save_on_each_node=False,
save_safetensors=False,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
sharded_ddp=[],
skip_memory_metrics=True,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.0,
)
08/29/2023 12:10:37 - WARNING - datasets.builder - Found cached dataset wikitext (/raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
08/29/2023 12:10:37 - INFO - datasets.builder - Using custom data configuration wikitext-2-raw-v1-ddf29beda1b1b3d3
08/29/2023 12:10:37 - INFO - datasets.info - Loading Dataset Infos from /raid/sourab/.cache/huggingface/modules/datasets_modules/datasets/wikitext/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 1225.69it/s]
08/29/2023 12:10:37 - INFO - datasets.builder - Overwrite dataset info from restored data version if exists.
08/29/2023 12:10:37 - INFO - datasets.info - Loading Dataset info from /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126
08/29/2023 12:10:37 - WARNING - datasets.builder - Found cached dataset wikitext (/raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
08/29/2023 12:10:37 - INFO - datasets.info - Loading Dataset info from /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 1283.71it/s]
[INFO|configuration_utils.py:715] 2023-08-29 12:10:37,431 >> loading configuration file config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/config.json
[INFO|configuration_utils.py:775] 2023-08-29 12:10:37,431 >> Model config OPTConfig {
"_name_or_path": "facebook/opt-350m",
"_remove_final_layer_norm": false,
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"enable_bias": true,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layer_norm_elementwise_affine": true,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.33.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
[INFO|configuration_utils.py:715] 2023-08-29 12:10:37,546 >> loading configuration file config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/config.json
[INFO|configuration_utils.py:775] 2023-08-29 12:10:37,547 >> Model config OPTConfig {
"_name_or_path": "facebook/opt-350m",
"_remove_final_layer_norm": false,
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"enable_bias": true,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layer_norm_elementwise_affine": true,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.33.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file vocab.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/vocab.json
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file merges.txt from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/merges.txt
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file tokenizer.json from cache at None
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file added_tokens.json from cache at None
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file special_tokens_map.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/special_tokens_map.json
[INFO|tokenization_utils_base.py:1852] 2023-08-29 12:10:37,556 >> loading file tokenizer_config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/tokenizer_config.json
[INFO|configuration_utils.py:715] 2023-08-29 12:10:37,556 >> loading configuration file config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/config.json
[INFO|configuration_utils.py:775] 2023-08-29 12:10:37,557 >> Model config OPTConfig {
"_name_or_path": "facebook/opt-350m",
"_remove_final_layer_norm": false,
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"enable_bias": true,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layer_norm_elementwise_affine": true,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.33.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
[INFO|configuration_utils.py:715] 2023-08-29 12:10:37,810 >> loading configuration file config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/config.json
[INFO|configuration_utils.py:775] 2023-08-29 12:10:37,811 >> Model config OPTConfig {
"_name_or_path": "facebook/opt-350m",
"_remove_final_layer_norm": false,
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": false,
"dropout": 0.1,
"enable_bias": true,
"eos_token_id": 2,
"ffn_dim": 4096,
"hidden_size": 1024,
"init_std": 0.02,
"layer_norm_elementwise_affine": true,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.33.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 512
}
[INFO|modeling_utils.py:2855] 2023-08-29 12:10:37,874 >> loading weights file pytorch_model.bin from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/pytorch_model.bin
[INFO|configuration_utils.py:768] 2023-08-29 12:10:38,153 >> Generate config GenerationConfig {
"_from_model_config": true,
"bos_token_id": 2,
"eos_token_id": 2,
"pad_token_id": 1,
"transformers_version": "4.33.0.dev0"
}
[INFO|modeling_utils.py:3635] 2023-08-29 12:10:39,049 >> All model checkpoint weights were used when initializing OPTForCausalLM.
[INFO|modeling_utils.py:3643] 2023-08-29 12:10:39,049 >> All the weights of OPTForCausalLM were initialized from the model checkpoint at facebook/opt-350m.
If your task is similar to the task the model of the checkpoint was trained on, you can already use OPTForCausalLM for predictions without further training.
[INFO|configuration_utils.py:730] 2023-08-29 12:10:39,160 >> loading configuration file generation_config.json from cache at /raid/sourab/.cache/huggingface/models--facebook--opt-350m/snapshots/cb32f77e905cccbca1d970436fb0f5e6b58ee3c5/generation_config.json
[INFO|configuration_utils.py:768] 2023-08-29 12:10:39,161 >> Generate config GenerationConfig {
"_from_model_config": true,
"bos_token_id": 2,
"eos_token_id": 2,
"pad_token_id": 1,
"transformers_version": "4.33.0.dev0"
}
08/29/2023 12:10:39 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-fe29956cc983b056.arrow
08/29/2023 12:10:39 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-cecdce462984f30c.arrow
08/29/2023 12:10:39 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ea4ff1f9a70456e2.arrow
08/29/2023 12:10:40 - WARNING - __main__ - The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can override this default with `--block_size xxx`.
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0ec6d944334ff56a.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-52830176c9b26401.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-32ee55629de51a05.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-fe29956cc983b056.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-cecdce462984f30c.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ea4ff1f9a70456e2.arrow
08/29/2023 12:10:40 - WARNING - __main__ - The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can override this default with `--block_size xxx`.
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0ec6d944334ff56a.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-52830176c9b26401.arrow
08/29/2023 12:10:40 - WARNING - datasets.arrow_dataset - Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1-ddf29beda1b1b3d3/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-32ee55629de51a05.arrow
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.SHARD_GRAD_OP: 2>, backward_prefetch=None, mixed_precision_policy=None, auto_wrap_policy=None, cpu_offload=CPUOffload(offload_params=False), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, use_dtensor=False, rank0_only=True), optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, use_dtensor=False, rank0_only=True), limit_all_gathers=False, use_orig_params=False, param_init_fn=<function FullyShardedDataParallelPlugin.__post_init__.<locals>.<lambda> at 0x7f81d8c71630>, sync_module_states=True, forward_prefetch=False, activation_checkpointing=False)
FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.SHARD_GRAD_OP: 2>, backward_prefetch=None, mixed_precision_policy=None, auto_wrap_policy=None, cpu_offload=CPUOffload(offload_params=False), ignored_modules=None, state_dict_type=<StateDictType.FULL_STATE_DICT: 1>, state_dict_config=FullStateDictConfig(offload_to_cpu=True, use_dtensor=False, rank0_only=True), optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, use_dtensor=False, rank0_only=True), limit_all_gathers=False, use_orig_params=False, param_init_fn=<function FullyShardedDataParallelPlugin.__post_init__.<locals>.<lambda> at 0x7ff575be9630>, sync_module_states=True, forward_prefetch=False, activation_checkpointing=False)
[INFO|trainer.py:1714] 2023-08-29 12:10:41,785 >> ***** Running training *****
[INFO|trainer.py:1715] 2023-08-29 12:10:41,785 >> Num examples = 2,355
[INFO|trainer.py:1716] 2023-08-29 12:10:41,785 >> Num Epochs = 3
[INFO|trainer.py:1717] 2023-08-29 12:10:41,785 >> Instantaneous batch size per device = 8
[INFO|trainer.py:1720] 2023-08-29 12:10:41,785 >> Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:1721] 2023-08-29 12:10:41,785 >> Gradient Accumulation steps = 1
[INFO|trainer.py:1722] 2023-08-29 12:10:41,785 >> Total optimization steps = 444
[INFO|trainer.py:1723] 2023-08-29 12:10:41,786 >> Number of trainable parameters = 165,598,208
23%|βββββββββββββββββββββ | 104/444 [02:42<08:48, 1.56s/it]
Notice the sharding strategy is correctly set to ShardingStrategy.SHARD_GRAD_OP
as passed in the cmd args.
@pacman100 thanks for taking a look!
System Info
transformers
version: 4.31.0Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Using the official
run_clm.py
script with FSDP enabled:where
fsdp_config.json
looks like:Expected behavior
We expect to use the sharding strategy
shard_grad_op
, but the accelerator is not instantiated with the fsdp config increate_accelerator_and_postprocess()
. As a result, if we print outself.accelerator.state.fsdp_plugin.sharding_strategy
at the end of__init__
, we get the default sharding strategyfull_shard
, even thoughself.fsdp == shard_grad_op
.I did not set the sharding strategy using
accelerate config
since I'm experimenting with different strategy and I believe it would make sense to overwrite the default strategy with the input config.I'm not completely sure if this would be the correct fix, but I found the following to work with the intended behavior:
where we update the sharding_strategy after determining the strategy used for fsdp from args in
__init__()
.