SQL读取类
官方暂时没有一键读取的功能类,因此需要自己写,一个简单的例子如下,运用请参考最下面的示例代码:
class SQLiteData(DataBase):
"""自定义的SQL lite数据格式"""
params = (
('dataname', None), # 策略中读取数据库是用到的名称
('name', ''), # 绘图时用到的名称
('timeframe', TimeFrame.Days), # 每条K线代表的时间长短
('fromdate', None), # 从什么时候开始
('todate', None), # 到什么时候截止
)
def __init__(self):
self.engine = create_engine('sqlite:///local_sql_lite.db') # 初始化数据库连接
self.tabel_name = "my_stock_code" # 数据表名称
def start(self): # 只会在加载数据前执行一次,常用于初始化参数
self.conn = self.engine.connect()
sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
.format(self.tabel_name)
self.result = self.conn.execute(sql_query)
def stop(self): # 结束数据加载程序之后执行一次,常用于关闭数据库链接
self.engine.dispose()
def _load(self): # 类似于策略的 next(),预期执行几次 next(),就会执行几次 _load()
one_row = self.result.fetchone()
if one_row is None:
return False
self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S')) # date parsing
self.lines.open[0] = float(one_row[1])
self.lines.high[0] = float(one_row[2])
self.lines.low[0] = float(one_row[3])
self.lines.close[0] = float(one_row[4])
self.lines.volume[0] = int(one_row[5])
self.lines.turnover[0] = float(one_row[6])
self.lines.openinterest[0] = -1
return True
其中有几个比较重要的函数:
def start()
:在加载数据前执行一次,常用于初始化参数def stop()
:结束数据加载程序之后执行一次,常用于关闭数据库链接def _load()
:在策略中的next()拿到的数据其实就是这里传过去的数据,会循环执行多次这个函数,直到取得全部数据或接收到False
或None
的返回值params
:这是定义数据集本身的一些参数,重要参数已在示例代码中解释,更多参数请参考官网
注意:def _load() 函数中:
self.lines
:表示第一个数据集的列,等同于self.datas[0].lines
,是一种简写形式self.lines.open
:指代数据中的开盘价那一列self.lines.open[0]
:特指当天的开盘价,如果是self.lines.open[-1],就是昨天的,如果是self.lines.open[1] 就是明天的
示例代码
import backtrader
import efinance
import pandas as pd
from datetime import datetime
import sqlite3
import datetime as dt
from backtrader import TimeFrame
from backtrader.feed import DataBase
from backtrader import date2num
from sqlalchemy import create_engine
def get_k_data(stock_code, begin: datetime, end: datetime) -> pd.DataFrame:
"""根据efinance工具包获取股票数据
:param stock_code:股票代码
:param begin: 开始日期
:param end: 结束日期
"""
# stock_code = '600519' # 股票代码,茅台
k_dataframe: pd.DataFrame = efinance.stock.get_quote_history(
stock_code, beg=begin.strftime("%Y%m%d"), end=end.strftime("%Y%m%d"))
k_dataframe = k_dataframe.iloc[:, :9]
k_dataframe.columns = ['name', 'code', 'date', 'open', 'close', 'high', 'low', 'volume', 'turnover']
k_dataframe.index = pd.to_datetime(k_dataframe.date)
k_dataframe.drop(['name', 'code', "date"], axis=1, inplace=True)
return k_dataframe
def write_sql_lite_from_pandas(stock_code, begin: datetime, end: datetime):
"""获取K线数据,并保存到SQL lite数据库"""
conn = sqlite3.connect('local_sql_lite.db')
dataframe = get_k_data(stock_code, begin=begin, end=end)
dataframe.to_sql("my_stock_code", conn, if_exists="replace") # 保存数据到数据库中,表名称:my_stock_code
class SQLiteData(DataBase):
"""自定义的SQL lite数据格式"""
params = (
('dataname', None), # 策略中读取数据库是用到的名称
('name', ''), # 绘图时用到的名称
('timeframe', TimeFrame.Days), # 每条K线代表的时间长短
('fromdate', None), # 从什么时候开始
('todate', None), # 到什么时候截止
# 下面是除了默认的open,close,high,low,volume外,新添加的维度
('turnover', -1),
)
# 新添加数据列用法相同
lines = ('turnover',)
def __init__(self):
self.engine = create_engine('sqlite:///local_sql_lite.db')
self._timeframe = self.p.timeframe
self._compression = self.p.compression
self._dataname = "my_stock_code"
def start(self):
self.conn = self.engine.connect()
sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
.format(self._dataname)
self.result = self.conn.execute(sql_query)
def stop(self):
self.engine.dispose()
def _load(self):
# 会全部循环完毕,然后再读取
one_row = self.result.fetchone()
if one_row is None:
return False
self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S')) # date parsing
self.lines.open[0] = float(one_row[1])
self.lines.high[0] = float(one_row[2])
self.lines.low[0] = float(one_row[3])
self.lines.close[0] = float(one_row[4])
self.lines.volume[0] = int(one_row[5])
self.lines.turnover[0] = float(one_row[6])
self.lines.openinterest[0] = -1
return True
class MyStrategy1(backtrader.Strategy): # 策略
def __init__(self):
# 初始化交易指令、买卖价格和手续费
self.close_price = self.datas[0].close # 这里加一个数据引用,方便后续操作
this_data = self.getdatabyname("stock_600519") # 获取传入的 name = stock_600519 的数据
print("全部列名:", this_data.getlinealiases()) # 全部的列名称
def next(self): # 框架执行过程中会不断循环next(),过一个K线,执行一次next()
print('=======================')
print("今天是:", self.datetime.date())
print("当前的值:", dict(zip(self.datas[0].getlinealiases(), [i[0] for i in list(self.datas[0].lines)])))
def main():
# 获取数据
start_time = datetime(2015, 1, 1)
end_time = datetime(2015, 1, 10)
# 先保存数据到 SQL lite 数据库,用于后续的读取
write_sql_lite_from_pandas("600519", start_time, end_time)
# 从SQL lite数据库读取数据
data = SQLiteData()
# =============== 为系统注入数据 =================
# 加载数据
# data = PandasDataPlus(dataname=dataframe, fromdate=start_time, todate=end_time)
# 初始化cerebro回测系统
cerebral_system = backtrader.Cerebro() # Cerebro引擎在后台创建了broker(经纪人)实例,系统默认每个broker的初始资金量为10000
# 将数据传入回测系统
cerebral_system.adddata(data, name="stock_600519") # 导入数据,在策略中使用 self.datas 来获取数据源
# 将交易策略加载到回测系统中
cerebral_system.addstrategy(MyStrategy1)
# =============== 系统设置 ==================
# 运行回测系统
cerebral_system.run()
if __name__ == '__main__':
main()