torch.nn.Unfold()详细解释

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

功能:从一个批次的输入张量中提取出滑动的局部区域块。(Extracts sliding local blocks from a batched input tensor.)

参数:

  • kernel_size (intortuple) – 滑块的大小
  • stride (intortuple,optional) – 滑块的步长(the stride of the sliding blocks in the input spatial dimensions. Default: 1)------controls the stride for the sliding blocks.
  • padding (intortuple,optional) – 补0的个数(implicit zero padding to be added on both sides of input. Default: 0)-------controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension before reshaping.
  • dilation (intortuple,optional) – 控制滑动过程中所跨越元素的个数(a parameter that controls the stride of elements within the neighborhood. Default: 1)-------controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.

输入: inputs (B, C, W, H )

B:batchsize C:channel W:width H:height

//Currently, only 4-D input tensors (batched image-like tensors) are supported.

输出: outputs (B, N, L)

N:表示每个滑块的大小,N=C×∏(kernel_size)=C*W*H

L:表示有多少个滑块,

torch.nn.Unfold()详细解释

其中,spatial_size表示输入张量的空间维度,这里spatial_size=(W, H ) ,d用来遍历这些维度,这里即为{0,1}。

import torch
import torch.nn as nn
inp = torch.tensor([[[[1.0, 2, 3, 4, 5, 6],
                      [7, 8, 9, 10, 11, 12],
                      [13, 14, 15, 16, 17, 18],
                      [19, 20, 21, 22, 23, 24],
                      [25, 26, 27, 28, 29, 30],
                      ]]])
print('inp=')
print(inp)

unfold = nn.Unfold(kernel_size=(3, 3), dilation=1, padding=0, stride=(2, 1))
inp_unf = unfold(inp)
print('inp_unf=')
print(inp_unf)

torch.nn.Unfold()详细解释

官网解释道:

Convolution = Unfold + Matrix Multiplication + Fold

>>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
>>> inp = torch.randn(1, 3, 10, 12)
>>> w = torch.randn(2, 3, 4, 5)
>>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
>>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
>>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
>>> # or equivalently (and avoiding a copy),
>>> # out = out_unf.view(1, 2, 7, 8)
>>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
tensor(1.9073e-06)

上一篇:多线程


下一篇:javaScript增删改