-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add RandomUnderSampling functionality #89
base: master
Are you sure you want to change the base?
Changes from 5 commits
59c3703
6ff5d68
ad90695
716b1ec
d92b4bc
874a54b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
*.jl.cov | ||
*.jl.*.cov | ||
*.jl.mem | ||
Manifest.toml |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,26 @@ | ||
name = "ClassImbalance" | ||
uuid = "04a18a73-7590-580c-b363-eeca0919eb2a" | ||
authors = ["Paul Stey <[email protected]>", "Dilum Aluthge <[email protected]>", "Brown Center for Biomedical Informatics <[email protected]>"] | ||
version = "0.8.7" | ||
version = "0.8.8" | ||
|
||
[deps] | ||
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | ||
|
||
[compat] | ||
Compat = "3" | ||
DataFrames = "0.20" | ||
Distributions = "0.21.3, 0.22" | ||
StatsBase = "0.32" | ||
julia = "1.1" | ||
Compat = "4.6" | ||
DataFrames = "1.5" | ||
Distributions = "0.25" | ||
StatsBase = "0.33" | ||
julia = "1.8" | ||
|
||
[extras] | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import Random | ||
import StatsBase | ||
import DataFrames | ||
import Tables | ||
import MLUtils | ||
|
||
function random_undersampler( | ||
X, | ||
y::T; | ||
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", | ||
random_state::Union{Nothing, S} = nothing, | ||
replacement::Bool = false | ||
) where T <: AbstractVector where S <: Integer where A <: Any | ||
# check if X implements getobs | ||
@assert Tables.istable(X) "$X is not implementing the MLUtils.jl getobs interface" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have any example inputs? Here is what I tried to get around this
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I do have one I tried. I had .csv files for it too btw, you can find them here. # import the required packages
import CSV
using ClassImbalance, DataFrames, DelimitedFiles
# load the data as defined in the task guidelines
# load X as a DataFrame
X = CSV.read("./data/pima-indians-diabetes-X.csv", DataFrame; header=["x1","x2","x3","x4","x5","x6","x7","x8"])
# load y as a CategoricalVector
y = CategoricalVector(vec(readdlm("./data/pima-indians-diabetes-y.csv", Int)))
# call the undersampler
Xover, yover = random_undersampler(X, y)
# print the results
show(Xover)
println()
show(IOContext(stdout, :limit => true), yover)
println()
# print the counts of the classes
println("initial count of class 0: ", count(y .== 0))
println("initial count of class 1: ", count(y .== 1))
println("count of class 0 after undersampling: ", count(yover .== 0))
println("count of class 1 after undersampling: ", count(yover .== 1)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I've written the function such that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In your example, the function does work correctly anyway, you can access the columns and check the data or the count of it. julia> Xover, yover = random_undersampler(Tables.table(X),y)
(Tables.DictColumnTable with 40 rows, 10 columns, and schema:
:Column1 Float64
:Column2 Float64
:Column3 Float64
:Column4 Float64
:Column5 Float64
:Column6 Float64
:Column7 Float64
:Column8 Float64
:Column9 Float64
:Column10 Float64, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
julia> count(y .== 1)
20
julia> count(y .== 0)
180
julia> count(yover .== 0)
20
julia> count(yover .== 1)
20
julia> Xover
Tables.DictColumnTable with 40 rows, 10 columns, and schema:
:Column1 Float64
:Column2 Float64
:Column3 Float64
:Column4 Float64
:Column5 Float64
:Column6 Float64
:Column7 Float64
:Column8 Float64
:Column9 Float64
:Column10 Float64
julia> Xover.Column1
40-element Vector{Float64}:
0.3952844545871751
0.8704752637842621
0.6724007567263786
0.35193156297873285
0.7208683859930942
0.4178980161173723
0.33217406300791696
0.07224589352246602
0.09320295190634142
0.10316414877241253
0.6166656578236519
0.20575631236600955
0.5488182158068842
0.3156879765220615
0.8174789837160467
⋮
0.10887263478147902
0.6100936867437683
0.9340285563941725
0.11318763378056362
0.24546432607219804
0.2788921342943955
0.6103736533699433
0.05058497458051314
0.9006314393478084
0.3487415101299778
0.693372173075407
0.4980414320194523
0.9952081568741031
0.8699331124771522
0.11711882057394163 |
||
|
||
classes = unique(y) | ||
classcount = Dict(c => count(y .== c) for c in classes) | ||
|
||
# checking classes in y | ||
@assert length(classes) > 1 "$y must have more than one class" | ||
# checking sampling_strategy | ||
if typeof(sampling_strategy) <: String | ||
@assert sampling_strategy in ["auto", "not minority", "not majority", "all", "majority"] "sampling_strategy must be one of \"auto\", \"not minority\", \"not majority\", \"all\", \"majority\"" | ||
elseif typeof(sampling_strategy) <: AbstractFloat | ||
@assert length(classes) == 2 "sampling_strategy of type float is supported only for binary classification" | ||
@assert 0 < sampling_strategy <= 1 "sampling_strategy must be between 0 and 1" | ||
elseif typeof(sampling_strategy) <: Dict | ||
@assert all(c in classes for c in keys(sampling_strategy)) "sampling_strategy must have keys that are classes in $y" | ||
@assert all(sampling_strategy[c] <= classcount[c] for c in keys(sampling_strategy)) "sampling_strategy must have values less than or equal to current number of samples for a particular class" | ||
end | ||
|
||
sampling_strategy = undersampling_strategy!(sampling_strategy, classes, classcount) | ||
|
||
if !isnothing(random_state) | ||
rng = Random.MersenneTwister(UInt(random_state)) | ||
else | ||
rng = Random.GLOBAL_RNG | ||
end | ||
|
||
undersampled_idx = [] | ||
for target_class in classes | ||
if target_class in keys(sampling_strategy) | ||
n_samples = sampling_strategy[target_class] | ||
target_class_idx = findall(y .== target_class) | ||
target_class_idx_sampled = StatsBase.sample(rng, target_class_idx, n_samples, replace=replacement) | ||
append!(undersampled_idx, target_class_idx_sampled) | ||
else | ||
append!(undersampled_idx, findall(y .== target_class)) | ||
end | ||
end | ||
|
||
return MLUtils.getobs(X, undersampled_idx), MLUtils.getobs(y, undersampled_idx) | ||
end | ||
|
||
function random_undersampler( | ||
df, | ||
label::Union{Symbol, String, S}; | ||
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", | ||
random_state::Union{Nothing, S} = nothing, | ||
replacement::Bool = false | ||
) where S <: Integer where A <: Any | ||
df = DataFrames.DataFrame(df) | ||
@assert label in names(df) "label or index $label does not exist in $df" | ||
|
||
Xover, yover = random_undersampler(DataFrames.select(df, DataFrames.Not(label)), df[!, label], sampling_strategy=sampling_strategy, random_state=random_state, replacement=replacement) | ||
DataFrames.join(Xover, DataFrames.DataFrame(label = yover)) | ||
return Xover | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a minor release (
"0.9.0"
or"0.9.0-dev"
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright, I'll do that