机器学习模型调优的重要性不容小觑,它是将模型从“能用”提升到“好用”、“可靠”甚至“卓越”的关键步骤。GridSearchCV 是 Scikit-learn 库中用于超参数调优的核心工具,它通过交叉验证和网格搜索的组合,自动寻找机器学习模型的最佳超参数组合。本文以KNN算法和鸢尾花数据集为例向大家介绍相关的使用方法。

一.核心功能

基本说明

网格搜索(Grid Search):遍历所有预定义的超参数组合(穷举搜索)

交叉验证(Cross-Validation):对每一组参数,使用交叉验证评估模型性能(如 5 折交叉验证),避免过拟合,确保参数泛化性。

最优模型选择

  • 自动选择最佳配置:返回交叉验证得分最高的超参数组合。

  • 返回训练好的模型:通过 best_estimator_ 直接获取最优模型,无需重新训练。

核心优势

优势 说明
全面性 检查所有可能的参数组合,确保不遗漏最优解。
客观性 基于数据驱动的评估(交叉验证),减少人工主观偏差。
一站式解决方案 集成参数搜索、模型训练、验证和最优模型选择,简化工作流。
结果可解释性 提供详细的搜索结果(如 cv_results_),支持深度分析不同参数的影响。
兼容性 支持所有 Scikit-learn 的评估指标(如 scoring='roc_auc'),适配分类/回归任务。

二.关键参数

参数 说明
estimator 模型对象(如 RandomForestClassifier()
param_grid 超参数网格(字典或字典列表)
cv 交叉验证策略(整数=折数,默认 5 折)
scoring 评估指标(如 'accuracy''f1''roc_auc'
n_jobs 并行任务数(-1 使用所有 CPU 核心)
verbose 输出详细程度(值越大输出越多)
refit 是否用最佳参数在整个数据集上重训练(默认 True

1.estimator:要优化的机器学习模型对象

from sklearn.ensemble import RandomForestClassifier
estimator = RandomForestClassifier(random_state=42)

2.param_grid:定义要搜索的超参数空间

字典类型:键为参数名,值为要尝试的参数值列表

支持复杂搜索策略:

  • 单字典:

param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
  • 字典列表:允许不同参数组合(适合条件参数)

param_grid = [
    {'kernel': ['linear'], 'C': [0.1, 1, 10]},
    {'kernel': ['rbf'], 'C': [1, 10, 100], 'gamma': [0.01, 0.1, 1]}
]

注意:

  • 参数名必须与估计器中的参数名完全一致

  • 参数值可以是各种类型(数值、字符串、布尔值等)

3.cv:交叉验证策略

  • 整数:指定K折交叉验证的折数(默认5)

  • 可迭代对象:生成训练/测试集索引的自定义交叉验证器

  • 支持多种交叉验证方法:

    • KFold:标准K折交叉验证

    • StratifiedKFold:分层K折(保持类别比例)

    • TimeSeriesSplit:时间序列数据的交叉验证

from sklearn.model_selection import StratifiedKFold

# 使用分层5折交叉验证
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

4. refit:找到最佳参数后是否重新训练模型

说明

  • True(默认):使用整个数据集和最佳参数重新训练模型

  • False:不重新训练,只进行参数搜索

  • 字符串:当使用多个评分指标时,指定使用哪个指标重新训练

重要提示

  • 设置为 True 时,可以通过 best_estimator_ 获取训练好的模型

  • 如果后续需要预测,必须设置为 True

5.n_jobs:并行计算设置

说明:

  • 整数:指定并行使用的CPU核心数

  • -1:使用所有可用核心

  • None:不使用并行(默认)

注意:

  • 但可能增加内存消耗,特别是大数据集时

  • 并行化可以显著加速搜索过程

6.scoring:模型评估指标

指标 参数值 说明
准确率 'accuracy' 正确预测的比例
精确率 'precision' 正例预测的准确度
召回率 'recall' 正例被正确识别的比例
F1分数 'f1' 精确率和召回率的调和平均
AUC-ROC 'roc_auc' ROC曲线下面积

字符串形式(内置指标):

# 分类问题
GridSearchCV(scoring='accuracy')

# 回归问题
GridSearchCV(scoring='neg_mean_squared_error')

 可调用对象(自定义指标):

from sklearn.metrics import make_scorer

def custom_metric(y_true, y_pred):
    return ...  # 自定义计算逻辑

custom_scorer = make_scorer(custom_metric, greater_is_better=True)
GridSearchCV(scoring=custom_scorer)

多指标评估 :

scoring = {
    'accuracy': 'accuracy',
    'precision': 'precision_macro',
    'recall': 'recall_macro',
    'f1': 'f1_macro'
}
GridSearchCV(scoring=scoring, refit='accuracy')  # 用accuracy选择最佳模型

 

三.结果获取

属性/方法 说明
best_params_ 最佳超参数组合
best_score_ 最佳交叉验证平均得分
best_estimator_ 使用最佳参数训练的模型(若 refit=True
cv_results_ 所有参数组合的详细结果(可转 DataFrame 分析)

1.best_params_:以字典返回最佳参数组合

{'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}

2.best_score_:最佳参数组合的平均交叉验证得分

  • 基于交叉验证,不是在整个数据集上的得分

  • 使用 scoring 参数指定的评估指标

3.best_estimator_:使用最佳参数在整个数据集上重新训练的模型,仅在 refit=True 时可用,可直接用于预测

best_model = grid_search.best_estimator_
predictions = best_model.predict(X_test)

4. cv_results_: 包含所有参数组合详细结果的字典

重要键

  • mean_test_score:各参数组合的平均验证得分

  • std_test_score:各参数组合得分的标准差

  • params:所有尝试的参数组合

  • rank_test_score:各参数组合的排名(1为最佳)

import pandas as pd
results = pd.DataFrame(grid_search.cv_results_)
results.sort_values(by='rank_test_score').head()

四.用法示例

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

# 获取数据集
iris = load_iris()

# 数据基本处理
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)

# 特征工程
# 1.实例化一个转化器
transfer = StandardScaler()
# 2.调用fit_transform,转换数据
x_train = transfer.fit_transform(x_train)
x_test = transfer.fit_transform(x_test)

# 机器学习(模型选择,模型训练)
# 1.实例化一个训练模型
estimator=KNeighborsClassifier()

param_grid = {"n_neighbors":[3,5,7,9]}
estimator = GridSearchCV(estimator,param_grid,cv=10,n_jobs=-1)

# 2.模型训练
estimator.fit(x_train,y_train)

# 模型评估
# y_predict=estimator.predict(x_test)
# print("预测结果为:",y_predict==y_test)
# print("预测结果为:",y_predict)

score=estimator.score(x_test,y_test)
print("模型的准确率为:",score)

# 模型的参数
print("模型的最佳参数为:",estimator.best_params_)
print("模型最佳结果为:",estimator.best_score_)
print("模型在交叉验证的结果为:")
for i in estimator.cv_results_:
    print("参数:",i)
    print(estimator.cv_results_[ i])
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐