Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 319 additions & 0 deletions notebooks/sandbox_sma_implementation_comparison.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "7d40767d",
"metadata": {},
"outputs": [],
"source": [
"# filename: sandbox_sma_implementation_comparison.ipynb\n",
"\n",
"# --- Imports and Setup ---\n",
"import numpy as np\n",
"import pandas as pd\n",
"import vectorbt as vbt\n",
"from numba import njit\n",
"\n",
"# --- Data Loading ---\n",
"symbol = \"EURUSD=X\"\n",
"end_date = pd.Timestamp.now(tz=\"UTC\")\n",
"start_date = end_date - pd.Timedelta(days=729)\n",
"timeframe = \"1h\"\n",
"\n",
"print(\n",
" f\"Downloading {symbol} {timeframe} data from {start_date.date()} to {end_date.date()}...\"\n",
")\n",
"try:\n",
" ohlc_data = vbt.YFData.download(\n",
" symbol, start=start_date, end=end_date, interval=timeframe\n",
" ).get([\"Open\", \"High\", \"Low\", \"Close\"])\n",
"\n",
" if ohlc_data.empty:\n",
" raise ValueError(\"No data returned from yfinance.\")\n",
"\n",
" close_price = ohlc_data[\"Close\"]\n",
" print(\"Data download complete.\\n\")\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred during data download: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4145c6c",
"metadata": {},
"outputs": [],
"source": [
"# --- Method 1: TA-Lib ---\n",
"print(\"--- Running Optimization using 'TA-Lib' ---\")\n",
"\n",
"# 1. Create indicator factory\n",
"sma_factory_talib = vbt.IndicatorFactory.from_talib(\"SMA\")\n",
"\n",
"# 2. Define parameter ranges\n",
"fast_sma_range = np.arange(10, 30, 2)\n",
"slow_sma_range = np.arange(40, 60, 2)\n",
"\n",
"# 3. Calculate indicators\n",
"fast_sma_talib = sma_factory_talib.run(\n",
" close_price, timeperiod=fast_sma_range, short_name=\"fast\"\n",
")\n",
"slow_sma_talib = sma_factory_talib.run(\n",
" close_price, timeperiod=slow_sma_range, short_name=\"slow\"\n",
")\n",
"\n",
"# 4. Generate signals\n",
"entries = fast_sma_talib.real.vbt.crossed_above(slow_sma_talib.real)\n",
"exits = fast_sma_talib.real.vbt.crossed_below(slow_sma_talib.real)\n",
"\n",
"# 5. Run portfolio backtest\n",
"portfolio_talib = vbt.Portfolio.from_signals(\n",
" close_price, entries, exits, freq=timeframe, init_cash=10000\n",
")\n",
"\n",
"# 6. Analyze and display results\n",
"if portfolio_talib.trades.count().sum() == 0:\n",
" print(\"No trades were executed for any parameter combination.\")\n",
"else:\n",
" # Find the best parameter combination\n",
" best_params_talib = portfolio_talib.sharpe_ratio().idxmax()\n",
"\n",
" # Select the single best portfolio instance\n",
" best_portfolio = portfolio_talib[best_params_talib]\n",
"\n",
" # Get the stats for that single portfolio\n",
" best_stats_talib = best_portfolio.stats()\n",
"\n",
" print(f\"Best Parameters (fast_period, slow_period): {best_params_talib}\")\n",
" print(\"Performance for Best Parameters:\")\n",
" print(\n",
" best_stats_talib[\n",
" [\"Total Return [%]\", \"Sharpe Ratio\", \"Max Drawdown [%]\", \"Win Rate [%]\"]\n",
" ]\n",
" )\n",
"\n",
" # 7. Visualize the best strategy\n",
" # --- CORRECTED, LAYERED PLOTTING LOGIC ---\n",
"\n",
" # Get the indicator series for the best parameters\n",
" best_fast_sma = fast_sma_talib.real[best_params_talib[0]]\n",
" best_slow_sma = slow_sma_talib.real[best_params_talib[1]]\n",
"\n",
" # Step A: Create the base candlestick chart from the OHLC data\n",
" fig = ohlc_data.vbt.ohlc.plot(\n",
" title_text=f\"TA-Lib SMA Strategy | Best Params: {best_params_talib}\",\n",
" template=\"plotly_dark\",\n",
" )\n",
"\n",
" # Step B: Add the SMA lines to the existing figure\n",
" best_fast_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Fast SMA\", line=dict(color=\"cyan\"))\n",
" )\n",
" best_slow_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Slow SMA\", line=dict(color=\"orange\"))\n",
" )\n",
"\n",
" # Step C: Add only the trade markers to the existing figure\n",
" best_portfolio.trades.plot(fig=fig)\n",
"\n",
" # Step D: Show the final, combined figure\n",
" fig.show()\n",
"\n",
"print(\"-\" * 50 + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cc35354",
"metadata": {},
"outputs": [],
"source": [
"# --- Method 2: pandas-ta ---\n",
"print(\"--- Running Optimization using 'pandas-ta' ---\")\n",
"\n",
"# 1. Create indicator factory, naming the output 'sma'\n",
"# This tells vectorbt to create a '.sma' attribute on the output object\n",
"sma_factory_pdta = vbt.IndicatorFactory.from_pandas_ta(\"sma\", output_names=[\"sma\"])\n",
"\n",
"# 2. Define parameter ranges (using the same as before for comparison)\n",
"fast_sma_range = np.arange(10, 30, 2)\n",
"slow_sma_range = np.arange(40, 60, 2)\n",
"\n",
"# 3. Calculate indicators\n",
"# pandas-ta's period parameter is named 'length'\n",
"fast_sma_pdta = sma_factory_pdta.run(\n",
" close_price, length=fast_sma_range, short_name=\"fast\"\n",
")\n",
"slow_sma_pdta = sma_factory_pdta.run(\n",
" close_price, length=slow_sma_range, short_name=\"slow\"\n",
")\n",
"\n",
"# 4. Generate signals using the '.sma' output attribute\n",
"entries = fast_sma_pdta.sma.vbt.crossed_above(slow_sma_pdta.sma)\n",
"exits = fast_sma_pdta.sma.vbt.crossed_below(slow_sma_pdta.sma)\n",
"\n",
"# 5. Run portfolio backtest\n",
"portfolio_pdta = vbt.Portfolio.from_signals(\n",
" close_price, entries, exits, freq=timeframe, init_cash=10000\n",
")\n",
"\n",
"# 6. Analyze and display results\n",
"if portfolio_pdta.trades.count().sum() == 0:\n",
" print(\"No trades were executed for any parameter combination.\")\n",
"else:\n",
" # Find the best parameter combination\n",
" best_params_pdta = portfolio_pdta.sharpe_ratio().idxmax()\n",
"\n",
" # Select the single best portfolio instance\n",
" best_portfolio = portfolio_pdta[best_params_pdta]\n",
"\n",
" # Get the stats for that single portfolio\n",
" best_stats_pdta = best_portfolio.stats()\n",
"\n",
" print(f\"Best Parameters (fast_period, slow_period): {best_params_pdta}\")\n",
" print(\"Performance for Best Parameters:\")\n",
" print(\n",
" best_stats_pdta[\n",
" [\"Total Return [%]\", \"Sharpe Ratio\", \"Max Drawdown [%]\", \"Win Rate [%]\"]\n",
" ]\n",
" )\n",
"\n",
" # 7. Visualize the best strategy using the same layered approach\n",
"\n",
" # Get the indicator series for the best parameters\n",
" best_fast_sma = fast_sma_pdta.sma[best_params_pdta[0]]\n",
" best_slow_sma = slow_sma_pdta.sma[best_params_pdta[1]]\n",
"\n",
" # Step A: Create the base candlestick chart\n",
" fig = ohlc_data.vbt.ohlc.plot(\n",
" title_text=f\"pandas-ta SMA Strategy | Best Params: {best_params_pdta}\",\n",
" template=\"plotly_dark\",\n",
" )\n",
"\n",
" # Step B: Add the SMA lines\n",
" best_fast_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Fast SMA\", line=dict(color=\"cyan\"))\n",
" )\n",
" best_slow_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Slow SMA\", line=dict(color=\"orange\"))\n",
" )\n",
"\n",
" # Step C: Add the trade markers\n",
" best_portfolio.trades.plot(fig=fig)\n",
"\n",
" # Step D: Show the final figure\n",
" fig.show()\n",
"\n",
"print(\"-\" * 50 + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ecf78ab3",
"metadata": {},
"outputs": [],
"source": [
"# --- Method 3: Custom Implementation (Numba Accelerated) ---\n",
"print(\"--- Running Optimization using 'Custom Numba Function' ---\")\n",
"\n",
"# # Import Numba and vectorbt's numba-aware functions\n",
"# import vectorbt as vbt\n",
"\n",
"\n",
"# CORRECTED: The custom function must handle 2D numpy arrays.\n",
"# We use vectorbt's built-in, numba-accelerated rolling_mean_nb for this.\n",
"# The @njit decorator compiles this function for massive speedup.\n",
"@njit\n",
"def custom_sma_func_nb(close, n):\n",
" \"\"\"Calculates SMA using vectorbt's numba-accelerated rolling mean.\"\"\"\n",
" # vbt.nb.rolling_mean_nb correctly handles the 2D array passed by the factory\n",
" return vbt.nb.rolling_mean_nb(close, window=n, minp=n)\n",
"\n",
"\n",
"# 1. Create the indicator class from our numba-jitted function\n",
"CustomSMA = vbt.IndicatorFactory(\n",
" input_names=[\"close\"], param_names=[\"n\"], output_names=[\"sma\"]\n",
").from_apply_func(\n",
" custom_sma_func_nb # Use the new, numba-aware function\n",
")\n",
"\n",
"# 2. Define parameter ranges\n",
"fast_sma_range = np.arange(10, 30, 2)\n",
"slow_sma_range = np.arange(40, 60, 2)\n",
"\n",
"# 3. Calculate indicators\n",
"fast_sma_custom = CustomSMA.run(close_price, n=fast_sma_range, short_name=\"fast\")\n",
"slow_sma_custom = CustomSMA.run(close_price, n=slow_sma_range, short_name=\"slow\")\n",
"\n",
"# 4. Generate signals\n",
"entries = fast_sma_custom.sma.vbt.crossed_above(slow_sma_custom.sma)\n",
"exits = fast_sma_custom.sma.vbt.crossed_below(slow_sma_custom.sma)\n",
"\n",
"# 5. Run portfolio backtest\n",
"portfolio_custom = vbt.Portfolio.from_signals(\n",
" close_price, entries, exits, freq=timeframe, init_cash=10000\n",
")\n",
"\n",
"# 6. Analyze and display results\n",
"if portfolio_custom.trades.count().sum() == 0:\n",
" print(\"No trades were executed for any parameter combination.\")\n",
"else:\n",
" best_params_custom = portfolio_custom.sharpe_ratio().idxmax()\n",
" best_portfolio = portfolio_custom[best_params_custom]\n",
" best_stats_custom = best_portfolio.stats()\n",
"\n",
" print(f\"Best Parameters (fast_period, slow_period): {best_params_custom}\")\n",
" print(\"Performance for Best Parameters:\")\n",
" print(\n",
" best_stats_custom[\n",
" [\"Total Return [%]\", \"Sharpe Ratio\", \"Max Drawdown [%]\", \"Win Rate [%]\"]\n",
" ]\n",
" )\n",
"\n",
" # 7. Visualize the best strategy\n",
" best_fast_sma = fast_sma_custom.sma[best_params_custom[0]]\n",
" best_slow_sma = slow_sma_custom.sma[best_params_custom[1]]\n",
"\n",
" fig = ohlc_data.vbt.ohlc.plot(\n",
" title_text=f\"Custom Numba SMA Strategy | Best Params: {best_params_custom}\",\n",
" template=\"plotly_dark\",\n",
" )\n",
" best_fast_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Fast SMA\", line=dict(color=\"cyan\"))\n",
" )\n",
" best_slow_sma.vbt.plot(\n",
" fig=fig, trace_kwargs=dict(name=\"Slow SMA\", line=dict(color=\"orange\"))\n",
" )\n",
" best_portfolio.trades.plot(fig=fig)\n",
" fig.show()\n",
"\n",
"print(\"-\" * 50 + \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "strategy-optimizer_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}