lukalabs / cakechat

CakeChat: Emotional Generative Dialog System
Apache License 2.0
1.7k stars 935 forks source link

Training own model #16

Closed jacobdanovitch closed 6 years ago

jacobdanovitch commented 6 years ago

Hi,

I've loaded my own training and validation corpus, ran prepare_index_files.py, and trained it with no issue. Afterwards, when I ran python bin/cakechat_server.py, it continually threw this error:

Traceback (most recent call last):
  File "bin/cakechat_server.py", line 10, in <module>
    from cakechat.api.v1.server import app
  File "C:\...\cakechat\cakechat\api\v1\server.py", line 3, in <module>
    from cakechat.api.response import get_response
  File "C:\...\cakechat\cakechat\api\response.py", line 14, in <module>
    _cakechat_model = get_trained_model(fetch_from_s3=False)
  File "C:\...\cakechat\cakechat\dialog_model\factory.py", line 53, in get_trained_model
    raise Exception('Can\'t get the model. '
Exception: Can't get the model. Run tools/download_model.py first to get all required files or train it by yourself.

I messed around with get_nn_model() in dialog_model/model.py a bit and realized it was looking for a file named: processed_dialogs_gru_hd512_drop0.2_encd2_decd2_il30_cs3_ansl32_lr1.0_gc_5.0_learnemb_cdim128_window10_voc11786_vec128_sgTrue

My file in data/nn_models was called: processed_dialogs_gru_hd7_drop0.2_encd2_decd2_il7_cs3_ansl9_lr1.0_gc_5.0_learnemb_cdim128_window10_voc11786_vec15_sgTrue_pp_free2926.32_sensitive3066.80

I made a copy and renamed it. I then tried to both run the server and train it again, and as soon as it tried to load the model, in both instances I got: ValueError: mismatch: parameter has shape (11786L, 128L) but value to set has shape (11786L, 15L)

Not really sure where to go from here; thanks in advance. Running Windows / Anaconda with py2.7. All dependencies installed and everything else is running fine thus far. I had it working with the pre-trained model. Tried running the server through both Git Bash and cmd, if that makes a difference. Trained through Bash.

rodart commented 6 years ago

Hi Jacob, It seems that you trained your model with IS_DEV env flag. Filename gives you a clue about params of your NN. I see two ways to solve the problem 1) rename you file processed_dialogs_gru_hd7_drop0.2_encd2_decd2_il7_cs3_ansl9_lr1.0_gc_5.0_learnemb_cdim128_window10_voc11786_vec15_sgTrue_pp_free2926.32_sensitive3066.80 -> processed_dialogs_gru_hd7_drop0.2_encd2_decd2_il7_cs3_ansl9_lr1.0_gc_5.0_learnemb_cdim128_window10_voc11786_vec15_sgTrue and run IS_DEV=1 python bin/cakechat_server.py 2) Train your model initially without IS_DEV flag

jacobdanovitch commented 6 years ago

Thank you, I'd forgotten I used IS_DEV, and I didn't realize I had to precede the server call with that with that as well. Works perfectly.

rodart commented 6 years ago

cool! let us know if you have other questions about cakechat