Softmax 是机器学习中常用的函数,广泛用于多分类问题的输出层。它可以将一组实数转换为一个概率分布,使得结果满足“非负”和“总和为1”的要求。在分类问题中,Softmax 让模型预测的每个类别概率都易于解释。本文将详细讲解 Softmax 的原理、公式推导、Numpy 实现及其在 Pytorch 中的实际应用。
Softmax 原理
给定一个类别集合 { y 1 , y 2 , … , y n } \{y_1, y_2, \dots, y_n\} {y1,y2,…,yn},Softmax 将模型输出的每个数值(称为“得分”或“logits”)转换为概率值。假设模型输出 z i z_i zi 为第 i i i 类的得分,Softmax 将所有的得分映射到概率空间,使每个得分转化为该类的预测概率。
Softmax 函数的公式为:
P
(
y
i
)
=
e
z
i
∑
j
=
1
n
e
z
j
P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}}
P(yi)=∑j=1nezjezi
其中
z
i
z_i
zi 表示模型为第
i
i
i 类输出的得分,
n
n
n 是类别的数量。通过对指数值的归一化处理,Softmax 函数输出的概率满足:
- 所有概率值都为非负数;
- 概率总和为 1。
Softmax 计算中的数值稳定性
在计算中,Softmax 可能会因为指数运算导致数值溢出,为了减小这种风险,可以对每个 (z_i) 值减去一个常数
max
(
z
)
\max(z)
max(z):
P
(
y
i
)
=
e
z
i
−
max
(
z
)
∑
j
=
1
n
e
z
j
−
max
(
z
)
P(y_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^n e^{z_j - \max(z)}}
P(yi)=∑j=1nezj−max(z)ezi−max(z)
这种转换不会改变概率的分布,避免了指数函数产生的大数值溢出问题。
Numpy 实现 Softmax 函数
下面通过 Numpy 实现 Softmax,并进行数据可视化以更直观地理解 Softmax 对得分的转换过程。
import numpy as np
import matplotlib.pyplot as plt
# 定义 Softmax 函数
def softmax(logits):
"""
使用数值稳定性的 Softmax 函数实现
参数:
- logits: 模型输出得分向量(shape: (n,),表示 n 个类别的得分)
返回:
- probs: 转换后的概率向量,shape: (n,)
"""
exp_shifted = np.exp(logits - np.max(logits)) # 减去 max(logits) 以确保数值稳定性
probs = exp_shifted / np.sum(exp_shifted) # 归一化为概率
return probs
# 示例输入的分类得分
logits = np.array([2.0, 1.0, 0.1])
# 使用 Softmax 函数计算各类别的概率
probs = softmax(logits)
# 输出各类的预测概率
print("分类得分:", logits)
print("预测概率:", probs)
Softmax 输出可视化
我们可以用图像展示 Softmax 如何将得分转化为概率,假设输入的分类得分范围为 -2 到 4。
# 生成模拟的分类得分范围
logit_range = np.linspace(-2, 4, 100)
all_probs = np.array([softmax([l, 1.0, 0.1]) for l in logit_range])
# 可视化不同类别的预测概率随得分变化的趋势
plt.plot(logit_range, all_probs[:, 0], label="类别 1")
plt.plot(logit_range, all_probs[:, 1], label="类别 2")
plt.plot(logit_range, all_probs[:, 2], label="类别 3")
plt.xlabel("得分 (logits)")
plt.ylabel("概率")
plt.title("Softmax 函数输出的概率分布")
plt.legend()
plt.show()
Softmax 损失函数:交叉熵损失
在多分类任务中,常用 交叉熵损失函数 来衡量模型预测概率分布与真实标签的匹配程度。对于单个样本,交叉熵损失定义为:
L
=
−
∑
i
=
1
n
y
i
⋅
log
(
P
(
y
i
)
)
L = -\sum_{i=1}^{n} y_i \cdot \log(P(y_i))
L=−i=1∑nyi⋅log(P(yi))
其中 (y_i) 是真实标签的 one-hot 编码,(P(y_i)) 是 Softmax 转换后的预测概率。
# 定义交叉熵损失函数
def cross_entropy_loss(probs, y_true):
"""
计算交叉熵损失
参数:
- probs: Softmax 预测概率 (shape: (n,))
- y_true: 实际标签 (shape: (n,)),one-hot 编码
返回:
- loss: 交叉熵损失
"""
loss = -np.sum(y_true * np.log(probs + 1e-10)) # 加1e-10防止 log(0)
return loss
# 示例计算
y_true = np.array([1, 0, 0]) # 假设类别 1 为正确类别
loss = cross_entropy_loss(probs, y_true)
print("交叉熵损失:", loss)
在 PyTorch 中使用 Softmax
在 PyTorch 中,我们可以直接调用 torch.nn.functional.softmax
来实现 Softmax。此外,PyTorch 提供的 torch.nn.CrossEntropyLoss
函数在内部自动包含了 Softmax 和交叉熵的计算,无需显式计算。
import torch
import torch.nn.functional as F
# 示例:在 PyTorch 中实现 Softmax 和交叉熵损失
logits_torch = torch.tensor([2.0, 1.0, 0.1])
# 使用 PyTorch 的 Softmax 函数
probs_torch = F.softmax(logits_torch, dim=0)
print("PyTorch 预测概率:", probs_torch.numpy())
# 使用交叉熵损失函数
y_true_index = torch.tensor([0]) # 假设第一个类别为正确类别
loss_fn = torch.nn.CrossEntropyLoss()
loss_torch = loss_fn(logits_torch.unsqueeze(0), y_true_index)
print("PyTorch 交叉熵损失:", loss_torch.item())
在 PyTorch 中,torch.nn.CrossEntropyLoss
在传入 logits 后自动应用 Softmax 和交叉熵计算,为多分类问题提供了便捷的计算方式。
总结
本文介绍了 Softmax 的原理、公式、Numpy 实现、可视化以及在 PyTorch 中的使用。Softmax 是将得分转化为概率分布的关键函数,尤其适用于多分类任务。我们还探讨了数值稳定性的处理以及交叉熵损失在多分类中的作用,理解并实现 Softmax 有助于构建更稳定且易解释的分类模型。