用单张X光进行CT的三维重建
一、课题背景
项目github链接:https://github.com/LijunRio/Xrays_CT
【前言-废话】
这个项目呢是我在研一开学前假期在实验室和师兄学习的一个项目。研一第一个学期的学习生活已经结束啦,然后发现之前做过的好多项目都忘记了。在学习期间也有好多人在CSDN上问我该课题相关的一些技术细节,然后我由于当时上课的确太忙了也没回复。刚好假期期间就重新的好好整理一遍这个项目~
【注意】
此次readme中更新的部分代码可能项目中没有,因为我已经在readme上放了完整代码啦,所以直接复制就可以跑啦~
1.1 背景
X光是一种常用的无痛快速的医学成像手段。X光的原理是在做X光检查的过程中,患者体内不同组织吸收X光的速率是不同的,在X光检查过程中,电磁波通过患者并以不同的速率被吸收,将患者体积的放射密度投影到曝光的照相胶片上,从而生成2D图像。
CT成像基本原理是用X线束对人体检查部位一定厚度的层面进行扫描,简单来说CT就是扫描人体产生多个切片,然后通过对切片信息重构就可以得到人体的详细三维信息。CT虽然能提供详细的三维信息,但是会带来额外的金钱和安全成本。CT成像比照单张X光更昂贵,而且如果会让患者暴露在一个数量级水平更高的辐射中。所以一般在临床医疗中,医生X光会比CT更为普变。
当医生拿到一张二维的X光的时候,医生是可以根据先验知识来理解二维平面X光片中的三维结构的。所以在人类领域通过足够的先验知识,是可以从二维的X光理解得到三维的人体结构的。所以此个课题项目就是希望通过二维的X光图像就可以重建出三维的脊柱骨模型,在空间上获取更多的数据,更好地对脊柱骨的侧弯程度进行分型。
1.2 前期论文研读
-
脊柱侧弯背景知识
- 阅读了Lenke分型的论文,掌握基本的脊柱分型知识
-
STATE OF THE ART 三维重建
- Yoni Kasten et al (2020) 等人提出的端对端的利用双平面X-Ray图像重建三维模型的方法
- Liyue Shen et al 等人提出的通过单个DRR图像生成CT数据的
- Henzler et al. (2018)提出的利用一个 encoder-decoder structure的卷积神经网络对实现2D到3D的重建
-
STATE OF THE ART 合成X光
由于通过X光生成CT需要成对的X光和CT数据,然而由于隐私政策等多重因素,实际上我们是很难获取如此多的数据的,所以打算在数据准备这块打算用CT数据生成X光来产生3D-2D配对的数据对。
- Moturu and Chang (2019) 提处了一种使用ray-tracing和Beer’s Law的从胸部CT扫描切片种合成X光的方法
- Teixeira et al. (2018)提出了一种合成X光图像的框架,该框架基于胸腔和腹部的表面几何形状来合成X光
二、实验部分前期准备
2.1 数据准备部分
为了创建一个可以接收2D-X光图像作为输入并产生3D-CT输出的神经网络,训练数据集必须由成对的2D-3D数据组成。我们课题组有和医院合作来获取数据集,不过在实际的实验过程中医院的数据量是远远不够的。因此我们想利用生成二维X光图像的方法来获取大量的2D-3D配对的数据集。CT数据集选用了公开的肺部数据集 The Cancer Imaging Archive (TCIA)
2.2 对CT数据进行数据处理
从官网下下来的肺部数据集分布比较零散,数据有的存在一级目录下,有的存在二级目录下。所以首先要先从原始的数据中找出所有的CT序列,并存入新的文件夹中
# 从原始数据中找出所有的CT序列,并存入新的文件夹
import os
import csv
import shutil
from tqdm import tqdm
data_dir = "D:\\lung_public_dataset\\LIDC-IDRI"
output_dir = "D:\\lung_public_dataset\\Raw_data"
csv_path = "D:\\lung_public_dataset\\data.csv"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
data_list = []
with open(csv_path, 'w', newline="") as f:
writer = csv.writer(f)
for maindir, subdir, file_name_list in os.walk(data_dir):
if len(file_name_list) > 50:
data_list.append(maindir)
print(os.listdir(data_list[0]))
file_count = 0
for item in tqdm(data_list):
file_dir = output_dir + "\\" + str(file_count)
shutil.copytree(item, file_dir, ignore=shutil.ignore_patterns('*.xml'))
file_count += 1
经过这个脚本之后,所有的数据就会被规整的放到Raw_data文件夹下,提取到了1013个样本。此时每一个文件夹下都已一些列的.dcm切片,但是此时这些切片都是不规整的,有些文件夹下切片比较多,有些比较少。需要进一步将所有的切片都变为一个正方体,即如果切片数过多的进行裁剪,舍弃一部分切片。如果切片数不够的,则在首尾补全黑色切片。并将最后整理得到的CT数据转化为.nii的三维数据类型进行保存。保存到ct_nii文件路径下
import numpy as np
from scipy.ndimage import zoom
import SimpleITK as sitk
import pydicom
import os
import dicom2nifti.patch_pydicom_encodings
from tqdm import tqdm
import csv
raw_ct_dir = "D:\\lung_public_dataset\\Raw_data\\"
raw_list = os.listdir(raw_ct_dir)
output = "D:\\lung_public_dataset\\ct_nii\\"
# resample to 128*128
def ImageResampleSize(sitk_image, img_size=128, is_label=False):
# 获取原来CT的Size和Spacing
size = np.array(sitk_image.GetSize())
spacing = np.array(sitk_image.GetSpacing())
new_size = (img_size, img_size, size[2])
new_spacing_refine = size * spacing / new_size
new_spacing_refine = [float(s) for s in new_spacing_refine]
new_size = [int(s) for s in new_size]
resample = sitk.ResampleImageFilter()
resample.SetOutputDirection(sitk_image.GetDirection())
resample.SetOutputOrigin(sitk_image.GetOrigin())
resample.SetSize(new_size)
resample.SetOutputSpacing(new_spacing_refine)
if is_label:
resample.SetInterpolator(sitk.sitkNearestNeighbor)
else:
# resample.SetInterpolator(sitk.sitkBSpline)
resample.SetInterpolator(sitk.sitkLinear)
newimage = resample.Execute(sitk_image)
return newimage
def saved_preprocessed(savedImg, origin, direction, xyz_thickness):
newImg = sitk.GetImageFromArray(savedImg)
newImg.SetOrigin(origin)
newImg.SetDirection(direction)
newImg.SetSpacing((xyz_thickness[0], xyz_thickness[1], xyz_thickness[2]))
return newImg
f = open('D:\\lung_public_dataset\\img_info.csv', 'w', encoding='utf-8', newline='')
csv_writer = csv.writer(f)
csv_writer.writerow(
["file_name", "origin_spacing", "origin_Origin", "origin_Direction", "origin_Size",
"resize_spacing", "resize_Origin", "resize_Direction", "resize_Size"])
for file in tqdm(raw_list):
# 读取dicom序列
file_path = os.path.join(raw_ct_dir, file)
output_path = output + file + ".nii"
reader = sitk.ImageSeriesReader()
reader.MetaDataDictionaryArrayUpdateOn() # 这一步是加载公开的元信息
reader.LoadPrivateTagsOn() # 这一步是加载私有的元信息
img_names = reader.GetGDCMSeriesFileNames(file_path)
reader.SetFileNames(img_names)
image = reader.Execute()
# 原来的info
origin_spacing = image.GetSpacing()
origin_Origin = image.GetOrigin()
origin_Direction = image.GetDirection()
origin_Size = image.GetSize()
normal = ImageResampleSize(image)
# resize后的Info
resize_spacing = normal.GetSpacing()
resize_Origin = normal.GetOrigin()
resize_Direction = normal.GetDirection()
resize_Size = normal.GetSize()
# 写入信息
csv_writer.writerow([file, origin_spacing, origin_Origin, origin_Direction, origin_Size, resize_spacing,
resize_Origin, resize_Direction, resize_Size])
# 获取CT_array
np_arr = sitk.GetArrayFromImage(normal)
# 去除CT边界图像
np_arr[np_arr < -500] = -500
# 将CT修正为128*128*128
x, y, z = np_arr.shape
if z < x or z < y:
startx = (x - z) // 2
starty = (y - z) // 2
# 裁剪成一个正方体
cubed_arr = np_arr[startx:startx + z, starty:starty + z, :]
else:
max_xyz = max(x, y, z)
x_pre_pad = (max_xyz - x) // 2
# (8-6)//2=1
x_post_pad = max_xyz - x - x_pre_pad
# 8-6-1=1
y_pre_pad = (max_xyz - y) // 2
y_post_pad = max_xyz - y - y_pre_pad
# 填充成一个正方体
cubed_arr = np.pad(np_arr, ((x_pre_pad, x_post_pad), (y_pre_pad, y_post_pad), (0, 0)), mode='constant',
constant_values=0)
print("process-2: resize to cube ! ", cubed_arr.shape)
x1, y1, z1 = cubed_arr.shape
assert x1 == y1, 'x and y dimensions are the same size'
assert y1 == z1, 'y and z dimensions are the same size'
assert z1 == x1, 'z and x dimensions are the same size'
# resize_img = sitk.GetImageFromArray(cubed_arr)
resize_img = saved_preprocessed(cubed_arr, resize_Origin, resize_Direction, resize_spacing)
print(resize_img.GetSize())
sitk.WriteImage(resize_img, output_path)
2.3 生成DRR来作为2D的X光
2.3.1 DRR介绍
DRR图像时医疗图像配准里面的一个重要的前置步骤:它的主要目的是,通过CT三维图像,获取模拟X射线影像,这个过程也被称为数字影像重建。在2D/3D的配准流程里面,需要首先通过CT三维图像,能够获取任意位置的DRR图像,然后去与已经获取的X光平面图像配准。正如前面所说的论文前期准备一样,用CT模拟生成X光的论文有很多。然后此次课题主要的重点时进行三维重建,所以这部分直接选用了ITK的一个相关示例进行DRR投影。这一部分网上的资料比较少,然后本人对DRR投影的原理其实并不是特别了解,所以花费了比较多的时间才跑通。然后下面时实验所用代码。【注意】需要在C++环境下运行
2.3.2 DRR代码即结果
-
这个代码的输入为CT的.nii格式;中间部分的注释为原来用.dicom作为输入。
-
这个代码主要也是基于官方示例教程进行更改的,所以很多地方也并没有十分看懂,但是主要需要调参的地方以及参数意义在代码中已经有注释了。
生成结果示例:好像是倒过来的不过其实并不影响训练:
#include "itkImage.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkResampleImageFilter.h"
#include "itkCenteredEuler3DTransform.h"
#include "itkNearestNeighborInterpolateImageFunction.h"
#include "itkImageRegionIteratorWithIndex.h"
#include "itkRescaleIntensityImageFilter.h"
#include <itkGDCMImageIO.h>
#include <itkPNGImageIO.h>
#include "itkPNGImageIOFactory.h"
#include "itkGDCMSeriesFileNames.h"
#include "itkImageSeriesReader.h"
#include "itkImageSeriesWriter.h"
#include "itkNumericSeriesFileNames.h"
#include "itkGiplImageIOFactory.h"
#include "itkCastImageFilter.h"
#include "itkNIFTIImageIO.h"
#include "itkNIFTIImageIOFactory.h"
#include "itkMetaImageIOFactory.h"
#pragma comment(lib,"ws2_32.lib")
#pragma comment(lib,"Rpcrt4.lib")
#pragma once
#pragma comment(lib,"Psapi.lib")
// Software Guide : BeginCodeSnippet
#include "itkRayCastInterpolateImageFunction.h"
#include <iostream>
// Software Guide : EndCodeSnippet
#include <io.h>
#include<iostream>
using namespace std;
int main(int argc, char* argv[])
{
int file_list[1013];
for (int i = 0; i < 1013; i++) {
file_list[i] = i;
//cout << i << endl;
}
// 总共1013个文件,然后懒得大改代码了。如果文件数有不同记得根据自己的文件进行修改
// 输出的文件会保存在文件下为drr文件夹下,为.png数据类型
for (int i = 0; i < 1013; i++) {
int number = file_list[i];
string num = to_string(number);
string in_drr = "D:/7_15data/open_dataset/ct_nii/";
string out_drr = "D:/7_15data/open_dataset/drr/";
string input = in_drr + num;
string output = out_drr + num + ".png";
const char* input_name = input.c_str();
const char* output_name = output.c_str();
cout << input_name << endl;
//cout << output_name << endl;
//const char* input_name = "D:/7_15data/ct/1 ";
//const char* output_name = "D:/7_15data/drr/5.png";
bool ok ; //true
bool verbose = false; //true
// CT volume rotation around isocenter along x,y,z axis in degrees
float rx = 90;
float ry = 0;
//float rz = 30.; //角度调节
//float rz = -90.;//侧着的
float rz = 0.; //正着的
//float rz = 180.; //背面
// Translation parameter of the isocenter in mm
float tx = 0.;
float ty = 0.;
float tz = 0.;
// The pixel indices of the isocenter
// The pixel indices of the isocenter
float cx = 0.;
float cy = 0.;
float cz = 0.;
//1000.
//float sid = 3000; //400. Source to isocenter distance in mm 源到等中心距 1000
float sid = 500; //400. Source to isocenter distance in mm 源到等中心距 1000
//Default pixel spacing in the iso-center plane in mm
//等角点平面中的默认像素间距,以毫米为单位
//float sx = 2.5; //1. 2,5
//float sy = 0.5; //1.
float sx = 2.5; //1.
float sy = 2.5; //1.
// Size of the output image in number of pixels
int dx = 512;
int dy = 512;
// The central axis positions of the 2D images in continuous indices
//连续索引中2D图像的中心轴位置
float o2Dx = 0;
float o2Dy = 0;
double threshold = -600; //1
const unsigned int Dimension = 3;
typedef float InputPixelType;
typedef unsigned char OutputPixelType;
typedef itk::Image< InputPixelType, Dimension > InputImageType;
typedef itk::Image< OutputPixelType, Dimension > OutputImageType;
//设置输出类型
InputImageType::Pointer image;
if (input_name)
{
//typedef itk::ImageSeriesReader< InputImageType > ReaderType;
//typedef itk::GDCMImageIO ImageIOType;
//typedef itk::GDCMSeriesFileNames NamesGeneratorType;
//ImageIOType::Pointer gdcmIO = ImageIOType::New();
//NamesGeneratorType::Pointer namesGenerator = NamesGeneratorType::New();
//namesGenerator->SetUseSeriesDetails(true);
//namesGenerator->SetDirectory("D:/7_9data/compress_ct/11");
namesGenerator->SetDirectory("D:/7_9data/volume-0.nii");
//using SeriesIdContainer = std::vector< std::string >;
//const SeriesIdContainer& seriesUID = namesGenerator->GetSeriesUIDs();
//auto seriesItr = seriesUID.begin();
//auto seriesEnd = seriesUID.end();
//using FileNamesContainer = std::vector< std::string >;
//FileNamesContainer filenames;
//std::string seriesIdentifier;
//while (seriesItr != seriesEnd)
//{
// seriesIdentifier = seriesItr->c_str();
// filenames = namesGenerator->GetFileNames(seriesIdentifier);
// ++seriesItr;
//}
namesGenerator->SetInputDirectory("D:/data/CT/PATIENT_DICOM/PATIENT_DICOM1"); //输入目录
const ReaderType::FileNamesContainer& filenames =
// //namesGenerator->GetInputFileNames();
//unsigned int numberOfFilenames = filenames.size();
//std::cout << numberOfFilenames << std::endl;
//for (unsigned int fni = 0; fni < numberOfFilenames; fni++)
//{
// std::cout << "filename # " << fni << " = ";
// std::cout << filenames[fni] << std::endl;
//}
//ReaderType::Pointer reader = ReaderType::New();
//reader->SetImageIO(gdcmIO);
//reader->SetFileNames(filenames);
/* typedef itk::ImageFileReader< InputImageType > ReaderType;
ReaderType::Pointer reader = ReaderType::New();
itk::MetaImageIOFactory::RegisterOneFactory();
reader->SetFileName("D:/7_9data/volume-0.nii");*/
typedef itk::ImageFileReader< InputImageType > ReaderType;
ReaderType::Pointer reader = ReaderType::New();
itk::NiftiImageIOFactory::RegisterOneFactory();
reader->SetFileName(input_name);
try
{
reader->Update();
}
catch (itk::ExceptionObject& err)
{
std::cerr << "ERROR: ExceptionObject caught !" << std::endl;
std::cerr << err << std::endl;
return EXIT_FAILURE;
}
image = reader->GetOutput();
}
else
{ // No input image specified so create a cube
image = InputImageType::New();
InputImageType::SpacingType spacing;
spacing[0] = 3.;
spacing[1] = 3.;
spacing[2] = 3.;
image->SetSpacing(spacing);
InputImageType::PointType origin;
origin[0] = 0.;
origin[1] = 0.;
origin[2] = 0.;
image->SetOrigin(origin);
InputImageType::IndexType start;
start[0] = 0; // first index on X
start[1] = 0; // first index on Y
start[2] = 0; // first index on Z
InputImageType::SizeType size;
size[0] = 61; // size along X
size[1] = 61; // size along Y
size[2] = 61; // size along Z
InputImageType::RegionType region;
region.SetSize(size);
region.SetIndex(start);
image->SetRegions(region);
image->Allocate(true); // initialize to zero.
image->Update();
typedef itk::ImageRegionIteratorWithIndex< InputImageType > IteratorType;
IteratorType iterate(image, image->GetLargestPossibleRegion());
while (!iterate.IsAtEnd())
{
InputImageType::IndexType idx = iterate.GetIndex();
if ((idx[0] >= 6) && (idx[0] <= 54)
&& (idx[1] >= 6) && (idx[1] <= 54)
&& (idx[2] >= 6) && (idx[2] <= 54)
&& ((((idx[0] <= 11) || (idx[0] >= 49))
&& ((idx[1] <= 11) || (idx[1] >= 49)))
|| (((idx[0] <= 11) || (idx[0] >= 49))
&& ((idx[2] <= 11) || (idx[2] >= 49)))
|| (((idx[1] <= 11) || (idx[1] >= 49))
&& ((idx[2] <= 11) || (idx[2] >= 49)))))
{
iterate.Set(10);
}
else if ((idx[0] >= 18) && (idx[0] <= 42)
&& (idx[1] >= 18) && (idx[1] <= 42)
&& (idx[2] >= 18) && (idx[2] <= 42)
&& ((((idx[0] <= 23) || (idx[0] >= 37))
&& ((idx[1] <= 23) || (idx[1] >= 37)))
|| (((idx[0] <= 23) || (idx[0] >= 37))
&& ((idx[2] <= 23) || (idx[2] >= 37)))
|| (((idx[1] <= 23) || (idx[1] >= 37))
&& ((idx[2] <= 23) || (idx[2] >= 37)))))
{
iterate.Set(60);
}
else if ((idx[0] == 30) && (idx[1] == 30) && (idx[2] == 30))
{
iterate.Set(100);
}
++iterate;
}
#ifdef WRITE_CUBE_IMAGE_TO_FILE
const char* filename = "cube.gipl";
typedef itk::ImageFileWriter< InputImageType > WriterType;
WriterType::Pointer writer = WriterType::New();
itk::GiplImageIOFactory::RegisterOneFactory();
writer->SetFileName(filename);
writer->SetInput(image);
try
{
std::cout << "Writing image: " << filename << std::endl;
writer->Update();
}
catch (itk::ExceptionObject& err)
{
std::cerr << "ERROR: ExceptionObject caught !" << std::endl;
std::cerr << err << std::endl;
return EXIT_FAILURE;
}
#endif
}
// Print out the details of the input volume
if (verbose)
{
unsigned int i;
const InputImageType::SpacingType spacing = image->GetSpacing();
std::cout << std::endl << "Input ";
InputImageType::RegionType region = image->GetBufferedRegion();
region.Print(std::cout);
std::cout << " Resolution: [";
for (i = 0; i < Dimension; i++)
{
std::cout << spacing[i];
if (i < Dimension - 1) std::cout << ", ";
}
std::cout << "]" << std::endl;
const InputImageType::PointType origin = image->GetOrigin();
std::cout << " Origin: [";
for (i = 0; i < Dimension; i++)
{
std::cout << origin[i];
if (i < Dimension - 1) std::cout << ", ";
}
std::cout << "]" << std::endl << std::endl;
}
typedef itk::ResampleImageFilter<InputImageType, InputImageType > FilterType;
FilterType::Pointer filter = FilterType::New();
filter->SetInput(image);
filter->SetDefaultPixelValue(0);
typedef itk::CenteredEuler3DTransform< double > TransformType;
TransformType::Pointer transform = TransformType::New();
transform->SetComputeZYX(true);
TransformType::OutputVectorType translation;
translation[0] = tx;
translation[1] = ty;
translation[2] = tz;
// constant for converting degrees into radians
const double dtr = (std::atan(1.0) * 4.0) / 180.0;
transform->SetTranslation(translation);
transform->SetRotation(dtr * rx, dtr * ry, dtr * rz);
InputImageType::PointType imOrigin = image->GetOrigin();
InputImageType::SpacingType imRes = image->GetSpacing();
typedef InputImageType::RegionType InputImageRegionType;
typedef InputImageRegionType::SizeType InputImageSizeType;
InputImageRegionType imRegion = image->GetBufferedRegion();
InputImageSizeType imSize = imRegion.GetSize();
imOrigin[0] += imRes[0] * static_cast<double>(imSize[0]) / 2.0;
imOrigin[1] += imRes[1] * static_cast<double>(imSize[1]) / 2.0;
imOrigin[2] += imRes[2] * static_cast<double>(imSize[2]) / 2.0;
TransformType::InputPointType center;
center[0] = cx + imOrigin[0];
center[1] = cy + imOrigin[1];
center[2] = cz + imOrigin[2];
transform->SetCenter(center);
if (verbose)
{
std::cout << "Image size: "
<< imSize[0] << ", " << imSize[1] << ", " << imSize[2]
<< std::endl << " resolution: "
<< imRes[0] << ", " << imRes[1] << ", " << imRes[2]
<< std::endl << " origin: "
<< imOrigin[0] << ", " << imOrigin[1] << ", " << imOrigin[2]
<< std::endl << " center: "
<< center[0] << ", " << center[1] << ", " << center[2]
<< std::endl << "Transform: " << transform << std::endl;
}
typedef itk::RayCastInterpolateImageFunction<InputImageType, double>
InterpolatorType;
InterpolatorType::Pointer interpolator = InterpolatorType::New();
interpolator->SetTransform(transform);
//
// We can then specify a threshold above which the volume's
// intensities will be integrated.
interpolator->SetThreshold(threshold);
InterpolatorType::InputPointType focalpoint;
focalpoint[0] = imOrigin[0];
focalpoint[1] = imOrigin[1];
focalpoint[2] = imOrigin[2] - sid / 2.;
interpolator->SetFocalPoint(focalpoint);
// Software Guide : EndCodeSnippet
if (verbose)
{
std::cout << "Focal Point: "
<< focalpoint[0] << ", "
<< focalpoint[1] << ", "
<< focalpoint[2] << std::endl;
}
// Software Guide : BeginLatex
//
// Having initialised the interpolator we pass the object to the
// resample filter.
interpolator->Print(std::cout);
filter->SetInterpolator(interpolator);
filter->SetTransform(transform);
// setup the scene
InputImageType::SizeType size;
size[0] = dx; // number of pixels along X of the 2D DRR image
size[1] = dy; // number of pixels along Y of the 2D DRR image
size[2] = 1; // only one slice
filter->SetSize(size);
InputImageType::SpacingType spacing;
spacing[0] = sx; // pixel spacing along X of the 2D DRR image [mm]
spacing[1] = sy; // pixel spacing along Y of the 2D DRR image [mm]
spacing[2] = 1.0; // slice thickness of the 2D DRR image [mm]
filter->SetOutputSpacing(spacing);
// Software Guide : EndCodeSnippet
if (verbose)
{
std::cout << "Output image size: "
<< size[0] << ", "
<< size[1] << ", "
<< size[2] << std::endl;
std::cout << "Output image spacing: "
<< spacing[0] << ", "
<< spacing[1] << ", "
<< spacing[2] << std::endl;
}
double origin[Dimension];
origin[0] = imOrigin[0] + o2Dx - sx * ((double)dx - 1.) / 2.;
origin[1] = imOrigin[1] + o2Dy - sy * ((double)dy - 1.) / 2.;
origin[2] = imOrigin[2] + sid / 2.;
filter->SetOutputOrigin(origin);
// Software Guide : EndCodeSnippet
if (verbose)
{
std::cout << "Output image origin: "
<< origin[0] << ", "
<< origin[1] << ", "
<< origin[2] << std::endl;
}
// create writer
if (output_name)
{
typedef itk::RescaleIntensityImageFilter<
InputImageType, OutputImageType > RescaleFilterType;
RescaleFilterType::Pointer rescaler = RescaleFilterType::New();
rescaler->SetOutputMinimum(0);
rescaler->SetOutputMaximum(255);
rescaler->SetInput(filter->GetOutput());
//
typedef itk::ImageFileWriter< OutputImageType > WriterType;
WriterType::Pointer writer = WriterType::New();
typedef itk::PNGImageIO pngType;
pngType::Pointer pngIO1 = pngType::New();
itk::PNGImageIOFactory::RegisterOneFactory();
writer->SetFileName(output_name);
writer->SetImageIO(pngIO1);
writer->SetImageIO(itk::PNGImageIO::New());
writer->SetInput(rescaler->GetOutput());
try
{
std::cout << "Writing image: " << output_name << std::endl;
writer->Update();
}
catch (itk::ExceptionObject& err)
{
std::cerr << "ERROR: ExceptionObject caught !" << std::endl;
std::cerr << err << std::endl;
}
}
else
{
filter->Update();
}
/* system("pause");
return EXIT_SUCCESS;*/
}
}
2.4 数据扩增
其实1000多个数据集还是不是很够的,所以这里进行了数据扩增。但是这一步怎么说呢比较微妙,感觉并不是特别的好。所以数据量足够的情况下不建议加这步。
三维重建的思想是想让模型观测一张二维的X光能生成对应的三维模型。实际上我们并不关心看到的这张图像的角度是什么样的,也不关心这张图想的颜色强度是怎样的(对应窗宽等概念),所以是否可以将X光旋转然后进行明暗变化,然后让模型学到其本质之下的特征所对应的三维关系是一样的呢?所以本着这样的假设,可以将我们人造的X光图像进行一定程度上的旋转,颜色明暗程度的变化,然后这些变化后的图像仍然对应原来的三维模型。这里只是实验,可能并不严谨,勿杠 [狗头保命]
这个地方思路简单就不放代码啦,数据扩增脚本很多的~
2.5 注意事项
在这个地方有一个地方没做好,就是没有将CT的体素规整到1:1:1,而是将每个体素的spacing保存下来了。因为这个是公开数据集,所以是不同CT仪器的数据,每一个CT的spacing都是不一样的。所以又要将它变为正方体,又要让他的间隔距离变为1:1:1当时时间紧迫就没有这样做了。而是实现将所有的spacing信息都保存了下来。这个地方就是有bug,假期刚接受的小白新手,真诚的希望大神指点![保护好我的狗头发际线.jpg]
三、网络部分
3.1 基本思路
网络这部分我在Liyue Shen et al 等人提出网络上进行了一些改进,然后由于这部分关乎到我师兄最后的论文部分,所以感觉不太好开源。所以这部分我简要的叙述一下我的思路:
其实这篇论文的网络结构并不是特别复杂,主要分为三个模块:表征网络、转换模块、生成网络。
- 表征层网络在原论文中用到的是一个类残差网络的结构,主要用于提取DRR图像的特征信息
- 在后面的实际测试中,我发现Liyue Shen在表征层的前几层就用了将卷积核的数目设得特别大,这样直接导致表征层的计算参数非常的多! 甚至表征层部分的计算参数比resnet-50还要多。这样网络训练起来是非常慢且很难收敛。所以表征层部分的网络是明显可以优化的!
- 虽然原论文说,这样做能尽可能的保留较多的信息,但是实际训练中可以发现过大的计算量反而不利于网络的训练
- 转化模块
- 经过表征层网络训练后得到4096×4×4的tensor,通过transform模块将这个tensor转为2048×2×4×4
- 实现起来其实很简单就是一个tensor的转化
- 生成网络
- 类似一个解码器,根据二维图像提取的特征生成三维的CT数据
- 生成网络比较简单,是否可以优化生成部分的网络结构(感觉只是单纯的为了将特征还原到CT的数据大小)
- 其实也有点像GAN,将得到的特征生成三维信息(仅个人感觉)
3.2 原论文代码
以下是Liyue Shen et al开源的代码,我稍作了改写方便调试运行
import torch.nn as nn
import torch
import math
from torch.nn import init
from torchsummary import summary
# 2D Conv
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes,
kernel_size=1, stride=stride, padding=0,
bias=False)
def conv2x2(in_planes, out_planes, stride=2):
return nn.Conv2d(in_planes, out_planes,
kernel_size=2, stride=stride, padding=0,
bias=False)
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes,
kernel_size=3, stride=stride, padding=1,
bias=False)
def conv4x4(in_planes, out_planes, stride=2):
return nn.Conv2d(in_planes, out_planes,
kernel_size=4, stride=stride, padding=1,
bias=False)
# 3D Conv
def conv1x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes,
kernel_size=1, stride=stride, padding=0,
bias=False)
def conv3x3x3(in_planes, out_planes, stride=1):
return nn.Conv3d(in_planes, out_planes,
kernel_size=3, stride=stride, padding=1,
bias=False)
def conv4x4x4(in_planes, out_planes, stride=2):
return nn.Conv3d(in_planes, out_planes,
kernel_size=4, stride=stride, padding=1,
bias=False)
# 2D Deconv
def deconv1x1(in_planes, out_planes, stride):
return nn.ConvTranspose2d(in_planes, out_planes,
kernel_size=1, stride=stride, padding=0, output_padding=0,
bias=False)
def deconv2x2(in_planes, out_planes, stride):
return nn.ConvTranspose2d(in_planes, out_planes,
kernel_size=2, stride=stride, padding=0, output_padding=0,
bias=False)
def deconv3x3(in_planes, out_planes, stride):
return nn.ConvTranspose2d(in_planes, out_planes,
kernel_size=3, stride=stride, padding=1, output_padding=0,
bias=False)
def deconv4x4(in_planes, out_planes, stride):
return nn.ConvTranspose2d(in_planes, out_planes,
kernel_size=4, stride=stride, padding=1, output_padding=0,
bias=False)
# 3D Deconv
def deconv1x1x1(in_planes, out_planes, stride):
return nn.ConvTranspose3d(in_planes, out_planes,
kernel_size=1, stride=stride, padding=0, output_padding=0,
bias=False)
def deconv3x3x3(in_planes, out_planes, stride):
return nn.ConvTranspose3d(in_planes, out_planes,
kernel_size=3, stride=stride, padding=1, output_padding=0,
bias=False)
def deconv4x4x4(in_planes, out_planes, stride):
return nn.ConvTranspose3d(in_planes, out_planes,
kernel_size=4, stride=stride, padding=1, output_padding=0,
bias=False)
def _make_layers(in_channels, output_channels, type, batch_norm=False, activation=None):
layers = []
if type == 'conv1_s1':
layers.append(conv1x1(in_channels, output_channels, stride=1))
elif type == 'conv2_s2':
layers.append(conv2x2(in_channels, output_channels, stride=2))
elif type == 'conv3_s1':
layers.append(conv3x3(in_channels, output_channels, stride=1))
elif type == 'conv4_s2':
layers.append(conv4x4(in_channels, output_channels, stride=2))
elif type == 'deconv1_s1':
layers.append(deconv1x1(in_channels, output_channels, stride=1))
elif type == 'deconv2_s2':
layers.append(deconv2x2(in_channels, output_channels, stride=2))
elif type == 'deconv3_s1':
layers.append(deconv3x3(in_channels, output_channels, stride=1))
elif type == 'deconv4_s2':
layers.append(deconv4x4(in_channels, output_channels, stride=2))
elif type == 'conv1x1_s1':
layers.append(conv1x1x1(in_channels, output_channels, stride=1))
elif type == 'deconv1x1_s1':
layers.append(deconv1x1x1(in_channels, output_channels, stride=1))
elif type == 'deconv3x3_s1':
layers.append(deconv3x3x3(in_channels, output_channels, stride=1))
elif type == 'deconv4x4_s2':
layers.append(deconv4x4x4(in_channels, output_channels, stride=2))
else:
raise NotImplementedError('layer type [{}] is not implemented'.format(type))
if batch_norm == '2d':
layers.append(nn.BatchNorm2d(output_channels))
elif batch_norm == '3d':
layers.append(nn.BatchNorm3d(output_channels))
if activation == 'relu':
layers.append(nn.ReLU(inplace=True))
elif activation == 'sigm':
layers.append(nn.Sigmoid())
elif activation == 'leakyrelu':
layers.append(nn.LeakyReLU(0.2, True))
else:
if activation is not None:
raise NotImplementedError('activation function [{}] is not implemented'.format(activation))
return nn.Sequential(*layers)
def _init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=1.0)
elif init_stype == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
print('Initializing Weights: {}...'.format(classname))
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
elif classname.find('Sequential') == -1 and classname.find('Conv5_Deconv5_Local') == -1:
raise NotImplementedError('initialization of [{}] is not implemented'.format(classname))
print('initialize network with {}'.format(init_type))
net.apply(init_func)
def _initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class ReconNet(nn.Module):
def __init__(self, in_planes, out_planes, gain=0.02, init_type='standard'):
super(ReconNet, self).__init__()
######### representation network - convolution layers
self.conv_layer1 = _make_layers(in_planes, 256, 'conv4_s2', False)
self.conv_layer2 = _make_layers(256, 256, 'conv3_s1', '2d')
self.relu2 = nn.ReLU(inplace=True)
self.conv_layer3 = _make_layers(256, 512, 'conv4_s2', '2d', 'relu')
self.conv_layer4 = _make_layers(512, 512, 'conv3_s1', '2d')
self.relu4 = nn.ReLU(inplace=True)
self.conv_layer5 = _make_layers(512, 1024, 'conv4_s2', '2d', 'relu')
self.conv_layer6 = _make_layers(1024, 1024, 'conv3_s1', '2d')
self.relu6 = nn.ReLU(inplace=True)
self.conv_layer7 = _make_layers(1024, 2048, 'conv4_s2', '2d', 'relu')
self.conv_layer8 = _make_layers(2048, 2048, 'conv3_s1', '2d')
self.relu8 = nn.ReLU(inplace=True)
self.conv_layer9 = _make_layers(2048, 4096, 'conv4_s2', '2d', 'relu')
self.conv_layer10 = _make_layers(4096, 4096, 'conv3_s1', '2d')
self.relu10 = nn.ReLU(inplace=True)
######### transform module
self.trans_layer1 = _make_layers(4096, 4096, 'conv1_s1', False, 'relu')
self.trans_layer2 = _make_layers(2048, 2048, 'deconv1x1_s1', False, 'relu')
######### generation network - deconvolution layers
self.deconv_layer10 = _make_layers(2048, 1024, 'deconv4x4_s2', '3d', 'relu')
self.deconv_layer8 = _make_layers(1024, 512, 'deconv4x4_s2', '3d', 'relu')
self.deconv_layer7 = _make_layers(512, 512, 'deconv3x3_s1', '3d', 'relu')
self.deconv_layer6 = _make_layers(512, 256, 'deconv4x4_s2', '3d', 'relu')
self.deconv_layer5 = _make_layers(256, 256, 'deconv3x3_s1', '3d', 'relu')
self.deconv_layer4 = _make_layers(256, 128, 'deconv4x4_s2', '3d', 'relu')
self.deconv_layer3 = _make_layers(128, 128, 'deconv3x3_s1', '3d', 'relu')
self.deconv_layer2 = _make_layers(128, 64, 'deconv4x4_s2', '3d', 'relu')
self.deconv_layer1 = _make_layers(64, 64, 'deconv3x3_s1', '3d', 'relu')
self.deconv_layer0 = _make_layers(64, 1, 'conv1x1_s1', False, 'relu')
self.output_layer = _make_layers(64, out_planes, 'conv1_s1', False)
if init_type == 'standard':
_initialize_weights(self)
else:
_init_weights(self, gain=gain, init_type=init_type)
def forward(self, x):
### representation network
conv1 = self.conv_layer1(x)
conv2 = self.conv_layer2(conv1)
relu2 = self.relu2(conv1 + conv2)
conv3 = self.conv_layer3(relu2)
conv4 = self.conv_layer4(conv3)
relu4 = self.relu4(conv3 + conv4)
conv5 = self.conv_layer5(relu4)
conv6 = self.conv_layer6(conv5)
relu6 = self.relu6(conv5 + conv6)
conv7 = self.conv_layer7(relu6)
conv8 = self.conv_layer8(conv7)
relu8 = self.relu8(conv7 + conv8)
conv9 = self.conv_layer9(relu8)
conv10 = self.conv_layer10(conv9)
relu10 = self.relu10(conv9 + conv10)
### transform module
features = self.trans_layer1(relu10)
trans_features = features.view(-1, 2048, 2, 4, 4)
trans_features = self.trans_layer2(trans_features)
### generation network
deconv10 = self.deconv_layer10(trans_features)
deconv8 = self.deconv_layer8(deconv10)
deconv7 = self.deconv_layer7(deconv8)
deconv6 = self.deconv_layer6(deconv7)
deconv5 = self.deconv_layer5(deconv6)
deconv4 = self.deconv_layer4(deconv5)
deconv3 = self.deconv_layer3(deconv4)
deconv2 = self.deconv_layer2(deconv3)
deconv1 = self.deconv_layer1(deconv2)
### output
out = self.deconv_layer0(deconv1)
out = torch.squeeze(out, 1)
out = self.output_layer(out)
return out
def reconnet(in_channels, out_channels, **kwargs):
model = ReconNet(in_channels, out_channels, **kwargs)
return model
model = reconnet(1, 128)
model.cuda()
summary(model, (1, 128, 128))
四、训练参数设置
个人训练设置:
-
激活函数选用ReLU
-
代价函数选用均方误差(MSE)
-
优化器论文中用了Adam
-
实际用的时候使用Adam在反向传播的时候计算量很大,网络跑不动,所以使用了SGD,且momentum=0才跑得动
-
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0)
-
五、总结
这个项目我是大致都做完了,正如过程中所说还有挺多坑没填的,然后网络我是改了,效果也有明显提升。这里就不细说啦,等师兄的论文有结果在进一步总结。
有待填的坑:
- 数据处理部分的CT间距问题
- 设计新的损失函数
- 更好的DRR生成方法