QData / TextAttack

TextAttack 🐙 is a Python framework for adversarial attacks, data augmentation, and model training in NLP https://textattack.readthedocs.io/en/master/
https://textattack.readthedocs.io/en/master/
MIT License
2.98k stars 397 forks source link

HuggingFaceDataset doesn't shuffle the dataset #791

Open ToldoDM opened 6 months ago

ToldoDM commented 6 months ago

Describe the bug HuggingFaceDataset doesn't shuffle the dataset either when passing shuffle=True and also by calling the shuffle() method

To Reproduce Steps to reproduce the behavior:

  1. Run this code:
    
    normal_ds = HuggingFaceDataset("imdb", split="test")
    print('Normal')
    for i in range(10):
    print(f'shuffle: {normal_ds.shuffled}, label: {normal_ds[i][1]}, Text: {normal_ds[i][0]}')

shuffle_ds = HuggingFaceDataset("imdb", split="test", shuffle=True) print('shuffled') for i in range(10): print(f'shuffle: {shuffle_ds.shuffled}, label: {shuffle_ds[i][1]}, Text: {shuffle_ds[i][0]}')

normal_ds.shuffle() print('Normal shuffled') for i in range(10): print(f'shuffle: {normal_ds.shuffled}, label: {normal_ds[i][1]}, Text: {normal_ds[i][0]}')


2. Output:
![Screenshot from 2024-05-14 15-23-24](https://github.com/QData/TextAttack/assets/38348465/9b83da82-4117-4551-b913-2f42fe2aa0d1)

**Expected behavior**
Shuffle the dataset

**Screenshots or Traceback**
No Traceback.

**System Information (please complete the following information):**
 - OS: Linux
 - Library versions: pip freeze ->
 absl-py==2.1.0
accelerate==0.30.0
aiohttp==3.9.5
aiosignal==1.3.1
anyio==4.3.0
anytree==2.12.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.15.0
beautifulsoup4==4.12.3
bert-score==0.3.13
bleach==6.1.0
boto3==1.34.101
botocore==1.34.101
bpemb==0.3.5
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.2
conllu==4.5.3
contourpy==1.2.1
cycler==0.12.1
datasets==2.19.1
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.8
docopt==0.6.2
editdistance==0.8.1
exceptiongroup==1.2.1
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.14.0
flair==0.13.1
flatbuffers==24.3.25
fonttools==4.51.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
ftfy==6.2.0
gast==0.5.4
gdown==5.1.0
gensim==4.3.2
google-pasta==0.2.0
grpcio==1.63.0
h11==0.14.0
h5py==3.11.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
idna==3.7
ipykernel==6.29.4
ipython==8.24.0
ipywidgets==8.1.2
isoduration==20.11.0
Janome==0.5.0
jedi==0.19.1
jieba==0.42.1
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
json5==0.9.25
jsonpointer==2.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
keras==3.3.3
kiwisolver==1.4.5
langdetect==1.0.9
language-tool-python==2.8
lemminflect==0.2.3
libclang==18.1.1
lru-dict==1.3.0
lxml==5.2.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.3.2
more-itertools==10.2.0
mpld3==0.5.10
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nltk==3.8.1
notebook==7.1.3
notebook_shim==0.2.4
num2words==0.5.13
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
OpenHowNet==2.0
opt-einsum==3.3.0
optree==0.11.0
overrides==7.7.0
packaging==24.0
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==10.3.0
pinyin==0.4.0
platformdirs==4.2.1
pptree==3.1
prometheus_client==0.20.0
prompt-toolkit==3.0.43
protobuf==4.25.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycparser==2.22
Pygments==2.18.0
pyparsing==3.1.2
PySocks==1.7.1
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
pytorch_revgrad==0.2.0
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
referencing==0.35.1
regex==2024.4.28
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
s3transfer==0.10.1
safetensors==0.4.3
scikit-learn==1.4.2
scipy==1.11.4
segtok==1.5.11
semver==3.0.2
Send2Trash==1.8.3
sentence-transformers==2.7.0
sentencepiece==0.2.0
six==1.16.0
smart-open==7.0.4
sniffio==1.3.1
soupsieve==2.5
sqlitedict==2.1.0
stack-data==0.6.3
sympy==1.12
tabulate==0.9.0
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.0
termcolor==2.4.0
terminado==0.18.1
terminaltables==3.1.10
textattack==0.3.10
tf_keras==2.16.0
threadpoolctl==3.5.0
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.3.0
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
transformer-smaller-training-vocab==0.4.0
transformers==4.40.2
triton==2.3.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tzdata==2024.1
uri-template==1.3.0
urllib3==1.26.18
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.3
widgetsnbextension==4.0.10
Wikipedia-API==0.6.0
word2number==1.1
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4

 - Textattack version 0.3.10

**Additional context**
in the constructor on line 152 `self._dataset.shuffle()` return the shuffled dataset object but it isn't saved in the HuggingFaceDataset dataset variable.
There's to say that the interpreter when going into the `self._dataset.shuffle()` shuffle() reference it goes into the `arrow_dataset.py`.