Programing/R- programming

R - Cross-validation 평균 ROC 그리기

sosal 2022. 10. 2. 23:56
반응형

Split validation이나, Leave-One-Out Cross-validation (LOOCV) 를 하게 되면 적절하게 ROC curve를 그릴 수 있다.

 

Split validation은 test set이 명확하게 있으니, 해당 샘플에 대해 ROC curve를 그리면 되고,

LOOCV의 경우 데이터 하나당 Prediction 값을 저장해놓고 ROC curve를 그리면 된다.

 

그러면 Cross-validation은??

AUC 구하는거야, MRMC 기법이든 뭐든, 어쨌건 결과 값이 각 fold 별 AUC의 평균과 동일하기에 계산하면 되는데

ROC가 항상 문제이다.

 

Multiple ROC curve에 대한 평균을 구하는 방식으로,

Cross-validation의 mean ROC를 시각화 할 수 있는데, 2가지를 써봤다.

 

두가지 mean ROC를 만들어 내는 방법은 다르지만,

간단하게 정리하면 ROCR은 각 fold별 ROC의 tpr, fpr의 구간 평균값을 활용하여 그리는 방식,

cutpointr 방식은 특정 threshold 이내에 포함되는 데이터의 youden's index를 매번 찾아가며 그리는 방식인듯 하다.

(따라서 cutpointr 방식은, fold의 mean AUC와 mean ROC의 넓이가 다를 수도 있을 것 같다)

 

# Label, Pred 두 변수는 각 fold별 데이터를 갖는다고 가정한다.

ex)

#Fold: Cross validation vector
for(fold in 1:5){
    Label[fold] = Data[ Fold == fold ]$Label
    Pred = predict(model)
}

 

 

1. ROCR 패키지를 활용한 mean ROC

library(ROCR)
pred <- prediction(Pred1, Label)
perf <- performance(pred,"tpr","fpr")
plot(perf, col="grey82", lty=3)
plot(perf, lwd=3,avg="vertical",spread.estimate="boxplot",add=TRUE)

 

 

 

2. cutpointr을 활용한 mean ROC

library(cutpointr)
library(tidyverse)
mean_roc <- function(data, cutoffs = seq(from = -5, to = 5, by = 0.5)) {
    map_df(cutoffs, function(cp) {
        out <- cutpointr(data = data, x = Pred, class = Label, subgroup = Fold,
                         method = oc_manual, cutpoint = cp,
                         pos_class = TRUE, direction = ">=")
        data.frame(cutoff = cp, 
                   sensitivity = mean(out$sensitivity),
                   specificity = mean(out$specificity))
    })
}

 

MeanROCData <- data.frame(
    Fold = Fold,
    Pred  = unlist(Pred),
    Label = unlist(Label)
)

mr = mean_roc(MeanROCData, cutoffs = seq(from=0, to=1, by=0.01))
ggplot(mr[order(mr[,2], mr[,1]),], aes(x = 1 - specificity, y = sensitivity)) + geom_line()

cutpointr(data = MeanROCData, 
          x = Pred, class = Label, subgroup = Fold,
          pos_class = TRUE, direction = ">=") %>% 
    plot_roc(display_cutpoint = F) + theme(legend.position="none") +
    geom_line(data = mr[order(mr[,2], mr[,1]),], mapping = aes(x = 1 - specificity, y = sensitivity), 
              color = "black") + theme_bw()

 

 

 

그림은 매우 훌륭!