學(xué)習(xí)Sci. Adv. 關(guān)于AMP_generator文章-復(fù)現(xiàn)
- 環(huán)境配置:在Anaconda Prompt中創(chuàng)建虛擬環(huán)境“AMPdesign”,python=3.9。本人新手,根據(jù)大佬帖子安裝pytorch:“https://zhuanlan.zhihu.com/p/1897261918172987396”,本次安裝的是2.5.1版本,conda途徑安裝,安裝代碼:
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia
![image]()
- pycharm中打開項目,配置剛剛新建的虛擬環(huán)境:文件→設(shè)置→項目→Python解釋器→添加解釋器→添加本地解釋器→Conda環(huán)境:使用現(xiàn)有環(huán)境→選擇AMPdesign
![image]()
- 先閱讀README,再開始代碼復(fù)現(xiàn)
3.1 首先運行train_AMP_GPT.py文件。這個文件有幾個小錯誤需要改,按照要求改就好了,比如文件名不統(tǒng)一,np.Inf需改為np.inf等。訓(xùn)練中的狀態(tài):
本地太慢確實不適合跑完整的訓(xùn)練,修改參數(shù)epoch=2,batch_size=4, warmup_steps=10進行測試。
3.2
本地先測試。首先將data/prompt_data里面的兩個帶標(biāo)簽數(shù)據(jù)集合并,并加上3個表頭:comment_text,id,label。
第二至關(guān)重要。修改測試參數(shù),將batch_size調(diào)為8,epoch調(diào)為1,--log_step調(diào)為10.不然本地跑不動,會顯存報錯。
第三,把Save model checkpoint注釋掉。不然硬盤存儲真的扛不動?。?!
注釋完后貌似有格式問題,新代碼如下:

第四,把17行的from pytorchtools import EarlyStopping修改為from early_stop.pytorchtools import EarlyStopping。
還有一些細節(jié)的改動,可以直接debug。訓(xùn)練結(jié)果:

但是好像沒有驗證集?看不出來有沒有過擬合。這樣的話最好把帶標(biāo)簽的數(shù)據(jù)給劃分一下。
3.3
這一步我們根據(jù)步驟先把模型蒸餾一下。
持續(xù)更新中...
歡迎指正交流,共同學(xué)習(xí)進步。郵箱:z1437143688@126.com


浙公網(wǎng)安備 33010602011771號