jupyter notebook加載DDP預訓練模型
最近遇到了一個問題,模型是用DistributedDataParallel一機多卡分布式訓練的,然后作為一個jupyter notebook重度用戶,我想用它來加載這個模型,搞點預測的例子可視化看看。
但是這會碰到一個問題,我們都知道通常加載預訓練模型的方法是:
pretrained_dict = torch.load(pretrained_path, map_location=device)
model.load_state_dict(pretrained_dict,strict=True)
但是要想load_state_dict在DDP下訓練的模型參數,首先初始化的模型model也需要在DDP下初始化,而我嘗試了很久,發現沒法在jupyter上初始化分布式環境:
torch.distributed.init_process_group(backend='nccl', init_method='env://')
然后想了一個解決辦法,查看一下torch.load之后的pretrained_dict字典參數,其中有很多項內容,可以看下我保存模型的時候:
save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_score': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best, work_dir = args.work_dir)
所以想要加載模型參數,首先就要取出'state_dict',然后看下model.state_dict()里的數據結構,發現參數變量名是套在module.model下的,而我們初始化的模型結構model, 其model.state_dict()參數變量是直接model.XX的,所以就把預訓練模型的參數變量名過濾掉moduel,然后用初始化模型model去load_state_dict它就好了;
整體代碼:
model = XXX #初始化模型結構 pretrained_path = 'XXX' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_dict = torch.load(pretrained_path, map_location=device)['state_dict'] pretrained_dict = {k[7:]:v for k,v in pretrained_dict.items()} #k[X:]看情況調整 model.load_state_dict(pretrained_dict,strict=True)
注意,load_state_dict里的參數strict還是需要True來嚴格對齊,如果False的話,預訓練的模型參數就會不嚴格加載,導致后續性能出現偏差。
人生苦短,何不用python

浙公網安備 33010602011771號