microsoft / dp-transformers

Differentially-private transformers using HuggingFace and Opacus
MIT License
100 stars 19 forks source link

Failed to run the example #26

Open xiehuanyi opened 1 year ago

xiehuanyi commented 1 year ago

I ran the example given

import os
os.environ["WANDB_DISABLED"] = "true"
!python examples/nlg-reddit/sample-level-dp/fine-tune-dp.py \
--output_dir scratch \
--model_name sshleifer/tiny-gpt2 \
--sequence_len 128 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 2 \
--evaluation_strategy steps \
--eval_steps 45 \
--log_level info \
--per_device_eval_batch_size 32 \
--eval_accumulation_steps 1 \
--seed 42 \
--target_epsilon 8 \
--per_sample_max_grad_norm 1.0 \
--prediction_loss_only \
--weight_decay 0.01 \
--remove_unused_columns False \
--num_train_epochs 3 \
--logging_steps 5 \
--max_grad_norm 0 \
--lr_scheduler_type constant \
--learning_rate 1e-4 \
--disable_tqdm True \
--dataloader_num_workers 2 

but got these

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /kaggle/working/dp-transformers/examples/nlg-reddit/sample-level-dp/fine-tun │
│ e-dp.py:137 in <module>                                                      │
│                                                                              │
│   134 if __name__ == "__main__":                                             │
│   135 │   arg_parser = transformers.HfArgumentParser((dp_transformers.Traini │
│   136 │   train_args, privacy_args, model_args = arg_parser.parse_args_into_ │
│ ❱ 137 │   main(Arguments(train=train_args, privacy=privacy_args, model=model │
│   138                                                                        │
│                                                                              │
│ /kaggle/working/dp-transformers/examples/nlg-reddit/sample-level-dp/fine-tun │
│ e-dp.py:125 in main                                                          │
│                                                                              │
│   122 │   )                                                                  │
│   123 │                                                                      │
│   124 │   try:                                                               │
│ ❱ 125 │   │   trainer.train()                                                │
│   126 │   finally:                                                           │
│   127 │   │   eps_prv = trainer.get_prv_epsilon()                            │
│   128 │   │   eps_rdp = trainer.get_rdp_epsilon()                            │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1645 in      │
│ train                                                                        │
│                                                                              │
│   1642 │   │   inner_training_loop = find_executable_batch_size(             │
│   1643 │   │   │   self._inner_training_loop, self._train_batch_size, args.a │
│   1644 │   │   )                                                             │
│ ❱ 1645 │   │   return inner_training_loop(                                   │
│   1646 │   │   │   args=args,                                                │
│   1647 │   │   │   resume_from_checkpoint=resume_from_checkpoint,            │
│   1648 │   │   │   trial=trial,                                              │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1938 in      │
│ _inner_training_loop                                                         │
│                                                                              │
│   1935 │   │   │   │   │   self.control = self.callback_handler.on_step_begi │
│   1936 │   │   │   │                                                         │
│   1937 │   │   │   │   with self.accelerator.accumulate(model):              │
│ ❱ 1938 │   │   │   │   │   tr_loss_step = self.training_step(model, inputs)  │
│   1939 │   │   │   │                                                         │
│   1940 │   │   │   │   if (                                                  │
│   1941 │   │   │   │   │   args.logging_nan_inf_filter                       │
│                                                                              │
│ /opt/conda/lib/python3.10/site-packages/dp_transformers/dp_utils.py:263 in   │
│ training_step                                                                │
│                                                                              │
│   260 │   │   │   raise NotImplementedError("DP currently doesn't support th │
│   261 │   │   elif self.use_apex:                                            │
│   262 │   │   │   raise NotImplementedError("DP currently doesn't support th │
│ ❱ 263 │   │   elif self.deepspeed:                                           │
│   264 │   │   │   raise NotImplementedError("DP currently doesn't support th │
│   265 │   │   else:                                                          │
│   266 │   │   │   loss.backward()                                            │
╰──────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'OpacusDPTrainer' object has no attribute 'deepspeed'

here is my environment

Package                                Version              Editable project location
-------------------------------------- -------------------- -------------------------
absl-py                                1.4.0
accelerate                             0.20.3
access                                 1.1.9
affine                                 2.4.0
aiobotocore                            2.5.0
aiofiles                               22.1.0
aiohttp                                3.8.4
aiohttp-cors                           0.7.0
aioitertools                           0.11.0
aiorwlock                              1.3.0
aiosignal                              1.3.1
aiosqlite                              0.19.0
albumentations                         1.3.1
alembic                                1.11.1
altair                                 5.0.1
annoy                                  1.17.2
ansiwrap                               0.8.4
anyio                                  3.6.2
apache-beam                            2.46.0
aplus                                  0.11.0
appdirs                                1.4.4
argon2-cffi                            21.3.0
argon2-cffi-bindings                   21.2.0
array-record                           0.2.0
arrow                                  1.2.3
arviz                                  0.12.1
astroid                                2.15.5
astropy                                5.3
asttokens                              2.2.1
astunparse                             1.6.3
async-timeout                          4.0.2
atpublic                               3.1.2
attrs                                  23.1.0
audioread                              3.0.0
autopep8                               2.0.2
Babel                                  2.12.1
backcall                               0.2.0
backoff                                2.2.1
backports.functools-lru-cache          1.6.4
bayesian-optimization                  1.4.3
bayespy                                0.5.26
beatrix-jupyterlab                     2023.58.190319
beautifulsoup4                         4.12.2
bidict                                 0.22.1
biopython                              1.81
blake3                                 0.2.1
bleach                                 6.0.0
blessed                                1.20.0
blinker                                1.6.2
blis                                   0.7.9
blosc2                                 2.0.0
bokeh                                  3.1.1
boltons                                23.0.0
Boruta                                 0.3
boto3                                  1.26.100
botocore                               1.29.76
bq-helper                              0.4.1                /src/bq-helper
bqplot                                 0.12.39
branca                                 0.6.0
brewer2mpl                             1.4.1
brotlipy                               0.7.0
cached-property                        1.5.2
cachetools                             4.2.4
Cartopy                                0.21.1
catalogue                              2.0.8
catalyst                               22.4
catboost                               1.2
category-encoders                      2.6.1
certifi                                2023.5.7
cesium                                 0.12.1
cffi                                   1.15.1
cftime                                 1.6.2
charset-normalizer                     2.1.1
chex                                   0.1.7
cleverhans                             4.0.0
click                                  8.1.3
click-plugins                          1.1.1
cligj                                  0.7.2
cloud-tpu-client                       0.10
cloud-tpu-profiler                     2.4.0
cloudpickle                            2.2.1
cmaes                                  0.9.1
cmdstanpy                              1.1.0
cmudict                                1.0.13
colorama                               0.4.6
colorcet                               3.0.1
colorful                               0.5.5
colorlog                               6.7.0
colorlover                             0.3.0
comm                                   0.1.3
commonmark                             0.9.1
conda                                  23.3.1
conda-content-trust                    0+unknown
conda-package-handling                 2.0.2
conda_package_streaming                0.7.0
confection                             0.0.4
contextily                             1.3.0
contourpy                              1.0.7
convertdate                            2.4.0
crcmod                                 1.7
cryptography                           40.0.2
cubinlinker                            0.3.0
cuda-python                            11.8.2
cudf                                   23.6.0
cufflinks                              0.17.3
cuml                                   23.6.0
cupy                                   12.0.0
CVXcanon                               0.1.2
cycler                                 0.11.0
cymem                                  2.0.7
cysignals                              1.11.2
Cython                                 0.29.34
cytoolz                                0.12.0
daal                                   2023.1.1
daal4py                                2023.1.1
dask                                   2023.6.0
dask-cuda                              23.6.0
dask-cudf                              23.6.0
dataclasses                            0.8
dataclasses-json                       0.5.8
datasets                               2.1.0
datashader                             0.15.0
datashape                              0.5.2
datatile                               1.0.3
db-dtypes                              1.1.1
deap                                   1.3.3
debugpy                                1.6.7
decorator                              5.1.1
deepspeed                              0.9.5
defusedxml                             0.7.1
Delorean                               1.0.0
deprecat                               2.1.1
Deprecated                             1.2.13
deprecation                            2.1.0
descartes                              1.1.0
dill                                   0.3.6
dipy                                   1.7.0
distlib                                0.3.6
distributed                            2023.3.2.1
dm-tree                                0.1.8
docker                                 6.1.1
docker-pycreds                         0.4.0
docopt                                 0.6.2
docstring-parser                       0.15
docstring-to-markdown                  0.12
docutils                               0.20.1
dp-transformers                        1.0.0
earthengine-api                        0.1.356
easydict                               1.10
easyocr                                1.7.0
ecos                                   2.0.12
eli5                                   0.13.0
emoji                                  2.5.0
en-core-web-lg                         3.5.0
en-core-web-sm                         3.5.0
entrypoints                            0.4
ephem                                  4.1.4
esda                                   2.4.3
essentia                               2.1b6.dev1034
et-xmlfile                             1.1.0
etils                                  1.2.0
exceptiongroup                         1.1.1
executing                              1.2.0
explainable-ai-sdk                     1.3.3
fastai                                 2.7.12
fastapi                                0.95.1
fastavro                               1.7.4
fastcore                               1.5.29
fastdownload                           0.0.7
fasteners                              0.18
fastjsonschema                         2.16.3
fastprogress                           1.0.3
fastrlock                              0.8
fasttext                               0.9.2
fbpca                                  1.0
feather-format                         0.4.1
featuretools                           1.26.0
filelock                               3.12.0
Fiona                                  1.9.4.post1
fire                                   0.5.0
fitter                                 1.5.2
flake8                                 6.0.0
flashtext                              2.7
Flask                                  2.3.2
flatbuffers                            23.3.3
flax                                   0.6.10
flit_core                              3.8.0
folium                                 0.14.0
fonttools                              4.39.3
fqdn                                   1.5.1
frozendict                             2.3.8
frozenlist                             1.3.3
fsspec                                 2023.6.0
functorch                              0.2.1
funcy                                  2.0
fury                                   0.9.0
future                                 0.18.3
fuzzywuzzy                             0.18.0
gast                                   0.4.0
gatspy                                 0.3
gcsfs                                  2023.5.0
gensim                                 4.3.1
geographiclib                          2.0
Geohash                                1.0
geojson                                3.0.1
geopandas                              0.13.2
geoplot                                0.5.1
geopy                                  2.3.0
geoviews                               1.10.0
ggplot                                 0.11.5
giddy                                  2.3.4
gitdb                                  4.0.10
GitPython                              3.1.31
google-api-core                        1.33.2
google-api-python-client               2.88.0
google-apitools                        0.5.31
google-auth                            2.17.3
google-auth-httplib2                   0.1.0
google-auth-oauthlib                   1.0.0
google-cloud-aiplatform                0.6.0a1
google-cloud-artifact-registry         1.8.1
google-cloud-automl                    1.0.1
google-cloud-bigquery                  2.34.4
google-cloud-bigtable                  1.7.3
google-cloud-core                      2.3.2
google-cloud-datastore                 2.15.2
google-cloud-dlp                       3.12.1
google-cloud-language                  2.6.1
google-cloud-monitoring                2.14.2
google-cloud-pubsub                    2.16.1
google-cloud-pubsublite                1.8.1
google-cloud-recommendations-ai        0.7.1
google-cloud-resource-manager          1.10.0
google-cloud-spanner                   3.33.0
google-cloud-storage                   1.44.0
google-cloud-translate                 3.8.4
google-cloud-videointelligence         2.8.3
google-cloud-vision                    2.8.0
google-crc32c                          1.5.0
google-pasta                           0.2.0
google-resumable-media                 2.5.0
googleapis-common-protos               1.57.1
gplearn                                0.4.2
gpustat                                1.0.0
gpxpy                                  1.5.0
graphviz                               0.20.1
greenlet                               2.0.2
grpc-google-iam-v1                     0.12.6
grpcio                                 1.51.1
grpcio-status                          1.48.1
gviz-api                               1.10.0
gym                                    0.26.2
gym-notices                            0.0.8
Gymnasium                              0.26.3
gymnasium-notices                      0.0.1
h11                                    0.14.0
h2o                                    3.40.0.4
h5py                                   3.8.0
haversine                              2.8.0
hdfs                                   2.7.0
hep-ml                                 0.7.2
hijri-converter                        2.3.1
hjson                                  3.1.0
hmmlearn                               0.3.0
holidays                               0.24
holoviews                              1.16.2
hpsklearn                              0.1.0
html5lib                               1.1
htmlmin                                0.1.12
httplib2                               0.21.0
httptools                              0.5.0
huggingface-hub                        0.15.1
humanize                               4.6.0
hunspell                               0.5.5
husl                                   4.0.3
hydra-slayer                           0.4.1
hyperopt                               0.2.7
hypertools                             0.8.0
ibis-framework                         5.1.0
idna                                   3.4
igraph                                 0.10.4
imagecodecs                            2023.3.16
ImageHash                              4.3.1
imageio                                2.28.1
imbalanced-learn                       0.10.1
imgaug                                 0.4.0
implicit                               0.5.2
importlib-metadata                     5.2.0
importlib-resources                    5.12.0
inequality                             1.0.0
iniconfig                              2.0.0
ipydatawidgets                         4.3.4
ipykernel                              6.23.0
ipyleaflet                             0.17.3
ipympl                                 0.7.0
ipython                                8.13.2
ipython-genutils                       0.2.0
ipython-sql                            0.5.0
ipyvolume                              0.6.3
ipyvue                                 1.9.1
ipyvuetify                             1.8.10
ipywebrtc                              0.6.0
ipywidgets                             7.7.1
isoduration                            20.11.0
isort                                  5.12.0
isoweek                                1.3.3
itsdangerous                           2.1.2
Janome                                 0.4.2
jaraco.classes                         3.2.3
jax                                    0.4.8
jaxlib                                 0.4.7+cuda11.cudnn86
jedi                                   0.18.2
jeepney                                0.8.0
jieba                                  0.42.1
Jinja2                                 3.1.2
jmespath                               1.0.1
joblib                                 1.2.0
json5                                  0.9.11
jsonpatch                              1.32
jsonpointer                            2.0
jsonschema                             4.17.3
jupyter_client                         7.4.9
jupyter-console                        6.6.3
jupyter_core                           5.3.0
jupyter-events                         0.6.3
jupyter-http-over-ws                   0.0.8
jupyter-lsp                            1.5.1
jupyter_server                         2.5.0
jupyter_server_fileid                  0.9.0
jupyter-server-mathjax                 0.2.6
jupyter_server_proxy                   4.0.0
jupyter_server_terminals               0.4.4
jupyter_server_ydoc                    0.8.0
jupyter-ydoc                           0.2.4
jupyterlab                             3.6.4
jupyterlab-git                         0.41.0
jupyterlab-lsp                         4.2.0
jupyterlab-pygments                    0.2.2
jupyterlab_server                      2.22.1
jupyterlab-widgets                     3.0.7
jupytext                               1.14.5
kaggle                                 1.5.13
kaggle-environments                    1.12.0
keras                                  2.12.0
keras-tuner                            1.3.5
keyring                                23.13.1
keyrings.google-artifactregistry-auth  1.1.2
kfp                                    1.8.21
kfp-pipeline-spec                      0.1.16
kfp-server-api                         1.8.5
kiwisolver                             1.4.4
kmapper                                2.0.1
kmodes                                 0.12.2
korean-lunar-calendar                  0.3.1
kornia                                 0.6.12
kt-legacy                              1.0.5
kubernetes                             25.3.0
langcodes                              3.3.0
langid                                 1.1.6
lazy_loader                            0.2
lazy-object-proxy                      1.9.0
learntools                             0.3.4
leven                                  1.0.4
Levenshtein                            0.21.1
libclang                               16.0.0
libmambapy                             1.4.2
libpysal                               4.7.0
librosa                                0.10.0.post2
lightgbm                               3.3.2
lightning-utilities                    0.8.0
lime                                   0.2.0.1
line-profiler                          4.0.3
linkify-it-py                          2.0.2
llvmlite                               0.40.0
lml                                    0.1.0
locket                                 1.0.0
LunarCalendar                          0.0.9
lxml                                   4.9.2
lz4                                    4.3.2
Mako                                   1.2.4
mamba                                  1.4.2
mapclassify                            2.5.0
marisa-trie                            0.8.0
Markdown                               3.4.3
markdown-it-py                         2.2.0
markovify                              0.9.4
MarkupSafe                             2.1.2
marshmallow                            3.19.0
marshmallow-enum                       1.5.1
matplotlib                             3.6.3
matplotlib-inline                      0.1.6
matplotlib-venn                        0.11.9
mccabe                                 0.7.0
mdit-py-plugins                        0.3.5
mdurl                                  0.1.2
memory-profiler                        0.61.0
mercantile                             1.2.1
mgwr                                   2.1.2
missingno                              0.5.2
mistune                                0.8.4
mizani                                 0.9.2
ml-dtypes                              0.1.0
mlcrate                                0.2.0
mlens                                  0.2.3
mlxtend                                0.22.0
mmh3                                   4.0.0
mne                                    1.4.2
mnist                                  0.2.2
mock                                   5.0.2
momepy                                 0.6.0
more-itertools                         9.1.0
mpld3                                  0.5.9
mpmath                                 1.3.0
msgpack                                1.0.5
msgpack-numpy                          0.4.8
multidict                              6.0.4
multimethod                            1.9.1
multipledispatch                       0.6.0
multiprocess                           0.70.14
munkres                                1.1.4
murmurhash                             1.0.9
mypy-extensions                        1.0.0
nb-conda                               2.2.1
nb-conda-kernels                       2.3.1
nbclassic                              1.0.0
nbclient                               0.5.13
nbconvert                              6.4.5
nbdime                                 3.2.0
nbformat                               5.8.0
nest-asyncio                           1.5.6
netCDF4                                1.6.4
networkx                               3.1
nibabel                                5.1.0
nilearn                                0.10.1
ninja                                  1.11.1
nltk                                   3.2.4
nose                                   1.3.7
notebook                               6.5.4
notebook-executor                      0.2
notebook_shim                          0.2.3
numba                                  0.57.0
numexpr                                2.8.4
numpy                                  1.23.5
nvidia-ml-py                           11.495.46
nvtx                                   0.2.5
oauth2client                           4.1.3
oauthlib                               3.2.2
objsize                                0.6.1
odfpy                                  1.4.1
olefile                                0.46
onnx                                   1.14.0
opacus                                 1.2.0
opencensus                             0.11.2
opencensus-context                     0.1.3
opencv-contrib-python                  4.7.0.72
opencv-python                          4.7.0.72
opencv-python-headless                 4.7.0.72
openpyxl                               3.1.2
openslide-python                       1.2.0
opentelemetry-api                      1.17.0
opentelemetry-exporter-otlp            1.17.0
opentelemetry-exporter-otlp-proto-grpc 1.17.0
opentelemetry-exporter-otlp-proto-http 1.17.0
opentelemetry-proto                    1.17.0
opentelemetry-sdk                      1.17.0
opentelemetry-semantic-conventions     0.38b0
opt-einsum                             3.3.0
optax                                  0.1.5
optuna                                 3.2.0
orbax-checkpoint                       0.2.2
orderedmultidict                       1.0.1
orjson                                 3.8.12
ortools                                9.4.1874
osmnx                                  1.1.1
overrides                              6.5.0
packaging                              21.3
pandas                                 1.5.3
pandas-datareader                      0.10.0
pandas-profiling                       3.6.6
pandas-summary                         0.2.0
pandasql                               0.7.3
pandocfilters                          1.5.0
panel                                  1.1.0
papermill                              2.4.0
param                                  1.13.0
parso                                  0.8.3
parsy                                  2.1
partd                                  1.4.0
path                                   16.6.0
path.py                                12.5.0
pathos                                 0.3.0
pathtools                              0.1.2
pathy                                  0.10.1
patsy                                  0.5.3
pdf2image                              1.16.3
pexpect                                4.8.0
phik                                   0.12.3
pickleshare                            0.7.5
Pillow                                 9.5.0
pip                                    23.1.2
pkgutil_resolve_name                   1.3.10
platformdirs                           3.5.0
plotly                                 5.14.1
plotly-express                         0.4.1
plotnine                               0.10.1
pluggy                                 1.0.0
pointpats                              2.3.0
polars                                 0.18.2
polyglot                               16.7.4
pooch                                  1.6.0
pox                                    0.3.2
ppca                                   0.0.4
ppft                                   1.7.6.6
preprocessing                          0.1.13
preshed                                3.0.8
prettytable                            3.7.0
progressbar2                           4.2.0
prometheus-client                      0.16.0
promise                                2.3
prompt-toolkit                         3.0.38
pronouncing                            0.2.0
prophet                                1.1.1
proto-plus                             1.22.2
protobuf                               3.20.3
prv-accountant                         0.1.1.post1
psutil                                 5.9.3
ptxcompiler                            0.8.1
ptyprocess                             0.7.0
pudb                                   2022.1.3
PuLP                                   2.7.0
pure-eval                              0.2.2
py-cpuinfo                             9.0.0
py-lz4framed                           0.14.0
py-spy                                 0.3.14
py4j                                   0.10.9.7
pyaml                                  23.5.9
PyArabic                               0.6.15
pyarrow                                11.0.0
pyasn1                                 0.4.8
pyasn1-modules                         0.2.7
PyAstronomy                            0.19.0
pybind11                               2.10.4
pyclipper                              1.3.0.post4
pycodestyle                            2.10.0
pycolmap                               0.4.0
pycosat                                0.6.4
pycparser                              2.21
pycryptodome                           3.18.0
pyct                                   0.5.0
pycuda                                 2022.2.2
pydantic                               1.10.7
pydegensac                             0.1.2
pydicom                                2.3.1
pydocstyle                             6.3.0
pydot                                  1.4.2
pydub                                  0.25.1
pyemd                                  1.0.0
pyerfa                                 2.0.0.3
pyexcel-io                             0.6.6
pyexcel-ods                            0.6.0
pyfasttext                             0.4.6
pyflakes                               3.0.1
pygltflib                              1.15.6
Pygments                               2.15.1
PyJWT                                  2.6.0
pykalman                               0.9.5
pyLDAvis                               3.2.2
pylibraft                              23.6.1
pylint                                 2.17.4
pymc3                                  3.11.5
PyMeeus                                0.5.12
pymongo                                3.13.0
Pympler                                1.0.1
pynndescent                            0.5.10
pynvml                                 11.4.1
pynvrtc                                9.2
pyocr                                  0.8.3
pyOpenSSL                              23.1.1
pyparsing                              3.0.9
pypdf                                  3.9.1
pyproj                                 3.6.0
pyrsistent                             0.19.3
pysal                                  23.1
pyshp                                  2.3.1
PySocks                                1.7.1
pytesseract                            0.3.10
pytest                                 7.3.2
python-bidi                            0.4.2
python-dateutil                        2.8.2
python-dotenv                          1.0.0
python-igraph                          0.10.4
python-json-logger                     2.0.7
python-Levenshtein                     0.21.1
python-louvain                         0.16
python-lsp-jsonrpc                     1.0.0
python-lsp-server                      1.7.3
python-slugify                         8.0.1
python-utils                           3.6.0
pythreejs                              2.4.2
pytoolconfig                           1.2.5
pytools                                2022.1.14
pytorch-ignite                         0.4.12
pytorch-lightning                      2.0.3
pytz                                   2023.3
pyu2f                                  0.1.5
PyUpSet                                0.1.1.post7
pyviz-comms                            2.3.1
PyWavelets                             1.4.1
PyYAML                                 5.4.1
pyzmq                                  25.0.2
qgrid                                  1.3.1
qtconsole                              5.4.3
QtPy                                   2.3.1
quantecon                              0.7.1
quantities                             0.14.1
qudida                                 0.0.4
raft-dask                              23.6.1
randomgen                              1.23.1
rapidfuzz                              3.1.1
rasterio                               1.3.7
rasterstats                            0.19.0
ray                                    2.4.0
ray-cpp                                2.4.0
regex                                  2023.5.5
requests                               2.28.2
requests-oauthlib                      1.3.1
requests-toolbelt                      0.10.1
responses                              0.18.0
retrying                               1.3.3
rfc3339-validator                      0.1.4
rfc3986-validator                      0.1.1
rgf-python                             3.12.0
rich                                   12.6.0
rmm                                    23.6.0
rope                                   1.8.0
rsa                                    4.9
Rtree                                  1.0.1
ruamel.yaml                            0.17.24
ruamel.yaml.clib                       0.2.7
ruamel-yaml-conda                      0.15.100
s2sphere                               0.2.5
s3fs                                   2023.6.0
s3transfer                             0.6.1
safetensors                            0.3.1
scattertext                            0.1.19
scikit-image                           0.20.0
scikit-learn                           1.2.2
scikit-learn-intelex                   2023.1.1
scikit-multilearn                      0.2.0
scikit-optimize                        0.9.0
scikit-plot                            0.3.7
scikit-surprise                        1.1.3
scipy                                  1.10.1
seaborn                                0.12.2
SecretStorage                          3.3.3
segment-anything                       1.0
segregation                            2.4.2
semver                                 3.0.0
Send2Trash                             1.8.2
sentencepiece                          0.1.99
sentry-sdk                             1.25.1
setproctitle                           1.3.2
setuptools                             59.8.0
setuptools-git                         1.2
setuptools-scm                         7.1.0
shap                                   0.41.0
Shapely                                1.8.5.post1
shellingham                            1.5.1
simpervisor                            0.4
SimpleITK                              2.2.1
simplejson                             3.19.1
six                                    1.16.0
sklearn-pandas                         2.2.0
slicer                                 0.0.7
smart-open                             6.3.0
smhasher                               0.150.1
smmap                                  5.0.0
sniffio                                1.3.0
snowballstemmer                        2.2.0
snuggs                                 1.4.7
sortedcontainers                       2.4.0
soundfile                              0.12.1
soupsieve                              2.3.2.post1
soxr                                   0.3.5
spacy                                  3.5.3
spacy-legacy                           3.0.12
spacy-loggers                          1.0.4
spaghetti                              1.7.3
spectral                               0.23.1
spglm                                  1.0.8
sphinx-rtd-theme                       0.2.4
spint                                  1.0.7
splot                                  1.1.5.post1
spopt                                  0.5.0
spreg                                  1.3.2
spvcm                                  0.3.0
SQLAlchemy                             2.0.12
sqlglot                                11.7.1
sqlparse                               0.4.4
squarify                               0.4.3
srsly                                  2.4.6
stack-data                             0.6.2
starlette                              0.26.1
statsmodels                            0.13.5
stemming                               1.0.1
stop-words                             2018.7.23
stopit                                 1.1.2
strip-hints                            0.1.10
stumpy                                 1.11.1
sympy                                  1.12
tables                                 3.8.0
tabulate                               0.9.0
tangled-up-in-unicode                  0.2.0
tbb                                    2021.9.0
tblib                                  1.7.0
tenacity                               8.2.2
tensorboard                            2.12.3
tensorboard-data-server                0.7.0
tensorboard-plugin-profile             2.11.2
tensorboardX                           2.6
tensorflow                             2.12.0
tensorflow-addons                      0.20.0
tensorflow-cloud                       0.1.16
tensorflow-datasets                    4.9.2
tensorflow-decision-forests            1.3.0
tensorflow-estimator                   2.12.0
tensorflow-gcs-config                  2.12.0
tensorflow-hub                         0.12.0
tensorflow-io                          0.31.0
tensorflow-io-gcs-filesystem           0.31.0
tensorflow-metadata                    0.14.0
tensorflow-probability                 0.20.0
tensorflow-serving-api                 2.12.1
tensorflow-text                        2.12.1
tensorflow-transform                   0.14.0
tensorflowjs                           3.15.0
tensorpack                             0.11
tensorstore                            0.1.37
termcolor                              2.3.0
terminado                              0.17.1
testpath                               0.6.0
text-unidecode                         1.3
textblob                               0.17.1
texttable                              1.6.7
textwrap3                              0.9.2
Theano                                 1.0.5
Theano-PyMC                            1.1.2
thinc                                  8.1.10
threadpoolctl                          3.1.0
tifffile                               2023.4.12
timm                                   0.9.2
tinycss2                               1.2.1
tobler                                 0.10
tokenizers                             0.13.3
toml                                   0.10.2
tomli                                  2.0.1
tomlkit                                0.11.8
toolz                                  0.12.0
torch                                  1.12.1
torchaudio                             2.0.1
torchdata                              0.6.0
torchinfo                              1.8.0
torchmetrics                           0.11.4
torchtext                              0.15.1
torchvision                            0.15.1
tornado                                6.3.1
TPOT                                   0.12.0
tqdm                                   4.64.1
traceml                                1.0.8
traitlets                              5.9.0
traittypes                             0.2.1
transformers                           4.30.1
treelite                               3.2.0
treelite-runtime                       3.2.0
trueskill                              0.4.5
tsfresh                                0.20.0
typeguard                              2.13.3
typer                                  0.7.0
typing_extensions                      4.5.0
typing-inspect                         0.9.0
tzlocal                                5.0.1
uc-micro-py                            1.0.2
ucx-py                                 0.32.0
ujson                                  5.8.0
umap-learn                             0.5.3
unicodedata2                           15.0.0
Unidecode                              1.3.6
update-checker                         0.18.0
uri-template                           1.2.0
uritemplate                            3.0.1
urllib3                                1.26.15
urwid                                  2.1.2
urwid-readline                         0.13
uvicorn                                0.22.0
uvloop                                 0.17.0
vaex                                   4.16.0
vaex-astro                             0.9.3
vaex-core                              4.16.1
vaex-hdf5                              0.14.1
vaex-jupyter                           0.8.1
vaex-ml                                0.18.1
vaex-server                            0.8.1
vaex-viz                               0.5.4
vecstack                               0.4.0
virtualenv                             20.21.0
visions                                0.7.5
vowpalwabbit                           9.8.0
vtk                                    9.2.6
Wand                                   0.6.11
wandb                                  0.15.4
wasabi                                 1.1.2
watchfiles                             0.19.0
wavio                                  0.0.7
wcwidth                                0.2.6
webcolors                              1.13
webencodings                           0.5.1
websocket-client                       1.5.1
websockets                             11.0.3
Werkzeug                               2.3.6
wfdb                                   4.1.1
whatthepatch                           1.0.5
wheel                                  0.40.0
widgetsnbextension                     3.6.4
witwidget                              1.8.1
woodwork                               0.24.0
Wordbatch                              1.4.9
wordcloud                              1.9.2
wordsegment                            1.3.1
wrapt                                  1.14.1
wurlitzer                              3.0.3
xarray                                 2023.5.0
xarray-einstats                        0.5.1
xgboost                                1.7.5
xvfbwrapper                            0.2.9
xxhash                                 3.2.0
xyzservices                            2023.5.0
y-py                                   0.5.9
yapf                                   0.33.0
yarl                                   1.9.1
ydata-profiling                        4.1.2
yellowbrick                            1.5
ypy-websocket                          0.8.2
zict                                   3.0.0
zipp                                   3.15.0
zstandard                              0.19.0

Could anyone help me with this?

huseyinatahaninan commented 12 months ago

Hi @xiehuanyi, our example was based on transformers==4.20.1 version and although I am sure it'd work with more recent versions, we did not have time to test for the newest versions unfortunately. I'll try to do that sometime this week but in the meantime if it's okay for you to use older versions, that would be the fastest route right now.

huseyinatahaninan commented 12 months ago

just fyi it works up to transformers==4.28.1 but beyond that it may need some changes. I see that elif self.deepspeed: ... part is causing the issue so you can comment it out to see if it works with transformers==4.30.1 but I have not fully tested with this version so I don't know yet if it requires further changes or not.

xiehuanyi commented 12 months ago

just fyi it works up to transformers==4.28.1 but beyond that it may need some changes. I see that elif self.deepspeed: ... part is causing the issue so you can comment it out to see if it works with transformers==4.30.1 but I have not fully tested with this version so I don't know yet if it requires further changes or not.

Great! I will try it. Thanks ~

xiehuanyi commented 12 months ago

I just tried a toy dataset out. and found it very strange. I ran it three times and it worked fine everytime without differential privacy. However, when using differential privacy, it may fail from time to time. My code is shown as below.

from dp_transformers.grad_sample.transformers import conv_1d
from transformers import AutoModelForCausalLM, GPT2Model
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import torch

class ToyData(Dataset):
    def __init__(self):
        super().__init__()

    def __getitem__(self, index):
        return (
            torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
            torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), 
            torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
        )

    def __len__(self):
        return 100

def run(use_dp):    
    data_loader = DataLoader(ToyData(), batch_size=8)

    model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
    opt = torch.optim.Adam(model.parameters(), lr=1e-2)

    from opacus import PrivacyEngine

    model = model.train()
    if use_dp:
        pe = PrivacyEngine()
        model, opt, data_loader = pe.make_private(
            module=model, 
            optimizer=opt, 
            data_loader=data_loader, 
            noise_multiplier=1.3,
            max_grad_norm=1.0)

    for epoch in range(100):
        for batch in data_loader:
            # print([i.shape for i in batch])
            loss = model(input_ids=batch[0], labels=batch[1], position_ids=batch[2]).loss
            loss.backward()
            opt.step()
            opt.zero_grad()
    print(loss.item())
for use_dp in [True, False]:
    for i in range(3):
        try:
            run(use_dp)
            print(f"use_dp: {use_dp} success")
        except Exception as e:
            print(f"use_dp: {use_dp} error msg: {e}")

and here is the output

[/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/opacus/privacy_engine.py:141](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224c696e7578227d.vscode-resource.vscode-cdn.net/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/opacus/privacy_engine.py:141): UserWarning: Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.
  warnings.warn(
[/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1053](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224c696e7578227d.vscode-resource.vscode-cdn.net/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1053): UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
[/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1018](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224c696e7578227d.vscode-resource.vscode-cdn.net/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1018): UserWarning: Using non-full backward hooks on a Module that does not return a single Tensor or a tuple of Tensors is deprecated and will be removed in future versions. This hook will be missing some of the grad_output. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using non-full backward hooks on a Module that does not return a "
[/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/autograd/__init__.py:173](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a224c696e7578227d.vscode-resource.vscode-cdn.net/home/huanyi/miniconda3/envs/dlenv/lib/python3.10/site-packages/torch/autograd/__init__.py:173): UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at  ../c10/cuda/CUDAFunctions.cpp:109.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
1.9542039632797241
use_dp: True success
1.9521187543869019
use_dp: True success
use_dp: True error msg: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
1.9460290670394897
use_dp: False success
1.9460164308547974
use_dp: False success
1.9460010528564453
use_dp: False success

I followed the tips given:

xiehuanyi commented 11 months ago

excuse me, which reddit dataset do you use? I didn't find 'reddit' on huggingface, however, I found someone similar. You can find it here: https://huggingface.co/datasets/solomonk/reddit. However, it seems having some incompatible problems. Since connection from China mainland to the huggingface is unstable, I used git-lfs to clone the dataset like this: git-lfs clone https://huggingface.co/datasets/solomonk/reddit It is stored under the directory 'dp-transformers', and it's found properly. However, I got an error during the running. My command and output are shown below: command

python examples/nlg-reddit/sample-level-dp/fine-tune-dp.py \
--output_dir scratch \
--model_name tiny-gpt2 \
--sequence_len 128 \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 2 \
--evaluation_strategy steps \
--eval_steps 45 \
--log_level info \
--per_device_eval_batch_size 64 \
--eval_accumulation_steps 1 \
--seed 42 \
--target_epsilon 8 \
--per_sample_max_grad_norm 1.0 \
--prediction_loss_only \
--weight_decay 0.01 \
--remove_unused_columns False \
--num_train_epochs 3 \
--logging_steps 5 \
--max_grad_norm 0 \
--lr_scheduler_type constant \
--learning_rate 1e-4 \
--disable_tqdm True \
--dataloader_num_workers 2

output:

07/13/2023 12:06:24:WARNING:Process rank: 0, device: cuda:0, n_gpu: 1, distributed training: True, 16-bits training: False
07/13/2023 12:06:24:INFO:Training/evaluation parameters TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=2,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=True,
do_eval=True,
do_predict=False,
do_train=False,
dry_run=False,
eval_accumulation_steps=1,
eval_delay=0,
eval_steps=45,
evaluation_strategy=steps,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=2,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=0.0001,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=info,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=scratch/runs/Jul13_12-06-23_df0caa500212d011ee0917a0c7f822b9ff09-task1-0,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=5,
logging_strategy=steps,
lr_scheduler_type=constant,
max_grad_norm=0.0,
max_steps=-1,
metric_for_best_model=None,
mp_parameters=,
no_cuda=False,
num_train_epochs=3.0,
optim=adamw_hf,
optim_args=None,
output_dir=scratch,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=64,
per_device_train_batch_size=32,
prediction_loss_only=True,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
remove_unused_columns=False,
report_to=[],
resume_from_checkpoint=None,
run_name=scratch,
save_on_each_node=False,
save_safetensors=False,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
sharded_ddp=[],
skip_memory_metrics=True,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_ipex=False,
use_legacy_prediction_loop=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.01,
xpu_backend=None,
)
07/13/2023 12:06:24:INFO:Privacy parameters PrivacyArguments(per_sample_max_grad_norm=1.0, noise_multiplier=None, target_epsilon=8.0, target_delta=None, disable_dp=False)
[INFO|configuration_utils.py:667] 2023-07-13 12:06:24,025 >> loading configuration file tiny-gpt2/config.json
[INFO|configuration_utils.py:725] 2023-07-13 12:06:24,026 >> Model config GPT2Config {
  "_name_or_path": "tiny-gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 2,
  "n_head": 2,
  "n_inner": null,
  "n_layer": 2,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.30.2",
  "use_cache": true,
  "vocab_size": 50257
}

[INFO|modeling_utils.py:2575] 2023-07-13 12:06:24,052 >> loading weights file tiny-gpt2/pytorch_model.bin
[INFO|configuration_utils.py:577] 2023-07-13 12:06:24,059 >> Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 50256,
  "eos_token_id": 50256,
  "transformers_version": "4.30.2"
}

[INFO|modeling_utils.py:3295] 2023-07-13 12:06:24,213 >> All model checkpoint weights were used when initializing GPT2LMHeadModel.

[INFO|modeling_utils.py:3304] 2023-07-13 12:06:24,213 >> All the weights of GPT2LMHeadModel were initialized from the model checkpoint at tiny-gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GPT2LMHeadModel for predictions without further training.
[INFO|modeling_utils.py:2928] 2023-07-13 12:06:24,214 >> Generation config file not found, using a generation config created from the model config.
07/13/2023 12:06:27:INFO:Some files matched the pattern 'reddit/**' at /code/dp-transformers/reddit but don't have valid data file extensions: [PosixPath('/code/dp-transformers/reddit/RS_2006-01.zst'), PosixPath('/code/dp-transformers/reddit/RC_2006-01.bz2')]
07/13/2023 12:06:27:WARNING:Using custom data configuration reddit-622990d947526d4c
07/13/2023 12:06:27:INFO:Loading Dataset Infos from /opt/conda/lib/python3.7/site-packages/datasets/packaged_modules/json
07/13/2023 12:06:27:INFO:Generating dataset json (/root/.cache/huggingface/datasets/json/reddit-622990d947526d4c/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
Downloading and preparing dataset json/reddit to /root/.cache/huggingface/datasets/json/reddit-622990d947526d4c/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab...
07/13/2023 12:06:28:INFO:Dataset not on Hf google storage. Downloading and preparing it from source
07/13/2023 12:06:28:INFO:Downloading took 0.0 min
07/13/2023 12:06:28:INFO:Checksum Computation took 0.0 min
07/13/2023 12:06:28:INFO:Unable to verify checksums.
07/13/2023 12:06:28:INFO:Generating train split
Traceback (most recent call last):
  File "examples/nlg-reddit/sample-level-dp/fine-tune-dp.py", line 141, in <module>
    main(Arguments(train=train_args, privacy=privacy_args, model=model_args))
  File "examples/nlg-reddit/sample-level-dp/fine-tune-dp.py", line 86, in main
    dataset = datasets.load_dataset('reddit')
  File "/opt/conda/lib/python3.7/site-packages/datasets/load.py", line 1747, in load_dataset
    use_auth_token=use_auth_token,
  File "/opt/conda/lib/python3.7/site-packages/datasets/builder.py", line 818, in download_and_prepare
    **download_and_prepare_kwargs,
  File "/opt/conda/lib/python3.7/site-packages/datasets/builder.py", line 905, in _download_and_prepare
    self._prepare_split(split_generator, **prepare_split_kwargs)
  File "/opt/conda/lib/python3.7/site-packages/datasets/builder.py", line 1520, in _prepare_split
    writer.write_table(table)
  File "/opt/conda/lib/python3.7/site-packages/datasets/arrow_writer.py", line 540, in write_table
    pa_table = table_cast(pa_table, self._schema)
  File "/opt/conda/lib/python3.7/site-packages/datasets/table.py", line 2068, in table_cast
    return cast_table_to_schema(table, schema)
  File "/opt/conda/lib/python3.7/site-packages/datasets/table.py", line 2029, in cast_table_to_schema
    raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match")
ValueError: Couldn't cast
archived: bool
author: string
author_flair_background_color: string
author_flair_css_class: null
author_flair_richtext: list<item: null>
  child 0, item: null
author_flair_text: null
author_flair_text_color: string
author_flair_type: string
brand_safe: bool
can_gild: bool
contest_mode: bool
created_utc: int64
distinguished: null
domain: string
edited: bool
gilded: int64
hidden: bool
hide_score: bool
id: string
is_crosspostable: bool
is_reddit_media_domain: bool
is_self: bool
is_video: bool
link_flair_css_class: null
link_flair_richtext: list<item: null>
  child 0, item: null
link_flair_text: null
link_flair_text_color: string
link_flair_type: string
locked: bool
media: null
media_embed: struct<>
no_follow: bool
num_comments: int64
num_crossposts: int64
over_18: bool
parent_whitelist_status: string
permalink: string
rte_mode: string
score: int64
secure_media: null
secure_media_embed: struct<>
selftext: string
send_replies: bool
spoiler: bool
stickied: bool
subreddit: string
subreddit_id: string
subreddit_name_prefixed: string
subreddit_type: string
suggested_sort: null
thumbnail: string
thumbnail_height: int64
thumbnail_width: int64
title: string
url: string
whitelist_status: string
post_hint: string
preview: struct<enabled: bool, images: list<item: struct<id: string, resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>, variants: struct<nsfw: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>, obfuscated: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>>>>>
  child 0, enabled: bool
  child 1, images: list<item: struct<id: string, resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>, variants: struct<nsfw: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>, obfuscated: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>>>>
      child 0, item: struct<id: string, resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>, variants: struct<nsfw: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>, obfuscated: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>>>
          child 0, id: string
          child 1, resolutions: list<item: struct<height: int64, url: string, width: int64>>
              child 0, item: struct<height: int64, url: string, width: int64>
                  child 0, height: int64
                  child 1, url: string
                  child 2, width: int64
          child 2, source: struct<height: int64, url: string, width: int64>
              child 0, height: int64
              child 1, url: string
              child 2, width: int64
          child 3, variants: struct<nsfw: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>, obfuscated: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>>
              child 0, nsfw: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>
                  child 0, resolutions: list<item: struct<height: int64, url: string, width: int64>>
                      child 0, item: struct<height: int64, url: string, width: int64>
                          child 0, height: int64
                          child 1, url: string
                          child 2, width: int64
                  child 1, source: struct<height: int64, url: string, width: int64>
                      child 0, height: int64
                      child 1, url: string
                      child 2, width: int64
              child 1, obfuscated: struct<resolutions: list<item: struct<height: int64, url: string, width: int64>>, source: struct<height: int64, url: string, width: int64>>
                  child 0, resolutions: list<item: struct<height: int64, url: string, width: int64>>
                      child 0, item: struct<height: int64, url: string, width: int64>
                          child 0, height: int64
                          child 1, url: string
                          child 2, width: int64
                  child 1, source: struct<height: int64, url: string, width: int64>
                      child 0, height: int64
                      child 1, url: string
                      child 2, width: int64
retrieved_on: int64
to
{'gilded': Value(dtype='int64', id=None), 'distinguished': Value(dtype='null', id=None), 'retrieved_on': Value(dtype='int64', id=None), 'author_flair_text': Value(dtype='null', id=None), 'author': Value(dtype='string', id=None), 'edited': Value(dtype='bool', id=None), 'id': Value(dtype='string', id=None), 'parent_id': Value(dtype='string', id=None), 'subreddit': Value(dtype='string', id=None), 'score': Value(dtype='int64', id=None), 'ups': Value(dtype='int64', id=None), 'created_utc': Value(dtype='int64', id=None), 'author_flair_css_class': Value(dtype='null', id=None), 'body': Value(dtype='string', id=None), 'controversiality': Value(dtype='int64', id=None), 'subreddit_id': Value(dtype='string', id=None), 'stickied': Value(dtype='bool', id=None), 'link_id': Value(dtype='string', id=None)}
because column names don't match

It seems I used the wrong version of dataset.

huseyinatahaninan commented 11 months ago

dataset = datasets.load_dataset('reddit', split="train[:500000]").train_test_split(0.02, seed=args.train.seed) This is where we load the dataset and we were using datasets==2.0.0. Does this not work for you?

xiehuanyi commented 11 months ago

dataset = datasets.load_dataset('reddit', split="train[:500000]").train_test_split(0.02, seed=args.train.seed) This is where we load the dataset and we were using datasets==2.0.0. Does this not work for you?

It turns out that my network was not stable which led to the failure. And I clone the dataset with git-lfs. It works fine for me. Thanks a lot!

huseyinatahaninan commented 11 months ago

glad that it helps! Sorry I did not get a chance to look at the other error you got but it does not look so much of an error related to DP really.

ooolivia2333 commented 2 weeks ago

Hi I ran the sample-level example with dp using the command (run on local machine by creating conda environment):

python -m torch.distributed.run --nproc_per_node 1 fine-tune-dp.py \
--output_dir scratch \
--sequence_len 128 \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--evaluation_strategy steps \
--eval_steps 45 \
--log_level info \
--per_device_eval_batch_size 64 \
--eval_accumulation_steps 1 \
--seed 42 \
--target_epsilon 8 \
--per_sample_max_grad_norm 1.0 \
--prediction_loss_only \
--weight_decay 0.01 \
--remove_unused_columns False \
--num_train_epochs 3 \
--logging_steps 5 \
--lora_dim 4 \
--lora_alpha 32 \
--lora_dropout 0.0 \
--max_grad_norm 0 \
--lr_scheduler_type constant \
--learning_rate 3e-4 \
--disable_tqdm True \
--dataloader_num_workers 2 \
--label_names labels \
--enable_lora

but get the following error when attempting to train:

Traceback (most recent call last):

File "/home/wentao/shiqi/dp-transformers/examples/nlg-reddit/sample-level-dp/fine-tune-dp.py", line 146, in main(Arguments(train=train_args, privacy=privacy_args, model=model_args, lora=lora_args)) File "/home/wentao/shiqi/dp-transformers/examples/nlg-reddit/sample-level-dp/fine-tune-dp.py", line 134, in main trainer.train() File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/transformers/trainer.py", line 1537, in train return inner_training_loop( File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/transformers/trainer.py", line 1851, in _inner_training_loop self.control = self.callback_handler.on_step_begin(args, self.state, self.control) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/transformers/trainer_callback.py", line 386, in on_step_begin return self.call_event("on_step_begin", args, state, control) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/transformers/trainer_callback.py", line 414, in call_event result = getattr(callback, event)( File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/dp_transformers/dp_utils.py", line 61, in on_step_begin optimizer.signal_skip_step(do_skip=False) AttributeError: 'AcceleratedOptimizer' object has no attribute 'signal_skip_step' ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 6753) of binary: /home/wentao/anaconda3/envs/dp-transformers/bin/python Traceback (most recent call last): File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/run.py", line 765, in main() File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 345, in wrapper return f(*args, **kwargs) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/run.py", line 761, in main run(args) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/run.py", line 752, in run elastic_launch( File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 131, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/home/wentao/anaconda3/envs/dp-transformers/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

I changed the dataset to ag_news since reddit is too big. Can you suggest what is the issue? Also after I activate the environment and installed the dp_transformers library, it will reinstall torch-1.12.1, but it is incompatible with peft:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. triton 2.0.0 requires cmake, which is not installed. triton 2.0.0 requires lit, which is not installed. peft 0.4.0 requires torch>=1.13.0, but you have torch 1.12.1 which is incompatible. Successfully installed dp_transformers-1.0.0 functorch-0.2.1 opacus-1.3.0 torch-1.12.1

and here are my environment versions:

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
accelerate 0.21.0 pypi_0 pypi aiohttp 3.8.5 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi alembic 1.11.2 pypi_0 pypi async-timeout 4.0.3 pypi_0 pypi attrs 23.1.0 pypi_0 pypi azure-common 1.1.28 pypi_0 pypi azure-core 1.29.2 pypi_0 pypi azure-identity 1.14.0 pypi_0 pypi azure-mgmt-core 1.4.0 pypi_0 pypi azure-storage-blob 12.13.0 pypi_0 pypi azureml-mlflow 1.52.0 pypi_0 pypi blas 1.0 mkl
blinker 1.6.2 pypi_0 pypi bzip2 1.0.8 h5eee18b_6
ca-certificates 2023.05.30 h06a4308_0
certifi 2023.7.22 pypi_0 pypi cffi 1.15.1 pypi_0 pypi charset-normalizer 3.2.0 pypi_0 pypi click 8.1.6 pypi_0 pypi cloudpickle 2.2.1 pypi_0 pypi contourpy 1.1.0 pypi_0 pypi cryptography 41.0.3 pypi_0 pypi cuda-cudart 11.8.89 0 nvidia cuda-cupti 11.8.87 0 nvidia cuda-libraries 11.8.0 0 nvidia cuda-nvrtc 11.8.89 0 nvidia cuda-nvtx 11.8.86 0 nvidia cuda-runtime 11.8.0 0 nvidia cycler 0.11.0 pypi_0 pypi databricks-cli 0.17.7 pypi_0 pypi datasets 2.14.4 pypi_0 pypi dill 0.3.7 pypi_0 pypi docker 6.1.3 pypi_0 pypi dp-transformers 1.0.0 pypi_0 pypi entrypoints 0.4 pypi_0 pypi exceptiongroup 1.1.3 pypi_0 pypi filelock 3.9.0 py310h06a4308_0
flask 2.3.2 pypi_0 pypi fonttools 4.42.0 pypi_0 pypi frozenlist 1.4.0 pypi_0 pypi fsspec 2023.6.0 pypi_0 pypi functorch 0.2.1 pypi_0 pypi gitdb 4.0.10 pypi_0 pypi gitpython 3.1.32 pypi_0 pypi gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py310heeb90bb_0
greenlet 2.0.2 pypi_0 pypi gunicorn 21.2.0 pypi_0 pypi huggingface-hub 0.19.4 pypi_0 pypi idna 3.4 pypi_0 pypi importlib-metadata 6.8.0 pypi_0 pypi iniconfig 2.0.0 pypi_0 pypi intel-openmp 2023.1.0 hdb19cb5_46306
isodate 0.6.1 pypi_0 pypi itsdangerous 2.1.2 pypi_0 pypi jinja2 3.1.2 py310h06a4308_0
joblib 1.3.2 pypi_0 pypi jsonpickle 3.0.2 pypi_0 pypi kiwisolver 1.4.4 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1
libcublas 11.11.3.6 0 nvidia libcufft 10.9.0.58 0 nvidia libcufile 1.7.1.12 0 nvidia libcurand 10.3.3.129 0 nvidia libcusolver 11.4.1.48 0 nvidia libcusparse 11.7.5.86 0 nvidia libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libnpp 11.8.0.86 0 nvidia libnvjpeg 11.9.0.86 0 nvidia libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
mako 1.2.4 pypi_0 pypi markdown 3.4.4 pypi_0 pypi markupsafe 2.1.1 py310h7f8727e_0
matplotlib 3.7.2 pypi_0 pypi mkl 2023.1.0 h213fc3f_46344
mlflow 2.6.0 pypi_0 pypi mlflow-skinny 2.6.0 pypi_0 pypi mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py310h06a4308_0
msal 1.23.0 pypi_0 pypi msal-extensions 1.0.0 pypi_0 pypi msrest 0.7.1 pypi_0 pypi multidict 6.0.4 pypi_0 pypi multiprocess 0.70.15 pypi_0 pypi ncurses 6.4 h6a678d5_0
networkx 3.1 py310h06a4308_0
numpy 1.25.2 pypi_0 pypi oauthlib 3.2.2 pypi_0 pypi opacus 1.3.0 pypi_0 pypi openssl 3.0.10 h7f8727e_2
opt-einsum 3.3.0 pypi_0 pypi packaging 23.1 pypi_0 pypi pandas 2.0.3 pypi_0 pypi peft 0.4.0 pypi_0 pypi pillow 10.0.0 pypi_0 pypi pip 23.2.1 py310h06a4308_0
pluggy 1.2.0 pypi_0 pypi portalocker 2.7.0 pypi_0 pypi protobuf 4.24.0 pypi_0 pypi prv-accountant 0.1.1.post1 pypi_0 pypi psutil 5.9.5 pypi_0 pypi pyarrow 12.0.1 pypi_0 pypi pycparser 2.21 pypi_0 pypi pyjwt 2.8.0 pypi_0 pypi pyparsing 3.0.9 pypi_0 pypi pytest 7.4.0 pypi_0 pypi python 3.10.12 h955ad1f_0
python-dateutil 2.8.2 pypi_0 pypi pytorch-cuda 11.8 h7e8668a_5 pytorch pytorch-mutex 1.0 cuda pytorch pytz 2023.3 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi querystring-parser 1.2.4 pypi_0 pypi readline 8.2 h5eee18b_0
regex 2023.8.8 pypi_0 pypi requests 2.31.0 pypi_0 pypi requests-oauthlib 1.3.1 pypi_0 pypi safetensors 0.3.2 pypi_0 pypi scikit-learn 1.3.0 pypi_0 pypi scipy 1.11.1 pypi_0 pypi setuptools 68.0.0 py310h06a4308_0
six 1.16.0 pypi_0 pypi smmap 5.0.0 pypi_0 pypi sqlalchemy 2.0.19 pypi_0 pypi sqlite 3.41.2 h5eee18b_0
sqlparse 0.4.4 pypi_0 pypi sympy 1.11.1 py310h06a4308_0
tabulate 0.9.0 pypi_0 pypi tbb 2021.8.0 hdb19cb5_0
threadpoolctl 3.2.0 pypi_0 pypi tk 8.6.12 h1ccaba5_0
tokenizers 0.15.0 pypi_0 pypi tomli 2.0.1 pypi_0 pypi torch 1.12.1 pypi_0 pypi torchtriton 2.0.0 py310 pytorch tqdm 4.66.1 pypi_0 pypi transformers 4.36.1 pypi_0 pypi typing_extensions 4.7.1 py310h06a4308_0
tzdata 2023.3 pypi_0 pypi urllib3 1.26.16 pypi_0 pypi websocket-client 1.6.1 pypi_0 pypi werkzeug 2.3.7 pypi_0 pypi wheel 0.38.4 py310h06a4308_0
xxhash 3.3.0 pypi_0 pypi xz 5.4.2 h5eee18b_0
yarl 1.9.2 pypi_0 pypi zipp 3.16.2 pypi_0 pypi zlib 1.2.13 h5eee18b_1

Thank you so much!

huseyinatahaninan commented 1 week ago

Hi @ooolivia2333, apologies for late response :/ Which version of our repo you are currently using? I see that the error is from on_step_begin but I removed this a while ago (https://github.com/microsoft/dp-transformers/commit/0358f699be12e8503251e131c6fbf25590cf35eb) and currently in dp_utils.py we don't have on_step_begin. I think if you use the latest version of our repo, you should not encounter this issue. Let us know if you have any further issues please.

ooolivia2333 commented 6 days ago

Hi @ooolivia2333, apologies for late response :/ Which version of our repo you are currently using? I see that the error is from on_step_begin but I removed this a while ago (0358f69) and currently in dp_utils.py we don't have on_step_begin. I think if you use the latest version of our repo, you should not encounter this issue. Let us know if you have any further issues please.

Thanks for your reply! I am attempting to reinstall dp_transformers with peft, however I encountered the following error:

The conflict is caused by:
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.11.1 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.11.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.10.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.9.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.8.2 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.8.1 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.8.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.7.1 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.7.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.6.2 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.6.1 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.6.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.5.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.4.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.3.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.2.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.1.0 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.0.2 depends on torch>=1.13.0
    dp-transformers 1.0.0 depends on torch<=1.12.1 and >=1.9.1
    peft 0.0.1 depends on torch>=1.13.0

Can you suggest me what versions I should be using?

huseyinatahaninan commented 5 days ago

I think you can use latest version of our repo (which is 1.0.1) by cloning the repository and installing with pip install . -- you can see that in the latest version it should not lead to such issues because we have torch>=1.13.1

https://github.com/microsoft/dp-transformers/blob/main/setup.py