good-code ex 6: outsourcing argument checking, trainControl

The train function in the caret package, used for training machine learning models, has a fairly complicated interface. Load the caret package and view the function signature for train.

# look at the train function
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(sig)
sig(train)
## train <- function(x, ...)
sig(train.default) # more interesting
## train.default <- function(x, y, method = "rf", preProcess = NULL, ...,
##              weights = NULL, metric = ifelse(is.factor(y), "Accuracy",
##              "RMSE"), maximize = ifelse(metric == "RMSE", FALSE, TRUE),
##              trControl = trainControl(), tuneGrid = NULL, tuneLength = 3)

Notice that by default, the trControl argument calls another function, trainControl. trainControl works very similarly to gpar in the grid package: it takes some inputs, checks that they are OK, then returns them as a list, in order to be used by the parent function.

Compare these uses of trainControl:

# no args
trainControl()
## $method
## [1] "boot"
## 
## $number
## [1] 25
## 
## $repeats
## [1] 25
## 
## $p
## [1] 0.75
## 
## $initialWindow
## NULL
## 
## $horizon
## [1] 1
## 
## $fixedWindow
## [1] TRUE
## 
## $verboseIter
## [1] FALSE
## 
## $returnData
## [1] TRUE
## 
## $returnResamp
## [1] "final"
## 
## $savePredictions
## [1] FALSE
## 
## $classProbs
## [1] FALSE
## 
## $summaryFunction
## function (data, lev = NULL, model = NULL) 
## {
##     if (is.character(data$obs)) 
##         data$obs <- factor(data$obs, levels = lev)
##     postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
## 
## $selectionFunction
## [1] "best"
## 
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
## 
## $preProcOptions$ICAcomp
## [1] 3
## 
## $preProcOptions$k
## [1] 5
## 
## 
## $index
## NULL
## 
## $indexOut
## NULL
## 
## $timingSamps
## [1] 0
## 
## $predictionBounds
## [1] FALSE FALSE
## 
## $seeds
## [1] NA
## 
## $adaptive
## $adaptive$min
## [1] 5
## 
## $adaptive$alpha
## [1] 0.05
## 
## $adaptive$method
## [1] "gls"
## 
## $adaptive$complete
## [1] TRUE
## 
## 
## $trim
## [1] FALSE
## 
## $allowParallel
## [1] TRUE
trainControl(method = "adaptive_LGOCV", classProbs = TRUE)
## $method
## [1] "adaptive_LGOCV"
## 
## $number
## [1] 25
## 
## $repeats
## [1] 25
## 
## $p
## [1] 0.75
## 
## $initialWindow
## NULL
## 
## $horizon
## [1] 1
## 
## $fixedWindow
## [1] TRUE
## 
## $verboseIter
## [1] FALSE
## 
## $returnData
## [1] TRUE
## 
## $returnResamp
## [1] "final"
## 
## $savePredictions
## [1] FALSE
## 
## $classProbs
## [1] TRUE
## 
## $summaryFunction
## function (data, lev = NULL, model = NULL) 
## {
##     if (is.character(data$obs)) 
##         data$obs <- factor(data$obs, levels = lev)
##     postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
## 
## $selectionFunction
## [1] "best"
## 
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
## 
## $preProcOptions$ICAcomp
## [1] 3
## 
## $preProcOptions$k
## [1] 5
## 
## 
## $index
## NULL
## 
## $indexOut
## NULL
## 
## $timingSamps
## [1] 0
## 
## $predictionBounds
## [1] FALSE FALSE
## 
## $seeds
## [1] NA
## 
## $adaptive
## $adaptive$min
## [1] 5
## 
## $adaptive$alpha
## [1] 0.05
## 
## $adaptive$method
## [1] "gls"
## 
## $adaptive$complete
## [1] TRUE
## 
## 
## $trim
## [1] FALSE
## 
## $allowParallel
## [1] TRUE
trainControl(returnResamp = "any")
## Error in trainControl(returnResamp = "any"): incorrect value of returnResamp

Try passing bad arguments to trainControl to see what errors you can make it throw. Does it catch all possible bad inputs?

# try to break trainControl here
# It doesn't catch, for example, 
trainControl(method = "nonsense")
## $method
## [1] "nonsense"
## 
## $number
## [1] 25
## 
## $repeats
## [1] 25
## 
## $p
## [1] 0.75
## 
## $initialWindow
## NULL
## 
## $horizon
## [1] 1
## 
## $fixedWindow
## [1] TRUE
## 
## $verboseIter
## [1] FALSE
## 
## $returnData
## [1] TRUE
## 
## $returnResamp
## [1] "final"
## 
## $savePredictions
## [1] FALSE
## 
## $classProbs
## [1] FALSE
## 
## $summaryFunction
## function (data, lev = NULL, model = NULL) 
## {
##     if (is.character(data$obs)) 
##         data$obs <- factor(data$obs, levels = lev)
##     postResample(data[, "pred"], data[, "obs"])
## }
## <environment: namespace:caret>
## 
## $selectionFunction
## [1] "best"
## 
## $preProcOptions
## $preProcOptions$thresh
## [1] 0.95
## 
## $preProcOptions$ICAcomp
## [1] 3
## 
## $preProcOptions$k
## [1] 5
## 
## 
## $index
## NULL
## 
## $indexOut
## NULL
## 
## $timingSamps
## [1] 0
## 
## $predictionBounds
## [1] FALSE FALSE
## 
## $seeds
## [1] NA
## 
## $adaptive
## $adaptive$min
## [1] 5
## 
## $adaptive$alpha
## [1] 0.05
## 
## $adaptive$method
## [1] "gls"
## 
## $adaptive$complete
## [1] TRUE
## 
## 
## $trim
## [1] FALSE
## 
## $allowParallel
## [1] TRUE

What extra checks could you add to trainControl to improve it?

# write an improved version of trainControl here
# Using match.arg on string arguments would be a good start
trainControl2 <- function(method = c("boot", "boot632", "cv", "repeatedcv", "LOOCV", "LGOCV", "none", "oob", "adaptive_cv", "adaptive_boot", "adaptive_LGOCV"), ...) # plus the other args
{
  method <- match.arg(method)
  # rest of the function
}

Write some unit tests to make sure that your function really works as you think it should.

# unit tests for updated trainControl
test_that(
  "trainControl, with a nonsense method, throws an error",
  {
    expect_error(
      trainControl2(method = "nonsense"),
      "'arg' should be one of"  
    )
  }
)