-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestingRandomForest.R
More file actions
79 lines (64 loc) · 1.94 KB
/
testingRandomForest.R
File metadata and controls
79 lines (64 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
setwd('~/IMLproject2021')
library(randomForest)
library(caret)
set.seed(42)
npf <- read.csv("npf_Train2mean.csv")
npf$class2 <- factor("event",levels=c("nonevent","event"))
npf$class2[npf$class4=="nonevent"] <- "nonevent"
rownames(npf) <- npf[,"date"]
npf <- npf[,-1]
idx <- sample.int(nrow(npf),92)
training_set <- npf[ idx,]
validation_set <- npf[-idx,]
ctrl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 10,
classProbs = TRUE,
savePredictions = "final")
rFClass4 <- train(factor(class4) ~ .,
method="rf",
data=training_set,
trControl=ctrl,
preProc=c("pca","center", "scale"))
rFClass4
pred4 <- predict(rFClass4, newdata = validation_set)
probs4 <- predict(rFClass4, newdata = validation_set, type = "prob")
confusionMatrix(factor(validation_set$class4), pred4)
testClass4 <-function(p) {
ia <- 0
ib <- 0
ii <- 0
nonevent <- 0
for (i in 1:length(dataset$class2)) {
if (p$Ia[i] >=0.5) {
ia = ia + 1
}
if (p$Ib[i] >=0.5) {
ib = ib + 1
}
if (p$II[i] >=0.5) {
ii = ii + 1
}
if (p$nonevent[i] >0.5) {
nonevent = nonevent + 1
}
}
return(list(ia, ib, ii, nonevent))
}
accurracy4 <- accClass4(probs4, validation_set)
ctrl <- trainControl(method = "repeatedcv",
number = 10,
repeats = 10,
classProbs = TRUE,
savePredictions = "final")
rFClass2 <- train(factor(class2) ~ .,
method="rf",
data=training_set,
trControl=ctrl,
preProc=c("pca","center", "scale"))
rFClass2
pred2 <- predict(rFClass2, newdata = validation_set)
probs2 <- predict(rFClass2, newdata = validation_set, type = "prob")
confusionMatrix(factor(validation_set$class2), pred2)
accuracy2 <- accClass2(probs2, validation_set)
accuracy2