Skip to content

Commit e3da1a1

Browse files
committed
Fixes in tensors from graph building
1 parent ec7eb30 commit e3da1a1

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

VSharp.Explorer/AISearcher.fs

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ open VSharp
66
open VSharp.IL.Serializer
77
open VSharp.ML.GameServer.Messages
88

9-
type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option<AIAgentTrainingOptions>) =
9+
type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentTrainingOptions>) =
1010
let stepsToSwitchToAI =
1111
match aiAgentTrainingOptions with
1212
| None -> 0u<step>
@@ -162,15 +162,15 @@ type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option<AIAgentTr
162162
let attributes = Array.zeroCreate (gameState.GraphVertices.Length * numOfVertexAttributes)
163163
for i in 0..gameState.GraphVertices.Length - 1 do
164164
let v = gameState.GraphVertices.[i]
165-
verticesIds.Add(v.Id,i)
166-
let i = i*numOfVertexAttributes
167-
attributes.[i] <- float32 <| if v.InCoverageZone then 1u else 0u
168-
attributes.[i + 1] <- float32 <| v.BasicBlockSize
169-
attributes.[i + 2] <- float32 <| if v.CoveredByTest then 1u else 0u
170-
attributes.[i + 3] <- float32 <| if v.VisitedByState then 1u else 0u
171-
attributes.[i + 4] <- float32 <| if v.TouchedByState then 1u else 0u
172-
attributes.[i + 5] <- float32 <| if v.ContainsCall then 1u else 0u
173-
attributes.[i + 6] <- float32 <| if v.ContainsThrow then 1u else 0u
165+
verticesIds.Add(v.Id, i)
166+
let j = i * numOfVertexAttributes
167+
attributes.[j] <- float32 <| if v.InCoverageZone then 1u else 0u
168+
attributes.[j + 1] <- float32 <| v.BasicBlockSize
169+
attributes.[j + 2] <- float32 <| if v.CoveredByTest then 1u else 0u
170+
attributes.[j + 3] <- float32 <| if v.VisitedByState then 1u else 0u
171+
attributes.[j + 4] <- float32 <| if v.TouchedByState then 1u else 0u
172+
attributes.[j + 5] <- float32 <| if v.ContainsCall then 1u else 0u
173+
attributes.[j + 6] <- float32 <| if v.ContainsThrow then 1u else 0u
174174
OrtValue.CreateTensorValueFromMemory(attributes, shape)
175175

176176
let states, numOfParentOfEdges, numOfHistoryEdges =
@@ -183,14 +183,14 @@ type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option<AIAgentTr
183183
numOfHistoryEdges <- numOfHistoryEdges + v.History.Length
184184
numOfParentOfEdges <- numOfParentOfEdges + v.Children.Length
185185
stateIds.Add(v.Id,i)
186-
let i = i*numOfStateAttributes
187-
attributes.[i] <- float32 v.Position
188-
attributes.[i + 1] <- float32 v.PathConditionSize
189-
attributes.[i + 2] <- float32 v.VisitedAgainVertices
190-
attributes.[i + 3] <- float32 v.VisitedNotCoveredVerticesInZone
191-
attributes.[i + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
192-
attributes.[i + 6] <- float32 v.StepWhenMovedLastTime
193-
attributes.[i + 5] <- float32 v.InstructionsVisitedInCurrentBlock
186+
let j = i * numOfStateAttributes
187+
attributes.[j] <- float32 v.Position
188+
attributes.[j + 1] <- float32 v.PathConditionSize
189+
attributes.[j + 2] <- float32 v.VisitedAgainVertices
190+
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
191+
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
192+
attributes.[j + 5] <- float32 v.StepWhenMovedLastTime
193+
attributes.[j + 6] <- float32 v.InstructionsVisitedInCurrentBlock
194194
OrtValue.CreateTensorValueFromMemory(attributes, shape)
195195
,numOfParentOfEdges
196196
,numOfHistoryEdges
@@ -222,26 +222,26 @@ type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option<AIAgentTr
222222
let mutable firstFreePositionInHistoryIndex = 0
223223
let mutable firstFreePositionInHistoryAttributes = 0
224224
gameState.States
225-
|> Array.iter (fun v ->
226-
v.Children
227-
|> Array.iteri (fun i s ->
228-
let i = firstFreePositionInParentsOf + i
229-
parentOf[i] <- int64 stateIds[v.Id]
230-
parentOf[numOfParentOfEdges + i] <- int64 stateIds[s]
225+
|> Array.iter (fun state ->
226+
state.Children
227+
|> Array.iteri (fun i children ->
228+
let j = firstFreePositionInParentsOf + i
229+
parentOf[j] <- int64 stateIds[state.Id]
230+
parentOf[numOfParentOfEdges + j] <- int64 stateIds[children]
231231
)
232-
firstFreePositionInParentsOf <- firstFreePositionInParentsOf + v.Children.Length
233-
v.History
234-
|> Array.iteri (fun i s ->
232+
firstFreePositionInParentsOf <- firstFreePositionInParentsOf + state.Children.Length
233+
state.History
234+
|> Array.iteri (fun i historyElem ->
235235
let j = firstFreePositionInHistoryIndex + i
236-
historyIndex_vertexToState[j] <- int64 verticesIds[s.GraphVertexId]
237-
historyIndex_vertexToState[numOfHistoryEdges + j] <- int64 stateIds[v.Id]
236+
historyIndex_vertexToState[j] <- int64 verticesIds[historyElem.GraphVertexId]
237+
historyIndex_vertexToState[numOfHistoryEdges + j] <- int64 stateIds[state.Id]
238238

239-
let j = firstFreePositionInHistoryAttributes + i
240-
historyAttributes[j] <- int64 s.NumOfVisits
241-
historyAttributes[j + 1] <- int64 s.StepWhenVisitedLastTime
239+
let j = firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * i
240+
historyAttributes[j] <- int64 historyElem.NumOfVisits
241+
historyAttributes[j + 1] <- int64 historyElem.StepWhenVisitedLastTime
242242
)
243-
firstFreePositionInHistoryIndex <- firstFreePositionInHistoryIndex + v.History.Length
244-
firstFreePositionInHistoryAttributes <- firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * v.History.Length
243+
firstFreePositionInHistoryIndex <- firstFreePositionInHistoryIndex + state.History.Length
244+
firstFreePositionInHistoryAttributes <- firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * state.History.Length
245245
)
246246

247247
OrtValue.CreateTensorValueFromMemory(historyIndex_vertexToState, shapeOfHistory)
@@ -257,15 +257,15 @@ type internal AISearcher(oracle:Oracle, aiAgentTrainingOptions: Option<AIAgentTr
257257
|> Array.iter (
258258
fun v ->
259259
v.States
260-
|> Array.iteri (fun i s ->
261-
let startPos = firstFreePosition + i
262-
let s = stateIds[s]
263-
let v' = verticesIds[v.Id]
264-
data_stateToVertex[startPos] <- int64 s
265-
data_stateToVertex[stateIds.Count + i] <- int64 v'
260+
|> Array.iteri (fun i stateId ->
261+
let j = firstFreePosition + i
262+
let stateIndex = int64 stateIds[stateId]
263+
let vertexIndex = int64 verticesIds[v.Id]
264+
data_stateToVertex[j] <- stateIndex
265+
data_stateToVertex[stateIds.Count + j] <- vertexIndex
266266

267-
data_vertexToState[i] <- int64 v'
268-
data_vertexToState[stateIds.Count + i] <- int64 s
267+
data_vertexToState[j] <- vertexIndex
268+
data_vertexToState[stateIds.Count + j] <- stateIndex
269269
)
270270
firstFreePosition <- firstFreePosition + v.States.Length
271271
)

0 commit comments

Comments
 (0)