Skip to content

Commit 2ae9534

Browse files
authored
Added test to linear regression (#12353)
1 parent 7647181 commit 2ae9534

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

machine_learning/linear_regression.py

+22
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def run_steep_gradient_descent(data_x, data_y, len_data, alpha, theta):
4141
:param theta : Feature vector (weight's for our model)
4242
;param return : Updated Feature's, using
4343
curr_features - alpha_ * gradient(w.r.t. feature)
44+
>>> import numpy as np
45+
>>> data_x = np.array([[1, 2], [3, 4]])
46+
>>> data_y = np.array([5, 6])
47+
>>> len_data = len(data_x)
48+
>>> alpha = 0.01
49+
>>> theta = np.array([0.1, 0.2])
50+
>>> run_steep_gradient_descent(data_x, data_y, len_data, alpha, theta)
51+
array([0.196, 0.343])
4452
"""
4553
n = len_data
4654

@@ -58,6 +66,12 @@ def sum_of_square_error(data_x, data_y, len_data, theta):
5866
:param len_data : len of the dataset
5967
:param theta : contains the feature vector
6068
:return : sum of square error computed from given feature's
69+
70+
Example:
71+
>>> vc_x = np.array([[1.1], [2.1], [3.1]])
72+
>>> vc_y = np.array([1.2, 2.2, 3.2])
73+
>>> round(sum_of_square_error(vc_x, vc_y, 3, np.array([1])),3)
74+
np.float64(0.005)
6175
"""
6276
prod = np.dot(theta, data_x.transpose())
6377
prod -= data_y.transpose()
@@ -93,6 +107,11 @@ def mean_absolute_error(predicted_y, original_y):
93107
:param predicted_y : contains the output of prediction (result vector)
94108
:param original_y : contains values of expected outcome
95109
:return : mean absolute error computed from given feature's
110+
111+
>>> predicted_y = [3, -0.5, 2, 7]
112+
>>> original_y = [2.5, 0.0, 2, 8]
113+
>>> mean_absolute_error(predicted_y, original_y)
114+
0.5
96115
"""
97116
total = sum(abs(y - predicted_y[i]) for i, y in enumerate(original_y))
98117
return total / len(original_y)
@@ -114,4 +133,7 @@ def main():
114133

115134

116135
if __name__ == "__main__":
136+
import doctest
137+
138+
doctest.testmod()
117139
main()

0 commit comments

Comments
 (0)