关于模型训练中的EMA参数
什么是EMA,有什么作用
详细的方式来讲解模型训练中 指数移动平均(Exponential Moving Average,简称 EMA) 的原理、作用、公式、实现和注意事项
🧠 一句话总结:
指数移动平均(EMA)用于在模型训练过程中对模型参数进行“平滑”,以获得更稳定、更泛化更强的“平均模型”用于验证或测试。
🧩 一、为什么需要“指数移动平均”?
在训练神经网络时,模型的参数(权重)会随着每个 batch 的梯度更新不停变化,尤其是在训练后期,参数的震荡变得频繁甚至不稳定:
🔁 一次训练迭代后的参数 ≠ 上一次迭代后的参数
这种震荡可能会导致:
- 模型在验证集上表现波动大
- 测试集上性能不稳定
- 模型泛化性较差
于是我们就想到一个主意:
👉 别只用最后一个模型参数,用所有历史参数的“加权平均”!
这就是:指数移动平均(EMA)
🧪 二、什么是指数移动平均(EMA)?
✅ 定义:
EMA 是一种对过去数据赋予指数衰减权重的加权平均方式。
最新值的权重最大,越早的值权重越小。
📐 三、公式(最重要)
设:
- θ t \theta_t θt 是第 t t t 次迭代的 当前模型参数
- θ ^ t \hat{\theta}_t θ^t 是第 t t t 次迭代的 EMA 平均后的参数
则 EMA 的更新公式是:
θ ^ t = α ⋅ θ ^ t − 1 + ( 1 − α ) ⋅ θ t \hat{\theta}_t = \alpha \cdot \hat{\theta}_{t-1} + (1 - \alpha) \cdot \theta_t θ^t=α⋅θ^t−1+(1−α)⋅θt
- α ∈ [ 0 , 1 ] \alpha \in [0, 1] α∈[0,1]:衰减率(decay rate)
- 通常 α \alpha α 很接近于 1(如 0.999)
或者也写成:
EMA t = decay × EMA t − 1 + ( 1 − decay ) × current_value \text{EMA}_{t} = \text{decay} \times \text{EMA}_{t-1} + (1 - \text{decay}) \times \text{current\_value} EMAt=decay×EMAt−1+(1−decay)×current_value
🎯 四、EMA 有哪些实际作用?
作用 | 说明 |
---|---|
平滑参数更新 | 抑制训练中参数的剧烈变化 |
提高泛化性能 | 测试/验证时使用平均后的模型表现通常更好 |
稳定评估指标 | 使用 EMA 模型计算验证集 loss/accuracy 更稳定 |
更好的收敛效果 | 比“最后一版模型”表现更鲁棒 |
🛠️ 五、实际使用方式(以 PyTorch 为例)
# 初始化
ema_model = deepcopy(model)
ema_decay = 0.999
# 每一步更新
def update_ema(model, ema_model, decay):
with torch.no_grad():
msd = model.state_dict()
for key, value in ema_model.state_dict().items():
if key in msd:
ema_value = decay * value + (1 - decay) * msd[key]
value.copy_(ema_value)
✅ 你在训练过程中每一步都执行 update_ema(...)
📦 最后验证/保存时用 ema_model
而不是原始的 model
🧮 六、手动模拟一个例子
假设我们有个参数初始为 0,当前每步训练它逐渐上升:
步数 t t t | 当前值 θ t \theta_t θt | EMA ( α = 0.9 \alpha=0.9 α=0.9) |
---|---|---|
0 | 0.0 | 0.0 |
1 | 0.5 | 0.9 × 0 + 0.1 × 0.5 = 0.05 0.9×0 + 0.1×0.5 = 0.05 0.9×0+0.1×0.5=0.05 |
2 | 1.0 | 0.9 × 0.05 + 0.1 × 1.0 = 0.145 0.9×0.05 + 0.1×1.0 = 0.145 0.9×0.05+0.1×1.0=0.145 |
3 | 1.5 | 0.9 × 0.145 + 0.1 × 1.5 ≈ 0.280 0.9×0.145 + 0.1×1.5 ≈ 0.280 0.9×0.145+0.1×1.5≈0.280 |
可以看到,EMA 总是落后于真实值,但趋近稳定,且更“平滑”。
🧠 七、EMA 与其他方法对比
方法 | 是否平滑 | 是否可用作最终模型 | 额外内存开销 |
---|---|---|---|
原始模型 | ❌ 否 | ✅ 是 | 无 |
EMA | ✅ 是 | ✅ 是(常用于测试) | 要保存一份平均参数 |
Moving Average(滑动平均) | ✅ 是 | ✅(少用) | 较大(需保留多个历史) |
🧯 八、注意事项与小贴士
-
EMA 只在验证或测试时使用
- 不要用 EMA 的参数来更新梯度
- 训练依然用原始模型继续 forward 和 backward
-
EMA 不适用于 early stage 训练
- 因为早期模型参数还太“差”,平均它们意义不大
-
EMA 参数不能保存过早
- 最好是在训练快完成后再开始保存 EMA 模型
-
可以冻结 BatchNorm 统计量
- 用
model.eval()
+torch.no_grad()
评估 EMA 模型时注意
- 用
🚀 九、在真实模型中的应用案例
案例 | 模型或库 | 说明 |
---|---|---|
YOLOv5 | 检测目标 | 使用 EMA 平滑模型参数,提升精度 |
DeepLabV3+ | 语义分割 | 验证阶段用 EMA 模型评估更稳健 |
StyleGAN | 生成模型 | 用 EMA 的生成器参数生成图像效果更佳 |
更多推荐
所有评论(0)