Open yaoyang33 opened 2 years ago
Hi, for nb211, did you download the checkpoints from this link? I have no issues when loading it. It might be a version mismatch with joblib.
for nb111, you are using the genotype of DARTS cells for prediction, that's why there is an input shape mismatch. We will add examples of how to predict learning curves with nb101, nb201, and nbnlp.
Hi, I downgraded TF to the version you specified and have nasbench installed. nb311-v0.5 surrogate model is able to be loaded, but not nb111 and nb211.
When I load nb211, I run into keyError Traceback (most recent call last): File "example.py", line 5, in
nb311_surrogate_model = load_ensemble('/Users/yao.a.yang/Documents/nas-bench-x11/checkpoints/nb211-v0.5')
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/api.py", line 68, in load_ensemble
surrogate_model.load(model_paths=ensemble_member_dirs)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/ensemble.py", line 142, in load
ens_mem.load(os.path.join(member_logdir, 'surrogate_model.model'))
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/models/svd_lgb.py", line 118, in load
if len(joblib.load(model_path)) == 5:
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/joblib/numpy_pickle.py", line 587, in load
obj = _unpickle(fobj, filename, mmap_mode)
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/joblib/numpy_pickle.py", line 506, in _unpickle
obj = unpickler.load()
File "/usr/local/anaconda3/envs/py36/lib/python3.6/pickle.py", line 1050, in load
dispatchkey[0]
KeyError: 239
For surrogate_model.model in nb111-v0.5, see error: /usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/base.py:315: UserWarning: Trying to unpickle estimator RegressorChain from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk. UserWarning) /usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/base.py:315: UserWarning: Trying to unpickle estimator StandardScaler from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk. UserWarning) <nas_bench_x11.models.svd_lgb.SVDLGBModel object at 0x7fa5ac487f28> [<nas_bench_x11.models.svd_lgb.SVDLGBModel object at 0x7fa5ac487f28>] Traceback (most recent call last): File "example.py", line 17, in
learning_curve = nb311_surrogate_model.predict(config=arch, representation="genotype", with_noise=True)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/api.py", line 113, in predict
pred = self.model.query(config_dict, search_space=search_space)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/ensemble.py", line 290, in query
use_noise=True)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/models/svd_lgb.py", line 161, in query
comp = self.model.predict(X)
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/multioutput.py", line 549, in predict Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/lightgbm/sklearn.py", line 800, in predict raise ValueError("Number of features of the model must " ValueError: Number of features of the model must match the input. Model nfeatures is 30 and input n_features is 56
Any guidance? Thanks!