tf.repeat实现单一元素扩展为其重复元素的二维图像
目标:将一维数据[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]扩展为9张二维图像(第一张图像中元素全为0,第二张的全为1,以此类推…)
代码:
import numpy as np
import tensorflow as tf
inpt = np.arange(10)
inpt = inpt.T # 生成[0, 1, 2 ... , 9]的列向量
inpt = np.expand_dims(inpt, axis=1) # 对inpt先扩一维,或者后面对std扩维也成
inpt.shape
输出:
(10, 1)
代码:
std = np.expand_dims(inpt, axis=2) # 继续扩维,将std扩为3维,便于后续repeat
std = tf.repeat(std, repeats=5, axis=1) # 对第1维重复
std = tf.repeat(std, repeats=4, axis=2) # 对第2维重复
std
输出:
<tf.Tensor: shape=(10, 5, 4), dtype=int32, numpy=
array([[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
...
[[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9],
[9, 9, 9, 9]]])>