详解 Scikit-learn 的 datasets.fetch_california_housing函数:加载加利福尼亚房价数据集

  • Post category:Python

fetch_california_housing() 是 Scikit-learn 内置的一个函数,用于从互联网下载加州住房价格数据集(California Housing dataset),并将其作为 Bunch 对象返回。

作用

fetch_california_housing() 的作用是获取加州住房价格数据集,用于机器学习模型的训练和评估。

这个数据集是由加利福尼亚州住房局提供的,包含了加州 1990 年的普查数据,其中共有 20640 条数据,每个数据表示一个街区,有 8 个属性:1) MedInc(街区收入的中位数); 2) HouseAge(街区房屋年龄中位数); 3) AveRooms(街区平均房间数); 4) AveBedrms(街区平均卧室数); 5) Population(街区人口总数); 6) AveOccup(街区平均入住率); 7) Latitude(街区的纬度); 8) Longitude(街区的经度)。

这个数据集通常用于回归问题的机器学习任务,例如预测街区房价中位数。

使用方法

在使用之前,需要先导入 fetch_california_housing() 函数:

from sklearn.datasets import fetch_california_housing

然后,可以通过下面的代码来获取数据集:

california_housing = fetch_california_housing()

默认情况下,该函数会将数据集下载到当前目录(~/scikit_learn_data),并将数据集从原始数据格式转换为以 Bunch 对象表示。其中,Bunch 对象类似于字典对象,具有 .data.target 属性。其中,.data 是数据矩阵,.target 是目标值。

此外,还可以通过设置参数 return_X_y=True 来直接返回数据矩阵和目标值,如下所示:

data, target = fetch_california_housing(return_X_y=True)

这样就可以将获取到的数据集作为数据矩阵和目标值拆分开来,方便后续的模型训练和评估。

实例说明

下面是两个实例说明 fetch_california_housing() 的使用方法。

实例1:将数据集保存到本地

import numpy as np
from sklearn.datasets import fetch_california_housing

# 设置数据集保存路径
data_path = './california_housing/'

# 获取数据集并保存到本地
california_housing = fetch_california_housing(data_home=data_path, download_if_missing=True)
data = np.hstack([california_housing.data, california_housing.target.reshape(-1, 1)])
np.savetxt(data_path + 'california_housing.csv', data, delimiter=',')

这个例子演示了如何将从 fetch_california_housing() 获取的数据集保存到本地。

首先,设置数据集保存路径 data_path,然后将其作为参数传递给 fetch_california_housing() 函数。在函数中,将参数 download_if_missing=True 设置为 True,表示如果本地已经存在该数据集,则不再下载,否则会从互联网上下载该数据集到本地。随后,将数据集保存到本地文件 california_housing.csv 中。

实例2:使用交叉验证评估1个线性回归模型

from sklearn.datasets import fetch_california_housing
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score

# 获取数据集
california_housing = fetch_california_housing()

# 创建模型对象
lr = LinearRegression()

# 进行交叉验证评估
scores = cross_val_score(lr, california_housing.data, california_housing.target, cv=10, scoring='neg_mean_squared_error')
print('scores:', scores)
print('mean score:', scores.mean())

这个例子演示了如何使用 fetch_california_housing() 函数获取数据集,并使用交叉验证评估一个线性回归模型。

首先,获取数据集。随后,创建一个线性回归模型 lr。接着,使用 cross_val_score() 函数对该模型进行交叉验证评估。在 cross_val_score() 函数中,将参数 cv 设置为 10,表示进行 10 折交叉验证;将参数 scoring 设置为 neg_mean_squared_error,表示评估指标为均方误差(MSE)的负数,即越小越好。最后,打印出所有折验证分的得分以及平均得分。