PyTorch中nn.Module使用示例指南
在 PyTorch 中,nn.Module 是神經(jīng)網(wǎng)絡(luò)中最核心的基類,用于構(gòu)建所有模型。理解并熟練使用 nn.Module 是掌握 PyTorch 的關(guān)鍵。
一、什么是nn.Module
nn.Module 是 PyTorch 中所有神經(jīng)網(wǎng)絡(luò)模塊的基類??梢园阉醋魇?ldquo;神經(jīng)網(wǎng)絡(luò)的容器”,它封裝了以下幾件事:
- 網(wǎng)絡(luò)層(如 Linear、Conv2d 等)
- 前向傳播邏輯(
forward函數(shù)) - 模型參數(shù)(自動注冊并可訓(xùn)練)
- 可嵌套(可以包含多個子模塊)
- 便捷的模型保存 / 加載等工具函數(shù)
二、基礎(chǔ)用法
2.1 自定義模型類
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x2.2 實例化與調(diào)用
model = MyNet() x = torch.randn(32, 784) # batch_size = 32 output = model(x) # 自動調(diào)用 forward
三、構(gòu)造方法詳解
3.1__init__()
- 定義子模塊、層等結(jié)構(gòu)。
- 例如
self.conv1 = nn.Conv2d(...)會被自動注冊為模型參數(shù)。
3.2forward()
- 定義前向傳播邏輯。
- 不能手動調(diào)用,應(yīng)使用
model(x)形式。
四、常見模塊層
| 模塊名 | 作用 | 示例 |
|---|---|---|
nn.Linear | 全連接層 | nn.Linear(128, 64) |
nn.Conv2d | 卷積層 | nn.Conv2d(3, 16, 3) |
nn.ReLU | 激活函數(shù) | nn.ReLU() |
nn.Sigmoid | 激活函數(shù) | nn.Sigmoid() |
nn.BatchNorm2d | 批歸一化 | nn.BatchNorm2d(16) |
nn.Dropout | Dropout 層 | nn.Dropout(0.5) |
nn.LSTM | LSTM 層 | nn.LSTM(10, 20) |
nn.Sequential | 層的順序容器 | 見下文說明 |
五、模型嵌套結(jié)構(gòu)(子模塊)
你可以將一個 nn.Module 作為另一個模塊的子模塊嵌套:
class Block(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU()
)
def forward(self, x):
return self.layer(x)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block1 = Block()
self.block2 = Block()
self.output = nn.Linear(64, 10)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
return self.output(x)六、內(nèi)置方法和屬性
| 方法 / 屬性 | 說明 |
|---|---|
model.parameters() | 返回所有可訓(xùn)練參數(shù)(用于優(yōu)化器) |
model.named_parameters() | 返回帶名字的參數(shù)迭代器 |
model.children() | 返回子模塊迭代器 |
model.eval() | 設(shè)置為評估模式(Dropout、BN失效) |
model.train() | 設(shè)置為訓(xùn)練模式 |
model.to(device) | 將模型轉(zhuǎn)移到 GPU/CPU |
model.state_dict() | 獲取模型參數(shù)字典(保存) |
model.load_state_dict() | 加載模型參數(shù)字典 |
七、使用nn.Sequential
nn.Sequential 是一個順序容器,可以用來簡化網(wǎng)絡(luò)結(jié)構(gòu)定義:
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)等價于手寫的自定義 nn.Module。適合前向傳播是線性“流動”的結(jié)構(gòu)。
八、實戰(zhàn)完整示例:MNIST 分類網(wǎng)絡(luò)
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, x):
return self.net(x)
# 實例化模型
model = MNISTNet()
print(model)
# 配置訓(xùn)練
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 示例訓(xùn)練循環(huán)
for epoch in range(10):
for images, labels in train_loader:
output = model(images)
loss = criterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()九、常見陷阱和建議
| 問題 | 說明 |
|---|---|
forward() 不起作用 | 應(yīng)該使用 model(x),而不是手動調(diào)用 model.forward(x) |
忘記 super().__init__() | 子模塊將不會被注冊 |
| 參數(shù)未注冊 | 層/模塊必須賦值為 self.xxx = ... |
| 訓(xùn)練/測試模式混淆 | 注意 model.eval() 和 model.train() |
十、總結(jié)
| 項目 | 說明 |
|---|---|
__init__() | 定義模型結(jié)構(gòu)(子模塊、層) |
forward() | 定義前向傳播 |
| 自動注冊參數(shù) | 所有 self.xxx = nn.XXX(...) 都會被追蹤 |
| 嵌套模塊 | 支持遞歸子模塊調(diào)用 |
| 便捷方法 | .parameters()、.to()、.eval() 等 |
十一、綜合示例
以下是基于 PyTorch nn.Module 封裝的三種經(jīng)典深度學(xué)習(xí)架構(gòu)(ResNet18、UNet、Transformer)的簡潔而完整的實現(xiàn),適合初學(xué)者快速上手。
1、ResNet18 簡潔實現(xiàn)(適合圖像分類)
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample:
identity = self.downsample(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += identity
return F.relu(out)
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.in_planes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_planes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
layers = [block(self.in_planes, planes, stride, downsample)]
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x).flatten(1)
return self.fc(x)
def ResNet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)2、UNet(適合圖像分割)
class UNetBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.enc1 = UNetBlock(in_channels, 64)
self.enc2 = UNetBlock(64, 128)
self.enc3 = UNetBlock(128, 256)
self.enc4 = UNetBlock(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = UNetBlock(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = UNetBlock(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = UNetBlock(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = UNetBlock(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = UNetBlock(128, 64)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
b = self.bottleneck(self.pool(e4))
d4 = self.upconv4(b)
d4 = self.dec4(torch.cat([d4, e4], dim=1))
d3 = self.upconv3(d4)
d3 = self.dec3(torch.cat([d3, e3], dim=1))
d2 = self.upconv2(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1))
d1 = self.upconv1(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1))
return self.final(d1)3、簡化版 Transformer 編碼器(適合序列建模)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_hidden_dim),
nn.ReLU(),
nn.Linear(ff_hidden_dim, embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
ff_out = self.ff(x)
x = self.norm2(x + self.dropout(ff_out))
return x
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim)
self.layers = nn.ModuleList([
TransformerBlock(embed_dim, n_heads, ff_dim)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(0.1)
def _generate_positional_encoding(self, max_len, d_model):
pos = torch.arange(0, max_len).unsqueeze(1)
i = torch.arange(0, d_model, 2)
angle_rates = 1 / torch.pow(10000, (i / d_model))
pos_enc = torch.zeros(max_len, d_model)
pos_enc[:, 0::2] = torch.sin(pos * angle_rates)
pos_enc[:, 1::2] = torch.cos(pos * angle_rates)
return pos_enc.unsqueeze(0)
def forward(self, x):
B, T = x.shape
x = self.embedding(x) + self.pos_encoding[:, :T].to(x.device)
x = self.dropout(x)
for layer in self.layers:
x = layer(x)
return x4、 總結(jié)對比
| 模型類型 | 場景 | 特點 |
|---|---|---|
| ResNet18 | 圖像分類 | 深殘差網(wǎng)絡(luò)結(jié)構(gòu),適合遷移學(xué)習(xí) |
| UNet | 圖像分割 | 對稱結(jié)構(gòu),編碼 + 解碼 + skip |
| Transformer | NLP / 序列建模 | 全注意力機制,無卷積無循環(huán) |
到此這篇關(guān)于PyTorch中nn.Module詳解和綜合代碼示例的文章就介紹到這了,更多相關(guān)PyTorch nn.Module內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python socket網(wǎng)絡(luò)編程之粘包問題詳解
這篇文章主要介紹了python socket網(wǎng)絡(luò)編程之粘包問題詳解,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-04-04
淺談python已知元素,獲取元素索引(numpy,pandas)
今天小編就為大家分享一篇淺談python已知元素,獲取元素索引(numpy,pandas),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-11-11

