Learning to Generalize:Meta-Learning for Domain Geralization 深度理解(公式详细推导,包含二阶导)
Learning to Generalize:Meta-Learning for Domain Geralization 深度理解
1.算法整体流程
输入:训练域为: SSS
初始化:模型参数θ\thetaθ,学习率:β,γ\beta,\gammaβ,γ
for ite in iterations do
Split:Sˉ\bar{S}Sˉ and Sˇ\check{S}Sˇ →\rightarrow→ SSS
Meta-train:
Gradients ∇θ1=Fθ′(Sˉ;θ)\nabla_{\theta_1}=\mathcal{F^{'}_\theta}(\bar{S};\theta)∇θ1=Fθ′(Sˉ;θ)
Updated parameters θ′=θ−β∇θ1\theta^{'}=\theta-\beta\nabla_{\theta_1}θ′=θ−β∇θ1
Meta-test:
Loss is G(Sˇ;θ′)\mathcal{G}(\check{S};\theta^{'})G(Sˇ;θ′)
Gradients Update θ:\theta:θ:
∇θ2=dG(Sˇ;θ′)dθ=dG(Sˇ;θ−β∇θ1)dθ′dθ′dθ=Gθ′′(Sˇ;θ−β∇θ1)d(θ−β∇θ1)dθ=Gθ′′(Sˇ;θ−β∇θ1)(1−βd∇θ1dθ)\nabla_{\theta_2}=\frac{d\mathcal{G}(\check{S};\theta^{'})}{d\theta}=\frac{d\mathcal{G}(\check{S};\theta-\beta\nabla_{\theta_1})}{d\theta^{'}}\frac{d\theta^{'}}{d\theta}=\mathcal{G^{'}_{\theta^{'}}}(\check{S};\theta-\beta\nabla_{\theta_1})\frac{d(\theta-\beta\nabla_{\theta_1})}{d\theta}=\mathcal{G^{'}_{\theta^{'}}}(\check{S};\theta-\beta\nabla_{\theta_1})(1-\beta\frac{d\nabla_{\theta_1}}{d\theta})∇θ2=dθdG(Sˇ;θ′)=dθ′dG(Sˇ;θ−β∇θ1)dθdθ′=Gθ′′(Sˇ;θ−β∇θ1)dθd(θ−β∇θ1)=Gθ′′(Sˇ;θ−β∇θ1)(1−βdθd∇θ1)
Meta-optimization: Update θ:csdn在线markdnow的latex不支持begin{align}标签,贴图如下\theta:{\color{Red}csdn在线markdnow的latex不支持begin\{align\}标签,贴图如下}θ:csdn在线markdnow的latex不支持begin{align}标签,贴图如下
2.总结
MLDG是第一篇将meta learning引入domain generalization的论文。在这篇论文中,G(;)=F(;)=CrossEntropy\mathcal{G}(;)=\mathcal{F}(;)=CrossEntropyG(;)=F(;)=CrossEntropy,二阶导几乎不起作用,训练模型的时候可以将代码中二阶导计算关掉(没用还很耗时)。
loss.backward(retain_graph=True, create_graph=True)
->改为
loss.backward(retain_graph=True, create_graph=False)
楼主跑过这篇文章的代码,性能与普通训练方式差不太多,即先在Sˉ\bar{S}Sˉ训练更新,然后再在 Sˇ\check{S}Sˇ上训练更新。个人认为他的insight在于启发了后续的meta-dg方法,后续方法重新设计能够约束特征空间的G(;)\mathcal{G}(;)G(;),后续方法是否有效等楼主验证后再更新。
更多推荐
所有评论(0)