Tune multiple machine learning models using cross validation to optimize performance

tune_models(
  d,
  outcome,
  models,
  metric,
  positive_class,
  n_folds = 5,
  tune_depth = 10,
  hyperparameters = NULL,
  model_class,
  model_name = NULL,
  allow_parallel = FALSE
)

Arguments

d

A data frame from prep_data. If you want to prepare your data on your own, use prep_data(..., no_prep = TRUE).

outcome

Optional. Name of the column to predict. When omitted the outcome from prep_data is used; otherwise it must match the outcome provided to prep_data.

models

Names of models to try. See get_supported_models for available models. Default is all available models.

metric

Which metric should be used to assess model performance? Options for classification: "ROC" (default) (area under the receiver operating characteristic curve) or "PR" (area under the precision-recall curve). Options for regression: "RMSE" (default) (root-mean-squared error, default), "MAE" (mean-absolute error), or "Rsquared." Options for multiclass: "Accuracy" (default) or "Kappa" (accuracy, adjusted for class imbalance).

positive_class

For classification only, which outcome level is the "yes" case, i.e. should be associated with high probabilities? Defaults to "Y" or "yes" if present, otherwise is the first level of the outcome variable (first alphabetically if the training data outcome was not already a factor).

n_folds

How many folds to use in cross-validation? Default = 5.

tune_depth

How many hyperparameter combinations to try? Default = 10. Value is multiplied by 5 for regularized regression. Increasing this value when tuning XGBoost models may be particularly useful for performance.

hyperparameters

Optional, a list of data frames containing hyperparameter values to tune over. If NULL (default) a random, tune_depth-deep search of the hyperparameter space will be performed. If provided, this overrides tune_depth. Should be a named list of data frames where the names of the list correspond to models (e.g. "rf") and each column in the data frame contains hyperparameter values. See hyperparameters for a template. If only one model is specified to the models argument, the data frame can be provided bare to this argument.

model_class

"regression" or "classification". If not provided, this will be determined by the class of `outcome` with the determination displayed in a message.

model_name

Quoted, name of the model. Defaults to the name of the outcome variable.

allow_parallel

Depreciated. Instead, control the number of cores though your parallel back end (e.g. with doMC).

Value

A model_list object. You can call plot, summary, evaluate, or predict on a model_list.

Details

Note that this function is training a lot of models (100 by default) and so can take a while to execute. In general a model is trained for each hyperparameter combination in each fold for each model, so run time is a function of length(models) x n_folds x tune_depth. At the default settings, a 1000 row, 10 column data frame should complete in about 30 seconds on a good laptop.

See also

For setting up model training: prep_data, supported_models, hyperparameters

For evaluating models: plot.model_list, evaluate.model_list

For making predictions: predict.model_list

For faster, but not-optimized model training: flash_models

To prepare data and tune models in a single step: machine_learn

Examples

if (FALSE) { ### Examples take about 30 seconds to run # Prepare data for tuning d <- prep_data(pima_diabetes, patient_id, outcome = diabetes) # Tune random forest, xgboost, and regularized regression classification models m <- tune_models(d) # Get some info about the tuned models m # Get more detailed info summary(m) # Plot performance over hyperparameter values for each algorithm plot(m) # To specify hyperparameter values to tune over, pass a data frame # of hyperparameter values to the hyperparameters argument: rf_hyperparameters <- expand.grid( mtry = 1:5, splitrule = c("gini", "extratrees"), min.node.size = 1 ) grid_search_models <- tune_models(d = d, outcome = diabetes, models = "rf", hyperparameters = list(rf = rf_hyperparameters) ) plot(grid_search_models) }