intel-analytics / analytics-zoo

Distributed Tensorflow, Keras and PyTorch on Apache Spark/Flink & Ray
https://analytics-zoo.readthedocs.io/
Apache License 2.0
16 stars 3 forks source link

Orca Pytorch Estimator can't fit RNN Models #521

Open jing-xu opened 3 years ago

jing-xu commented 3 years ago

Currently both the torch_distributed backend and bigdl backend could only support output = model(data) type of applications. For RNN models, we need output, hidden = model(data, hidden) in the training and evaluation process.

hkvision commented 3 years ago

@jason-dai @yangw1234 @qiuxin2012 Any ideas?

qiuxin2012 commented 3 years ago

how about the loss? Do you need the hidden to compute loss?

jing-xu commented 3 years ago

how about the loss? Do you need the hidden to compute loss?

@qiuxin2012 To my understanding, the hidden will not be directly used in compute loss, but it will contribute to the training and predicting process.

hkvision commented 3 years ago

https://github.com/pytorch/examples/blob/master/word_language_model/main.py