3.2export_onnx
?? 1. 模型定義與導出代碼
import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm2d(num_features=64)
self.act2 = nn.ReLU()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.act2(x)
# flatten: B×C×H×W → B×C×L(L=H×W)
x = torch.flatten(x, 2, 3)
# 平均池化:B×C×L → B×C×1
x = self.avgpool(x)
# 再次 flatten:B×C×1 → B×C
x = torch.flatten(x, 1)
# 全連接層分類:B×C → B×10
x = self.head(x)
return x
?? 2. 導出為 ONNX 并簡化
def export_norm_onnx():
input = torch.rand(1, 3, 64, 64) # 輸入:B×3×64×64
model = Model()
file = "./sample-reshape.onnx"
# 導出 ONNX 模型
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15
)
print("Finished normal onnx export")
# 檢查模型結構合法性
model_onnx = onnx.load(file)
onnx.checker.check_model(model_onnx)
# 使用 onnx-simplifier 進行圖結構簡化
print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, file)
?? 小提示:為什么 flatten 會生成多個節點?
x = torch.flatten(x, 2, 3)
ONNX 中不支持 flatten(x, start_dim=2) 這樣的高維展開直接表示,因此 PyTorch 導出時會轉換為:
Shape:獲取張量形狀Slice:提取要 flatten 的維度Concat:拼接新 shapeReshape:完成 flatten 動作
使用 onnxsim 簡化后,這些操作通常會被合并為一個簡單的 Flatten 或 Reshape。
? 3. 主函數執行導出流程
if __name__ == "__main__":
export_norm_onnx()
?? 代碼結構整體說明:
?? 模型結構(Model 類):
x -> conv1 -> bn1 -> relu1
-> conv2 -> bn2 -> relu2
-> flatten -> avgpool -> flatten -> linear -> output
其中重點在于:
?? 第一段 flatten:
x = torch.flatten(x, 2, 3) # B, C, H, W -> B, C, L
這個操作會導致導出的 ONNX 圖中生成:
ShapeSliceConcatReshape
等一系列輔助節點。為什么?
?? 為什么 flatten(x, 2, 3) 會變成這么多 ONNX 節點?
PyTorch 的 torch.flatten(x, 2, 3) 表示:
- 把
x從第 2 維(H)到第 3 維(W)展平為一維 - 舉個例子:輸入
x是[B, C, H, W],flatten 后變成[B, C, H*W]
但是在 ONNX 中:
- ONNX 不支持 “動態切片 + flatten” 作為單一原始操作
- 所以需要分解為多個步驟來實現:
1. Shape:先獲取 x 的形狀
2. Slice:抽取你需要的維度值(這里是 H 和 W)
3. Concat:拼接出新 shape,例如 [B, C, H*W]
4. Reshape:應用這個新 shape
這就是你看到的:
Shape -> Slice -> Slice -> Mul -> Concat -> Reshape
的由來。
?? 為什么導出前后圖不一樣?
你有兩個版本:
? 原始導出圖:
- 有上述所有細化節點(Slice/Shape/Reshape 等)
- 這對于 動態輸入尺寸 很重要,但會讓圖復雜
? 簡化后的 ONNX(使用 onnxsim.simplify):
- 會自動識別這部分是一個 flatten 動作
- 用更簡潔的方式重新表達(甚至直接用一個
Flatten節點)
這是為什么你寫了:
# onnx中其實會有一些constant value,以及不需要計算圖跟蹤的節點
# 大家可以一起從netron中看看這些節點都在干什么
?? 平鋪流程:flatten + avgpool + flatten + fc
你原始的網絡有這幾步轉換:
| 步驟 | 輸入維度 | 輸出維度 | 說明 |
|---|---|---|---|
flatten(x, 2, 3) |
[B,C,H,W] |
[B,C,L] |
H × W 展平為 L |
AdaptiveAvgPool1d(1) |
[B,C,L] |
[B,C,1] |
類似全局平均池化 |
flatten(x, 1) |
[B,C,1] |
[B,C] |
去掉最后一維 |
Linear |
[B,C] |
[B,10] |
最終全連接層分類 |
?? 建議你動手做以下實驗理解更深:
- 注釋掉
onnxsim.simplify(),用 Netron 打開.onnx文件,看看flatten變成了哪些低層操作? - 然后再運行一次
simplify,看看有沒有把它們合并成一個Flatten或更簡潔的結構? - 把
torch.flatten(x, 2, 3)換成.view(b, c, -1)或.reshape(...),看看導出的結構是否更簡潔?
? 總結重點
| 項 | 內容 |
|---|---|
| flatten 操作為什么變復雜? | 因為 ONNX 中 flatten 只支持從第 1 個維度開始,如果你指定的是 2~3,會生成 shape/slice/reshape |
onnxsim.simplify 作用? |
自動識別復雜邏輯并簡化(合并 slice、reshape 等) |
| 推薦做法? | 導出前先理解動態維度怎么計算,導出后建議簡化以減小模型體積、提升兼容性 |
| 哪些操作最容易生成冗余圖? | flatten、transpose、reshape、permute、expand 等涉及動態 shape 的操作 |

浙公網安備 33010602011771號