参考资料:
https://pytorch.org/docs/stable/generated/torch.meshgrid.html
在此记录下torch.meshgrid的用法,该函数常常用于生成二维的网格:
>>> x = torch.tensor([1, 2, 3]) >>> y = torch.tensor([4, 5, 6]) >>> grid_x, grid_y = torch.meshgrid(x, y) >>> grid_x tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) >>> grid_y tensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]])
另一个例子:
>>> import torch >>> h = 6 >>> w = 10 >>> ys,xs = torch.meshgrid(torch.arange(h), torch.arange(w)) >>> xs.shape torch.Size([6, 10]) >>> ys.shape torch.Size([6, 10]) >>> xs tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) >>> ys tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]]) >>> xys = torch.stack([xs, ys], dim=-1) >>> xys.shape torch.Size([6, 10, 2])
需要注意的点:
1. torch.meshgrid函数的输入是若干个(N个)一维Tensor或者若干个标量。
2. torch.meshgrid函数的输出有N个,每个输出都是N维的。
3. torch.meshgrid函数的每个输出tensor的shape都为$(d_1, d_2, d_3 ... d_N)$,其中$d_i$为第i个输入向量的长度。
4. torch.meshgrid函数的每个输出有什么不同?答:为该输出对应输入向量在其他维度舒展开的结果。
5. torch的meshgrid实现和numpy的meshgrid实现有所不同,后者“可能”能够更直接地获取我们需要的东西,而torch的meshgrid调用后可能还需要做一个转置。