详细的方式来讲解模型训练中 指数移动平均(Exponential Moving Average,简称 EMA) 的原理、作用、公式、实现和注意事项


🧠 一句话总结:

指数移动平均(EMA)用于在模型训练过程中对模型参数进行“平滑”,以获得更稳定、更泛化更强的“平均模型”用于验证或测试。


🧩 一、为什么需要“指数移动平均”?

在训练神经网络时,模型的参数(权重)会随着每个 batch 的梯度更新不停变化,尤其是在训练后期,参数的震荡变得频繁甚至不稳定:

🔁 一次训练迭代后的参数 ≠ 上一次迭代后的参数

这种震荡可能会导致:

  • 模型在验证集上表现波动大
  • 测试集上性能不稳定
  • 模型泛化性较差

于是我们就想到一个主意:

👉 别只用最后一个模型参数,用所有历史参数的“加权平均”!

这就是:指数移动平均(EMA)


🧪 二、什么是指数移动平均(EMA)?

✅ 定义:

EMA 是一种对过去数据赋予指数衰减权重的加权平均方式。

最新值的权重最大,越早的值权重越小。


📐 三、公式(最重要)

设:

  • θ t \theta_t θt 是第 t t t 次迭代的 当前模型参数
  • θ ^ t \hat{\theta}_t θ^t 是第 t t t 次迭代的 EMA 平均后的参数

则 EMA 的更新公式是:

θ ^ t = α ⋅ θ ^ t − 1 + ( 1 − α ) ⋅ θ t \hat{\theta}_t = \alpha \cdot \hat{\theta}_{t-1} + (1 - \alpha) \cdot \theta_t θ^t=αθ^t1+(1α)θt

  • α ∈ [ 0 , 1 ] \alpha \in [0, 1] α[0,1]衰减率(decay rate)
  • 通常 α \alpha α 很接近于 1(如 0.999)

或者也写成:

EMA t = decay × EMA t − 1 + ( 1 − decay ) × current_value \text{EMA}_{t} = \text{decay} \times \text{EMA}_{t-1} + (1 - \text{decay}) \times \text{current\_value} EMAt=decay×EMAt1+(1decay)×current_value


🎯 四、EMA 有哪些实际作用?

作用 说明
平滑参数更新 抑制训练中参数的剧烈变化
提高泛化性能 测试/验证时使用平均后的模型表现通常更好
稳定评估指标 使用 EMA 模型计算验证集 loss/accuracy 更稳定
更好的收敛效果 比“最后一版模型”表现更鲁棒

🛠️ 五、实际使用方式(以 PyTorch 为例)

# 初始化
ema_model = deepcopy(model)
ema_decay = 0.999

# 每一步更新
def update_ema(model, ema_model, decay):
    with torch.no_grad():
        msd = model.state_dict()
        for key, value in ema_model.state_dict().items():
            if key in msd:
                ema_value = decay * value + (1 - decay) * msd[key]
                value.copy_(ema_value)

✅ 你在训练过程中每一步都执行 update_ema(...)
📦 最后验证/保存时用 ema_model 而不是原始的 model


🧮 六、手动模拟一个例子

假设我们有个参数初始为 0,当前每步训练它逐渐上升:

步数 t t t 当前值 θ t \theta_t θt EMA ( α = 0.9 \alpha=0.9 α=0.9)
0 0.0 0.0
1 0.5 0.9 × 0 + 0.1 × 0.5 = 0.05 0.9×0 + 0.1×0.5 = 0.05 0.9×0+0.1×0.5=0.05
2 1.0 0.9 × 0.05 + 0.1 × 1.0 = 0.145 0.9×0.05 + 0.1×1.0 = 0.145 0.9×0.05+0.1×1.0=0.145
3 1.5 0.9 × 0.145 + 0.1 × 1.5 ≈ 0.280 0.9×0.145 + 0.1×1.5 ≈ 0.280 0.9×0.145+0.1×1.50.280

可以看到,EMA 总是落后于真实值,但趋近稳定,且更“平滑”。


🧠 七、EMA 与其他方法对比

方法 是否平滑 是否可用作最终模型 额外内存开销
原始模型 ❌ 否 ✅ 是
EMA ✅ 是 ✅ 是(常用于测试) 要保存一份平均参数
Moving Average(滑动平均) ✅ 是 ✅(少用) 较大(需保留多个历史)

🧯 八、注意事项与小贴士

  1. EMA 只在验证或测试时使用

    • 不要用 EMA 的参数来更新梯度
    • 训练依然用原始模型继续 forward 和 backward
  2. EMA 不适用于 early stage 训练

    • 因为早期模型参数还太“差”,平均它们意义不大
  3. EMA 参数不能保存过早

    • 最好是在训练快完成后再开始保存 EMA 模型
  4. 可以冻结 BatchNorm 统计量

    • model.eval() + torch.no_grad() 评估 EMA 模型时注意

🚀 九、在真实模型中的应用案例

案例 模型或库 说明
YOLOv5 检测目标 使用 EMA 平滑模型参数,提升精度
DeepLabV3+ 语义分割 验证阶段用 EMA 模型评估更稳健
StyleGAN 生成模型 用 EMA 的生成器参数生成图像效果更佳
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐