Skip to content

Commit

Permalink
Move Shapley fix from #167 (#168)
Browse files Browse the repository at this point in the history
* Update shapley_sensitivity.jl

* Update shapley_method.jl
  • Loading branch information
Vaibhavdixit02 authored May 30, 2024
1 parent 394dc74 commit 5a389f3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/shapley_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ function gsa(f, method::Shapley, input_distribution::SklarDist; batch = false)
sample_complement = rand(
Copulas.subsetdims(input_distribution, idx_minus), n_outer)

if size(sample_complement, 2) == 1
sample_complement = reshape(
sample_complement, (1, length(sample_complement)))
end

for l in 1:n_outer
curr_sample = @view sample_complement[:, l]
# Sampling of the set conditionally to the complementary element
Expand Down
14 changes: 6 additions & 8 deletions test/shapley_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,31 @@ n_perms = -1;
n_var = 10_000;
n_outer = 1000;
n_inner = 3;
dim = 3;
margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi));
dim = 4;
margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi));
dependency_matrix = Matrix(4 * I, dim, dim);
C = GaussianCopula(dependency_matrix);
input_distribution = SklarDist(C, margins);

method = Shapley(n_perms = n_perms,
n_var = n_var,
n_outer = n_outer,
n_inner = n_inner);

#---> non batch
@time result = gsa(ishi, method, input_distribution, batch = false)

@test result.shapley_effects[1]0.43813841765976547 atol=1e-1
@test result.shapley_effects[2]0.44673952698721386 atol=1e-1
@test result.shapley_effects[3]0.23144736934254417 atol=1e-1
# @test result.shapley_effects[4]≈0.0 atol=1e-1
@test result.shapley_effects[3]0.11855122481995543 atol=1e-1
@test result.shapley_effects[4]0.0 atol=1e-1
#<---- non batch

#---> batch
result = gsa(ishi_batch, method, input_distribution, batch = true);

@test result.shapley_effects[1]0.44080027198796035 atol=1e-1
@test result.shapley_effects[2]0.43029987176805085 atol=1e-1
@test result.shapley_effects[3]0.23144736934254417 atol=1e-1
# @test result.shapley_effects[4]≈0.0 atol=1e-1
@test result.shapley_effects[3]0.11855122481995543 atol=1e-1
@test result.shapley_effects[4]0.0 atol=1e-1
#<--- batch

d = 3
Expand Down

0 comments on commit 5a389f3

Please sign in to comment.