在 PyTorch 中,size
方法和属性用于获取张量的维度信息。下面是它们的用法和区别:
-
node_features.size
:- 这是一个属性(attribute ),返回一个
torch.Size
对象,表示张量的维度。这是不可调用的,因此不能直接用于获取特定维度的大小。 - 示例:
size_attr = node_features.size print(size_attr) # 输出: torch.Size([3, 4])
- 这是一个属性(attribute ),返回一个
-
node_features.size()
:- 这是一个方法(method),返回一个
torch.Size
对象,本质上是一个包含张量维度的元组(tuple)。这个方法是可调用的,返回的结果与size
属性相同。 - 示例:
size_method = node_features.size() print(size_method) # 输出: torch.Size([3, 4])
- 这是一个方法(method),返回一个
-
node_features.size(1)
:- 这是一个方法调用,它接受一个整数参数(维度索引)并返回该维度的大小。这个方法用于直接获取特定维度的大小。
- This is a method call that takes an integer argument (the dimension index) and returns the size of that specific dimension. This is useful for obtaining the size of a particular dimension directly.
- 示例:
size_dim1 = node_features.size(1) print(size_dim1) # 输出: 4
总结
-
node_features.size
:属性,返回维度信息作为torch.Size
对象。Attribute that returns the dimensions as a torch.Size object. -
node_features.size()
:方法,返回维度信息作为torch.Size
对象(与属性相同)。 -
node_features.size(dimension)
:方法,返回指定维度的大小。 Method that returns the size of the specified dimension.
示例使用
下面是一个完整的示例来展示用法:
import torch
node_features = torch.tensor([[1.0, 2.0, 3.0, 4.0],
[2.0, 3.0, 4.0, 5.0],
[3.0, 4.0, 5.0, 6.0]])
# 使用 size 属性
size_attr = node_features.size
print(f"使用 size 属性: {size_attr}") # 输出: 使用 size 属性: torch.Size([3, 4])
# 使用 size 方法(无参数)
size_method = node_features.size()
print(f"使用 size 方法: {size_method}") # 输出: 使用 size 方法: torch.Size([3, 4])
# 使用 size 方法(带维度参数)
size_dim1 = node_features.size(1)
print(f"维度 1 的大小: {size_dim1}") # 输出: 维度 1 的大小: 4
为什么能输入 1
在 node_features.size(1)
中,参数 1
表示你想获取张量的第 1 个维度(从 0 开始计数)。对于这个特定的张量 node_features
,它的形状是 [3, 4]
,其中:
-
0
维度的大小是3
(行数) -
1
维度的大小是4
(列数)
因此,node_features.size(1)
返回 4
,因为第 1 个维度有 4 个元素。
使用 size 属性: <built-in method size of Tensor object at 0x7f0254cef400>
使用 size 方法: torch.Size([3, 4])
维度 1 的大小: 4
- 打印type(size_attr)得到<class ‘builtin_function_or_method’>
- 打印type(size_method)得到<class ‘torch.Size’>
- 如果调用不存在的维度会报错
size_dim1 = node_features.size(2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)