详解 Scikit-learn 的 model_selection.GridSearchCV函数:网格搜索超参数

  • Post category:Python

Scikit-learn 是一个用于机器学习的 Python 库,提供了大量的算法和工具,可以帮助我们快速地完成各种机器学习任务。其中,sklearn.model_selection.GridSearchCV 函数是一个重要的工具,它能够在指定的参数空间内进行网格搜索,从而寻找最优的参数组合,提供了一种简洁而易于使用的调参方式。

1. GridSearchCV函数的作用:

sklearn.model_selection.GridSearchCV 函数的作用是通过网格搜索来自动调整算法的超参数,例如我们可以使用 GridSearchCV 函数找到最优的模型参数,从而让模型拥有更高的泛化能力和更好的性能。在实际使用中,我们常常使用交叉验证的方式来避免过拟合。

2. GridSearchCV函数的使用方法:

  • 导入相关库和数据集
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV, train_test_split

# 载入鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target
  • 将数据集分成训练集和测试集(可选)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
  • 定义网格搜索的超参数空间,以决策树分类器为例
param_grid = {'max_depth': [3, 4, 5, 6], 'min_samples_leaf': [1, 2, 3]}
  • 实例化分类器和网格搜索对象
dtc = DecisionTreeClassifier()
grid_search = GridSearchCV(dtc, param_grid, cv=10)
  • 对训练集进行训练,寻找最优的参数组合
grid_search.fit(X_train, y_train)
  • 查看最优参数和得分
print('Best parameters:', grid_search.best_params_)
print('Best score:', grid_search.best_score_)

例如,我们使用上述代码对鸢尾花数据集进行分类,得到的最优参数为{‘max_depth’: 3, ‘min_samples_leaf’: 3},最优得分为0.952。

3. 实例分析

下面是两个实例分析,说明了如何使用 GridSearchCV 函数进行超参数的调优。

实例1:K近邻算法的超参数调优

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, train_test_split

# 载入鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# 定义超参数空间
param_grid = {'n_neighbors': [3, 5, 7], 'weights': ['uniform', 'distance']}

# 实例化分类器和网格搜索对象
knn = KNeighborsClassifier()
grid_search = GridSearchCV(knn, param_grid, cv=10)

# 对训练集进行训练,寻找最优的参数组合
grid_search.fit(X_train, y_train)

# 打印最优参数和最优得分
print('Best parameters:', grid_search.best_params_)
print('Best score:', grid_search.best_score_)

输出结果为:
Best parameters: {‘n_neighbors’: 5, ‘weights’: ‘uniform’}
Best score: 0.9619047619047618

实例2:SVM算法的超参数调优

from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split

# 载入鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# 定义超参数空间
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf', 'poly']}

# 实例化分类器和网格搜索对象
svc = SVC()
grid_search = GridSearchCV(svc, param_grid, cv=10)

# 对训练集进行训练,寻找最优的参数组合
grid_search.fit(X_train, y_train)

# 打印最优参数和最优得分
print('Best parameters:', grid_search.best_params_)
print('Best score:', grid_search.best_score_)

输出结果为:
Best parameters: {‘C’: 1, ‘kernel’: ‘linear’}
Best score: 0.980952380952381