📅  最后修改于: 2023-12-03 14:53:15.553000             🧑  作者: Mango
在 PyTorch 中,我们可以通过模型的 named_modules
或 named_children
方法获取模型的模块或子模块的名称。下面是具体的用法和示例代码。
named_modules
named_modules
方法可以递归地迭代模型的所有模块,并返回每个模块的名称和模块本身的迭代器。下面是 named_modules
方法的语法:
named_modules(prefix='')
其中,prefix
参数表示要迭代的模块的名称前缀(默认为空字符串)。返回值是一个生成器对象,可以遍历所有的子模块。下面是示例代码:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
self.fc = nn.Linear(20 * 26 * 26, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.relu(self.conv2(x))
x = x.view(-1, 20 * 26 * 26)
x = self.fc(x)
return x
net = Net()
for name, module in net.named_modules():
print(name, module)
输出结果如下:
('', Net(
(conv1): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
(fc): Linear(in_features=13520, out_features=10, bias=True)
))
('conv1', Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1)))
('conv2', Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1)))
('fc', Linear(in_features=13520, out_features=10, bias=True))
可以看到,named_modules
方法遍历了所有子模块,并返回了它们的名称和模块本身。
named_children
named_children
方法和 named_modules
方法类似,但它只迭代直接子模块,不会递归遍历所有子模块。下面是 named_children
方法的语法:
named_children()
该方法不需要参数,返回值是一个可迭代的生成器对象,可以遍历所有的直接子模块。下面是示例代码:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
self.fc = nn.Linear(20 * 26 * 26, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.relu(self.conv2(x))
x = x.view(-1, 20 * 26 * 26)
x = self.fc(x)
return x
net = Net()
for name, module in net.named_children():
print(name, module)
输出结果如下:
conv1 Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
conv2 Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
fc Linear(in_features=13520, out_features=10, bias=True)
可以看到,named_children
方法只遍历了所有直接子模块,并返回了它们的名称和模块本身。
在 PyTorch 中,我们可以通过 named_modules
或 named_children
方法获取模型的模块或子模块的名称。这些方法可以帮助我们更好地理解模型的结构,并在需要的时候方便地进行调试和修改。