龙空技术网

MATLAB环境下基于深度学习的新型冠状病毒肺炎 (COVID-19) 检测

哥本哈根诠释2023 159

前言:

当前你们对“衡量算法性能的主要指标”都比较注意,小伙伴们都想要知道一些“衡量算法性能的主要指标”的相关文章。那么小编在网络上搜集了一些对于“衡量算法性能的主要指标””的相关知识,希望同学们能喜欢,咱们快快来了解一下吧!

作者:来自代顿大学研究院(UDRI))的 Barath Narayanan 博士

新型冠状病毒肺炎 (COVID-19) 是 2019 年发现的一种新型人类疾病,无任何先例可循。冠状病毒是一个庞大的病毒家族,可导致患者出现轻重不一的病症,轻至普通感冒,重至急性呼吸道综合症,如中东呼吸综合症 (MERS-COV) 或严重急性呼吸道综合症 (SARS-COV)。该肺炎已造成全球大流行,目前世界各地均有大量人口发生感染和接受治疗。仅美国一地,COVID-19 疫情或可导致 1.6 亿到 2.14 亿人感染 。一些国家和地区已宣布进入紧急状态,隔离人数多达数百万。

检测和诊断工具可以为医生提供宝贵的第二诊疗意见,协助他们完成筛查。与此同时,此类机制还有助于快速向医生呈现检测结果。来自代顿大学研究院(UDRI))的 Barath Narayanan 博士将为我们介绍他的团队如何应用基于深度学习的技术,使用 MATLAB 根据胸片检测 COVID-19

背景

本文使用的 COVID-19 数据集由蒙特利尔大学博士后研究员 Joseph Cohen 博士管理,下载 ZIP 文件后,将文件解压到名为 “Covid 19” 的文件夹中,所得的每个子文件夹对应 “dataset” 中的一个类。标签 “covid” 表示在患者体内检出 COVID-19,“normal” 则表示未检出。数据均等分布在两个类中(各 25 张影像),所以此处不存在类不均衡问题。

加载数据集

首先,使用 imageDatastore 加载数据库。该函数用于加载影像及其标签以执行分析,具有较高的计算效率。

%Clear workspaceclear; close all; clc; %ImagesDatapath–Please modify your path accordinglydatapath='dataset'; %ImageDatastoreimds=imageDatastore(datapath,...    'IncludeSubfolders',true,...'LabelSource','foldernames');%Determine the split uptotal_split=countEachLabel(imds)

影像可视化

可视化影像,了解各个类之间的影像差异。另外,这也有助于我们确定采用何种分类方法区分这两个类。根据影像,我们可以选择适当的预处理技术来帮助我们完成分类。根据类内相似性及类间差异性,我们可以确定研究所需的 CNN 架构类型。

%Number of Imagesnum_images=length(imds.Labels);%Visualize random imagesperm=randperm(num_images,6);figure;for idx=1:length(perm)        subplot(2,3,idx);    imshow(imread(imds.Files{perm(idx)}));    title(sprintf('%s',imds.Labels(perm(idx))))    end

K 折验证

上文已述,此数据集提供的影像数量有限,因此,我们将数据集拆分为 10 折进行分析,也就是使用数据集中的各组影像分别训练 10 个不同算法。此验证方法相比常用的留出验证法提供更为准确的性能预估。

本文采用 ResNet-50 架构,因为该架构经证实对各类医学成像应用均十分有效 [1,2]。

%Number of foldsnum_folds=10;%Loopfor each foldfor fold_idx=1:num_folds        fprintf        fprintf('Processing %d among %d folds \n',fold_idx,num_folds);      %TestIndicesfor current fold    test_idx=fold_idx:num_folds:num_images;    %Test cases for current fold    imdsTest = subset(imds,test_idx);       %Train indices for current fold    train_idx=setdiff(1:length(imds.Files),test_idx);       %Train cases for current fold    imdsTrain = subset(imds,train_idx);    %ResNetArchitecture    net=resnet50;    lgraph = layerGraph(net);    clear net;        %Number of categories    numClasses = numel(categories(imdsTrain.Labels));        %NewLearnableLayer    newLearnableLayer = fullyConnectedLayer(numClasses,...        'Name','new_fc',...        'WeightLearnRateFactor',10,...        'BiasLearnRateFactor',10);        %Replacing the last layers withnew layers    lgraph = replaceLayer(lgraph,'fc1000',newLearnableLayer);    newsoftmaxLayer = softmaxLayer('Name','new_softmax');    lgraph = replaceLayer(lgraph,'fc1000_softmax',newsoftmaxLayer);    newClassLayer = classificationLayer('Name','new_classoutput');    lgraph = replaceLayer(lgraph,'ClassificationLayer_fc1000',newClassLayer);          %PreprocessingTechnique    imdsTrain.ReadFcn=@(filename)preprocess_Xray(filename);    imdsTest.ReadFcn=@(filename)preprocess_Xray(filename);       %TrainingOptions, we choose a small mini-batch size due to limited images    options = trainingOptions('adam',...        'MaxEpochs',30,'MiniBatchSize',8,...        'Shuffle','every-epoch',...        'InitialLearnRate',1e-4,...        'Verbose',false,...        'Plots','training-progress');       %DataAugumentation    augmenter = imageDataAugmenter(...        'RandRotation',[-55],'RandXReflection',1,...        'RandYReflection',1,'RandXShear',[-0.050.05],'RandYShear',[-0.050.05]);        %Resizing all training images to [224224]forResNet architecture    auimds = augmentedImageDatastore([224224],imdsTrain,'DataAugmentation',augmenter);       %Training    netTransfer = trainNetwork(auimds,lgraph,options);      %Resizing all testing images to [224224]forResNet architecture    augtestimds = augmentedImageDatastore([224224],imdsTest);      %Testingand their corresponding LabelsandPosteriorfor each Case    [predicted_labels(test_idx),posterior(test_idx,:)]= classify(netTransfer,augtestimds);        %Save the IndependentResNetArchitectures obtained for each Fold    save(sprintf('ResNet50_%d_among_%d_folds',fold_idx,num_folds),'netTransfer','test_idx','train_idx');       %Clearing unnecessary variables    clearvars -except fold_idx num_folds num_images predicted_labels posterior imds netTransfer;    end

性能研究

我们通过混淆矩阵衡量算法性能,该指标同时也反映了查准率和查全率方面的性能。我们认为总体准确度是一个有效指标,因为本研究中使用的测试数据集为均匀分布(每个类别的图像数均等)。

混淆矩阵

%ActualLabelsactual_labels=imds.Labels;%ConfusionMatrixfigure;plotconfusion(actual_labels,predicted_labels')title('ConfusionMatrix:ResNet');

ROC 曲线

ROC 协助医生根据误报率和检测率选择工作点。

test_labels=double(nominal(imds.Labels));% ROC Curve-Our target classis the first classinthis scenario [fp_rate,tp_rate,T,AUC]=perfcurve(test_labels,posterior(:,1),1);figure;plot(fp_rate,tp_rate,'b-');grid on;xlabel('False Positive Rate');ylabel('Detection Rate');
%Area under the ROC curve valueAUC

AUC = 0.9776

类激活映射

将不同 COVID-19 病例经过这些网络处理后得到的类激活映射 (CAM) 结果可视化,这有助于医生了解算法决策背后的依据。以下是不同病例的相应结果:

基于其他公开数据集进行测试

为了进一步研究和分析算法性能,我们需要确定从不含 COVID-19 标签的其他公开数据集检测出 COVID-19 的概率。在此,我们使用 [2] 提供的病例,病例由放射科医生标记为正常、细菌性肺炎或病毒性肺炎。前文已述,每个网络分别使用 COVID-19 数据集中的一组不同影像进行训练。只要影像的冠状病毒后验概率大于 0.5,即视为假阳性 (FP)。结果清楚地表明,我们的算法具有较高的特异度和敏感度。在单核 GPU 上,每个测试用例的用时约 13 毫秒

结论

本文介绍了一种基于深度学习的简单分类方法,可用于 COVID-19 的计算机辅助诊断。基于 ResNet 的分类算法表现相对出色,总体准确度和 AUC 较高。迁移学习方法的良好性能再次印证,基于 CNN 的分类模型适于执行特征提取。使用几组新的带标签影像即可轻松重新训练算法,从而进一步增强性能。将上述结果与其他现有架构相结合,可以从 AUC 和总体准确度两方面提高性能。如果能就计算量(内存和时间)和性能两方面综合研究这些算法,将有助于相关专家有所侧重地选择算法。计算机辅助诊断不仅是医生进行 COVID-19 筛查的好帮手,还有助于提供宝贵的第二诊疗意见。

参考文献

面包多代码

[1] Narayanan, B. N., De Silva, M. S., Hardie, R. C., Kueterman, N. K., & Ali, R. (2019)."Understanding Deep Neural Network Predictions for Medical Imaging Applications". arXiv preprint arXiv:1912.09621.

[2] Narayanan, B. N., Davuluru, V. S. P., & Hardie, R. C. (2020, March)."Two-stage deep learning architecture for pneumonia detection and its diagnosis in chest radiographs".In Medical Imaging 2020:Imaging Informatics for Healthcare, Research, and Applications (Vol. 11318, p. 113180G).International Society for Optics and Photonics.

此外,知乎付费咨询:哥廷根数学学派

擅长现代信号处理(改进小波分析系列,改进变分模态分解,改进经验小波变换,改进辛几何模态分解等等),改进机器学习,改进深度学习,机械故障诊断,改进时间序列分析(金融信号,心电信号,振动信号等)

标签: #衡量算法性能的主要指标