文章目录
torch equivalents of numpy functions
lua+torch的社区资源太少了,搬运自torch的官网,希望对大家有所帮助
Types
Numpy |
Torch |
np.ndarray |
torch.Tensor |
np.float32 |
torch.FloatTensor |
np.float64 |
torch.DoubleTensor |
np.int8 |
torch.CharTensor |
np.uint8 |
torch.ByteTensor |
np.int16 |
torch.ShortTensor |
np.int32 |
torch.IntTensor |
np.int64 |
torch.LongTensor |
Constructors
Ones and zeros
Numpy |
Torch |
np.empty([2,2]) |
torch.Tensor(2,2) |
np.empty_like(x) |
x.new(x:size()) |
np.eye |
torch.eye |
np.identity |
torch.eye |
np.ones |
torch.ones |
np.ones_like |
torch.ones(x:size()) |
np.zeros |
torch.zeros |
np.zeros_like |
torch.zeros(x:size()) |
From existing data
Numpy |
Torch |
np.array([ [1,2],[3,4] ]) |
torch.Tensor({{1,2},{3,4}}) |
np.ascontiguousarray(x) |
x:contiguous() |
np.copy(x) |
x:clone() |
np.fromfile(file) |
torch.Tensor(torch.Storage(file)) |
np.frombuffer |
??? |
np.fromfunction |
??? |
np.fromiter |
??? |
np.fromstring |
??? |
np.loadtxt |
??? |
np.concatenate |
torch.cat |
np.multiply |
torch.cmul |
Numerical Ranges
Numpy |
Torch |
np.arange(10) |
torch.range(0,9) |
np.arange(2, 3, 0.1) |
torch.linspace(2, 2.9, 10) |
np.linspace(1, 4, 6) |
torch.linspace(1, 4, 6) |
np.logspace |
torch.logspace |
Building Matrices
Numpy |
Torch |
np.diag |
torch.diag |
np.tril |
torch.tril |
np.triu |
torch.triu |
Attributes
Numpy |
Torch |
x.shape |
x:size() |
x.strides |
x:stride() |
x.ndim |
x:dim() |
x.data |
x:data() |
x.size |
x:nElement() |
x.size == y.size |
x:isSameSizeAs(y) |
x.dtype |
x:type() |
Indexing
Shape Manipulation
Numpy |
Torch |
x.reshape |
x:reshape |
x.resize |
x:resize |
? |
x:resizeAs |
x.transpose |
x:transpose() |
x.flatten |
x:view(x:nElement()) |
x.squeeze |
x:squeeze |
Item selection and manipulation
Numpy |
Torch |
np.take(a, indices) |
a[indices] |
x[:,0] |
x[{{},1}] |
np.put |
??? |
x.repeat |
x:repeatTensor |
x.fill |
x:fill |
np.choose |
??? |
np.sort |
sorted, indices = torch.sort(x, [dim]) |
np.argsort |
sorted, indices = torch.sort(x, [dim]) |
np.nonzero |
torch.find(x:gt(0), 1) (torchx) |
Calculation
Numpy |
Torch |
ndarray.min |
mins, indices = torch.min(x, [dim]) |
ndarray.argmin |
mins, indices = torch.min(x, [dim]) |
ndarray.max |
maxs, indices = torch.max(x, [dim]) |
ndarray.argmax |
maxs, indices = torch.max(x, [dim]) |
ndarray.clip |
torch.clamp |
ndarray.round |
|
ndarray.trace |
torch.trace |
ndarray.sum |
torch.sum |
ndarray.cumsum |
torch.cumsum |
ndarray.mean |
torch.mean |
ndarray.std |
torch.std |
ndarray.prod |
torch.prod |
ndarray.dot |
torch.mm |
ndarray.cumprod |
torch.cumprod |
ndarray.all |
??? |
ndarray.any |
??? |
Arithmetic and comparison operations
Numpy |
Torch |
ndarray.lt
|
torch.lt |
ndarray.le
|
torch.le |
ndarray.gt
|
torch.gt |
ndarray.ge
|
torch.ge |
ndarray.eq
|
torch.eq |
ndarray.ne
|
torch.ne |