bigdata-ustc / EduKTM

The Model Zoo of Knowledge Tracing Models
Apache License 2.0
197 stars 63 forks source link

DKT training fails for batch size of 1 #19

Closed lukas-a-olson closed 2 years ago

lukas-a-olson commented 2 years ago

🐛 Description

DKT training fails for batch size of 1, but works for larger batch sizes (i.e. 64)

Error Message

RuntimeErrorTraceback (most recent call last)
<ipython-input-15-0a528344bb33> in <module>
      5 # Initialize and train model
      6 dkt = DKT(NUM_QUESTIONS, HIDDEN_SIZE, NUM_LAYERS)
----> 7 dkt.train(train_loader, epoch=50)
      8 
      9 # Save weights

/usr/local/lib/python3.6/dist-packages/EduKTM/DKT/DKT.py in train(self, train_data, test_data, epoch, lr)
     61                 # back propagation
     62                 optimizer.zero_grad()
---> 63                 loss.backward()
     64                 optimizer.step()
     65 

/usr/local/lib/python3.6/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305                 create_graph=create_graph,
    306                 inputs=inputs)
--> 307         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    308 
    309     def register_hook(self, hook):

/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    154     Variable._execution_engine.run_backward(
    155         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    157 
    158 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

To Reproduce

Run the example notebook (https://github.com/bigdata-ustc/EduKTM/blob/main/examples/DKT/DKT.ipynb) and set the BATCH_SIZE variable to 1.

What have you tried to solve it?

Increasing the batch size avoids the error.

Environment

Environment Information **Operating System:** Ubuntu in docker image: tensorflow/tensorflow:2.6.0-gpu-jupyter. Also tested in Google Colab notebook. **Python Version:** Python 3.6.9 pip freeze: ``` absl-py==0.13.0 aiocontextvars==0.2.2 altair==4.1.0 ansi2html==1.6.0 anyio==3.3.4 argon2-cffi==20.1.0 asgiref==3.4.1 asn1crypto==0.24.0 astor==0.8.1 astunparse==1.6.3 async-generator==1.10 attrs==21.2.0 Babel==2.9.1 backcall==0.2.0 backports.zoneinfo==0.2.1 base58==2.1.0 beautifulsoup4==4.10.0 bleach==4.0.0 blinker==1.4 Brotli==1.0.9 bs4==0.0.1 cached-property==1.5.2 cachetools==4.2.2 certifi==2021.5.30 cffi==1.14.6 charset-normalizer==2.0.4 clang==5.0 click==7.1.2 contextlib2==21.6.0 contextvars==2.4 cryptography==2.1.4 cycler==0.10.0 dash==2.0.0 dash-core-components==2.0.0 dash-html-components==2.0.0 dash-table==5.0.0 dataclasses==0.8 decorator==4.4.2 defusedxml==0.7.1 EduData==0.0.18 EduKTM==0.0.9 entrypoints==0.3 et-xmlfile==1.1.0 fastapi==0.70.0 fire==0.4.0 Flask==2.0.2 Flask-Compress==1.10.1 flatbuffers==1.12 gast==0.4.0 gitdb==4.0.8 GitPython==3.1.18 google-auth==1.34.0 google-auth-oauthlib==0.4.5 google-pasta==0.2.0 grpcio==1.39.0 h11==0.12.0 h5py==3.1.0 idna==3.3 immutables==0.16 importlib-metadata==4.6.3 importlib-resources==5.3.0 ipykernel==5.5.6 ipython==7.16.1 ipython-genutils==0.2.0 ipywidgets==7.6.3 itsdangerous==2.0.1 jedi==0.18.0 Jinja2==3.0.1 joblib==1.1.0 json5==0.9.6 jsonschema==3.2.0 jupyter==1.0.0 jupyter-client==6.1.12 jupyter-console==6.4.0 jupyter-core==4.7.1 jupyter-dash==0.4.0 jupyter-http-over-ws==0.0.8 jupyter-server==1.11.1 jupyterlab==3.0.16 jupyterlab-pygments==0.1.2 jupyterlab-server==2.8.2 jupyterlab-widgets==1.0.0 keras==2.6.0 Keras-Preprocessing==1.1.2 keyring==10.6.0 keyrings.alt==3.0 kiwisolver==1.3.1 loguru==0.5.3 longling==1.3.32 lxml==4.6.3 Markdown==3.3.4 MarkupSafe==2.0.1 matplotlib==3.3.4 mistune==0.8.4 nbclassic==0.3.3 nbclient==0.5.3 nbconvert==6.0.7 nbformat==5.1.3 nest-asyncio==1.5.1 networkx==2.5.1 notebook==6.4.3 numpy==1.19.5 oauthlib==3.1.1 openpyxl==3.0.9 opt-einsum==3.3.0 opyrator==0.0.12 packaging==21.0 pandas==1.1.5 pandocfilters==1.4.3 parso==0.8.2 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.3.1 plotly==5.3.1 prometheus-client==0.11.0 prompt-toolkit==3.0.19 protobuf==3.17.3 ptyprocess==0.7.0 pyarrow==5.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser==2.20 pycrypto==2.6.1 pydantic==1.8.2 pydeck==0.6.2 Pygments==2.9.0 PyGObject==3.26.1 pygraphviz==1.6 pyparsing==2.4.7 pyrsistent==0.18.0 python-apt==1.6.5+ubuntu0.7 python-dateutil==2.8.2 pytz==2021.3 pytz-deprecation-shim==0.1.0.post0 pyxdg==0.25 PyYAML==6.0 pyzmq==22.2.1 qtconsole==5.1.1 QtPy==1.9.0 rarfile==4.0 requests==2.26.0 requests-oauthlib==1.3.0 requests-unixsocket==0.2.0 retrying==1.3.3 rsa==4.7.2 scikit-learn==0.24.2 scipy==1.5.4 seaborn==0.11.2 SecretStorage==2.3.1 Send2Trash==1.8.0 six==1.15.0 sklearn==0.0 smmap==5.0.0 sniffio==1.2.0 soupsieve==2.2.1 starlette==0.16.0 streamlit==1.1.0 tenacity==8.0.1 tensorboard==2.6.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorflow==2.6.0 tensorflow-estimator==2.6.0 termcolor==1.1.0 terminado==0.10.1 testpath==0.5.0 threadpoolctl==3.0.0 toml==0.10.2 toolz==0.11.1 torch==1.10.0 tornado==6.1 tqdm==4.62.3 traitlets==4.3.3 typer==0.4.0 typing-extensions==3.7.4.3 tzdata==2021.4 tzlocal==4.0.1 urllib3==1.26.6 uvicorn==0.15.0 validators==0.18.2 watchdog==2.1.6 wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.2.1 Werkzeug==2.0.1 widgetsnbextension==3.5.1 wrapt==1.12.1 zipp==3.5.0 ```
tswsxk commented 2 years ago

@sone47