Description of the bug
When using a State which includes torch.Tensor field:
@dataclass
class ABMCTAState:
x_t: Optional[torch.Tensor]
It leads to the following error:
Traceback (most recent call last):
...
File "/home/test.py", line 123, in search_until
self.state = self.algo.step(self.state, generate_map, inplace=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/.venv/lib/python3.11/site-packages/treequest/algos/ab_mcts_a/algo.py", line 196, in step
self.tell(state, trial.trial_id, result)
File "/home/test/.venv/lib/python3.11/site-packages/treequest/algos/ab_mcts_a/algo.py", line 379, in tell
thompson_state.register_new_child_node(
File "/home/test/.venv/lib/python3.11/site-packages/treequest/algos/ab_mcts_a/prob_state.py", line 453, in register_new_child_node
node.parent.children.index(node)
File "<string>", line 4, in eq
File "<string>", line 4, in eq
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
Workaround
We can disable the __eq__ comparison of dataclass by:
@dataclass(eq=False)
class ABMCTAState:
x_t: Optional[torch.Tensor]