数据挖掘与R语言

第15讲:决策树与回归树——分类决策树

2026年05月20日

上讲回顾

  • 多变量逻辑回归:右边线性组合加入多个自变量,每个系数是偏效应
  • step() 逐步回归:按 AIC 自动加/删变量,推荐 direction = "both"
  • 预测五步predict(type="response")ifelse 截断 → 混淆矩阵 → 指标解读
  • 评估指标:准确率、灵敏度、特异度、Kappa,截断点选择是业务决策
  • 逻辑回归本质上是线性分类器。今天学习决策树——一种基于规则划分的非线性方法。

本讲内容

  • Part 1:决策树基本概念 ——生活例子、树结构与核心术语
  • Part 2:基尼指数(不纯度) ——如何量化"纯度"、手工计算示例
  • Part 3:分裂过程 ——递归分裂的完整逻辑
  • Part 4:过拟合与剪枝 ——cp 参数的作用
  • Part 5:Titanic 实战演练 ——完整 R 操作流程
  • Part 6:你必须掌握什么 ——学习路线图

Part 1:决策树基本概念

像流程图一样做决策

生活中的决策树

我们每天都在"用决策树"做判断——只是没有意识到:

例:今天要不要带伞?

今天会下雨吗?
├─ 是 ──→ 带伞 ☂️
└─ 否 ──→ 出门有太阳吗?
           ├─ 是 ──→ 带防晒 🧴
           └─ 否 ──→ 什么都不带 😎

例:这封邮件是垃圾邮件吗?

包含"免费领取"字样?
├─ 是 ──→ 发件人是陌生人?
│          ├─ 是 ──→ 🗑️ 垃圾邮件
│          └─ 否 ──→ 📬 正常邮件
└─ 否 ──→ 📬 正常邮件

提示

核心思想: 每次提一个问题,把数据"分成两堆",直到每堆里的样本属于同一类别为止。

决策树的三个基本要素

  • 根节点:包含整个样本集
  • 内部节点:对应一次决策过程(一次属性测试)
  • 叶节点:对应一个决策结果,不能再分裂

分类决策树:因变量为因子变量

生成过程

  • 特征选择
  • 树生成
  • 树剪枝

与逻辑回归的对比

逻辑回归 决策树
决策边界 线性(直线/超平面) 非线性(轴对齐的矩形区域)
可解释性 系数需要解读 直接读树形规则
特征缩放 需要标准化 不需要
处理类别变量 需要虚拟变量 直接处理
过拟合风险 较低 较高,需剪枝
输出 概率值 类别标签(或概率)

. . .

注记

决策树最大的优势是可解释性——每一条规则都可以用自然语言表达。这在医疗诊断、信用评分等需要"说清楚为什么"的场景中极为重要。

Part 2:基尼指数(Gini index) (基尼不纯度)

如何用数字衡量"混乱程度"

什么是基尼指数?

直觉理解: 一个节点越"纯",基尼指数越小,里面的样本越整齐,预测就越容易。

三个装球的盒子:

盒子 内容 纯度
盒子 A 全是红球(10个红) 最纯
盒子 B 红球多,绿球少(8红2绿) 较纯
盒子 C 红绿各半(5红5绿) 最混

基尼不纯度(Gini Impurity) 把这个直觉量化:

\[\text{Gini}(t) = 1 - \sum_{k=1}^{K} p_k^2\]

其中 \(p_k\) 是节点中第 \(k\) 类样本的比例,\(K\) 是类别总数。

手工计算:三种情况

情况一:盒子 A(全是红球)

\(p_{\text{红}} = 1,\quad p_{\text{绿}} = 0\)

\[\text{Gini} = 1 - (1^2 + 0^2) = 1 - 1 = \mathbf{0} \quad \text{(最纯)}\]

情况二:盒子 C(红绿各半)

\(p_{\text{红}} = 0.5,\quad p_{\text{绿}} = 0.5\)

\[\text{Gini} = 1 - (0.5^2 + 0.5^2) = 1 - 0.5 = \mathbf{0.5} \quad \text{(最混)}\]

情况三:三类均匀分布(如 iris 三个品种各占 1/3,或三类等比例的多分类问题)

\(p_1 = p_2 = p_3 = 1/3\)

\[\text{Gini} = 1 - \left[\left(\frac{1}{3}\right)^2 \times 3\right] = 1 - \frac{1}{3} = \mathbf{0.667}\]

Gini 值的范围与含义

贷款数据:认识数据集

图片来源:教材表 5-1。共 10 条记录,变量如下:

变量 类型 取值
房产状况 类别 是 / 否
婚姻状况 类别 未婚 / 已婚
年收入 有序类别 差 / 良 / 优
类别(因变量) 类别 是(贷款)/ 否(不贷款)
▶️ 查看代码
# 10条贷款训练数据(与教材表5-1一致)
print(loan)
   id house marriage income label
1   1    是     未婚     良    否
2   2    否     已婚     良    否
3   3    否     未婚     差    否
4   4    是     已婚     良    否
5   5    否     未婚     良    否
6   6    否     已婚     差    否
7   7    是     已婚     优    否
8   8    否     未婚     良    是
9   9    否     已婚     良    否
10 10    否     未婚     良    是

贷款数据:根节点 Gini 计算

整个数据集(根节点,10条记录):

  • 类别"是"(贷款):2 条(序号 8、10)
  • 类别"否"(不贷款):8 条

\[p_{\text{是}} = \frac{2}{10} = 0.2, \quad p_{\text{否}} = \frac{8}{10} = 0.8\]

\[\text{Gini}_{\text{根}} = 1 - (0.2^2 + 0.8^2) = 1 - (0.04 + 0.64) = \mathbf{0.32}\]

用"房产状况"分裂:

子节点 内容 Gini
左(有房产) 序号 1、4、7 0 3 0.000
右(无房产) 序号 2、3、5、6、8、9、10 2 5 0.408

\[\text{Gini}_{\text{房产}} = \frac{3}{10}\times 0 + \frac{7}{10}\times 0.408 = \mathbf{0.286}\]

Gini 下降: \(0.32 - 0.286 = 0.034\)

贷款数据:比较三个特征的分裂效果

提示

结论: 房产状况的 Gini 下降量最大(0.034),因此被选为根节点的分裂变量。有房产 → 全部拒绝贷款(Gini=0,纯节点);无房产 → 继续分裂。

贷款数据:决策树结构(教材示例)

Part 3:分裂过程

递归分裂的完整逻辑

决策树的生长过程:以iris数据为例

  1. 从根节点出发:整个训练集放入根节点,计算 Gini = 0.667
  2. 寻找最佳分裂:遍历所有特征(Sepal.Length / Sepal.Width / Petal.Length / Petal.Width)的所有可能阈值,找使加权 Gini 下降最多的那个
  3. 执行分裂:将数据一分为二,生成两个子节点
  4. 递归重复:对每个子节点再次执行步骤 2–3
  5. 停止条件:节点已经"纯"(Gini = 0)、样本数太少、或 cp 参数阻止继续分裂

提示

为什么 Petal.Length 最先被选中?

花瓣长度(Petal.Length)能完美分离 setosa(≤ 2.45 cm)与另外两种,Gini 下降幅度最大(0.334),远超其他特征——所以它成为根节点的分裂变量。

Titanic 数据可视化:谁更容易幸存?

决策树的生长过程(以 Titanic 为例)

  1. 从根节点出发:全部乘客(约710人),幸存率约 38%,计算 Gini
  2. 寻找最佳分裂:遍历 Sex / Pclass / Age / Fare 的所有阈值,找Gini 下降最大的那个
  3. 第一次分裂Sex(性别)效果最强——男女幸存率差异最大,Gini 下降最多
  4. 递归重复:左子树(女性)再按 Pclass / Age 分裂;右子树(男性)再按 Age 分裂
  5. 停止条件:样本数太少,或 cp 参数阻止继续分裂

提示

为什么 Sex 最先被选中?

女性幸存率 ≈ 74%,男性幸存率 ≈ 19%,差异极大。用 Sex 分裂后,两个子节点的组内纯度大幅提升,Gini 下降最多——所以 Sex 成为根节点的分裂变量。

Part 4:过拟合与剪枝

cp 参数控制树的复杂度

树太深会怎样?

复杂度参数 cp 详解

重要

rpart 的分裂判断标准:

\[\text{只有当某次分裂能使相对误差下降} \geq cp \text{ 时,才执行该分裂}\]

cp 值 含义 树的形态
cp = 0.01(默认) 只要误差下降 ≥ 1%,就分裂 较深,节点多
cp = 0.1 误差下降 ≥ 10% 才分裂 中等深度
cp = 0.5 误差下降 ≥ 50% 才分裂 极浅,通常仅 1–2 层
cp = 0 任何分裂都执行 最深(完全生长)

选择 cp 的科学方法:

printcp() → 找 xerror(交叉验证误差)最小的行 → 取对应 CP 值 → prune(fit, cp = ?)

Part 5:Titanic 实战演练

用泰坦尼克号数据完整走一遍

数据介绍:Titanic 泰坦尼克号数据集

▶️ 查看代码
# 读入数据并清理
titanic_raw <- read.csv("titanic.csv")

titanic <- titanic_raw |>
  select(Survived, Pclass, Sex, Age, Fare, Embarked) |>
  mutate(
    Survived = factor(Survived, levels=c(0,1), labels=c("遇难","幸存")),
    Pclass   = factor(Pclass,   levels=c(1,2,3), labels=c("一等舱","二等舱","三等舱")),
    Sex      = factor(Sex)
  ) |>
  filter(!is.na(Age))

str(titanic)
'data.frame':   714 obs. of  6 variables:
 $ Survived: Factor w/ 2 levels "遇难","幸存": 1 2 2 2 1 1 1 2 2 2 ...
 $ Pclass  : Factor w/ 3 levels "一等舱","二等舱",..: 3 1 3 1 3 1 3 3 2 3 ...
 $ Sex     : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 1 1 1 ...
 $ Age     : num  22 38 26 35 35 54 2 27 14 4 ...
 $ Fare    : num  7.25 71.28 7.92 53.1 8.05 ...
 $ Embarked: chr  "S" "C" "S" "S" ...

数据概览

▶️ 查看代码
# 因变量:幸存/遇难分布
table(titanic$Survived)

遇难 幸存 
 424  290 
▶️ 查看代码
prop.table(table(titanic$Survived)) |> round(3)

 遇难  幸存 
0.594 0.406 

步骤一:划分训练集与测试集

▶️ 查看代码
library(rsample)

set.seed(42)

# initial_split() 按 Survived 分层,70% 训练,30% 测试
split    <- initial_split(titanic, prop = 0.7, strata = "Survived")
train_df <- training(split)
test_df  <- testing(split)
▶️ 查看代码
# 验证幸存比例是否均衡(分层抽样效果)
prop.table(table(train_df$Survived)) |> round(3)

 遇难  幸存 
0.593 0.407 
▶️ 查看代码
prop.table(table(test_df$Survived))  |> round(3)

 遇难  幸存 
0.595 0.405 

提示

rsample::initial_split() 是 tidymodels 生态的数据划分函数,strata 参数实现分层抽样,确保训练集和测试集中幸存比例一致,比 sample() 随机划分更稳健。

步骤二:建立分类决策树

▶️ 查看代码
library(rpart)

# 以 Survived 为因变量,Pclass / Sex / Age / Fare 为自变量
fit1 <- rpart(Survived ~ Pclass + Sex + Age + Fare,
              data   = train_df,
              method = "class")

# 查看 cp 表
printcp(fit1)

Classification tree:
rpart(formula = Survived ~ Pclass + Sex + Age + Fare, data = train_df, 
    method = "class")

Variables actually used in tree construction:
[1] Age    Fare   Pclass Sex   

Root node error: 203/499 = 0.4

n= 499 

    CP nsplit rel error xerror xstd
1 0.52      0       1.0    1.0 0.05
2 0.03      1       0.5    0.5 0.04
3 0.03      2       0.4    0.5 0.04
4 0.01      4       0.4    0.4 0.04

注记

Variables actually used:观察模型实际使用了哪些变量。Sex(性别)和 Pclass(舱位) 通常排在最前面——这与我们在数据探索中看到的规律一致。

步骤三:查看模型摘要(部分)

▶️ 查看代码
# 查看节点分裂细节
fit1$splits
       count ncat improve index    adj
Sex      499    2  83.488  1.00 0.0000
Fare     499   -1  24.407 52.28 0.0000
Pclass   499    3  23.563  2.00 0.0000
Age      499    1   6.988  5.50 0.0000
Fare       0   -1   0.663 56.71 0.0820
Age        0    1   0.637  0.79 0.0109
Age      316    1   9.411  5.50 0.0000
Fare     316   -1   5.224 29.85 0.0000
Pclass   316    3   4.169  3.00 0.0000
Pclass   183    3  14.276  4.00 0.0000
Fare     183   -1   5.630 48.20 0.0000
Age      183   -1   2.915 12.00 0.0000
Fare       0   -1   0.776 22.51 0.4306
Age        0   -1   0.710 22.50 0.2639
Fare      72    1   6.750 20.66 0.0000
Age       72    1   1.531 36.50 0.0000
Age        0    1   0.792 37.50 0.1667
▶️ 查看代码
fit1
n= 499 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 499 203 遇难 (0.5932 0.4068)  
   2) Sex=male 316  59 遇难 (0.8133 0.1867)  
     4) Age>=5.5 301  48 遇难 (0.8405 0.1595) *
     5) Age< 5.5 15   4 幸存 (0.2667 0.7333) *
   3) Sex=female 183  39 幸存 (0.2131 0.7869)  
     6) Pclass=三等舱 72  33 幸存 (0.4583 0.5417)  
      12) Fare>=20.7 18   3 遇难 (0.8333 0.1667) *
      13) Fare< 20.7 54  18 幸存 (0.3333 0.6667) *
     7) Pclass=一等舱,二等舱 111   6 幸存 (0.0541 0.9459) *

注记

fit1$splits 第一行是根节点的分裂变量——预期是 Sex,因为性别对幸存率的影响最大(女性幸存率远高于男性),分裂后 Gini 下降幅度最大。

步骤四:决策树可视化

▶️ 查看代码
library(rpart.plot)

rpart.plot(fit1,
           type          = 4,
           extra         = 104,
           fallen.leaves = TRUE,
           box.palette   = "RdYlGn",
           cex           = 0.85,
           main          = "分类决策树:Titanic 幸存预测(fit1)")

解读决策路径

根据上图,从树中读出典型的决策路径:

路径一(预测为幸存):

Sex = female(女性) → 预测为 幸存 ✅ 女性幸存率高,该叶节点多数为幸存

路径二(预测为遇难):

Sex = male(男性)且 Age ≥ 某阈值(年龄较大) → 预测为 遇难

路径三(预测为遇难):

Sex = male(男性)且 Age < 某阈值(年龄较小)且 Pclass = 三等舱 → 预测为 遇难

提示

读树口诀: 从根节点出发,满足条件走,不满足走,到达叶节点即为预测类别。实际路径以 rpart.plot() 图形为准。

步骤五:在测试集上预测

▶️ 查看代码
# type = "class" 输出类别标签
pred1 <- predict(fit1,
                 newdata = test_df,
                 type    = "class")

# 查看预测结果分布
table(pred1)
pred1
遇难 幸存 
 134   81 
▶️ 查看代码
# 对比真实值与预测值的前10行
data.frame(
  真实结果 = test_df$Survived,
  预测结果 = pred1
) |> head(10)
   真实结果 预测结果
1      遇难     幸存
2      幸存     幸存
3      遇难     遇难
4      遇难     遇难
5      遇难     幸存
6      遇难     幸存
7      幸存     遇难
8      遇难     遇难
9      遇难     遇难
10     遇难     遇难

步骤六:混淆矩阵与准确性评估

▶️ 查看代码
cm1 <- confusionMatrix(pred1, test_df$Survived, positive = "幸存")

# 混淆矩阵
cm1
Confusion Matrix and Statistics

          Reference
Prediction 遇难 幸存
      遇难  104   30
      幸存   24   57
                                        
               Accuracy : 0.749         
                 95% CI : (0.685, 0.805)
    No Information Rate : 0.595         
    P-Value [Acc > NIR] : 1.69e-06      
                                        
                  Kappa : 0.473         
                                        
 Mcnemar's Test P-Value : 0.496         
                                        
            Sensitivity : 0.655         
            Specificity : 0.812         
         Pos Pred Value : 0.704         
         Neg Pred Value : 0.776         
             Prevalence : 0.405         
         Detection Rate : 0.265         
   Detection Prevalence : 0.377         
      Balanced Accuracy : 0.734         
                                        
       'Positive' Class : 幸存          
                                        

注记

混淆矩阵读法: 行是真实结果,列是预测结果。对角线数字越大越好。召回率/灵敏度(Sensitivity) = 实际幸存中被正确预测的比例;特异度(Specificity) = 实际遇难中被正确预测的比例。

步骤七:cp = 0.5 剪枝建立 fit2

▶️ 查看代码
# 激进剪枝:cp = 0.5
fit2 <- rpart(Survived ~ Pclass + Sex + Age + Fare,
              data    = train_df,
              method  = "class",
              control = rpart.control(cp = 0.5))

printcp(fit2)

Classification tree:
rpart(formula = Survived ~ Pclass + Sex + Age + Fare, data = train_df, 
    method = "class", control = rpart.control(cp = 0.5))

Variables actually used in tree construction:
[1] Sex

Root node error: 203/499 = 0.4

n= 499 

   CP nsplit rel error xerror xstd
1 0.5      0       1.0    1.0 0.05
2 0.5      1       0.5    0.6 0.05

重要

cp = 0.5 表示只有分裂能使误差下降 ≥ 50% 时才执行。在 Titanic 数据中,通常只有第一次分裂(用 Sex 分离男女)满足条件,因此 fit2 的树只有1次分裂、2个叶节点:男性预测遇难,女性预测幸存。

fit2 可视化

▶️ 查看代码
rpart.plot(fit2,
           type          = 4,
           extra         = 104,
           fallen.leaves = TRUE,
           box.palette   = "RdYlGn",
           cex           = 1.0,
           main          = "剪枝后的决策树(cp = 0.5)—— fit2")

注记

fit2 只有一个分裂点:Sex。女性 → 幸存;男性 → 遇难。这与历史记录"妇女儿童优先"的救援原则吻合,但牺牲了对男性幸存者的识别能力。

步骤八:对比 fit1 与 fit2

▶️ 查看代码
pred2 <- predict(fit2, newdata = test_df, type = "class")
cm2   <- confusionMatrix(pred2, test_df$Survived, positive = "幸存")

# 混淆矩阵对比
cm2
Confusion Matrix and Statistics

          Reference
Prediction 遇难 幸存
      遇难  103   34
      幸存   25   53
                                        
               Accuracy : 0.726         
                 95% CI : (0.661, 0.784)
    No Information Rate : 0.595         
    P-Value [Acc > NIR] : 4.67e-05      
                                        
                  Kappa : 0.421         
                                        
 Mcnemar's Test P-Value : 0.298         
                                        
            Sensitivity : 0.609         
            Specificity : 0.805         
         Pos Pred Value : 0.679         
         Neg Pred Value : 0.752         
             Prevalence : 0.405         
         Detection Rate : 0.247         
   Detection Prevalence : 0.363         
      Balanced Accuracy : 0.707         
                                        
       'Positive' Class : 幸存          
                                        
▶️ 查看代码
modelsummary(list(cm1,cm2))
(1) (2)
accuracy 0.749 0.726
kappa 0.473 0.421
sensitivity 0.655 0.609
specificity 0.812 0.805
pos_pred_value 0.704 0.679
neg_pred_value 0.776 0.752
precision 0.704 0.679
recall 0.655 0.609
f1 0.679 0.642
prevalence 0.405 0.405
detection_rate 0.265 0.247
detection_prevalence 0.377 0.363
balanced_accuracy 0.734 0.707

两模型对比分析

重要

结论: cp = 0.5 使树退化为仅用性别做判断(男→遇难,女→幸存),忽略了舱位、年龄等重要信息,准确率和 Kappa 均下降。实际应用中应用 printcp() 找最优 cp。

科学选择最优 cp

▶️ 查看代码
# 找 xerror 最小的 cp
best_cp <- fit1$cptable[which.min(fit1$cptable[, "xerror"]), "CP"]
cat("交叉验证误差最小的 cp:", best_cp, "\n")
交叉验证误差最小的 cp: 0.01 
▶️ 查看代码
# 用最优 cp 剪枝
fit_best <- prune(fit1, cp = best_cp)
▶️ 查看代码
rpart.plot(fit_best,
           type = 4, extra = 104, fallen.leaves = TRUE,
           box.palette = "RdYlGn",
           cex = 0.85, main = "最优剪枝树(xerror 最小)")

▶️ 查看代码
pred_best <- predict(fit_best, newdata = test_df, type = "class")
cm3   <- confusionMatrix(pred_best, test_df$Survived, positive = "幸存")

# 混淆矩阵对比
cm3
Confusion Matrix and Statistics

          Reference
Prediction 遇难 幸存
      遇难  104   30
      幸存   24   57
                                        
               Accuracy : 0.749         
                 95% CI : (0.685, 0.805)
    No Information Rate : 0.595         
    P-Value [Acc > NIR] : 1.69e-06      
                                        
                  Kappa : 0.473         
                                        
 Mcnemar's Test P-Value : 0.496         
                                        
            Sensitivity : 0.655         
            Specificity : 0.812         
         Pos Pred Value : 0.704         
         Neg Pred Value : 0.776         
             Prevalence : 0.405         
         Detection Rate : 0.265         
   Detection Prevalence : 0.377         
      Balanced Accuracy : 0.734         
                                        
       'Positive' Class : 幸存          
                                        
▶️ 查看代码
modelsummary(list(cm1, cm2, cm3))
(1) (2) (3)
accuracy 0.749 0.726 0.749
kappa 0.473 0.421 0.473
sensitivity 0.655 0.609 0.655
specificity 0.812 0.805 0.812
pos_pred_value 0.704 0.679 0.704
neg_pred_value 0.776 0.752 0.776
precision 0.704 0.679 0.704
recall 0.655 0.609 0.655
f1 0.679 0.642 0.679
prevalence 0.405 0.405 0.405
detection_rate 0.265 0.247 0.265
detection_prevalence 0.377 0.363 0.377
balanced_accuracy 0.734 0.707 0.734

Part 6:你必须掌握什么?

学习路线图

完整建模流程回顾

必须掌握

重要

以下是本讲分类决策树的完整核心,期末考试必考:

  1. 决策树生成过程: 特征选择、树生成、树剪枝

  2. 决策树原理:基尼指数(基尼不纯度)的含义;树如何递归分裂;某个特征为何最先被选中(Gini指数 下降最大)

  3. R 操作rpart(y ~ ., data, method = "class") 建树;printcp() 查 cp 表

  4. 可视化与路径解读rpart.plot(type=4, extra=104) 画树;从根节点读出"若……则……"规则

  5. 预测与评估predict(type = "class")confusionMatrix() → 准确率、Kappa、各类灵敏度

  6. 剪枝:cp 的含义;prune(fit, cp=?) 后剪枝;用 which.min(cptable[,"xerror"]) 找最优 cp

  7. 模型对比

  8. 决策树优点

  • 模型构建和预测过程可以图形方式显示,易于理解和解释
  • 不需要大规模的训练数据,不需要进行数据标准化操作
  • 能够处理数值型和分类型数据
  • 能够处理多输出问题,同时预测多个结果
  • 是白盒模型中的一种,相对于黑盒模型结果更容易解释
  • 对于异常的数据点有更强的包容性

常见错误清单

警告

请对照检查:

错误 正确做法
忘记 method = "class" 不写默认为回归树;因变量是类别时必须method = "class"
predict() 不写 type = "class" 默认输出概率矩阵;分类评估必须type = "class"
只看准确率判断模型好坏 还需比较 Kappa、各类灵敏度,并考虑可解释性
选 cp = 0.5 就以为是合理剪枝 0.5 是极端值,应用 xerror 最小原则科学选 cp
混淆矩阵行列方向读错 行 = 预测值,列 = 真实值;对角线 = 预测正确

本讲小结

  • 基尼指数(不纯度):衡量节点混乱程度;Gini = 0 最纯,越大越混;每次分裂选 Gini 下降最大的特征与阈值

  • Titanic 规律:Sex(性别)是最强的分裂变量;女性幸存率远高于男性;舱位(Pclass)和年龄(Age)在后续分裂中发挥作用

  • 建模流程rsample::initial_split(strata=) 分层抽样 → rpart(method="class") 建树 → rpart.plot() 可视化 → predict() + confusionMatrix() 评估

  • 剪枝对比:fit1(cp=0.01)细化利用多个特征;fit2(cp=0.5)只用性别做单次分裂,准确率下降;科学做法:which.min(xerror) 找最优 cp

课后作业

作业说明

项目介绍: 现有一组学生信息,分别为学生的跑步成绩 jogging、游泳成绩 swim、跳高成绩 jump 和成绩等级 Grade,现需要通过学生的跑步成绩、游泳成绩、跳高成绩建立决策树模型,以预测学生的成绩等级。将 70% 数据进行模型训练,30% 数据进行模型验证,分析模型预测的准确性。

步骤 1:读入数据与划分数据集

▶️ 查看代码
library(tidyverse)
library(caret)
library(rpart)
library(rpart.plot)

# 读入数据
students <- read.csv("students.csv")

# 划分数据集(70% 训练,30% 测试)
set.seed(123)
train_index <- createDataPartition(students$Grade, p = 0.7, list = FALSE)
train_df    <- students[ train_index, ] |> mutate(Grade = factor(Grade))
test_df     <- students[-train_index, ] |>
                 mutate(Grade = factor(Grade, levels = levels(train_df$Grade)))

# 查看训练集结构
str(train_df)
'data.frame':   105 obs. of  4 variables:
 $ jogging: num  4.7 4.6 5 4.6 5 4.4 4.9 5.4 4.8 4.8 ...
 $ swim   : num  3.2 3.1 3.6 3.4 3.4 2.9 3.1 3.7 3.4 3 ...
 $ jump   : num  1.3 1.5 1.4 1.4 1.5 1.4 1.5 1.5 1.6 1.4 ...
 $ Grade  : Factor w/ 3 levels "A","B","C": 1 1 1 1 1 1 1 1 1 1 ...
▶️ 查看代码
table(train_df$Grade)

 A  B  C 
35 35 35 

要求: 说明训练集有多少行、多少列;各变量类型;各等级(Grade)的样本数量及比例。

步骤 2:建立分类决策树模型

▶️ 查看代码
# 建立分类决策树 fit1
fit1 <- rpart(Grade ~ jogging + swim + jump,
              data   = train_df,
              method = "class")

# 查看 cp 表(截取输出即可)
printcp(fit1)

Classification tree:
rpart(formula = Grade ~ jogging + swim + jump, data = train_df, 
    method = "class")

Variables actually used in tree construction:
[1] jump

Root node error: 70/105 = 0.7

n= 105 

    CP nsplit rel error xerror xstd
1 0.50      0      1.00    1.3 0.05
2 0.44      1      0.50    0.8 0.07
3 0.01      2      0.06    0.1 0.04

要求: 截取 printcp() 输出,说明模型用到了哪些变量进行分裂;报告根节点误差(Root node error)是多少。

步骤 3:可视化并总结决策路径

▶️ 查看代码
# 可视化
rpart.plot(fit1, type = 4, extra = 104,
           fallen.leaves = TRUE, box.palette = "RdYlGn", cex = 0.8,
           main = "分类决策树:学生成绩等级预测")

要求: 根据可视化结果,至少总结 一条 完整的决策路径(从根节点到叶节点),用文字描述"若……则预测等级为……"的规则,并说明该叶节点包含多少比例的训练样本。

步骤 4:测试集预测与混淆矩阵

▶️ 查看代码
# 预测测试集
pred1 <- predict(fit1, newdata = test_df, type = "class")

# 混淆矩阵
cm1 <- table(pred1, test_df$Grade)
confusionMatrix(cm1)
Confusion Matrix and Statistics

     
pred1  A  B  C
    A 15  0  0
    B  0 13  1
    C  0  2 14

Overall Statistics
                                        
               Accuracy : 0.933         
                 95% CI : (0.817, 0.986)
    No Information Rate : 0.333         
    P-Value [Acc > NIR] : <2e-16        
                                        
                  Kappa : 0.9           
                                        
 Mcnemar's Test P-Value : NA            

Statistics by Class:

                     Class: A Class: B Class: C
Sensitivity             1.000    0.867    0.933
Specificity             1.000    0.967    0.933
Pos Pred Value          1.000    0.929    0.875
Neg Pred Value          1.000    0.935    0.966
Prevalence              0.333    0.333    0.333
Detection Rate          0.333    0.289    0.311
Detection Prevalence    0.333    0.311    0.356
Balanced Accuracy       1.000    0.917    0.933

要求: 列出混淆矩阵;报告整体准确率(Accuracy)与 Kappa 系数;分析各等级的预测情况,指出哪个等级预测效果较差及可能原因。

步骤 5:剪枝建立 fit2

▶️ 查看代码
# 以 cp = 0.5 建立剪枝模型 fit2
fit2 <- rpart(Grade ~ jogging + swim + jump,
              data    = train_df,
              method  = "class",
              control = rpart.control(cp = 0.5))

# 可视化新模型
rpart.plot(fit2, type = 4, extra = 104,
           fallen.leaves = TRUE, box.palette = "RdYlGn", cex = 0.9,
           main = "剪枝决策树(cp = 0.5)")

要求: 描述 fit2 的树形结构(有几个节点、几个叶节点);与 fit1 对比,说明 cp = 0.5 对树的形状产生了什么影响。

步骤 6:对比 fit1 与 fit2 的预测准确性

▶️ 查看代码
# fit2 在测试集上预测
pred2 <- predict(fit2, newdata = test_df, type = "class")
cm2   <- table(pred2, test_df$Grade)
confusionMatrix(cm2)
Confusion Matrix and Statistics

     
pred2  A  B  C
    A 15 15 15
    B  0  0  0
    C  0  0  0

Overall Statistics
                                     
               Accuracy : 0.333      
                 95% CI : (0.2, 0.49)
    No Information Rate : 0.333      
    P-Value [Acc > NIR] : 0.556      
                                     
                  Kappa : 0          
                                     
 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: A Class: B Class: C
Sensitivity             1.000    0.000    0.000
Specificity             0.000    1.000    1.000
Pos Pred Value          0.333      NaN      NaN
Neg Pred Value            NaN    0.667    0.667
Prevalence              0.333    0.333    0.333
Detection Rate          0.333    0.000    0.000
Detection Prevalence    1.000    0.000    0.000
Balanced Accuracy       0.500    0.500    0.500

要求: 对比两个模型的准确率和 Kappa 系数,分析剪枝后模型性能的变化;结合偏差–方差权衡,说明在该预测任务中你会选择哪个模型,并给出理由。

谢谢!

第15讲:决策树与回归树——分类决策树


「决策树是一面镜子——它把数据的结构映射成人类可以理解的规则。」