Функция train() в пакете caret служит для тренировки, кросс-валидации (cross-validation) и, в конечном итоге, для выбора лучшей модели.

Простейший пример применения функции:

library(caret)
set.seed(1)
lmfit <- train(mpg~.,          # формула
               data=mtcars,    # данные для формулы
               method="lm")    # метод регрессионного анализа
lmfit
## Linear Regression 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## 
## Summary of sample sizes: 32, 32, 32, 32, 32, 32, ... 
## 
## Resampling results
## 
##   RMSE      Rsquared  RMSE SD  Rsquared SD
##   4.572776  0.615759  1.63698  0.1686746  
## 
## 

Результаты:

  • Рузультаты линейного регрессионного анализа были усреднены по 25 случайным выборкам, полученным с помощью бутстреппа (bootstrap).
  • RMSE = 4.57 , Rsquared = 1.64 – показывают качество подгонки модели.
    Кроме того указаны также и соответствующие стандартные отклонения (standard deviations). Они важны для определения, насколько хорошо модель будет работать на сетах, которые она еще “не видела” (generalization ability).

Можно посмотреть все модели, которые имеют отношение к линейной регрессии:

ls(getModelInfo(model = "lm"))
##  [1]"bayesglm"   "elm"        "glm"        "glmboost"   "glmnet"    
##  [6]"glmStepAIC" "lm"         "lmStepAIC"  "plsRglm"    "rlm"

Также можно посмотреть параметры “настройки” интересующей модели:

modelLookup("lm")
##   model parameter     label forReg forClass probModel
## 1    lm parameter parameter   TRUE    FALSE     FALSE

Как мы видим, простая линейная регрессия method = "lm"не имеет параметров для настройки.

Список всех доступных моделей можно посмотреть по следующим адресам:

http://topepo.github.io/caret/modelList.html
http://topepo.github.io/caret/bytag.html

Рассмотрим модель, которая имеет параметры для настройки:

modelLookup("svmRadial")
##       model parameter label forReg forClass probModel
## 1 svmRadial     sigma Sigma   TRUE     TRUE      TRUE
## 2 svmRadial         C  Cost   TRUE     TRUE      TRUE

Как мы видим, модель имеет 2 параметра для настройки:

  • sigma
  • C

Произведем “настройку” svmRadial модели по обоим параметрам, а выбирать модель
будем по минимальному RMSE с помощью 5х10 кросс-валидации:

set.seed(1)
svmfit <- train(mpg~.,
                data=mtcars,
                method="svmRadial",
                metric="RMSE",
                maximize=F,
                tuneGrid=expand.grid(C=c(.25,.5,1,2), sigma = c(.03,.04,.05,.06,.07,.08)),
                trControl=trainControl("repeatedcv", number=10, repeats=5))

svmfit
## Support Vector Machines with Radial Basis Function Kernel 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 5 times) 
## 
## Summary of sample sizes: 29, 29, 28, 28, 29, 29, ... 
## 
## Resampling results across tuning parameters:
## 
##   C     sigma  RMSE      Rsquared   RMSE SD   Rsquared SD
##   0.25  0.03   3.240881  0.8876284  1.586471  0.1104699  
##   0.25  0.04   3.262631  0.8798178  1.579929  0.1160362  
##   0.25  0.05   3.306468  0.8731589  1.574965  0.1213410  
##   0.25  0.06   3.358232  0.8670291  1.576731  0.1261827  
##   0.25  0.07   3.415166  0.8612436  1.580513  0.1312709  
##   0.25  0.08   3.474735  0.8555915  1.586303  0.1364247  
##   0.50  0.03   2.887490  0.8923670  1.409169  0.1013447  
##   0.50  0.04   2.865294  0.8894495  1.375109  0.1046870  
##   0.50  0.05   2.877322  0.8854410  1.353343  0.1092166  
##   0.50  0.06   2.919128  0.8809504  1.349005  0.1142167  
##   0.50  0.07   2.973591  0.8761904  1.354651  0.1196062  
##   0.50  0.08   3.033687  0.8715886  1.366976  0.1248165  
##   1.00  0.03   2.665711  0.8950718  1.183801  0.1009863  
##   1.00  0.04   2.724610  0.8899812  1.150501  0.1076596  
##   1.00  0.05   2.801872  0.8845433  1.126097  0.1135936  
##   1.00  0.06   2.867553  0.8814501  1.118444  0.1175818  
##   1.00  0.07   2.917247  0.8784264  1.116208  0.1221584  
##   1.00  0.08   2.964989  0.8751478  1.118254  0.1265269  
##   2.00  0.03   2.786457  0.8877993  1.102575  0.1023794  
##   2.00  0.04   2.812728  0.8828893  1.068991  0.1092689  
##   2.00  0.05   2.865070  0.8772023  1.065686  0.1151911  
##   2.00  0.06   2.914954  0.8724187  1.073292  0.1200380  
##   2.00  0.07   2.950067  0.8700023  1.089608  0.1240911  
##   2.00  0.08   2.987405  0.8661810  1.105442  0.1285558  
## 
## RMSE was used to select the optimal model using  the smallest value.
## The final values used for the model were sigma = 0.03 and C = 1.
plot(svmfit, scales=list(x=list(log=2)))

plot of chunk unnamed-chunk-5

Резултат train – это объект класса train, соответствующие элементы которого можно
извлекать с помощью оператора $.

ls(svmfit)
##  [1]"bestTune"     "call"         "coefnames"    "control"     
##  [5]"dots"         "finalModel"   "maximize"     "method"      
##  [9]"metric"       "modelInfo"    "modelType"    "perfNames"   
## [13]"pred"         "preProcess"   "resample"     "resampledCM" 
## [17]"results"      "terms"        "times"        "trainingData"
## [21]"xlevels"      "yLimits"

Например, время настройки модели можно получить таким образом:

svmfit$times
## $everything
##    user  system elapsed 
##  20.207  14.220  21.524 
## 
## $final
##    user  system elapsed 
##   0.008   0.002   0.009 
## 
## $prediction
## [1]NA NA NA

В общем случае функция train() имеет следующие аргументы:

library(caret)
methods(train)
## [1]train.default train.formula
args(train.default)
## function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, 
##     metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric == 
##         "RMSE", FALSE, TRUE), trControl = trainControl(), tuneGrid = NULL, 
##     tuneLength = 3) 
## NULL
args(train.formula)
## function (form, data, ..., weights, subset, na.action = na.fail, 
##     contrasts = NULL) 
## NULL

На что следует обратить внимание:

  • т.к. кросс-валидация работает посредством усреднения по случайным выборкам,
    для повторяемости результатов следует зафиксировать сид set.seed()
    или
  • для сравнения разных моделей, оптимальным способом будет задание выборок
    явным способом заранее, и затем использование одинакового набора случайных выборок
    для всех моделей:
set.seed(1)
indx <- createFolds(mtcars$mpg, returnTrain = TRUE)
ctrl <- trainControl(method = "cv", index = indx)
lmfit2 <- train(mpg ~.,
                data=mtcars,
                method="lm",
                trControl=ctrl)
lmfit2
## Linear Regression 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## 
## Summary of sample sizes: 29, 29, 28, 28, 29, 29, ... 
## 
## Resampling results
## 
##   RMSE      Rsquared   RMSE SD   Rsquared SD
##   3.240244  0.8230029  1.320212  0.1805104  
## 
## 
svmfit2 <- train(mpg ~.,
                data=mtcars,
                method="lm",
                tuneLength=10,
                trControl=ctrl)
svmfit2
## Linear Regression 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## 
## Summary of sample sizes: 29, 29, 28, 28, 29, 29, ... 
## 
## Resampling results
## 
##   RMSE      Rsquared   RMSE SD   Rsquared SD
##   3.240244  0.8230029  1.320212  0.1805104  
## 
## 

Простая проверка на совпадение выборок для разных моделей:

lmfit2$control$index$Fold05
##  [1]1  2  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 23 24 25 26
## [24]27 28 29 30 31 32
svmfit2$control$index$Fold05
##  [1]1  2  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 23 24 25 26
## [24]27 28 29 30 31 32
© 2014 In R we trust.
Top
Follow us: