Skip to content

Commit

Permalink
Fix bug when giving strategy to value_iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Oct 11, 2024
1 parent 530f14c commit a90f234
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/bellman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ function state_bellman!(
)
@inbounds begin
s₁ = stateptr[jₛ]
jₐ = s₁ + strategy_cache[s₁] - 1
jₐ = s₁ + strategy_cache[jₛ] - 1
Vres[jₛ] = state_action_bellman(workspace, V, prob, jₐ, upper_bound)
end
end
Expand Down
2 changes: 2 additions & 0 deletions src/strategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct StationaryStrategy{A <: AbstractArray{Int32}} <: AbstractStrategy
strategy::A
end
Base.getindex(strategy::StationaryStrategy, k) = strategy.strategy
time_length(::StationaryStrategy) = typemax(Int64)

function checkstrategy!(strategy::StationaryStrategy, system)
checkstrategy!(strategy.strategy, system)
Expand Down Expand Up @@ -131,6 +132,7 @@ end

construct_strategy_cache(::IntervalMarkovProcess, ::GivenStrategyConfig, strategy, dims) =
GivenStrategyCache(strategy)
time_length(cache::GivenStrategyCache) = time_length(cache.strategy)

struct ActiveGivenStrategyCache{A <: AbstractArray{Int32}} <: NonOptimizingStrategyCache
strategy::A
Expand Down
2 changes: 1 addition & 1 deletion src/value_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function value_iteration(problem::Problem)
return V, k, res
end
whichstrategyconfig(::Problem{S, F, <:NoStrategy}) where {S, F} = NoStrategyConfig()
whichstrategyconfig(::Problem{S, F, <:AbstractStrategy}) where {S, F} = NoStrategyConfig()
whichstrategyconfig(::Problem{S, F, <:AbstractStrategy}) where {S, F} = GivenStrategyConfig()

function _value_iteration!(strategy_config::AbstractStrategyConfig, problem::Problem)
mp = system(problem)
Expand Down

0 comments on commit a90f234

Please sign in to comment.