基于联邦学习和改进ResNet的肺炎辅助诊断①

2022-05-10 02:29曹润芝刘嘎琼
计算机系统应用 2022年2期
关键词:残差联邦客户端

曹润芝,韩 斌,刘嘎琼

(江苏科技大学 计算机学院,镇江 212114)

世界卫生组织公布肺炎是导致儿童死亡的主要原因之一[1],据估计每年160 万5 岁以下儿童因患肺炎而死亡.肺炎病患的医学影像需要具备专业知识的医生进行评估,但在实际诊断中医生由于缺少经验和视觉疲惫等主观因素会造成误诊,因此准确高效地通过肺部影像进行辅助诊断对于肺炎的治疗至关重要[2,3].面对医疗诊断过程决策难且不确定性高的问题,Chouhan等人[4]采用迁移学习建立肺炎检测的深度学习框架,利用AlexNet、DenseNet121、Inception V3、GoogLeNet和ResNet18 网络模型提取图像特征,并使用集成学习方法将各网络的输出合并到一个预测向量中,使用多数表决权进行最终预测.潘丽艳等人[5]使用深度学习技术对肺区域病毒或细菌的病原学类型进行判断,从而在临床上规范儿童肺炎治疗.Al-Antari 等人[6]提出一种基于YOLO 预测器的同步深度学习计算机辅助诊断系统,能可靠地将新冠肺炎与其它呼吸系统疾病进行区分.何新宇等人[7]提出基于深度卷积神经网络的肺炎图像识别模型,采用GoogLeNet Incepion V3 网络模型进行特征提取,以及随机森林分类器进行分类预测,最终在识别准确率和敏感度上实现明显提高.

深度学习技术在图像的分类[8]、检测[9]等方面具有高精度的优势,能准确地提取医学图像的特征从而辅助疾病诊断.然而在现实中仍然存在以下问题:(1)训练模型需要大量的优质数据,而医疗数据涉及病人隐私,容易造成数据孤岛现象.(2)传统的深度学习过程将各方数据集汇总到一起进行模型训练,对于单机的算力需求较高且效率较低.(3)经典残差网络忽略了图像的通道特征且对批量大小具有高依赖性.

Mao 等人[10]论述联邦学习具有不共享本地数据实现联合训练的优势,能够解决深度学习过程中的医疗数据隐私泄漏问题.利用联邦学习对医学图像进行分析,能为临床医生提供更好的辅助诊断工具[11].

针对以上问题,本文提出了一种由联邦学习框架(federated learning,FL)、改进残差网络(residual networkgroup normalization,ResNet-GN)和压缩激励网络(squeezeand-excitation network,SE)融合形成的FL-SE-ResNet-GN方法,将残差网络的批量归一化方式(batch normalization,BN)替换为组归一化方式(group normalization,GN),嵌入压缩激励网络,在提取图像深层特征的同时关注通道特征,最终融合联邦学习框架进一步提升肺炎辅助诊断的效果和保护患者数据隐私.

1 相关理论

1.1 残差网络

基于卷积神经网络的模型通过不断地增加卷积层和池化层组合的数量从而学习图像更深层的特征,例如LeNet[12]、AlexNet[13]、VGG[14]等模型.然而当模型的层数不断地增加,浅层网络参数会逐渐趋于零,导致其无法更新,造成梯度消失现象.为了解决上述问题,2016年He 等人[15]提出了深度残差网络(residual network,ResNet)模型.其结构如图1所示,通过在两个权重层外部使用跳跃连接实现恒等映射,从而构成一个残差块.定义该残差块的输入为x,模块最终的整体输出为H(x),则残差模块两层权重层部分的残差函数为F(x)=H(x)−x.其中跳跃连接不会引入新的参数,一定程度上保证了训练的效率.相较于残差模块的整体输出H(x),计算通过优化转换后的残差函数F(x)更容易.

图1 经典残差模块

1.2 压缩激励网络

压缩激励网络[16]结构简单,便于耦合到已有的深度学习网络模型框架中.SE 模块主要关注了通道之间的相关性,通过训练得到各个通道不同的权重信息,含有主要特征的通道分配更高的权重,从而提高模型的表达能力.SE 模块由3 部分组成:压缩(squeeze)、激励(excitation)和重分配(reweight).定义SE 模块的输入为W×H×C的特征,其中,W表示宽,H表示高,C表示通道数;使用缩放参数减少通道个数从而降低计算量.

首先SE 模块利用全局池化对输入特征进行压缩,将特征转化为1×1×C的向量,使其获得全局的感受野;

然后利用两次全连接操作、ReLU 函数和缩放参数实现降维和升维操作.最后实现重分配,使用Sigmoid函数对激励操作的输出实现归一化,从而得到对应通道的归一化权重,并将权重与输入特征相乘,实现了输入特征的重标定,从而增强了有用特征的提取,提高特征提取的准确度.

1.3 联邦学习

联邦学习[10]是一种特殊的分布式机器学习框架,与传统的机器学习的区别在于缺少了中心服务器汇总数据的过程.联邦学习采用客户端-服务器架构如图2所示,联邦学习的组成部分为中心服务器和若干个客户端,其每轮的训练过程是客户端首先从中心服务器下载联邦共享模型和参数,利用本地的数据进行迭代训练,最后将该轮训练得到的参数或模型上传到中心服务器进行聚合,实现联邦共享模型和参数的更新.

图2 基于联邦学习的客户端-服务器架构

2 基于联邦学习与改进ResNet的肺炎辅助诊断方法

本文为了准确地提取肺部图像更深层次的特征同时保护医疗数据的隐私性.首先,在残差网络的基础上引入压缩激励网络.然后,将残差网络中的批量归一化方式转换为组归一化方式形成改进后的深度神经网络模型(SE-ResNet-GN).最后,将改进后的深度学习模型融合联邦学习框架,运用联邦平均算法实现分布式模型训练以获得最终的联邦共享模型.

2.1 改进ResNet 网络

本文以经典的ResNet-18 作为基础网络,总共由17 个卷积层和1 个全连接层组成,第一个卷积层使用64个7×7的卷积核,剩下的卷积层使用 3×3卷积核.

首先在残差块内部的输出后嵌入激励压缩网络,本文中令缩放参数r为16,输入特征依次经过两次全连接层进行16 倍降维和升维操作,利用Sigmoid 激活函数获得通道的归一化权重并与通道特征相乘从而重标定输入的特征.

其次,残差块使用BN 方式实现归一化,在学习过程中,BN 方式对一个批量的数据计算得到均值和方差并通过滑动平均的方式获得训练的全局均值和方差,因此BN 方式对批量大小(batch size)的依赖程度较高,当数据批量较小时难以保证训练效果,而较大批量的数据对于计算机的算力要求较高.因此本文使用组归一化[17]方式替换经典残差块内部的BN.改进后的网络如图3所示.GN 通过对通道进行分组操作,将卷积层的输出特征图的通道分为G组,G为16,图3中表示为64//G,并计算每组的均值和方差,因此使用GN方式性能稳定,对批数量的鲁棒性更强.

图3 融合SE 块的改进残差网络

定义归一化运算的输入的特征图x为[N,C,H,W],其中N表征批量数,C表征通道数,H、W表征高度和宽度,y与 β为学习的映射参数,BN 保留输入通道C的维度,其归一化公式如下所示:

GN 进行归一化首先将输入x的通道划分为G组,则每组包含C/G个通道数,然后计算组内元素的均值和方差,公式如下所示:

2.2 FL-SE-ResNet-GN 学习框架

基于联邦学习的肺炎辅助诊断中对于病人数据的数量和质量要求较高,对于单一的医院或者医疗机构而言本地的数据集数量较少而且类型单一,难以达到深度学习的要求.与此同时,医疗数据因涉及大量病人的隐私,形成大量的数据孤岛.本文联邦学习采用客户端-服务器架构,将单个医院或者医疗机构作为一个客户端以及拥有足够算力的可信第三方作为中心服务器,提供FL-SE-ResNet-GN 框架足够的算力支持并进行数据传输,能够让客户端在不暴露本地的数据的情况下进行深度学习的模型训练,各个客户端将本地训练的参数上传到中心服务器应用联邦平均算法进行聚合.联邦平均算法步骤如算法1 所示.

算法1.联邦平均算法输入:参与运算的客户端比例:;客户端:;学习率:;本地训练批数量B:;本地迭代次数:E;服务端拟合轮数:.输出:全局权重.C(0∼1) client j(1≤j≤J) η b∈Bt∈T Step 1.初始化t=1 w0 Step 2.从 到T m←max(C·J,1)1) S t← m 2) 数量为的客户端子集client j∈S ti∈E 3) 并行进行本地训练,同时本地训练轮数.client j ωi,j,t=ωi−1,j,t−η∇ℓj,t(ωi−1,j,t;b)i∈E 4) 利用本地数据对下载的模型参数进行本地E 轮迭代更新:,client j ω j,t ωt+1←J∑j=1 5)获得各的本地参数,利用本地客户端的数据量占所有参与训练的客户端的总数据量的比值得到聚合模型:nj n ω j,t

在联邦学习过程中,各个客户端在本地进行独立的模型训练,其采用的批处理数量与各自计算机算力相关,同时使用小批量数据在BN 中误差较大,从而影响聚合的联邦共享模型的准确度.因此,为了减少对批处理数量的依赖.本文使用联邦学习框架与改进的SEResNet-GN 网络模型相结合实现分布式训练,当联邦聚合的次数达到限制或者训练模型收敛后将会结束训练.具体的实现过程如图4所示.图中使用clienti和clientj表示不同的客户端.

图4 FL-SE-ResNet-GN 模型的训练流程

改进的深度神经网络模型融合联邦学习框架后的训练算法如算法2 所示.

算法2.FL-SE-ResNet-GN输入:本地数据集,服务端拟合总轮数Smax.w0{Data1,Data2,···,Datan}初始化:模型网络结构,权重.输出:收敛的FL-SE-ResNet-GN 模型.Step 1.客户端从中心服务器下载初始的FL-SE-ResNet-GN 模型和初始参数.

?

3 仿真实验与理论分析

3.1 实验环境

本文实验的环境为Windows 10 操作系统,Intel(R)CPU E5-1620 v3@3.5 GHz 3.5 GHz 处理器,8 GB NVIDIA Quadro M4000 显卡,32 GB 内存,500 GB 硬盘,计算机语言为Python,实验框架基于PyTorch 实现.

3.2 数据集与数据增强

本实验所使用的数据集是来自2018年美国加州大学圣迭戈分校公开的Chest X-Ray Images 图像数据集[18].该数据集由两类图像组成,分别是正常(normal)和患肺炎(pneumonia).训练集和测试集的数量如表1所示.

表1 数据集分布

3.3 实验评估指标

本次实验所使用的评估指标分别为准确率(Accuracy,ACC),召回率(Recall),精确率(Precision,P),F1 分数(F1).以上指标的计算公式如下.

其中,TP表示正常样本被正确分类的数量,TN表示肺炎样本被正确分类的数量,FP表示肺炎样本被分类为正常样本数量,FN表示正常样本中被分类为肺炎样本的数量.

3.4 实验过程和结果分析

本节使用Chest X-Ray Images 图像[18]作为数据集进行实验.首先,为了验证批量归一化和组归一化方式受批处理数量参数的影响,在不同的批处理数量下将SE-ResNet-GN 模型与嵌入SE的原始残差网络(以下简称SE-ResNet-BN) 进行比较,并分析模型的准确率、召回率、精准率和F1 等指标.然后,运用联邦学习框架与改进SE-ResNet-GN 模型相融合,并与经典的神经网络模型进行对比.最后,将本文改进的方法与其它研究者已有的研究成果进行对比.

3.4.1 Batch size 对融合BN和GN 网络的影响

本节将改进后的SE-ResNet-GN 与原始网络对比,批数量分别设定为4、8、16和32,其训练结果如表2.

表2 Batch size 大小对SE-ResNet-BN和SE-ResNet-GN的指标影响

从表2可知,batch size为4–32 之间,SE-ResNet-BN模型的准确率最大差值为11.5 个百分点,准确率最小差值为2 个百分点.在上述batch size 下的SE-ResNet-GN模型的准确率变化较为平稳,上述batch size 下准确率差值最大为3 个百分点,差值最小为0.2 个百分点.当batch size为32 时,SE-ResNet-GN的准确率、召回率、F1 分数分别比SE-ResNet-BN 高1.1 个百分点、5.8 个百分点、1.8 个百分点.以上结果表明,整体上,SE-ResNet-GN 模型不仅具有较好的评价指标效果而且受batch size的影响较少.

3.4.2 本文方法与经典深度神经网络模型对比

本节实验假定有3 个客户端进行20 轮联邦拟合过程,客户端本地进行10 轮迭代训练.实验数据集将根据客户端的数量等量划分成为各自本地的数据.本节实验选用AlexNet[13]、VGG[14]、ResNet[15]三种模型以及相关改进SE-ResNet-BN 模型与本文方法在相同的实验环境和数据集下进行对比实验,其结果如表3.

表3 本文方法与经典深度神经网络模型性能对比

由表3可知,SE-ResNet-BN 模型的指标优于ResNet,表明引入SE 模块实现注意力机制能够提升图像分类的性能.经典深度神经网络模型中准确率最高的SEResNet-BN 模型的准确率比本文方法低1.2 个百分点.本文方法的各项评价指标与AlexNet、ResNet、SEResNet-BN 相比具有明显提升.本文方法利用改进残差网络能提取更深层次的图像特征,嵌入压缩激励模块关注通道特征,并进行特征重标定增强有效特征的提取,能够提高联邦共享模型的准确性,同时本文方法具有本地数据不对外共享的优势,能够有效保护医疗数据的隐私,从而打破医疗数据孤岛现象.因此与上述经典方法相比,在诊断准确率和数据隐私保护方面具有优势.本节实验将联邦学习训练的迭代拟合过程与传统方法迭代过程对比,实验准确率和损失值变化如图5和图6所示,整体上训练过程中的准确率和损失值变化结果均优于其它网络.

图5 迭代次数和准确率关系

图6 迭代次数和损失值关系

3.4.3 与其它方法的对比

通过对Chest X-Ray Images 图像数据集进行训练,文献[18]使用迁移学习方法建立了一种基于深度学习框架的诊断工具.文献[19]提出一种结合残差思想和膨胀卷积[20]的肺炎图像分类方法[19],通过残差网络结构克服模型深度增加引起训练过程中的过拟合和退化问题,并利用膨胀卷积避免肺炎图像分类过程中的损失.文献[21]利用AlexNet和InceptionV3 网络模型并结合知识蒸馏方法提高对肺炎CT 图像的分类性能,提出了AlexNet_S 方法.本文方法与上述3 种方法的实验结果对比如表4所示.实验结果表明,本文方法的准确率和召回率最优.在精度指标上本文方法比文献[19]方法提升4.2 个百分点,但低于迁移学习[18]和AlexNet_S[21]的方法的精度指标.整体上本文方法在准确率、精度、召回率方面性能较好,同时融合联邦学习框架,在医疗数据隐私安全性方面具有较大优势,这在现实应用中具有重要的实际意义.

表4 本文方法与其它方法对比

4 结论与展望

为了提高医生疾病诊断的效率和准确性,本文融合联邦学习架构和改进后的SE-ResNet-GN 模型,利用联邦学习过程中数据保存在本地且不对外共享的优势,保证了医疗数据的隐私性并具有打破数据孤岛的优势.为了提取更深层次的特征同时避免梯度消失,本文以残差网络作为基础模型,在联邦学习过程中,各个客户端在本地进行独立训练,为了避免批处理数量对于联邦共享模型的影响,本文将传统的批量归一化方式转换为组归一化方式,对输入特征的通道进行分组运算,提高了模型的稳定性,同时引入激励压缩网络关注通道间的相关性.经过与其它的深度神经网络以及融合联邦学习框架后的模型对比实验后发现本文提出的模型具有更好的准确率与安全性.

猜你喜欢
残差联邦客户端
你的手机安装了多少个客户端
“人民网+客户端”推出数据新闻
——稳就业、惠民生,“数”读十年成绩单
联邦学习在金融数据安全领域的研究与应用
基于残差-注意力和LSTM的心律失常心拍分类方法研究
基于双向GRU与残差拟合的车辆跟驰建模
一“炮”而红 音联邦SVSound 2000 Pro品鉴会完满举行
基于残差学习的自适应无人机目标跟踪算法
基于深度卷积的残差三生网络研究与应用
媒体客户端的发展策略与推广模式
新华社推出新版客户端 打造移动互联新闻旗舰