Skip to content

Commit d032a17

Browse files
Merge pull request #65 from toubinaattori/master
Fix an error in CVaR formulation
2 parents 23d5875 + a63f6d8 commit d032a17

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/decision_model.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using JuMP
22

3-
function decision_variable(model::Model, S::States, d::Node, I_d::Vector{Node}, names::Bool, base_name::String="")
3+
function decision_variable(model::Model, S::States, d::Node, I_d::Vector{Node}, names::Bool, base_name::String="z")
44
# Create decision variables.
55
dims = S[[I_d; d]]
66
z_d = Array{VariableRef}(undef, dims...)
@@ -378,7 +378,6 @@ function ID_to_RJT(diagram::InfluenceDiagram)
378378
C_j_aux = sort([(elem, findfirst(isequal(elem), names)) for elem in C_j], by = last)
379379
C_j = [C_j_tuple[1] for C_j_tuple in C_j_aux]
380380
C_rjt[names[j]] = C_j
381-
382381
if length(C_rjt[names[j]]) > 1
383382
u = maximum([findfirst(isequal(name), names) for name in setdiff(C_j, [names[j]])])
384383
push!(A_rjt, (names[u], names[j]))
@@ -573,15 +572,16 @@ function conditional_value_at_risk(model::Model,
573572

574573
#Finding the name and index of differing element between value nodes' information set and its preceding nodes rjt cluster.
575574
#This is needed in conditional sums for constraints.
576-
missing_element = setdiff(diagram.RJT.clusters[preceding_node_name], diagram.Nodes[value_node_name].I_j)[1]
577-
index_to_remove = findfirst(x -> x == missing_element, diagram.RJT.clusters[preceding_node_name])
575+
missing_element = setdiff(diagram.RJT.clusters[preceding_node_name], diagram.Nodes[value_node_name].I_j)
576+
index_to_remove = findall(x -> x in missing_element, diagram.RJT.clusters[preceding_node_name])
578577

579578
statevars = μVars.data[preceding_node_name].statevars
580579
statevars_dims = collect(size(statevars))
581580
statevars_dims_ranges = [1:d for d in statevars_dims]
582581

583-
function remove_index(old_tuple::NTuple{N, Int64}, index::Int64) where N
584-
return collect(ntuple(i -> i >= index ? old_tuple[i + 1] : old_tuple[i], N-1))
582+
function remove_index(old_tuple::NTuple{N, Int64}, index::Vector{Int64}) where N
583+
vector = [old_tuple[i] for i in 1:length(old_tuple) if !(i in index)]
584+
return collect(ntuple(i -> vector[i], N-length(index)))
585585
end
586586

587587
for u in unique(diagram.U.Y[1])

0 commit comments

Comments
 (0)