PyTorch笔记——FX

官方文档链接:https://pytorch.org/docs/master/fx.html#

概述

FX是供开发人员用于转换nn.Module实例的工具包。FX由三个主要组件组成:符号追踪:symbolic tracer, 中间层表示:intermediate representation, Python代码生成:Python code generation。这些组件的运行演示:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪(symbolic tracer): 对Python代码进行"符号执行"。它以构造的值(也叫作:代理Proxies)为输入,贯穿运行所有代码。记录下对这些Proxie的操作。更多的符号追踪的信息可见 symbolic_trace()Tracer的相关文档。

**中间层表示(intermediate representation): ** 它里面保存了在符号追中中记录下的运算操作。它由表示函数输入、调用哪些对象(函数、方法或torch.nn.Module实例)和返回值的节点列表组成。关于IR的更多信息可以在Graph的文档中找到。IR是应用转换的格式。

**Python代码生成(Python code generation): ** Python代码生成使FX成为Python代码到Python代码(或模块到模块)转换工具包。对于每个 Graph IR,我们可以创建与图的语义匹配的有效Python代码。此功能包含在GraphModule中,GraphModule是一个torch.nn.Module实例,它包含一个图以及从该图生成的正向方法。

综合起来,这个组件的流水线(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python-to-Python 转换通道。 此外,这些组件可以单独使用。 例如,可以单独使用符号跟踪来捕获代码形式以用于分析(而不是转换)。 代码生成可用于以编程方式生成模型,例如从配置文件生成模型。 FX 有很多用途!

在示例库中有几个转换的样例

API

symbolic_trace

torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)

符号追踪的函数,以nn.Module或者函数实例为输入,然后将追踪过程中记录的操作记录下来构造一个GraphModule对象并返回。

concrete_args的作用是根据函数中的分支和参数进行定制化,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于控制流的存在,FX通常无法正常的追踪。但是,我们可以使用concrete_args指定b的值来解决该问题。

f = fx.symbolic_trace(f, concrete_args={‘b’: False}) assert f(3, False) == 6

注意,虽然你仍然可以给b传不同的值,但是这些值都会被忽略掉。

我们还可以使用concrete_args来消除函数中的数据结构处理。这将使用pytrees将输入展开。为避免过度定制,请为不应指定固定值的传入fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7

参数

  • **root(torch.nn.Module或者可调用对象): ** 要跟踪并转换为Graph的Module或函数。
  • **concrete_args (可选[Dict[str, any]]): **定制化部分输入
  • **enable_cpatching: ** 启用C级功能补全(捕获类似torch.randn的内容)

返回值

通过遍历root获得的相关计算操作创建出的Module。

返回值类型

GraphModule

注意:保证此 API 的向后兼容性。

CLASS torch.fx.Graph

CLASS torch.fx.Graph(owning_module=None, tracer_cls=None)

Graph是FX中间层表示的主要数据结构。它包含一组Node,每个Node表示了一个调用关系(或其他语法结构)。这些Node组合在一起构成了完整的Python功能。

样例:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

这样我们就构造了下面的Graph。

> print(gm.graph)
> graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

call_function(the_function, args=None, kwargs=None, type_expr=None)

在Graph中插入一个call_function的节点。call_function 节点表示对 Python 可调用对象的调用,被调用对象由 the_function 指定。
插入的位置选择同create_node。

call_method(method_name, args=None, kwargs=None, type_expr=None)

在Graph中插入一个call_method的节点。call_method 节点表示对 args 的第 0 个元素上的给定方法的调用。
插入的位置选择同create_node。

call_module(module_name, args=None, kwargs=None, type_expr=None)

在Graph中插入call_module 节点。 call_module 节点表示对模块层次结构中模块的 forward()函数的调用。
插入的位置选择同create_node。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)

创建一个节点并将其添加到当前插入点的图形中。 请注意,当前插入点可以通过 Graph.inserting_before() 和 Graph.inserting_after() 设置。

eliminate_dead_code()

根据每个节点的用户数以及删除节点是否有任何其他影响,从图中删除所有无用代码。调用之前,必须对图形进行拓扑排序。

erase_node(to_erase)

从图形中删除节点。如果图中仍在使用该节点,则抛出异常。

flatten_inps(*args)

get_attr(qualified_name, type_expr=None)

在Graph中插入一个get_attr节点。get_attr节点表示从模块层次结构中获取属性。
插入的位置选择同create_node。

graph_copy(g, val_map, return_output_node=False)

将所给的Graph中所有节点拷贝一份。

inserting_after(n=None)

设置 create_node 和配套方法将插入图中的点。
使用with语句时,临时设置插入点,然后在 with 语句退出时恢复原来的值:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

inserting_before(n=None)

类似after

lint()

对此图运行各种检查以确保其格式正确。 特别是: - 检查节点是否具有正确的所有权(由该图拥有) - 检查节点是否按拓扑顺序出现 - 如果该图有一个拥有的 GraphModule,则检查该 GraphModule 中是否存在目标

node_copy(node, arg_transform=<function Graph.>)

将一个节点从一个图中复制到另一个图中。 arg_transform 需要将参数从源Graph转换到目的Graph。 例子:

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])

PROPERTY nodes

获取Graph的Node列表。
请注意,此节点列表表示是一个双向链表。 迭代期间的突变(例如删除节点、添加节点)是安全的。

output(result, type_expr=None)

在Graph中插入一个output节点。output节点表示 Python 代码中的返回语句。 result是应该返回的值。

PROPERTY owning_module

如果有拥有此 GraphModule 的模块,则返回该模块,如果没有或有多个则返回 None。

placeholder(name, type_expr=None)

在Graph中插入一个placeholder。placeholder表示函数的输入。

print_tabular()

以表格格式打印图形的IR。 请注意,此 API 需要安装 tabulate 模块。

python_code(root_module)

将该Graph转换为有效的Python代码。

unflatten_outs(out)

CLASS torch.fx.Node

CLASStorch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)

Node是Graph中表示一个独立计算单元的数据结构。在大多数情况下,Node代表对各种entity的调用关系,例如运算符、方法和模块(一些例外包括指定函数输入和输出的节点)。 每个节点都有一个由其 op 属性指定的函数。 op的每个值的Node语义如下:

  • **placeholder : ** 表示函数输入。 name 属性指定此值将采用的名称。 target 同样是参数的名称。 args 包含:1) 什么都没有,或 2) 表示函数输入的默认参数的单个参数。 kwargs 不用关心的。 占位符对应于图形打印输出中的函数参数(例如 x)。
  • **get_attr: ** 从模块层次结构中检索参数。 name 与提取结果的名称类似。 target 是模块层次结构中参数位置的完全限定名称。 args 和 kwargs不用关心.
  • **call_function: ** 对某些值应用的*函数也就是非成员函数。 name 同样是要分配给的值的名称。 target 是要应用的函数。 args 和 kwargs 表示函数的参数,遵循 Python 调用约定
  • **call_module: ** 将模块层次结构的forward()方法中的模块应用于给定参数。name和前面一样。target是要调用的模块层次结构中模块的完全限定名。args和kwargs表示要在其上调用模块的参数,包括self参数。
  • **call_method: ** 对值调用方法。name的含义一样。target是要应用于自参数的方法的字符串名称。args和kwargs表示要在其上调用模块的参数,包括self参数
  • **output: ** 在其args[0]属性中包含跟踪函数的输出。这对应于图形打印输出中的"return"语句。

PROPERTY all_input_nodes

获取该节点的所有输入节点。这相当于找出 args 和 kwargs中值为Node的参数。

append(x)

在图中的节点列表中,在此节点后插入x。与self.next.prepend(x)功能相同。

PROPERTY args

此节点的参数元组。参数的解释取决于节点的操作码。有关详细信息,请参阅节点docstring。
允许对此属性进行赋值。使用和用户的所有记帐在分配时自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)

返回描述本Node的字符串。

is_impure()

返回此op是否是纯操作,即其op是否为占位符或输出,或者是否是纯的call_module或call_function。

PROPERTY kwargs

此节点的关键字参数的dict。参数的解释取决于节点的op代码实现。有关详细信息,请参阅node docstring。

PROPERTY next

返回该节点的下一个Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)

将规范化参数返回到Python目标。这意味着,如果normalize_to_only_use_kwargs为true,则args/kwargs将与模块/函数的签名匹配,并以位置顺序专门返回kwargs。还填充默认值。不支持仅位置参数或varargs参数。
支持模块调用。
可能需要arg_类型和kwarg_类型以消除重载的歧义。

prepend(x)

在该节点前插入x节点

replace_all_uses_with(replace_with)

将图中所有用本节点的地方替换为节点replace_with。

replace_input_with(old_input, new_input)

遍历Node的所有输入,并将old_input都替换为new_input。

PROPERTY stack_trace

返回跟踪期间记录的Python堆栈跟踪(如果有)。此属性通常由Tracer.create_proxy填充。要在跟踪过程中记录堆栈跟踪以进行调试,请在跟踪程序实例上设置record_stack_traces=True。

update_arg(idx, arg)

更新现有参数使第inx个参数值为arg。调用后,self.args[idx]==arg。

update_kwarg(key, arg)

更新现有kwarg参数新增键值为key对应值为arg的参数。调用后,self.kwargs[key]==arg。

torch.fx.replace_pattern(gm, pattern, replacement)

找到GraphModule中符合pattern匹配规则的所有运算符集,然后用replacement替换掉。

参数

  • gm: 要操作的GraphModule
  • **pattern: ** 匹配的模式
  • **replacement: ** 要替换成的目的子图

返回值

匹配对象列表,表示模式匹配到的原始图形中的位置。如果没有匹配项,则列表为空。匹配定义为:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回值

List[Match]

例子:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上面的代码将首先在traced_module的forward方法中对pattern进行匹配。模式匹配是基于use-def关系而不是节点名称完成的。例如,如果模式中有p=torch.cat([a,b]),则可以在原始forward函数中匹配m=torch.cat([a,b]),尽管变量名称不同(p与m)。

模式中的return语句仅基于其值进行匹配;它可能与较大图形中的return语句匹配,也可能不匹配。换句话说,模式不必延伸到较大图形的末尾。

当pattern匹配成功时,它将从较大的函数中删除,并用replacement来替换。如果在较大的函数中有匹配成功多个,则将替换每个不重叠的匹配。在匹配重叠的情况下,将替换重叠匹配集中找到的第一个匹配。(“第一个”在这里被定义为节点use-def关系拓扑顺序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的任何值。)

如果pattern是可调用的,则它的参数必须在pattern里面使用,replacement的参数必须与pattern的参数匹配。第一条规则,为什么在上面的代码块中,forward函数有参数x、w1、w2,而pattern函数只有参数w1、w2。因为pattern不使用x,所以不应该将x指定为参数。作为第二条规则的一个例子:

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

替换为

def replacement(x, y):
    return torch.relu(x)

在这种情况下,替换需要与pattern相同数量的参数(x和y),即使在替换中没有使用参数y。
调用subgraph_rewriter.replace_pattern后,生成的Python代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2
上一篇:DRF源码分析


下一篇:python中 类中def extend(self, *args, **kwargs)理解