Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add RandomUnderSampling functionality #89

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.jl.cov
*.jl.*.cov
*.jl.mem
Manifest.toml
15 changes: 9 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
name = "ClassImbalance"
uuid = "04a18a73-7590-580c-b363-eeca0919eb2a"
authors = ["Paul Stey <[email protected]>", "Dilum Aluthge <[email protected]>", "Brown Center for Biomedical Informatics <[email protected]>"]
version = "0.8.7"
version = "0.9.0-dev"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Compat = "3"
DataFrames = "0.20"
Distributions = "0.21.3, 0.22"
StatsBase = "0.32"
julia = "1.1"
Compat = "4.6"
DataFrames = "1.5"
Distributions = "0.25"
StatsBase = "0.33"
julia = "1.8"

[extras]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
3 changes: 2 additions & 1 deletion src/ClassImbalance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ __precompile__(true)

module ClassImbalance

export smote, rose
export smote, rose, random_undersampler

include("utils.jl")
include("smote_exs.jl")
include("ub_smote.jl")
include("rose.jl")
include("random_undersampler.jl")

end # end module ClassImbalance
69 changes: 69 additions & 0 deletions src/random_undersampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import Random
import StatsBase
import DataFrames
import Tables
import MLUtils

function random_undersampler(
X,
y::T;
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto",
random_state::Union{Nothing, S} = nothing,
replacement::Bool = false
) where T <: AbstractVector where S <: Integer where A <: Any
# check if X implements getobs
@assert Tables.istable(X) "$X is not implementing the MLUtils.jl getobs interface"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any example inputs? Here is what I tried to get around this @assert:

pkg> activate --temp
pkg> add https://github.com/vmpyr/ClassImbalance.jl#vmpyr/random_undersampler
# Example from README: 
julia> using ClassImbalance
julia> y = vcat(ones(20), zeros(180)); # 0 = majority, 1 = minority
julia> X = hcat(rand(200, 10), y)
julia> X2, y2 = smote(X, y, k = 5, pct_under = 100, pct_over = 200)
([0.8133030487718941 0.6399427612098323 … 0.18570608874157368 1.0; 0.7749414282075959 0.46482119778838893 … 0.17369768251844653 1.0; … ; 0.8348062387369848 0.04923338419004153 … 0.8706114262388865 1.0; 0.8582028352406533 0.39648688626915957 … 0.46337293622094944 1.0], [1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0  …  1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
julia> using Tables
julia> random_undersampler(Tables.table(X),y)
(Tables.DictColumnTable with 40 rows, 11 columns, and schema:
 :Column1   Float64
 :Column2   Float64
 :Column3   Float64
 :Column4   Float64
 :Column5   Float64
 :Column6   Float64
 :Column7   Float64
 :Column8   Float64
 :Column9   Float64
 :Column10  Float64
 :Column11  Float64, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

julia> random_undersampler(Tables.table(X2),y2)
(Tables.DictColumnTable with 80 rows, 11 columns, and schema:
 :Column1   Float64
 :Column2   Float64
 :Column3   Float64
 :Column4   Float64
 :Column5   Float64
 :Column6   Float64
 :Column7   Float64
 :Column8   Float64
 :Column9   Float64
 :Column10  Float64
 :Column11  Float64, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I do have one I tried. I had .csv files for it too btw, you can find them here.

# import the required packages
import CSV
using ClassImbalance, DataFrames, DelimitedFiles

# load the data as defined in the task guidelines
# load X as a DataFrame
X = CSV.read("./data/pima-indians-diabetes-X.csv", DataFrame; header=["x1","x2","x3","x4","x5","x6","x7","x8"])
# load y as a CategoricalVector
y = CategoricalVector(vec(readdlm("./data/pima-indians-diabetes-y.csv", Int)))

# call the undersampler
Xover, yover = random_undersampler(X, y)

# print the results
show(Xover)
println()
show(IOContext(stdout, :limit => true), yover)
println()

# print the counts of the classes
println("initial count of class 0: ", count(y .== 0))
println("initial count of class 1: ", count(y .== 1))
println("count of class 0 after undersampling: ", count(yover .== 0))
println("count of class 1 after undersampling: ", count(yover .== 1))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I've written the function such that X contains only the features and not the labels. I'm writing a overloaded function for a data matrix containing both features and labels.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your example, the function does work correctly anyway, you can access the columns and check the data or the count of it.

julia> Xover, yover = random_undersampler(Tables.table(X),y)
(Tables.DictColumnTable with 40 rows, 10 columns, and schema:
 :Column1   Float64
 :Column2   Float64
 :Column3   Float64
 :Column4   Float64
 :Column5   Float64
 :Column6   Float64
 :Column7   Float64
 :Column8   Float64
 :Column9   Float64
 :Column10  Float64, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

julia> count(y .== 1)
20

julia> count(y .== 0)
180

julia> count(yover .== 0)
20

julia> count(yover .== 1)
20

julia> Xover
Tables.DictColumnTable with 40 rows, 10 columns, and schema:
 :Column1   Float64
 :Column2   Float64
 :Column3   Float64
 :Column4   Float64
 :Column5   Float64
 :Column6   Float64
 :Column7   Float64
 :Column8   Float64
 :Column9   Float64
 :Column10  Float64

julia> Xover.Column1
40-element Vector{Float64}:
 0.3952844545871751
 0.8704752637842621
 0.6724007567263786
 0.35193156297873285
 0.7208683859930942
 0.4178980161173723
 0.33217406300791696
 0.07224589352246602
 0.09320295190634142
 0.10316414877241253
 0.6166656578236519
 0.20575631236600955
 0.5488182158068842
 0.3156879765220615
 0.8174789837160467
 
 0.10887263478147902
 0.6100936867437683
 0.9340285563941725
 0.11318763378056362
 0.24546432607219804
 0.2788921342943955
 0.6103736533699433
 0.05058497458051314
 0.9006314393478084
 0.3487415101299778
 0.693372173075407
 0.4980414320194523
 0.9952081568741031
 0.8699331124771522
 0.11711882057394163


classes = unique(y)
classcount = Dict(c => count(y .== c) for c in classes)

# checking classes in y
@assert length(classes) > 1 "$y must have more than one class"
# checking sampling_strategy
if typeof(sampling_strategy) <: String
@assert sampling_strategy in ["auto", "not minority", "not majority", "all", "majority"] "sampling_strategy must be one of \"auto\", \"not minority\", \"not majority\", \"all\", \"majority\""
elseif typeof(sampling_strategy) <: AbstractFloat
@assert length(classes) == 2 "sampling_strategy of type float is supported only for binary classification"
@assert 0 < sampling_strategy <= 1 "sampling_strategy must be between 0 and 1"
elseif typeof(sampling_strategy) <: Dict
@assert all(c in classes for c in keys(sampling_strategy)) "sampling_strategy must have keys that are classes in $y"
@assert all(sampling_strategy[c] <= classcount[c] for c in keys(sampling_strategy)) "sampling_strategy must have values less than or equal to current number of samples for a particular class"
end

sampling_strategy = undersampling_strategy!(sampling_strategy, classes, classcount)

if !isnothing(random_state)
rng = Random.MersenneTwister(UInt(random_state))
else
rng = Random.GLOBAL_RNG
end

undersampled_idx = []
for target_class in classes
if target_class in keys(sampling_strategy)
n_samples = sampling_strategy[target_class]
target_class_idx = findall(y .== target_class)
target_class_idx_sampled = StatsBase.sample(rng, target_class_idx, n_samples, replace=replacement)
append!(undersampled_idx, target_class_idx_sampled)
else
append!(undersampled_idx, findall(y .== target_class))
end
end

return MLUtils.getobs(X, undersampled_idx), MLUtils.getobs(y, undersampled_idx)
end

function random_undersampler(
df,
label::Union{Symbol, String, S};
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto",
random_state::Union{Nothing, S} = nothing,
replacement::Bool = false
) where S <: Integer where A <: Any
df = DataFrames.DataFrame(df)
@assert label in names(df) "label or index $label does not exist in $df"

Xover, yover = random_undersampler(DataFrames.select(df, DataFrames.Not(label)), df[!, label], sampling_strategy=sampling_strategy, random_state=random_state, replacement=replacement)
DataFrames.join(Xover, DataFrames.DataFrame(label = yover))
return Xover
end
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,22 @@ function calculate_smote_pct_under(
result = 100*minority_to_majority_ratio*(100+pct_over)/pct_over
return result
end

function undersampling_strategy!(
sampling_strategy::String,
classes::T,
classcount::Dict{A, S},
) where T <: AbstractVector where S <: Integer where A <: Any
mincount = minimum(values(classcount))
maxcount = maximum(values(classcount))

if sampling_strategy == "majority"
sampling_strategy = Dict(c => mincount for c in classes if classcount[c] == maxcount)
elseif sampling_strategy == "auto" || sampling_strategy == "not minority"
sampling_strategy = Dict(c => mincount for c in classes if classcount[c] != mincount)
elseif sampling_strategy == "not majority"
sampling_strategy = Dict(c => mincount for c in classes if classcount[c] != maxcount)
elseif sampling_strategy == "all"
sampling_strategy = Dict(c => mincount for c in classes)
end
end