-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add RandomUnderSampling functionality #89
Open
vmpyr
wants to merge
6
commits into
bcbi:master
Choose a base branch
from
vmpyr:vmpyr/random_undersampler
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
59c3703
change package version and setup initial files
vmpyr 6ff5d68
add: sampling_strategy functionality
vmpyr ad90695
change version number for passing checks
vmpyr 716b1ec
add: random_undersampler functionality
vmpyr d92b4bc
add: scribbled random_undersampler in join X and y are passed
vmpyr 874a54b
chore(Project.toml): change version number
vmpyr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
*.jl.cov | ||
*.jl.*.cov | ||
*.jl.mem | ||
Manifest.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,26 @@ | ||
name = "ClassImbalance" | ||
uuid = "04a18a73-7590-580c-b363-eeca0919eb2a" | ||
authors = ["Paul Stey <[email protected]>", "Dilum Aluthge <[email protected]>", "Brown Center for Biomedical Informatics <[email protected]>"] | ||
version = "0.8.7" | ||
version = "0.9.0-dev" | ||
|
||
[deps] | ||
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | ||
|
||
[compat] | ||
Compat = "3" | ||
DataFrames = "0.20" | ||
Distributions = "0.21.3, 0.22" | ||
StatsBase = "0.32" | ||
julia = "1.1" | ||
Compat = "4.6" | ||
DataFrames = "1.5" | ||
Distributions = "0.25" | ||
StatsBase = "0.33" | ||
julia = "1.8" | ||
|
||
[extras] | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import Random | ||
import StatsBase | ||
import DataFrames | ||
import Tables | ||
import MLUtils | ||
|
||
function random_undersampler( | ||
X, | ||
y::T; | ||
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", | ||
random_state::Union{Nothing, S} = nothing, | ||
replacement::Bool = false | ||
) where T <: AbstractVector where S <: Integer where A <: Any | ||
# check if X implements getobs | ||
@assert Tables.istable(X) "$X is not implementing the MLUtils.jl getobs interface" | ||
|
||
classes = unique(y) | ||
classcount = Dict(c => count(y .== c) for c in classes) | ||
|
||
# checking classes in y | ||
@assert length(classes) > 1 "$y must have more than one class" | ||
# checking sampling_strategy | ||
if typeof(sampling_strategy) <: String | ||
@assert sampling_strategy in ["auto", "not minority", "not majority", "all", "majority"] "sampling_strategy must be one of \"auto\", \"not minority\", \"not majority\", \"all\", \"majority\"" | ||
elseif typeof(sampling_strategy) <: AbstractFloat | ||
@assert length(classes) == 2 "sampling_strategy of type float is supported only for binary classification" | ||
@assert 0 < sampling_strategy <= 1 "sampling_strategy must be between 0 and 1" | ||
elseif typeof(sampling_strategy) <: Dict | ||
@assert all(c in classes for c in keys(sampling_strategy)) "sampling_strategy must have keys that are classes in $y" | ||
@assert all(sampling_strategy[c] <= classcount[c] for c in keys(sampling_strategy)) "sampling_strategy must have values less than or equal to current number of samples for a particular class" | ||
end | ||
|
||
sampling_strategy = undersampling_strategy!(sampling_strategy, classes, classcount) | ||
|
||
if !isnothing(random_state) | ||
rng = Random.MersenneTwister(UInt(random_state)) | ||
else | ||
rng = Random.GLOBAL_RNG | ||
end | ||
|
||
undersampled_idx = [] | ||
for target_class in classes | ||
if target_class in keys(sampling_strategy) | ||
n_samples = sampling_strategy[target_class] | ||
target_class_idx = findall(y .== target_class) | ||
target_class_idx_sampled = StatsBase.sample(rng, target_class_idx, n_samples, replace=replacement) | ||
append!(undersampled_idx, target_class_idx_sampled) | ||
else | ||
append!(undersampled_idx, findall(y .== target_class)) | ||
end | ||
end | ||
|
||
return MLUtils.getobs(X, undersampled_idx), MLUtils.getobs(y, undersampled_idx) | ||
end | ||
|
||
function random_undersampler( | ||
df, | ||
label::Union{Symbol, String, S}; | ||
sampling_strategy::Union{AbstractFloat, String, Dict{A, S}} = "auto", | ||
random_state::Union{Nothing, S} = nothing, | ||
replacement::Bool = false | ||
) where S <: Integer where A <: Any | ||
df = DataFrames.DataFrame(df) | ||
@assert label in names(df) "label or index $label does not exist in $df" | ||
|
||
Xover, yover = random_undersampler(DataFrames.select(df, DataFrames.Not(label)), df[!, label], sampling_strategy=sampling_strategy, random_state=random_state, replacement=replacement) | ||
DataFrames.join(Xover, DataFrames.DataFrame(label = yover)) | ||
return Xover | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any example inputs? Here is what I tried to get around this
@assert
:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I do have one I tried. I had .csv files for it too btw, you can find them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I've written the function such that
X
contains only the features and not the labels. I'm writing a overloaded function for a data matrix containing both features and labels.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your example, the function does work correctly anyway, you can access the columns and check the data or the count of it.