提升主要来源于编译层面的算子融合、代码编译优化,总体可以带来一定幅度的提升
但其具有一定缺点:
首先,每次进程启动都需要预编译和预热,需要一定时间,若训练轮数少、时间短,其加速效果一般
其次,其暂时不支持 Python 3.12 + 版本
最后,针对性价比,个人认为需要看消融结果、继续实验,不同实验效果不同,很难一概而论
总结:
总体可以认为是一个提速的方案,但具体在项目中的应用需要根据具体情况斟酌/消融
其应用方案:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import torch import torch.nn as nn import torch.optim as optim # 定义简单模型 class SimpleNet(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(32, 10) ) def forward(self, x): return self.net(x) # 构造数据 x = torch.randn(16, 3, 224, 224).cuda() y = torch.randint(0, 10, (16,)).cuda() # 定义模型、优化器、损失函数 model = SimpleNet().cuda() optimizer = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() # 使用 torch.compile 包装模型 compiled_model = torch.compile(model) # 默认后端 'inductor'// 还有其他选项 # 训练一步 for _ in range(5): optimizer.zero_grad() out = compiled_model(x) loss = criterion(out, y) loss.backward() optimizer.step() print("loss:", loss.item())
补注:
1.性价比这一块依旧需要做大量实验来确定效果,如:
显存占用情况和影响因素
内存占用情况和影响因素
提速情况和影响因素(理论上轻量算子多效果好,重量算子多效果一般)
2.编译对算子融合的优化逻辑比较建议加以研究,同时进一步完善算子融合相关的理解
何种算子应当被优化,何种不应当被优化
何种算子理应被优化但并没有被编译优化,这些都是值得研究的
3.编译优化实际上还有很多模式,同时也不止包裹这一种方案,都可以研究、实验,总结影响因素