[LLM] LLM后量化(PTQ)總結及原理實現
LLM后量化(PTQ)總結及原理實現
weight only
per_channel:按照每個channel的方式,計算得到scale和zero參數,通過weight = weight * scale + zero的方式進行還原。
per_channel_group_wise:按照每個channel的方式,在per_channel的基礎上產生一個scale,再增加了group_wise, 即每個channel內部再進行一次group的scale和zero,相當于更細粒度的量化.
TensorRT-LLM中的量化的gemm實現
以下是加載half 權重和反量化的代碼,在TensorRT中,兩個half在一個32bit中存儲,形成half2數據類型,以便于混合計算。
for (int idx = 0; idx < NPerBlock; ++idx)
{
for (int i = 0; i < Details::kShuffleContinous; ++i)
{
for (int j = 0; j < Details::kShuffleStrided; ++j)
{
// Dequantize the weights and arrange the shuffled elements back to the correct order in the
// register array
half2 v = *reinterpret_cast<half2*>(weights_vec + i * Details::kShuffleBasicTile
+ j * Details::kShuffleContinous * Details::kShuffleBasicTile);
v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
+ j * Details::kShuffleBasicTile + 0)
* NPerBlock
+ idx]
= v.x;
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
+ j * Details::kShuffleBasicTile + 1)
* NPerBlock
+ idx]
= v.y;
}
}
}
其中__hfma2是half2下的乘加運算,那么v = v * scale + zero;只不過這里同時計算了兩個half。
v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
下面這段代碼是計算 v = w * v + v 的過程,也即是將v累加到每一個batch結果中的過程,當每個block計算N個時可直接累加,否則則需要拆開計算。有所不同的是,使用的是half2數據類型進行計算。
這里accumulator是half類型的全局內存。
half accumulator[Num];
for (int b = 0; b < Batch; ++b)
{
half in_v[Details::kElemsPerThread];
// Perform vector inner product and accumulate
if constexpr (NPerBlock == 1)
{
half2 v = __float2half2_rn(0.f);
for (int y = 0; y < Details::kElemsPerThread; y += 2)
{
v = __hfma2(*reinterpret_cast<half2*>(weights_f16 + y), *reinterpret_cast<half2*>(in_v + y), v);
}
accumulator[b] += __hadd(v.x, v.y);
}
else
{
for (int x = 0; x < NPerBlock / 2; ++x)
{
for (int y = 0; y < Details::kElemsPerThread; ++y)
{
*reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2)
= __hfma2(*reinterpret_cast<half2*>(weights_f16 + y * NPerBlock + x * 2),
__half2half2(in_v[y]), *reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2));
}
}
}
}
float reses[Num];
for (int i = 0; i < Num; ++i)
{
reses[i] = __half2float(accumulator[i]);
}
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
// corresponding address in shared memory
Details::Layout::sync<Num, WarpSize>(reses, sm);
// Each thread is responsible for the accumulation and store to global memory of one element
for (int i = tid; i < Num * Interleave; i += BlockSize)
{
int nid = i % (NPerBlock * Interleave);
float v = 0.f;
for (int j = 0; j < BlockSize / WarpSize; ++j)
{
v += sm[j][i];
}
float bias_v = 0.f;
if constexpr (Bias)
{
bias_v = __half2float(bias[n_start_id + nid]);
}
int b = i / NPerBlock / Interleave;
out[b * n + n_start_id + nid] = __float2half_rn(ActOp<float>::apply(v + bias_v));
}
smooth_quant
主要基于以下幾點觀察:
- 激活值比權重更加難以量化
- 異常值的存在讓激活值更加難以量化
- 異常值通常出現在固定的channel中
所以smooth_quant的做法是將激活值和weight同時放縮一定倍數,這樣異常激活值就可以被平滑,進而使得激活的量化不那么困難。
因為放縮本身就是乘加操作,所以可以將attn中量化操作融合到前一步中的RMSNORM操作中,節省量化開銷。
按照粒度的不同,可以分為per-channel、per-token、per-tensor等幾種不同的粒度。
其計算公式如下:
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
# 統計的是每一層中x,y,w的最大值
weight_scales = max_abs_value in per channel
scale = activation^alpha / (weights ^ (1-alpha))
gemm *= scale
rmsnorm_weights /= scale
計算出scale保存,在推理階段,因為rms_norm已經除了scale,所以X不需要額外的操作,使用普通的gemm計算即可。
GPTQ
GPTQ: ACCURATE POST-TRAINING QUANTIZATION FOR GENERATIVE PR E-TRAINED TRANSFORMERS
GPTQ 將權重分組(如:128列為一組)為多個子矩陣(block)。對某個 block 內的所有參數逐個量化,每個參數量化后,需要適當調整這個 block 內其他未量化的參數,以彌補量化造成的精度損失。因此,GPTQ 量化需要準備校準數據集。
使用 Cholesky 分解中 Hessian 矩陣的逆,在給定的step中對連續列的塊進行量化,并在step結束時更新剩余的權重。
取消貪心算法:OBS 采用貪心策略,先量化對目標影響最小的參數;但 GPTQ 發現直接按順序做參數量化,對精度影響也不大。這項改進使得參數矩陣每一行的量化可以做并行的矩陣計算。
Lazy Batch-Updates:OBQ 對權重一個個進行單獨更新,作者發現性能瓶頸實際在于GPU的內存帶寬,而且同一個特征矩陣W不同列間的權重更新是不會互相影響的。因此作者提出了延遲批處理的方法,通過延遲一部分參數的更新,一次處理多個(如:128)列,來緩解帶寬的壓力,大幅提升了計算速度。
Cholesky 分解:用 Cholesky 分解求海森矩陣的逆,提前計算好所有需要的信息,在增強數值穩定性的同時,后續更新的過程中再計算,進一步減少了計算量。
columns = w.shape[1]
H = torch.zeros((self.columns, self.columns), device=self.dev)
# 計算海森矩陣的逆向
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
# 利用海森矩陣的逆依次計算誤差,并更新后續參數
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
# 量化當前的weight
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
# 在同一block內, 更新后續的weight
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
# 在不同blcok間更新后續weight
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
上述代碼實現是Auto_GPTQ中的實現,代碼可分為兩個部分,第一部分使用cholesky分解方法求得海森矩陣的逆;第二部分則根據scale和zero對w進行量化,依次得到err和loss,并據此依次更新后續的矩陣。注意這里我把group_size部分縮減了。
其計算公式為\(q = quantize(w.unsqueeze(1)); err1 = (w - q) / d; W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))\),這里d為海森逆矩陣對角線上的值,需要在同一block內和不同block間順序更新后續權重參數。
quantizer中為普通的量化實現,quantize如下所示,會將原有的浮點數scale到對應的int8或in4區間。實現如下:
# 對稱量化操作
scale = (xmax - xmin) / self.maxq
zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
# 非對稱量化操作
zero = torch.round(-xmin / self.scale)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
AWQ
權重對于LLM的性能并不同等重要,但要找到顯著的權重通道,我們應該根據激活分布而不是權重分布,AWQ可以看作是smooth_quant的改進版。
自動搜索最優縮放,使全部權重下的量化誤差最小,采用grid_search的方法對scale進行搜索,以保證最終Loss的損失值最小。
只測量每個通道的平均幅度誤差來確定每個通道權重的重要性。
autoawq中的代碼實現如下所示,所有過程分為4步:
- 計算每個channel weight的最大值
- 計算x的最大值
- 計算當前module的fp16輸出
- 計算更新最大的scale
# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)
# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = module2inspect(inp, **kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, kwargs
)
具體的網格搜索算法如下所示,將按照網格對[0, 1]內的數進行遍歷scale,計算出使得L2(fp16 - quant_v)的最佳scale。
- 網格遍歷scale,這里n_grid=20
- 按照smooth_quant中的方式計算scale,這里多了一步平滑
- weight *= scale;將偽量化算子插入到模型中
- 計算fp16和quant_v之間的誤差,保留誤差最小時的scale
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()

浙公網安備 33010602011771號