R 编程中的重复 K 折交叉验证
重复 K-fold 是分类和回归机器学习模型最优选的交叉验证技术。数据集的多次混洗和随机采样是重复 K-fold 算法的核心过程,它可以生成一个健壮的模型,因为它涵盖了最大的训练和测试操作。这种用于评估机器学习模型准确性的交叉验证技术的工作取决于两个参数。第一个参数是K ,它是一个整数值,它表明给定的数据集将被分成 K 个折叠(或子集)。在 K 个折叠中,模型在 K-1 个子集上进行训练,剩余的子集将用于评估模型的性能。这些步骤将重复到一定次数,这将由该算法的第二个参数决定,因此它的名称为重复 K-fold,即K-fold 交叉验证算法重复一定次数次。
重复 K 折交叉验证涉及的步骤:
重复K-fold的每次迭代都是正常K-fold算法的实现。在 K 折交叉验证技术中,涉及以下步骤:
- 将数据集随机拆分为 K 个子集
- 对于每个已开发的数据点子集
- 将该子集视为验证集
- 使用所有其余子集进行训练
- 训练模型并在验证集或测试集上对其进行评估
- 计算预测误差
- 重复上述步骤 K 次,即直到模型没有在所有子集上训练和测试
- 通过取每种情况下的预测误差的平均值来生成总体预测误差
因此,在重复 k 折交叉验证方法中,上述步骤将在给定数据集上重复一定次数。在每次迭代中,数据集都会有完全不同的拆分为 K-folds,模型的性能得分也会不同。最后,所有情况下的平均性能得分将给出模型的最终准确度。为了执行重复 K-fold 方法的这些复杂任务,R 语言提供了丰富的内置函数和包库。以下是在分类和回归机器学习模型上实现重复 K 折交叉验证技术的分步方法。
对分类实施重复 K 折交叉验证
当目标变量是分类数据类型时,分类机器学习模型用于预测类标签。在这个例子中,朴素贝叶斯算法将被用作概率分类器来预测目标变量的类标签。
第 1 步:加载所需的包和库
必须导入所有必要的库和包才能执行任务而不会出现任何错误。下面是为重复 K 折算法设置 R 环境的代码。
R
# load the library
# package to perform data manipulation
# and visualization
library(tidyverse)
# package to compute
# cross - validation methods
library(caret)
# loading package to
# import desired dataset
library(ISLR)
R
# assigning the complete dataset
# Smarket to a variable
dataset <- Smarket[complete.cases(Smarket), ]
# display the dataset with details
# like column name and its data type
# along with values in each row
glimpse(dataset)
# checking values present
# in the Direction column
# of the dataset
table(dataset$Direction)
R
# setting seed to generate a
# reproducible random sampling
set.seed(123)
# define training control which
# generates parameters that further
# control how models are created
train_control <- trainControl(method = "repeatedcv",
number = 10, repeats = 3)
# building the model and
# predicting the target variable
# as per the Naive Bayes classifier
model <- train(Direction~., data = dataset,
trControl = train_control, method = "nb")
R
# summarize results of the
# model after calculating
# prediction error in each case
print(model)
R
# loading required packages
# package to perform data manipulation
# and visualization
library(tidyverse)
# package to compute
# cross - validation methods
library(caret)
R
# access the data from R’s datasets package
data(trees)
# look at the first several rows of the data
head(trees)
R
# setting seed to generate a
# reproducible random sampling
set.seed(125)
# defining training control as
# repeated cross-validation and
# value of K is 10 and repetition is 3 times
train_control <- trainControl(method = "repeatedcv",
number = 10, repeats = 3)
# training the model by assigning sales column
# as target variable and rest other column
# as independent variable
model <- train(Volume ~., data = trees,
method = "lm",
trControl = train_control)
R
# printing model performance metrics
# along with other details
print(model)
第 2 步:探索数据集
导入所需的库后,是时候在 R 环境中加载数据集了。数据集的探索也非常重要,因为它可以让您了解在将数据集用于训练和测试目的之前是否需要对数据集进行任何更改。下面是执行此任务的代码。
R
# assigning the complete dataset
# Smarket to a variable
dataset <- Smarket[complete.cases(Smarket), ]
# display the dataset with details
# like column name and its data type
# along with values in each row
glimpse(dataset)
# checking values present
# in the Direction column
# of the dataset
table(dataset$Direction)
输出:
Rows: 1,250
Columns: 9
$ Year 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, ...
$ Lag1 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, 0.680, 0.701, -0.562, 0.546, -1...
$ Lag2 -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, 0.680, 0.701, -0.562, 0...
$ Lag3 -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, 0.680, 0.701, -...
$ Lag4 -1.055, -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, 0.680, ...
$ Lag5 5.010, -1.055, -2.624, -0.192, 0.381, 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, ...
$ Volume 1.19130, 1.29650, 1.41120, 1.27600, 1.20570, 1.34910, 1.44500, 1.40780, 1.16400, 1.23260, 1.30900, 1.25800, 1.09800, 1.05310, ...
$ Today 0.959, 1.032, -0.623, 0.614, 0.213, 1.392, -0.403, 0.027, 1.303, 0.287, -0.498, -0.189, 0.680, 0.701, -0.562, 0.546, -1.747, 0...
$ Direction Up, Up, Down, Up, Up, Up, Down, Up, Up, Up, Down, Down, Up, Up, Down, Up, Down, Up, Down, Down, Down, Down, Up, Down, Down, Up...
> table(dataset$Direction)
Down Up
602 648
以上信息表明数据集的自变量是
- 下采样
- 上采样
- 使用 SMOTE 和 ROSE 的混合采样
步骤 3:使用重复 K 折算法构建模型
trainControl()函数被定义为设置重复次数和 K 参数的值。之后,按照重复 K 折算法中涉及的步骤开发模型。下面是实现。
R
# setting seed to generate a
# reproducible random sampling
set.seed(123)
# define training control which
# generates parameters that further
# control how models are created
train_control <- trainControl(method = "repeatedcv",
number = 10, repeats = 3)
# building the model and
# predicting the target variable
# as per the Naive Bayes classifier
model <- train(Direction~., data = dataset,
trControl = train_control, method = "nb")
第 4 步:评估模型的准确性
在这最后一步中,模型的性能分数将在对所有可能的验证折叠进行测试后生成。下面是打印开发模型的准确性和总体摘要的代码。
R
# summarize results of the
# model after calculating
# prediction error in each case
print(model)
输出:
Naive Bayes
1250 samples
8 predictor
2 classes: 'Down', 'Up'
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 1124, 1125, 1126, 1125, 1125, 1126, ...
Resampling results across tuning parameters:
usekernel Accuracy Kappa
FALSE 0.9562616 0.9121273
TRUE 0.9696037 0.9390601
Tuning parameter 'fL' was held constant at a value of 0
Tuning parameter 'adjust' was held constant at a value of 1
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were fL = 0, usekernel = TRUE and adjust = 1.
对回归实施重复 K 折交叉验证
对于目标变量具有连续性(如区域温度、商品成本等)的那些数据集,回归机器学习模型是首选。目标变量的值可以是整数或浮点数。以下是在回归模型中实现重复 k 折算法作为交叉验证技术所需的步骤。
第 1 步:加载数据集和所需的包
作为第一步,R 环境必须加载所有必要的包和库以执行各种操作。下面是导入所有必需库的代码。
R
# loading required packages
# package to perform data manipulation
# and visualization
library(tidyverse)
# package to compute
# cross - validation methods
library(caret)
第 2 步:加载和检查数据集
导入所有包后,就该加载所需的数据集了。这里的“trees”数据集用于回归模型,它是 R 语言的内置数据集。此外,为了建立正确的模型,有必要了解数据集的结构。所有这些任务都可以使用以下代码执行。
R
# access the data from R’s datasets package
data(trees)
# look at the first several rows of the data
head(trees)
输出:
Girth Height Volume
1 8.3 70 10.3
2 8.6 65 10.3
3 8.8 63 10.2
4 10.5 72 16.4
5 10.7 81 18.8
6 10.8 83 19.7
第 3 步:使用重复 K 折算法构建模型
trainControl()函数被定义为设置重复次数和 K 参数的值。之后,按照重复 K 折算法中涉及的步骤开发模型。下面是实现。
R
# setting seed to generate a
# reproducible random sampling
set.seed(125)
# defining training control as
# repeated cross-validation and
# value of K is 10 and repetition is 3 times
train_control <- trainControl(method = "repeatedcv",
number = 10, repeats = 3)
# training the model by assigning sales column
# as target variable and rest other column
# as independent variable
model <- train(Volume ~., data = trees,
method = "lm",
trControl = train_control)
第 4 步:评估模型的准确性
根据重复 K-fold 技术的算法,该模型针对数据集的每个唯一折叠(或子集)进行测试,并且在每种情况下,计算预测误差,最后将所有预测误差的平均值视为最终结果模型的性能得分。因此,下面是打印模型的最终分数和总体摘要的代码。
R
# printing model performance metrics
# along with other details
print(model)
输出:
Linear Regression
31 samples
2 predictor
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 28, 28, 28, 29, 28, 28, ...
Resampling results:
RMSE Rsquared MAE
4.021691 0.957571 3.362063
Tuning parameter 'intercept' was held constant at a value of TRUE
重复 K 折交叉验证的优点
- 一种非常有效的方法来估计模型的预测误差和准确性。
- 在每次重复中,数据样本都会被打乱,从而导致样本数据出现不同的拆分。
重复 K 折交叉验证的缺点
- 较低的 K 值会导致模型有偏差,而较高的 K 值会导致模型的性能指标发生变化。因此,必须为模型使用正确的 K 值(通常 K = 5 和 K = 10 是可取的)。
- 每次重复时,算法都必须从头开始训练模型,这意味着评估模型的计算时间会随着重复次数的增加而增加。