相关文章推荐
爱运动的小虾米  ·  Does ...·  5 月前    · 
暗恋学妹的铅笔  ·  Black Duck ...·  7 月前    · 
安静的吐司  ·  广东省发展和改革委员会 - 业务通知·  10 月前    · 
冷冷的小笼包  ·  九民纪要全文:《全国法院民商事审判工作会议纪 ...·  10 月前    · 
千杯不醉的上铺  ·  小程序简介 - SuperApp - 阿里云·  11 月前    · 
Code  ›  mlr3_训练和测试开发者社区
https://cloud.tencent.com/developer/article/1786164
高兴的四季豆
3 年前
作者头像
火星娃统计
0 篇文章

mlr3_训练和测试

前往专栏
腾讯云
备案 控制台
开发者社区
学习
实践
活动
专区
工具
TVP
文章/答案/技术大牛
写文章
社区首页 > 专栏 > 火星娃统计 > 正文

mlr3_训练和测试

发布 于 2021-02-05 16:39:11
530 0
举报

mlr3_训练和测试

概述

之前的章节中,我们已经建立了task和learner,接下来利用这两个R6对象,建立模型,并使用新的数据集对模型进行评估

建立task和learner

这里使用简单的tsk和lrn方法建立

task = tsk("sonar")
learner = lrn("classif.rpart")

设置训练和测试数据

这里设置的其实是task里面数据的行数目

train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

训练learner

$model 是learner中用来存储训练好的模型

# 可以看到目前是没有模型训练好的
learner$model
## NULL

接下来使用任务来训练learner

# 这里使用row_ids选择训练数据
learner$train(task, row_ids = train_set)
# 训练完成后查看模型
print(learner$model)

预测

使用剩余的数据进行预测 predict

# 返回每一个个案的预测结果
prediction = learner$predict(task, row_ids = test_set)
## <PredictionClassif> for 42 observations:
##     row_id truth response
##          2     R        R
##          6     R        R
##         12     R        M
## ---                      
##        191     M        M
##        199     M        M
##        204     M        M
# 为了提取预测后的数据,最好的办法是转换为data.table
head(as.data.table(prediction))
# 同时,我们需要计算混淆矩阵
prediction$confusion
##         truth
## response  M  R
##        M 15  3
##        R  8 16

改变预测的类型

这个部分主要是计算每一种类型的概率,有时候用于roc曲线的绘制

learner$predict_type = "prob"
# 重新训练
learner$train(task, row_ids = train_set)
# 重新预测
prediction = learner$predict(task, row_ids = test_set)
# 查看结果
head(as.data.table(prediction))
##    row_id truth response prob.M  prob.R
## 1:      2     R        R 0.2222 0.77778
## 2:      6     R        R 0.2222 0.77778
## 3:     12     R        M 0.9375 0.06250
## 4:     13     R        R 0.1429 0.85714
## 5:     30     R        R 0.2222 0.77778
## 6:     31     R        M 0.9535 0.04651

可以看到,里面出现了新的两列,用于描述各自的概率大小

绘制预测图

library("mlr3viz")
task = tsk("sonar")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
# 绘制默认图
autoplot(prediction)
# 绘制roc图
autoplot(prediction, type = "roc")

对于回归任务

library("mlr3viz")
library("mlr3learners")
task = tsk("mtcars")
learner = lrn("regr.lm")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

模型评估

mlr3 自带一系列的评估方法,如

mlr_measures
## <DictionaryMeasure> with 54 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
##   regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
##   regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
##   regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
##   selected_features, time_both, time_predict, time_train
# 使用msr获取评估的方法,这里是准确率
 
推荐文章
爱运动的小虾米  ·  Does sqlalchemy.sql.functions really need now() and current_timestamp()? · sqlalchemy/sqlalchemy · D
5 月前
暗恋学妹的铅笔  ·  Black Duck Documentation Portal
7 月前
安静的吐司  ·  广东省发展和改革委员会 - 业务通知
10 月前
冷冷的小笼包  ·  九民纪要全文:《全国法院民商事审判工作会议纪要》-甘肃省人民检察院
10 月前
千杯不醉的上铺  ·  小程序简介 - SuperApp - 阿里云
11 月前
今天看啥   ·   Py中国   ·   codingpro   ·   小百科   ·   link之家   ·   卧龙AI搜索
删除内容请联系邮箱 2879853325@qq.com
Code - 代码工具平台
© 2024 ~ 沪ICP备11025650号