背景
需要對3B模型進行蒸餾,一張4090的卡無法完成實驗。完成這個實驗的前提是需要兩張卡,一張用來加載學生模型,一張用來加載教師模型。
多卡使用
這里的多卡使用并不是像以往的方式,使用dataloaderparallel等方式,這種是數據并行的策略,不適合蒸餾的場景,因為蒸餾是一個模型做推理,一個模型做訓練,并非數據并行計算。因此分開加載模型,一個用來訓練,一個用來推理,訓練的數據和訓練卡放在同一個設備上即可。
device_stu = "cuda:0"
device_teh = "cuda:1"
# 模型加載
student_model.to(device_stu)
teacher_model.to(device_teh)
student_model.train()
teacher_model.eval()
# 數據加載
for batch_stu in dataloader(text):
batch_teh = copy.deepcopy(batch_stu)
batch_stu.to(device_stu)
batch_teh.to(device_teh)
logits_stu = student_model(**batch_stu)
logits_teh = teacher_molde(**batch_teh)
loss = kl(logits_stu, logits_teh, device_stu)
loss.backend()
代碼分析:
- 學生模型和教師模型分開加載
- 數據需要深度拷貝,否會出現設備不一致的錯誤
- 把logits放在相同的設備,并計算損失
- 反向傳播
浙公網安備 33010602011771號