📜  如何获取 pytroch 模型层名称 - Python (1)

📅  最后修改于: 2023-12-03 14:53:15.553000             🧑  作者: Mango

如何获取 PyTorch 模型层名称

在 PyTorch 中,我们可以通过模型的 named_modulesnamed_children 方法获取模型的模块或子模块的名称。下面是具体的用法和示例代码。

named_modules

named_modules 方法可以递归地迭代模型的所有模块,并返回每个模块的名称和模块本身的迭代器。下面是 named_modules 方法的语法:

named_modules(prefix='')

其中,prefix 参数表示要迭代的模块的名称前缀(默认为空字符串)。返回值是一个生成器对象,可以遍历所有的子模块。下面是示例代码:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        self.fc = nn.Linear(20 * 26 * 26, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = x.view(-1, 20 * 26 * 26)
        x = self.fc(x)
        return x

net = Net()

for name, module in net.named_modules():
    print(name, module)

输出结果如下:

('', Net(
  (conv1): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
  (fc): Linear(in_features=13520, out_features=10, bias=True)
))
('conv1', Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1)))
('conv2', Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1)))
('fc', Linear(in_features=13520, out_features=10, bias=True))

可以看到,named_modules 方法遍历了所有子模块,并返回了它们的名称和模块本身。

named_children

named_children 方法和 named_modules 方法类似,但它只迭代直接子模块,不会递归遍历所有子模块。下面是 named_children 方法的语法:

named_children()

该方法不需要参数,返回值是一个可迭代的生成器对象,可以遍历所有的直接子模块。下面是示例代码:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        self.fc = nn.Linear(20 * 26 * 26, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = x.view(-1, 20 * 26 * 26)
        x = self.fc(x)
        return x

net = Net()

for name, module in net.named_children():
    print(name, module)

输出结果如下:

conv1 Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
conv2 Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
fc Linear(in_features=13520, out_features=10, bias=True)

可以看到,named_children 方法只遍历了所有直接子模块,并返回了它们的名称和模块本身。

结论

在 PyTorch 中,我们可以通过 named_modulesnamed_children 方法获取模型的模块或子模块的名称。这些方法可以帮助我们更好地理解模型的结构,并在需要的时候方便地进行调试和修改。