-
-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtest_tree.py
29 lines (24 loc) · 1004 Bytes
/
test_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
from pymc_bart.tree import Node, get_depth, get_idx_left_child, get_idx_right_child
def test_split_node():
index = 5
split_node = Node(idx_split_variable=2, value=3.0)
assert get_depth(index) == 2
assert split_node.value == 3.0
assert split_node.idx_split_variable == 2
assert split_node.idx_data_points is None
assert get_idx_left_child(index) == 11
assert get_idx_right_child(index) == 12
assert split_node.is_split_node() is True
assert split_node.is_leaf_node() is False
def test_leaf_node():
index = 5
leaf_node = Node.new_leaf_node(value=3.14, idx_data_points=[1, 2, 3])
assert get_depth(index) == 2
assert leaf_node.value == 3.14
assert leaf_node.idx_split_variable == -1
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
assert get_idx_left_child(index) == 11
assert get_idx_right_child(index) == 12
assert leaf_node.is_split_node() is False
assert leaf_node.is_leaf_node() is True