Qihoo360 / tensornet

Apache License 2.0
315 stars 72 forks source link

document of `01-begin-with-wide-deep.ipynb` run with mpi throw error #31

Open zhangys-lucky opened 3 years ago

zhangys-lucky commented 3 years ago

保存文档begin-with-wide-deepwide_deep.py执行时报错:

750/750 [==============================] - 5s 6ms/step - loss: 0.6378 - accuracy: 0.6670 - auc: 0.5043
    715/Unknown - 5s 7ms/step - loss: 0.6367 - accuracy: 0.6668 - auc: 0.5175F1029 14:52:05.389852 13099 core/ps/ps_local_server.cc:61] Check failed: nullptr != table. 
#0 0x7fbfe75d2d2a tensornet::PsLocalServer::DensePushPullAsync()
#1 0x7fbfe75d6b3d tensornet::PsServiceImpl::DensePushPull()
#2 0x7fbfe75ecc6c tensornet::PsService::CallMethod()
#3 0x7fbfe7639d9e brpc::policy::ProcessRpcRequest()
#4 0x7fbfe7631547 brpc::ProcessInputMessage()
#5 0x7fbfe76324d7 brpc::InputMessenger::OnNewMessages()
#6 0x7fbfe76d268d brpc::Socket::ProcessEvent()
#7 0x7fbfe775cfc1 bthread::TaskGroup::task_runner()
#8 0x7fbfe7746731 bthread_make_fcontext
zhangys-lucky commented 3 years ago

fix code: add callbacks for train fit:

def train(strategy, callbacks=[]):
    with strategy.scope():
        wide_column, deep_column = columns_builder()
        model = create_model(wide_column, deep_column)

        train_dataset = read_dataset(TEST_DATA_PATH, C.FILE_MATCH_PATTERN)
        model.fit(train_dataset, epochs=1, verbose=1,callbacks=callbacks)

    return

train tensornet with callbacks

cp_cb = tn.callbacks.PsWeightCheckpoint("./model")
train(tn.distribute.PsStrategy(), [cp_cb])