model = Model() model(input) 直接调用Model类中的forward(input)函数,因其实现了__call__
举个例子
1 import math, random 2 import numpy as np 3 4 import torch 5 import torch.nn as nn 6 import torch.optim as optim 7 import torch.autograd as autograd 8 import torch.nn.functional as F 9 USE_CUDA = torch.cuda.is_available() 10 Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs) 11 12 class Encoder(nn.Module): 13 def __init__(self, din=32, hidden_dim=128): 14 super(Encoder, self).__init__() 15 self.fc = nn.Linear(din, hidden_dim) 16 17 def forward(self, x): 18 embedding = F.relu(self.fc(x)) 19 return embedding 20 21 class AttModel(nn.Module): 22 def __init__(self, n_node, din, hidden_dim, dout): 23 super(AttModel, self).__init__() 24 self.fcv = nn.Linear(din, hidden_dim) 25 self.fck = nn.Linear(din, hidden_dim) 26 self.fcq = nn.Linear(din, hidden_dim) 27 self.fcout = nn.Linear(hidden_dim, dout) 28 29 def forward(self, x, mask): 30 v = F.relu(self.fcv(x)) 31 q = F.relu(self.fcq(x)) 32 k = F.relu(self.fck(x)).permute(0,2,1) 33 att = F.softmax(torch.mul(torch.bmm(q,k), mask) - 9e15*(1 - mask),dim=2) 34 35 out = torch.bmm(att,v) 36 #out = torch.add(out,v) 37 out = F.relu(self.fcout(out)) 38 return out 39 40 class Q_Net(nn.Module): 41 def __init__(self, hidden_dim, dout): 42 super(Q_Net, self).__init__() 43 self.fc = nn.Linear(hidden_dim, dout) 44 45 def forward(self, x): 46 q = self.fc(x) 47 return qView Code
1 class DGN(nn.Module): 2 def __init__(self,n_agent,num_inputs,hidden_dim,num_actions): 3 super(DGN, self).__init__() 4 5 self.encoder = Encoder(num_inputs,hidden_dim) 6 self.att_1 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim) 7 self.att_2 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim) 8 self.q_net = Q_Net(hidden_dim,num_actions) 9 10 def forward(self, x, mask): 11 h1 = self.encoder(x) 12 h2 = self.att_1(h1, mask) 13 h3 = self.att_2(h2, mask) 14 q = self.q_net(h3) 15 return q
在监视窗口查看
model是Tensor类型
故model(input)[0]是取第一个batch