-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
183 lines (148 loc) · 5.81 KB
/
main.py
File metadata and controls
183 lines (148 loc) · 5.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import json
import sys
import os
# 添加当前目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from data_fetcher import get_stock_data, convert_to_backtrader_data
from grid_strategy import GridStrategy
from backtest_engine import BacktestEngine
from visualization import YieldVisualization
import backtrader as bt
import pandas as pd
def load_config(config_file='config.json'):
"""
加载配置文件
参数:
config_file: 配置文件路径
返回:
dict: 配置参数字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
print(f"加载配置文件时出错: {e}")
return {}
def main():
# 加载配置
config = load_config('config.json')
if not config:
print("无法加载配置文件,程序退出")
return
print("配置加载成功:")
for key, value in config.items():
print(f" {key}: {value}")
# 获取股票数据
print("\n正在获取股票数据...")
stock_data = get_stock_data(
symbol=config['symbol'],
start_date=config['start_date'],
end_date=config['end_date'],
period=config['period']
)
if stock_data is None or stock_data.empty:
print("获取股票数据失败,程序退出")
return
print(f"成功获取股票数据,共 {len(stock_data)} 条记录")
# 转换为Backtrader数据格式
bt_data = convert_to_backtrader_data(stock_data)
if bt_data is None:
print("转换数据格式失败,程序退出")
return
# 获取基准股票数据
print("\n正在获取基准股票数据...")
benchmark_data = get_stock_data(
symbol=config['benchmark_symbol'],
start_date=config['start_date'],
end_date=config['end_date'],
period=config['period']
)
if benchmark_data is None or benchmark_data.empty:
print("获取基准股票数据失败")
benchmark_bt_data = None
else:
print(f"成功获取基准股票数据,共 {len(benchmark_data)} 条记录")
benchmark_bt_data = convert_to_backtrader_data(benchmark_data)
# 初始化回测引擎
print("\n正在初始化回测引擎...")
engine = BacktestEngine()
# 添加策略
engine.add_strategy(
GridStrategy,
grid_num=config['grid_num'],
grid_spacing=config['grid_spacing'],
allocation_ratio=config['allocation_ratio']
)
# 添加数据
engine.add_data(bt_data)
# 设置初始资金
engine.set_initial_cash(config['initial_cash'])
# 添加基准数据(如果有)
if benchmark_bt_data is not None:
# 注意:这里需要修改BacktestEngine以支持添加基准数据
print("基准数据已添加(注意:实际实现中需要修改BacktestEngine以完全支持基准对比)")
# 运行回测
print("\n正在运行回测...")
results = engine.run()
if not results:
print("回测运行失败,程序退出")
return
print("回测运行完成")
# 获取回测结果
result_dict = engine.get_results(results)
print("\n回测结果:")
total_return = result_dict.get('total_return', 0)
annual_return = result_dict.get('annual_return', 0)
sharpe_ratio = result_dict.get('sharpe_ratio', 0)
max_drawdown = result_dict.get('max_drawdown', 0)
print(f" 总收益率: {total_return:.4f}" if total_return is not None else " 总收益率: N/A")
print(f" 年化收益率: {annual_return:.4f}" if annual_return is not None else " 年化收益率: N/A")
print(f" 夏普比率: {sharpe_ratio:.4f}" if sharpe_ratio is not None else " 夏普比率: N/A")
print(f" 最大回撤: {max_drawdown:.4f}" if max_drawdown is not None else " 最大回撤: N/A")
# 如果有基准数据,计算基准收益率
benchmark_returns = None
if benchmark_data is not None and hasattr(benchmark_data, 'close'):
# 计算基准收益率
benchmark_returns = (benchmark_data['close'] / benchmark_data['close'].iloc[0]) - 1
print(f"\n基准股票({config['benchmark_symbol']})累计收益率: {benchmark_returns.iloc[-1]:.4f}")
# 可视化结果
print("\n正在生成可视化图表...")
viz = YieldVisualization()
# 计算策略收益率
value_history = result_dict.get('value_history', [])
if value_history and len(value_history) > 0:
strategy_returns = [(v / config['initial_cash']) - 1 for v in value_history]
else:
# 如果没有价值历史,创建一个简单的模拟序列
strategy_returns = [0.0] * len(stock_data) if len(stock_data) > 0 else [0.0] * 100
# 创建收益率对比图
if benchmark_returns is not None and len(benchmark_returns) > 0:
# 确保两个序列长度一致
min_len = min(len(strategy_returns), len(benchmark_returns))
if min_len > 0:
strategy_returns_trimmed = strategy_returns[:min_len]
benchmark_returns_trimmed = benchmark_returns[:min_len]
else:
strategy_returns_trimmed = [0.0] * 100
benchmark_returns_trimmed = [0.0] * 100
fig = viz.create_yield_comparison_chart(
strategy_returns_trimmed,
benchmark_returns_trimmed.values,
"网格策略",
f"基准({config['benchmark_symbol']})"
)
else:
fig = viz.create_yield_comparison_chart(
strategy_returns,
[0] * len(strategy_returns),
"网格策略",
"基准(无数据)"
)
# 添加交互功能
viz.add_interactive_features()
# 保存图表
viz.save_chart("yield_comparison.html")
print("\n程序运行完成,结果已保存到 yield_comparison.html")
if __name__ == "__main__":
main()