首先配置文件不能少config.yaml
# Model Parameters
network:
num_cascades: 6
num_layers: 5 # Number of layers in the CNN per cascade
num_filters: 64
kernel_size: 3
stride: 1
padding: 1 #A padding of 1 is needed to keep the image in the same size
noise: null #Noise in the measurements. To be used in the data consistency step
#Dataset parameters
dataset:
data_path: 'data1/'
acceleration_factor: 4.0
fraction: 0.8 #train set size
shuffle: 3 #Seed for numpy random generator
sample_n: 10
acq_noise: 0 #acquisation noise
centred: False
norm: 'ortho' #norm: 'ortho' or null. if 'ortho', performs unitary transform, otherwise normal dft
# Training parameters
train:
batch_size: 1
num_epochs: 5
early_stop: 100
# Adam Optimizer Parameters
learning_rate: 0.001
b_1: 0.9
b_2: 0.999
l2: 0.0000001
# Miscellaneous
output_path: 'logs'
cuda: False
单步执行
import os
import torch
import numpy as np
from math import ceil
from helpers_1 import *
from scipy.io import loadmat
from numpy.lib.stride_tricks import as_strided
import yaml
args = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)
dataset = OCMRDataset(fold=‘train’, **args[‘dataset’])
self.evalset = evalset
self.data_path = data_path
self.acc = acceleration_factor
self.sample_n = sample_n
self.noise = acq_noise
self.centred = centred
self.norm = norm
self.files = os.listdir(self.data_path)
if shuffle:
np.random.seed(shuffle)
np.random.shuffle(self.files)
if fold == 'train':
self.files = self.files[:int(len(self.files) * fraction)]
def __getitem__(self, idx):
if self.evalset and idx == 0:
np.random.seed(9001)
data = loadmat(os.path.join(self.data_path, self.files[idx]))['xn'] * 1e3
data = np.expand_dims(data, 0)
因为这里batch_size设的1,所以就有一个256*256
mask = self.cartesian_mask(data.shape)
def cartesian_mask(self, shape):
N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1]
pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2)
def normal_pdf(length, sensitivity):
return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)
lmda = Nx/(2.*self.acc)
n_lines = int(Nx / self.acc)
# add uniform distribution
pdf_x += lmda * 1./Nx
if self.sample_n:
pdf_x[Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 0
pdf_x /= np.sum(pdf_x)
n_lines -= self.sample_n
mask = np.zeros((N, Nx))
for i in range(N):
idx = np.random.choice(Nx, n_lines, False, pdf_x)
mask[i, idx] = 1
if self.sample_n:
mask[:, Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 1
size = mask.itemsize
mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0))
mask = mask.reshape(shape)
mask上下很多行都是0,中间位置1越来越密集
if not self.centred:
mask = ifftshift(mask, axes=(-1, -2))
return mask
反过来了1变0,0变1
data_und, k_und = self.undersample(data, mask)
assert x.shape == mask.shape
# zero mean complex Gaussian noise
noise_power = self.noise
nz = np.sqrt(.5)*(np.random.normal(0, 1, x.shape) + 1j * np.random.normal(0, 1, x.shape))
nz = nz * np.sqrt(noise_power)
if self.norm == 'ortho':
# multiplicative factor
nz = nz * np.sqrt(np.prod(mask.shape[-2:]))
if self.centred:
x_f = fft2c(x, norm=self.norm)
x_fu = mask * (x_f + nz)
x_u = ifft2c(x_fu, norm=self.norm)
return x_u, x_fu
else:
x_f = fft2(x, norm=self.norm)
x_fu = mask * (x_f + nz)
x_u = ifft2(x_fu, norm=self.norm)
return x_u, x_fu
data_und, k_und = x_u, x_fu
data_gnd = format_data(data)
def format_data(data, mask=False):
if mask:
data = data * (1+1j)
data = complex2real(data)
def complex2real(x):
x_real = np.real(x)
x_imag = np.imag(x)
y = np.array([x_real, x_imag]).astype(np.float)
if x.ndim >= 3:
y = y.swapaxes(0, 1)
return y
def format_data(data, mask=False):
data = complex2real(data)
return data.squeeze(0)
data_gnd = data.squeeze(0)
data_und = format_data(data_und)
def format_data(data, mask=False):
data = complex2real(data)
return data.squeeze(0)
def complex2real(x):
x_real = np.real(x)
x_imag = np.imag(x)
y = np.array([x_real, x_imag]).astype(np.float)
# re-order in convenient order
if x.ndim >= 3:
y = y.swapaxes(0, 1)
return y
data = complex2real(data)即等于y
data_und = data.squeeze(0)
k_und = format_data(k_und)
def format_data(data, mask=False):
data = complex2real(data)
return data.squeeze(0)
def complex2real(x):
x_real = np.real(x)
x_imag = np.imag(x)
y = np.array([x_real, x_imag]).astype(np.float)
# re-order in convenient order
if x.ndim >= 3:
y = y.swapaxes(0, 1)
return y
k_und = data.squeeze(0)
mask = format_data(mask, mask=True)
def format_data(data, mask=False):
data = complex2real(data)
return data.squeeze(0)
def complex2real(x):
x_real = np.real(x)
x_imag = np.imag(x)
y = np.array([x_real, x_imag]).astype(np.float)
# re-order in convenient order
if x.ndim >= 3:
y = y.swapaxes(0, 1)
return y
data = complex2real(data)即等于y
mask = data.squeeze(0)
return {
'image': data_und,
'k': k_und.transpose(1,2,0),
'mask': mask.transpose(1,2,0),
'full': data_gnd
}
sample = dataset[0]即4个tensor
第一个tensor。‘image’: data_und
第二个tensor。‘k’: k_und.transpose(1,2,0)
第三个tensor。‘mask’: mask.transpose(1,2,0)
第四个tensor。‘full’: data_gnd
输出:
Sample image shape: (2, 256, 256)
Sample full shape: (2, 256, 256)