LLM采樣后處理總結:LLM的后處理的cpp實現(xiàn)
LLM采樣后處理總結:LLM的后處理的cpp實現(xiàn)
在經(jīng)過LLM的lm_head之后,會得到[batch, vocab_size]大小的矩陣向量,此時需要對輸出的邏輯張量進行采樣,除了beam_search的貪心策略,還有repetition_penalty、temperature、top_k、top_p等幾種控制采樣的方法。
repetition_penalty
repetition_penalty的主要作用是控制重復,這里first和last分別為vocab中的第一個元素和最后一個元素的位置,input_ids為之前輸出的文本id。
也即是把之前輸出過的內容全部變小,那么就可以防止文本出現(xiàn)不斷重復的情況,penalty越小,懲罰力度越大,penalty越大,懲罰力度越小,重復概率就會增加。
void sampling_repetition_penalty(float *first, float *last, const std::vector<int> &input_ids,
float penalty) {
std::unordered_set<int> unique_input_ids(input_ids.begin(), input_ids.end());
for (int id : unique_input_ids) {
if (first[id] > 0) {
first[id] /= penalty;
} else {
first[id] *= penalty;
}
}
}
temperature
temperature是控制softmax下的平滑參數(shù),相當于在softmax前每個邏輯值都進行了放縮。
當temp越大的時候,此時softmax值之間的差距會減小,分布就越均勻,此時采樣出的結果就越隨機,反之就會使得原本高概率的的變得更高低的更低減少了隨機性。
void sampling_temperature(float *first, float *last, float temp) {
float inv_temp = 1.f / temp;
for (float *it = first; it != last; it++) {
*it *= inv_temp;
}
}
top_k
top_k是取前k個,直接排序拿到概率最大的前k個。
void sampling_top_k(TokenIdScore *first, TokenIdScore *kth, TokenIdScore *last) {
std::nth_element(first, kth, last, std::greater<TokenIdScore>());
}
top_p
top_p是先對所有的值進行softmax,然后找到滿足sum_p <= top_p的最小集合,然后對這個集合內的數(shù)再進行softmax和采樣。
一種簡單的做法是將所有值進行排序,然后貪心找到滿足條件的前k個。
示例代碼中使用了一種類似于快速排序的方法,每次找mid點,將大于mid和小于mid的分為兩堆,要么在大的一堆要么在小的一堆。
當在大的一堆中時就mid往前移動,在小的一堆時則更新top_p = top_p-sum_p,直至找到對應的位置。
時間復雜度上會稍微比先排序快一些。
void sampling_softmax_inplace(TokenIdScore *first, TokenIdScore *last) {
float max_score = std::max_element(first, last)->score;
float sum = 0.f;
for (TokenIdScore *p = first; p != last; p++) {
float s = std::exp(p->score - max_score);
p->score = s;
sum += s;
}
float inv_sum = 1.f / sum;
for (TokenIdScore *p = first; p != last; p++) {
p->score *= inv_sum;
}
}
TokenIdScore *sampling_top_p(TokenIdScore *first, TokenIdScore *last, float top_p) {
// fast top_p in expected O(n) time complexity
sampling_softmax_inplace(first, last);
while (first + 1 < last) {
float pivot_score = (last - 1)->score; // use mid score?
TokenIdScore *mid =
std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; });
std::swap(*mid, *(last - 1));
float prefix_sum =
std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; });
if (prefix_sum >= top_p) {
last = mid;
} else if (prefix_sum + mid->score < top_p) {
first = mid + 1;
top_p -= prefix_sum + mid->score;
} else {
return mid + 1;
}
}
return last;
}

浙公網(wǎng)安備 33010602011771號