【tensorflow】tf.repeat实现单一元素扩展为具有重复元素的二维图像

tf.repeat实现单一元素扩展为其重复元素的二维图像

目标:将一维数据[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]扩展为9张二维图像(第一张图像中元素全为0,第二张的全为1,以此类推…)

代码

import numpy as np
import tensorflow as tf
inpt = np.arange(10)
inpt = inpt.T  # 生成[0, 1, 2 ... , 9]的列向量
inpt = np.expand_dims(inpt, axis=1)  # 对inpt先扩一维,或者后面对std扩维也成
inpt.shape

输出

(10, 1)

代码

std = np.expand_dims(inpt, axis=2)  # 继续扩维,将std扩为3维,便于后续repeat
std = tf.repeat(std, repeats=5, axis=1)  # 对第1维重复
std = tf.repeat(std, repeats=4, axis=2)  # 对第2维重复
std

输出

<tf.Tensor: shape=(10, 5, 4), dtype=int32, numpy=
array([[[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]],

       ...

       [[9, 9, 9, 9],
        [9, 9, 9, 9],
        [9, 9, 9, 9],
        [9, 9, 9, 9],
        [9, 9, 9, 9]]])>
上一篇:第十一届蓝桥杯真题解析JavaC组


下一篇:CSS基本样式