Closed steveepreston closed 1 month ago
@steveepreston you confirmed it works on 4.44.0?
Hey @muellerzr Yes.
Tested below versions from 4.43.1 to 4.45.2 one by one. for each test, full restarted session/kernel.
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
same error on dev build:
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
!pip install -qq git+https://github.com/huggingface/transformers.git@###
(Aug 16) 0b066be
Revert PR 32299, flag users when Zero-3 > RuntimeError: There are currently no available devices
(Aug 13) 481e156
Add support for GrokAdamW optimizer > RuntimeError: There are currently no available devices
(Aug 12) f1c8542
"to be not" -> "not to be" > RuntimeError: There are currently no available devices
Broken after 194cf1f
(Aug 6) 194cf1f
Migrate import checks not need accelerate > Success
(Aug 1) 82efc53
Yell at the user if zero-3 init wasn't performed > Success
!pip install -qq git+https://github.com/huggingface/transformers.git@###
(Aug 7) 46d09af
enable xla fsdp > RuntimeError: There are currently no available devices
Broken After 194cf1f
(Aug 6) 194cf1f
Migrate import checks not need accelerate > Success
transformers
on Aug 6 - Aug 7(Aug 7) 46d09af
enable xla fsdp > RuntimeError: There are currently no available devices
(Aug 6) 7ad784a
Gemma2: add cache warning > Success
(Aug 6) 194cf1f
Migrate import checks not need accelerate > Success
Problem Found:
Commit Caused Error is:
46d09af
enable xla fsdp > RuntimeError: There are currently no available devices
⚠️ @hanwen-sun Your commit caused Error for Trainer on TPU VM. Fix it please.
Hey @steveepreston, we probably need to revert this commit as I just checked that the fsdp integration in accelerate do not support xla yet. We only have this integration in Trainer as you can see here. Another solution would be to add the integration in accelerate. Would you like to open a PR to revert this PR first ?
Hey @SunMarc. Thank for attention!
I'm not deeply familiar with fsdp
. i just tested and saw that SFTTrainer
worked like a charm on TPU VM on 4.44.2
and due to super fast train speed i though it's using the power of TPU
.
btw, i created a PL to revert the Error Throwing by 46d09af
commit.
@SunMarc Thank you for your support. The error gone now and Trainer works again ✅
But I confused after your explain. was that past commit correct in fact and was RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'
an expected correct Error for running this example on TPU VM v3-8?
so now we are bypassing accelerate
? if yes, this means we are not using TPU
power? no parallel/distributed training?
so how model is trained now? does it trained on cpu0 and ignores [xla0, xla1, xla2, xla3, xla4, xla5, xla6, xla7]?
Sorry for this newbie question. Please explain a little i'm really confused. Thanks.
i wonder if accelerate
supports npu
and xpu
but not tpu
and what about the official blog post for Fine-Tuning Gemma Models in Huggingface website?
hi, actually accelerate support xla fsdp in this pr: https://github.com/huggingface/accelerate/pull/2176. But we only integrate it in transformers: https://github.com/huggingface/transformers/pull/29334.
Trl init from fsdp_plugin
in accelerate to decide whether to use fsdp: https://github.com/huggingface/trl/blob/02f4e750c07c5a470f2d82a3a59e011401b5c63e/trl/trainer/ppo_trainer.py#L204 and accelerate fsdp_plugin
init from the env variable ACCELERATE_USE_FSDP
: https://github.com/huggingface/accelerate/blob/a84327e59652b79b1f6e3be58be634fbd35184f3/src/accelerate/accelerator.py#L348, which is set here: https://github.com/huggingface/transformers/blob/3f06f95ebe617b192251ef756518690f5bc7ff76/src/transformers/training_args.py#L1939.
@hanwen-sun Hey, Thank for explain.
Am i understanding correct:
Accelerate
supports xla fsdpTransformers Trainer
supports xla fsdpTRL SFTTrainer
inherits Transformers Trainer
, it supports xla fsdp tooBut not self.fsdp_config["xla"]
in the if
code seems is bypassing to set os.environ["ACCELERATE_USE_FSDP"] = "true"
Correct me if i'm wrong @hanwen-sun but XLA FSDP requires to use torch_xla
modules such as from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
. We have that in transformers but not in accelerate. Trl is a wrapper around Trainer so it will use the same code path.
@steveepreston @SunMarc I will take some time to check this and give you a reply tomorrow.
@SunMarc @hanwen-sun Thank you both!
I'm agree with @SunMarc
I think problem is on accelerate
side
Once again see error trace:
SFTTrainer()
SFTTrainer.__init__()
Trainer.__init__()
Accelerator.__init__()
FullyShardedDataParallelPlugin.__post_init__()
> ⛔ must be one of 'XPU', 'CUDA', or 'NPU'
@steveepreston @SunMarc sorry I made a mistake. The Accelerator does not support XLA FSDP; instead, it wraps FSDP within transformers/trainer.py. The Accelerator checks the device in FullyShardedDataParallelPlugin.__post_init__(). Previously, we used GPU as the backend for XLA, which allowed us to run the code successfully. However, this approach will not work correctly for TPU. @steveepreston, I am not sure if SFTTrainer correctly wraps XLA FSDP. You might want to perform some checks. Set accelerator_use_fsdp=Flase could potentially cause issues with the accelerator.clip_grad_norm method for xla fsdp. I will open a new issue later and keep you informed.
@hanwen-sun Thank you for checking.
Can you please check that I'm correct? Then I can deep into this issue and debug it:
without XLA, torch operations run on cpu0 and ignores [xla0, xla1, xla2, xla3, xla4, xla5, xla6, xla7].
xla xr.use_spmd()
enables us to distributes training process on all cores of TPU.
so what is the point of FSDP? just to optimize and speedup this distribution?
sorry for newbie question
@steveepreston FSDP is a type of distributed training strategy which aims to fully utilize the computation resource of hardware. You can refer to https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html. I'm not family with the use_spmd(). But you are right in general.
@hanwen-sun Thanks for the note
System Info
transformers: v4.45.0 and up (any of v4.45.0 / v4.45.1 / v4.45.2) accelerate: v1.0.1 (same result on v0.34.2)
Who can help?
trainer experts: @muellerzr @SunMarc accelerate expert: @muellerzr text models expert: @ArthurZucker Thank you guys!
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Minimal working code is Here. Code follows GoogleCloudPlatform example
on TPU VM, train done like a charm on transformers from v4.43.1 to v4.44.2, but when upgrading to any of v4.45.0 / v4.45.1 / v4.45.2 it throws this Error:
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'.
Error Traceback:
General traceback is: callling
SFTTrainer()
>self.accelerator = Accelerator(**args)
(transformers/trainer.py)Click here to Show Full Error Traceback
```python --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[48], line 4 1 from trl import SFTTrainer 2 from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments ----> 4 trainer = SFTTrainer( 5 model=base_model, 6 train_dataset=data, 7 args=TrainingArguments( 8 per_device_train_batch_size=BATCH_SIZE, # This is actually the global batch size for SPMD. 9 num_train_epochs=1, 10 max_steps=-1, 11 output_dir="/output_dir", 12 optim="adafactor", 13 logging_steps=1, 14 dataloader_drop_last = True, # Required for SPMD. 15 fsdp="full_shard", 16 fsdp_config=fsdp_config, 17 ), 18 peft_config=lora_config, 19 dataset_text_field="quote", 20 max_seq_length=max_seq_length, 21 packing=True, 22 ) File /usr/local/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:101, in _deprecate_arguments.My observation and guess
I tested multiple times, and can confirm that this error is Directly Caused by only changing version of
transformers
. Thereforeaccelerate
version was fixed during all runs, my guess is something changed onv4.45.0
(maybe ontrainer.py
) that affectsargs
in theself.accelerator = Accelerator(**args)
, so that error will raised byaccelerate
.Expected behavior
my guess:
args
corrected andself.accelerator = Accelerator(**args)
called correctly. soaccelerate
can work onTPU
.