A = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
# 创建一个布尔张量,选择所有大于 5 的元素
mask = A > 5
# 使用布尔张量作为索引
filtered_elements = A[mask]
print("原始张量 A:")
print(A)
print("\n布尔张量 mask:")
print(mask)
print("\n过滤后的元素:")
print(filtered_elements)
原始张量 A:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
布尔张量 mask:
tensor([[False, False, False, False],
[False, True, True, True],
[ True, True, True, True]])
过滤后的元素:
tensor([ 6, 7, 8, 9, 10, 11, 12])