从理论到代码手把手实现Evidential Deep Learning中的Dirichlet分布分类器附PyTorch示例在构建一个图像分类模型时我们通常关心的是它的准确率。模型在测试集上达到了99%的准确率这听起来很棒不是吗但当我们把模型部署到现实世界面对一张模糊不清、光线诡异或者压根不属于训练类别的图片时模型依然会“自信满满”地给出一个预测。这种“过度自信”在自动驾驶、医疗诊断等高风险领域是致命的。模型不仅需要知道“是什么”更需要知道“什么时候不知道”。这正是不确定性量化Uncertainty Quantification的核心价值。传统的深度学习模型如使用Softmax的分类器输出的是一个归一化的概率分布但这个概率更多反映的是模型在已知类别间的相对偏好而非模型对自身预测的“信心”或“证据”强度。一个输出为[0.51, 0.49]的模型和一个输出为[0.99, 0.01]的模型其“自信”程度天差地别但Softmax本身无法区分这种内在的认知状态。今天我们将深入探讨一种前沿的解决方案证据深度学习。我们将聚焦于分类任务从Dirichlet分布的数学原理出发一步步推导出证据神经网络的损失函数并用PyTorch构建一个完整的、可运行的分类器。我们将以经典的MNIST数据集为基础通过改造实验例如引入旋转、噪声或OOD样本直观地展示模型如何量化其偶然不确定性和认知不确定性。本文面向有一定机器学习实践经验的开发者我们将避开繁复的数学证明直击代码实现与调优中的核心陷阱。1. 理论基础从Softmax到Dirichlet证据先验要理解证据深度学习我们必须先跳出Softmax的思维定式。在传统的多分类神经网络中最后一层通常是一个线性层其输出称为“logits”。Softmax函数将这些logits转化为一个概率向量p [p1, p2, ..., pK]其中∑pi 1pi 0。这个p被解释为样本属于各个类别的概率。然而这里存在一个根本性的混淆这个p描述的是数据的似然分布即给定模型参数观察到某类标签的概率而非模型对自身参数的信念分布。模型对自己的预测有多确定是铁证如山还是模棱两可Softmax输出无法直接告诉我们。证据深度学习的核心思想是我们不再让网络直接输出一个固定的概率向量p而是让网络输出一个先验分布的参数这个先验分布描述了模型对概率向量p的“信念”或“证据”。对于分类问题这个先验分布的自然选择是Dirichlet分布。Dirichlet分布是定义在K维概率单纯形上的一个连续多元概率分布。它是Beta分布的高维推广。其概率密度函数为Dir(p | α) (1 / B(α)) * ∏_{i1}^{K} p_i^{α_i - 1}其中α [α1, α2, ..., αK]是浓度参数Concentration Parameters且α_i 0。B(α)是多元Beta函数作为归一化常数。提示α_i可以被直观地理解为模型为第i个类别收集到的“证据”数量。α_i越大意味着模型越“相信”样本属于类别i。Dirichlet分布有一个美妙的性质它的期望值E[p_i]正好是α_i / S其中S ∑α_i被称为证据总量。当我们设置所有α_i 1时就得到了一个均匀的Dirichlet分布这对应于“完全无知”的先验状态——模型对任何类别都没有偏好。那么网络如何工作呢对于一个输入样本我们的证据神经网络Evidential Neural Network, ENN的输出不再是K个logits而是K个正的证据值e [e1, e2, ..., eK]。通常我们使用一个Softplus激活函数来确保ei 0。然后我们令α e 1。这里的“1”对应于均匀先验α1。因此α就是我们的Dirichlet分布参数。有了α我们可以做三件事计算预测概率p̂_i α_i / S。这类似于Softmax的输出但含义更深。计算偶然不确定性这源于数据本身的噪声可以用预测概率的方差来估计。对于Dirichlet分布类别i的方差为(α_i (S - α_i)) / (S^2 (S 1))。计算认知不确定性这源于模型知识的缺乏。一个简单而有效的度量是K / S。证据总量S越小模型总体证据不足认知不确定性就越高。当面对一个完全陌生的OOD样本时网络无法为任何类别收集到强证据所有e_i都接近0S ≈ K此时认知不确定性K/S ≈ 1达到很高水平。下表对比了传统Softmax分类器与Dirichlet证据分类器的关键区别特性传统Softmax分类器Dirichlet证据分类器输出概率向量pDirichlet浓度参数α不确定性来源隐含于概率值中难以分离可明确区分为偶然和认知不确定性OOD检测能力弱常过度自信强证据总量低时认知不确定性高训练目标最小化交叉熵损失最小化证据损失后文详述计算开销低略高需计算Psi函数等解释性概率输出解释性一般证据输出可解释性更强2. 损失函数设计原理、推导与陷阱训练一个证据神经网络最大的挑战在于设计合适的损失函数。我们的目标不再是简单地让预测概率匹配one-hot标签而是要让网络输出的Dirichlet分布参数α能够合理地反映模型的“信念”。损失函数由两部分组成拟合项和正则项。1. 拟合项最大化似然我们希望模型预测的Dirichlet分布能够使得观察到的真实标签y一个one-hot向量具有高的期望似然。这推导出的损失项是负的对数似然期望。经过数学推导具体过程可参考相关论文对于单个样本其损失为L_i ∑_{j1}^{K} y_j ( log(S) - log(α_j) )其中y_j是标签的第j个分量0或1。这个形式非常简洁它鼓励模型为真实类别分配更大的α_j从而降低-log(α_j)同时控制总证据S不要无意义地膨胀log(S)项起到一定的抑制作用。2. 正则项最小化错误证据仅有拟合项是不够的。想象一下如果网络学会了对所有样本都输出一个巨大的α向量那么预测概率α/S会非常尖锐接近one-hot方差会很小看起来模型非常“自信”。同时对于错误类别其α值也可能很大这积累了大量的“错误证据”。为了防止这种退化解我们需要一个正则项来惩罚那些没有观测到标签的类别上的证据。正则项的形式为R_i ∑_{j1}^{K} (α_j - 1) * (ψ(S) - ψ(α_j))其中ψ(·)是Digamma函数Gamma函数对数的导数。这一项可以理解为KL散度的一种近似它惩罚了预测的Dirichlet分布与一个“无知”的均匀先验α1之间的偏差但更侧重于惩罚错误类别上的证据积累。因此总损失函数为Loss L_i λ * R_i这里的λ是一个超参数用于平衡拟合精度和不确定性校准。λ的选择至关重要是实践中最需要调优的部分之一。注意Digamma函数ψ(x)在x较小时变化剧烈。在PyTorch中我们可以直接使用torch.digamma()。计算时需确保α的值不会导致数值不稳定例如非常接近0。常见的陷阱与调优技巧λ过大正则项过强模型倾向于输出接近均匀先验的α所有α_i ≈ 1导致预测概率模糊认知不确定性始终很高模型变得“过于谦虚”分类准确率下降。λ过小正则项太弱模型会倾向于堆积证据特别是堆积在真实类别上导致认知不确定性被不合理地压低模型在OOD数据上可能依然“过度自信”。初始化与激活函数输出证据e的层其权重初始化不宜过大否则初始证据可能太强。使用Softplus(如log(1 exp(x))) 作为最终激活函数比ReLU更平滑能提供更稳定的梯度。数值稳定性计算log(α_j)和ψ(α_j)时确保α_j有一个小的下界如1e-8以防止取对数或Digamma函数时出现-inf。一个实用的调优策略是在验证集上不仅监控准确率还要监控模型在已知分布数据和OOD数据上认知不确定性的分布差异。一个好的模型应该在已知数据上表现出较低的认知不确定性在OOD数据上表现出较高的认知不确定性。3. PyTorch实现构建Dirichlet证据分类器理论说得再多不如一行代码。让我们开始用PyTorch构建一个用于MNIST分类的Dirichlet证据网络。我们将设计一个简单的卷积神经网络CNN作为骨干并将其最后一层替换为证据输出层。首先定义网络结构import torch import torch.nn as nn import torch.nn.functional as F class EvidentialCNN(nn.Module): def __init__(self, num_classes10): super(EvidentialCNN, self).__init__() self.num_classes num_classes # 特征提取骨干网络 (一个简单的CNN) self.conv1 nn.Conv2d(1, 32, kernel_size3, stride1, padding1) self.conv2 nn.Conv2d(32, 64, kernel_size3, stride1, padding1) self.pool nn.MaxPool2d(kernel_size2, stride2) self.dropout1 nn.Dropout2d(0.25) self.dropout2 nn.Dropout(0.5) # 计算展平后的特征维度 # MNIST图像为28x28经过两次池化后为7x7 self.feature_size 64 * 7 * 7 # 全连接层 self.fc1 nn.Linear(self.feature_size, 128) # **关键证据输出层** # 输出K个证据值使用Softplus确保为正 self.evidence_layer nn.Linear(128, num_classes) def forward(self, x): x self.pool(F.relu(self.conv1(x))) x self.dropout1(x) x self.pool(F.relu(self.conv2(x))) x self.dropout1(x) x x.view(-1, self.feature_size) x F.relu(self.fc1(x)) x self.dropout2(x) # 输出证据 (evidence) evidence F.softplus(self.evidence_layer(x)) # Dirichlet参数 α evidence 1 alpha evidence 1.0 return alpha接下来实现核心的损失函数。我们将严格遵循上一节推导的公式并加入数值稳定措施。class DirichletLoss(nn.Module): def __init__(self, lambda_reg0.01): super(DirichletLoss, self).__init__() self.lambda_reg lambda_reg def forward(self, alpha, targets): alpha: (batch_size, num_classes) Dirichlet浓度参数 targets: (batch_size, num_classes) one-hot编码的标签 # 确保数值稳定给alpha一个很小的下界 alpha torch.clamp(alpha, min1e-8) # 计算证据总量 S sum(alpha_i) S torch.sum(alpha, dim1, keepdimTrue) # (batch_size, 1) # 拟合项 L_i sum(y_j * (log(S) - log(alpha_j))) log_S torch.log(S) log_alpha torch.log(alpha) fit_loss torch.sum(targets * (log_S - log_alpha), dim1) # 正则项 R_i sum( (alpha_j - 1) * (psi(S) - psi(alpha_j)) ) # 注意对所有类别求和包括真实类别和错误类别 psi_S torch.digamma(S) psi_alpha torch.digamma(alpha) reg_loss torch.sum((alpha - 1.0) * (psi_S - psi_alpha), dim1) # 总损失 total_loss fit_loss self.lambda_reg * reg_loss return total_loss.mean() # 返回批次平均损失现在我们可以组装训练流程。我们将使用标准的MNIST数据集但为了后续的OOD检测实验我们保留原始的测试集作为“已知分布”数据并创建一个“OOD测试集”——例如将MNIST测试集图像随机旋转一个较大角度如45度。import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, Subset import numpy as np def train_epoch(model, device, train_loader, optimizer, criterion, epoch): model.train() total_loss 0 correct 0 total 0 for batch_idx, (data, target) in enumerate(train_loader): data, target data.to(device), target.to(device) optimizer.zero_grad() # 前向传播得到alpha alpha model(data) # 将标签转换为one-hot target_onehot F.one_hot(target, num_classes10).float() # 计算损失 loss criterion(alpha, target_onehot) # 反向传播与优化 loss.backward() optimizer.step() total_loss loss.item() # 计算准确率 (预测类别为最大概率的类别) S torch.sum(alpha, dim1, keepdimTrue) p alpha / S # 预测概率 pred p.argmax(dim1, keepdimTrue) correct pred.eq(target.view_as(pred)).sum().item() total target.size(0) avg_loss total_loss / len(train_loader) accuracy 100. * correct / total print(fTrain Epoch: {epoch} \tLoss: {avg_loss:.6f}\tAccuracy: {correct}/{total} ({accuracy:.2f}%)) return avg_loss, accuracy # 数据准备 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset datasets.MNIST(./data, trainTrue, downloadTrue, transformtransform) test_dataset datasets.MNIST(./data, trainFalse, transformtransform) # 创建OOD数据集旋转的MNIST ood_transform transforms.Compose([ transforms.RandomRotation(degrees45), # 随机旋转-45到45度 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) ood_dataset datasets.MNIST(./data, trainFalse, transformood_transform) train_loader DataLoader(train_dataset, batch_size128, shuffleTrue) test_loader DataLoader(test_dataset, batch_size1000, shuffleFalse) ood_loader DataLoader(ood_dataset, batch_size1000, shuffleFalse) # 初始化模型、损失、优化器 device torch.device(cuda if torch.cuda.is_available() else cpu) model EvidentialCNN(num_classes10).to(device) criterion DirichletLoss(lambda_reg0.01) # 尝试不同的lambda_reg optimizer optim.Adam(model.parameters(), lr1e-3) # 训练循环 num_epochs 10 for epoch in range(1, num_epochs 1): train_epoch(model, device, train_loader, optimizer, criterion, epoch)4. 实验设计与效果对比不确定性可视化与分析模型训练完成后真正的乐趣在于观察它如何量化不确定性。我们需要设计实验对比我们的Dirichlet证据分类器与一个结构完全相同、但使用标准交叉熵损失和Softmax输出的传统分类器。首先我们定义一个传统的Softmax CNN作为基线class SoftmaxCNN(nn.Module): # ... 结构与EvidentialCNN的卷积部分完全相同 ... def forward(self, x): # ... 卷积层和全连接层 ... x self.fc2(x) # 输出10维logits return x # 不在这里做Softmax损失函数里做 # 训练时使用CrossEntropyLoss它内部包含了Softmax和NLLLoss。关键对比实验1在标准MNIST测试集上的表现我们评估两个模型的分类准确率。理想情况下证据分类器的准确率应与Softmax分类器相当或略低因为损失函数目标不同。更重要的是我们计算每个样本的认知不确定性U_cognitive K / S。def evaluate_uncertainty(model, data_loader, device, is_evidentialTrue): model.eval() all_probs [] all_labels [] all_uncertainty [] with torch.no_grad(): for data, target in data_loader: data, target data.to(device), target.to(device) if is_evidential: alpha model(data) S torch.sum(alpha, dim1) probs alpha / S.unsqueeze(1) # 认知不确定性 u_cognitive 10.0 / S # K10 all_uncertainty.append(u_cognitive.cpu()) else: logits model(data) probs F.softmax(logits, dim1) # 对于Softmax可以用预测概率的最大熵或最大概率的倒数作为不确定性的简单代理 max_probs, _ torch.max(probs, dim1) u_proxy 1.0 - max_probs # 简单代理1 - 最大概率 all_uncertainty.append(u_proxy.cpu()) all_probs.append(probs.cpu()) all_labels.append(target.cpu()) all_probs torch.cat(all_probs, dim0) all_labels torch.cat(all_labels, dim0) all_uncertainty torch.cat(all_uncertainty, dim0) # 计算准确率 preds all_probs.argmax(dim1) accuracy (preds all_labels).float().mean().item() return accuracy, all_uncertainty.numpy(), all_probs.numpy() # 评估两个模型 acc_evi, unc_evi, _ evaluate_uncertainty(model_evi, test_loader, device, is_evidentialTrue) acc_soft, unc_soft, _ evaluate_uncertainty(model_soft, test_loader, device, is_evidentialFalse) print(f证据模型准确率: {acc_evi:.4f}, Softmax模型准确率: {acc_soft:.4f})关键对比实验2在OOD数据旋转MNIST上的不确定性反应这是检验模型“自知之明”的核心。我们将两个模型在OOD测试集上运行并绘制它们预测不确定性的分布直方图。import matplotlib.pyplot as plt import numpy as np acc_evi_ood, unc_evi_ood, _ evaluate_uncertainty(model_evi, ood_loader, device, is_evidentialTrue) acc_soft_ood, unc_soft_ood, _ evaluate_uncertainty(model_soft, ood_loader, device, is_evidentialFalse) print(f在OOD数据上 - 证据模型准确率: {acc_evi_ood:.4f}, 不确定性均值: {unc_evi_ood.mean():.4f}) print(f在OOD数据上 - Softmax模型准确率: {acc_soft_ood:.4f}, 不确定性代理均值: {unc_soft_ood.mean():.4f}) # 绘制不确定性分布对比图 plt.figure(figsize(12, 5)) plt.subplot(1, 2, 1) plt.hist(unc_evi, bins50, alpha0.7, labelIn-Dist (MNIST Test), densityTrue) plt.hist(unc_evi_ood, bins50, alpha0.7, labelOOD (Rotated MNIST), densityTrue) plt.xlabel(Cognitive Uncertainty (K/S)) plt.ylabel(Density) plt.title(Evidential Model: Uncertainty Distribution) plt.legend() plt.grid(True, alpha0.3) plt.subplot(1, 2, 2) plt.hist(unc_soft, bins50, alpha0.7, labelIn-Dist (MNIST Test), densityTrue) plt.hist(unc_soft_ood, bins50, alpha0.7, labelOOD (Rotated MNIST), densityTrue) plt.xlabel(Uncertainty Proxy (1 - Max Prob)) plt.ylabel(Density) plt.title(Softmax Model: Uncertainty Proxy Distribution) plt.legend() plt.grid(True, alpha0.3) plt.tight_layout() plt.show()预期结果与分析准确率在标准测试集上两个模型的准确率应非常接近例如都99%。在OOD集上准确率都会显著下降但这不是重点。不确定性分布这是关键区别所在。对于证据模型我们希望看到在标准测试集In-Distribution上认知不确定性K/S的值集中在一个较低的区域例如大部分0.2。而在OOD数据集上不确定性分布明显右移均值大幅提高例如0.5。这清晰地表明模型“知道”自己遇到了没见过的数据。对于Softmax模型其不确定性代理1 - max(p)在ID和OOD数据上的分布可能重叠严重。即使OOD数据上模型已经预测错误其输出的最大概率max(p)仍然可能很高过度自信导致1 - max(p)依然很低无法有效区分ID和OOD样本。可视化单个样本我们可以选取几个OOD样本查看模型的具体输出。def visualize_sample(model, images, labels, is_evidentialTrue, num_samples5): model.eval() with torch.no_grad(): if is_evidential: alpha model(images[:num_samples]) S torch.sum(alpha, dim1, keepdimTrue) probs (alpha / S).cpu().numpy() uncertainty (10.0 / S.squeeze()).cpu().numpy() evidence (alpha - 1.0).cpu().numpy() else: logits model(images[:num_samples]) probs F.softmax(logits, dim1).cpu().numpy() uncertainty 1.0 - np.max(probs, axis1) evidence None fig, axes plt.subplots(num_samples, 2, figsize(10, num_samples*2)) for i in range(num_samples): ax_img axes[i, 0] ax_bar axes[i, 1] # 显示图像 ax_img.imshow(images[i].squeeze().cpu().numpy(), cmapgray) ax_img.set_title(fTrue: {labels[i].item()}) ax_img.axis(off) # 显示概率条和不确定性 classes list(range(10)) ax_bar.bar(classes, probs[i], alpha0.7, labelProbability) ax_bar.set_ylim(0, 1) ax_bar.set_xlabel(Class) ax_bar.set_ylabel(Probability) title_str fPred: {np.argmax(probs[i])}, if is_evidential: title_str fUnc: {uncertainty[i]:.3f}, S: {S[i].item():.1f} # 可选在图上用小字显示证据值 for j, ev in enumerate(evidence[i]): ax_bar.text(j, probs[i][j]0.02, f{ev:.1f}, hacenter, fontsize8) else: title_str fUnc(proxy): {uncertainty[i]:.3f} ax_bar.set_title(title_str) ax_bar.grid(True, alpha0.3, axisy) plt.tight_layout() plt.show() # 从OOD数据加载器中取一个批次 ood_images, ood_labels next(iter(ood_loader)) ood_images, ood_labels ood_images.to(device), ood_labels.to(device) print(证据模型对OOD样本的预测) visualize_sample(model_evi, ood_images, ood_labels, is_evidentialTrue) print(\nSoftmax模型对OOD样本的预测) visualize_sample(model_soft, ood_images, ood_labels, is_evidentialFalse)通过这样的可视化你可以直观地看到对于一张旋转的“1”证据模型可能会给出一个相对平坦的概率分布例如每个类别都在0.1左右同时总证据S很小认知不确定性很高。而Softmax模型可能依然会“固执”地将很高的概率如0.9分配给某个错误的类别如“7”其不确定性代理值很低完全无法反映实际情况的风险。在实际项目中这种不确定性估计可以用于构建一个拒绝机制当模型的认知不确定性超过某个阈值时系统可以将该样本交给人类专家处理而不是盲目相信模型的预测。这为构建安全、可靠的AI系统提供了关键的技术支撑。调优λ参数的过程本质上就是在平衡模型的分类精度和这种“自知之明”的敏感度。我个人的经验是从一个较小的λ如0.001开始逐步增加观察模型在ID和OOD数据上不确定性分布的分离程度直到找到一个在保持合理准确率的前提下能清晰区分两者的最佳点。这个过程没有银弹需要根据具体的数据集和应用场景进行反复实验和验证。