The train
function in the caret
package, used for training machine learning models, has a fairly complicated interface. Load the caret
package and view the function signature for train
.
# look at the train function
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(sig)
sig(train)
## train <- function(x, ...)
sig(train.default) # more interesting
## train.default <- function(x, y, method = "rf", preProcess = NULL, ...,
## weights = NULL, metric = ifelse(is.factor(y), "Accuracy",
## "RMSE"), maximize = ifelse(metric == "RMSE", FALSE, TRUE),
## trControl = trainControl(), tuneGrid = NULL, tuneLength = 3)
Notice that by default, the trControl
argument calls another function, trainControl
. trainControl
works very similarly to gpar
in the grid
package: it takes some inputs, checks that they are OK, then returns them as a list, in order to be used by the parent function.
Compare these uses of trainControl
:
# no args
trainControl()
## $method
## [1] "boot"
##
## $number
## [1] 25
##
## $repeats
## [1] 25
##
## $p
## [1] 0.75
##
## $initialWindow
## NULL
##
## $horizon
## [1] 1
##
## $fixedWindow
## [1] TRUE
##
## $verboseIter
## [1] FALSE
##
## $returnData
## [1] TRUE
##
## $returnResamp
## [1] "final"
##
## $savePredictions
## [1] FALSE
##
## $classProbs
## [1] FALSE
##
## $summaryFunction
## function (data, lev = NULL, model = NULL)
## {
## if (is.character(data$obs))
## data$obs <- factor(data$obs, levels = lev)
## postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
##
## $selectionFunction
## [1] "best"
##
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
##
## $preProcOptions$ICAcomp
## [1] 3
##
## $preProcOptions$k
## [1] 5
##
##
## $index
## NULL
##
## $indexOut
## NULL
##
## $timingSamps
## [1] 0
##
## $predictionBounds
## [1] FALSE FALSE
##
## $seeds
## [1] NA
##
## $adaptive
## $adaptive$min
## [1] 5
##
## $adaptive$alpha
## [1] 0.05
##
## $adaptive$method
## [1] "gls"
##
## $adaptive$complete
## [1] TRUE
##
##
## $trim
## [1] FALSE
##
## $allowParallel
## [1] TRUE
trainControl(method = "adaptive_LGOCV", classProbs = TRUE)
## $method
## [1] "adaptive_LGOCV"
##
## $number
## [1] 25
##
## $repeats
## [1] 25
##
## $p
## [1] 0.75
##
## $initialWindow
## NULL
##
## $horizon
## [1] 1
##
## $fixedWindow
## [1] TRUE
##
## $verboseIter
## [1] FALSE
##
## $returnData
## [1] TRUE
##
## $returnResamp
## [1] "final"
##
## $savePredictions
## [1] FALSE
##
## $classProbs
## [1] TRUE
##
## $summaryFunction
## function (data, lev = NULL, model = NULL)
## {
## if (is.character(data$obs))
## data$obs <- factor(data$obs, levels = lev)
## postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
##
## $selectionFunction
## [1] "best"
##
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
##
## $preProcOptions$ICAcomp
## [1] 3
##
## $preProcOptions$k
## [1] 5
##
##
## $index
## NULL
##
## $indexOut
## NULL
##
## $timingSamps
## [1] 0
##
## $predictionBounds
## [1] FALSE FALSE
##
## $seeds
## [1] NA
##
## $adaptive
## $adaptive$min
## [1] 5
##
## $adaptive$alpha
## [1] 0.05
##
## $adaptive$method
## [1] "gls"
##
## $adaptive$complete
## [1] TRUE
##
##
## $trim
## [1] FALSE
##
## $allowParallel
## [1] TRUE
trainControl(returnResamp = "any")
## Error in trainControl(returnResamp = "any"): incorrect value of returnResamp
Try passing bad arguments to trainControl
to see what errors you can make it throw. Does it catch all possible bad inputs?
# try to break trainControl here
# It doesn't catch, for example,
trainControl(method = "nonsense")
## $method
## [1] "nonsense"
##
## $number
## [1] 25
##
## $repeats
## [1] 25
##
## $p
## [1] 0.75
##
## $initialWindow
## NULL
##
## $horizon
## [1] 1
##
## $fixedWindow
## [1] TRUE
##
## $verboseIter
## [1] FALSE
##
## $returnData
## [1] TRUE
##
## $returnResamp
## [1] "final"
##
## $savePredictions
## [1] FALSE
##
## $classProbs
## [1] FALSE
##
## $summaryFunction
## function (data, lev = NULL, model = NULL)
## {
## if (is.character(data$obs))
## data$obs <- factor(data$obs, levels = lev)
## postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
##
## $selectionFunction
## [1] "best"
##
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
##
## $preProcOptions$ICAcomp
## [1] 3
##
## $preProcOptions$k
## [1] 5
##
##
## $index
## NULL
##
## $indexOut
## NULL
##
## $timingSamps
## [1] 0
##
## $predictionBounds
## [1] FALSE FALSE
##
## $seeds
## [1] NA
##
## $adaptive
## $adaptive$min
## [1] 5
##
## $adaptive$alpha
## [1] 0.05
##
## $adaptive$method
## [1] "gls"
##
## $adaptive$complete
## [1] TRUE
##
##
## $trim
## [1] FALSE
##
## $allowParallel
## [1] TRUE
What extra checks could you add to trainControl to improve it?
# write an improved version of trainControl here
# Using match.arg on string arguments would be a good start
trainControl2 <- function(method = c("boot", "boot632", "cv", "repeatedcv", "LOOCV", "LGOCV", "none", "oob", "adaptive_cv", "adaptive_boot", "adaptive_LGOCV"), ...) # plus the other args
{
method <- match.arg(method)
# rest of the function
}
Write some unit tests to make sure that your function really works as you think it should.
# unit tests for updated trainControl
test_that(
"trainControl, with a nonsense method, throws an error",
{
expect_error(
trainControl2(method = "nonsense"),
"'arg' should be one of"
)
}
)