研究小组里初学习深度学习的同学我都布置写过 HybridSN 的代码:https://github.com/OUCTheoryGroup/colab_demo/blob/master/202003_models/HybridSN_GRSL2020.ipynb
最近做SRDP的同学反映,跑 Pavia 数据集的时候内存会爆,主要原因是 createImageCubes 这个函数有个地方:
patchesData = np.zeros([ width*height, windowSize, windowsSize, spectral_num])
因为 Pavia 数据集尺寸较大,width*height 就比较大了,内存会爆掉。
其实图像中大部分为是0,没有label,我们要取的,只是有label的部分。现在我改了改,先做个循环,看看有多少个像素有 label,然后记录在 count 里,分配内存时:
patchesData = np.zeros([count, windowSize, windowSize, spectral_num])
这样 count 比以前的 width*height 要小很多,内存就不会爆了
修改后的代码如下,供感兴趣的同学参考(Github上的我就不改了,留给以后新同学排雷):
# 在每个像素周围提取 patch
def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
# 给 X 做 padding
margin = int((windowSize - 1) / 2)
zeroPaddedX = padWithZeros(X, margin=margin)
# 获得 y 中的标记样本数
count = 0
for r in range(0, y.shape[0]):
for c in range(0, y.shape[1]):
if y[r, c] != 0:
count = count+1
# split patches
patchesData = np.zeros([count, windowSize, windowSize, X.shape[2]])
patchesLabels = np.zeros(count)
count = 0
for r in range(margin, zeroPaddedX.shape[0] - margin):
for c in range(margin, zeroPaddedX.shape[1] - margin):
if y[r-margin, c-margin] != 0:
patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]
patchesData[count, :, :, :] = patch
patchesLabels[count] = y[r-margin, c-margin]
count = count + 1
return patchesData, patchesLabels