huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.71k stars 937 forks source link

'AcceleratorState' object has no attribute 'distributed_type' #2564

Closed Zhuxing01 closed 5 months ago

Zhuxing01 commented 6 months ago

System Info

accelerate == 0.28.0
python == 3.8.18
transformers == 4.37.2
torch == 2.2.1

Information

Tasks

Reproduction

I was fine-tuning a bert-base model using codes from huggingface's example.I just copied the code from the website and then the bug occured. Code and bug are listed below. By the way,the pycharm told me that accelerator's arguments should be list but I give it Dataloader. Bug:
AttributeError: 'AcceleratorState' object has no attribute 'distributed_type'

Code:

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2) 
train_dataloader = DataLoader(datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator) 
eval_dataloader = DataLoader(datasets["validation"], batch_size=8, collate_fn=data_collator) 
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)

Expected behavior

the code works properly.

muellerzr commented 6 months ago

Please give us the full stack trace

Zhuxing01 commented 6 months ago

full stack trace:

Traceback (most recent call last):
  File "C:\Users\Noah\PycharmProjects\pythonProject\venv\fine-tune.py", line 44, in <module>
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)
  File "E:\Anaconda\envs\pytorch\lib\site-packages\accelerate\accelerator.py", line 1219, in prepare
    if self.distributed_type == DistributedType.DEEPSPEED:
  File "E:\Anaconda\envs\pytorch\lib\site-packages\accelerate\accelerator.py", line 509, in distributed_type
    return self.state.distributed_type
AttributeError: 'AcceleratorState' object has no attribute 'distributed_type'
muellerzr commented 6 months ago

What is your full, entire code please.

Zhuxing01 commented 6 months ago

Entire code:

import torch
import numpy as np
import evaluate
from datasets import load_dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments
from transformers import Trainer, AdamW, get_scheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate import Accelerator

def tokenize_function(example):
    return tokenizer(example['sentence1'], example["sentence2"], truncation=True)

checkpoint = r"E:\model\bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
accelerator = Accelerator()
device = accelerator.device
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
raw_dataset = load_dataset(r"E:\date\glue-mrpc")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
datasets = raw_dataset.map(tokenize_function, batched=True)
datasets = datasets.remove_columns(["sentence1", "sentence2", "idx"])
datasets = datasets.rename_column("label", "labels")
datasets.set_format("torch")
print(datasets["train"].column_names)
train_dataloader = DataLoader(datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator)
eval_dataloader = DataLoader(datasets["validation"], batch_size=8, collate_fn=data_collator)
for batch in train_dataloader:
    break
print({k: v.shape for k, v in batch.items()})
outputs = model(**batch)
print(outputs["loss"], outputs.logits.shape)
optimizer = AdamW(model.parameters(), lr=3e-5)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)
num_epoch = 3
num_train_steps = len(train_dataloader)*num_epoch
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
)
print(num_train_steps)

progress_bar = tqdm(range(num_train_steps))
model.train()
for epoch in range(num_epoch):
    for batch in train_dataloader:
       # batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
      #  loss.backward()
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
metric = evaluate.load(r"E:\微调\测试\metrics\glue","mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])
temp = metric.compute()
muellerzr commented 6 months ago

Thanks, one more bit we need (this is great, as this issue has been reported before with none of us being able to reproduce, so this is very helpful!)

How are you:

  1. Launching the code (python, accelerate launch or torchrun?)
  2. If you are using accelerate launch, what is the output of accelerate env
Zhuxing01 commented 6 months ago
  1. I lauch the code by python in pycharm
  2. sorry,I am afraid I can not answer you this question right now, but I will answer you tomorrow. Thanks for helping me solving this bug!
muellerzr commented 6 months ago

Just running in python I don't see this:

command: python test.py file:

from datasets import load_dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate import Accelerator

def tokenize_function(example):
    return tokenizer(example['sentence1'], example["sentence2"], truncation=True)

checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
accelerator = Accelerator()
device = accelerator.device
raw_dataset = load_dataset("glue", "mrpc")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
datasets = raw_dataset.map(tokenize_function, batched=True)
datasets = datasets.remove_columns(["sentence1", "sentence2", "idx"])
datasets = datasets.rename_column("label", "labels")
datasets.set_format("torch")
train_dataloader = DataLoader(datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator)
eval_dataloader = DataLoader(datasets["validation"], batch_size=8, collate_fn=data_collator)
optimizer = AdamW(model.parameters(), lr=3e-5)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)

Can you verify if that fails to run for you? Thanks!

muellerzr commented 6 months ago

Also the output of pip freeze too

Zhuxing01 commented 6 months ago

accelerate env:

pip freeze:

accelerate==0.28.0
aiohttp @ file:///C:/b/abs_27h_1rpxgd/croot/aiohttp_1707342354614/work
aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
arrow @ file:///C:/b/abs_cal7u12ktb/croot/arrow_1676588147908/work
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
async-timeout @ file:///C:/b/abs_c8fgiuixkq/croot/async-timeout_1703097556097/work
attrs @ file:///C:/b/abs_35n0jusce8/croot/attrs_1695717880170/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work
bitarray==2.9.2
Bottleneck @ file:///C:/b/abs_f05kqh7yvj/croot/bottleneck_1707864273291/work
Brotli @ file:///C:/Windows/Temp/abs_63l7912z0e/croots/recipe/brotli-split_1659616056886/work
certifi @ file:///C:/b/abs_35d7n66oz9/croot/certifi_1707229248467/work/certifi
cffi==1.16.0
chardet @ file:///C:/ci/chardet_1607690654534/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click @ file:///C:/b/abs_f9ihnt72pu/croot/click_1698129847492/work
colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
comm @ file:///C:/b/abs_67a8058udb/croot/comm_1709322909844/work
contourpy @ file:///C:/b/abs_d5rpy288vc/croots/recipe/contourpy_1663827418189/work
cookiecutter @ file:///C:/b/abs_3d1730toam/croot/cookiecutter_1700677089156/work
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
Cython==3.0.9
datasets==2.18.0
debugpy @ file:///C:/b/abs_c0y1fjipt2/croot/debugpy_1690906864587/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
dill @ file:///C:/b/abs_42h_07z1yj/croot/dill_1667919550096/work
evaluate @ file:///C:/b/abs_b0u82qzwka/croot/evaluate_1679573623633/work
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
-e git+https://github.com/pytorch/fairseq@34973a94d09ecc12092a5ecc8afece5e536b7692#egg=fairseq
filelock @ file:///C:/b/abs_f2gie28u58/croot/filelock_1700591233643/work
fonttools==4.25.0
frozenlist @ file:///C:/b/abs_d8e__s1ys3/croot/frozenlist_1698702612014/work
fsspec @ file:///C:/b/abs_97mpfsesn0/croot/fsspec_1701286534629/work
gmpy2 @ file:///C:/ci/gmpy2_1645456279018/work
huggingface-hub @ file:///C:/b/abs_2f65ujfg51/croot/huggingface_hub_1708635586197/work
hydra-core==1.0.7
idna @ file:///C:/b/abs_bdhbebrioa/croot/idna_1666125572046/work
importlib-metadata @ file:///C:/b/abs_c1egths604/croot/importlib_metadata-suite_1704813568388/work
importlib-resources @ file:///C:/b/abs_d0dmp77t95/croot/importlib_resources-suite_1704281892795/work
ipykernel @ file:///C:/b/abs_c2u94kxcy6/croot/ipykernel_1705933907920/work
ipython @ file:///C:/b/abs_254uk73z5b/croot/ipython_1691532131313/work
jedi @ file:///C:/ci/jedi_1644315425835/work
Jinja2 @ file:///C:/b/abs_f7x5a8op2h/croot/jinja2_1706733672594/work
joblib @ file:///C:/b/abs_1anqjntpan/croot/joblib_1685113317150/work
jupyter_client @ file:///C:/b/abs_a6h3c8hfdq/croot/jupyter_client_1699455939372/work
jupyter_core @ file:///C:/b/abs_c769pbqg9b/croot/jupyter_core_1698937367513/work
kiwisolver @ file:///C:/b/abs_88mdhvtahm/croot/kiwisolver_1672387921783/work
lightgbm @ file:///C:/b/abs_40p8j029wh/croot/lightgbm_1700267980215/work
lxml==5.1.0
markdown-it-py @ file:///C:/b/abs_a5bfngz6fu/croot/markdown-it-py_1684279915556/work
MarkupSafe @ file:///C:/b/abs_ecfdqh67b_/croot/markupsafe_1704206030535/work
matplotlib @ file:///C:/b/abs_085jhivdha/croot/matplotlib-suite_1693812524572/work
matplotlib-inline @ file:///C:/ci/matplotlib-inline_1661934035815/work
mdurl @ file:///C:/Windows/TEMP/abs_3197pzpjbi/croots/recipe/mdurl_1659716032440/work
mkl-fft @ file:///C:/b/abs_19i1y8ykas/croot/mkl_fft_1695058226480/work
mkl-random @ file:///C:/b/abs_edwkj1_o69/croot/mkl_random_1695059866750/work
mkl-service==2.4.0
mlxtend @ file:///C:/b/abs_abdi0b9anv/croot/mlxtend_1708368176260/work
mpmath @ file:///C:/b/abs_7833jrbiox/croot/mpmath_1690848321154/work
multidict @ file:///C:/b/abs_44ido987fv/croot/multidict_1701097803486/work
multiprocess @ file:///C:/b/abs_ca0rg6wl6_/croot/multiprocess_1668006439310/work
munkres==1.1.4
nest-asyncio @ file:///C:/b/abs_65d6lblmoi/croot/nest-asyncio_1708532721305/work
networkx @ file:///C:/b/abs_e6gi1go5op/croot/networkx_1690562046966/work
numexpr @ file:///C:/b/abs_afm0oewmmt/croot/numexpr_1683221839116/work
numpy @ file:///C:/Users/dev-admin/mkl/numpy_and_numpy_base_1682982345978/work
omegaconf==2.0.6
packaging @ file:///C:/b/abs_28t5mcoltc/croot/packaging_1693575224052/work
pandas @ file:///C:/miniconda3/conda-bld/pandas_1692299636855/work
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
pillow @ file:///C:/b/abs_e22m71t0cb/croot/pillow_1707233126420/work
platformdirs @ file:///C:/b/abs_b6z_yqw_ii/croot/platformdirs_1692205479426/work
pooch @ file:///tmp/build/80754af9/pooch_1623324770023/work
portalocker==2.8.2
prompt-toolkit @ file:///C:/b/abs_68uwr58ed1/croot/prompt-toolkit_1704404394082/work
protobuf==3.20.3
psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pyarrow @ file:///C:/b/abs_93i_y2dub4/croot/pyarrow_1707330894046/work/python
pyarrow-hotfix==0.6
pycparser==2.21
Pygments @ file:///C:/b/abs_fay9dpq4n_/croot/pygments_1684279990574/work
pyparsing @ file:///C:/Users/BUILDE~1/AppData/Local/Temp/abs_7f_7lba6rl/croots/recipe/pyparsing_1661452540662/work
PySocks @ file:///C:/ci/pysocks_1605287845585/work
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work
pytz @ file:///C:/b/abs_19q3ljkez4/croot/pytz_1695131651401/work
pywin32==305.1
PyYAML @ file:///C:/b/abs_782o3mbw7z/croot/pyyaml_1698096085010/work
pyzmq @ file:///C:/b/abs_89aq69t0up/croot/pyzmq_1705605705281/work
regex @ file:///C:/b/abs_d5e2e5uqmr/croot/regex_1696515472506/work
requests @ file:///C:/b/abs_474vaa3x9e/croot/requests_1707355619957/work
responses @ file:///tmp/build/80754af9/responses_1619800270522/work
rich @ file:///C:/b/abs_09j2g5qnu8/croot/rich_1684282185530/work
sacrebleu==2.4.0
safetensors @ file:///C:/b/abs_88nwhm1qj3/croot/safetensors_1708633899663/work
scikit-learn @ file:///C:/b/abs_daon7wm2p4/croot/scikit-learn_1694788586973/work
scipy==1.10.1
seaborn @ file:///C:/b/abs_68ltdkoyoo/croot/seaborn_1673479199997/work
six @ file:///tmp/build/80754af9/six_1644875935023/work
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
sympy @ file:///C:/b/abs_82njkonm7f/croot/sympy_1701397685028/work
tabulate==0.9.0
tensorboardX @ file:///tmp/build/80754af9/tensorboardx_1621440489103/work
text-unidecode @ file:///Users/ktietz/demo/mc3/conda-bld/text-unidecode_1629401354553/work
threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
tokenizers @ file:///C:/b/abs_ba1xnwbavr/croot/tokenizers_1708633883417/work
torch==2.2.1
torchaudio==2.2.1
torchvision==0.17.1
tornado @ file:///C:/b/abs_0cbrstidzg/croot/tornado_1696937003724/work
tqdm @ file:///C:/b/abs_f76j9hg7pv/croot/tqdm_1679561871187/work
traitlets @ file:///C:/b/abs_e5m_xjjl94/croot/traitlets_1671143896266/work
transformers @ file:///C:/b/abs_495035xqf4/croot/transformers_1708700635762/work
typing_extensions @ file:///C:/b/abs_72cdotwc_6/croot/typing_extensions_1705599364138/work
tzdata @ file:///croot/python-tzdata_1690578112552/work
Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work
urllib3 @ file:///C:/b/abs_4etpfrkumr/croot/urllib3_1707770616184/work
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306167264/work
xxhash @ file:///C:/b/abs_e59ry5bslh/croot/python-xxhash_1667919514787/work
yarl @ file:///C:/b/abs_8bxwdyhjvp/croot/yarl_1701105248152/work
zipp @ file:///C:/b/abs_b0beoc27oa/croot/zipp_1704206963359/work

When I use your code above, it works properly. And after trying a lot times , I deem I found the cause of the bug: training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch") Everytime I add this to my code, then the same bug occurs. When I delete this sentence, the code works properly.