PyTorch中的广播机制
文章目录
- PyTorch中的广播机制
- 1. 广播
- 代码示例
- 2. 不适合广播机制的情况:in-place操作
1. 广播
-
“广播”这一术语用于描述如何在形状不一的数组上应用算术运算。
-
在满足特定限制的前提下,较小的数组“广播至”较大的数组,使两者形状互相兼容。广播提供了一个向量化数组操作的机制,这样遍历就发生在C层面,而不是Python层面。广播可以避免不必要的数据复制,通常导向高效的算法实现。不过,也存在不适用广播的情形(可能导致拖慢计算过程的低效内存使用)。
-
可广播的一对张量需满足以下规则:
- 每个张量至少有一个维度。
- 遍历tensor所有维度,从尾部的维度开始,两个tensor的维度尺寸存在:
-
tensor
维度相等。 -
tensor
维度不等且其中一个维度为1。 -
tensor
维度不等且其中一个维度不存在。
-
-
如果两个tensor是“可广播的”,则计算过程遵循下列规则:
- 如果两个
tensor
的维度不同,则在维度较小的tensor
的前面增加维度,使它们维度相等。 - 对于每个维度,计算结果的维度值取两个
tensor
中较大的那个值。 - 两个
tensor
扩展维度的过程是将数值进行复制。
- 如果两个
代码示例
import torch
# 示例1:相同形状的张量总是可广播的,因为总能满足以上规则。
x = torch.empty(5, 7, 3)
y = torch.empty(5, 7, 3)
# 示例2:不可广播( a 不满足第一条规则)。
a = torch.empty((0,))
b = torch.empty(2, 2)
# 示例3:m 和 n 可广播:
m = torch.empty(5, 3, 4, 1)
n = torch.empty( 3, 1, 1)
# 倒数第一个维度:两者的尺寸均为1
# 倒数第二个维度:n尺寸为1
# 倒数第三个维度:两者尺寸相同
# 倒数第四个维度:n该维度不存在
# 示例4:不可广播,因为倒数第三个维度:2 != 3
p = torch.empty(5, 2, 4, 1)
q = torch.empty( 3, 1, 1)
- 现在你对“可广播”这一概念已经有所了解了,让我们看下,广播后的张量是什么样的。
- 如果张量x和张量y是可广播的,那么广播后的张量尺寸按照如下方法计算:
- 如果x和y的维数不等,在维数较少的张量上添加尺寸为 1 的维度。结果维度尺寸是x和y相应维度尺寸的较大者。
# 示例5:可广播
c = torch.empty(5, 1, 4, 1)
d = torch.empty( 3, 1, 1)
(c + d).size() # torch.Size([5, 3, 4, 1])
# 示例6:可广播
f = torch.empty( 1)
g = torch.empty(3, 1, 7)
(f + g).size() # torch.Size([3, 1, 7])
# 示例7:不可广播
o = torch.empty(5, 2, 4, 1)
u = torch.empty( 3, 1, 1)
(o + u).size()
# 报错:
# ---------------------------------------------------------------------------
#
# RuntimeError Traceback (most recent call last)
#
# <ipython-input-17-72fb34250db7> in <module>()
# 1 o=torch.empty(5,2,4,1)
# 2 u=torch.empty(3,1,1)
# ----> 3 (o+u).size()
#
# RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at no
2. 不适合广播机制的情况:in-place操作
in-place operation称为原地操作符,在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀_
来代表原地操作符,例:.add_()、.scatter()
。in-place操作不允许tensor像广播那样改变形状