torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
功能:从一个批次的输入张量中提取出滑动的局部区域块。(Extracts sliding local blocks from a batched input tensor.)
参数:
- kernel_size (intortuple) – 滑块的大小
- stride (intortuple,optional) – 滑块的步长(the stride of the sliding blocks in the input spatial dimensions. Default: 1)------controls the stride for the sliding blocks.
- padding (intortuple,optional) – 补0的个数(implicit zero padding to be added on both sides of input. Default: 0)-------controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension before reshaping.
- dilation (intortuple,optional) – 控制滑动过程中所跨越元素的个数(a parameter that controls the stride of elements within the neighborhood. Default: 1)-------controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.
输入: inputs (B, C, W, H )
B:batchsize C:channel W:width H:height
//Currently, only 4-D input tensors (batched image-like tensors) are supported.
输出: outputs (B, N, L)
N:表示每个滑块的大小,N=C×∏(kernel_size)=C*W*H
L:表示有多少个滑块,
其中,spatial_size表示输入张量的空间维度,这里spatial_size=(W, H ) ,d用来遍历这些维度,这里即为{0,1}。
import torch
import torch.nn as nn
inp = torch.tensor([[[[1.0, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24], [25, 26, 27, 28, 29, 30], ]]]) print('inp=') print(inp) unfold = nn.Unfold(kernel_size=(3, 3), dilation=1, padding=0, stride=(2, 1)) inp_unf = unfold(inp) print('inp_unf=') print(inp_unf)
官网解释道:
Convolution = Unfold + Matrix Multiplication + Fold
>>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
>>> inp = torch.randn(1, 3, 10, 12)
>>> w = torch.randn(2, 3, 4, 5)
>>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
>>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
>>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
>>> # or equivalently (and avoiding a copy),
>>> # out = out_unf.view(1, 2, 7, 8)
>>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
tensor(1.9073e-06)