在 PyTorch 中,当你定义一个模型,即使在模型定义时没有显式声明接收输入数据的参数,模型的使用仍然可以通过直接传入输入数据来进行。这是因为模型类继承自 torch.nn.Module
,而 torch.nn.Module
已经预定义了如何处理输入数据的方式。
继承自 torch.nn.Module
当你创建一个新的模型类并使其继承自 torch.nn.Module
,你的模型类就自动继承了 torch.nn.Module
的方法和属性。torch.nn.Module
提供了一种机制来定义模型应如何接受输入并进行前向传播。这主要通过重写 forward
方法实现。
forward
方法
在 PyTorch 中,你不需要在模型的构造函数 (__init__
) 中指定输入参数。相反,你需要定义一个 forward
方法,它将自动接收模型的输入并定义模型的前向传播行为:
class MyModel(nn.Module):