Physical Address
304 North Cardinal St.
Dorchester Center, MA 02124
Physical Address
304 North Cardinal St.
Dorchester Center, MA 02124
机器学习实战系列:首篇聚焦项目目标和数据质量评估的关键步骤。
在这个系列中,我们将以 Kaggle 上的 Binary Classification with a Tabular Employee Attrition Dataset 项目为例,逐步介绍机器学习的基础知识和实践应用,包括数据预处理、特征工程、模型选择与调优等核心环节。
该数据集来源于广为流传的 IBM HR Analytics Employee Attrition & Performance 数据集,主要处理一个员工离职预测的二分类问题,我们的目标是构建模型来预测员工是否会离职。
在当今企业环境中,员工流失是一个普遍问题,它显著影响着公司的运营效率和成本。通过对员工特征(如年龄、教育背景、职位)和工作行为(如出差频率、工作满意度、加班情况)进行数字化分析,我们可以找出导致员工流失的潜在因素。
如果能实时获取这些数据,我们就可以利用模型预测员工流失风险,从而及时制定挽留策略,预防实际的流失情况。
基于此背景,我们的算法建模有两个主要目标:
综上所述,我们需要的模型不仅要有强大的预测能力,还要能输出特征重要性排名,并具备充分的可解释性——即能清楚地说明特征变化如何影响员工的离职倾向。
在开始分析之前,我们需要先理解每个数据字段的具体含义。下面列出了数据集中的主要字段及其解释:
字段名称 | 含义 |
---|---|
Age | 年龄 |
Attrition | 离职情况(是否离职) |
BusinessTravel | 出差频率 |
DailyRate | 每日工资 |
Department | 部门 |
DistanceFromHome | 离家距离 |
Education | 教育水平 |
EducationField | 教育领域 |
EmployeeCount | 员工计数 |
id | 员工编号 |
EnvironmentSatisfaction | 环境满意度 |
Gender | 性别 |
HourlyRate | 时薪 |
JobInvolvement | 工作投入度 |
JobLevel | 工作级别 |
JobRole | 职位 |
JobSatisfaction | 工作满意度 |
MaritalStatus | 婚姻状况 |
MonthlyIncome | 月收入 |
MonthlyRate | 月工资 |
NumCompaniesWorked | 曾工作过的公司数量 |
Over18 | 是否成年 |
OverTime | 加班情况 |
PercentSalaryHike | 加薪百分比 |
PerformanceRating | 绩效评级 |
RelationshipSatisfaction | 人际关系满意度 |
StandardHours | 标准工时 |
StockOptionLevel | 股票期权级别 |
TotalWorkingYears | 总工作年限 |
TrainingTimesLastYear | 去年培训次数 |
WorkLifeBalance | 工作与生活的平衡 |
YearsAtCompany | 在公司的年限 |
YearsInCurrentRole | 当前职位的年限 |
YearsSinceLastPromotion | 自上次晋升后的年限 |
YearsWithCurrManager | 与现任经理共事的年限 |
“Garbage In, Garbage Out”是数据科学领域的基本准则。在开展复杂的特征工程和模型训练前,全面检查数据质量是不可或缺的环节。这能帮助我们及早发现并解决潜在问题,避免后续工作的浪费和模型结果的偏差。
数据质量检查主要涵盖以下任务:缺失值分析、唯一值分析、重复值检测、数据类型核对和异常值识别等。
首先,让我们了解数据集的规模。
print(f"训练集数据结构: {train_df.shape}")
print(f"测试集数据结构: {test_df.shape}")
训练集数据结构: (1677 行, 35 列)
测试集数据结构: (1119 行, 34 列) (注:测试集通常不包含目标变量 Attrition)
为了提高效率,我们将对每个字段的缺失情况、唯一值数量和数据类型进行一次性检查。
我在这里构建了一个数据特征分析器,它能同时生成每个数据列的缺失值数量及比例、唯一值数量和数据类型信息。这种自动化的检查方法可以帮我们快速识别数据中的潜在问题。
import pandas as pd
from tabulate import tabulate
def column_info(df, sort_by='default'):
"""
计算并输出每一列的缺失值、唯一值和数据类型。
参数:
df: 数据集(DataFrame)
sort_by: 排序依据,接受的值包括 'missing', 'unique' 或 'default'(默认顺序)
"""
# 创建列信息表,直接计算缺失值和唯一值
column_info_df = pd.DataFrame({
'Missing_Count': df.isnull().sum(),
'Missing_Percent': df.isnull().mean() * 100, # 使用mean()直接得到百分比
'Unique_Count': df.nunique(),
'Data_Type': df.dtypes
})
# 根据指定的排序依据进行排序
if sort_by == 'missing':
column_info_df.sort_values(by='Missing_Count', ascending=False, inplace=True)
elif sort_by == 'unique':
column_info_df.sort_values(by='Unique_Count', ascending=False, inplace=True)
# 输出列信息表格
print("Column Information:")
print(tabulate(column_info_df, headers='keys', tablefmt='pretty'))
# 查看训练集数据质量
column_info(train_df, sort_by='unique')
Column Information:
+--------------------------+---------------+-----------------+--------------+-----------+
| | Missing_Count | Missing_Percent | Unique_Count | Data_Type |
+--------------------------+---------------+-----------------+--------------+-----------+
| id | 0 | 0.0 | 1677 | int64 |
| MonthlyRate | 0 | 0.0 | 903 | int64 |
| MonthlyIncome | 0 | 0.0 | 895 | int64 |
| DailyRate | 0 | 0.0 | 625 | int64 |
| HourlyRate | 0 | 0.0 | 71 | int64 |
| Age | 0 | 0.0 | 43 | int64 |
| TotalWorkingYears | 0 | 0.0 | 41 | int64 |
| YearsAtCompany | 0 | 0.0 | 34 | int64 |
| DistanceFromHome | 0 | 0.0 | 29 | int64 |
| YearsInCurrentRole | 0 | 0.0 | 19 | int64 |
| YearsWithCurrManager | 0 | 0.0 | 18 | int64 |
| YearsSinceLastPromotion | 0 | 0.0 | 16 | int64 |
| PercentSalaryHike | 0 | 0.0 | 15 | int64 |
| NumCompaniesWorked | 0 | 0.0 | 10 | int64 |
| JobRole | 0 | 0.0 | 9 | object |
| TrainingTimesLastYear | 0 | 0.0 | 7 | int64 |
| Education | 0 | 0.0 | 6 | int64 |
| JobLevel | 0 | 0.0 | 6 | int64 |
| EducationField | 0 | 0.0 | 6 | object |
| JobSatisfaction | 0 | 0.0 | 4 | int64 |
| EnvironmentSatisfaction | 0 | 0.0 | 4 | int64 |
| WorkLifeBalance | 0 | 0.0 | 4 | int64 |
| StockOptionLevel | 0 | 0.0 | 4 | int64 |
| RelationshipSatisfaction | 0 | 0.0 | 4 | int64 |
| JobInvolvement | 0 | 0.0 | 4 | int64 |
| Department | 0 | 0.0 | 3 | object |
| BusinessTravel | 0 | 0.0 | 3 | object |
| MaritalStatus | 0 | 0.0 | 3 | object |
| PerformanceRating | 0 | 0.0 | 2 | int64 |
| Gender | 0 | 0.0 | 2 | object |
| OverTime | 0 | 0.0 | 2 | object |
| Attrition | 0 | 0.0 | 2 | int64 |
| StandardHours | 0 | 0.0 | 1 | int64 |
| Over18 | 0 | 0.0 | 1 | object |
| EmployeeCount | 0 | 0.0 | 1 | int64 |
+--------------------------+---------------+-----------------+--------------+-----------+
关键发现:
StandardHours
、Over18
和 EmployeeCount
这三个字段的唯一值均为 1。这表明它们在所有样本中取值相同,对预测员工流失没有任何帮助。这些常量列应在建模前移除。id
列的唯一值数量与样本总数相同(1677),说明它仅作为数据的唯一标识符。由于不含预测信息,也应在建模前移除。int64
、object
)是否符合预期,这将有助于后续的数据处理(如确定哪些需要编码)。接下来,我们检查数据中是否存在重复记录。
下面的函数通过计算数据集中重复记录的数量及其占总记录的比例,帮助我们评估数据的冗余程度。
def duplicate_info(df):
"""
计算并输出数据集的重复值统计。
参数:
df: 数据集(DataFrame)
"""
total_count = len(df)
# 计算数据集的重复值及占比
duplicate_count = df.duplicated().sum()
duplicate_percent = (duplicate_count / total_count) * 100
duplicate_info_df = pd.DataFrame({
'Duplicate_Count': [duplicate_count],
'Duplicate_Percent': [duplicate_percent]
})
# 输出重复值统计
print("Duplicate Rows Information:")
print(tabulate(duplicate_info_df, headers='keys', tablefmt='psql'))
# 检验训练集中是否存在重复数据
duplicate_info(train_df)
Duplicate Rows Information:
+----+-------------------+---------------------+
| | Duplicate_Count | Duplicate_Percent |
|----+-------------------+---------------------|
| 0 | 0 | 0 |
+----+-------------------+---------------------+
关键发现:
数据质量检查不应仅限于训练集。在模型实际部署时,新数据(或测试集)可能出现训练阶段未见过的离散值(Categorical Values),这会影响模型的预测效果。
我们必须检查测试集中的离散字段取值是否都落在训练集的取值范围内。为此,我们实现了一个类别取值比对工具,通过对比训练集和测试集的类别取值来识别测试集中的新值。
def check_categorical_values(train_df, test_df, categorical_cols):
"""
检查测试数据集中指定的离散字段取值是否都包含在训练数据集中。
参数:
categorical_cols (list): 要检查的离散字段列表
"""
issues = {}
for col in categorical_cols:
train_values = set(train_df[col].unique())
test_values = set(test_df[col].unique())
if not test_values.issubset(train_values):
issues[col] = test_values - train_values
if issues:
print("以下离散字段在测试集中包含训练集中未出现的取值:")
for col, values in issues.items():
print(f"{col}: {values}")
else:
print("测试集中所有离散字段的取值都包含在训练集中。")
check_categorical_values(train_df, test_df, categorical_cols)
以下离散字段在测试集中包含训练集中未出现的取值:
EnvironmentSatisfaction: {0}
JobInvolvement: {0}
StockOptionLevel: {4}
关键发现:
EnvironmentSatisfaction
、JobInvolvement
和 StockOptionLevel
字段出现了训练集中未见过的值。异常值(Outliers)是指远离数据主体分布的极端值。这些值可能源于录入错误,也可能反映真实的极端情况。无论成因如何,异常值都可能对某些模型(尤其是对距离或方差敏感的模型)产生过度影响,降低模型性能和稳定性。
我们通常采用两步法进行异常值分析:先用统计方法(如三倍标准差法)进行初步筛选,再结合可视化工具(箱线图、直方图)和业务理解做出最终判断。
为了识别异常值,我们采用三倍标准差法(即认为数据超出平均值±3倍标准差范围的为异常值)。
我们定义了一个函数,它有两个输出:一个用于显示每个字段的异常值数量,另一个返回包含异常值的字段名称列表,便于后续的数据可视化分析。
def find_outliers(df, cols=None):
"""
利用三倍标准差法来识别数据集中每列的异常值。如果没有传入列列表,则自动检测所有数值型字段。
参数:
df (pd.DataFrame): 要分析的数据集。
cols (list, optional): 要分析的字段列表。如果没有提供,则自动检测所有数值型列。
返回:
tuple: 包含两个元素的元组:
- outlier_counts (dict): 每个列的异常值数量。
- outlier_columns (list): 存在异常值的字段名称列表。
"""
if cols is None:
cols = df.select_dtypes(include=['int64', 'float64']).columns
outlier_counts = {}
outlier_columns = []
for col in cols:
mean = df[col].mean()
std = df[col].std()
lower_bound = mean - 3 * std
upper_bound = mean + 3 * std
outliers = df[(df[col] < lower_bound) | (df[col] > upper_bound)]
outlier_count = outliers.shape[0]
if outlier_count > 0:
outlier_columns.append(col)
outlier_counts[col] = outlier_count
return outlier_counts, outlier_columns
# 检查异常值
outlier_counts, outlier_columns = find_outliers(train_df)
# 打印每个字段的异常值数量
for col, count in outlier_counts.items():
if count > 0:
print(f"字段: {col} - 异常值数量: {count}")
字段: DailyRate - 异常值数量: 1
字段: Education - 异常值数量: 1
字段: JobLevel - 异常值数量: 1
字段: MonthlyIncome - 异常值数量: 9
字段: TotalWorkingYears - 异常值数量: 21
字段: YearsAtCompany - 异常值数量: 26
字段: YearsInCurrentRole - 异常值数量: 21
字段: YearsSinceLastPromotion - 异常值数量: 34
字段: YearsWithCurrManager - 异常值数量: 16
注意:
统计方法识别的”异常值”并非都需要处理。例如,高收入或长工龄可能是合理的极端情况。
对筛选出的可疑字段,我们需要通过箱线图和直方图进行可视化检查。
def visualize_outliers(df, cols):
"""
使用箱线图和直方图可视化数据集中可能存在异常值的字段。
参数:
df (pd.DataFrame): 要分析的数据集。
cols (list): 要可视化的列的列表。
"""
for col in cols:
plt.rcParams['axes.facecolor'] = '#F4F2F0'
plt.rcParams['figure.facecolor'] = '#F4F2F0'
plt.rcParams['axes.grid'] = False
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
# 绘制箱线图
sns.boxplot(y=df[col], color='#91ABB4', ax=axes[0])
axes[0].set_title(f'Boxplot of {col}')
# 绘制直方图
sns.histplot(df[col], bins=30, kde=True, color='#214D5C', ax=axes[1])
axes[1].set_title(f'Histogram of {col}')
plt.tight_layout()
plt.show()
# 可视化存在异常值的列
visualize_outliers(train_df, outlier_columns)
结合图表分析:
DailyRate
:箱线图显示一个异常高的值(接近4000),这可能是录入错误或特殊情况,需要重点关注。Education
:出现值为15的异常点,而其他编码都在1-5范围内,这很可能是异常值。JobLevel
:出现值为7的点,如果职级体系中不存在7级,则属于异常值。MonthlyIncome
、TotalWorkingYears
、YearsAtCompany
等字段虽有统计意义上的”异常点”(如特高收入、特长工龄),但这些可能是正常的长尾分布。处理这类值时,应结合业务逻辑(如收入是否合理)和对模型的影响来决定。处理策略考量: 对确认的异常值,我们有以下处理方法:
具体采用哪种方法,需要根据异常值的性质、数量及模型对异常值的敏感程度来决定。
本文作为系列的第一篇,重点介绍了启动员工流失预测项目的两个核心前期步骤:
这些基础工作是确保后续模型可靠性和有效性的关键保障。只有建立在高质量且充分理解的数据之上,模型才能真正发挥其预测和洞察的价值。
在下一篇文章中,我们将进入特征工程阶段,基于对数据和业务的深入理解,创建、转换和选择特征,为模型训练做好充分准备。