這次把GRPO訓(xùn)練推理模型的上下文變長(zhǎng)10倍,同時(shí)需要的顯存少了90%。 使用最新的Unsloth,只要5GB顯存就能訓(xùn)練自己的推理模型,而且Qwen2.5-1.5B不會(huì)損失準(zhǔn)確率。 5GB顯存什么概念呢? 16年開(kāi)始發(fā)售的GPU比如GTX 1060的顯存都有8GB。16年GTX 1060放到現(xiàn)在,堪稱電子古董! 目前,實(shí)現(xiàn)更長(zhǎng)的上下文是GRPO面臨的最大挑戰(zhàn)之一。 與其他GRPO LoRA/QLoRA實(shí)現(xiàn)相比,即使是基于Flash Attention 2(FA2)的實(shí)現(xiàn),Unsloth新推出的高效GRPO算法上下文長(zhǎng)度增加了10倍,同時(shí)使用的VRAM只要10%。 在配備TRL+FA2的GRPO設(shè)置中,Llama 3.1(8B)在20K上下文長(zhǎng)度下,訓(xùn)練需要510.8GB的VRAM。 而Unsloth將VRAM減少了90%,降至僅54.3GB。 減少長(zhǎng)上下文90%VRAM 和使用Flash Attention 2的標(biāo)準(zhǔn)實(shí)現(xiàn)相比,Unsloth使用多種技巧,巧妙地把GRPO的VRAM使用量減少了90%多! 在20K的上下文長(zhǎng)度下,每個(gè)提示生成8次,Unsloth在Llama-3.1-8B模型上僅使用54.3GB的VRAM,而標(biāo)準(zhǔn)實(shí)現(xiàn)需要510.8GB(Unsloth減少了90%)。這一切得益于下列3項(xiàng)突破: 全新設(shè)計(jì)的內(nèi)存高效線性算法:將GRPO的內(nèi)存使用量削減了8倍以上,節(jié)省了68.5GB的內(nèi)存。借助torch.compile,在num_generations=8和20K上下文長(zhǎng)度下,實(shí)際上還更快。 利用了Unsloth已發(fā)布的智能梯度checkpoint算法:將中間激活值異步卸載到系統(tǒng)RAM中,速度僅慢了1%。由于需要num_generations=8,這節(jié)省了高達(dá)372GB的VRAM。通過(guò)中間梯度累積,甚至可以進(jìn)一步減少內(nèi)存使用。 與底層推理引擎(vLLM)共享相同的GPU/CUDA內(nèi)存空間,不像其他包中的實(shí)現(xiàn)那樣。這又節(jié)省了16GB的VRAM。 Unsloth和基于Flash Attention 2(FA2)的標(biāo)準(zhǔn)實(shí)現(xiàn)內(nèi)存比較 在典型的GRPO標(biāo)準(zhǔn)實(shí)現(xiàn)中,需要?jiǎng)?chuàng)建兩個(gè)大小為(8,20K)的logits來(lái)計(jì)算GRPO損失。這需要2*2字節(jié)*8(生成次數(shù))*20K(上下文長(zhǎng)度)*128256(詞匯表大。=78.3GB的VRAM。 Unsloth將長(zhǎng)上下文GRPO的內(nèi)存使用量削減了8倍,因此對(duì)于20K的上下文長(zhǎng)度,只需要額外的9.8GBVRAM! 還需要以16位格式存儲(chǔ)KV緩存。Llama3.18B有32層,K和V的大小均為1024。因此,對(duì)于20K的上下文長(zhǎng)度,內(nèi)存使用量=2*2字節(jié)*32層*20K上下文長(zhǎng)度*1024=每個(gè)批次2.5GB。 可以將vLLM的批次大小設(shè)置為8,但為了節(jié)省VRAM,在計(jì)算中將其保持為1。否則,需要20GB來(lái)存儲(chǔ)KV緩存。 數(shù)學(xué)原理 分組相對(duì)策略優(yōu)化(Group Relative Policy Optimization,GRPO),出自DeepSeek去年發(fā)表的論文。 如果一生只能讀一篇DeepSeek的論文,網(wǎng)友建議選擇首次提出GRPO的DeepSeekMath論文。 論文鏈接:https://arxiv.org/abs/2402.03300 隨后在DeepSeek的論文中,利用GRPO算法創(chuàng)建了DeepSeek-R1。 發(fā)現(xiàn)的問(wèn)題在這里利用了Hugging Face的TRL GRPO實(shí)現(xiàn)。 注意到,TRL實(shí)現(xiàn)的公式如下: 其中使用的是反向KL散度(而不是正向KL散度)。β是一個(gè)設(shè)為0.04的縮放因子,A是考慮所有獎(jiǎng)勵(lì)函數(shù)后得到的優(yōu)勢(shì)值。q是新訓(xùn)練的模型,P是原始參考模型。 然后注意到,該實(shí)現(xiàn)將反向KL散度計(jì)算為: 但這真的是正確的嗎? 首先嘗試推導(dǎo)并整理類似項(xiàng): 這意味著什么?實(shí)現(xiàn)中可能缺少一個(gè)與q(新分布項(xiàng))的乘法嗎? 但這似乎是正確的,和DeepSeek-Math論文第14頁(yè)首次引入GRPO時(shí)一樣。 DeepSeek-Math論文第14頁(yè):在損失函數(shù)中添加KL散度,正則化GRPO算法 同樣,John Schulman的博客也提到,反向KL項(xiàng)的無(wú)偏估計(jì),實(shí)際上并不需要額外的q項(xiàng)。 鏈接地址:http://joschu.net/blog/kl-approx.html 在博客中看到: 還發(fā)現(xiàn)了一個(gè)有趣的現(xiàn)象:
這應(yīng)該等于1,對(duì)嗎? Hugging Face的TRL GRPO實(shí)現(xiàn) 實(shí)際上,發(fā)現(xiàn)這是必要的——似乎自動(dòng)梯度autograd引擎可能無(wú)法正確傳播梯度。 因此,進(jìn)行了4個(gè)實(shí)驗(yàn): 使用參考實(shí)現(xiàn)的常規(guī)GRPO(紅線) 移除detach代碼(藍(lán)線) 按照之前討論的完整反向KL,添加額外項(xiàng)(黃線) 使用正向KL散度代替(綠線) 總體來(lái)說(shuō),移除detach顯然會(huì)破壞訓(xùn)練,所以必須保留它——這很可能需要進(jìn)一步調(diào)查。其他實(shí)現(xiàn)似乎也類似?可能需要運(yùn)行模型更長(zhǎng)時(shí)間,以觀察不同的效果。 在所有實(shí)現(xiàn)中,還利用了logsumexp技巧: Unsloth高效GRPO算法 但沒(méi)想到華人工程師Horace He的線性交叉熵實(shí)現(xiàn),帶給unsloth靈感并成功應(yīng)用于GRPO!Horace He,在Meta從事PyTorch相關(guān)工作 實(shí)際上,unsloth發(fā)現(xiàn)了一些令人驚訝的要點(diǎn): 1 GRPO參考實(shí)現(xiàn)使用的是反向KL散度,而不是正向KL散度。 2 如果不正確處理,在float16混合精度(以及float8)上直接實(shí)現(xiàn)線性交叉熵,并使用自動(dòng)混合精度縮放機(jī)制,會(huì)導(dǎo)致崩潰。 3 發(fā)現(xiàn)了GRPO損失實(shí)現(xiàn)中的其他一些奇怪之處,主要是在反向KL散度的公式表述方面。 線性交叉商鏈接:https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899 本文來(lái)源:新智元 |
原創(chuàng)欄目
IT百科
網(wǎng)友評(píng)論
聚超值•精選