zvtvz / zvt

modular quant framework.
https://zvt.readthedocs.io/en/latest/
MIT License
3.13k stars 856 forks source link

实现插件的Stock K线 Recorder 出现问题 #99

Closed Zkffkah closed 4 years ago

Zkffkah commented 4 years ago

尝试接入baostock日k,测试的recorder如下

# -*- coding: utf-8 -*-

import baostock as bs
import pandas as pd

from zvt.api import get_kdata, AdjustType
from zvt.api.quote import generate_kdata_id, get_kdata_schema, StockKdataCommon, Stock
from zvt.contract import IntervalLevel
from zvt.contract.api import df_to_db
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.domain import register_schema, declarative_base

from zvt_tm.domain import Stock1dKdata
from zvt_tm.recorders.baostock.common import to_bs_trading_level, to_bs_entity_id
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import to_time_str, TIME_FORMAT_DAY, TIME_FORMAT_ISO8601

class BaoStockChinaStockKdataRecorder(FixedCycleDataRecorder):
    entity_provider = 'eastmoney'
    entity_schema = Stock

    # 数据来自bs
    provider = 'baostock'

    # 只是为了把recorder注册到data_schema
    data_schema = StockKdataCommon

    def __init__(self,
                 exchanges=['sh', 'sz'],
                 entity_ids=None,
                 codes=None,
                 batch_size=10,
                 force_update=True,
                 sleeping_time=0,
                 default_size=2000,
                 real_time=False,
                 fix_duplicate_way='ignore',
                 start_timestamp=None,
                 end_timestamp=None,
                 level=IntervalLevel.LEVEL_1WEEK,
                 kdata_use_begin_time=False,
                 close_hour=15,
                 close_minute=0,
                 one_day_trading_minutes=4 * 60,
                 adjust_type=AdjustType.qfq) -> None:
        level = IntervalLevel(level)
        adjust_type = AdjustType(adjust_type)
        self.data_schema = get_kdata_schema(entity_type='stock', level=level, adjust_type=adjust_type)
        self.bs_trading_level = to_bs_trading_level(level)

        super().__init__('stock', exchanges, entity_ids, codes, batch_size, force_update, sleeping_time,
                         default_size, real_time, fix_duplicate_way, start_timestamp, end_timestamp, close_hour,
                         close_minute, level, kdata_use_begin_time, one_day_trading_minutes)
        self.adjust_type = adjust_type

        print("尝试登陆baostock")
        #####login#####
        lg = bs.login(user_id="anonymous", password="123456")
        if (lg.error_code == '0'):
            print("登陆成功")
        else:
            print("登录失败")

    def generate_domain_id(self, entity, original_data):
        return generate_kdata_id(entity_id=entity.id, timestamp=original_data['timestamp'], level=self.level)

    def recompute_qfq(self, entity, qfq_factor, last_timestamp):
        # 重新计算前复权数据
        if qfq_factor != 0:
            kdatas = get_kdata(provider=self.provider, entity_id=entity.id, level=self.level.value,
                               order=self.data_schema.timestamp.asc(),
                               return_type='domain',
                               session=self.session,
                               filters=[self.data_schema.timestamp < last_timestamp])
            if kdatas:
                self.logger.info('recomputing {} qfq kdata,factor is:{}'.format(entity.code, qfq_factor))
                for kdata in kdatas:
                    kdata.open = round(kdata.open * qfq_factor, 2)
                    kdata.close = round(kdata.close * qfq_factor, 2)
                    kdata.high = round(kdata.high * qfq_factor, 2)
                    kdata.low = round(kdata.low * qfq_factor, 2)
                self.session.add_all(kdatas)
                self.session.commit()

    def on_finish(self):
        super().on_finish()
        bs.logout()

    def record(self, entity, start, end, size, timestamps):
        if self.adjust_type == AdjustType.hfq:
            adflag = '1'
        else:
            adflag = '2'

        if not self.end_timestamp:
            data = bs.query_history_k_data(to_bs_entity_id(entity),
                                           "date,code,open,high,low,close,volume,amount",
                                           start_date=to_time_str(start),
                                           frequency=self.bs_trading_level, adjustflag=adflag)
        else:
            end_timestamp = to_time_str(self.end_timestamp)
            data = bs.query_history_k_data(to_bs_entity_id(entity),
                                         "date,code,open,high,low,close,volume,amount",
                                         start_date = to_time_str(start),
                                         end_date = end_timestamp,
                                         frequency=self.bs_trading_level, adjustflag=adflag)
        df = data.get_data()
        if pd_is_not_null(df):
            df['name'] = entity.name
            df.rename(columns={'amount': 'turnover', 'date': 'timestamp'}, inplace=True)

            df['entity_id'] = entity.id
            df['timestamp'] = pd.to_datetime(df['timestamp'])
            df['provider'] = 'baostock'
            df['level'] = self.level.value
            df['code'] = entity.code

            # 判断是否需要重新计算之前保存的前复权数据
            # if self.adjust_type == AdjustType.qfq:
            #     check_df = df.head(1)
            #     check_date = check_df['timestamp'][0]
            #     current_df = get_kdata(entity_id=entity.id, provider=self.provider, start_timestamp=check_date,
            #                            end_timestamp=check_date, limit=1, level=self.level,
            #                            adjust_type=self.adjust_type)
            #     if pd_is_not_null(current_df):
            #         old = current_df.iloc[0, :]['close']
            #         new = check_df['close'][0]
            #         # 相同时间的close不同,表明前复权需要重新计算
            #         if round(old, 2) != round(new, 2):
            #             qfq_factor = new / old
            #             last_timestamp = pd.Timestamp(check_date)
            #             self.recompute_qfq(entity, qfq_factor=qfq_factor, last_timestamp=last_timestamp)

            def generate_kdata_id(se):
                if self.level >= IntervalLevel.LEVEL_1DAY:
                    return "{}_{}".format(se['entity_id'], to_time_str(se['timestamp'], fmt=TIME_FORMAT_DAY))
                else:
                    return "{}_{}".format(se['entity_id'], to_time_str(se['timestamp'], fmt=TIME_FORMAT_ISO8601))

            df['id'] = df[['entity_id', 'timestamp']].apply(generate_kdata_id, axis=1)

            df_to_db(df=df, data_schema=self.data_schema, provider=self.provider, force_update=self.force_update)

        return None

__all__ = ['BaoStockChinaStockKdataRecorder']

if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--level', help='trading level', default='1d', choices=[item.value for item in IntervalLevel])
    # parser.add_argument('--codes', help='codes', default=['000001'], nargs='+')
    #
    # args = parser.parse_args()
    #
    # level = IntervalLevel(args.level)
    # codes = args.codes

    # init_log('baostock_china_stock_{}_kdata.log'.format(args.level))
    # BaoStockChinaStockKdataRecorder(level=IntervalLevel.LEVEL_1DAY, sleeping_time=0, codes=['000001'], real_time=False,
    #                           adjust_type=AdjustType.qfq).run()

    Stock1dKdata.record_data(provider='baostock', sleeping_time=1)

    # print(get_kdata(entity_id='stock_sz_000001', limit=10, order=Stock1dHfqKdata.timestamp.desc(),
    #                 adjust_type=AdjustType.hfq))
def to_bs_trading_level(trading_level: IntervalLevel):
    if trading_level < IntervalLevel.LEVEL_1HOUR:
        return trading_level.value

    if trading_level == IntervalLevel.LEVEL_1HOUR:
        return '60'
    if trading_level == IntervalLevel.LEVEL_4HOUR:
        return '240'
    if trading_level == IntervalLevel.LEVEL_1DAY:
        return 'd'
    if trading_level == IntervalLevel.LEVEL_1WEEK:
        return 'w'
    if trading_level == IntervalLevel.LEVEL_1MON:
        return 'm'

def to_bs_entity_id(security_item):
    if security_item.entity_type == 'stock' or security_item.entity_type == 'index':
        if security_item.exchange == 'sh':
            return 'sh.{}'.format(security_item.code)
        if security_item.exchange == 'sz':
            return 'sz.{}'.format(security_item.code)

尝试直接修改zvt/domain/quotes/stock/stock_1d_kdata.py

# -*- coding: utf-8 -*-
# this file is generated by gen_kdata_schema function, dont't change it
from sqlalchemy.ext.declarative import declarative_base

from zvt.contract.register import register_schema
from zvt.domain.quotes import StockKdataCommon

KdataBase = declarative_base()

class Stock1dKdata(KdataBase, StockKdataCommon):
    __tablename__ = 'stock_1d_kdata'

register_schema(providers=['joinquant'], db_name='stock_1d_kdata', schema_base=KdataBase)
+ register_schema(providers=['baostock'], db_name='stock_1d_kdata', schema_base=KdataBase)//加了这句

__all__ = ['Stock1dKdata']

抓取数据正常,使用factor没问题

 factor = ImprovedMaFactor(entity_schema=Stock, start_timestamp='2020-01-01',
                              end_timestamp=now_pd_timestamp(), need_persist=False, provider='baostock',entity_provider='eastmoney',
                              level=IntervalLevel.LEVEL_1DAY)
    print(factor.result_df)

尝试将

register_schema(providers=['baostock'], db_name='stock_1d_kdata', schema_base=KdataBase)

移动到单独插件,则出现问题。

报错堆栈

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Users/xxx/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/201.8538.36/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Users/xxx/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/201.8538.36/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/xxx/zvt-tm/zvt_tm/reports/report_crypto_tm.py", line 79, in <module>
    level=IntervalLevel.LEVEL_1DAY)
  File "/usr/local/lib/python3.7/site-packages/zvt/factors/ma/ma_factor.py", line 122, in __init__
    accumulator, need_persist, dry_run)
  File "/usr/local/lib/python3.7/site-packages/zvt/factors/technical_factor.py", line 49, in __init__
    effective_number, transformer, accumulator, need_persist, dry_run)
  File "/usr/local/lib/python3.7/site-packages/zvt/factors/factor.py", line 120, in __init__
    category_field, time_field, computing_window)
  File "/usr/local/lib/python3.7/site-packages/zvt/contract/reader.py", line 139, in __init__
    self.load_data()
  File "/usr/local/lib/python3.7/site-packages/zvt/contract/reader.py", line 170, in load_data
    time_field=self.time_field)
  File "/usr/local/lib/python3.7/site-packages/zvt/contract/schema.py", line 85, in query_data
    filters=filters, session=session, order=order, limit=limit, index=index, time_field=time_field)
  File "/usr/local/lib/python3.7/site-packages/zvt/contract/api.py", line 328, in get_data
    df = pd.read_sql(query.statement, query.session.bind)
  File "/usr/local/lib/python3.7/site-packages/pandas/io/sql.py", line 410, in read_sql
    chunksize=chunksize,
  File "/usr/local/lib/python3.7/site-packages/pandas/io/sql.py", line 1645, in read_query
    cursor = self.execute(*args)
  File "/usr/local/lib/python3.7/site-packages/pandas/io/sql.py", line 1590, in execute
    cur = self.con.cursor()
AttributeError: 'NoneType' object has no attribute 'cursor'

调试信息如下 image session的bind为空,正常应该不为空

Zkffkah commented 4 years ago

解决了,register_schema理解错了。