-
Notifications
You must be signed in to change notification settings - Fork 2
/
driver.R
executable file
·72 lines (57 loc) · 1.89 KB
/
driver.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/Rscript
source("get_sepsis_score.R")
load_challenge_data = function(file){
data = data.matrix(read.csv(file, sep='|'))
column_names = colnames(data)
# Ignore SepsisLabel column if present.
if (column_names[ncol(data)] == 'SepsisLabel'){
data = data[, 1:ncol(data)-1]
}
return(data)
}
save_challenge_predictions = function(file, predictions){
colnames(predictions) = c('PredictedProbability', 'PredictedLabel')
write.table(predictions, file = file, sep = '|', quote = FALSE, row.names = FALSE)
}
# Parse arguments.
args = commandArgs(trailingOnly=TRUE)
if (length(args) != 2){
stop('Include the input and output directories as arguments, e.g., Rscript driver.R input output.')
}
input_directory = args[1]
output_directory = args[2]
# Find files.
files = c()
for (f in list.files(input_directory)){
if (file.exists(file.path(input_directory, f)) && nchar(f) >= 3 && substr(f, 1, 1) != '.' && substr(f, nchar(f)-2, nchar(f)) == 'psv'){
files = c(files, f)
}
}
if (!dir.exists(output_directory)){
dir.create(output_directory)
}
# Load model.
print('Loading sepsis model...')
model = load_sepsis_model()
# Iterate over files.
print('Predicting sepsis labels...')
num_files = length(files)
for (i in 1:num_files){
print(paste0(' ', i, '/', num_files, '...'))
# Load data.
input_file = file.path(input_directory, files[i])
data = load_challenge_data(input_file)
# Make predictions.
num_rows = nrow(data)
num_cols = ncol(data)
predictions = matrix(, num_rows, 2)
for (t in 1:num_rows){
current_data = matrix(data[1:t,], t, num_cols)
current_predictions = get_sepsis_score(current_data, model)
predictions[t,] = current_predictions
}
# Save results.
output_file = file.path(output_directory, files[i])
save_challenge_predictions(output_file, predictions)
}
print('Done.')