ControlNet-trt優(yōu)化總結(jié)2:使用TRT-API從零構(gòu)建ControlNet網(wǎng)絡
ControlNet-trt優(yōu)化總結(jié)2:使用TRT-API從零構(gòu)建ControlNet網(wǎng)絡
在上節(jié)講到,可以通過手動搭建trt網(wǎng)絡的方式來重新構(gòu)造controlnet網(wǎng)絡,這樣可以避免onnx中間轉(zhuǎn)換過程中的精度損失,也可避免onnx中間轉(zhuǎn)化時的算子被拆解的細碎的情況,對于不支持的算子,也可通過添加插件的方式添加不支持的算子。
基礎(chǔ)概念
tensorrt.INetworkDefinition: 網(wǎng)絡結(jié)構(gòu)定義對象,可以由解析器解析得到,或者由TensorRT API構(gòu)建而成
tensorrt.Builder: 根據(jù)NetworkDefinition和相應的BuilderConfig生成CudaEngine,CudaEngine是build好的二進制計算圖
tensorrt.IExecutionContext: 根據(jù)CudaEngine生成IExecutionContext,每個CudaEngine可以生成多個ExecutionContext
注意:
- 下面的network一般是指tensorrt.INetworkDefinition對象。
- x有兩種情況,一種是tensorrt.ITensor對象,多見于第一次輸入,另外一種是tensorrt.ILayer對象,多見于中間層輸入,tensorrt.ITensor可以視為計算圖的邊,tensorrt.ILayer可以視為計算圖的節(jié)點。
- 所有算子都需要傳入weight_map和其參數(shù)名稱,其返回值都是tensorrt.ILayer對象。
常用TRT接口函數(shù)
add_input(self: tensorrt.tensorrt.INetworkDefinition,
name: str,
dtype: tensorrt.tensorrt.DataType,
shape: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ITensor
功能:為網(wǎng)絡添加一個輸入層
參數(shù): name - 層的名字
dtype - tensor的數(shù)據(jù)類型,如trt.float32
shape - tensor的形狀,必須小于2^30個元素
返回值: 一個新的tensor
add_scale(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
mode: tensorrt.tensorrt.ScaleMode,
shift: tensorrt.tensorrt.Weights ,
scale: tensorrt.tensorrt.Weights ,
power: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IScaleLayer
功能:控制每個元素縮放大小,計算公式為$output=(input*scale+shift)^{power}$
參數(shù) : input - 輸入tensor,最少有三個維度
mode - 縮放的模式,如trt.ScaleMode.UNIFORM,表示作用于每一個元素
shift - Weights變量,公式中的shift值
scale - Weights變量,公式中的scale值
power - Weights變量,公式中的power值
如果Weights變量可以得到,那么Weights變量的shape與mode模式相關(guān):
UNIFORM:形狀等于1
CHANNEL:形狀為通道的維度
ELEMENTWISE:形狀與input的形狀相同
返回值: 一個新的layer或None
add_slice(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
start: tensorrt.tensorrt.Dims,
shape: tensorrt.tensorrt.Dims,
stride: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ISliceLayer
功能:tensor切片
參數(shù) : input - 輸入tensor
start - 起始index
shape - 輸出shape
stride - 切片步長
返回值: 一個新的layer或None
add_constant(self: tensorrt.tensorrt.INetworkDefinition,
shape: tensorrt.tensorrt.Dims,
weights: tensorrt.tensorrt.Weights) → tensorrt.tensorrt.IConstantLayer
功能:添加一個常數(shù)層,可以把weight對象轉(zhuǎn)變?yōu)閘ayer進而變?yōu)閠ensor
參數(shù) : shape - 形狀
weights - weight對象
返回值: 一個新的layer或None
add_elementwise(self: tensorrt.tensorrt.INetworkDefinition,
input1: tensorrt.tensorrt.ITensor,
input2: tensorrt.tensorrt.ITensor,
op: tensorrt.tensorrt.ElementWiseOperation) → tensorrt.tensorrt.IElementWiseLayer
功能:二元操作
參數(shù): input1(input2) - 輸入tensor,形狀必須相等
op - 二元操作符,在ElementWiseOperation中,如:
trt.ElementWiseOperation.PROD(乘積)
trt.ElementWiseOperation.SUM(加法)
返回值: 一個新的layer或None
add_unary(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
op: tensorrt.tensorrt.UnaryOperation) → tensorrt.tensorrt.IUnaryLayer
功能:一元操作
參數(shù): input1 - 輸入tensor,
op - 一元操作符,在UnaryOperation中,如:
trt.UnaryOperation.EXP(自然指數(shù))
trt.UnaryOperation.LOG(自然對數(shù))
返回值: 一個新的layer或None
add_convolution(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
num_output_maps: int,
kernel_shape: tensorrt.tensorrt.DimsHW,
kernel: tensorrt.tensorrt.Weights,
bias: tensorrt.tensorrt.Weights = None)→ tensorrt.tensorrt.IConvolutionLayer
功能:添加一個2D的卷積
參數(shù): input - 輸入Tensor,4維張量
num_output_maps - 輸出特征圖數(shù)量,也即后一層的channel
kernel_shape - 卷積核大小
kernel - 卷積核的數(shù)據(jù)
bias - 卷積bias的數(shù)據(jù)
返回值: 一個新的layer或None
add_activation(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
type: tensorrt.tensorrt.ActivationType) → tensorrt.tensorrt.IActivationLayer
功能:添加激活層,進行逐元素的激活操作,輸出形狀大小和輸入形狀大小一致
參數(shù): input – 輸入tensor
type – 對應的激活類型,RELU、SIGMOID、TANH、LEAKY_RELU等,參考tensorrt.ActivationType。
返回值:一個新的layer或None
add_normalization(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
scale: tensorrt.tensorrt.ITensor,
bias: tensorrt.tensorrt.ITensor,
axesMask: int)→ tensorrt.tensorrt.INormalizationLayer
功能:添加一個歸一化層,執(zhí)行$Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B$,trt內(nèi)部實際上是使用instancenorm來實現(xiàn)的,有些時候需要自己手寫替換
參數(shù): input – 輸入Tensor
scale – 歸一化的sacle放縮參數(shù)
bias – 歸一化的bias參數(shù)
axesMask – 進行mean操作的axes,以(1<<i)位壓縮的方式進行傳遞
返回值: 一個新的layer或None
add_matrix_multiply(self: tensorrt.tensorrt.INetworkDefinition,
input0: tensorrt.tensorrt.ITensor,
op0: tensorrt.tensorrt.MatrixOperation,
input1: tensorrt.tensorrt.ITensor,
op1: tensorrt.tensorrt.MatrixOperation) → tensorrt.tensorrt.IMatrixMultiplyLayer
功能: 添加一個一個矩陣乘積運算,分為4種情況,矩陣矩陣、矩陣向量、向量矩陣和向量向量
參數(shù): input0 – 第一個矩陣張量
op0 – 處理類型,矩陣處理類型,轉(zhuǎn)置或向量
input1 – 第二個矩陣向量
op1 – 處理類型,矩陣處理類型,轉(zhuǎn)置或向量
返回值: 一個新的layer或None
add_shuffle(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.IShuffleLayer?
功能:添加一個shuffle層,對應的是transpose核reshape算子
參數(shù): input – 每一層的輸入tensor
返回值: 一個新的layer或None
add_softmax(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.ISoftMaxLayer?
功能:添加一個softmax層,按照axes方向進行逐通道softmax操作,axes是位壓縮的mask
參數(shù): input – 輸入的Tensor
返回值: 一個新的layer或None
add_gather(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
indices: tensorrt.tensorrt.ITensor,
axis: int)→ tensorrt.tensorrt.IGatherLayer
功能:添加一個gather層,按照axis方向,在indices上取相應數(shù)據(jù),
參數(shù): input – 輸入張量
indices – index序列來產(chǎn)生output張量
axis – gather的方向,不能是batch方向
返回值:一個新的layer或None
add_einsum
功能:添加一個愛因斯坦算子層,與einsum相對應,主要用于矩陣乘法
參數(shù): inputs – 輸入張量
equation – 愛因斯坦等式
返回值: 一個新的layer或None
關(guān)鍵TRT算子
卷積算子
由于trt原生支持conv操作,所以這里調(diào)用的add_convolution函數(shù)直接計算,不過需要注意的是conv也可接受第一層的原始輸入。
def conv(network, weight_map, x, ch, pre, kernel, padding, stride):
x = network.add_convolution(
input=x if isinstance(x, trt.ITensor) else x.get_output(0),
num_output_maps=ch,
kernel_shape=(kernel, kernel),
kernel=weight_map['{}.weight'.format(pre)],
bias=weight_map['{}.bias'.format(pre)])
assert x
x.padding = (padding, padding)
x.stride = (stride, stride)
return x
激活算子
SILU算子被拆分為了SIGMOID和PROD兩個操作,實際上和onnx導出結(jié)果基本一致。
def silu(network, x):
y = network.add_activation(x.get_output(0), trt.ActivationType.SIGMOID)
assert y
x = network.add_elementwise(x.get_output(0), y.get_output(0), trt.ElementWiseOperation.PROD)
return x
歸一化算子
這里groupnorm調(diào)用了plugin插件,通過PluginField定義了epsilon和bSwish兩個屬性參數(shù),分別為誤差及是否使用Swish激活函數(shù)。
其輸入有上一層的輸入、weights以及bias,輸出的是groupnorm歸一化后的值。
import ctypes
ctypes.CDLL('./trt/libmyplugins.so.1', mode=ctypes.RTLD_GLOBAL)
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
gn_plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
def group_norm(network, weight_map, h, pre, epsilon=EPS, silu=False):
ch = h.get_output(0).shape[1]
# plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
plugin_creator = gn_plugin_creator
s = network.add_constant([1, ch, 1, 1], weight_map['{}.weight'.format(pre)])
b = network.add_constant([1, ch, 1, 1], weight_map['{}.bias'.format(pre)])
eps_attr = trt.PluginField("epsilon", np.array([epsilon], dtype=np.float32), type=trt.PluginFieldType.FLOAT32)
silu_attr = trt.PluginField("bSwish", np.array([1 if silu else 0], dtype=np.int32), type=trt.PluginFieldType.INT32)
field_collection = trt.PluginFieldCollection([eps_attr, silu_attr])
plugin = plugin_creator.create_plugin(name='{}.group_norm'.format(pre), field_collection=field_collection)
n = network.add_plugin_v2(inputs=[h.get_output(0), s.get_output(0), b.get_output(0)], plugin=plugin)
return n
這里layer_norm執(zhí)行的計算如下:
Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B
在不同axes執(zhí)行的結(jié)果實際上是不一樣的,這里axesMask的設(shè)置實際上是倒數(shù)第3維方向上進行歸一化,對于seq人物,第一維是batch,第二維是seq長度。
def layer_norm(network, weight_map, h, pre, epsilon=EPS):
scale_np = weight_map['{}.weight'.format(pre)]
ch = scale_np.shape[0]
scale = network.add_constant([1, 1, ch], scale_np)
bias_np = weight_map['{}.bias'.format(pre)]
bias = network.add_constant([1, 1, ch], bias_np)
n = network.add_normalization(
h.get_output(0),
scale=scale.get_output(0),
bias=bias.get_output(0),
axesMask=1 << 2)
assert n
n.epsilon = epsilon
return n
Attention算子
因為Trt不直接支持4維矩陣的乘加運算,所以HW進行了合并。這里MHA是8個head,在計算時時合并batch進行計算的,所以就有以下的轉(zhuǎn)化。
[2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
在具體運算上,qkv的計算是由矩陣乘加得到的,這點有可優(yōu)化的點,可以將3個乘積一起計算,而不是分開來進行計算,更利于并行。
而qk乘積部分則是由add_einsum計算得到的,隨后softmax之后的結(jié)果與v進行乘積,需要注意的是需要將最終結(jié)果還原到[2, h * w, c]。
接下來的部分便是一個殘差連接,得到并輸出最終結(jié)果。
def self_attention(network, weight_map, i, ch, x):
heads = 8
dim_head = ch / heads
scale = dim_head ** -0.5
wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_q.weight'.format(i)])
wk = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_k.weight'.format(i)])
wv = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_v.weight'.format(i)])
q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
k = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wk.get_output(0), trt.MatrixOperation.TRANSPOSE)
v = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wv.get_output(0), trt.MatrixOperation.TRANSPOSE)
# q [2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (2, -1, 8, ch // 8)
q.second_transpose = trt.Permutation([0, 2, 1, 3])
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (16, -1, ch // 8)
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (2, -1, 8, ch // 8)
k.second_transpose = trt.Permutation([0, 2, 1, 3])
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (16, -1, ch // 8)
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (2, -1, 8, ch // 8)
v.second_transpose = trt.Permutation([0, 2, 1, 3])
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (16, -1, ch // 8)
s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
print(s.get_output(0).shape)
s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([scale], np.float32)))
s = network.add_softmax(s.get_output(0))
s.axes = 1<<2
out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
# [16, h * w, d] -> [2, 8, h * w, d] -> [2, h * w, 8, d] -> [2, h * w, c]
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, 8, -1, ch // 8)
out.second_transpose = trt.Permutation([0, 2, 1, 3])
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, -1, ch)
# to_out
outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.weight'.format(i)])
outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.bias'.format(i)])
out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
return out
cross attention與self attention算子類似,區(qū)別在于其kv是從context中獲取,這里的context是上一層或上一次context計算的結(jié)果,而只有q是weight和上一層計算得到的結(jié)果。
def cross_attention(network, weight_map, i, ch, x, context):
heads = 8
dim_head = ch / heads
scale = dim_head ** -0.5
wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_q.weight'.format(i)])
q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
# [2, h*w, c]
dim = ch // 8
k = network.add_slice(context['context'],
trt.Dims([0, 0, 8 * context['start']]),
trt.Dims([2, 77, ch]),
trt.Dims([1, 1, 1]))
v = network.add_slice(context['context'],
trt.Dims([0, 0, 8 * (context['start'] + dim)]),
trt.Dims([2, 77, ch]),
trt.Dims([1, 1, 1]))
context['start'] += 2 * dim
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (2, -1, 8, ch // 8)
q.second_transpose = trt.Permutation([0, 2, 1, 3])
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (16, -1, ch // 8)
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (2, -1, 8, ch // 8)
k.second_transpose = trt.Permutation([0, 2, 1, 3])
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (16, -1, ch // 8)
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (2, -1, 8, ch // 8)
v.second_transpose = trt.Permutation([0, 2, 1, 3])
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (16, -1, ch // 8)
s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
print(s.get_output(0).shape)
# scale = network.add_constant((1, 1, 1), np.array([scale], np.float32))
# s = network.add_elementwise(s.get_output(0), scale.get_output(0), trt.ElementWiseOperation.PROD)
s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([scale], np.float32)))
s = network.add_softmax(s.get_output(0))
s.axes = 1<<2
out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, 8, -1, ch // 8)
out.second_transpose = trt.Permutation([0, 2, 1, 3])
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, -1, ch)
# to_out
outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.weight'.format(i)])
outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.bias'.format(i)])
out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
return out
這里把ffn同樣歸總到attention算子中,有一次全連接和一個gelu激活函數(shù),需要注意的是乘加結(jié)果是分開來算的。
這里add_unary是一元算子,主要進行指數(shù)運算。
def feed_forward(network, weight_map, i, ch, x):
w1 = network.add_constant((1, ch * 8, ch), weight_map['{}.transformer_blocks.0.ff.net.0.proj.weight'.format(i)])
b1 = network.add_constant((1, 1, ch * 8), weight_map['{}.transformer_blocks.0.ff.net.0.proj.bias'.format(i)])
n = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
w1.get_output(0), trt.MatrixOperation.TRANSPOSE)
n = network.add_elementwise(n.get_output(0), b1.get_output(0), trt.ElementWiseOperation.SUM)
hw = n.get_output(0).shape[1]
# w = n.get_output(0).shape[3]
n1 = network.add_slice(n.get_output(0), trt.Dims([0, 0, 0]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
n2 = network.add_slice(n.get_output(0), trt.Dims([0, 0, ch * 4]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
# gelu
e = network.add_scale(n2.get_output(0), mode=trt.ScaleMode.UNIFORM, scale=trt.Weights(np.array([2 ** -0.5], np.float32)))
e = network.add_unary(e.get_output(0), trt.UnaryOperation.ERF)
e = network.add_scale(e.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([0.5], np.float32)),
shift=trt.Weights(np.array([0.5], np.float32)))
n = network.add_elementwise(n2.get_output(0), e.get_output(0), trt.ElementWiseOperation.PROD)
n = network.add_elementwise(n.get_output(0), n1.get_output(0), trt.ElementWiseOperation.PROD)
w2 = network.add_constant((1, ch, ch * 4), weight_map['{}.transformer_blocks.0.ff.net.2.weight'.format(i)])
b2 = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.ff.net.2.bias'.format(i)])
n = network.add_matrix_multiply(n.get_output(0), trt.MatrixOperation.NONE,
w2.get_output(0), trt.MatrixOperation.TRANSPOSE)
n = network.add_elementwise(n.get_output(0), b2.get_output(0), trt.ElementWiseOperation.SUM)
return n
關(guān)鍵模塊
transformer模塊
這里基礎(chǔ)的transformer就不再詳細探討,標準的attn1-attn2-ffn的過程,需要注意的是trt不支持4維操作,前后要多一次reshape操作。
def basic_transformer(network, weight_map, i, ch, x, context):
H = x.get_output(0).shape[2]
W = x.get_output(0).shape[3]
# n c h w -> b (h w) c
x = network.add_shuffle(x.get_output(0))
x.first_transpose = trt.Permutation([0, 2, 3, 1])
x.reshape_dims = (2, -1, ch)
# attn1
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm1'.format(i))
attn1 = self_attention(network, weight_map, i, ch, n)
x = network.add_elementwise(attn1.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# attn2
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm2'.format(i))
attn2 = cross_attention(network, weight_map, i, ch, n, context)
x = network.add_elementwise(attn2.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# ff
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm3'.format(i))
ff = feed_forward(network, weight_map, i, ch, n)
x = network.add_elementwise(ff.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# n (h w) c -> n c h w
x = network.add_shuffle(x.get_output(0))
x.first_transpose = trt.Permutation([0, 2, 1])
x.reshape_dims = (2, ch, H, W)
return x
spatial_transformer是在basic_transformer基礎(chǔ)上加了兩次conv投影。
def spatial_transformer(network, weight_map, i, ch, h, context):
# return h
# norm
n = group_norm(network, weight_map, h, '{}.norm'.format(i), 1e-6)
# proj_in
n = conv(network, weight_map, n, ch, '{}.proj_in'.format(i), 1, 0, 1)
# BasicTransformerBlock
n = basic_transformer(network, weight_map, i, ch, n, context)
# proj_out
n = conv(network, weight_map, n, ch, '{}.proj_out'.format(i), 1, 0, 1)
h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
return h
采樣模塊
下采樣則是卷積操作,上采樣則是線性插值操作,zero_convs則是不改變原有特征圖大小。
def input_first(network, weight_map, pre, h):
h = conv(network, weight_map, h, 320, '{}.input_blocks.0.0'.format(pre), 3, 1, 1)
return h
def downsample(network, weight_map, i, ch, x):
x = conv(network, weight_map, x, ch, '{}.op'.format(i), 3, 1, 2)
return x
def upsample(network, weight_map, i, ch, x):
x = network.add_resize(x.get_output(0))
x.scales = [1, 1, 2, 2]
x.resize_mode = trt.ResizeMode.NEAREST
x = conv(network, weight_map, x, ch, '{}.conv'.format(i), 3, 1, 1)
return x
def zero_convs(network, weight_map, x, i):
ch = x.get_output(0).shape[1]
x = conv(network, weight_map, x, ch, 'control_model.zero_convs.{}.0'.format(i), 1, 0, 1)
return x
block模塊
resblock 是由倒瓶頸結(jié)構(gòu)的卷積塊組成的殘差連接模塊。
def resblock(network, weight_map, embed_weight, i, ch, h, emb):
print('resblock: ', h.get_output(0).shape, '{}.in_layers.0'.format(i))
## in_layers
# group_norm
n = group_norm(network, weight_map, h, '{}.in_layers.0'.format(i), silu=True)
# silu
# n = silu(network, n)
# conv_nd
n = conv(network, weight_map, n, ch, '{}.in_layers.2'.format(i), 3, 1, 1)
print('in_layers: ', n.get_output(0).shape)
## emb_layers
m = network.add_constant([20, ch, 1, 1], embed_weight.pop(0))
m = network.add_gather(m.get_output(0), emb, axis=0)
print('emb_layers: ', m.get_output(0).shape)
n = network.add_elementwise(n.get_output(0), m.get_output(0), trt.ElementWiseOperation.SUM)
## out_layers
n = group_norm(network, weight_map, n, '{}.out_layers.0'.format(i), silu=True)
# n = silu(network, n)
n = conv(network, weight_map, n, ch, '{}.out_layers.3'.format(i), 3, 1, 1)
print('out_layers: ', n.get_output(0).shape)
in_ch = h.get_output(0).shape[1]
if in_ch != ch:
# skip_connection
h = conv(network, weight_map, h, ch, '{}.skip_connection'.format(i), 1, 0, 1)
h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
return h
input_block則是由不同level、不同大小channel的resblock以及spatial_transformer組成的。
middle_block則是resblock和spatial_transformer的組合。
output_blocks與input_block類似,只不過由input_block中的下采樣變成了output_blocks中的上采樣。
這三個block是unet中的重要組成部分,對應了Unet先下采樣到特征狀態(tài)再上采樣到對應圖像的過程。
def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
hs = []
h = input_first(network, weight_map, 'control_model', h)
h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
hs.append(zero_convs(network, weight_map, h, 0))
# h [2, 320, 32, 48]
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = 'control_model.input_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(zero_convs(network, weight_map, h, index))
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = 'control_model.input_blocks.{}'.format(index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(zero_convs(network, weight_map, h, index))
index = index + 1
# if index == 10:
return hs, h
def input_block(network, weight_map, embed_weight, h, emb, context, model_name):
hs = []
h = input_first(network, weight_map, model_name, h)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
#return h
hs.append(h)
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = '{}.input_blocks.{}'.format(model_name, index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(h)
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = '{}.input_blocks.{}'.format(model_name, index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(h)
index = index + 1
# if index == 10:
return hs, h
def middle_block(network, weight_map, embed_weight, h, emb, context, model_name):
pre = '{}.middle_block'.format(model_name)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), 1280, h, emb)
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), 1280, h, context)
h = resblock(network, weight_map, embed_weight, '{}.2'.format(pre), 1280, h, emb)
return h
def output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs):
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 0
for level, mult in list(enumerate(channel_mult))[::-1]:
ch = model_channels * mult
for i in range(num_res_blocks[level] + 1):
print(control[-1].shape, hs[-1].shape, len(hs), h.get_output(0).shape)
c = network.add_elementwise(control.pop(), hs.pop(), trt.ElementWiseOperation.SUM)
h = network.add_concatenation([h.get_output(0), c.get_output(0)])
print('output: ', index, h.get_output(0).shape)
pre = 'model.diffusion_model.output_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
if level and i == num_res_blocks[level]:
h = upsample(network, weight_map,
'{}.{}'.format(pre, 1 if level == len(channel_mult) - 1 else 2), ch, h)
index = index + 1
print(h.get_output(0).shape, len(hs), len(control), index)
return h
input_block_control是control_net的上半部分,在結(jié)構(gòu)參數(shù)上與Unet一樣,但是在每一層都添加了zero_convs層學習參數(shù)。
def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
hs = []
h = input_first(network, weight_map, 'control_model', h)
h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
hs.append(zero_convs(network, weight_map, h, 0))
# h [2, 320, 32, 48]
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = 'control_model.input_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(zero_convs(network, weight_map, h, index))
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = 'control_model.input_blocks.{}'.format(index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(zero_convs(network, weight_map, h, index))
index = index + 1
# if index == 10:
return hs, h
網(wǎng)絡構(gòu)建模塊
controlnet
這里h, hint, emb經(jīng)過input_block_control得到control和h的特征,h經(jīng)過middle_block得到不同尺度特征的control特征。
def control_net(network, weight_map, embed_weight, h, hint, emb, context):
# #####################
# # time_embed
# #####################
#####################
# input_blocks
#####################
control, h = input_block_control(network, weight_map, embed_weight, h, emb, context, hint)
print(h.get_output(0).shape)
#####################
# middle_blocks
#####################
h = middle_block(network, weight_map, embed_weight, h, emb, context, 'control_model')
h = conv(network, weight_map, h, 1280, 'control_model.middle_block_out.0', 1, 0, 1)
control.append(h)
return control
Unet
Unet的組成相對簡單,經(jīng)過input_block、middle_block和output_blocks得到最終結(jié)果,并返回最終狀態(tài)。
def unet(network, weight_map, embed_weight, h, emb, context, control):
# #####################
# # time_embed
# #####################
#####################
# input_blocks
#####################
hs, h = input_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
print(h.get_output(0).shape)
#####################
# middle_blocks
#####################
h = middle_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
print(h.get_output(0).shape)
h = network.add_elementwise(h.get_output(0), control.pop().get_output(0), trt.ElementWiseOperation.SUM)
#####################
# output_blocks
#####################
h = output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs)
# out
# group_norm
# h = group_norm_sile(network, weight_map, h)
h = group_norm(network, weight_map, h, 'model.diffusion_model.out.0', silu=True)
# silu
# h = silu(network, h)
# conv_nd
h = conv(network, weight_map, h, 4, 'model.diffusion_model.out.2', 3, 1, 1)
return h
參考
- nvidia python api: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
- xiatwhu: https://github.com/deeplearning/xiatwhu/trt2023

浙公網(wǎng)安備 33010602011771號