-
Notifications
You must be signed in to change notification settings - Fork 0
/
Decision Tree.py
33 lines (27 loc) · 890 Bytes
/
Decision Tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng. rand(16))
# Fit regression model
regr_1 = DecisionTreeRegressor (max_depth=2)
regr_2 = DecisionTreeRegressor (max_depth=5)
regr_1. fit(X, y)
regr_2. fit(X, y)
# Predict
X_test = np.arange(0.0, 5.0, 0.01) [:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
# Plot the result
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend( )
plt.show( )