diff --git a/pytorch/.gitignore b/pytorch/.gitignore new file mode 100644 index 00000000..f9187508 --- /dev/null +++ b/pytorch/.gitignore @@ -0,0 +1,3 @@ +*rl_trader_models +*rl_trader_rewards +*.png diff --git a/pytorch/aapl_msi_sbux.csv b/pytorch/aapl_msi_sbux.csv new file mode 100644 index 00000000..cb98cb88 --- /dev/null +++ b/pytorch/aapl_msi_sbux.csv @@ -0,0 +1,1260 @@ +AAPL,MSI,SBUX +67.8542,60.3,28.185 +68.5614,60.9,28.07 +66.8428,60.83,28.13 +66.7156,60.81,27.915 +66.6556,61.12,27.775 +65.7371,61.43,27.17 +65.7128,62.03,27.225 +64.1214,61.26,26.655 +63.7228,60.88,26.675 +64.4014,61.9,27.085 +63.2571,60.28,26.605 +64.1385,60.63,26.64 +63.5099,62.09,27.285 +63.0571,62.21,27.425 +61.4957,62.03,27.435 +60.0071,62.5,27.85 +61.5919,62.97,28.255 +60.8088,63.11,28.55 +61.5117,62.64,29.125 +61.6742,62.75,29.335 +62.5528,62.56,29.305 +61.2042,62.13,29.14 +61.1928,62.22,29.2925 +61.7857,62.34,28.84 +63.3799,62.07,28.83 +65.1028,61.64,28.465 +64.9271,61.67,28.415 +64.5828,62.4,28.715 +64.6756,62.43,28.525 +65.9871,63.61,28.69 +66.2256,63.29,28.345 +65.8765,63.46,28.525 +64.5828,63.56,28.455 +63.2371,64.03,28.475 +61.2728,63.7,28.435 +61.3988,63.7,29.13 +61.7128,62.8,28.85 +61.1028,62.99,29.055 +60.4571,62.67,28.9 +60.8871,63.17,29.06 +60.9971,63.64,28.705 +62.2414,64.69,28.9 +62.0471,64.63,29.2875 +61.3999,63.87,29.545 +59.9785,61.83,28.855 +60.8914,62.96,29.28 +57.5428,62.13,29.085 +56.0071,61.15,28.86 +55.7899,61.72,29.2025 +56.9528,61.78,29.32 +58.0185,61.75,29.695 +57.9231,56.02,29.915 +58.3399,56.39,30.25 +59.6007,56.8,30.0 +61.4457,57.44,30.29 +63.2542,57.2,30.42 +62.7557,56.37,30.07 +63.6457,56.89,30.19 +64.2828,57.29,30.935 +65.8156,56.95,31.24 +65.5225,56.79,31.095 +66.2628,57.0,31.205 +65.2528,56.78,31.18 +64.7099,56.48,31.5485 +64.9628,56.17,31.41 +63.4085,56.89,31.76 +61.2642,57.1,32.035 +62.0825,57.53,31.775 +61.8942,57.84,32.065 +63.2757,58.25,31.915 +62.8085,57.77,32.125 +63.0505,57.3,32.075 +63.1628,57.48,31.76 +63.5928,57.81,31.68 +63.0627,58.53,32.13 +63.5642,58.32,31.815 +64.5114,58.54,31.735 +64.2478,57.96,31.57 +64.3885,57.83,31.73 +64.1871,57.41,31.665 +63.5871,56.27,31.17 +62.6371,56.92,31.51 +63.1158,56.94,32.52 +62.6985,56.61,33.055 +62.5142,56.38,32.71 +61.7414,56.26,32.225 +62.2807,57.19,32.985 +61.4357,56.93,32.8 +61.7142,57.33,33.015 +61.6814,57.35,33.5475 +60.4285,56.78,33.205 +59.5482,55.5,32.61 +59.0714,55.82,32.345 +57.5057,55.59,32.005 +57.5185,56.35,32.37 +56.8671,57.49,32.9 +56.2542,57.84,32.845 +56.6471,57.73,32.755 +58.4599,57.98,33.12 +59.7842,57.49,33.395 +60.1142,57.26,33.65 +59.6314,57.93,33.86 +59.2928,57.86,34.145 +60.3357,58.03,34.065 +60.1042,58.43,34.05 +61.0411,59.05,34.67 +60.9299,59.54,34.86 +61.0628,59.17,34.83 +61.4564,59.32,34.76 +61.4728,59.42,34.1 +61.6797,59.36,34.24 +60.7071,59.85,34.395 +60.9014,59.87,34.51 +59.8557,59.98,33.83 +62.9299,56.04,33.305 +62.6428,54.25,34.085 +62.9985,54.26,36.68 +63.9699,54.01,36.225 +64.7599,54.35,35.965 +64.6471,54.83,35.6445 +65.2394,55.32,36.74 +66.0771,56.02,37.115 +67.0642,56.1,36.985 +66.4642,56.4,36.4 +66.4256,56.48,36.095 +65.8585,57.13,36.47 +64.9214,57.36,36.4 +66.7656,57.44,36.465 +69.9385,57.84,36.32 +71.2142,57.71,35.925 +71.1299,56.96,35.37 +71.7614,57.15,35.355 +72.5342,57.09,35.145 +71.5814,57.05,35.33 +71.7656,56.06,35.3565 +71.8514,56.33,35.95 +71.5742,56.74,35.985 +71.8528,56.55,35.94 +69.7985,56.12,35.08 +70.1279,56.39,35.48 +70.2428,56.19,35.59 +69.6022,56.01,35.26 +69.7971,56.28,35.8 +71.2415,56.08,36.07 +70.7528,56.17,36.025 +71.1742,56.47,35.785 +72.3099,57.59,36.22 +70.6628,57.37,37.1075 +66.8156,57.25,37.695 +67.5271,57.5,37.835 +66.4142,57.46,37.785 +64.3028,57.81,37.62 +65.0456,58.28,38.02 +66.3828,59.26,38.665 +67.4714,59.69,38.175 +66.7728,60.39,38.06 +70.0914,60.37,37.68 +69.8714,59.99,38.275 +68.7899,59.85,38.17 +69.4599,59.87,38.59 +68.9642,59.75,38.665 +68.1071,59.38,38.485 +69.7085,60.89,38.58 +69.9371,60.7,38.595 +69.0585,60.56,38.435 +69.0042,61.14,38.7 +69.6785,60.89,38.4305 +68.7056,59.62,37.765 +69.5125,59.39,37.63 +69.9482,60.61,38.56 +70.4016,60.52,38.91 +70.8628,61.03,39.05 +71.2399,60.49,38.355 +71.5876,60.71,39.02 +72.0714,60.92,39.3675 +72.6985,60.81,39.655 +74.4802,61.18,39.73 +74.2667,60.43,40.45 +74.9942,62.4,40.025 +75.9871,62.51,39.525 +75.1368,62.99,39.98 +75.6965,62.44,39.355 +73.8111,62.73,39.81 +74.9851,62.25,40.415 +74.6716,62.52,40.525 +74.2899,62.39,40.185 +75.2499,62.71,40.185 +75.0641,62.68,40.995 +74.4171,62.65,40.565 +73.2131,62.49,39.535 +74.3656,63.12,40.6 +74.1496,63.51,40.495 +74.2871,64.24,40.3075 +74.3762,64.45,40.7305 +75.4514,64.58,40.57 +74.9986,65.57,40.595 +74.0898,65.42,40.27 +74.2214,64.61,39.96 +73.5714,64.58,39.845 +74.4479,65.41,40.765 +74.2571,65.88,40.675 +74.8199,65.79,40.355 +76.1999,65.57,40.755 +77.9942,65.5,40.81 +79.4385,65.88,40.73 +78.7471,65.66,40.535 +80.9031,65.79,40.275 +80.7142,64.93,39.75 +81.1286,65.23,39.86 +80.0028,66.18,39.97 +80.9185,65.79,39.865 +80.7928,65.41,38.69 +80.1942,64.6,38.2 +80.0771,64.86,38.24 +79.2042,65.05,38.175 +79.6428,65.36,38.23 +79.2842,65.52,38.045 +78.6813,66.16,38.84 +77.7799,65.85,38.575 +78.4314,65.61,38.83 +81.4413,66.78,39.16 +81.0956,67.1,39.285 +80.5571,67.18,39.44 +80.0128,67.33,39.285 +79.2171,67.25,39.275 +80.1456,67.5,39.195 +79.0185,66.33,38.585 +77.2828,66.2,38.475 +77.7042,65.92,38.085 +77.1481,66.19,38.605 +77.6371,65.99,39.015 +76.6455,66.5,38.8 +76.1342,66.15,38.835 +76.5328,65.49,37.56 +78.0556,66.35,37.73 +79.6228,65.62,38.095 +79.1785,65.81,37.645 +77.2385,66.1,37.45 +78.4385,67.11,36.825 +78.7871,64.51,36.8 +79.4542,65.34,36.695 +78.0099,64.42,37.49 +78.6428,64.43,37.105 +72.3571,64.34,36.945 +71.5356,63.98,35.78 +71.3974,64.91,35.955 +71.5142,63.8,35.56 +71.6471,62.72,34.485 +72.6842,62.99,35.325 +73.2271,62.89,35.245 +73.2156,63.4,36.18 +74.2399,64.6,37.0175 +75.5699,65.08,37.4 +76.5656,65.03,37.25 +76.5599,65.78,36.955 +77.7756,65.67,37.345 +77.7128,65.61,37.515 +77.9985,65.78,36.985 +76.7671,64.93,36.66 +75.8785,65.22,36.775 +75.0356,65.02,36.28 +75.3642,64.96,36.28 +74.5799,65.1,35.275 +73.9071,65.45,35.89 +75.3814,65.9,36.095 +75.1771,66.2,35.48 +75.3942,65.98,35.235 +75.8914,66.76,35.83 +76.0514,66.33,35.65 +75.8214,66.57,36.345 +75.7771,66.64,36.535 +75.8456,66.43,36.78 +76.5842,66.08,37.515 +76.6585,65.02,37.815 +75.8071,64.21,37.215 +74.9556,63.67,37.135 +75.2485,65.08,37.09 +75.9142,65.72,37.3 +75.8942,65.7,37.955 +75.5285,66.66,38.4775 +76.1242,66.81,38.355 +77.0271,66.05,37.885 +77.8556,66.18,37.305 +77.1114,65.16,36.77 +76.7799,64.36,36.7 +76.6942,64.3,36.85 +76.6771,64.29,36.69 +77.3785,64.91,37.005 +77.5071,65.1,36.835 +76.9699,65.09,36.545 +75.9742,64.26,35.775 +74.7814,64.43,35.215 +74.7771,64.95,35.74 +75.7599,65.26,36.24 +74.7828,63.99,35.11 +74.2299,63.39,34.365 +74.5256,63.78,34.655 +73.9942,63.37,34.445 +74.1442,63.23,35.395 +74.9914,63.15,35.075 +75.8814,62.51,35.24 +75.9569,63.27,35.5745 +74.9642,63.29,35.195 +81.1099,63.0,35.545 +81.7056,62.5,35.725 +84.8699,62.64,35.465 +84.6185,63.43,35.32 +84.2985,63.58,35.31 +84.4971,62.65,35.56 +84.6542,65.51,35.3 +85.8513,66.15,35.46 +84.9156,66.4,34.79 +84.6185,67.14,34.87 +83.9985,67.38,34.79 +83.6488,67.26,35.145 +84.6899,67.8,35.575 +84.8228,67.75,35.58 +84.8385,67.2,35.085 +84.1171,66.34,34.925 +85.3585,66.3,35.47 +86.3699,66.88,35.51 +86.3871,66.52,35.115 +86.6156,66.89,35.2 +86.7528,66.63,35.7 +87.7328,67.0,35.99 +89.3756,67.02,36.83 +89.1442,66.93,36.635 +90.7685,66.91,36.555 +90.4285,67.42,36.62 +89.8071,67.4,36.925 +91.0771,66.86,37.09 +92.1171,67.23,37.335 +92.4785,67.17,37.36 +92.2242,67.66,37.665 +93.7,67.67,37.59 +94.25,67.7,37.3 +93.86,66.93,37.4 +92.29,66.46,36.98 +91.28,66.78,37.345 +92.2,66.72,37.545 +92.08,66.64,37.655 +92.18,66.62,37.78 +91.86,67.06,38.615 +90.91,67.07,38.3 +90.83,67.1,38.365 +90.28,66.73,38.715 +90.36,66.55,39.06 +90.9,66.56,39.03 +91.98,66.78,38.97 +92.93,66.57,38.69 +93.52,66.96,39.04 +93.48,67.02,39.095 +94.03,67.41,39.53 +95.96799999999999,67.24,39.345 +95.35,66.27,39.28 +95.39,66.58,39.725 +95.035,66.45,39.425 +95.22,66.0,39.3 +96.45,66.08,39.28 +95.32,65.49,39.445 +94.78,65.67,39.365 +93.0899,64.94,38.62 +94.43,65.49,38.97 +93.939,65.74,38.805 +94.72,66.05,39.37 +97.19,65.77,39.57 +97.03,65.61,40.225 +97.671,65.0,39.37 +99.02,65.21,39.18 +98.38,64.74,39.325 +98.15,64.83,39.45 +95.6,63.68,38.84 +96.13,64.11,38.49 +95.59,64.11,38.765 +95.12,61.39,38.395 +94.96,61.21,38.565 +94.48,61.25,38.355 +94.74,62.19,38.81 +95.99,61.73,38.935 +95.97,61.64,38.91 +97.24,62.03,38.62 +97.5,61.52,38.31 +97.98,61.0,38.455 +99.16,60.81,38.795 +100.53,61.37,39.06 +100.57,61.64,39.015 +100.58,61.7,38.735 +101.32,61.23,38.64 +101.54,61.02,38.985 +100.889,60.3,38.895 +102.13,59.68,38.96 +102.25,59.37,38.905 +102.5,59.4,38.905 +103.3,59.01,38.74 +98.94,58.94,38.395 +98.12,58.98,38.58 +98.97,58.89,38.975 +98.36,61.02,38.835 +97.99,61.08,38.56 +101.0,61.22,38.605 +101.43,61.22,38.06 +101.66,61.54,37.735 +101.63,61.42,37.46 +100.86,61.69,37.545 +101.58,61.91,37.67 +101.79,62.04,37.865 +100.96,61.88,38.035 +101.06,61.68,37.3 +102.64,61.57,36.9775 +101.75,61.8,37.66 +97.87,62.24,37.06 +100.75,63.42,37.585 +100.11,63.18,37.635 +100.75,63.28,37.73 +99.18,62.34,37.305 +99.9,61.03,37.225 +99.62,61.3,37.945 +99.62,61.5,37.5725 +98.75,60.47,37.025 +100.8,61.58,37.63 +101.02,60.46,37.24 +100.73,59.05,37.23 +99.81,58.5,36.095 +98.75,58.73,36.37 +97.54,59.32,36.19 +96.26,59.18,36.32 +97.67,60.79,36.77 +99.76,61.25,37.35 +102.47,62.39,37.18 +102.99,61.63,37.3 +104.83,62.25,37.42 +105.22,62.57,37.905 +105.11,62.8,37.985 +106.74,64.06,38.525 +107.34,63.94,38.27 +106.98,63.7,38.66 +108.0,64.5,37.78 +109.4,64.68,38.05 +108.6,66.76,38.355 +108.86,64.46,38.33 +108.7,63.42,38.725 +109.01,64.14,38.895 +108.83,63.94,38.825 +109.7,63.55,38.865 +111.25,63.7,38.925 +112.82,64.43,38.945 +114.18,65.25,39.06 +113.99,65.4,38.915 +115.47,66.0,38.785 +114.67,65.94,38.91 +116.31,65.66,39.1 +116.47,65.27,39.88 +118.625,65.81,40.26 +117.6,65.6,40.105 +119.0,65.56,39.85 +118.93,65.72,40.605 +115.07,65.44,40.425 +114.63,65.51,40.185 +115.93,65.32,40.235 +115.49,65.2,40.655 +115.0,65.0,41.785 +112.4,65.27,41.9 +114.12,65.29,41.515 +111.95,63.52,41.33 +111.62,63.29,41.56 +109.73,62.31,41.625 +108.225,61.91,40.445 +106.745,61.73,39.565 +109.41,63.99,40.2175 +112.65,65.11,40.015 +111.78,65.5,39.72 +112.94,66.53,40.27 +112.54,66.93,40.715 +112.01,67.34,40.635 +113.99,67.49,40.915 +113.91,67.87,41.19 +112.52,67.53,40.895 +110.38,67.08,41.025 +109.33,66.51,40.72 +106.25,65.06,39.94 +106.26,64.51,39.615 +107.75,64.43,40.59 +111.89,65.43,41.245 +112.01,65.11,39.895 +109.25,64.35,40.115 +110.22,64.11,40.435 +109.8,63.76,40.21 +106.82,63.41,39.79 +105.99,64.05,40.305 +108.72,64.02,40.6125 +109.55,64.31,40.645 +112.4,65.36,41.37 +112.98,65.48,44.11 +113.1,65.71,44.06 +109.14,64.94,44.17 +115.31,63.84,43.7825 +118.9,63.83,44.525 +117.16,62.41,43.765 +118.63,62.81,43.995 +118.65,64.01,44.245 +119.56,63.94,44.35 +119.94,64.0,44.82 +118.93,64.66,44.5 +119.72,67.78,44.41 +122.02,68.22,45.59 +124.88,68.57,45.395 +126.46,70.0,45.9125 +127.08,69.91,45.79 +127.83,69.79,46.015 +128.715,69.12,46.5 +128.45,69.03,46.585 +129.495,69.83,46.755 +133.0,68.63,46.79 +132.17,68.53,46.725 +128.79,68.02,47.13 +130.415,68.47,47.275 +128.46,67.94,46.7425 +129.09,68.89,47.1125 +129.36,68.14,47.0 +128.54,67.64,46.53 +126.41,67.93,46.815 +126.6,66.82,46.1075 +127.14,66.57,46.52 +124.51,65.33,46.09 +122.24,65.31,45.71 +124.45,64.96,46.69 +123.59,64.8,46.645 +124.95,65.86,47.0225 +127.04,65.32,47.1925 +128.47,66.65,47.92 +127.495,66.34,48.88 +125.9,66.83,48.73 +127.21,66.52,48.685 +126.69,66.23,48.9575 +123.38,65.35,47.885 +124.24,65.42,47.54 +123.25,65.38,47.535 +126.37,66.39,47.99 +124.43,66.67,47.35 +124.25,66.67,46.51 +125.32,62.51,47.195 +127.35,61.48,47.26 +126.01,61.99,47.035 +125.6,62.42,47.615 +126.56,62.32,47.96 +127.1,62.53,48.17 +126.85,61.97,48.5 +126.3,61.91,48.3 +126.78,61.82,48.14 +126.17,61.86,48.245 +124.75,60.68,47.62 +127.6,61.16,47.97 +126.91,61.43,48.37 +128.62,61.59,48.335 +129.67,60.84,49.43 +130.28,60.57,51.84 +132.65,60.98,50.87 +130.56,60.69,50.61 +128.64,59.74,50.65 +125.15,59.75,49.58 +128.95,60.28,50.29 +128.7,60.68,50.445 +125.8,58.59,49.405 +125.01,58.75,48.93 +125.26,60.01,49.35 +127.62,60.59,49.78 +126.32,59.8,49.5 +125.865,59.42,49.71 +126.01,59.25,49.59 +128.95,59.79,50.555 +128.77,59.3,50.8 +130.19,60.12,51.18 +130.07,59.8,51.42 +130.06,59.8,51.03 +131.39,59.79,51.33 +132.54,59.66,51.48 +129.62,59.11,50.84 +132.045,59.06,51.59 +131.78,59.63,51.81 +130.28,59.0,51.96 +130.535,59.65,52.22 +129.96,59.19,51.73 +130.12,59.48,52.12 +129.36,58.8,51.72 +128.65,58.61,52.19 +127.8,58.08,51.53 +127.42,57.9,51.54 +128.88,58.49,52.69 +128.59,58.55,52.49 +127.17,57.65,52.63 +126.92,57.95,52.27 +127.6,58.18,52.965 +127.3,57.97,53.24 +127.88,58.39,54.11 +126.6,58.05,53.93 +127.61,59.22,53.9 +127.03,59.12,54.115 +128.11,58.29,53.71 +127.5,58.35,54.07 +126.75,58.38,54.62 +124.53,57.14,53.55 +125.425,57.34,53.615 +126.6,57.6,53.89 +126.44,57.51,54.24 +126.0,57.22,54.305 +125.69,57.49,54.375 +122.57,56.79,53.39 +120.07,56.94,54.05 +123.28,57.48,54.57 +125.66,58.43,55.7 +125.61,58.6,55.75 +126.82,58.89,55.34 +128.51,59.29,55.74 +129.62,58.85,55.69 +132.07,59.4,56.21 +130.75,59.57,56.2 +125.22,59.35,56.69 +125.16,58.85,56.56 +124.5,59.5,57.29 +122.77,58.71,56.98 +123.38,59.11,57.14 +122.99,59.58,57.51 +122.37,59.86,58.06 +121.3,60.16,57.93 +118.44,59.76,58.19 +114.64,60.22,58.7 +115.4,64.04,59.01 +115.13,63.8,57.23 +115.52,64.19,57.2 +119.72,63.99,56.27 +113.49,63.35,56.35 +115.24,64.6,56.38 +115.15,64.34,56.85 +115.96,64.98,57.1 +117.16,65.27,57.74 +116.5,65.77,57.83 +115.01,65.35,57.59 +112.65,63.89,55.81 +105.76,62.45,52.84 +103.12,60.79,50.34 +103.74,60.44,51.09 +109.69,63.14,53.96 +112.92,64.29,55.95 +113.29,64.55,55.63 +112.76,64.82,54.71 +107.72,63.85,53.5 +112.34,64.72,55.26 +110.37,65.11,54.69 +109.27,66.31,54.28 +112.31,69.61,55.21 +110.15,68.3,54.69 +112.57,69.09,55.37 +114.21,67.08,56.53 +115.31,66.84,56.29 +116.28,67.15,56.91 +116.41,67.47,57.26 +113.92,67.03,57.28 +113.45,67.09,56.84 +115.21,67.05,57.54 +113.4,66.58,57.12 +114.32,67.8,57.79 +115.0,67.91,58.37 +114.71,69.2,57.99 +112.44,67.93,55.77 +109.06,67.45,55.72 +110.3,68.38,56.84 +109.58,67.76,57.48 +110.38,68.4,58.08 +110.78,69.75,59.04 +111.31,69.19,58.69 +110.78,69.79,58.78 +109.5,69.5,59.46 +112.12,68.78,60.07 +111.6,69.43,60.54 +111.79,69.04,60.16 +110.21,68.7,58.82 +111.86,69.27,59.69 +111.04,69.26,59.93 +111.73,69.03,60.97 +113.77,69.48,60.88 +113.76,69.47,60.53 +115.5,70.48,61.49 +119.08,70.48,62.61 +115.28,70.05,63.43 +114.55,69.96,62.71 +119.27,70.37,63.51 +120.53,70.13,62.5 +119.5,69.97,62.57 +121.18,70.73,62.24 +122.57,71.36,62.8 +122.0,65.24,61.96 +120.92,67.4,62.28 +121.06,68.01,61.97 +120.57,68.2,61.34 +116.77,68.34,62.18 +116.11,70.02,61.87 +115.72,69.44,61.07 +112.34,69.03,59.74 +114.175,70.02,60.68 +113.69,71.05,60.55 +117.29,71.98,61.8 +118.78,72.45,61.46 +119.3,72.19,61.99 +117.75,72.24,62.64 +118.88,71.96,61.96 +118.03,71.83,62.19 +117.81,72.02,62.18 +118.3,71.78,61.39 +117.34,72.05,61.37 +116.28,71.89,61.22 +115.2,71.08,59.55 +119.03,72.11,61.75 +118.28,70.38,61.89 +118.23,69.75,62.16 +115.62,69.31,61.18 +116.17,69.37,61.87 +113.18,68.61,59.82 +112.48,68.14,59.92 +110.49,69.13,59.98 +111.34,69.52,60.35 +108.98,68.56,59.515 +106.03,67.58,58.62 +107.33,68.03,59.54 +107.23,68.87,59.99 +108.61,69.21,60.34 +108.03,69.06,60.32 +106.82,69.18,60.19 +108.74,69.64,61.13 +107.32,69.3,60.82 +105.26,68.45,60.03 +105.35,67.13,58.26 +102.71,66.39,58.65 +100.7,65.43,58.13 +96.45,64.11,56.69 +96.96,64.25,56.63 +98.53,64.37,57.82 +99.96,64.91,59.46 +97.39,63.37,57.87 +99.52,63.11,58.98 +97.13,61.59,58.0 +96.66,61.13,58.55 +96.79,60.36,56.92 +96.3,60.82,59.03 +101.42,62.04,59.17 +99.44,62.42,57.71 +99.99,63.16,58.61 +93.42,64.8,57.63 +94.09,64.74,59.285 +97.34,66.77,60.77 +96.43,66.85,61.4 +94.48,64.32,60.695 +96.35,64.88,59.53 +96.6,64.25,58.29 +94.02,62.82,54.49 +95.01,62.09,54.14 +94.99,62.24,54.42 +94.27,60.97,55.14 +93.7,60.52,54.92 +93.99,61.78,55.86 +96.64,63.42,56.41 +98.12,65.05,57.63 +96.26,64.78,56.96 +96.04,66.0,57.67 +96.88,66.75,58.87 +94.69,70.78,58.46 +96.1,72.84,58.11 +96.76,74.06,58.75 +96.91,74.86,58.34 +96.69,73.49,58.21 +100.53,71.19,60.04 +100.75,71.28,59.56 +101.5,71.25,59.04 +103.01,70.95,58.7 +101.87,71.01,58.0 +101.03,71.1,57.6 +101.12,71.48,57.07 +101.17,71.22,57.52 +102.26,71.2,57.59 +102.52,71.83,58.65 +104.58,71.97,59.08 +105.97,72.24,59.67 +105.8,72.83,59.55 +105.92,72.59,59.7 +105.91,73.12,59.1 +106.72,73.71,59.38 +106.13,73.15,58.83 +105.67,72.59,58.36 +105.19,73.37,58.96 +107.68,74.09,59.55 +109.56,74.89,60.01 +108.99,75.7,59.7 +109.99,76.11,61.02 +111.12,76.32,60.25 +109.81,75.71,60.04 +110.96,76.09,60.83 +108.54,74.99,61.17 +108.66,75.24,61.04 +109.02,74.88,60.9 +110.44,75.04,59.5 +112.04,75.37,60.21 +112.1,75.31,60.13 +109.85,75.64,60.51 +107.48,75.69,60.89 +106.91,75.97,60.9 +107.13,75.55,60.9 +105.97,74.99,60.64 +105.68,75.56,57.68 +105.08,75.51,57.77 +104.35,75.9,57.72 +97.82,76.04,56.9 +94.83,75.34,56.42 +93.74,75.19,56.23 +93.64,76.0,57.36 +95.18,74.96,56.25 +94.19,74.22,56.39 +93.24,74.25,56.25 +92.72,70.54,56.31 +92.79,70.82,56.64 +93.42,71.05,57.49 +92.51,70.07,56.23 +90.34,71.11,56.3 +90.52,70.62,55.82 +93.88,70.83,55.53 +93.49,69.89,54.88 +94.56,69.46,54.8 +94.2,68.72,54.55 +95.22,68.75,54.62 +96.43,68.78,54.6 +97.9,69.68,55.44 +99.62,69.35,55.15 +100.41,69.4,55.29 +100.35,69.5,55.15 +99.86,69.27,54.89 +98.46,69.06,54.82 +97.72,68.8,54.62 +97.92,68.47,54.61 +98.63,68.77,55.59 +99.03,68.16,55.3 +98.94,69.05,55.22 +99.65,68.56,55.58 +98.83,67.45,54.865 +97.34,66.82,55.04 +97.46,67.24,55.57 +97.14,67.54,55.35 +97.55,67.8,55.53 +95.33,67.33,55.31 +95.1,68.35,55.38 +95.91,67.81,55.81 +95.55,67.43,55.61 +96.1,68.01,56.13 +93.4,64.73,54.68 +92.04,63.08,53.69 +93.59,63.69,54.85 +94.4,64.55,56.74 +95.6,65.97,57.12 +95.89,66.01,56.99 +94.99,64.77,56.77 +95.53,65.3,56.75 +95.94,65.05,56.91 +96.68,66.38,56.51 +96.98,66.62,56.32 +97.42,67.4,57.48 +96.87,67.46,56.48 +98.79,67.58,57.59 +98.78,67.4,57.41 +99.83,67.55,56.92 +99.87,67.5,56.76 +99.96,67.93,57.54 +99.43,67.55,57.6 +98.66,68.25,57.9 +97.34,68.09,57.95 +96.67,68.42,58.31 +102.95,69.26,57.85 +104.34,69.58,58.21 +104.21,69.38,58.05 +106.05,69.63,57.63 +104.48,68.84,56.73 +105.79,69.29,55.94 +105.87,70.24,55.42 +107.48,73.5,55.9 +108.37,73.93,55.36 +108.81,74.28,55.2 +108.0,74.28,55.62 +107.93,75.52,55.47 +108.18,74.54,55.47 +109.48,75.44,55.25 +109.38,75.58,55.37 +109.22,75.68,55.8 +109.08,75.99,55.53 +109.36,76.34,54.94 +108.51,76.49,55.85 +108.85,76.99,56.4 +108.03,77.12,57.09 +107.57,77.18,57.29 +106.94,77.2,57.29 +106.82,77.29,56.8 +106.0,77.51,56.4 +106.1,76.99,56.23 +106.73,76.8,56.31 +107.73,77.95,56.18 +107.7,78.32,56.02 +108.36,78.08,56.32 +105.52,77.37,55.3 +103.13,76.65,54.35 +105.44,77.23,54.71 +107.95,76.09,53.98 +111.77,75.47,53.9 +115.57,76.04,54.11 +114.92,75.63,53.74 +113.58,75.76,53.01 +113.57,75.21,53.3 +113.55,75.73,53.98 +114.62,76.19,54.39 +112.71,76.11,54.43 +112.88,75.95,54.04 +113.09,76.32,54.19 +113.95,76.79,53.98 +112.18,77.21,53.45 +113.05,76.28,54.14 +112.52,75.25,53.84 +113.0,74.42,53.53 +113.05,74.35,53.35 +113.89,74.64,53.14 +114.06,74.48,53.46 +116.05,74.67,53.3 +116.3,73.5,52.92 +117.34,73.76,53.16 +116.98,73.06,52.95 +117.63,73.58,53.08 +117.55,73.13,52.76 +117.47,73.8,52.61 +117.12,73.8,53.15 +117.06,73.57,53.59 +116.6,73.62,53.63 +117.65,74.49,54.18 +118.25,74.16,53.67 +115.59,73.58,53.63 +114.48,73.48,53.59 +113.72,72.83,53.53 +113.54,72.58,53.07 +111.49,72.32,52.5 +111.59,71.57,52.98 +109.83,71.29,51.77 +108.84,75.9,52.75 +110.41,77.71,54.49 +111.06,78.56,54.62 +110.88,78.96,54.58 +107.79,79.19,53.57 +108.43,80.38,53.93 +105.71,80.6,54.22 +107.11,81.8,54.59 +109.99,80.51,55.44 +109.95,80.35,55.85 +110.06,79.98,55.77 +111.73,79.83,56.1 +111.8,80.31,57.12 +111.23,80.26,57.59 +111.79,80.98,57.43 +111.57,80.86,57.59 +111.46,81.11,58.17 +110.52,80.25,57.97 +109.49,79.19,58.51 +109.9,79.5,57.21 +109.11,80.92,57.5 +109.95,82.22,57.44 +111.03,83.27,58.76 +112.12,83.3,58.65 +113.95,82.79,58.75 +113.3,82.6,58.77 +115.19,83.24,59.31 +115.19,82.9,58.75 +115.82,83.46,57.71 +115.97,83.4,57.66 +116.64,83.93,57.65 +116.95,83.76,57.7 +117.06,84.0,57.44 +116.29,83.72,57.11 +116.52,83.41,57.01 +117.26,83.52,56.86 +116.76,82.86,56.35 +116.73,82.87,56.32 +115.82,82.89,55.52 +116.15,83.6,55.35 +116.02,83.49,55.99 +116.61,82.64,56.46 +117.91,82.89,57.13 +118.99,83.02,58.2 +119.11,82.63,57.88 +119.75,82.88,58.1 +119.25,82.18,58.03 +119.04,82.27,57.85 +120.0,80.73,58.0 +119.99,81.65,58.45 +119.78,81.86,57.89 +120.0,82.36,57.66 +120.08,82.44,57.76 +119.97,84.35,58.44 +121.88,85.29,58.7 +121.94,83.36,58.46 +121.95,82.98,56.12 +121.63,81.7,55.9 +121.35,80.71,55.22 +128.75,80.03,53.9 +128.53,81.0,53.87 +129.08,81.6,55.06 +130.29,81.73,55.73 +131.53,77.34,55.24 +132.04,78.25,55.22 +132.42,77.81,55.81 +132.12,78.37,56.22 +133.29,78.48,56.11 +135.02,78.68,56.58 +135.51,79.4,56.86 +135.345,78.66,56.73 +135.72,79.31,57.35 +136.7,80.15,57.54 +137.11,79.65,57.57 +136.53,79.36,57.64 +136.66,80.27,57.48 +136.93,79.28,56.78 +136.99,78.97,56.87 +139.79,79.98,57.14 +138.96,80.02,57.12 +139.78,80.55,57.1 +139.34,79.97,56.68 +139.52,79.66,56.2 +139.0,80.2,55.74 +138.68,81.37,55.19 +139.14,82.1,54.53 +139.2,81.65,54.63 +138.99,83.36,54.27 +140.46,85.24,54.54 +140.69,85.15,54.8 +139.99,84.72,55.78 +141.46,84.3,55.81 +139.84,83.76,55.54 +141.42,83.59,55.89 +140.92,83.74,55.85 +140.64,83.67,56.81 +140.88,84.0,57.23 +143.8,84.0,57.35 +144.12,84.13,57.54 +143.93,84.87,58.16 +143.66,86.22,58.39 +143.7,84.83,58.44 +144.77,84.52,58.32 +144.02,83.83,58.22 +143.66,84.2,57.92 +143.34,84.25,58.02 +143.17,83.71,57.95 +141.63,83.45,57.88 +141.8,82.84,57.58 +141.05,82.34,57.51 +141.83,83.08,58.08 +141.2,82.64,58.35 +140.68,83.37,59.04 +142.44,84.1,60.08 +142.27,83.72,60.61 +143.64,84.72,61.11 +144.53,85.39,60.96 +143.68,85.38,61.56 +143.79,86.07,61.3 +143.65,85.97,60.06 +146.58,86.16,60.18 +147.51,85.92,60.5 +147.06,86.37,60.59 +146.53,86.1,60.83 +148.96,84.44,60.95 +153.01,83.59,60.94 +153.99,84.77,60.98 +153.26,85.77,60.66 +153.95,85.36,60.27 +156.1,84.21,59.93 +155.7,84.48,60.45 +155.47,83.7,59.98 +150.25,81.85,59.73 +152.54,80.83,59.82 +153.06,80.83,61.36 +153.99,82.93,61.23 +153.8,82.11,61.15 +153.34,82.22,61.89 +153.87,82.27,62.9 +153.61,81.86,63.3 +153.67,82.83,63.26 +152.76,83.57,63.61 +153.18,85.64,63.75 +155.45,86.62,64.57 +153.93,87.31,64.27 +154.45,87.46,64.16 +155.37,86.18,63.5 +154.99,86.65,62.24 +148.98,86.17,62.19 +145.42,86.11,61.29 +146.59,86.04,60.92 +145.16,84.87,60.27 +144.29,84.45,60.09 +142.27,84.72,60.14 +146.34,86.2,60.9 +145.01,85.74,59.86 +145.87,86.24,59.96 +145.63,87.36,59.51 +146.28,88.64,59.81 +145.82,88.42,59.64 +143.73,87.72,58.96 +145.83,88.13,59.18 +143.68,86.8,58.36 +144.02,86.74,58.31 +143.5,86.68,58.25 +144.09,86.83,57.94 +142.73,85.96,57.6 +144.18,87.31,58.04 +145.06,87.23,57.81 +145.53,87.65,57.9 +145.74,88.65,58.54 +147.77,88.33,58.38 +149.04,88.61,58.76 +149.56,88.28,58.33 +150.08,88.45,58.21 +151.02,89.78,58.11 +150.34,89.96,58.03 +150.27,90.52,57.98 +152.09,90.67,58.02 +152.74,91.39,58.55 +153.46,91.84,57.94 +150.56,92.21,59.5 +149.5,91.01,54.0 +148.73,90.68,53.98 +158.59,90.43,54.73 +157.14,90.43,55.43 +155.57,90.4,55.68 +156.39,90.37,55.44 +158.81,89.2,55.63 +160.08,88.58,54.52 +161.06,88.51,53.74 +155.32,86.99,53.07 +157.48,87.48,53.18 +159.85,88.6,53.22 +161.6,87.92,53.15 +160.95,88.19,53.5 +157.86,87.13,53.04 +157.5,87.37,52.7 +157.21,87.2,53.15 +159.78,86.41,54.45 +159.98,86.21,54.08 +159.27,86.51,53.94 +159.86,86.88,54.36 +161.47,87.3,54.4 +162.91,86.94,54.1 +163.35,87.66,54.52 +164.0,88.12,54.86 +164.05,87.87,54.93 +162.08,86.66,55.13 +161.91,85.72,54.31 +161.26,86.35,53.47 +158.63,85.12,53.49 +161.5,87.01,54.02 +160.86,87.74,53.54 +159.65,85.97,54.29 +158.28,84.58,54.53 +159.88,85.48,54.67 +158.67,85.48,54.69 +158.73,85.84,54.62 +156.07,85.65,55.15 +153.39,84.99,55.01 +151.89,84.29,55.09 +150.55,83.5,54.95 +153.14,83.02,55.13 +154.23,84.1,54.99 +153.28,83.69,54.5 +154.12,84.87,53.71 +153.81,85.84,53.81 +154.48,85.69,53.99 +153.48,85.64,53.93 +155.39,86.0,54.6 +155.3,89.44,55.17 +155.84,89.1,55.02 +155.9,89.08,55.42 +156.55,89.26,55.64 +156.0,89.8,55.97 +156.99,89.93,55.72 +159.88,89.36,54.91 +160.47,88.88,54.51 +159.76,89.09,55.21 +155.98,89.65,55.4 +156.25,90.0,54.57 +156.17,89.94,54.27 +157.1,90.23,54.28 +156.41,90.04,54.16 +157.41,90.28,54.91 +163.05,91.19,54.88 +166.72,90.37,55.17 +169.04,90.54,54.84 +166.89,90.56,55.13 +168.11,90.02,54.87 +172.5,94.25,56.03 +174.25,92.43,56.57 +174.81,92.11,57.22 +176.24,92.66,57.91 +175.88,91.61,57.36 +174.67,91.07,57.04 +173.97,91.37,56.64 +171.34,91.02,56.93 +169.08,90.39,56.7 +171.1,90.97,57.24 +170.15,90.95,56.93 +169.98,92.33,56.81 +173.14,92.45,57.26 +174.96,91.83,57.14 +174.97,92.36,56.8 +174.09,92.88,55.91 +173.07,94.53,56.66 +169.48,94.17,57.51 +171.85,94.11,57.82 +171.05,93.03,57.32 +169.8,93.63,58.76 +169.64,90.66,59.34 +169.01,91.29,59.28 +169.32,92.8,59.14 +169.37,92.52,58.61 +172.67,92.33,59.07 +171.7,93.37,59.27 +172.27,93.94,59.49 +172.22,92.2,59.7 +173.97,93.15,58.29 +176.42,94.49,58.03 +174.54,93.28,58.01 +174.35,92.1,57.73 +175.01,91.62,57.58 +175.01,90.76,57.3 +170.57,90.67,57.14 +170.6,90.8,57.27 +171.08,90.57,57.81 +169.23,90.34,57.43 +172.26,90.55,57.63 +172.23,89.91,58.71 +173.03,90.66,58.93 +175.0,91.88,59.61 +174.35,92.82,59.31 +174.33,92.12,59.18 +174.29,92.38,59.82 +175.28,93.55,60.0 +177.09,96.57,60.4 +176.19,95.86,60.56 +179.1,97.28,60.66 +179.26,97.5,61.09 +178.46,97.8,61.26 +177.0,97.33,61.41 +177.04,96.76,61.69 +174.22,95.84,60.83 +171.11,97.68,60.55 +171.51,99.0,57.99 +167.96,99.18,57.02 +166.97,99.8,57.19 +167.43,99.46,56.81 +167.78,99.12,56.0 +160.5,103.87,55.77 +156.49,101.06,54.69 +163.03,102.76,55.61 +159.54,102.63,54.46 diff --git a/pytorch/ann_regression.py b/pytorch/ann_regression.py new file mode 100644 index 00000000..65d66250 --- /dev/null +++ b/pytorch/ann_regression.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +"""PyTorch Regression.ipynb + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/1pEjzEmbnu2wXAhIaBS8PSpi-0cWtR6ov +""" + +import torch +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +# Make the dataset +N = 1000 +X = np.random.random((N, 2)) * 6 - 3 # uniformly distributed between (-3, +3) +Y = np.cos(2*X[:,0]) + np.cos(3*X[:,1]) + +# Plot it +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') +ax.scatter(X[:,0], X[:,1], Y) +plt.show() + +# Build the model +model = nn.Sequential( + nn.Linear(2, 128), + nn.ReLU(), + nn.Linear(128, 1) +) + +# Loss and optimizer +criterion = nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +# Train the model +def full_gd(model, criterion, optimizer, X_train, y_train, epochs=1000): + # Stuff to store + train_losses = np.zeros(epochs) + + for it in range(epochs): + # zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = model(X_train) + loss = criterion(outputs, y_train) + + # Backward and optimize + loss.backward() + optimizer.step() + + # Save losses + train_losses[it] = loss.item() + + if (it + 1) % 50 == 0: + print(f'Epoch {it+1}/{epochs}, Train Loss: {loss.item():.4f}') + + return train_losses + +X_train = torch.from_numpy(X.astype(np.float32)) +y_train = torch.from_numpy(Y.astype(np.float32).reshape(-1, 1)) +train_losses = full_gd(model, criterion, optimizer, X_train, y_train) + +plt.plot(train_losses) +plt.show() + + +# Plot the prediction surface +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') +ax.scatter(X[:,0], X[:,1], Y) + +# surface plot +with torch.no_grad(): + line = np.linspace(-3, 3, 50) + xx, yy = np.meshgrid(line, line) + Xgrid = np.vstack((xx.flatten(), yy.flatten())).T + Xgrid_torch = torch.from_numpy(Xgrid.astype(np.float32)) + Yhat = model(Xgrid_torch).numpy().flatten() + ax.plot_trisurf(Xgrid[:,0], Xgrid[:,1], Yhat, linewidth=0.2, antialiased=True) + plt.show() + +# Can it extrapolate? +# Plot the prediction surface +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') +ax.scatter(X[:,0], X[:,1], Y) + +# surface plot +with torch.no_grad(): + line = np.linspace(-5, 5, 50) + xx, yy = np.meshgrid(line, line) + Xgrid = np.vstack((xx.flatten(), yy.flatten())).T + Xgrid_torch = torch.from_numpy(Xgrid.astype(np.float32)) + Yhat = model(Xgrid_torch).numpy().flatten() + ax.plot_trisurf(Xgrid[:,0], Xgrid[:,1], Yhat, linewidth=0.2, antialiased=True) + plt.show() \ No newline at end of file diff --git a/pytorch/extra_reading.txt b/pytorch/extra_reading.txt new file mode 100644 index 00000000..7d5afcf1 --- /dev/null +++ b/pytorch/extra_reading.txt @@ -0,0 +1,27 @@ +Gradient Descent: Convergence Analysis +http://www.stat.cmu.edu/~ryantibs/convexopt-F13/scribes/lec6.pdf + +Deep learning improved by biological activation functions +https://arxiv.org/pdf/1804.11237.pdf + +Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift +Sergey Ioffe, Christian Szegedy +https://arxiv.org/abs/1502.03167 + +Dropout: A Simple Way to Prevent Neural Networks from Overfitting +https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf + +Convolution arithmetic tutorial +http://deeplearning.net/software/theano_versions/dev/tutorial/conv_arithmetic.html + +On the Practical Computational Power of Finite Precision RNNs for Language Recognition +https://arxiv.org/abs/1805.04908 + +Massive Exploration of Neural Machine Translation Architectures +https://arxiv.org/abs/1703.03906 + +Practical Deep Reinforcement Learning Approach for Stock Trading +https://arxiv.org/abs/1811.07522 + +Inceptionism: Going Deeper into Neural Networks +https://ai.googleblog.com/2015/06/inceptionism-going-deeper-into-neural.html \ No newline at end of file diff --git a/pytorch/plot_rl_rewards.py b/pytorch/plot_rl_rewards.py new file mode 100644 index 00000000..85cc1b2e --- /dev/null +++ b/pytorch/plot_rl_rewards.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +import numpy as np +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('-m', '--mode', type=str, required=True, + help='either "train" or "test"') +args = parser.parse_args() + +a = np.load(f'rl_trader_rewards/{args.mode}.npy') + +print(f"average reward: {a.mean():.2f}, min: {a.min():.2f}, max: {a.max():.2f}") + +plt.hist(a, bins=20) +plt.title(args.mode) +plt.show() \ No newline at end of file diff --git a/pytorch/rl_trader.py b/pytorch/rl_trader.py new file mode 100644 index 00000000..fbb96c91 --- /dev/null +++ b/pytorch/rl_trader.py @@ -0,0 +1,441 @@ +import numpy as np +import pandas as pd + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from datetime import datetime +import itertools +import argparse +import re +import os +import pickle + +from sklearn.preprocessing import StandardScaler + + +# Let's use AAPL (Apple), MSI (Motorola), SBUX (Starbucks) +def get_data(): + # returns a T x 3 list of stock prices + # each row is a different stock + # 0 = AAPL + # 1 = MSI + # 2 = SBUX + df = pd.read_csv('aapl_msi_sbux.csv') + return df.values + + + +### The experience replay memory ### +class ReplayBuffer: + def __init__(self, obs_dim, act_dim, size): + self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32) + self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32) + self.acts_buf = np.zeros(size, dtype=np.uint8) + self.rews_buf = np.zeros(size, dtype=np.float32) + self.done_buf = np.zeros(size, dtype=np.uint8) + self.ptr, self.size, self.max_size = 0, 0, size + + def store(self, obs, act, rew, next_obs, done): + self.obs1_buf[self.ptr] = obs + self.obs2_buf[self.ptr] = next_obs + self.acts_buf[self.ptr] = act + self.rews_buf[self.ptr] = rew + self.done_buf[self.ptr] = done + self.ptr = (self.ptr+1) % self.max_size + self.size = min(self.size+1, self.max_size) + + def sample_batch(self, batch_size=32): + idxs = np.random.randint(0, self.size, size=batch_size) + return dict(s=self.obs1_buf[idxs], + s2=self.obs2_buf[idxs], + a=self.acts_buf[idxs], + r=self.rews_buf[idxs], + d=self.done_buf[idxs]) + + + + + +def get_scaler(env): + # return scikit-learn scaler object to scale the states + # Note: you could also populate the replay buffer here + + states = [] + for _ in range(env.n_step): + action = np.random.choice(env.action_space) + state, reward, done, info = env.step(action) + states.append(state) + if done: + break + + scaler = StandardScaler() + scaler.fit(states) + return scaler + + + + +def maybe_make_dir(directory): + if not os.path.exists(directory): + os.makedirs(directory) + + + + +class MLP(nn.Module): + def __init__(self, n_inputs, n_action, n_hidden_layers=1, hidden_dim=32): + super(MLP, self).__init__() + + M = n_inputs + self.layers = [] + for _ in range(n_hidden_layers): + layer = nn.Linear(M, hidden_dim) + M = hidden_dim + self.layers.append(layer) + self.layers.append(nn.ReLU()) + + # final layer + self.layers.append(nn.Linear(M, n_action)) + self.layers = nn.Sequential(*self.layers) + + def forward(self, X): + return self.layers(X) + + def save_weights(self, path): + torch.save(self.state_dict(), path) + + def load_weights(self, path): + self.load_state_dict(torch.load(path)) + + + +def predict(model, np_states): + with torch.no_grad(): + inputs = torch.from_numpy(np_states.astype(np.float32)) + output = model(inputs) + # print("output:", output) + return output.numpy() + + + +def train_one_step(model, criterion, optimizer, inputs, targets): + # convert to tensors + inputs = torch.from_numpy(inputs.astype(np.float32)) + targets = torch.from_numpy(targets.astype(np.float32)) + + # zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, targets) + + # Backward and optimize + loss.backward() + optimizer.step() + + + +class MultiStockEnv: + """ + A 3-stock trading environment. + State: vector of size 7 (n_stock * 2 + 1) + - # shares of stock 1 owned + - # shares of stock 2 owned + - # shares of stock 3 owned + - price of stock 1 (using daily close price) + - price of stock 2 + - price of stock 3 + - cash owned (can be used to purchase more stocks) + Action: categorical variable with 27 (3^3) possibilities + - for each stock, you can: + - 0 = sell + - 1 = hold + - 2 = buy + """ + def __init__(self, data, initial_investment=20000): + # data + self.stock_price_history = data + self.n_step, self.n_stock = self.stock_price_history.shape + + # instance attributes + self.initial_investment = initial_investment + self.cur_step = None + self.stock_owned = None + self.stock_price = None + self.cash_in_hand = None + + self.action_space = np.arange(3**self.n_stock) + + # action permutations + # returns a nested list with elements like: + # [0,0,0] + # [0,0,1] + # [0,0,2] + # [0,1,0] + # [0,1,1] + # etc. + # 0 = sell + # 1 = hold + # 2 = buy + self.action_list = list(map(list, itertools.product([0, 1, 2], repeat=self.n_stock))) + + # calculate size of state + self.state_dim = self.n_stock * 2 + 1 + + self.reset() + + + def reset(self): + self.cur_step = 0 + self.stock_owned = np.zeros(self.n_stock) + self.stock_price = self.stock_price_history[self.cur_step] + self.cash_in_hand = self.initial_investment + return self._get_obs() + + + def step(self, action): + assert action in self.action_space + + # get current value before performing the action + prev_val = self._get_val() + + # update price, i.e. go to the next day + self.cur_step += 1 + self.stock_price = self.stock_price_history[self.cur_step] + + # perform the trade + self._trade(action) + + # get the new value after taking the action + cur_val = self._get_val() + + # reward is the increase in porfolio value + reward = cur_val - prev_val + + # done if we have run out of data + done = self.cur_step == self.n_step - 1 + + # store the current value of the portfolio here + info = {'cur_val': cur_val} + + # conform to the Gym API + return self._get_obs(), reward, done, info + + + def _get_obs(self): + obs = np.empty(self.state_dim) + obs[:self.n_stock] = self.stock_owned + obs[self.n_stock:2*self.n_stock] = self.stock_price + obs[-1] = self.cash_in_hand + return obs + + + + def _get_val(self): + return self.stock_owned.dot(self.stock_price) + self.cash_in_hand + + + def _trade(self, action): + # index the action we want to perform + # 0 = sell + # 1 = hold + # 2 = buy + # e.g. [2,1,0] means: + # buy first stock + # hold second stock + # sell third stock + action_vec = self.action_list[action] + + # determine which stocks to buy or sell + sell_index = [] # stores index of stocks we want to sell + buy_index = [] # stores index of stocks we want to buy + for i, a in enumerate(action_vec): + if a == 0: + sell_index.append(i) + elif a == 2: + buy_index.append(i) + + # sell any stocks we want to sell + # then buy any stocks we want to buy + if sell_index: + # NOTE: to simplify the problem, when we sell, we will sell ALL shares of that stock + for i in sell_index: + self.cash_in_hand += self.stock_price[i] * self.stock_owned[i] + self.stock_owned[i] = 0 + if buy_index: + # NOTE: when buying, we will loop through each stock we want to buy, + # and buy one share at a time until we run out of cash + can_buy = True + while can_buy: + for i in buy_index: + if self.cash_in_hand > self.stock_price[i]: + self.stock_owned[i] += 1 # buy one share + self.cash_in_hand -= self.stock_price[i] + else: + can_buy = False + + + + + +class DQNAgent(object): + def __init__(self, state_size, action_size): + self.state_size = state_size + self.action_size = action_size + self.memory = ReplayBuffer(state_size, action_size, size=500) + self.gamma = 0.95 # discount rate + self.epsilon = 1.0 # exploration rate + self.epsilon_min = 0.01 + self.epsilon_decay = 0.995 + self.model = MLP(state_size, action_size) + + # Loss and optimizer + self.criterion = nn.MSELoss() + self.optimizer = torch.optim.Adam(self.model.parameters()) + + + def update_replay_memory(self, state, action, reward, next_state, done): + self.memory.store(state, action, reward, next_state, done) + + + def act(self, state): + if np.random.rand() <= self.epsilon: + return np.random.choice(self.action_size) + act_values = predict(self.model, state) + return np.argmax(act_values[0]) # returns action + + + def replay(self, batch_size=32): + # first check if replay buffer contains enough data + if self.memory.size < batch_size: + return + + # sample a batch of data from the replay memory + minibatch = self.memory.sample_batch(batch_size) + states = minibatch['s'] + actions = minibatch['a'] + rewards = minibatch['r'] + next_states = minibatch['s2'] + done = minibatch['d'] + + # Calculate the target: Q(s',a) + target = rewards + (1 - done) * self.gamma * np.amax(self.model.predict(next_states), axis=1) + + # With the PyTorch API, it is simplest to have the target be the + # same shape as the predictions. + # However, we only need to update the network for the actions + # which were actually taken. + # We can accomplish this by setting the target to be equal to + # the prediction for all values. + # Then, only change the targets for the actions taken. + # Q(s,a) + target_full = predict(self.model, states) + target_full[np.arange(batch_size), actions] = target + + # Run one training step + train_one_step(self.model, self.criterion, self.optimizer, states, target_full) + + if self.epsilon > self.epsilon_min: + self.epsilon *= self.epsilon_decay + + + def load(self, name): + self.model.load_weights(name) + + + def save(self, name): + self.model.save_weights(name) + + +def play_one_episode(agent, env, is_train): + # note: after transforming states are already 1xD + state = env.reset() + state = scaler.transform([state]) + done = False + + while not done: + action = agent.act(state) + next_state, reward, done, info = env.step(action) + next_state = scaler.transform([next_state]) + if is_train == 'train': + agent.update_replay_memory(state, action, reward, next_state, done) + agent.replay(batch_size) + state = next_state + + return info['cur_val'] + + + +if __name__ == '__main__': + + # config + models_folder = 'rl_trader_models' + rewards_folder = 'rl_trader_rewards' + num_episodes = 2000 + batch_size = 32 + initial_investment = 20000 + + + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--mode', type=str, required=True, + help='either "train" or "test"') + args = parser.parse_args() + + maybe_make_dir(models_folder) + maybe_make_dir(rewards_folder) + + data = get_data() + n_timesteps, n_stocks = data.shape + + n_train = n_timesteps // 2 + + train_data = data[:n_train] + test_data = data[n_train:] + + env = MultiStockEnv(train_data, initial_investment) + state_size = env.state_dim + action_size = len(env.action_space) + agent = DQNAgent(state_size, action_size) + scaler = get_scaler(env) + + # store the final value of the portfolio (end of episode) + portfolio_value = [] + + if args.mode == 'test': + # then load the previous scaler + with open(f'{models_folder}/scaler.pkl', 'rb') as f: + scaler = pickle.load(f) + + # remake the env with test data + env = MultiStockEnv(test_data, initial_investment) + + # make sure epsilon is not 1! + # no need to run multiple episodes if epsilon = 0, it's deterministic + agent.epsilon = 0.01 + + # load trained weights + agent.load(f'{models_folder}/dqn.ckpt') + + # play the game num_episodes times + for e in range(num_episodes): + t0 = datetime.now() + val = play_one_episode(agent, env, args.mode) + dt = datetime.now() - t0 + print(f"episode: {e + 1}/{num_episodes}, episode end value: {val:.2f}, duration: {dt}") + portfolio_value.append(val) # append episode end portfolio value + + # save the weights when we are done + if args.mode == 'train': + # save the DQN + agent.save(f'{models_folder}/dqn.ckpt') + + # save the scaler + with open(f'{models_folder}/scaler.pkl', 'wb') as f: + pickle.dump(scaler, f) + + + # save portfolio value for each episode + np.save(f'{rewards_folder}/{args.mode}.npy', portfolio_value)