-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmcts.cpp
66 lines (54 loc) · 1.52 KB
/
mcts.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include "mcts.h"
MCTSNode new_node(MCTSNode* parent, Move move) {
MCTSNode node;
GameState state = parent->state.copy();
state.doMove(move);
node.parent = parent;
node.rewards[0] = 0.0;
node.rewards[1] = 0.0;
node.state = state;
node.visits = 0.0;
return node;
}
void MCTS::get_move(GameState* state, Move* bestMove) {
std::vector<Move> moves;
moves = state->getMoves();
}
float MCTS::ucb_score(MCTSNode* p, MCTSNode* c) {
return p->rewards[p->state.getPlayerTurn()] / p->visits * C * sqrt(log(p->visits) / c->visits);
}
MCTSNode* MCTS::traverse(MCTSNode* tree) {
MCTSNode* current = tree;
int bestNode;
float score;
float bestScore = FLT_MIN;
while(current->children.size() > 0) {
for(int i = 0; i < (int) current->children.size(); i++) {
score = ucb_score(current, ¤t[i]);
if(score > bestScore) {
bestScore = score;
bestNode = i;
}
}
current = ¤t->children[bestNode];
}
return current;
}
void expand(MCTSNode* tree, Move move) {
MCTSNode node = new_node(tree, move);
tree->children.push_back(node);
}
struct MCTSResult MCTS::search(GameState* state, Move* bestMove) {
struct MCTSResult result;
int simulation = 0;
MCTSNode root;
MCTSNode* current = &root;
GameState _state = state->copy();
while(simulation < simulations_per_move) {
current = traverse(&root);
simulation++;
}
}
void MCTS::setup() {
;
}