microsoft / qlib

Qlib is an AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms. including supervised learning, market dynamics modeling, and RL.
https://qlib.readthedocs.io/en/latest/
MIT License
15.47k stars 2.64k forks source link

method on save and load TFT model #1340

Closed nkchem09 closed 1 year ago

nkchem09 commented 2 years ago

❓ Questions and Help

We sincerely suggest you to carefully read the documentation of our library as well as the official paper. After that, if you still feel puzzled, please describe the question clearly under this issue.

When TFT model is build, how to save it and reload it for predicting? It seems not work by the method:

with R.start(experiment_name=experiment_name): R.log_params(**flatten_dict(task)) model.fit(dataset) R.save_objects(trained_model=model)

when loading the model and predict new dataset, with R.start(experiment_name=back_ex_name): recorder = R.get_recorder(recorder_id=rid, experiment_name=experiment_name) model = recorder.load_object("trained_model") pred = model.predict(dataset)

The error information will be got:

raise ValueError("model is not fitted yet!") ValueError: model is not fitted yet!

Thanks for help.

github-actions[bot] commented 1 year ago

This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days

LiuHao-THU commented 1 year ago

I met this problem also! I run this code in jupyter lab, When I run this code for the second time, The code run correctly.

Here Is the code for Prediction! `###################################

prediction, backtest & analysis

################################### port_analysis_config = { "executor": { "class": "SimulatorExecutor", "module_path": "qlib.backtest.executor", "kwargs": { "time_per_step": "day", "generate_portfolio_metrics": True, }, }, "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, "topk":3, "n_drop": 1, }, }, "backtest": { "start_time": "2021-01-01", "end_time": '2023-03-01', "account": 100000000, "benchmark": benchmark, "exchange_kwargs": { "freq": "day", "limit_threshold": 0.095, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, }, }

from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord

rid = '61cc4cd94ae24137ac196d7e82c8e666'

backtest and analysis

with R.start(experiment_name='PatchTsT', recorder_id=rid, resume=True): #

get model

model = R.load_object("trained_model")
# signal-based analysis
rec = R.get_recorder()
sar = SigAnaRecord(rec)
sar.generate()

# prediction
sr = SignalRecord(model, dataset, rec)
sr.generate()
ba_rid = rec.id 

#  portfolio-based analysis: backtest
par = PortAnaRecord(rec, port_analysis_config, "day")
par.generate()`