1
1
using System ;
2
2
using System . Collections . Generic ;
3
+ using System . ComponentModel ;
3
4
using System . IO ;
4
5
using System . Numerics ;
5
6
using Microsoft . Mesh . CloudScripting ;
@@ -9,27 +10,38 @@ namespace Presentation1
9
10
{
10
11
public class QLearning
11
12
{
12
- private double [ , ] qTable ; // Table for the memory of the npc, works with a reward value
13
- private Dictionary < int , double > maxQValues = new Dictionary < int , double > ( ) ;
14
- private Dictionary < int , int > maxQActions = new Dictionary < int , int > ( ) ;
15
- private int numStates , numActions ; //numState : number of stats possible | numActions : number of actions posible
16
- private double learningRate , discountFactor , explorationRate ; // Parameters for the Q learning algorithm
17
- private float lastDistance = 9999 ; // Stores the last distance from the choosen destination
18
- private Random rnd = new Random ( ) ;
19
13
const int GRID_SIZE = 60 ;
20
14
const int STATE_MODULUS = 100000 ;
21
- private Dictionary < int , Vector3 > actionDirections = new Dictionary < int , Vector3 >
15
+ const float DISTANCE_THRESHOLD = 0.1f ;
16
+ const int REWARD_GOAL = 100 ;
17
+ const int REWARD_FAR = - 1 ;
18
+ const int REWARD_CLOSE = 1 ;
19
+
20
+ private readonly double [ , ] qTable ; // Table for the memory of the npc, works with a reward value
21
+ private readonly int numStates , numActions ; //numState : number of stats possible | numActions : number of actions posible
22
+ private readonly double learningRate , discountFactor , explorationRate ; // Parameters for the Q learning algorithm
23
+ private readonly bool [ , ] gridObstacles ;
24
+ private readonly object lockObject = new object ( ) ;
25
+ private readonly Random rnd = new Random ( ) ;
26
+
27
+ private List < Vector3 > npcPositions = new List < Vector3 > ( ) ;
28
+
29
+ private float lastDistance = 9999 ; // Stores the last distance from the choosen destination
30
+
31
+ private readonly Dictionary < int , double > maxQValues = new Dictionary < int , double > ( ) ;
32
+ private readonly Dictionary < int , int > maxQActions = new Dictionary < int , int > ( ) ;
33
+ private readonly Dictionary < float , Vector3 > actionDirections = new Dictionary < float , Vector3 >
22
34
{
23
- { 0 , new Vector3 ( 0 , 0 , 1 ) } , // Move up
24
- { 1 , new Vector3 ( 0 , 0 , - 1 ) } , // Move down
25
- { 2 , new Vector3 ( - 1 , 0 , 0 ) } , // Move left
26
- { 3 , new Vector3 ( 1 , 0 , 0 ) } , // Move right
27
- { 4 , new Vector3 ( 1 , 0f , 1 ) } , // Move up-right
28
- { 5 , new Vector3 ( - 1 , 0f , 1 ) } , // Move up-left
29
- { 6 , new Vector3 ( 1 , 0f , - 1 ) } , // Move down-right
30
- { 7 , new Vector3 ( - 1 , 0f , - 1 ) } // Move down-left
35
+ { 0 , new Vector3 ( 0 , 0 , 1f ) } , // Move up
36
+ { 1 , new Vector3 ( 0 , 0 , - 1f ) } , // Move down
37
+ { 2 , new Vector3 ( - 1f , 0 , 0 ) } , // Move left
38
+ { 3 , new Vector3 ( 1f , 0 , 0 ) } , // Move right
39
+ { 4 , new Vector3 ( 1f , 0f , 1f ) } , // Move up-right
40
+ { 5 , new Vector3 ( - 01f , 0f , 1f ) } , // Move up-left
41
+ { 6 , new Vector3 ( 1f , 0f , - 1f ) } , // Move down-right
42
+ { 7 , new Vector3 ( - 1f , 0f , - 1f ) } // Move down-left
31
43
} ;
32
- private bool [ , ] gridObstacles ;
44
+
33
45
public QLearning ( int numStates , int numActions , double learningRate , double discountFactor , double explorationRate , int npcNum )
34
46
{
35
47
this . numStates = numStates ;
@@ -40,7 +52,6 @@ public QLearning(int numStates, int numActions, double learningRate, double disc
40
52
41
53
qTable = new double [ numStates , numActions ] ;
42
54
43
- LoadQTable ( npcNum ) ;
44
55
gridObstacles = LoadGrid ( ) ;
45
56
}
46
57
@@ -55,22 +66,23 @@ public int ChooseAction(int state)
55
66
}
56
67
else
57
68
{
69
+ maxQActions . TryGetValue ( state , out int action ) ;
58
70
// Exploit: select the action with max value
59
- return maxQActions . ContainsKey ( state ) ? maxQActions [ state ] : 0 ;
71
+ return action ;
60
72
}
61
73
}
62
74
63
75
// Update the Q value in the Q table with the reward it gets
64
76
public void UpdateQValue ( int prevState , int action , float reward , int nextState )
65
77
{
66
- double oldValue = qTable [ prevState , action ] ;
67
- if ( ! maxQValues . ContainsKey ( prevState ) || qTable [ prevState , action ] > maxQValues [ prevState ] )
68
- {
69
- double learnedValue = reward + discountFactor * ( maxQValues . ContainsKey ( nextState ) ? maxQValues [ nextState ] : 0 ) ;
70
- qTable [ prevState , action ] += learningRate * ( learnedValue - oldValue ) ;
71
- maxQValues [ prevState ] = qTable [ prevState , action ] ;
72
- maxQActions [ prevState ] = action ;
73
- }
78
+ double oldValue = qTable [ prevState , action ] ;
79
+ if ( ! maxQValues . ContainsKey ( prevState ) || qTable [ prevState , action ] > maxQValues [ prevState ] )
80
+ {
81
+ double learnedValue = reward + discountFactor * ( maxQValues . ContainsKey ( nextState ) ? maxQValues [ nextState ] : 0 ) ;
82
+ qTable [ prevState , action ] += learningRate * ( learnedValue - oldValue ) ;
83
+ maxQValues [ prevState ] = qTable [ prevState , action ] ;
84
+ maxQActions [ prevState ] = action ;
85
+ }
74
86
}
75
87
76
88
// Get the position/state of the npc
@@ -86,17 +98,22 @@ public int GetState(TransformNode npc)
86
98
}
87
99
88
100
// Main function that make the npc move and calls all the subfunctions
89
- public async void MoveAction ( TransformNode npc , Vector3 destination , int npcNum , int numIterations )
101
+ public async void MoveAction ( TransformNode npc , Vector3 destination , string destinationName , int npcNum , int numIterations , CancellationToken cancellationToken )
90
102
{
103
+ LoadQTable ( npcNum , destinationName ) ;
91
104
for ( int i = 0 ; i < numIterations ; i ++ )
92
105
{
106
+ if ( cancellationToken . IsCancellationRequested )
107
+ {
108
+ return ;
109
+ }
93
110
int prevState = GetState ( npc ) ;
94
111
int action = ChooseAction ( prevState ) ;
95
112
96
113
Vector3 direction = actionDirections [ action ] ;
97
114
98
- await RotateNpc ( npc , direction ) ;
99
- await MoveNpc ( npc , direction , npcNum ) ;
115
+ await RotateNpc ( npc , direction , cancellationToken ) ;
116
+ await MoveNpc ( npc , direction , npcNum , cancellationToken ) ;
100
117
101
118
// Calculate the reward
102
119
float reward = CalculateReward ( npc , destination ) ;
@@ -107,39 +124,51 @@ public async void MoveAction(TransformNode npc, Vector3 destination, int npcNum,
107
124
UpdateQValue ( prevState , action , reward , nextState ) ;
108
125
if ( i % 100 == 0 )
109
126
{
110
- SaveQTable ( npcNum ) ;
127
+ SaveQTable ( npcNum , destinationName ) ;
128
+ }
129
+ if ( Vector3 . Distance ( npc . Position , destination ) <= DISTANCE_THRESHOLD ) // 0.1f is a small threshold to account for floating point precision
130
+ {
131
+ break ; // Stop the movement
111
132
}
112
133
}
113
134
}
114
135
115
- public async Task MoveNpc ( TransformNode npc , Vector3 direction , int npcNum )
136
+ public async Task MoveNpc ( TransformNode npc , Vector3 direction , int npcNum , CancellationToken cancellationToken )
116
137
{
117
138
float duration = 2f ;
118
- float remainingTime = duration ;
119
139
Vector3 desiredPosition = npc . Position + direction ;
120
140
141
+ if ( npcPositions . Any ( pos => Vector3 . Distance ( pos , desiredPosition ) < DISTANCE_THRESHOLD ) )
142
+ {
143
+ // If there is a collision, return without moving the NPC
144
+ return ;
145
+ }
146
+
121
147
if ( desiredPosition . X >= - GRID_SIZE / 2 && desiredPosition . X < GRID_SIZE / 2 && desiredPosition . Z >= - GRID_SIZE / 2 && desiredPosition . Z < GRID_SIZE / 2 && ! gridObstacles [ ( int ) desiredPosition . X + GRID_SIZE / 2 , ( int ) desiredPosition . Z + GRID_SIZE / 2 ] )
122
148
{
123
149
float t = 0f ;
124
150
while ( t < 1f )
125
151
{
126
- float delatTime = 0.01f ;
127
- float stepSize = delatTime / duration ;
152
+ if ( cancellationToken . IsCancellationRequested )
153
+ {
154
+ return ;
155
+ }
156
+ float deltaTime = 0.02f ;
157
+ float stepSize = deltaTime / duration ;
128
158
t += stepSize ;
129
159
npc . Position = Vector3 . Lerp ( npc . Position , desiredPosition , t ) ;
130
160
131
161
if ( t > 1f ) t = 1f ;
132
- await Task . Delay ( ( int ) ( delatTime * 1000 ) ) ;
133
- remainingTime -= delatTime ;
162
+ await Task . Delay ( ( int ) ( deltaTime * 100 ) ) ;
134
163
}
135
164
}
136
165
137
166
}
138
167
139
- public async Task RotateNpc ( TransformNode npc , Vector3 direction )
168
+ public async Task RotateNpc ( TransformNode npc , Vector3 direction , CancellationToken cancellationToken )
140
169
{
141
170
Vector3 normalizedDirection = Vector3 . Normalize ( direction ) ;
142
-
171
+
143
172
float rotationAngleRadians = MathF . Atan2 ( normalizedDirection . X , normalizedDirection . Z ) ;
144
173
if ( rotationAngleRadians == 0 )
145
174
{
@@ -151,21 +180,23 @@ public async Task RotateNpc(TransformNode npc, Vector3 direction)
151
180
Quaternion rotation = new Quaternion ( rotationAxis . X , rotationAxis . Y , rotationAxis . Z , MathF . Cos ( rotationAngleRadians / 2 ) ) ;
152
181
if ( Equals ( npc . Rotation , rotation ) ) return ;
153
182
154
- float duration = 1f ;
155
- float remainingTime = duration ;
183
+ float duration = 0.5f ;
156
184
float t = 0f ;
157
185
while ( t < 1f )
158
186
{
159
- float deltaTime = 0.01f ; // Adjust as needed
187
+ if ( cancellationToken . IsCancellationRequested )
188
+ {
189
+ return ;
190
+ }
191
+ float deltaTime = 0.02f ; // Adjust as needed
160
192
float stepSize = deltaTime / duration ;
161
193
162
194
t += stepSize ;
163
- npc . Rotation = Quaternion . Slerp ( npc . Rotation , rotation , t ) ;
195
+ npc . Rotation = Quaternion . Slerp ( npc . Rotation , rotation , t * t * ( 3 - 2 * t ) ) ;
164
196
165
197
if ( t > 1f ) t = 1f ;
166
198
167
- await Task . Delay ( ( int ) ( deltaTime * 1000 ) ) ; // Convert deltaTime to milliseconds
168
- remainingTime -= deltaTime ;
199
+ await Task . Delay ( ( int ) ( deltaTime * 100 ) ) ;
169
200
}
170
201
}
171
202
@@ -177,52 +208,47 @@ public int CalculateReward(TransformNode npc, Vector3 destination)
177
208
178
209
if ( distance == 0 )
179
210
{
180
- return 100 ; // Big reward for reaching the goal
181
- }
182
- else if ( distance >= lastDistance )
183
- {
184
- lastDistance = distance ;
185
- return - 1 ;
211
+ return REWARD_GOAL ; // Big reward for reaching the goal
186
212
}
213
+ else
187
214
{
188
- lastDistance = distance ;
189
- return 1 ;
215
+ return - 1 * ( int ) distance ;
190
216
}
191
217
}
192
218
193
219
// At the end of the movement, save the Q table in a json to exploit it at the next launchs
194
- public void SaveQTable ( int npcNum )
220
+ public void SaveQTable ( int npcNum , string destinationName )
195
221
{
196
- var qTableList = new List < List < double > > ( ) ;
197
- for ( int i = 0 ; i < numStates ; i ++ )
198
- {
199
- var row = new List < double > ( ) ;
200
- for ( int j = 0 ; j < numActions ; j ++ )
222
+ var qTableList = new List < List < double > > ( ) ;
223
+ for ( int i = 0 ; i < numStates ; i ++ )
201
224
{
202
- row . Add ( qTable [ i , j ] ) ;
225
+ var row = new List < double > ( ) ;
226
+ for ( int j = 0 ; j < numActions ; j ++ )
227
+ {
228
+ row . Add ( qTable [ i , j ] ) ;
229
+ }
230
+ qTableList . Add ( row ) ;
203
231
}
204
- qTableList . Add ( row ) ;
205
- }
206
- string finalFilePath = Path . Combine ( AppDomain . CurrentDomain . BaseDirectory , "qtable" + npcNum + ".json" ) ; ;
207
- File . WriteAllText ( finalFilePath , JsonConvert . SerializeObject ( qTableList ) ) ;
232
+ string finalFilePath = Path . Combine ( AppDomain . CurrentDomain . BaseDirectory , "qtable" + npcNum + destinationName + ".json" ) ; ;
233
+ File . WriteAllText ( finalFilePath , JsonConvert . SerializeObject ( qTableList ) ) ;
208
234
}
209
235
210
236
211
237
// Load the Q table
212
- public void LoadQTable ( int npcNum )
238
+ public void LoadQTable ( int npcNum , string destinationName )
213
239
{
214
- string finalFilePath = Path . Combine ( AppDomain . CurrentDomain . BaseDirectory , "qtable" + npcNum + ".json" ) ;
215
- if ( File . Exists ( finalFilePath ) )
216
- {
217
- var qTableList = JsonConvert . DeserializeObject < List < List < double > > > ( File . ReadAllText ( finalFilePath ) ) ;
218
- for ( int i = 0 ; i < numStates ; i ++ )
240
+ string finalFilePath = Path . Combine ( AppDomain . CurrentDomain . BaseDirectory , "qtable" + npcNum + destinationName + ".json" ) ;
241
+ if ( File . Exists ( finalFilePath ) )
219
242
{
220
- for ( int j = 0 ; j < numActions ; j ++ )
243
+ var qTableList = JsonConvert . DeserializeObject < List < List < double > > > ( File . ReadAllText ( finalFilePath ) ) ;
244
+ for ( int i = 0 ; i < numStates ; i ++ )
221
245
{
222
- qTable [ i , j ] = qTableList [ i ] [ j ] ;
246
+ for ( int j = 0 ; j < numActions ; j ++ )
247
+ {
248
+ qTable [ i , j ] = qTableList [ i ] [ j ] ;
249
+ }
223
250
}
224
251
}
225
- }
226
252
}
227
253
228
254
public bool [ , ] LoadGrid ( )
0 commit comments