在NumPy中,可以使用__array_priority__属性来控制作用于ndarray和用户定义类型的二元运算符.例如:
class Foo(object):
def __radd__(self, lhs): return 0
__array_priority__ = 100
a = np.random.random((100,100))
b = Foo()
a + b # calls b.__radd__(a) -> 0
然而,同样的事情似乎不适用于比较运算符.例如,如果我将以下行添加到Foo,那么它永远不会从表达式a< b:
def __rlt__(self, lhs): return 0
我意识到__rlt__并不是真正的Python特殊名称,但我认为它可能有用.我尝试了所有的__lt __,__ le __,__ eq __,__ ne__,__ ge __,__ gt__,有或没有前面的r,加上__cmp__,但我永远无法让NumPy调用它们中的任何一个.
这些比较可以被覆盖吗?
UPDATE
为了避免混淆,这里有一个更长的描述NumPy的行为.对于初学者来说,这是NumPy指南中所说的内容:
If the ufunc has 2 inputs and 1 output and the second input is an Object array
then a special-case check is performed so that NotImplemented is returned if the
second input is not an ndarray, has the array priority attribute, and has an
r<op> special method.
我认为这是制定工作的规则.这是一个例子:
import numpy as np
a = np.random.random((2,2))
class Bar0(object):
def __add__(self, rhs): return 0
def __radd__(self, rhs): return 1
b = Bar0()
print a + b # Calls __radd__ four times, returns an array
# [[1 1]
# [1 1]]
class Bar1(object):
def __add__(self, rhs): return 0
def __radd__(self, rhs): return 1
__array_priority__ = 100
b = Bar1()
print a + b # Calls __radd__ once, returns 1
# 1
如您所见,在没有__array_priority__的情况下,NumPy将用户定义的对象解释为标量类型,并在数组中的每个位置应用该操作.那不是我想要的.我的类型是数组(但不应该从ndarray派生).
这是一个较长的示例,显示了在定义所有比较方法时如何失败:
class Foo(object):
def __cmp__(self, rhs): return 0
def __lt__(self, rhs): return 1
def __le__(self, rhs): return 2
def __eq__(self, rhs): return 3
def __ne__(self, rhs): return 4
def __gt__(self, rhs): return 5
def __ge__(self, rhs): return 6
__array_priority__ = 100
b = Foo()
print a < b # Calls __cmp__ four times, returns an array
# [[False False]
# [False False]]
解决方法:
看起来我自己可以回答这个问题. np.set_numeric_ops可以使用如下:
class Foo(object):
def __lt__(self, rhs): return 0
def __le__(self, rhs): return 1
def __eq__(self, rhs): return 2
def __ne__(self, rhs): return 3
def __gt__(self, rhs): return 4
def __ge__(self, rhs): return 5
__array_priority__ = 100
def override(name):
def ufunc(x,y):
if isinstance(y,Foo): return NotImplemented
return np.getattr(name)(x,y)
return ufunc
np.set_numeric_ops(
** {
ufunc : override(ufunc) for ufunc in (
"less", "less_equal", "equal", "not_equal", "greater_equal"
, "greater"
)
}
)
a = np.random.random((2,2))
b = Foo()
print a < b
# 4