失眠的红豆 · FFmpeg开发第四讲:FFmpeg + ...· 1 年前 · |
灰常酷的柳树 · 使用DownloadManager下载大文件 ...· 1 年前 · |
睿智的荒野 · html初始化调用js函数_51CTO博客_ ...· 1 年前 · |
独立的滑板 · 最适合入门的100个深度学习实战项目_深度学 ...· 1 年前 · |
严肃的鸵鸟 · 【天风金工吴先兴团队】FOF专题研究(三): ...· 1 年前 · |
evalCpp()
转换单一计算表达式
cppFunction()
转换简单的C++函数—Fibnacci例子
sourceCpp()
转换C++程序—正负交替迭代例子
sourceCpp()
转换C++源文件中的程序—正负交替迭代例子
sourceCpp()
转换C++源程序文件—卷积例子
wrap()
把C++变量返回到R中
as()
函数把R变量转换为C++类型
as()
和
wrap()
的隐含调用
//[[Rcpp::export]]
sourceCpp()
函数中直接包含C++源程序字符串
cppFunction()
函数中直接包含C++函数源程序字符串
evalCpp()
函数中直接包含C++源程序表达式字符串
depends
指定要链接的库
invisible
要求函数结果不自动显示
clone
函数
is_na
seq_along
seq_len
pmin
和
pmax
ifelse
sapply
和
lapply
sign
diff
kable()
函数制作表格
统计学习介绍的主要参考书: ( James et al. 2013 ) : Gareth James, Daniela Witten, Trevor Hastie, Robert Tibshirani(2013) An Introduction to Statistical Learning: with Applications in R, Springer.
调入需要的扩展包:
## Loading required package: Matrix
## Loaded glmnet 4.1-4
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## Loaded gbm 2.1.8.1
统计学习(statistical learning), 也有数据挖掘(data mining),机器学习(machine learning)等称呼。 主要目的是用一些计算机算法从大量数据中发现知识。 方兴未艾的数据科学就以统计学习为重要支柱。 方法分为有监督(supervised)学习与无监督(unsupervised)学习。
无监督学习方法如聚类问题、购物篮问题、主成分分析等。
有监督学习 即统计中回归分析和判别分析解决的问题, 现在又有树回归、树判别、随机森林、lasso、支持向量机、 神经网络、贝叶斯网络、排序算法等许多方法。
无监督学习 在给了数据之后, 直接从数据中发现规律, 比如聚类分析是发现数据中的聚集和分组现象, 购物篮分析是从数据中找到更多的共同出现的条目 (比如购买啤酒的用户也有较大可能购买火腿肠)。
有监督学习方法众多。 通常,需要把数据分为训练样本和检验样本, 训练样本的因变量(数值型或分类型)是已知的, 根据训练样本中自变量和因变量的关系训练出一个回归函数, 此函数以自变量为输入, 可以输出因变量的预测值。
训练出的函数有可能是有简单表达式的(例如,logistic回归)、 有参数众多的表达式的(如神经网络), 也有可能是依赖于所有训练样本而无法写出表达式的(例如k近邻分类)。
对回归问题,经常使用均方误差 \(E|Ey - \hat y|^2\) 来衡量精度。 对分类问题,经常使用分类准确率等来衡量精度。 易见 \(E|Ey - \hat y|^2 = \text{Var}(\hat y) + (E\hat y - E y)^2\) ,所以均方误差可以分解为 \text{均方误差} = \text{方差} + \text{偏差}^2,
训练的回归函数如果仅考虑对训练样本解释尽可能好, 就会使得估计结果方差很大,在对检验样本进行计算时因方差大而导致很大的误差, 所以选取的回归函数应该尽可能简单。
如果选取的回归函数过于简单而实际上自变量与因变量关系比较复杂, 就会使得估计的回归函数偏差比较大, 这样在对检验样本进行计算时也会有比较大的误差。
所以,在有监督学习时, 回归函数的复杂程度是一个很关键的量, 太复杂和太简单都可能导致差的结果, 需要找到一个折衷的值。
复杂程度在线性回归中就是自变量个数, 在一元曲线拟合中就是曲线的不光滑程度。 在其它指标类似的情况下,简单的模型更稳定、可解释更好, 所以统计学特别重视模型的简化。
即使是在从训练样本中修炼(估计)回归函数时, 也需要适当地选择模型的复杂度。 仅考虑对训练数据的拟合程度是不够的, 这会造成过度拟合问题。
为了相对客观地度量模型的预报误差, 假设训练样本有 \(n\) 个观测, 可以留出第一个观测不用, 用剩余的 \(n-1\) 个观测建模,然后预测第一个观测的因变量值, 得到一个误差;对每个观测都这样做, 就可以得到 \(n\) 个误差。 这样的方法叫做留一法。
更常用的是五折或十折交叉验证。 假设训练集有 \(n\) 个观测, 将其均分成 \(10\) 分, 保留第一份不用, 将其余九份合并在一起用来建模,然后预报第一份; 对每一份都这样做, 也可以得到 \(n\) 个误差, 这叫做十折(ten-fold)交叉验证方法。
因为要预报的数据没有用来建模, 交叉验证得到的误差估计更准确。
一个有监督的统计学习项目, 大致上按如下步骤进行:
## [1] "AtBat" "Hits" "HmRun" "Runs" "RBI" "Walks" "Years" "CAtBat" "CHits" "CHmRun" "CRuns" "CRBI" "CWalks" "League" "Division" "PutOuts" "Assists" "Errors" "Salary" "NewLeague"
数据集的详细变量信息如下:
## 'data.frame': 322 obs. of 20 variables:
## $ AtBat : int 293 315 479 496 321 594 185 298 323 401 ...
## $ Hits : int 66 81 130 141 87 169 37 73 81 92 ...
## $ HmRun : int 1 7 18 20 10 4 1 0 6 17 ...
## $ Runs : int 30 24 66 65 39 74 23 24 26 49 ...
## $ RBI : int 29 38 72 78 42 51 8 24 32 66 ...
## $ Walks : int 14 39 76 37 30 35 21 7 8 65 ...
## $ Years : int 1 14 3 11 2 11 2 3 2 13 ...
## $ CAtBat : int 293 3449 1624 5628 396 4408 214 509 341 5206 ...
## $ CHits : int 66 835 457 1575 101 1133 42 108 86 1332 ...
## $ CHmRun : int 1 69 63 225 12 19 1 0 6 253 ...
## $ CRuns : int 30 321 224 828 48 501 30 41 32 784 ...
## $ CRBI : int 29 414 266 838 46 336 9 37 34 890 ...
## $ CWalks : int 14 375 263 354 33 194 24 12 8 866 ...
## $ League : Factor w/ 2 levels "A","N": 1 2 1 2 2 1 2 1 2 1 ...
## $ Division : Factor w/ 2 levels "E","W": 1 2 2 1 1 2 1 2 2 1 ...
## $ PutOuts : int 446 632 880 200 805 282 76 121 143 0 ...
## $ Assists : int 33 43 82 11 40 421 127 283 290 0 ...
## $ Errors : int 20 10 14 3 4 25 7 9 19 0 ...
## $ Salary : num NA 475 480 500 91.5 750 70 100 75 1100 ...
## $ NewLeague: Factor w/ 2 levels "A","N": 1 2 1 2 2 1 1 1 2 1 ...
希望以Salary为因变量,查看其缺失值个数:
## [1] 59
为简单起见,去掉有缺失值的观测:
## [1] 263 20
用leaps包的
regsubsets()
函数计算最优子集回归,
办法是对某个试验性的子集自变量个数
\(\hat p\)
值,
都找到
\(\hat p\)
固定情况下残差平方和最小的变量子集,
这样只要在这些不同
\(\hat p\)
的最优子集中挑选就可以了。
挑选可以用AIC、BIC等方法。
可以先进行一个包含所有自变量的全集回归:
regfit.full <- regsubsets(Salary ~ ., data=d, nvmax=19) reg.summary <- summary(regfit.full) reg.summary
## Subset selection object
## Call: regsubsets.formula(Salary ~ ., data = d, nvmax = 19)
## 19 Variables (and intercept)
## Forced in Forced out
## AtBat FALSE FALSE
## Hits FALSE FALSE
## HmRun FALSE FALSE
## Runs FALSE FALSE
## RBI FALSE FALSE
## Walks FALSE FALSE
## Years FALSE FALSE
## CAtBat FALSE FALSE
## CHits FALSE FALSE
## CHmRun FALSE FALSE
## CRuns FALSE FALSE
## CRBI FALSE FALSE
## CWalks FALSE FALSE
## LeagueN FALSE FALSE
## DivisionW FALSE FALSE
## PutOuts FALSE FALSE
## Assists FALSE FALSE
## Errors FALSE FALSE
## NewLeagueN FALSE FALSE
## 1 subsets of each size up to 19
## Selection Algorithm: exhaustive
## AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN
## 1 ( 1 ) " " " " " " " " " " " " " " " " " " " " " " "*" " " " " " " " " " " " " " "
## 2 ( 1 ) " " "*" " " " " " " " " " " " " " " " " " " "*" " " " " " " " " " " " " " "
## 3 ( 1 ) " " "*" " " " " " " " " " " " " " " " " " " "*" " " " " " " "*" " " " " " "
## 4 ( 1 ) " " "*" " " " " " " " " " " " " " " " " " " "*" " " " " "*" "*" " " " " " "
## 5 ( 1 ) "*" "*" " " " " " " " " " " " " " " " " " " "*" " " " " "*" "*" " " " " " "
## 6 ( 1 ) "*" "*" " " " " " " "*" " " " " " " " " " " "*" " " " " "*" "*" " " " " " "
## 7 ( 1 ) " " "*" " " " " " " "*" " " "*" "*" "*" " " " " " " " " "*" "*" " " " " " "
## 8 ( 1 ) "*" "*" " " " " " " "*" " " " " " " "*" "*" " " "*" " " "*" "*" " " " " " "
## 9 ( 1 ) "*" "*" " " " " " " "*" " " "*" " " " " "*" "*" "*" " " "*" "*" " " " " " "
## 10 ( 1 ) "*" "*" " " " " " " "*" " " "*" " " " " "*" "*" "*" " " "*" "*" "*" " " " "
## 11 ( 1 ) "*" "*" " " " " " " "*" " " "*" " " " " "*" "*" "*" "*" "*" "*" "*" " " " "
## 12 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " " " " "*" "*" "*" "*" "*" "*" "*" " " " "
## 13 ( 1 ) "*" "*" " " "*" " " "*" " " "*" " " " " "*" "*" "*" "*" "*" "*" "*" "*" " "
## 14 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" " " " " "*" "*" "*" "*" "*" "*" "*" "*" " "
## 15 ( 1 ) "*" "*" "*" "*" " " "*" " " "*" "*" " " "*" "*" "*" "*" "*" "*" "*" "*" " "
## 16 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*" " " "*" "*" "*" "*" "*" "*" "*" "*" " "
## 17 ( 1 ) "*" "*" "*" "*" "*" "*" " " "*" "*" " " "*" "*" "*" "*" "*" "*" "*" "*" "*"
## 18 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*" " " "*" "*" "*" "*" "*" "*" "*" "*" "*"
## 19 ( 1 ) "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*" "*"
这里用
nvmax=
指定了允许所有的自变量都参加,
缺省行为是限制最多个数的。
上述结果表格中每一行给出了固定
\(\hat p\)
条件下的最优子集。
试比较这些最优模型的BIC值:
## [1] -90.84637 -128.92622 -135.62693 -141.80892 -144.07143 -147.91690 -145.25594 -147.61525 -145.44316 -143.21651 -138.86077 -133.87283 -128.77759 -123.64420 -118.21832 -112.81768 -107.35339 -101.86391 -96.30412
图38.1: Hitters数据最优子集回归BIC
其中
\(\hat p=6, 8\)
的值相近,都很低,
取
\(\hat p=6\)
。
用
coef()
加
id=6
指定第六种子集:
## (Intercept) AtBat Hits Walks CRBI DivisionW PutOuts
## 91.5117981 -1.8685892 7.6043976 3.6976468 0.6430169 -122.9515338 0.2643076
这种方法实现了选取BIC最小的自变量子集。
在用做了全集回归后, 把全集回归结果输入到函数中可以执行逐步回归。
## Call: ## lm(formula = Salary ~ ., data = d) ## Residuals: ## Min 1Q Median 3Q Max ## -907.62 -178.35 -31.11 139.09 1877.04 ## Coefficients: ## Estimate Std. Error t value Pr(>|t|) ## (Intercept) 163.10359 90.77854 1.797 0.073622 . ## AtBat -1.97987 0.63398 -3.123 0.002008 ** ## Hits 7.50077 2.37753 3.155 0.001808 ** ## HmRun 4.33088 6.20145 0.698 0.485616 ## Runs -2.37621 2.98076 -0.797 0.426122 ## RBI -1.04496 2.60088 -0.402 0.688204 ## Walks 6.23129 1.82850 3.408 0.000766 *** ## Years -3.48905 12.41219 -0.281 0.778874 ## CAtBat -0.17134 0.13524 -1.267 0.206380 ## CHits 0.13399 0.67455 0.199 0.842713 ## CHmRun -0.17286 1.61724 -0.107 0.914967 ## CRuns 1.45430 0.75046 1.938 0.053795 . ## CRBI 0.80771 0.69262 1.166 0.244691 ## CWalks -0.81157 0.32808 -2.474 0.014057 * ## LeagueN 62.59942 79.26140 0.790 0.430424 ## DivisionW -116.84925 40.36695 -2.895 0.004141 ** ## PutOuts 0.28189 0.07744 3.640 0.000333 *** ## Assists 0.37107 0.22120 1.678 0.094723 . ## Errors -3.36076 4.39163 -0.765 0.444857 ## NewLeagueN -24.76233 79.00263 -0.313 0.754218 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## Residual standard error: 315.6 on 243 degrees of freedom ## Multiple R-squared: 0.5461, Adjusted R-squared: 0.5106 ## F-statistic: 15.39 on 19 and 243 DF, p-value: < 2.2e-16## Start: AIC=3046.02
## Salary ~ AtBat + Hits + HmRun + Runs + RBI + Walks + Years +
## CAtBat + CHits + CHmRun + CRuns + CRBI + CWalks + League +
## Division + PutOuts + Assists + Errors + NewLeague
## Df Sum of Sq RSS AIC
## - CHmRun 1 1138 24201837 3044.0
## - CHits 1 3930 24204629 3044.1
## - Years 1 7869 24208569 3044.1
## - NewLeague 1 9784 24210484 3044.1
## - RBI 1 16076 24216776 3044.2
## - HmRun 1 48572 24249272 3044.6
## - Errors 1 58324 24259023 3044.7
## - League 1 62121 24262821 3044.7
## - Runs 1 63291 24263990 3044.7
## - CRBI 1 135439 24336138 3045.5
## - CAtBat 1 159864 24360564 3045.8
## <none> 24200700 3046.0
## - Assists 1 280263 24480963 3047.1
## - CRuns 1 374007 24574707 3048.1
## - CWalks 1 609408 24810108 3050.6
## - Division 1 834491 25035190 3052.9
## - AtBat 1 971288 25171987 3054.4
## - Hits 1 991242 25191941 3054.6
## - Walks 1 1156606 25357305 3056.3
## - PutOuts 1 1319628 25520328 3058.0
## Step: AIC=3044.03
## Salary ~ AtBat + Hits + HmRun + Runs + RBI + Walks + Years +
## CAtBat + CHits + CRuns + CRBI + CWalks + League + Division +
## PutOuts + Assists + Errors + NewLeague
## Df Sum of Sq RSS AIC
## - Years 1 7609 24209447 3042.1
## - NewLeague 1 10268 24212106 3042.2
## - CHits 1 14003 24215840 3042.2
## - RBI 1 14955 24216793 3042.2
## - HmRun 1 52777 24254614 3042.6
## - Errors 1 59530 24261367 3042.7
## - League 1 63407 24265244 3042.7
## - Runs 1 64860 24266698 3042.7
## - CAtBat 1 174992 24376830 3043.9
## <none> 24201837 3044.0
## - Assists 1 285766 24487603 3045.1
## - CRuns 1 611358 24813196 3048.6
## - CWalks 1 645627 24847464 3049.0
## - Division 1 834637 25036474 3050.9
## - CRBI 1 864220 25066057 3051.3
## - AtBat 1 970861 25172699 3052.4
## - Hits 1 1025981 25227819 3052.9
## - Walks 1 1167378 25369216 3054.4
## - PutOuts 1 1325273 25527110 3056.1
## Step: AIC=3042.12
## Salary ~ AtBat + Hits + HmRun + Runs + RBI + Walks + CAtBat +
## CHits + CRuns + CRBI + CWalks + League + Division + PutOuts +
## Assists + Errors + NewLeague
## Df Sum of Sq RSS AIC
## - NewLeague 1 9931 24219377 3040.2
## - RBI 1 15989 24225436 3040.3
## - CHits 1 18291 24227738 3040.3
## - HmRun 1 54144 24263591 3040.7
## - Errors 1 57312 24266759 3040.7
## - Runs 1 63172 24272619 3040.8
## - League 1 65732 24275178 3040.8
## <none> 24209447 3042.1
## - CAtBat 1 266205 24475652 3043.0
## - Assists 1 293479 24502926 3043.3
## - CRuns 1 646350 24855797 3047.1
## - CWalks 1 649269 24858716 3047.1
## - Division 1 827511 25036958 3049.0
## - CRBI 1 872121 25081568 3049.4
## - AtBat 1 968713 25178160 3050.4
## - Hits 1 1018379 25227825 3050.9
## - Walks 1 1164536 25373983 3052.5
## - PutOuts 1 1334525 25543972 3054.2
## Step: AIC=3040.22
## Salary ~ AtBat + Hits + HmRun + Runs + RBI + Walks + CAtBat +
## CHits + CRuns + CRBI + CWalks + League + Division + PutOuts +
## Assists + Errors
## Df Sum of Sq RSS AIC
## - RBI 1 15800 24235177 3038.4
## - CHits 1 15859 24235237 3038.4
## - Errors 1 54505 24273883 3038.8
## - HmRun 1 54938 24274316 3038.8
## - Runs 1 62294 24281671 3038.9
## - League 1 107479 24326856 3039.4
## <none> 24219377 3040.2
## - CAtBat 1 261336 24480713 3041.1
## - Assists 1 295536 24514914 3041.4
## - CWalks 1 648860 24868237 3045.2
## - CRuns 1 661449 24880826 3045.3
## - Division 1 824672 25044049 3047.0
## - CRBI 1 880429 25099806 3047.6
## - AtBat 1 999057 25218434 3048.9
## - Hits 1 1034463 25253840 3049.2
## - Walks 1 1157205 25376583 3050.5
## - PutOuts 1 1335173 25554550 3052.3
## Step: AIC=3038.4
## Salary ~ AtBat + Hits + HmRun + Runs + Walks + CAtBat + CHits +
## CRuns + CRBI + CWalks + League + Division + PutOuts + Assists +
## Errors
## Df Sum of Sq RSS AIC
## - CHits 1 13483 24248660 3036.5
## - HmRun 1 44586 24279763 3036.9
## - Runs 1 54057 24289234 3037.0
## - Errors 1 57656 24292833 3037.0
## - League 1 108644 24343821 3037.6
## <none> 24235177 3038.4
## - CAtBat 1 252756 24487934 3039.1
## - Assists 1 294674 24529851 3039.6
## - CWalks 1 639690 24874868 3043.2
## - CRuns 1 693535 24928712 3043.8
## - Division 1 808984 25044161 3045.0
## - CRBI 1 893830 25129008 3045.9
## - Hits 1 1034884 25270061 3047.4
## - AtBat 1 1042798 25277975 3047.5
## - Walks 1 1145013 25380191 3048.5
## - PutOuts 1 1340713 25575890 3050.6
## Step: AIC=3036.54
## Salary ~ AtBat + Hits + HmRun + Runs + Walks + CAtBat + CRuns +
## CRBI + CWalks + League + Division + PutOuts + Assists + Errors
## Df Sum of Sq RSS AIC
## - HmRun 1 40487 24289148 3035.0
## - Errors 1 51930 24300590 3035.1
## - Runs 1 79343 24328003 3035.4
## - League 1 114742 24363402 3035.8
## <none> 24248660 3036.5
## - Assists 1 283442 24532103 3037.6
## - CAtBat 1 613356 24862016 3041.1
## - Division 1 801474 25050134 3043.1
## - CRBI 1 903248 25151908 3044.2
## - CWalks 1 1011953 25260613 3045.3
## - Walks 1 1246164 25494824 3047.7
## - AtBat 1 1339620 25588280 3048.7
## - CRuns 1 1390808 25639469 3049.2
## - PutOuts 1 1406023 25654684 3049.4
## - Hits 1 1607990 25856650 3051.4
## Step: AIC=3034.98
## Salary ~ AtBat + Hits + Runs + Walks + CAtBat + CRuns + CRBI +
## CWalks + League + Division + PutOuts + Assists + Errors
## Df Sum of Sq RSS AIC
## - Errors 1 44085 24333232 3033.5
## - Runs 1 49068 24338215 3033.5
## - League 1 103837 24392985 3034.1
## <none> 24289148 3035.0
## - Assists 1 247002 24536150 3035.6
## - CAtBat 1 652746 24941894 3040.0
## - Division 1 795643 25084791 3041.5
## - CWalks 1 982896 25272044 3043.4
## - Walks 1 1205823 25494971 3045.7
## - AtBat 1 1300972 25590120 3046.7
## - CRuns 1 1351200 25640348 3047.2
## - CRBI 1 1353507 25642655 3047.2
## - PutOuts 1 1429006 25718154 3048.0
## - Hits 1 1574140 25863288 3049.5
## Step: AIC=3033.46
## Salary ~ AtBat + Hits + Runs + Walks + CAtBat + CRuns + CRBI +
## CWalks + League + Division + PutOuts + Assists
## Df Sum of Sq RSS AIC
## - Runs 1 54113 24387345 3032.0
## - League 1 91269 24424501 3032.4
## <none> 24333232 3033.5
## - Assists 1 220010 24553242 3033.8
## - CAtBat 1 650513 24983746 3038.4
## - Division 1 799455 25132687 3040.0
## - CWalks 1 971260 25304493 3041.8
## - Walks 1 1239533 25572765 3044.5
## - CRBI 1 1331672 25664904 3045.5
## - CRuns 1 1361070 25694302 3045.8
## - AtBat 1 1378592 25711824 3045.9
## - PutOuts 1 1391660 25724892 3046.1
## - Hits 1 1649291 25982523 3048.7
## Step: AIC=3032.04
## Salary ~ AtBat + Hits + Walks + CAtBat + CRuns + CRBI + CWalks +
## League + Division + PutOuts + Assists
## Df Sum of Sq RSS AIC
## - League 1 113056 24500402 3031.3
## <none> 24387345 3032.0
## - Assists 1 280689 24668034 3033.1
## - CAtBat 1 596622 24983967 3036.4
## - Division 1 780369 25167714 3038.3
## - CWalks 1 946687 25334032 3040.1
## - Walks 1 1212997 25600342 3042.8
## - CRuns 1 1334397 25721742 3044.1
## - CRBI 1 1361339 25748684 3044.3
## - PutOuts 1 1455210 25842555 3045.3
## - AtBat 1 1522760 25910105 3046.0
## - Hits 1 1718870 26106215 3047.9
## Step: AIC=3031.26
## Salary ~ AtBat + Hits + Walks + CAtBat + CRuns + CRBI + CWalks +
## Division + PutOuts + Assists
## Df Sum of Sq RSS AIC
## <none> 24500402 3031.3
## - Assists 1 313650 24814051 3032.6
## - CAtBat 1 534156 25034558 3034.9
## - Division 1 798473 25298875 3037.7
## - CWalks 1 965875 25466276 3039.4
## - CRuns 1 1265082 25765484 3042.5
## - Walks 1 1290168 25790569 3042.8
## - CRBI 1 1326770 25827172 3043.1
## - PutOuts 1 1551523 26051925 3045.4
## - AtBat 1 1589780 26090181 3045.8
## - Hits 1 1716068 26216469 3047.1
## Call:
## lm(formula = Salary ~ AtBat + Hits + Walks + CAtBat + CRuns +
## CRBI + CWalks + Division + PutOuts + Assists, data = d)
## Coefficients:
## (Intercept) AtBat Hits Walks CAtBat CRuns CRBI CWalks DivisionW PutOuts Assists
## 162.5354 -2.1687 6.9180 5.7732 -0.1301 1.4082 0.7743 -0.8308 -112.3801 0.2974 0.2832
最后保留了10个自变量。
在整个数据集中随机选取一部分作为训练集,其余作为测试集。 下面的程序把原始数据一分为二:
仅用训练集估计模型。 为了在测试集上用模型进行预报并估计预测均方误差, 需要自己写一个预测函数:
predict.regsubsets <- function(object, newdata, id, ...){ form <- as.formula(object$call[[2]]) mat <- model.matrix(form, newdata) coefi <- coef(object, id=id) xvars <- names(coefi) mat[, xvars] %*% coefi
然后,对每个子集大小,用最优子集在测试集上进行预报, 计算均方误差:
regfit.best <- regsubsets( Salary ~ ., data=d[train,], nvmax=19 ) val.errors <- rep(as.numeric(NA), 19) for(i in 1:19){ #pred <- predict.regsubsets(regfit.best, newdata=d[test,], id=i) pred <- predict(regfit.best, newdata=d[test,], id=i) val.errors[i] <- mean( (d[test, 'Salary'] - pred)^2 ) print(val.errors)
## [1] 188190.9 163306.2 152365.0 164857.0 152100.7 147120.0 148833.0 155546.5 167429.2 169949.1 173607.9 173039.5 168450.4 169300.5 169139.3 173575.1 175216.2 175080.2 175057.5
## [1] 6
用测试集得到的最优子集大小为6。 模型子集和回归系数为:
## (Intercept) Walks CAtBat CHits CHmRun DivisionW PutOuts ## 179.4442609 4.1205817 -0.5508342 2.1670021 2.3479409 -126.3067258 0.1840943
38.2.1.4 用10折交叉验证方法选择最优子集
下列程序对数据中每一行分配一个折号:
下面,对10折中每一折都分别当作测试集一次, 得到不同子集大小的均方误差:
cv.errors <- matrix( as.numeric(NA), k, 19, dimnames=list(NULL, paste(1:19)) ) for(j in 1:k){ # 对 best.fit <- regsubsets(Salary ~ ., data=d[folds != j,], nvmax=19) for(i in 1:19){ pred <- predict( best.fit, d[folds==j,], id=i) cv.errors[j, i] <- mean( (d[folds==j, 'Salary'] - pred)^2 ) head(cv.errors)
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ## [1,] 98623.24 115600.61 120884.31 113831.63 120728.51 122922.93 155507.25 137753.36 149198.01 153332.89 155702.91 155842.88 158755.87 156037.17 157739.46 155548.96 156688.01 156860.92 156976.98 ## [2,] 155320.11 100425.87 168838.35 159729.47 145895.71 123555.25 119983.35 96609.16 99057.32 80375.78 91290.74 92292.69 100498.84 101562.45 104621.38 100922.27 102198.69 105318.26 106064.89 ## [3,] 124151.77 68833.50 69392.29 77221.37 83802.82 70125.41 68997.77 64143.70 65813.14 65120.27 68160.94 70263.77 69765.81 68987.54 69471.32 69294.21 69199.91 68866.84 69195.74 ## [4,] 232191.41 279001.29 294568.10 288765.81 276972.83 260121.22 276413.09 259923.88 270151.18 263492.31 259154.53 269017.80 265468.90 269666.65 265518.87 267240.44 267771.74 267670.66 267717.80 ## [5,] 115397.35 96807.44 108421.66 104933.55 99561.69 86103.05 89345.61 87693.15 91631.88 88763.37 89801.07 91070.44 92429.43 92821.15 95849.81 96513.70 95209.20 94952.21 94951.70 ## [6,] 103839.30 75652.50 69962.31 58291.91 65893.45 64215.56 65800.88 61413.45 60200.70 59599.54 59831.90 60081.48 59662.51 60618.91 62540.03 62776.81 62717.77 62354.97 62268.97
cv.errors
是一个\(10\times 19\)矩阵, 每行对应一折作为测试集的情形, 每列是一个子集大小, 元素值是测试均方误差。对每列的10个元素求平均, 可以得到每个子集大小的平均均方误差:
图38.2: Hitters数据CV均方误差## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ## 149821.1 130922.0 139127.0 131028.8 131050.2 119538.6 124286.1 113580.0 115556.5 112216.7 113251.2 115755.9 117820.8 119481.2 120121.6 120074.3 120084.8 120085.8 120403.5
这样找到的最优子集大小是10。 用这种方法找到最优子集大小后, 可以对全数据集重新建模但是选择最优子集大小为10:
划分训练集和验证集与交叉验证方法经常联合运用。 取一个固定的较小规模的测试集, 此测试集不用来作子集选择, 对训练集用交叉验证方法选择最优子集, 然后再测试集上验证。## (Intercept) AtBat Hits Walks CAtBat CRuns CRBI CWalks DivisionW PutOuts Assists ## 162.5354420 -2.1686501 6.9180175 5.7732246 -0.1300798 1.4082490 0.7743122 -0.8308264 -112.3800575 0.2973726 0.2831680
38.2.2 岭回归
当自变量个数太多时,模型复杂度高, 可能有过度拟合, 模型不稳定。
一种方法是对较大的模型系数施加二次惩罚, 把最小二乘问题变成带有二次惩罚项的惩罚最小二乘问题: \[\begin{aligned} \min\; \sum_{i=1}^n \left( y_i - \beta_0 - \beta_1 x_{i1} - \dots - \beta_p x_{ip} \right)^2 + \lambda \sum_{j=1}^p \beta_j^2 . \end{aligned}\] 这比通常最小二乘得到的回归系数绝对值变小, 但是求解的稳定性增加了,避免了共线问题。
与线性模型\(\boldsymbol Y = \boldsymbol X \boldsymbol\beta + \boldsymbol\varepsilon\) 的普通最小二乘解 \(\hat{\boldsymbol\beta} = (\boldsymbol X^T \boldsymbol X)^{-1} \boldsymbol X^T \boldsymbol Y\) 岭回归问题的解为 \tilde{\boldsymbol\beta} = (\boldsymbol X^T \boldsymbol X + s \boldsymbol I)^{-1} \boldsymbol X^T \boldsymbol Y 其中\(\boldsymbol I\)为单位阵,\(s>0\)与\(\lambda\)有关。\(\lambda\)称为调节参数,\(\lambda\)越大,相当于模型复杂度越低。 适当选择\(\lambda\)可以在方差与偏差之间找到适当的折衷, 从而减小预测误差。
由于量纲问题,在不同自变量不可比时,数据集应该进行标准化。
用R的glmnet包计算岭回归。 用
glmnet()
函数, 指定参数alpha=0
时执行的是岭回归。 用参数lambda=
指定一个调节参数网格, 岭回归将在这些调节参数上计算。 用coef()
从回归结果中取得不同调节参数对应的回归系数估计, 结果是一个矩阵,每列对应于一个调节参数。仍采用上面去掉了缺失值的Hitters数据集结果d。
如下程序把回归的设计阵与因变量提取出来:
岭回归涉及到调节参数\(\lambda\)的选择, 为了绘图, 先选择\(\lambda\)的一个网格:
用所有数据针对这样的调节参数网格计算岭回归结果, 注意
glmnet()
函数允许调节参数\(\lambda\)输入多个值:## [1] 20 100
glmnet()
函数默认对数据进行标准化。
coef()
的结果是一个矩阵,每列对应一个调节参数值。38.2.2.1 划分训练集与测试集
如下程序把数据分为一半训练、一半测试:
仅用测试集建立岭回归:
用建立的模型对测试集进行预测,并计算调节参数等于4时的均方误差:
## [1] 142199.2
如果用因变量平均值作预测, 这是最差的预测:
## [1] 224669.9
\(\lambda=4\)的结果要好得多。 事实上,取\(\lambda\)接近正无穷时模型就相当于用因变量平均值预测。 取\(\lambda=0\)就相当于普通最小二乘回归(但是
glmnet()
是对输入数据要做标准化的)。38.2.2.2 用10折交叉验证选取调节参数
仍使用训练集, 但训练集再进行交叉验证。
图38.3: Hitters数据岭回归参数选择cv.glmnet()
函数可以执行交叉验证。这样获得了最优调节参数\(\lambda=\) 326.0827865。 用最优调节参数对测试集作预测, 得到预测均方误差:
## [1] 139856.6
结果比\(\lambda=4\)略有改进。
最后,用选取的最优调节系数对全数据集建模, 得到相应的岭回归系数估计:
## (Intercept) AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN ## 15.44383120 0.07715547 0.85911582 0.60103106 1.06369007 0.87936105 1.62444617 1.35254778 0.01134999 0.05746654 0.40680157 0.11456224 0.12116504 0.05299202 22.09143197 -79.04032656 0.16619903 0.02941950 -1.36092945 9.12487765
38.2.3 Lasso回归
另一种对回归系数的惩罚是\(L_1\)惩罚: \[\begin{align} \min\; \sum_{i=1}^n \left( y_i - \beta_0 - \beta_1 x_{i1} - \dots - \beta_p x_{ip} \right)^2 + \lambda \sum_{j=1}^p |\beta_j| . \tag{38.1} \end{align}\] 奇妙地是,适当选择调节参数\(\lambda\),可以使得部分回归系数变成零, 达到了即减小回归系数的绝对值又挑选重要变量子集的效果。
事实上,(38.1)等价于约束最小值问题 \[\begin{aligned} & \min\; \sum_{i=1}^n \left( y_i - \beta_0 - \beta_1 x_{i1} - \dots - \beta_p x_{ip} \right)^2 \quad \text{s.t.} \\ & \sum_{j=1}^p |\beta_j| \leq s \end{aligned}\] 其中\(s\)与\(\lambda\)一一对应。 这样的约束区域是带有顶点的凸集, 而目标函数是二次函数, 最小值点经常在约束区域顶点达到, 这些顶点是某些坐标等于零的点。 见图38.4。
图38.4: Lasso约束优化问题图示对于每个调节参数\(\lambda\), 都应该解出(38.1)的相应解, 记为\(\hat{\boldsymbol\beta}(\lambda)\)。 幸运的是, 不需要对每个\(\lambda\)去解最小值问题(38.1), 存在巧妙的算法使得问题的计算量与求解一次最小二乘相仿。
通常选取\(\lambda\)的格子点,计算相应的惩罚回归系数。 用交叉验证方法估计预测的均方误差。 选取使得交叉验证均方误差最小的调节参数(一般R函数中已经作为选项)。
用R的glmnet包计算lasso。 用
glmnet()
函数, 指定参数alpha=1
时执行的是lasso。 用参数lambda=
指定一个调节参数网格, lasso将输出这些调节参数对应的结果。 对回归结果使用plot()
函数可以画出调节参数变化时系数估计的变化情况。仍使用gmlnet包的
glmnet()
函数计算Lasso回归, 指定一个调节参数网格(沿用前面的网格):图38.5: Hitters数据lasso轨迹## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm): collapsing to unique 'x' values
对lasso结果使用
plot()
函数可以绘制延调节参数网格变化的各回归系数估计,横坐标不是调节参数而是调节参数对应的系数绝对值和, 可以看出随着系数绝对值和增大,实际是调节参数变小, 更多地自变量进入模型。38.2.3.1 用交叉验证估计调节参数
按照前面划分的训练集与测试集, 仅使用训练集数据做交叉验证估计最优调节参数:
## [1] 9.286955
得到调节参数估计后,对测试集计算预测均方误差:
## [1] 143673.6
这个效果比岭回归效果略差。
为了充分利用数据, 使用前面获得的最优调节参数, 对全数据集建模:
out <- glmnet(x, y, alpha=1, lambda=grid) lasso.coef <- predict(out, type='coefficients', s=bestlam)[1:20,]; lasso.coef
## (Intercept) AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI CWalks LeagueN DivisionW PutOuts Assists Errors NewLeagueN ## 1.27479059 -0.05497143 2.18034583 0.00000000 0.00000000 0.00000000 2.29192406 -0.33806109 0.00000000 0.00000000 0.02825013 0.21628385 0.41712537 0.00000000 20.28615023 -116.16755870 0.23752385 0.00000000 -0.85629148 0.00000000
## (Intercept) AtBat Hits Walks Years CHmRun CRuns CRBI LeagueN DivisionW PutOuts Errors ## 1.27479059 -0.05497143 2.18034583 2.29192406 -0.33806109 0.02825013 0.21628385 0.41712537 20.28615023 -116.16755870 0.23752385 -0.85629148
选择的自变量子集有11个自变量。
38.2.4 树回归的简单演示
决策树方法按不同自变量的不同值, 分层地把训练集分组。 每层使用一个变量, 所以这样的分组构成一个二叉树表示。 为了预测一个观测的类归属, 找到它所属的组, 用组的类归属或大多数观测的类归属进行预测。 这样的方法称为决策树(decision tree)。 决策树方法既可以用于判别问题, 也可以用于回归问题,称为回归树。
决策树的好处是容易解释, 在自变量为分类变量时没有额外困难。 但预测准确率可能比其它有监督学习方法差。
改进方法包括装袋法(bagging)、随机森林(random forests)、 提升法(boosting)。 这些改进方法都是把许多棵树合并在一起, 通常能改善准确率但是可解释性变差。
对Hitters数据,用Years和Hits作因变量预测log(Salaray)。
仅取Hitters数据集的Salary, Years, Hits三个变量, 并仅保留完全观测:
## 'data.frame': 263 obs. of 3 variables: ## $ Salary: num 475 480 500 91.5 750 ... ## $ Years : int 14 3 11 2 11 2 3 2 13 10 ... ## $ Hits : int 81 130 141 87 169 37 73 81 92 159 ... ## - attr(*, "na.action")= 'omit' Named int [1:59] 1 16 19 23 31 33 37 39 40 42 ... ## ..- attr(*, "names")= chr [1:59] "-Andy Allanson" "-Billy Beane" "-Bruce Bochte" "-Bob Boone" ... ## NULL
建立完整的树:
剪枝为只有3个叶结点:
## node), split, n, deviance, yval ## * denotes terminal node ## 1) root 263 207.20 5.927 ## 2) Years < 4.5 90 42.35 5.107 * ## 3) Years > 4.5 173 72.71 6.354 ## 6) Hits < 117.5 90 28.09 5.998 * ## 7) Hits > 117.5 83 20.88 6.740 *
显示概括:
## Regression tree: ## snip.tree(tree = tr1, nodes = c(6L, 2L)) ## Number of terminal nodes: 3 ## Residual mean deviance: 0.3513 = 91.33 / 260 ## Distribution of residuals: ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## -2.24000 -0.39580 -0.03162 0.00000 0.33380 2.55600
把数据随机地分成一半训练集,一半测试集:
对训练集,建立未剪枝的树:
图38.6: Hitters数据训练集未剪枝树对训练集上的未剪枝树用交叉验证方法寻找最优大小:
## $size
## [1] 8 7 6 5 4 3 2 1
## $dev
## [1] 44.55223 44.45312 44.57906 44.53469 46.93001 54.03823 57.53660 105.17743
## $k
## [1] -Inf 1.679266 1.750440 1.836204 3.300858 6.230249 7.420672 56.727362
## $method
## [1] "deviance"
## attr(,"class")
## [1] "prune" "tree.sequence"
plot(cv1$size, cv1$dev, type='b') best.size <- cv1$size[which.min(cv1$dev)[1]] abline(v=best.size, col='gray')
最优大小为7。 获得训练集上构造的树剪枝后的结果:
在测试集上计算预测均方误差:
pred.test <- predict(tr1b, newdata=d[test,]) test.mse <- mean( (d[test, 'Salary'] - exp(pred.test))^2 ) test.mse
## [1] 128224.1
如果用训练集的因变量平均值估计测试集的因变量值, 均方误差为:
## [1] 224692.1
用所有数据来构造未剪枝树:
用训练集上得到的子树大小剪枝:
判别树在不同的训练集、测试集划分上可以产生很大变化, 说明其预测值方差较大。 利用bootstrap的思想, 可以随机选取许多个训练集, 把许多个训练集的模型结果平均, 就可以降低预测值的方差。
办法是从一个训练集中用有放回抽样的方法抽取 \(B\) 个训练集, 设第 \(b\) 个抽取的训练集得到的回归函数为 \(\hat f^{*b}(\cdot)\) , 则最后的回归函数是这些回归函数的平均值: \[\begin{aligned} \hat f_{\text{bagging}}(x) = \frac{1}{B} \sum_{b=1}^b \hat f^{*b}(x) \end{aligned}\] 这称为装袋法(bagging)。 装袋法对改善判别与回归树的精度十分有效。
装袋法的步骤如下:
装袋法也可以用来改进其他的回归和判别方法。
装袋后不能再用图形表示,模型可解释性较差。 但是,可以度量自变量在预测中的重要程度。 在回归问题中, 可以计算每个自变量在所有 \(B\) 个树种平均减少的残差平方和的量, 以此度量其重要度。 在判别问题中, 可以计算每个自变量在所有 \(B\) 个树种平均减少的基尼系数的量, 以此度量其重要度。
除了可以用测试集、交叉验证方法以外, 还可以使用袋外观测预测误差。 用bootstrap再抽样获得多个训练集时每个bootstrap训练集总会遗漏一些观测, 平均每个bootstrap训练集会遗漏三分之一的观测。 对每个观测,大约有 \(B/3\) 棵树没有用到此观测, 可以用这些树的预测值平均来预测此观测,得到一个误差估计, 这样得到的均方误差估计或错判率称为 袋外观测估计 (OOB估计)。 好处是不用很多额外的工作。
对训练集用装袋法:
bag1 <- randomForest(log(Salary) ~ ., data=d, subset=train, mtry=ncol(d)-1, importance=TRUE) ## Call: ## randomForest(formula = log(Salary) ~ ., data = d, mtry = ncol(d) - 1, importance = TRUE, subset = train) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 19 ## Mean of squared residuals: 0.2549051 ## % Var explained: 67.32
注意
randomForest()
函数实际是随机森林法,
但是当
mtry
的值取为所有自变量个数时就是装袋法。
对测试集进行预报:
pred2 <- predict(bag1, newdata=d[test,]) test.mse2 <- mean( (d[test, 'Salary'] - exp(pred2))^2 ) test.mse2
## [1] 89851.48
结果与剪枝过的单课树相近。
在全集上使用装袋法:
bag2 <- randomForest(log(Salary) ~ ., data=d, mtry=ncol(d)-1, importance=TRUE) ## Call: ## randomForest(formula = log(Salary) ~ ., data = d, mtry = ncol(d) - 1, importance = TRUE) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 19 ## Mean of squared residuals: 0.1894008 ## % Var explained: 75.95
变量的重要度数值和图形: 各变量的重要度数值及其图形:
## %IncMSE IncNodePurity
## AtBat 10.4186792 8.9315248
## Hits 8.0033436 7.5938472
## HmRun 3.6180992 1.9689157
## Runs 7.2586283 3.9341954
## RBI 5.9223739 5.9201328
## Walks 7.6449979 6.6988173
## Years 9.4732817 2.2977104
## CAtBat 28.2381456 84.4338845
## CHits 13.8405414 26.2455674
## CHmRun 6.7109109 3.8246805
## CRuns 13.6067783 29.1340423
## CRBI 14.1694017 10.9852537
## CWalks 7.7656943 4.1725799
## League -1.0964577 0.2146176
## Division 0.5534057 0.2307037
## PutOuts 0.3157195 4.1169655
## Assists -1.7730978 1.6599765
## Errors 2.3420783 1.6852796
## NewLeague -0.3091532 0.3747445
图38.7: Hitters数据装袋法的变量重要性结果
最重要的自变量是CAtBats, 其次有CRuns, CHits等。
随机森林的思想与装袋法类似, 但是试图使得参加平均的各个树之间变得比较独立。 仍采用有放回抽样得到的多个bootstrap训练集, 但是对每个bootstrap训练集构造判别树时, 每次分叉时不考虑所有自变量, 而是仅考虑随机选取的一个自变量子集。
对判别树,每次分叉时选取的自变量个数通常取 \(m \approx \sqrt{p}\) 个。 比如,对Heart数据的13个自变量, 每次分叉时仅随机选取4个纳入考察范围。
随机森林的想法是基于正相关的样本在平均时并不能很好地降低方差, 独立样本能比较好地降低方差。 如果存在一个最重要的变量, 如果不加限制这个最重要的变量总会是第一个分叉, 使得 \(B\) 棵树相似程度很高。 随机森林解决这个问题的办法是限制分叉时可选的变量子集。
随机森林也可以用来改进其他的回归和判别方法。
装袋法和随机森林都可以用R扩展包randomForest的
randomForest()
函数实现。
当此函数的
mtry
参数取为自变量个数时,执行的就是装袋法;
mtry
取缺省值时,执行随机森林算法。
执行随机森林算法时,
randomForest()
函数在回归问题时分叉时考虑的自变量个数取
\(m \approx p/3\)
,
在判别问题时取
\(m \approx \sqrt{p}\)
。
对训练集用随机森林法:
rf1 <- randomForest(log(Salary) ~ ., data=d, subset=train, importance=TRUE) ## Call: ## randomForest(formula = log(Salary) ~ ., data = d, importance = TRUE, subset = train) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 6 ## Mean of squared residuals: 0.2422914 ## % Var explained: 68.94
当
mtry
的值取为缺省值时执行随机森林算法。
对测试集进行预报:
pred3 <- predict(rf1, newdata=d[test,]) test.mse3 <- mean( (d[test, 'Salary'] - exp(pred3))^2 ) test.mse3
## [1] 95455.53
结果与剪枝过的单课树、装袋法相近。
在全集上使用随机森林:
rf2 <- randomForest(log(Salary) ~ ., data=d, importance=TRUE) ## Call: ## randomForest(formula = log(Salary) ~ ., data = d, importance = TRUE) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 6 ## Mean of squared residuals: 0.1789257 ## % Var explained: 77.28
各变量的重要度数值及其图形:
## %IncMSE IncNodePurity
## AtBat 10.1010786 7.6235571
## Hits 8.3614365 7.9201425
## HmRun 3.7354302 2.4553471
## Runs 7.9446786 4.6566217
## RBI 6.6246470 6.0251832
## Walks 8.8848565 5.9901840
## Years 11.3153391 8.4097999
## CAtBat 18.0377154 41.0851904
## CHits 17.5110322 37.8686798
## CHmRun 8.4476983 6.9078686
## CRuns 14.8793512 30.8423409
## CRBI 14.8308800 20.5339522
## CWalks 9.7578555 14.9467745
## League -0.6619015 0.3041459
## Division -0.3583808 0.2954341
## PutOuts 2.4366422 3.4901980
## Assists -0.0965240 1.8105096
## Errors 1.5007791 1.7068960
## NewLeague 1.0593259 0.3453514
图38.8: Hitters数据随机森林法的变量重要度结果
最重要的自变量是CAtBats, CRuns, CHits, CWalks, CRBI等。
Heart数据是心脏病诊断的数据, 因变量AHD为是否有心脏病, 试图用各个自变量预测(判别)。
读入Heart数据集,并去掉有缺失值的观测:
Heart <- read.csv( "data/Heart.csv", header=TRUE, row.names=1, stringsAsFactors=TRUE) Heart <- na.omit(Heart) str(Heart)
## 'data.frame': 297 obs. of 14 variables:
## $ Age : int 63 67 67 37 41 56 62 57 63 53 ...
## $ Sex : int 1 1 1 1 0 1 0 0 1 1 ...
## $ ChestPain: Factor w/ 4 levels "asymptomatic",..: 4 1 1 2 3 3 1 1 1 1 ...
## $ RestBP : int 145 160 120 130 130 120 140 120 130 140 ...
## $ Chol : int 233 286 229 250 204 236 268 354 254 203 ...
## $ Fbs : int 1 0 0 0 0 0 0 0 0 1 ...
## $ RestECG : int 2 2 2 0 2 0 2 0 2 2 ...
## $ MaxHR : int 150 108 129 187 172 178 160 163 147 155 ...
## $ ExAng : int 0 1 1 0 0 0 0 1 0 1 ...
## $ Oldpeak : num 2.3 1.5 2.6 3.5 1.4 0.8 3.6 0.6 1.4 3.1 ...
## $ Slope : int 3 2 2 3 1 1 3 1 2 3 ...
## $ Ca : int 0 3 2 0 0 0 2 0 1 0 ...
## $ Thal : Factor w/ 3 levels "fixed","normal",..: 1 2 3 2 2 2 2 2 3 3 ...
## $ AHD : Factor w/ 2 levels "No","Yes": 1 2 2 1 1 1 2 1 2 2 ...
## - attr(*, "na.action")= 'omit' Named int [1:6] 88 167 193 267 288 303
## ..- attr(*, "names")= chr [1:6] "88" "167" "193" "267" ...
## Age Min. :29.00 1st Qu.:48.00 Median :56.00 Mean :54.54 3rd Qu.:61.00 Max. :77.00
## Sex Min. :0.0000 1st Qu.:0.0000 Median :1.0000 Mean :0.6768 3rd Qu.:1.0000 Max. :1.0000
## ChestPain asymptomatic:142 nonanginal : 83 nontypical : 49 typical : 23
## RestBP Min. : 94.0 1st Qu.:120.0 Median :130.0 Mean :131.7 3rd Qu.:140.0 Max. :200.0
## Chol Min. :126.0 1st Qu.:211.0 Median :243.0 Mean :247.4 3rd Qu.:276.0 Max. :564.0
## Fbs Min. :0.0000 1st Qu.:0.0000 Median :0.0000 Mean :0.1448 3rd Qu.:0.0000 Max. :1.0000
## RestECG Min. :0.0000 1st Qu.:0.0000 Median :1.0000 Mean :0.9966 3rd Qu.:2.0000 Max. :2.0000
## MaxHR Min. : 71.0 1st Qu.:133.0 Median :153.0 Mean :149.6 3rd Qu.:166.0 Max. :202.0
## ExAng Min. :0.0000 1st Qu.:0.0000 Median :0.0000 Mean :0.3266 3rd Qu.:1.0000 Max. :1.0000
## Oldpeak Min. :0.000 1st Qu.:0.000 Median :0.800 Mean :1.056 3rd Qu.:1.600 Max. :6.200
## Slope Min. :1.000 1st Qu.:1.000 Median :2.000 Mean :1.603 3rd Qu.:2.000 Max. :3.000
## Ca Min. :0.0000 1st Qu.:0.0000 Median :0.0000 Mean :0.6768 3rd Qu.:1.0000 Max. :3.0000
## Thal fixed : 18 normal :164 reversable:115
## AHD No :160 Yes:137
简单地把观测分为一半训练集、一半测试集:
set.seed(1) train <- sample(nrow(Heart), size=round(nrow(Heart)/2)) test <- (-train) test.y <- Heart[test, 'AHD']
在训练集上建立未剪枝的判别树:
用交叉验证方法确定剪枝保留的叶子个数,剪枝时按照错判率执行:
cv1 <- cv.tree(tr1, FUN=prune.misclass)
## $size ## [1] 12 9 6 4 2 1 ## $dev ## [1] 42 44 47 44 57 69 ## $k ## [1] -Inf 0.000000 1.666667 3.000000 7.000000 26.000000 ## $method ## [1] "misclass" ## attr(,"class") ## [1] "prune" "tree.sequence"
最优的大小是12。但是从图上看,4个叶结点已经足够好,所以取为4。
对训练集生成剪枝结果:
图38.9: Heart数据回归树注意剪枝后树的显示中, 内部节点的自变量存在分类变量, 这时按照这个自变量分叉时, 取指定的某几个分类值时对应分支Yes, 取其它的分类值时对应分支No。
38.3.1.3 对测试集计算误判率
## test.y ## pred1 No Yes ## No 56 17 ## Yes 21 55
## [1] 0.2550336
对测试集的错判率约26%。
利用未剪枝的树对测试集进行预测, 一般比剪枝后的结果差:
## test.y ## pred1a No Yes ## No 58 21 ## Yes 19 51
## [1] 0.2684564
38.3.1.4 利用全集数据建立剪枝判别树
tr2 <- tree(AHD ~ ., data=Heart) tr2b <- prune.misclass(tr2, best=best.size) plot(tr2b); text(tr2b, pretty=0)
38.3.2 用装袋法
对训练集用装袋法:
bag1 <- randomForest(AHD ~ ., data=Heart, subset=train, mtry=13, importance=TRUE) ## Call: ## randomForest(formula = AHD ~ ., data = Heart, mtry = 13, importance = TRUE, subset = train) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 13 ## OOB estimate of error rate: 22.3% ## Confusion matrix: ## No Yes class.error ## No 71 12 0.1445783 ## Yes 21 44 0.3230769
注意
randomForest()
函数实际是随机森林法, 但是当mtry
的值取为所有自变量个数时就是装袋法。 袋外观测得到的错判率比较差。对测试集进行预报:
## test.y ## pred2 No Yes ## No 66 17 ## Yes 11 55
## [1] 0.1879195
测试集的错判率约为19%。
对全集用装袋法:
## Call: ## randomForest(formula = AHD ~ ., data = Heart, mtry = 13, importance = TRUE) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 13 ## OOB estimate of error rate: 20.88% ## Confusion matrix: ## No Yes class.error ## No 131 29 0.1812500 ## Yes 33 104 0.2408759
各变量的重要度数值及其图形:
## No Yes MeanDecreaseAccuracy MeanDecreaseGini
## Age 6.5766876 5.12005531 8.7542379 12.2956568
## Sex 11.2077275 4.48390165 11.2853739 3.7278623
## ChestPain 13.0268932 17.89348038 20.4292863 23.3424850
## RestBP 2.6203153 0.05626759 2.0521195 9.7650173
## Chol -0.8712348 -4.23294461 -3.0733270 11.5911988
## Fbs -0.6941335 -1.16860850 -1.2288380 0.6775051
## RestECG -1.4881617 0.23292163 -0.8772267 1.8426038
## MaxHR 7.7625054 2.34660468 7.5122314 13.2101707
## ExAng 2.7926364 5.45108497 5.7525854 3.5491718
## Oldpeak 14.8193517 14.67748373 20.2425364 14.5480191
## Slope 2.5189935 5.73789018 5.9744484 4.2777028
## Ca 23.0513399 18.01671793 27.4320740 20.0564750
## Thal 20.1968435 18.74418431 25.0618361 28.2479833
最重要的变量是Thal, ChestPain, Ca。
对训练集用随机森林法:
rf1 <- randomForest(AHD ~ ., data=Heart, subset=train, importance=TRUE) ## Call: ## randomForest(formula = AHD ~ ., data = Heart, importance = TRUE, subset = train) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 3 ## OOB estimate of error rate: 21.62% ## Confusion matrix: ## No Yes class.error ## No 71 12 0.1445783 ## Yes 20 45 0.3076923
这里
mtry
取缺省值,对应于随机森林法。
对测试集进行预报:
## test.y
## pred3 No Yes
## No 70 16
## Yes 7 56
## [1] 0.1543624
测试集的错判率约为15%。
对全集用随机森林:
rf1b <- randomForest(AHD ~ ., data=Heart, importance=TRUE) ## Call: ## randomForest(formula = AHD ~ ., data = Heart, importance = TRUE) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 3 ## OOB estimate of error rate: 16.5% ## Confusion matrix: ## No Yes class.error ## No 140 20 0.1250000 ## Yes 29 108 0.2116788
各变量的重要度数值及其图形:
## No Yes MeanDecreaseAccuracy MeanDecreaseGini
## Age 7.2380857 5.4451404 9.1859647 12.908917
## Sex 10.1973138 8.0790483 12.6929315 4.938266
## ChestPain 10.4623927 16.7054395 18.8771946 18.218363
## RestBP 1.2157266 1.8875229 2.1025511 10.624864
## Chol -1.2630538 -0.4285615 -1.3028275 11.470420
## Fbs 0.4417651 -2.6574949 -1.4327524 1.418137
## RestECG -1.1149040 1.4649476 0.2220661 2.840670
## MaxHR 9.3788412 6.0542618 10.7139921 17.383623
## ExAng 3.3923281 9.8037831 9.3523828 6.715947
## Oldpeak 10.1617047 14.3372404 17.5616061 15.403935
## Slope 2.6703016 9.3147738 8.5774100 6.752552
## Ca 21.1750038 20.2285033 26.7114362 18.388524
## Thal 18.0250446 16.6731737 22.5365589 18.351175
图38.10: Heart数据随机森林方法得到的变量重要度
最重要的变量是ChestPain, Thal, Ca。
Carseats是ISLR包的一个数据集,基本情况如下:
{rstatl-car-summ01, cache=TRUE} str(Carseats) summary(Carseats)
把Salses变量按照大于8与否分成两组, 结果存入变量High,以High为因变量作判别分析。
## [1] 400 12
对全体数据建立未剪枝的判别树:
## Classification tree: ## tree(formula = High ~ . - Sales, data = d) ## Variables actually used in tree construction: ## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population" "Advertising" "Age" "US" ## Number of terminal nodes: 27 ## Residual mean deviance: 0.4575 = 170.7 / 373 ## Misclassification error rate: 0.09 = 36 / 400把输入数据集随机地分一半当作训练集,另一半当作测试集:
set.seed(2) train <- sample(nrow(d), size=round(nrow(d)/2)) test <- (-train) test.high <- d[test, 'High']
用训练数据建立未剪枝的判别树:
## Classification tree: ## tree(formula = High ~ . - Sales, data = d, subset = train) ## Variables actually used in tree construction: ## [1] "Price" "Population" "ShelveLoc" "Age" "Education" "CompPrice" "Advertising" "Income" "US" ## Number of terminal nodes: 21 ## Residual mean deviance: 0.5543 = 99.22 / 179 ## Misclassification error rate: 0.115 = 23 / 200用未剪枝的树对测试集进行预测,并计算误判率:
## test.high
## pred2 No Yes
## No 104 33
## Yes 13 50
## [1] 0.23
set.seed(3) cv1 <- cv.tree(tr2, FUN=prune.misclass)
## $size ## [1] 21 19 14 9 8 5 3 2 1 ## $dev ## [1] 74 76 81 81 75 77 78 85 81 ## $k ## [1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0 ## $method ## [1] "misclass" ## attr(,"class") ## [1] "prune" "tree.sequence"
用交叉验证方法自动选择的最佳树大小为21。
## Classification tree: ## tree(formula = High ~ . - Sales, data = d, subset = train) ## Variables actually used in tree construction: ## [1] "Price" "Population" "ShelveLoc" "Age" "Education" "CompPrice" "Advertising" "Income" "US" ## Number of terminal nodes: 21 ## Residual mean deviance: 0.5543 = 99.22 / 179 ## Misclassification error rate: 0.115 = 23 / 200
用剪枝后的树对测试集进行预测,计算误判率:
## test.high
## pred3 No Yes
## No 104 32
## Yes 13 51
## [1] 0.225
对训练集用随机森林法:
rf4 <- randomForest(High ~ . - Sales, data=d, subset=train, importance=TRUE) ## Call: ## randomForest(formula = High ~ . - Sales, data = d, importance = TRUE, subset = train) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 3 ## OOB estimate of error rate: 25.5% ## Confusion matrix: ## No Yes class.error ## No 102 17 0.1428571 ## Yes 34 47 0.4197531
这里
mtry
取缺省值,对应于随机森林法。
对测试集进行预报:
## test.high
## pred4 No Yes
## No 109 24
## Yes 8 59
## [1] 0.16
注意错判率结果依赖于训练集和测试集的划分, 另行选择训练集与测试集可能会得到很不一样的错判率结果。
对全集用随机森林:
rf5 <- randomForest(High ~ . - Sales, data=d, importance=TRUE) ## Call: ## randomForest(formula = High ~ . - Sales, data = d, importance = TRUE) ## Type of random forest: classification ## Number of trees: 500 ## No. of variables tried at each split: 3 ## OOB estimate of error rate: 18.25% ## Confusion matrix: ## No Yes class.error ## No 213 23 0.09745763 ## Yes 50 114 0.30487805
各变量的重要度数值及其图形:
## No Yes MeanDecreaseAccuracy MeanDecreaseGini
## CompPrice 11.0998129 5.4168875 11.65469579 21.820876
## Income 3.2897388 4.7177705 5.75865440 20.384692
## Advertising 10.8093624 16.0263308 18.31175988 23.350563
## Population -3.1872660 -1.7367798 -3.63402082 15.670307
## Price 30.0864270 28.7929995 37.44125656 43.492787
## ShelveLoc 30.2789749 33.8109594 39.67983055 30.053785
## Age 9.7116826 9.0261373 12.78808426 22.578000
## Education 0.2214031 -0.3203644 0.06365633 9.899447
## Urban 1.3826674 1.4199879 1.98859615 2.128048
## US 3.7289827 5.1909662 6.83788775 3.405420
图38.11: Carseats数据随机森林法得到的变量重要度
重要的自变量为Price, ShelfLoc, 其次有Age, Advertising, CompPrice, Income等。
MASS包的Boston数据包含了波士顿地区郊区房价的若干数据。 以中位房价medv为因变量建立回归模型。 首先把缺失值去掉后存入数据集d:
数据集概况:
## 'data.frame': 506 obs. of 14 variables:
## $ crim : num 0.00632 0.02731 0.02729 0.03237 0.06905 ...
## $ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
## $ indus : num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
## $ chas : int 0 0 0 0 0 0 0 0 0 0 ...
## $ nox : num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
## $ rm : num 6.58 6.42 7.18 7 7.15 ...
## $ age : num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
## $ dis : num 4.09 4.97 4.97 6.06 6.06 ...
## $ rad : int 1 2 2 3 3 3 5 5 5 5 ...
## $ tax : num 296 242 242 222 222 222 311 311 311 311 ...
## $ ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
## $ black : num 397 397 393 395 397 ...
## $ lstat : num 4.98 9.14 4.03 2.94 5.33 ...
## $ medv : num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
## crim zn indus chas nox rm age dis rad tax ptratio black lstat medv
## Min. : 0.00632 Min. : 0.00 Min. : 0.46 Min. :0.00000 Min. :0.3850 Min. :3.561 Min. : 2.90 Min. : 1.130 Min. : 1.000 Min. :187.0 Min. :12.60 Min. : 0.32 Min. : 1.73 Min. : 5.00
## 1st Qu.: 0.08205 1st Qu.: 0.00 1st Qu.: 5.19 1st Qu.:0.00000 1st Qu.:0.4490 1st Qu.:5.886 1st Qu.: 45.02 1st Qu.: 2.100 1st Qu.: 4.000 1st Qu.:279.0 1st Qu.:17.40 1st Qu.:375.38 1st Qu.: 6.95 1st Qu.:17.02
## Median : 0.25651 Median : 0.00 Median : 9.69 Median :0.00000 Median :0.5380 Median :6.208 Median : 77.50 Median : 3.207 Median : 5.000 Median :330.0 Median :19.05 Median :391.44 Median :11.36 Median :21.20
## Mean : 3.61352 Mean : 11.36 Mean :11.14 Mean :0.06917 Mean :0.5547 Mean :6.285 Mean : 68.57 Mean : 3.795 Mean : 9.549 Mean :408.2 Mean :18.46 Mean :356.67 Mean :12.65 Mean :22.53
## 3rd Qu.: 3.67708 3rd Qu.: 12.50 3rd Qu.:18.10 3rd Qu.:0.00000 3rd Qu.:0.6240 3rd Qu.:6.623 3rd Qu.: 94.08 3rd Qu.: 5.188 3rd Qu.:24.000 3rd Qu.:666.0 3rd Qu.:20.20 3rd Qu.:396.23 3rd Qu.:16.95 3rd Qu.:25.00
## Max. :88.97620 Max. :100.00 Max. :27.74 Max. :1.00000 Max. :0.8710 Max. :8.780 Max. :100.00 Max. :12.127 Max. :24.000 Max. :711.0 Max. :22.00 Max. :396.90 Max. :37.97 Max. :50.00
对训练集建立未剪枝的树:
## Regression tree: ## tree(formula = medv ~ ., data = d, subset = train) ## Variables actually used in tree construction: ## [1] "rm" "lstat" "crim" "age" ## Number of terminal nodes: 7 ## Residual mean deviance: 10.38 = 2555 / 246 ## Distribution of residuals: ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800用未剪枝的树对测试集进行预测,计算均方误差:
yhat <-predict(tr1, newdata=d[test,]) mse1 <- mean((yhat - d[test, 'medv'])^2)
## [1] 35.28688
38.5.1.2 用交叉验证方法确定剪枝复杂度
## [1] 7
剪枝并对测试集进行预测:
yhat <-predict(tr2, newdata=d[test,]) mse2 <- mean((yhat - d[test, 'medv'])^2)
## [1] 35.28688
剪枝后效果没有改善。
38.5.2 装袋法
用randomForest包计算。 当参数
mtry
取为自变量个数时按照装袋法计算。 对训练集计算。
set.seed(1) bag1 <- randomForest( medv ~ ., data=d, subset=train, mtry=ncol(d)-1, importance=TRUE) ## Call: ## randomForest(formula = medv ~ ., data = d, mtry = ncol(d) - 1, importance = TRUE, subset = train) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 13 ## Mean of squared residuals: 11.39601 ## % Var explained: 85.17
在测试集上计算装袋法的均方误差:
## [1] 23.59273
比单棵树的结果有明显改善。
38.5.3 随机森林
用randomForest包计算。 当参数
mtry
取为缺省值时按照随机森林方法计算。 对训练集计算。
set.seed(1) rf1 <- randomForest( medv ~ ., data=d, subset=train, importance=TRUE) ## Call: ## randomForest(formula = medv ~ ., data = d, importance = TRUE, subset = train) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 4 ## Mean of squared residuals: 10.23441 ## % Var explained: 86.69
在测试集上计算随机森林法的均方误差:
## [1] 18.11686
比单棵树的结果有明显改善, 比装袋法的结果也好一些。
各变量的重要度数值及其图形:
图38.12: Boston数据用随机森林法得到的变量重要度## %IncMSE IncNodePurity ## crim 15.372334 1220.14856 ## zn 3.335435 194.85945 ## indus 6.964559 1021.94751 ## chas 2.059298 69.68099 ## nox 14.009761 1005.14707 ## rm 28.693900 6162.30720 ## age 13.832143 708.55138 ## dis 10.317731 852.33701 ## rad 4.390624 162.22597 ## tax 7.536563 564.60422 ## ptratio 9.333716 1163.39624 ## black 8.341316 355.62445 ## lstat 27.132450 5549.25088
[(1)] 对训练集,设置\(r_i = y_i\),并令初始回归函数为\(\hat f(\cdot)=0\)。
[(2)] 对\(b=1,2,\dots,B\)重复执行:
[(a)] 以训练集的自变量为自变量,以\(r\)为因变量,拟合一个仅有\(d\)个分叉的简单树回归函数, 设为\(\hat f_b\); [(b)] 更新回归函数,添加一个压缩过的树回归函数: \[\begin{aligned} \hat f(x) \leftarrow \hat f(x) + \lambda \hat f_b(x); \end{aligned}\] [(c)] 更新残差: \[\begin{aligned} r_i \leftarrow r_i - \lambda \hat f_b(x_i). \end{aligned}\] [(3)] 提升法的回归函数为 \[\begin{aligned} \hat f(x) = \sum_{b=1}^B \lambda \hat f_b(x) . \end{aligned}\]
用多少个回归函数做加权和,即\(B\)的选取问题。 取得\(B\)太大也会有过度拟合, 但是只要\(B\)不太大这个问题不严重。 可以用交叉验证选择\(B\)的值。
收缩系数\(\lambda\)。 是一个小的正数, 控制学习速度, 经常用0.01, 0.001这样的值, 与要解决的问题有关。 取\(\lambda\)很小,就需要取\(B\)很大。
用来控制每个回归函数复杂度的参数, 对树回归而言就是树的大小。 一个分叉的树往往就很好。 取单个分叉时结果模型是可加模型, 没有交互项, 这是因为每个加权相加得回归函数都只依赖于单一自变量。 \(d>1\)时就加入了交互项。
使用gbm包。 在训练集上拟合:
set.seed(1) bst1 <- gbm(medv ~ ., data=d[train,], distribution='gaussian', n.trees=5000, interaction.depth=4) summary(bst1)
## var rel.inf ## rm rm 43.9919329 ## lstat lstat 33.1216941 ## crim crim 4.2604167 ## dis dis 4.0111090 ## nox nox 3.4353017 ## black black 2.8267554 ## age age 2.6113938 ## ptratio ptratio 2.5403035 ## tax tax 1.4565654 ## indus indus 0.8008740 ## rad rad 0.6546400 ## zn zn 0.1446149 ## chas chas 0.1443986
lstat和rm是最重要的变量。
在测试集上预报,并计算均方误差:
## [1] 18.84709
与随机森林方法结果相近。
如果提高学习速度:
bst2 <- gbm(medv ~ ., data=d[train,], distribution='gaussian', n.trees=5000, interaction.depth=4, shrinkage=0.2) yhat <- predict(bst2, newdata=d[test,], n.trees=5000) mean( (yhat - d[test, 'medv'])^2 )
## [1] 18.33455
均方误差有改善。
38.6 支持向量机方法
支持向量机是1990年代有计算机科学家发明的一种有监督学习方法, 使用范围较广,预测精度较高。
支持向量机利用了Hilbert空间的方法将线性问题扩展为非线性问题。 线性的支持向量判别法, 可以通过\(\mathbb R^p\)的内积将线性的判别函数转化为如下的表示:
\[\begin{aligned} f(\boldsymbol x) = \beta_0 + \sum_{i=1}^n \alpha_i \langle \boldsymbol x, \boldsymbol x_i \rangle \end{aligned}\] 其中\(\beta_0, \alpha_1, \dots, \alpha_n\)是待定参数。 为了估计参数, 不需要用到各\(\boldsymbol x_i\)的具体值, 而只需要其两两的内积值, 而且在判别函数中只有支持向量对应的\(\alpha_i\)才非零, 记\(\mathcal S\)为支持向量点集, 则线性判别函数为 \[\begin{aligned} f(\boldsymbol x) = \beta_0 + \sum_{i \in \mathcal S} \alpha_i \langle \boldsymbol x, \boldsymbol x_i \rangle \end{aligned}\]
支持向量机方法将\(\mathbb R^p\)中的内积推广为如下的核函数值: \[\begin{aligned} K(\boldsymbol x, \boldsymbol x') \end{aligned}\] 核函数\(K(\boldsymbol x, \boldsymbol x')\), \(\boldsymbol x, \boldsymbol x' \in \mathbb R^p\) 是度量两个观测点\(\boldsymbol x, \boldsymbol x'\)的相似程度的函数。 \[\begin{aligned} K(\boldsymbol x, \boldsymbol x') = \sum_{j=1}^p x_j x_j' \end{aligned}\] 就又回到了线性的支持向量判别法。
核有多种取法。 \[\begin{aligned} K(\boldsymbol x, \boldsymbol x') = \left\{ 1 + \sum_{j=1}^p x_j x_j' \right\}^d \end{aligned}\] 其中\(d>1\)为正整数, 称为多项式核, 则结果是多项式边界的判别法, 本质上是对线性的支持向量方法添加了高次项和交叉项。
利用核代替内积后, 判别法的判别函数变成 \[\begin{aligned} f(\boldsymbol x) = \beta_0 + \sum_{i \in \mathcal S} K(\boldsymbol x, \boldsymbol x_i) \end{aligned}\]
另一种常用的核是径向核(radial kernel), \[\begin{aligned} K(\boldsymbol x, \boldsymbol x') = \exp\left\{ - \gamma \sum_{j=1}^p (x_j - x_j')^2 \right\} \end{aligned}\] \(\gamma\)为正常数。 当\(\boldsymbol x\)和\(\boldsymbol x'\)分别落在以原点为中心的两个超球面上时, 其核函数值不变。
使用径向核时, 判别函数为 \[\begin{aligned} f(\boldsymbol x) = \beta_0 + \sum_{i \in \mathcal S} \exp\left\{ - \gamma \sum_{j=1}^p (x_{j} - x_{ij})^2 \right\} \end{aligned}\] 对一个待判别的观测\(\boldsymbol x^*\), 如果\(\boldsymbol x^*\)距离训练观测点\(\boldsymbol x_i\)较远, 则\(K(\boldsymbol x^*, \boldsymbol x_i)\)的值很小, \(\boldsymbol x_i\)对\(\boldsymbol x^*\)的判别基本不起作用。 这样的性质使得径向核方法具有很强的局部性, 只有离\(\boldsymbol x^*\)很近的点才对其判别起作用。
为什么采用核函数计算观测两两的\(\binom{n}{2}\)个核函数值, 而不是直接增加非线性项? 原因是计算这些核函数值计算量是确定的, 而增加许多非线性项, 则可能有很大的计算量, 而且某些核如径向核对应的自变量空间维数是无穷维的, 不能通过添加维度的办法解决。
支持向量机的理论基于再生核希尔伯特空间(RKHS), 可参见(Trevor Hastie 2009)节5.8和节12.3.3。
38.6.1 支持向量机用于Heart数据
考虑心脏病数据Heart的判别。 共297个观测, 随机选取其中207个作为训练集, 90个作为测试集。
set.seed(1) Heart <- read.csv( "data/Heart.csv", header=TRUE, row.names=1, stringsAsFactors=TRUE) d <- na.omit(Heart) train <- sample(nrow(d), size=207) test <- -train d[["AHD"]] <- factor(d[["AHD"]], levels=c("No", "Yes"))
定义一个错判率函数:
classifier.error <- function(truth, pred){ tab1 <- table(truth, pred) err <- 1 - sum(diag(tab1))/sum(c(tab1))
38.6.1.1 线性的SVM
支持向量判别法就是SVM取多项式核, 阶数\(d=1\)的情形。 需要一个调节参数
cost
,cost
越大, 分隔边界越窄, 过度拟合危险越大。先随便取调节参数
cost=1
试验支持向量判别法:## Call: ## svm(formula = AHD ~ ., data = d[train, ], kernel = "linear", cost = 1, scale = TRUE) ## Parameters: ## SVM-Type: C-classification ## SVM-Kernel: linear ## cost: 1 ## Number of Support Vectors: 79 ## ( 38 41 ) ## Number of Classes: 2 ## Levels: ## No Yesres.svc <- svm(AHD ~ ., data=d[train,], kernel="linear", cost=1, scale=TRUE) fit.svc <- predict(res.svc) summary(res.svc)
计算拟合结果并计算错判率:
## fitted ## truth No Yes ## No 105 9 ## Yes 18 75
## SVC错判率: 0.13
e1071函数提供了
tune()
函数, 可以在训练集上用十折交叉验证选择较好的调节参数。## Parameter tuning of 'svm': ## - sampling method: 10-fold cross validation ## - best parameters: ## cost ## 0.1 ## - best performance: 0.1542857 ## - Detailed performance results: ## cost error dispersion ## 1 1e-03 0.4450000 0.08509809 ## 2 1e-02 0.1695238 0.07062868 ## 3 1e-01 0.1542857 0.07006458 ## 4 1e+00 0.1590476 0.07793796 ## 5 5e+00 0.1590476 0.08709789 ## 6 1e+01 0.1590476 0.08709789 ## 7 1e+02 0.1590476 0.08709789 ## 8 1e+03 0.1590476 0.08709789set.seed(101) res.tune <- tune(svm, AHD ~ ., data=d[train,], kernel="linear", scale=TRUE, ranges=list(cost=c(0.001, 0.01, 0.1, 1, 5, 10, 100, 1000))) summary(res.tune)
找到的最优调节参数为0.1, 可以用
## Call: ## best.tune(method = svm, train.x = AHD ~ ., data = d[train, ], ranges = list(cost = c(0.001, 0.01, 0.1, 1, 5, 10, 100, 1000)), kernel = "linear", scale = TRUE) ## Parameters: ## SVM-Type: C-classification ## SVM-Kernel: linear ## cost: 0.1 ## Number of Support Vectors: 90 ## ( 44 46 ) ## Number of Classes: 2 ## Levels: ## No Yesres.tune$best.model
获得对应于最优调节参数的模型:
在测试集上测试:
pred.svc <- predict(res.tune$best.model, newdata=d[test,]) tab1 <- table(truth=d[test,"AHD"], predict=pred.svc); tab1
## predict
## truth No Yes
## No 43 3
## Yes 11 33
## SVC错判率: 0.16
res.svm1 <- svm(AHD ~ ., data=d[train,], kernel="polynomial", order=2, cost=0.1, scale=TRUE) fit.svm1 <- predict(res.svm1) summary(res.svm1)
## fitted
## truth No Yes
## No 114 0
## Yes 82 11
## 2阶多项式核SVM错判率: 0.4
尝试找到调节参数
cost
的最优值:
set.seed(101) res.tune2 <- tune(svm, AHD ~ ., data=d[train,], kernel="polynomial", order=2, scale=TRUE, ranges=list(cost=c(0.001, 0.01, 0.1, 1, 5, 10, 100, 1000))) summary(res.tune2)
fit.svm2 <- predict(res.tune2$best.model) tab1 <- table(truth=d[train,"AHD"], fitted=fit.svm2); tab1
## fitted
## truth No Yes
## No 111 3
## Yes 4 89
## 2阶多项式核最优参数SVM错判率: 0.03
看这个最优调节参数的模型在测试集上的表现:
pred.svm2 <- predict(res.tune2$best.model, d[test,]) tab1 <- table(truth=d[test,"AHD"], predict=pred.svm2); tab1
## predict
## truth No Yes
## No 43 3
## Yes 10 34
## 2阶多项式核最优参数SVM测试集错判率: 0.14
在测试集上的表现与线性方法相近。
径向核需要的参数为
\(\gamma\)
值。
取参数
gamma=0.1
。
res.svm3 <- svm(AHD ~ ., data=d[train,], kernel="radial", gamma=0.1, cost=0.1, scale=TRUE) fit.svm3 <- predict(res.svm3) summary(res.svm3)
## fitted
## truth No Yes
## No 108 6
## Yes 26 67
## 径向核(gamma=0.1, cost=0.1)SVM错判率: 0.15
选取最优
cost
,
gamma
调节参数:
set.seed(101) res.tune4 <- tune(svm, AHD ~ ., data=d[train,], kernel="radial", scale=TRUE, ranges=list(cost=c(0.001, 0.01, 0.1, 1, 5, 10, 100, 1000), gamma=c(0.1, 0.01, 0.001))) summary(res.tune4)
fit.svm4 <- predict(res.tune4$best.model) tab1 <- table(truth=d[train,"AHD"], fitted=fit.svm4); tab1
## fitted
## truth No Yes
## No 107 7
## Yes 18 75
## 径向核最优参数SVM错判率: 0.12
看这个最优调节参数的模型在测试集上的表现:
pred.svm4 <- predict(res.tune4$best.model, d[test,]) tab1 <- table(truth=d[test,"AHD"], predict=pred.svm2); tab1
## predict
## truth No Yes
## No 43 3
## Yes 10 34
## 径向核最优参数SVM测试集错判率: 0.14
与线性方法结果相近。
James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2013. An Introduction to Statistical Learning with Applications in r . Springer. Trevor Hastie, Jerome Friedman, Robert Tibshirani. 2009. The Elements of Statistical Learning . 2nd Ed. Springer.