ROC曲线

最后发布时间:2023-04-25 17:02:38 浏览量:

ROC(receiver operating characteristic)

  • 应用于某指标对某疾病的诊断价值或者分类器预测结果评价
  • 横坐标是敏感性、假阳性率(sensitivity, False Positive Rate, FPR),纵坐标是特异性、真阳性率(specificity, True Positive Rate, TPR)
  • 假阳性率=预测结果中负样本个数 / 负样本的个数
  • 真阳性率=预测结果中正样本个数 / 正样本个数

假阳性:实际上是,但是检测出来不是
真阳性:实际上是,检测出来是

二分类模型的ROC

python程序求解案例

from sklearn.metrics import roc_curve, auc
import numpy as np
y = np.array([1, 1, 2, 2])
scores = np.array([0.1, 0.4, 0.35, 0.8])
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2)
auc = auc(fpr, tpr)
>>> fpr 
array([ 0. ,  0.5,  0.5,  1. ])
>>> tpr
array([ 0.5,  0.5,  1. ,  1. ])
>>> thresholds
array([ 0.8 ,  0.4 ,  0.35,  0.1 ])
  • 该示例有4个样本,参数说明如下:
    • y:样本的真值
    • pos_label=2:表明取值为2的样本是正样本。
    • scores:预测出的某样本是正样本的概率。
    • fpr、tpr:每个(fpr[i], tpr[i])都表示ROC曲线上的一个点,一共求解出了4个点。
    • thresholds:求解(fpr[i], tpr[i])时使用的阈值。

求解步骤

可以看出,阈值thresholds就是对概率scores进行了排序(倒序)。不断改变阈值,得到ROC曲线上不同的点。步骤如下:

  • threshold取0.8:也就是说预测概率大于等于0.8时,我们将样本预测为正样本。那么4个样本的预测结果为[1, 1, 1, 2]。负样本全部预测正确,正样本全部找到了。从而得到ROC曲线上一个点(0, 0.5)
  • threshold取0.4:预测概率大于等于0.4时,认为是正样本。预测结果为[1, 2, 1, 2]。结果比上次糟糕,负样本一个预测错误,正样本一个没有找到,从而得到ROC上面的(0.5, 0.5)点。
  • threshold取0.35:预测概率大于等于0.35时,认为是正样本。得到预测结果[1, 2, 2, 2]。负样本一个预测错误,正样本全部找出来了,从而得到(0.5, 1)
  • threshold取0.1:预测大于等于0.1时,就认为是正样本。尽管召回率很高,但预测结果再次变差,把所有样本都预测为了正样本,从而得到(1, 1)点。

可视化

import matplotlib.pyplot as plt
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

图片alt

图片alt

完整的ROC代码

from sklearn.metrics import roc_curve, auc


cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)
# cv = KFold(n_splits=5, shuffle=True, random_state=1)
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(select_X, Y)):
    clf.fit(select_X[train], Y[train])
    scores = clf.predict_proba(select_X[test])
    fpr, tpr, thresholds = roc_curve(Y[test], scores[:,0], pos_label=0)
    roc_auc = auc(fpr, tpr)
    interp_tpr = np.interp(mean_fpr, fpr, tpr)
    interp_tpr[0] = 0.0
    tprs.append(interp_tpr)
    aucs.append(roc_auc)
    
ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)




ax.plot(
    mean_fpr,
    mean_tpr,
    color="b",
    label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
    lw=2,
    alpha=0.8,
)
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(
    mean_fpr,
    tprs_lower,
    tprs_upper,
    color="grey",
    alpha=0.2,
    label=r"$\pm$ 1 std. dev.",
)

ax.set(
    xlim=[-0.05, 1.05],
    ylim=[-0.05, 1.05],
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title=f"Mean ROC curve with variability",
)
ax.axis("square")
ax.legend(loc="lower right")
plt.show()

带有生存信息ROC

options(repr.plot.width=9, repr.plot.height=9)
(function(lasso){
    png(filename = "COAD_ROC.png", width = 7, height = 7, res = 300, units = "in")

 color <- brewer.pal(8, "Dark2")
  plot(0,0,type="l",xlab="False postive rate",ylab="True positive rate",xlim=c(0,1),ylim=c(0,1), cex.lab =1.6,cex.axis=1.5)
  abline(0,1)
  legend_vector <- vector()
  color_vector <- vector()
  for(i in 1:5){
    roc <- survivalROC(Stime=lasso@clinical$futime_year ,
                       status=lasso@clinical$fustat,
                       marker=lasso@clinical$riskScoreNum,
                       predict.time=i,
                       method="KM")
    if(roc$AUC>0.7){
      lines(roc$FP,roc$TP,col=color[i],lwd=2)
      legend_vector <- c(legend_vector,paste0(i,"-year surval:",round(roc$AUC,3)))
      color_vector <- c(color_vector,color[i])
    }
  }

  legend(border=NA,0.5,0.4,
         legend_vector,
         col=color_vector,
         text.col=color_vector,lty=c(1,1,1),inset=.5,cex=1.3)
 dev.off()
})(lncRNA_lasso)

参考