论文地址: https://arxiv.org/abs/2206.02425
代码仓库: https://github.com/YaoZhang93/mmFormer

摘要

从磁共振成像(MRI)图像中准确地分割出脑肿瘤是多模态图像联合学习地前提。然而,在临床实践中,并不总是能够获得一张完整的MRI图像,并且模态缺失问题导致现有的多模态分割方法性能严重下降。在这项工作中,首次尝试利用Transformer进行多模态脑肿瘤分割,该方法对任何可用模态的子集都具有鲁棒性。具体来说,本文提出了一种新颖的多模态医学Transformer(mmFormer),用于不完整多模态学习,他主要由三部分组成:混合模态特异性编码器,该编码器桥接了一个卷积编码器和一个模态内的Transformer,用于每个模态内的局部和全局上下文建模;跨模态Transformer用于构建和对齐模态不变特征与肿瘤区域对应的全局语义的跨模态长程相关性;一个解码器,与模态不变的特征进行渐进的上采样和融合,以产生鲁棒的分割。此外,在编码器和解码器中都加入了辅助正则器,以进一步增强模型对不完整模态的鲁棒性。我们在公开的用于脑肿瘤分割实验的BraTS 2018数据集上进行了大量实验。结果表明,本文提出的mmFormer方法在几乎所有的不完整模态子集上都优于现有的不完整多模态脑肿瘤分割SOTA方法,特别是在仅有一个模态的情况下,对肿瘤分割的Dice平均改进为19.07%

引言

脑肿瘤的自动、准确分割在临床评估和诊断中起着至关重要的作用。磁共振成像(MRI)是临床上常用的定量评估脑肿瘤的神经成像技术,可提供多种成像模态:T1、T1ce、T2、FLAIR图像,每一种成像模态都提供了不同的大脑结构和病理对比。多模态图像的联合学习对脑肿瘤分割至关重要,可以显著提高分割性能。为了有效地融合多模态MRI进行脑肿瘤分割,许多方法已经被广泛的探索,例如,在通道维度将多模态图像拼接作为输入或在潜在空间中融合特征,然而,在临床实践中,由于数据损坏、各种扫描协议和患者不适合的条件,并不总是能够获得一套完整的MRI。在这种情况下,现有的大多数多模态方法可能无法完全处理不完整的成像模态,并面临严重的分割性能下降,因此,对于一个或多个模态缺失的灵活和实际的临床应用,一个鲁棒的多模态方法是非常有必要的。

不完整多模态学习

也称为异模态学习,旨在设计在推理时对可用模态的任何子集都具有鲁棒性的方法。 脑肿瘤分割的不完全多模态学习的一个直接策略是通过生成模型来合成缺失的模态。另一类方法探索从完整模态到不完整模态的知识蒸馏。尽管这些方法取得了很好的效果,但这些方法都必须为每个缺失的模态子集训练和部署特定的模型,这在临床应用中是复杂而繁重的。Zhang等人提出了一种自适应融合单模态模型的集成学习来实现多模态分割,然而,只有当一种或所有模态都可用时,它才会起作用。同时,所有的这些方法在训练的过程中都需要完整的模态。
最近的方法侧重于学习一个统一的模型,而不是一堆蒸馏的网络,用于不完整的多模态分割。例如,HeMIS通过计算来自任意可用模态的特征的均值和方差来学习多模态信息的嵌入。U-HVED进一步引入多模态变分自编码器,以利于生成缺失模态的不完整多模态分割。最近的一些方法也提出利用特征解耦和注意力机制来进行鲁棒的多模态脑肿瘤分割。全卷积网络(FCN)在医学图像分割中取得了巨大的成功,被广泛应用于上述方法中的特征提取。尽管性能优越,但卷积的归纳性偏差,
即局部性,使得FCN难以显式地构建长程依赖关系。在脑肿瘤分割的不完整多模态学习中,有限感受野提取的特征在处理不同模态时往往存在偏差。相比之下,具有跨不同模态的肿瘤区域全局语义信息的模态不变性嵌入可能有助于更鲁棒的分割,特别是当一个或多个模态缺失时。
Transformer最初被提出用于建模序列到序列的长程依赖关系,并且在各种计算机视觉任务上也表现出最先进的性能。并行工作从主干网络的角度利用Transformer进行脑肿瘤分割。然而,用于脑肿瘤分割的多模态建模专用Transformer尚未被仔细挖掘,更不用说不完整的多模态分割。
本文做法: 本文旨在利用Transformer为不完整多模态学习的脑肿瘤分割构建一个统一的模型。提出了多模态医学Transformer,它利用混合模态特异性编码器和模态相关的编码器来建立不同模态内部和跨模态之间的长程依赖关系。通过显式构建和对齐不同模态之间的全局相关性来提取模态的不变表示,所提出的mmFormer在脑肿瘤分割的不完整多模态学习中表现出了更好地鲁棒性。同时,在mmFormer中引入辅助正则项,以鼓励编码器和解码器在一定数量的模态缺失情况下学习具有判别性的特征。我们在BraTS 2018数据集上验证了mmFormer在多模态脑肿瘤分割任务上的有效性。在所有缺失模态的情形下,本文方法在平均Dice度量上都优于SOTA方法,特别是在仅有一个可用模态的情况下,Dice在增强肿瘤分割方面平均提高了19.07%,据我们所知,这是首次尝试将Transformer用于脑肿瘤分割的不完整多模态学习。

方法

在本文中,我们提出了mmFormer用于脑肿瘤分割的不完整多模态学习。我们采用编码器-解码器架构来构建mmFormer,包括每个模态的混合模态特异性编码器、模态相关编码器和卷积编码器。此外,在编码器和解码器中都引入了辅助正则项,mmFormer的概述如图1所示。下面对每个组件的细节进行详细阐述。

Figure 1:mmFormer的概述,它由四个混合模态特异性编码器,一个模态相关编码器和一个卷积解码器组成。同时,在编码器和解码器中都引入了辅助正则器。隐藏了卷积编码器和解码器之间的跳跃连接,以便清晰显示

混合模态特异性编码器

混合模态特异性编码器旨在通过桥接卷积编码器和模态内Transformer来提取特定模态内的局部和全局上下文信息。用M={FLAIR,T1c,T1,T2}M=\{FLAIR,T1c,T1,T2\}M={FLAIR,T1c,T1,T2}表示完整的模态集合,给定输入为Xm∈R1×D×H×W\mathbf{X}_m\in\mathbb{R}^{1\times D\times H\times W}XmR1×D×H×W,大小为D×H×W,m∈MD\times H\times W,m\in MD×H×W,mM,我们首先利用卷积编码器生成具有局部上下文的紧凑特征图,然后利用模态内Transformer来建模全局空间中的长程相关性。

卷积编码器

卷积编码器由卷积块堆叠而成,类似于U-Net的编码器部分。由卷积编码器Fmconv\mathcal{F}_{m}^{conv}Fmconv产生的每个模态内具有局部上下文的特征图可以表示Fmlocal=Fmconv(Xm;θmconv)(1)\mathbf{F}_m^{local}=\mathcal{F}_m^{conv}(\mathbf{X}_m;\theta_m^{conv})\quad\quad\quad\quad(1)Fmlocal=Fmconv(Xm;θmconv)(1)其中Fmlocal∈RC×D2l−1×H2l−1×W2l−1\mathbf{F}_{m}^{local}\in\mathbb{R}^{C\times}\frac{D}{2^{l-1}}\times\frac{H}{2^{l-1}}\times\frac{W}{2^{l-1}}FmlocalRC×2l1D×2l1H×2l1WCCC为channel维数,lll为编码器的级数。具体来说,我们构建了一个五阶段的编码器,每个阶段由两个卷积块组成。每个块包含组归一化、RELU和卷积核大小为3的卷积层,而第一阶段的第一个卷积块只包含一个卷积层。在连续的两个块之间,使用步长为2的卷积层对特征图进行降采样。编码器每一级的卷积核个数分别为16,32,64,128和256 。

模态内Transformer

受限于卷积网络固有的局部性,卷积编码器未能有效构建各模态内部的长程依赖关系。因此,我们利用模态内Transformer进行显式的长程上下文建模。模态内Transformer包含一个标记器、一个多头自注意力(MSA)和一个前馈网络(FFN)。
因为Transforemer以序列到序列的方式处理嵌入,由卷积编码器产生的局部特征映射Fmlocal\mathbf{F}_{m}^{local}Fmlocal首先被展平化为1D序列,并通过线性投影转换到token空间,然而,扁平化操作不可避免地会导致空间信息的塌陷,这对图像分割至关重要。为了解决这个问题,我们引入了一个可学习的位置嵌入Pm\mathbf{P}_{m}Pm,通过逐元素求和来补充扁平化的特征,表示为Fmtoken=FmlocalWm+Pm,(2)\mathbf{F}_m^{tok\boldsymbol{e}n}=\mathbf{F}_m^{loc\boldsymbol{a}l}\mathbf{W}_m+\mathbf{P}_m,\quad\quad\quad\quad\quad\quad\quad\quad(2)Fmtoken=FmlocalWm+Pm,(2)其中,Fmtoken∈RC′×DHW23(l−1)\mathbf{F}_m^{token}\in\mathbb{R}^{C^{\prime}\times\frac{DHW}{2^{3(l-1)}}}FmtokenRC×23(l1)DHW表示token,Wm\mathbf{W}_{m}Wm表示线性投影的权重。MSA通过查看特征图中所有可能的位置来建立每个模态内的关系,表示为headmi=Attention(Qmi,Kmi,Vmi)=softmax(QmiKmiTdk)Vmi,(3)head_{m}^{i}=Attention(\mathbf{Q}_{m}^{i},\mathbf{K}_{m}^{i},\mathbf{V}_{m}^{i})=softmax(\frac{\mathbf{Q}_{m}^{i}\mathbf{K}_{m}^{i\text{T}} }{ \sqrt { d _ { k }}})\mathbf{V}_{m}^{i},\quad\quad(3)headmi=Attention(Qmi,Kmi,Vmi)=softmax(dk QmiKmiT)Vmi,(3)MSAm=[headm1,...,headmN]Wmo,(4)MSA_m=[head_m^1,...,head_m^N]\text{W}_m^o,\quad\quad\quad\quad(4)MSAm=[headm1,...,headmN]Wmo,(4)其中Qmi=LN(Fmtoken)WmQi,Kmi=LN(Fmtoken)WmKi,Vmi=LN(Fmtoken)WmVi,\mathbf{Q}_m^i=LN(\mathbf{F}_m^{token})\mathbf{W}_m^{Qi},\mathbf{K}_m^i=LN(\mathbf{F}_m^{token})\mathbf{W}_m^{Ki},\mathbf{V}_m^i=LN(\mathbf{F}_m^{token})\mathbf{W}_m^{Vi},Qmi=LN(Fmtoken)WmQi,Kmi=LN(Fmtoken)WmKi,Vmi=LN(Fmtoken)WmVi,LN(⋅)LN(\cdot)LN()为层归一化,dkd_{k}dkKm\mathbf{K}_{m}Km的维度,N=8N = 8N=8是注意力头的数目,[⋅,⋅][\cdot,\cdot][,]是concat操作,FFN是具有GELU激活的双层感知机。模态内Transformer产生的每个模态内具有全局上下文的特征映射定义为Fmglobal=FFNm(LN(z))+z,z=MSAm(LN(Fmtoken))+Fmtoken,(5)\mathbf{F}_m^{\boldsymbol{global}}=FFN_m(LN(z))+z,z=MSA_m(LN(\mathbf{F}_m^{\boldsymbol{token}}))+\mathbf{F}_m^{\boldsymbol{token}},\quad(5)Fmglobal=FFNm(LN(z))+z,z=MSAm(LN(Fmtoken))+Fmtoken,(5)其中Fmglobal∈RC′×DHW23(l−1)\mathbf{F}_m^{\boldsymbol{global}}\in\mathbb{R}^{C^{\prime}\times\frac{DHW}{2^{3(l-1)}}}FmglobalRC×23(l1)DHW

模态相关编码器

模态相关编码器旨在为与肿瘤区域对应的具有全局语义的模态不变特征建立跨模态的长程相关性。用一种模态间Transformer来实现。

模态间Transformer

与模态内Transformer不同,模态间Transformer将所有模态特异性编码器的嵌入通过concat的方式组合成输入的多模态token,其定义为Ftoken=[δFLAIRFFLAIRglobal,δT1cFT1cglobal,δT1FT1global,δT2FT2global]W+P,(6)\mathbf{F}^{token}=[\delta_{FLAIR}\mathbf{F}_{FLAIR}^{global},\delta_{T1c}\mathbf{F}_{T1c}^{global},\delta_{T1}\mathbf{F}_{T1}^{global},\delta_{T2}\mathbf{F}_{T2}^{global}]\mathbf{W}+\mathbf{P},\quad(6)Ftoken=[δFLAIRFFLAIRglobal,δT1cFT1cglobal,δT1FT1global,δT2FT2global]W+P,(6)其中δm∈{0,1}\delta_{m}\in\{0,1\}δm{0,1}是一个伯努利指标,其目的是在构建不同模态之间长程依赖关系时,即使在某些模态缺失的情况下,也能提供鲁棒性。这种模态级别的dropout是在训练过程中随机进行的,通过将δm\delta_{m}δm设置为0.在模态缺失的情况下,缺失模态的多模态token被0向量取代。随后,对跨模态的模态不变特征进行进行MSD和FFN处理,表示为Fglobal=FFN(LN(z))+z,z=MSA(LN(Ftoken))+Ftoken,(7)\mathbf{F}^{\boldsymbol{global}}=FFN(LN(z))+z,z=MSA(LN(\mathbf{F}^{\boldsymbol{toke}n}))+\mathbf{F}^{\boldsymbol{toke}n},\quad(7)Fglobal=FFN(LN(z))+z,z=MSA(LN(Ftoken))+Ftoken,(7),其中Fglobal∈RC′×DHW2(l−1).\mathbf{F}^{global}\in\mathbb{R}^{C^{\prime}\times\frac{DHW}{2^{(l-1)}}}.FglobalRC×2(l1)DHW.

卷积解码器

卷积解码器的设计目的是将空间分辨率从高层潜在空间逐步恢复到原始掩码空间。模态相关Transformer的输出序列Fglobal\mathbf{F}^{global}Fglobal被重塑为扁平化前的尺寸对应的特征图。卷积解码器具有与U-Net类似地卷积编码器的对称结构。此外,编码器和解码器之间的跳跃连接也被添加,以保留更多的低层细节以获得更好的分割效果。来自特定层次不同模态的卷积编码器的特征被concat并作为跳跃特征转发到卷积解码器。

辅助正则化器

传统的多模态学习模型倾向于依靠判别模态(T2、FLAIR)来识别脑肿瘤。当判别模态缺失时,这类模型可能面临严重的退化。因此,鼓励每个卷积编码器在没有其他模态辅助的情况下分割脑肿瘤是至关重要的。为此,卷积编码器的输出通过共享权重解码器 (所有模态使用一种解码器,同时对该解码器进行权重的更新) 进行上采样,以分别从每个模态中分割肿瘤。共享权重解码器与卷积解码器具有相同的架构。此外,我们还在卷积解码器中引入辅助正则项,以迫使解码器即使在某些模态缺失的情况下也能产生准确的分割。它是通过对卷积解码器每个阶段的特征图进行插值来实现的,通过深度监督来分割肿瘤。采用Dice损失作为正则项,结合网络输出的训练损失和辅助正则项,定义整体损失函数为L=1−Dice=1−2∑c=1C∑i=1Ncgicpic∑c=1C∑i=1Ncgic2+∑c=1C∑i=1Ncpic2,(8)\mathcal{L}=1-Dice=1-\frac{2\sum_{c=1}^C\sum_{i=1}^{N_c}g_i^cp_i^c}{\sum_{c=1}^C\sum_{i=1}^{N_c}g_i^{c2}+\sum_{c=1}^C\sum_{i=1}^{N_c}p_i^{c2}},\quad\quad(8)L=1Dice=1c=1Ci=1Ncgic2+c=1Ci=1Ncpic22c=1Ci=1Ncgicpic,(8)Ltotal =∑i∈MLiencoder+∑i=1l−1Lidecoder+Loutput,(9)\mathcal{L}_\text{total }=\sum_{i\in M}\mathcal{L}_i^{encoder}+\sum_{i=1}^{l-1}\mathcal{L}_i^{decoder}+\mathcal{L}^{output},\quad\quad\quad(9)Ltotal =iMLiencoder+i=1l1Lidecoder+Loutput,(9)其中CCC是分割类别数,NcN_{c}Nc为类别ccc的体素数,gicg_{i}^{c}gic是一个二值指标,如果类别标记ccc是像素iii的正确分类,picp_{i}^{c}pic是相应的预测概率,M={FLAIR,T1c,T1,T2}M=\{FLAIR,T1c,T1,T2\}M={FLAIR,T1c,T1,T2}lll是卷积解码器中的阶段数

实验及结果

数据集与实现

数据集:BraTS 2018数据集
硬件配置:4块NVDIA Tesla V100 GPUs
实现环境:PyTorch 1.7
Batch size: 1
Learning rat:0.0002
epoch:1000

结果:

在这里插入图片描述
Table 1:在BraTS 2018数据集上本文方法与SOTA方法HeMIS和U-HVED的结果,采用Dice相似性系数(DSC)对每个模态的组合设置进行评估,用●和o分别表示可用和缺失模态
在这里插入图片描述
Fig 2:mmFormer在多种可用模态下的分割结果

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐