R语言实现近邻法

最后发布时间:2022-03-23 16:33:19 浏览量:

input 格式

               group     BTG1       IDS      GPI
GSM5576716 treatment 10.66123 10.789608 7.520969
GSM5576717 treatment 10.52216  9.432275 6.718029
GSM5576718 treatment 10.43048 10.974920 7.966851
GSM5576719 treatment 10.68696  9.568500 6.514932
 m <- nrow(input)
  val <- sample(1:m,
                size = round(m/3),
                replace = FALSE,
                prob = rep(1/m, m))

  imodel <- kknn(group~.,input[-val,],input[val,])
  ModelMetrics::ce(input$group[val],imodel$fit)

参数调优

 train_kk <- train.kknn(group~.,
                          data=input,
                         kmax=100,
                          kernel = c("rectangular", "triangular",
                                     "epanechnikov", "gaussian",
                                     "rank", "optimal"))

train_kk$best.parameters$kernel
train_kk$best.parameters$k
Call:
train.kknn(formula = group ~ ., data = input, kmax = 100, kernel = c("rectangular",     "triangular", "epanechnikov", "gaussian", "rank", "optimal"))

Type of response variable: nominal
Minimal misclassification: 0.2413793
Best kernel: rectangular
Best k: 7

图形展示

 best_kernel <- train_kk$best.parameters$kernel
  best_k <- train_kk$best.parameters$k
  ce_kk <- train_kk$MISCLASS
  min_ce <- min(ce_kk)
  as.data.frame(train_kk$MISCLASS) |>
    mutate(k=1:nrow(ce_kk)) |>
    gather(key = "kernel", value="ce",-k) |>
    ggplot(aes(x=k,y=ce,colour=kernel)) +
    geom_line()+
    geom_point(aes(shape=kernel))+
    geom_vline(aes(xintercept=best_k),linetype="dashed")+
    geom_hline(aes(yintercept=min_ce),linetype="dashed")+
    theme(legend.position = c(0.9,0.8))

图片alt

图片alt

k 交叉验证更多

  global_performance <- NULL
  imetrics <- function (method,type,predicted,actual){
    con_table <- table(predicted,actual)
    cur_one <- data.frame(method=method,
                          type=type,
                          accuray=sum(diag(con_table)) /sum(con_table))
    assign("global_performance",
           rbind(get("global_performance",envir = .GlobalEnv),
                 cur_one),
           envir = .GlobalEnv)
  }
  sp <- Sys.time()
  cat(as.character(sp),"\n")
  kfolds <-  cv_kfold(input)
  for (i in 1:length(kfolds)){
    curr_fold <- kfolds[[i]]
    train_set <- input[-curr_fold,]
    test_set <- input[curr_fold,]
    predicted_train <- kknn(group~.,
                            train=train_set,
                            test=train_set,
                            k=best_k,
                            kernel = best_kernel)$fit
    imetrics("kknn","Train",predicted_train,train_set$group)
    predicted_test <- kknn(group~.,
                              train=train_set,
                              test=test_set,
                              k=best_k,
                              kernel = best_kernel)$fit
    imetrics("kknn","test",predicted_train,train_set$group)
  }
  ep <- Sys.time()
  cat(as.character(ep),"\n")
  difftime(ep,sp,units = "secs")

图片alt

图片alt

参考

快捷入口
R 思维导图 浏览PDF 下载PDF
分享到:
标签