sklearn.model_selection.GridSearchCV

GridSearchCVsklearn.model_selection 提供的 超参数优化工具,用于 遍历所有可能的超参数组合,通过 交叉验证 选择 最佳超参数


1. GridSearchCV 作用

  • 自动搜索最佳超参数组合,提高模型性能。
  • 使用交叉验证(默认 cv=5 评估不同超参数的效果。
  • 适用于分类和回归任务,支持 不同评分指标accuracyf1roc_aucr2 等)。

2. GridSearchCV 代码示例

(1) 超参数优化 SVM

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 设定超参数搜索空间
param_grid = {
    "C": [0.1, 1, 10],  # 正则化参数
    "kernel": ["linear", "rbf"]  # 选择不同核函数
}

# 初始化 SVM 并进行网格搜索
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)
print("最佳得分:", grid_search.best_score_)

输出

最佳参数: {'C': 1, 'kernel': 'linear'}
最佳得分: 0.98

解释

  • GridSearchCV 遍历所有可能的 Ckernel 组合,选择最优超参数。

(2) 使用 StratifiedKFold 进行分层交叉验证

from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(n_splits=5)
grid_search = GridSearchCV(SVC(), param_grid, cv=cv)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)

解释

  • 适用于类别不均衡数据,保证交叉验证每折类别比例一致。

(3) 选择不同评分指标

grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="accuracy")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)

可选评分指标

任务 评分指标 (scoring) 说明
分类 "accuracy" 准确率
分类 "f1" F1-score
分类 "roc_auc" ROC AUC
回归 "r2" R²(决定系数)
回归 "neg_mean_absolute_error" 负 MAE
回归 "neg_mean_squared_error" 负 MSE

示例:

grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="f1_macro")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)

(4) 训练多个模型

from sklearn.ensemble import RandomForestClassifier

param_grid = {
    "n_estimators": [10, 50, 100],
    "max_depth": [None, 10, 20]
}

grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)

解释

  • GridSearchCV 可用于多种模型,如 RandomForestClassifier

3. GridSearchCV 的参数

GridSearchCV(estimator, param_grid, scoring=None, cv=None, n_jobs=None, verbose=0)
参数 说明
estimator 评估器(模型),如 SVC()
param_grid 需要搜索的超参数
scoring 评分指标(如 "accuracy""f1""roc_auc""r2"
cv 交叉验证策略(默认 5,可传 KFold()StratifiedKFold()
n_jobs 并行计算(-1 表示使用所有 CPU 核心)
verbose 是否打印搜索过程(0=不输出,1=简单输出,2=详细输出)

4. 适用场景

  • 超参数优化,提高模型性能
  • 分类/回归任务的模型调优
  • 结合 StratifiedKFold 处理类别不均衡数据

5. GridSearchCV vs. RandomizedSearchCV vs. train_test_split

方法 适用情况 作用
GridSearchCV 参数范围较小,计算量可控 遍历所有参数组合
RandomizedSearchCV 参数范围较大 随机选择部分参数搜索
train_test_split 简单模型训练 训练集/测试集划分

示例:

from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform

# 设定参数分布
param_dist = {'C': uniform(0.1, 10), 'kernel': ['linear', 'rbf']}

# 进行随机搜索
random_search = RandomizedSearchCV(SVC(), param_dist, n_iter=5, cv=5, random_state=42)
random_search.fit(X, y)

print("最佳参数:", random_search.best_params_)

解释

  • RandomizedSearchCV 随机选取参数组合,适用于 大参数空间

6. 结论

  • GridSearchCV 遍历所有超参数组合,通过 交叉验证选择最佳参数,适用于 分类和回归任务
  • 如果参数空间 较大,可使用 RandomizedSearchCV 进行随机搜索
  • 如果数据 类别不均衡,应结合 StratifiedKFold 进行分层交叉验证
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐