From 59c37034fd8ba211172eae8cbc64ab0667b1da42 Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Fri, 7 Apr 2023 15:55:52 +0530 Subject: [PATCH 1/6] change package version and setup initial files --- .gitignore | 1 + Project.toml | 10 +++++----- src/ClassImbalance.jl | 1 + src/random_undersampler.jl | 0 4 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 src/random_undersampler.jl diff --git a/.gitignore b/.gitignore index 8c960ec..0ee3d17 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.jl.cov *.jl.*.cov *.jl.mem +Manifest.toml \ No newline at end of file diff --git a/Project.toml b/Project.toml index 95202f4..39b51a3 100644 --- a/Project.toml +++ b/Project.toml @@ -13,11 +13,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -Compat = "3" -DataFrames = "0.20" -Distributions = "0.21.3, 0.22" -StatsBase = "0.32" -julia = "1.1" +Compat = "4.6" +DataFrames = "1.5" +Distributions = "0.25" +StatsBase = "0.33" +julia = "1.8" [extras] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/ClassImbalance.jl b/src/ClassImbalance.jl index 4da861e..32e689c 100644 --- a/src/ClassImbalance.jl +++ b/src/ClassImbalance.jl @@ -8,5 +8,6 @@ include("utils.jl") include("smote_exs.jl") include("ub_smote.jl") include("rose.jl") +include("random_undersampler.jl") end # end module ClassImbalance diff --git a/src/random_undersampler.jl b/src/random_undersampler.jl new file mode 100644 index 0000000..e69de29 From 6ff5d68b03a5f96796e1222ee3cb8f55f8446982 Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Sat, 8 Apr 2023 18:58:06 +0530 Subject: [PATCH 2/6] add: sampling_strategy functionality --- Project.toml | 1 + src/random_undersampler.jl | 37 +++++++++++++++++++++++++++++++++++++ src/utils.jl | 18 ++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/Project.toml b/Project.toml index 39b51a3..baee9b8 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Compat = "4.6" diff --git a/src/random_undersampler.jl b/src/random_undersampler.jl index e69de29..42654bf 100644 --- a/src/random_undersampler.jl +++ b/src/random_undersampler.jl @@ -0,0 +1,37 @@ +import DataFrames +import Distributions +import LinearAlgebra +import Statistics +import StatsBase +import Random +import Tables + +function random_undersampler( + X, + y::T; + sampling_strategy::Union{AbstractFloat, String, Dict{Any, Int}} = "auto", + random_state=nothing + ) where T <: AbstractVector + # check if X implements getobs + @assert Tables.istable(X) "$X is not implementing the MLUtils.jl getobs interface" + + classes = unique(y) + classpos = Dict(c => findall(y .== c) for c in classes) + classcount = Dict(c => length(classpos[c]) for c in classes) + + # checking classes in y + @assert length(classes) > 1 "$y must have more than one class" + # checking sampling_strategy + if typeof(sampling_strategy) <: String + @assert sampling_strategy in ["auto", "not minority", "not majority", "all", "majority"] "sampling_strategy must be one of \"auto\", \"not minority\", \"not majority\", \"all\", \"majority\"" + elseif typeof(sampling_strategy) <: AbstractFloat + @assert length(classes) == 2 "sampling_strategy of type float is supported only for binary classification" + @assert 0 < sampling_strategy <= 1 "sampling_strategy must be between 0 and 1" + elseif typeof(sampling_strategy) <: Dict + @assert all(c in classes for c in keys(sampling_strategy)) "sampling_strategy must have keys that are classes in $y" + @assert all(sampling_strategy[c] <= classcount[c] for c in keys(sampling_strategy)) "sampling_strategy must have values less than or equal to current number of samples for a particular class" + end + + X_new = DataFrames.DataFrame(X) + y_new = copy(y) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 3e103ad..d1ab917 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -86,3 +86,21 @@ function calculate_smote_pct_under( result = 100*minority_to_majority_ratio*(100+pct_over)/pct_over return result end + +function undersampling_strategy!( + sampling_strategy::String, + classes::T, + classcount::Dict{Any, Int}, + ) where T <: AbstractVector + mincount = minimum(values(classcount)) + maxcount = maximum(values(classcount)) + + if sampling_strategy == "majority" + sampling_strategy = Dict(c => mincount for c in classes if classcount[c] == maxcount) + elseif sampling_strategy == "auto" || sampling_strategy == "not minority" + sampling_strategy = Dict(c => mincount for c in classes if classcount[c] != mincount) + elseif sampling_strategy == "not majority" + sampling_strategy = Dict(c => mincount for c in classes if classcount[c] != maxcount) + elseif sampling_strategy == "all" + sampling_strategy = Dict(c => mincount for c in classes) +end \ No newline at end of file From ad906955f4baafb1edf010fe7c33435b510758da Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Mon, 10 Apr 2023 17:32:43 +0530 Subject: [PATCH 3/6] change version number for passing checks --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index baee9b8..394c18f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClassImbalance" uuid = "04a18a73-7590-580c-b363-eeca0919eb2a" authors = ["Paul Stey ", "Dilum Aluthge ", "Brown Center for Biomedical Informatics "] -version = "0.8.7" +version = "0.8.8" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 716b1ec7818728521e0be90105fc5ea08c5ff81d Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Mon, 10 Apr 2023 20:02:22 +0530 Subject: [PATCH 4/6] add: random_undersampler functionality --- Project.toml | 1 + src/ClassImbalance.jl | 2 +- src/random_undersampler.jl | 43 ++++++++++++++++++++++++++------------ src/utils.jl | 7 ++++--- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 394c18f..1d7e950 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/src/ClassImbalance.jl b/src/ClassImbalance.jl index 32e689c..8fbce48 100644 --- a/src/ClassImbalance.jl +++ b/src/ClassImbalance.jl @@ -2,7 +2,7 @@ __precompile__(true) module ClassImbalance -export smote, rose +export smote, rose, random_undersampler include("utils.jl") include("smote_exs.jl") diff --git a/src/random_undersampler.jl b/src/random_undersampler.jl index 42654bf..5c0e87d 100644 --- a/src/random_undersampler.jl +++ b/src/random_undersampler.jl @@ -1,23 +1,21 @@ -import DataFrames -import Distributions -import LinearAlgebra -import Statistics -import StatsBase import Random +import StatsBase +import DataFrames import Tables +import MLUtils function random_undersampler( X, y::T; - sampling_strategy::Union{AbstractFloat, String, Dict{Any, Int}} = "auto", - random_state=nothing - ) where T <: AbstractVector + sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", + random_state::Union{Nothing, S} = nothing, + replacement::Bool = false + ) where T <: AbstractVector where S <: Integer where A <: Any # check if X implements getobs @assert Tables.istable(X) "$X is not implementing the MLUtils.jl getobs interface" classes = unique(y) - classpos = Dict(c => findall(y .== c) for c in classes) - classcount = Dict(c => length(classpos[c]) for c in classes) + classcount = Dict(c => count(y .== c) for c in classes) # checking classes in y @assert length(classes) > 1 "$y must have more than one class" @@ -32,6 +30,25 @@ function random_undersampler( @assert all(sampling_strategy[c] <= classcount[c] for c in keys(sampling_strategy)) "sampling_strategy must have values less than or equal to current number of samples for a particular class" end - X_new = DataFrames.DataFrame(X) - y_new = copy(y) -end \ No newline at end of file + sampling_strategy = undersampling_strategy!(sampling_strategy, classes, classcount) + + if !isnothing(random_state) + rng = Random.MersenneTwister(UInt(random_state)) + else + rng = Random.GLOBAL_RNG + end + + undersampled_idx = [] + for target_class in classes + if target_class in keys(sampling_strategy) + n_samples = sampling_strategy[target_class] + target_class_idx = findall(y .== target_class) + target_class_idx_sampled = StatsBase.sample(rng, target_class_idx, n_samples, replace=replacement) + append!(undersampled_idx, target_class_idx_sampled) + else + append!(undersampled_idx, findall(y .== target_class)) + end + end + + return MLUtils.getobs(X, undersampled_idx), MLUtils.getobs(y, undersampled_idx) +end diff --git a/src/utils.jl b/src/utils.jl index d1ab917..95b5a42 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -90,8 +90,8 @@ end function undersampling_strategy!( sampling_strategy::String, classes::T, - classcount::Dict{Any, Int}, - ) where T <: AbstractVector + classcount::Dict{A, S}, + ) where T <: AbstractVector where S <: Integer where A <: Any mincount = minimum(values(classcount)) maxcount = maximum(values(classcount)) @@ -103,4 +103,5 @@ function undersampling_strategy!( sampling_strategy = Dict(c => mincount for c in classes if classcount[c] != maxcount) elseif sampling_strategy == "all" sampling_strategy = Dict(c => mincount for c in classes) -end \ No newline at end of file + end +end From d92b4bc7818388542337060c5d9661ad870645cb Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Tue, 11 Apr 2023 18:38:58 +0530 Subject: [PATCH 5/6] add: scribbled random_undersampler in join X and y are passed --- Project.toml | 1 + src/random_undersampler.jl | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/Project.toml b/Project.toml index 1d7e950..4fe8d5d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Paul Stey ", "Dilum Aluthge " version = "0.8.8" [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/random_undersampler.jl b/src/random_undersampler.jl index 5c0e87d..be7c198 100644 --- a/src/random_undersampler.jl +++ b/src/random_undersampler.jl @@ -52,3 +52,18 @@ function random_undersampler( return MLUtils.getobs(X, undersampled_idx), MLUtils.getobs(y, undersampled_idx) end + +function random_undersampler( + df, + label::Union{Symbol, String, S}; + sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", + random_state::Union{Nothing, S} = nothing, + replacement::Bool = false + ) where S <: Integer where A <: Any + df = DataFrames.DataFrame(df) + @assert label in names(df) "label or index $label does not exist in $df" + + Xover, yover = random_undersampler(DataFrames.select(df, DataFrames.Not(label)), df[!, label], sampling_strategy=sampling_strategy, random_state=random_state, replacement=replacement) + DataFrames.join(Xover, DataFrames.DataFrame(label = yover)) + return Xover +end From 874a54b8bc8487eddf9a9cd9e4ecca2cfe70106b Mon Sep 17 00:00:00 2001 From: Ujjwal Sarswat Date: Thu, 13 Apr 2023 11:48:58 +0530 Subject: [PATCH 6/6] chore(Project.toml): change version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4fe8d5d..4ab45cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClassImbalance" uuid = "04a18a73-7590-580c-b363-eeca0919eb2a" authors = ["Paul Stey ", "Dilum Aluthge ", "Brown Center for Biomedical Informatics "] -version = "0.8.8" +version = "0.9.0-dev" [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"