Skip to content

Commit 93b6ae7

Browse files
add all tutorials
1 parent 36e1010 commit 93b6ae7

File tree

290 files changed

+621941
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

290 files changed

+621941
-0
lines changed

.DS_Store

0 Bytes
Binary file not shown.

1-Introduction/.ipynb_checkpoints/Stock_NeurIPS2018_ElegantRL-checkpoint.ipynb

+3,783
Large diffs are not rendered by default.

1-Introduction/.ipynb_checkpoints/Stock_NeurIPS2018_SB3-checkpoint.ipynb

+3,783
Large diffs are not rendered by default.

1-Introduction/China_A_share_market_tushare.ipynb

+3,012
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import warnings
2+
3+
warnings.filterwarnings("ignore")
4+
5+
import pandas as pd
6+
from IPython import display
7+
8+
display.set_matplotlib_formats("svg")
9+
10+
from meta import config
11+
from meta.data_processor import DataProcessor
12+
from main import check_and_make_directories
13+
from meta.data_processors.tushare import Tushare, ReturnPlotter
14+
from meta.env_stock_trading.env_stocktrading_China_A_shares import (
15+
StockTradingEnv,
16+
)
17+
from agents.stablebaselines3_models import DRLAgent
18+
import os
19+
from typing import List
20+
from argparse import ArgumentParser
21+
from meta import config
22+
from meta.config_tickers import DOW_30_TICKER
23+
from meta.config import (
24+
DATA_SAVE_DIR,
25+
TRAINED_MODEL_DIR,
26+
TENSORBOARD_LOG_DIR,
27+
RESULTS_DIR,
28+
INDICATORS,
29+
TRAIN_START_DATE,
30+
TRAIN_END_DATE,
31+
TEST_START_DATE,
32+
TEST_END_DATE,
33+
TRADE_START_DATE,
34+
TRADE_END_DATE,
35+
ERL_PARAMS,
36+
RLlib_PARAMS,
37+
SAC_PARAMS,
38+
ALPACA_API_KEY,
39+
ALPACA_API_SECRET,
40+
ALPACA_API_BASE_URL,
41+
)
42+
import pyfolio
43+
from pyfolio import timeseries
44+
45+
pd.options.display.max_columns = None
46+
47+
print("ALL Modules have been imported!")
48+
49+
50+
### Create folders
51+
52+
import os
53+
54+
"""
55+
use check_and_make_directories() to replace the following
56+
57+
if not os.path.exists("./datasets"):
58+
os.makedirs("./datasets")
59+
if not os.path.exists("./trained_models"):
60+
os.makedirs("./trained_models")
61+
if not os.path.exists("./tensorboard_log"):
62+
os.makedirs("./tensorboard_log")
63+
if not os.path.exists("./results"):
64+
os.makedirs("./results")
65+
"""
66+
67+
check_and_make_directories(
68+
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
69+
)
70+
71+
72+
### Download data, cleaning and feature engineering
73+
74+
ticker_list = [
75+
"600000.SH",
76+
"600009.SH",
77+
"600016.SH",
78+
"600028.SH",
79+
"600030.SH",
80+
"600031.SH",
81+
"600036.SH",
82+
"600050.SH",
83+
"600104.SH",
84+
"600196.SH",
85+
"600276.SH",
86+
"600309.SH",
87+
"600519.SH",
88+
"600547.SH",
89+
"600570.SH",
90+
]
91+
92+
TRAIN_START_DATE = "2015-01-01"
93+
TRAIN_END_DATE = "2019-08-01"
94+
TRADE_START_DATE = "2019-08-01"
95+
TRADE_END_DATE = "2020-01-03"
96+
97+
98+
TIME_INTERVAL = "1d"
99+
kwargs = {}
100+
kwargs["token"] = "27080ec403c0218f96f388bca1b1d85329d563c91a43672239619ef5"
101+
p = DataProcessor(
102+
data_source="tushare",
103+
start_date=TRAIN_START_DATE,
104+
end_date=TRADE_END_DATE,
105+
time_interval=TIME_INTERVAL,
106+
**kwargs,
107+
)
108+
109+
110+
# download and clean
111+
p.download_data(ticker_list=ticker_list)
112+
113+
114+
p.clean_data()
115+
116+
117+
# add_technical_indicator
118+
p.add_technical_indicator(config.INDICATORS)
119+
p.clean_data()
120+
print(f"p.dataframe: {p.dataframe}")
121+
122+
123+
### Split traning dataset
124+
125+
train = p.data_split(p.dataframe, TRAIN_START_DATE, TRAIN_END_DATE)
126+
print(f"len(train.tic.unique()): {len(train.tic.unique())}")
127+
128+
print(f"train.tic.unique(): {train.tic.unique()}")
129+
130+
print(f"train.head(): {train.head()}")
131+
132+
print(f"train.shape: {train.shape}")
133+
134+
stock_dimension = len(train.tic.unique())
135+
state_space = stock_dimension * (len(config.INDICATORS) + 2) + 1
136+
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")
137+
138+
### Train
139+
140+
env_kwargs = {
141+
"stock_dim": stock_dimension,
142+
"hmax": 1000,
143+
"initial_amount": 1000000,
144+
"buy_cost_pct": 6.87e-5,
145+
"sell_cost_pct": 1.0687e-3,
146+
"reward_scaling": 1e-4,
147+
"state_space": state_space,
148+
"action_space": stock_dimension,
149+
"tech_indicator_list": config.INDICATORS,
150+
"print_verbosity": 1,
151+
"initial_buy": True,
152+
"hundred_each_trade": True,
153+
}
154+
155+
e_train_gym = StockTradingEnv(df=train, **env_kwargs)
156+
157+
## DDPG
158+
159+
env_train, _ = e_train_gym.get_sb_env()
160+
print(f"print(type(env_train)): {print(type(env_train))}")
161+
162+
agent = DRLAgent(env=env_train)
163+
DDPG_PARAMS = {
164+
"batch_size": 256,
165+
"buffer_size": 50000,
166+
"learning_rate": 0.0005,
167+
"action_noise": "normal",
168+
}
169+
POLICY_KWARGS = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
170+
model_ddpg = agent.get_model(
171+
"ddpg", model_kwargs=DDPG_PARAMS, policy_kwargs=POLICY_KWARGS
172+
)
173+
174+
trained_ddpg = agent.train_model(
175+
model=model_ddpg, tb_log_name="ddpg", total_timesteps=10000
176+
)
177+
178+
## A2C
179+
180+
agent = DRLAgent(env=env_train)
181+
model_a2c = agent.get_model("a2c")
182+
183+
trained_a2c = agent.train_model(
184+
model=model_a2c, tb_log_name="a2c", total_timesteps=50000
185+
)
186+
187+
### Trade
188+
189+
trade = p.data_split(p.dataframe, TRADE_START_DATE, TRADE_END_DATE)
190+
env_kwargs = {
191+
"stock_dim": stock_dimension,
192+
"hmax": 1000,
193+
"initial_amount": 1000000,
194+
"buy_cost_pct": 6.87e-5,
195+
"sell_cost_pct": 1.0687e-3,
196+
"reward_scaling": 1e-4,
197+
"state_space": state_space,
198+
"action_space": stock_dimension,
199+
"tech_indicator_list": config.INDICATORS,
200+
"print_verbosity": 1,
201+
"initial_buy": False,
202+
"hundred_each_trade": True,
203+
}
204+
e_trade_gym = StockTradingEnv(df=trade, **env_kwargs)
205+
206+
df_account_value, df_actions = DRLAgent.DRL_prediction(
207+
model=trained_ddpg, environment=e_trade_gym
208+
)
209+
210+
df_actions.to_csv("action.csv", index=False)
211+
print(f"df_actions: {df_actions}")
212+
213+
### Backtest
214+
215+
# matplotlib inline
216+
plotter = ReturnPlotter(df_account_value, trade, TRADE_START_DATE, TRADE_END_DATE)
217+
# plotter.plot_all()
218+
219+
plotter.plot()
220+
221+
# matplotlib inline
222+
# # ticket: SSE 50:000016
223+
# plotter.plot("000016")
224+
225+
#### Use pyfolio
226+
227+
# CSI 300
228+
baseline_df = plotter.get_baseline("399300")
229+
230+
231+
daily_return = plotter.get_return(df_account_value)
232+
daily_return_base = plotter.get_return(baseline_df, value_col_name="close")
233+
234+
perf_func = timeseries.perf_stats
235+
perf_stats_all = perf_func(
236+
returns=daily_return,
237+
factor_returns=daily_return_base,
238+
positions=None,
239+
transactions=None,
240+
turnover_denom="AGB",
241+
)
242+
print("==============DRL Strategy Stats===========")
243+
print(f"perf_stats_all: {perf_stats_all}")
244+
245+
246+
daily_return = plotter.get_return(df_account_value)
247+
daily_return_base = plotter.get_return(baseline_df, value_col_name="close")
248+
249+
perf_func = timeseries.perf_stats
250+
perf_stats_all = perf_func(
251+
returns=daily_return_base,
252+
factor_returns=daily_return_base,
253+
positions=None,
254+
transactions=None,
255+
turnover_denom="AGB",
256+
)
257+
print("==============Baseline Strategy Stats===========")
258+
259+
print(f"perf_stats_all: {perf_stats_all}")

0 commit comments

Comments
 (0)