<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      ControlNet-trt優(yōu)化總結(jié)2:使用TRT-API從零構(gòu)建ControlNet網(wǎng)絡

      ControlNet-trt優(yōu)化總結(jié)2:使用TRT-API從零構(gòu)建ControlNet網(wǎng)絡

      在上節(jié)講到,可以通過手動搭建trt網(wǎng)絡的方式來重新構(gòu)造controlnet網(wǎng)絡,這樣可以避免onnx中間轉(zhuǎn)換過程中的精度損失,也可避免onnx中間轉(zhuǎn)化時的算子被拆解的細碎的情況,對于不支持的算子,也可通過添加插件的方式添加不支持的算子。

      基礎(chǔ)概念

      tensorrt.INetworkDefinition: 網(wǎng)絡結(jié)構(gòu)定義對象,可以由解析器解析得到,或者由TensorRT API構(gòu)建而成
      tensorrt.Builder: 根據(jù)NetworkDefinition和相應的BuilderConfig生成CudaEngine,CudaEngine是build好的二進制計算圖
      tensorrt.IExecutionContext: 根據(jù)CudaEngine生成IExecutionContext,每個CudaEngine可以生成多個ExecutionContext

      注意:

      1. 下面的network一般是指tensorrt.INetworkDefinition對象。
      2. x有兩種情況,一種是tensorrt.ITensor對象,多見于第一次輸入,另外一種是tensorrt.ILayer對象,多見于中間層輸入,tensorrt.ITensor可以視為計算圖的邊,tensorrt.ILayer可以視為計算圖的節(jié)點。
      3. 所有算子都需要傳入weight_map和其參數(shù)名稱,其返回值都是tensorrt.ILayer對象。

      常用TRT接口函數(shù)

      add_input(self: tensorrt.tensorrt.INetworkDefinition, 
                name: str, 
               dtype: tensorrt.tensorrt.DataType, 
               shape: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ITensor
      
      功能:為網(wǎng)絡添加一個輸入層  
      參數(shù):          name  - 層的名字  
                    dtype - tensor的數(shù)據(jù)類型,如trt.float32  
                    shape - tensor的形狀,必須小于2^30個元素  
      返回值:  一個新的tensor  
      
      add_scale(self: tensorrt.tensorrt.INetworkDefinition, 
               input: tensorrt.tensorrt.ITensor, 
                mode: tensorrt.tensorrt.ScaleMode, 
               shift: tensorrt.tensorrt.Weights , 
               scale: tensorrt.tensorrt.Weights , 
               power: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IScaleLayer
      功能:控制每個元素縮放大小,計算公式為$output=(input*scale+shift)^{power}$  
      參數(shù) :         input - 輸入tensor,最少有三個維度  
                    mode - 縮放的模式,如trt.ScaleMode.UNIFORM,表示作用于每一個元素  
                    shift - Weights變量,公式中的shift值  
                    scale - Weights變量,公式中的scale值  
                    power - Weights變量,公式中的power值  
      如果Weights變量可以得到,那么Weights變量的shape與mode模式相關(guān):  
              UNIFORM:形狀等于1  
              CHANNEL:形狀為通道的維度  
              ELEMENTWISE:形狀與input的形狀相同  
      返回值:  一個新的layer或None  
      
      add_slice(self: tensorrt.tensorrt.INetworkDefinition, 
               input: tensorrt.tensorrt.ITensor, 
               start: tensorrt.tensorrt.Dims, 
               shape: tensorrt.tensorrt.Dims, 
              stride: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ISliceLayer
      
      功能:tensor切片
      參數(shù) :       input - 輸入tensor
                  start - 起始index
                  shape - 輸出shape
                  stride - 切片步長
      
      返回值:  一個新的layer或None
      
      add_constant(self: tensorrt.tensorrt.INetworkDefinition, 
                  shape: tensorrt.tensorrt.Dims, 
                weights: tensorrt.tensorrt.Weights) → tensorrt.tensorrt.IConstantLayer
      
      功能:添加一個常數(shù)層,可以把weight對象轉(zhuǎn)變?yōu)閘ayer進而變?yōu)閠ensor  
      參數(shù) :       shape - 形狀  
                  weights - weight對象  
      返回值:  一個新的layer或None    
      
      add_elementwise(self: tensorrt.tensorrt.INetworkDefinition, 
                    input1: tensorrt.tensorrt.ITensor, 
                    input2: tensorrt.tensorrt.ITensor, 
                    op: tensorrt.tensorrt.ElementWiseOperation) → tensorrt.tensorrt.IElementWiseLayer
      
      功能:二元操作
      參數(shù):  input1(input2) - 輸入tensor,形狀必須相等
                    op - 二元操作符,在ElementWiseOperation中,如:
                          trt.ElementWiseOperation.PROD(乘積)
                          trt.ElementWiseOperation.SUM(加法)
      
      返回值:  一個新的layer或None
      
      add_unary(self: tensorrt.tensorrt.INetworkDefinition,
               input: tensorrt.tensorrt.ITensor, 
               op: tensorrt.tensorrt.UnaryOperation) → tensorrt.tensorrt.IUnaryLayer
      功能:一元操作
      參數(shù):  input1 - 輸入tensor,
                    op - 一元操作符,在UnaryOperation中,如:
                          trt.UnaryOperation.EXP(自然指數(shù))
                          trt.UnaryOperation.LOG(自然對數(shù))
      
      返回值:  一個新的layer或None
      
      add_convolution(self: tensorrt.tensorrt.INetworkDefinition, 
                  input: tensorrt.tensorrt.ITensor, 
                  num_output_maps: int, 
                  kernel_shape: tensorrt.tensorrt.DimsHW, 
                  kernel: tensorrt.tensorrt.Weights, 
                  bias: tensorrt.tensorrt.Weights = None)→ tensorrt.tensorrt.IConvolutionLayer
      功能:添加一個2D的卷積
      參數(shù):           input - 輸入Tensor,4維張量
                      num_output_maps - 輸出特征圖數(shù)量,也即后一層的channel
                      kernel_shape - 卷積核大小
                      kernel - 卷積核的數(shù)據(jù)
                      bias - 卷積bias的數(shù)據(jù)
      返回值: 一個新的layer或None
      
      add_activation(self: tensorrt.tensorrt.INetworkDefinition, 
                  input: tensorrt.tensorrt.ITensor, 
                  type: tensorrt.tensorrt.ActivationType) → tensorrt.tensorrt.IActivationLayer
      功能:添加激活層,進行逐元素的激活操作,輸出形狀大小和輸入形狀大小一致
      參數(shù):           input – 輸入tensor
                      type – 對應的激活類型,RELU、SIGMOID、TANH、LEAKY_RELU等,參考tensorrt.ActivationType。
      返回值:一個新的layer或None
      
      add_normalization(self: tensorrt.tensorrt.INetworkDefinition,
                      input: tensorrt.tensorrt.ITensor, 
                      scale: tensorrt.tensorrt.ITensor, 
                      bias: tensorrt.tensorrt.ITensor, 
                      axesMask: int)→ tensorrt.tensorrt.INormalizationLayer
      功能:添加一個歸一化層,執(zhí)行$Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B$,trt內(nèi)部實際上是使用instancenorm來實現(xiàn)的,有些時候需要自己手寫替換
      參數(shù):           input – 輸入Tensor  
                      scale – 歸一化的sacle放縮參數(shù)  
                      bias – 歸一化的bias參數(shù)  
                      axesMask – 進行mean操作的axes,以(1<<i)位壓縮的方式進行傳遞  
      返回值: 一個新的layer或None  
      
      add_matrix_multiply(self: tensorrt.tensorrt.INetworkDefinition,
                      input0: tensorrt.tensorrt.ITensor, 
                      op0: tensorrt.tensorrt.MatrixOperation, 
                      input1: tensorrt.tensorrt.ITensor, 
                      op1: tensorrt.tensorrt.MatrixOperation) → tensorrt.tensorrt.IMatrixMultiplyLayer
      功能: 添加一個一個矩陣乘積運算,分為4種情況,矩陣矩陣、矩陣向量、向量矩陣和向量向量
      參數(shù):           input0 – 第一個矩陣張量
                      op0 – 處理類型,矩陣處理類型,轉(zhuǎn)置或向量
                      input1 – 第二個矩陣向量
                      op1 – 處理類型,矩陣處理類型,轉(zhuǎn)置或向量
      返回值: 一個新的layer或None 
      
      add_shuffle(self: tensorrt.tensorrt.INetworkDefinition, 
                  input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.IShuffleLayer?
      功能:添加一個shuffle層,對應的是transpose核reshape算子
      參數(shù):  input – 每一層的輸入tensor
      返回值: 一個新的layer或None 
      
      add_softmax(self: tensorrt.tensorrt.INetworkDefinition, 
                  input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.ISoftMaxLayer?
      功能:添加一個softmax層,按照axes方向進行逐通道softmax操作,axes是位壓縮的mask
      參數(shù): input – 輸入的Tensor
      返回值: 一個新的layer或None 
      
      add_gather(self: tensorrt.tensorrt.INetworkDefinition, 
                  input: tensorrt.tensorrt.ITensor, 
                  indices: tensorrt.tensorrt.ITensor, 
                  axis: int)→ tensorrt.tensorrt.IGatherLayer
      功能:添加一個gather層,按照axis方向,在indices上取相應數(shù)據(jù),
      參數(shù):   input – 輸入張量
              indices – index序列來產(chǎn)生output張量
              axis – gather的方向,不能是batch方向
      返回值:一個新的layer或None 
      
      add_einsum
      功能:添加一個愛因斯坦算子層,與einsum相對應,主要用于矩陣乘法
      參數(shù):   inputs – 輸入張量
              equation – 愛因斯坦等式
      返回值: 一個新的layer或None 
      

      關(guān)鍵TRT算子

      卷積算子

      由于trt原生支持conv操作,所以這里調(diào)用的add_convolution函數(shù)直接計算,不過需要注意的是conv也可接受第一層的原始輸入。

      def conv(network, weight_map, x, ch, pre, kernel, padding, stride):
          x = network.add_convolution(
                  input=x if isinstance(x, trt.ITensor) else x.get_output(0),
                  num_output_maps=ch,
                  kernel_shape=(kernel, kernel),
                  kernel=weight_map['{}.weight'.format(pre)],
                  bias=weight_map['{}.bias'.format(pre)])
          assert x
          x.padding = (padding, padding)
          x.stride = (stride, stride)
          return x
      

      激活算子

      SILU算子被拆分為了SIGMOID和PROD兩個操作,實際上和onnx導出結(jié)果基本一致。

      def silu(network, x):
          y = network.add_activation(x.get_output(0), trt.ActivationType.SIGMOID)
          assert y
          x = network.add_elementwise(x.get_output(0), y.get_output(0), trt.ElementWiseOperation.PROD)
          return x
      

      歸一化算子

      這里groupnorm調(diào)用了plugin插件,通過PluginField定義了epsilon和bSwish兩個屬性參數(shù),分別為誤差及是否使用Swish激活函數(shù)。
      其輸入有上一層的輸入、weights以及bias,輸出的是groupnorm歸一化后的值。

      import ctypes
      ctypes.CDLL('./trt/libmyplugins.so.1', mode=ctypes.RTLD_GLOBAL)
      
      TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
      trt.init_libnvinfer_plugins(TRT_LOGGER, '')
      gn_plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
      
      def group_norm(network, weight_map, h, pre, epsilon=EPS, silu=False):
          ch = h.get_output(0).shape[1]
          # plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
          plugin_creator = gn_plugin_creator
          s = network.add_constant([1, ch, 1, 1], weight_map['{}.weight'.format(pre)])
          b = network.add_constant([1, ch, 1, 1], weight_map['{}.bias'.format(pre)])
      
          eps_attr = trt.PluginField("epsilon", np.array([epsilon], dtype=np.float32), type=trt.PluginFieldType.FLOAT32)
          silu_attr = trt.PluginField("bSwish", np.array([1 if silu else 0], dtype=np.int32), type=trt.PluginFieldType.INT32)
          field_collection = trt.PluginFieldCollection([eps_attr, silu_attr])
      
          plugin = plugin_creator.create_plugin(name='{}.group_norm'.format(pre), field_collection=field_collection)
          n = network.add_plugin_v2(inputs=[h.get_output(0), s.get_output(0), b.get_output(0)], plugin=plugin)
          return n
      

      這里layer_norm執(zhí)行的計算如下:
      Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B
      在不同axes執(zhí)行的結(jié)果實際上是不一樣的,這里axesMask的設(shè)置實際上是倒數(shù)第3維方向上進行歸一化,對于seq人物,第一維是batch,第二維是seq長度。

      def layer_norm(network, weight_map, h, pre, epsilon=EPS):
          scale_np = weight_map['{}.weight'.format(pre)]
          ch = scale_np.shape[0]
          scale = network.add_constant([1, 1, ch], scale_np)
          bias_np = weight_map['{}.bias'.format(pre)]
          bias = network.add_constant([1, 1, ch], bias_np)
          n = network.add_normalization(
              h.get_output(0),
              scale=scale.get_output(0),
              bias=bias.get_output(0),
              axesMask=1 << 2)
          assert n
          n.epsilon = epsilon
      
          return n    
      

      Attention算子

      因為Trt不直接支持4維矩陣的乘加運算,所以HW進行了合并。這里MHA是8個head,在計算時時合并batch進行計算的,所以就有以下的轉(zhuǎn)化。
      [2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]

      在具體運算上,qkv的計算是由矩陣乘加得到的,這點有可優(yōu)化的點,可以將3個乘積一起計算,而不是分開來進行計算,更利于并行。
      而qk乘積部分則是由add_einsum計算得到的,隨后softmax之后的結(jié)果與v進行乘積,需要注意的是需要將最終結(jié)果還原到[2, h * w, c]。
      接下來的部分便是一個殘差連接,得到并輸出最終結(jié)果。

      def self_attention(network, weight_map, i, ch, x):   
          heads = 8
          dim_head = ch / heads
          scale = dim_head ** -0.5
      
          wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_q.weight'.format(i)])
          wk = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_k.weight'.format(i)])
          wv = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_v.weight'.format(i)])
      
          q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                          wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
          k = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                          wk.get_output(0), trt.MatrixOperation.TRANSPOSE)
          v = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                          wv.get_output(0), trt.MatrixOperation.TRANSPOSE)
      
          # q [2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
          q = network.add_shuffle(q.get_output(0))
          q.reshape_dims = (2, -1, 8, ch // 8)
          q.second_transpose = trt.Permutation([0, 2, 1, 3])
          q = network.add_shuffle(q.get_output(0))
          q.reshape_dims = (16, -1, ch // 8)
      
          k = network.add_shuffle(k.get_output(0))
          k.reshape_dims = (2, -1, 8, ch // 8)
          k.second_transpose = trt.Permutation([0, 2, 1, 3])
          k = network.add_shuffle(k.get_output(0))
          k.reshape_dims = (16, -1, ch // 8)
      
          v = network.add_shuffle(v.get_output(0))
          v.reshape_dims = (2, -1, 8, ch // 8)
          v.second_transpose = trt.Permutation([0, 2, 1, 3])
          v = network.add_shuffle(v.get_output(0))
          v.reshape_dims = (16, -1, ch // 8)
      
          s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
          print(s.get_output(0).shape)
      
          s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
                                scale=trt.Weights(np.array([scale], np.float32)))
      
          s = network.add_softmax(s.get_output(0))
          s.axes = 1<<2
      
          out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
          # [16, h * w, d] -> [2, 8, h * w, d] -> [2, h * w, 8, d] -> [2, h * w, c]
          out = network.add_shuffle(out.get_output(0))
          out.reshape_dims = (2, 8, -1, ch // 8)
          out.second_transpose = trt.Permutation([0, 2, 1, 3])
          out = network.add_shuffle(out.get_output(0))
          out.reshape_dims = (2, -1, ch)
      
          # to_out
          outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.weight'.format(i)])
          outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.bias'.format(i)])
      
          out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
                                            outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
      
          out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
      
          return out
      

      cross attention與self attention算子類似,區(qū)別在于其kv是從context中獲取,這里的context是上一層或上一次context計算的結(jié)果,而只有q是weight和上一層計算得到的結(jié)果。

      def cross_attention(network, weight_map, i, ch, x, context):
          heads = 8
          dim_head = ch / heads
          scale = dim_head ** -0.5
      
          wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_q.weight'.format(i)])
      
          q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                          wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
          # [2, h*w, c]
      
          dim = ch // 8
          k = network.add_slice(context['context'],
                                trt.Dims([0, 0, 8 * context['start']]),
                                trt.Dims([2, 77, ch]),
                                trt.Dims([1, 1, 1]))
          v = network.add_slice(context['context'],
                                trt.Dims([0, 0, 8 * (context['start'] + dim)]),
                                trt.Dims([2, 77, ch]),
                                trt.Dims([1, 1, 1]))
          context['start'] += 2 * dim
      
          q = network.add_shuffle(q.get_output(0))
          q.reshape_dims = (2, -1, 8, ch // 8)
          q.second_transpose = trt.Permutation([0, 2, 1, 3])
          q = network.add_shuffle(q.get_output(0))
          q.reshape_dims = (16, -1, ch // 8)
      
          k = network.add_shuffle(k.get_output(0))
          k.reshape_dims = (2, -1, 8, ch // 8)
          k.second_transpose = trt.Permutation([0, 2, 1, 3])
          k = network.add_shuffle(k.get_output(0))
          k.reshape_dims = (16, -1, ch // 8)
      
          v = network.add_shuffle(v.get_output(0))
          v.reshape_dims = (2, -1, 8, ch // 8)
          v.second_transpose = trt.Permutation([0, 2, 1, 3])
          v = network.add_shuffle(v.get_output(0))
          v.reshape_dims = (16, -1, ch // 8)
      
          s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
          print(s.get_output(0).shape)
      
          # scale = network.add_constant((1, 1, 1), np.array([scale], np.float32))
          # s = network.add_elementwise(s.get_output(0), scale.get_output(0), trt.ElementWiseOperation.PROD)
          s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
                                scale=trt.Weights(np.array([scale], np.float32)))
      
          s = network.add_softmax(s.get_output(0))
          s.axes = 1<<2
      
          out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
          out = network.add_shuffle(out.get_output(0))
          out.reshape_dims = (2, 8, -1, ch // 8)
          out.second_transpose = trt.Permutation([0, 2, 1, 3])
      
          out = network.add_shuffle(out.get_output(0))
          out.reshape_dims = (2, -1, ch)
      
          # to_out
          outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.weight'.format(i)])
          outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.bias'.format(i)])
      
          out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
                                            outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
      
          out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
      
          return out
      

      這里把ffn同樣歸總到attention算子中,有一次全連接和一個gelu激活函數(shù),需要注意的是乘加結(jié)果是分開來算的。
      這里add_unary是一元算子,主要進行指數(shù)運算。

      def feed_forward(network, weight_map, i, ch, x):
          w1 = network.add_constant((1, ch * 8, ch), weight_map['{}.transformer_blocks.0.ff.net.0.proj.weight'.format(i)])
          b1 = network.add_constant((1, 1, ch * 8), weight_map['{}.transformer_blocks.0.ff.net.0.proj.bias'.format(i)])
          n = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                          w1.get_output(0), trt.MatrixOperation.TRANSPOSE)
          n = network.add_elementwise(n.get_output(0), b1.get_output(0), trt.ElementWiseOperation.SUM)
      
          hw = n.get_output(0).shape[1]
          # w = n.get_output(0).shape[3]
          n1 = network.add_slice(n.get_output(0), trt.Dims([0, 0, 0]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
          n2 = network.add_slice(n.get_output(0), trt.Dims([0, 0, ch * 4]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
      
          # gelu
          e = network.add_scale(n2.get_output(0), mode=trt.ScaleMode.UNIFORM, scale=trt.Weights(np.array([2 ** -0.5], np.float32)))
          e = network.add_unary(e.get_output(0), trt.UnaryOperation.ERF)
          e = network.add_scale(e.get_output(0), mode=trt.ScaleMode.UNIFORM,
                                scale=trt.Weights(np.array([0.5], np.float32)),
                                shift=trt.Weights(np.array([0.5], np.float32)))
      
          n = network.add_elementwise(n2.get_output(0), e.get_output(0), trt.ElementWiseOperation.PROD)
          n = network.add_elementwise(n.get_output(0), n1.get_output(0), trt.ElementWiseOperation.PROD)
      
          w2 = network.add_constant((1, ch, ch * 4), weight_map['{}.transformer_blocks.0.ff.net.2.weight'.format(i)])
          b2 = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.ff.net.2.bias'.format(i)])
          n = network.add_matrix_multiply(n.get_output(0), trt.MatrixOperation.NONE,
                                          w2.get_output(0), trt.MatrixOperation.TRANSPOSE)
          n = network.add_elementwise(n.get_output(0), b2.get_output(0), trt.ElementWiseOperation.SUM)
      
          return n
      

      關(guān)鍵模塊

      transformer模塊

      這里基礎(chǔ)的transformer就不再詳細探討,標準的attn1-attn2-ffn的過程,需要注意的是trt不支持4維操作,前后要多一次reshape操作。

      def basic_transformer(network, weight_map, i, ch, x, context):
          H = x.get_output(0).shape[2]
          W = x.get_output(0).shape[3]
      
          # n c h w -> b (h w) c
          x = network.add_shuffle(x.get_output(0))
          x.first_transpose = trt.Permutation([0, 2, 3, 1])
          x.reshape_dims = (2, -1, ch)
      
          # attn1
          n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm1'.format(i))
          
          attn1 = self_attention(network, weight_map, i, ch, n)
          x = network.add_elementwise(attn1.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
      
          # attn2
          n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm2'.format(i))
          attn2 = cross_attention(network, weight_map, i, ch, n, context)
          x = network.add_elementwise(attn2.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
      
          # ff
          n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm3'.format(i))
          ff = feed_forward(network, weight_map, i, ch, n)
          
          x = network.add_elementwise(ff.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
      
          # n (h w) c -> n c h w
          x = network.add_shuffle(x.get_output(0))
          x.first_transpose = trt.Permutation([0, 2, 1])
          x.reshape_dims = (2, ch, H, W)
          return x
      

      spatial_transformer是在basic_transformer基礎(chǔ)上加了兩次conv投影。

      def spatial_transformer(network, weight_map, i, ch, h, context):
          # return h
          # norm
          n = group_norm(network, weight_map, h, '{}.norm'.format(i), 1e-6)
          # proj_in
          n = conv(network, weight_map, n, ch, '{}.proj_in'.format(i), 1, 0, 1)
      
          # BasicTransformerBlock
          n = basic_transformer(network, weight_map, i, ch, n, context)
      
          # proj_out
          n = conv(network, weight_map, n, ch, '{}.proj_out'.format(i), 1, 0, 1)
      
          h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
          return h
      

      采樣模塊

      下采樣則是卷積操作,上采樣則是線性插值操作,zero_convs則是不改變原有特征圖大小。

      def input_first(network, weight_map, pre, h):
          h = conv(network, weight_map, h, 320, '{}.input_blocks.0.0'.format(pre), 3, 1, 1)
          return h
      
      def downsample(network, weight_map, i, ch, x):
          x = conv(network, weight_map, x, ch, '{}.op'.format(i), 3, 1, 2)
          return x
      
      def upsample(network, weight_map, i, ch, x):
          x = network.add_resize(x.get_output(0))
          x.scales = [1, 1, 2, 2]
          x.resize_mode = trt.ResizeMode.NEAREST
      
          x = conv(network, weight_map, x, ch, '{}.conv'.format(i), 3, 1, 1)
      
          return x
      
      def zero_convs(network, weight_map, x, i):
          ch = x.get_output(0).shape[1]
          x = conv(network, weight_map, x, ch, 'control_model.zero_convs.{}.0'.format(i), 1, 0, 1)
          return x
      

      block模塊

      resblock 是由倒瓶頸結(jié)構(gòu)的卷積塊組成的殘差連接模塊。

      def resblock(network, weight_map, embed_weight, i, ch, h, emb):
          print('resblock: ', h.get_output(0).shape, '{}.in_layers.0'.format(i))
          ## in_layers
          # group_norm
          n = group_norm(network, weight_map, h, '{}.in_layers.0'.format(i), silu=True)
          # silu
          # n = silu(network, n)
          # conv_nd
          n = conv(network, weight_map, n, ch, '{}.in_layers.2'.format(i), 3, 1, 1)
      
          print('in_layers: ', n.get_output(0).shape)
      
          ## emb_layers
          m = network.add_constant([20, ch, 1, 1], embed_weight.pop(0))
          m = network.add_gather(m.get_output(0), emb, axis=0)
          print('emb_layers: ', m.get_output(0).shape)
      
          n = network.add_elementwise(n.get_output(0), m.get_output(0), trt.ElementWiseOperation.SUM)
      
          ## out_layers
          n = group_norm(network, weight_map, n, '{}.out_layers.0'.format(i), silu=True)
          # n = silu(network, n)
          n = conv(network, weight_map, n, ch, '{}.out_layers.3'.format(i), 3, 1, 1)
      
          print('out_layers: ', n.get_output(0).shape)
      
          in_ch = h.get_output(0).shape[1]
          if in_ch != ch:
              # skip_connection
              h = conv(network, weight_map, h, ch, '{}.skip_connection'.format(i), 1, 0, 1)
      
          h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
          return h
      

      input_block則是由不同level、不同大小channel的resblock以及spatial_transformer組成的。
      middle_block則是resblock和spatial_transformer的組合。
      output_blocks與input_block類似,只不過由input_block中的下采樣變成了output_blocks中的上采樣。
      這三個block是unet中的重要組成部分,對應了Unet先下采樣到特征狀態(tài)再上采樣到對應圖像的過程。

      def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
          hs = []
          h = input_first(network, weight_map, 'control_model', h)
          h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
      
          h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
          h.mode = trt.SliceMode.WRAP
          hs.append(zero_convs(network, weight_map, h, 0))
          # h [2, 320, 32, 48]
      
          channel_mult = [1, 2, 4, 4]
          num_res_blocks = [2] * 4
      
          model_channels = 320
          index = 1
          for level, mult in enumerate(channel_mult):
              ch = model_channels * mult
              for nr in range(num_res_blocks[level]):
                  pre = 'control_model.input_blocks.{}'.format(index)
                  h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
                  print('resblock: ', h.get_output(0).shape)
                  if level != len(channel_mult) -1:
                      h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
                  hs.append(zero_convs(network, weight_map, h, index))
      
                  # ch = mult * model_channels
                  index = index + 1
      
              if level != len(channel_mult) - 1:
                  pre = 'control_model.input_blocks.{}'.format(index)
                  out_ch = ch
                  h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
                  hs.append(zero_convs(network, weight_map, h, index))
                  index = index + 1
              
              # if index == 10:
          return hs, h
      
      def input_block(network, weight_map, embed_weight, h, emb, context, model_name):
          hs = []
          h = input_first(network, weight_map, model_name, h)
          h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
          h.mode = trt.SliceMode.WRAP
      
          #return h
          hs.append(h)
      
          channel_mult = [1, 2, 4, 4]
          num_res_blocks = [2] * 4
      
          model_channels = 320
          index = 1
          for level, mult in enumerate(channel_mult):
              ch = model_channels * mult
              for nr in range(num_res_blocks[level]):
                  pre = '{}.input_blocks.{}'.format(model_name, index)
                  h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
                  print('resblock: ', h.get_output(0).shape)
                  if level != len(channel_mult) -1:
                      h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
                  hs.append(h)
      
                  # ch = mult * model_channels
                  index = index + 1
      
              if level != len(channel_mult) - 1:
                  pre = '{}.input_blocks.{}'.format(model_name, index)
                  out_ch = ch
                  h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
                  hs.append(h)
                  index = index + 1
              
              # if index == 10:
          return hs, h
      
      def middle_block(network, weight_map, embed_weight, h, emb, context, model_name):
          pre = '{}.middle_block'.format(model_name)
          h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), 1280, h, emb)
          h = spatial_transformer(network, weight_map, '{}.1'.format(pre), 1280, h, context)
          h = resblock(network, weight_map, embed_weight, '{}.2'.format(pre), 1280, h, emb)
          return h
      
      def output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs):
          channel_mult = [1, 2, 4, 4]
          num_res_blocks = [2] * 4
      
          model_channels = 320
          index = 0
          for level, mult in list(enumerate(channel_mult))[::-1]:
              ch = model_channels * mult
              for i in range(num_res_blocks[level] + 1):
                  print(control[-1].shape, hs[-1].shape, len(hs), h.get_output(0).shape)
                  c = network.add_elementwise(control.pop(), hs.pop(), trt.ElementWiseOperation.SUM)
                  h = network.add_concatenation([h.get_output(0), c.get_output(0)])
                  print('output: ', index, h.get_output(0).shape)
                  pre = 'model.diffusion_model.output_blocks.{}'.format(index)
                  h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
                  print('resblock: ', h.get_output(0).shape)
                  if level != len(channel_mult) -1:
                      h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
                  
                  if level and i == num_res_blocks[level]:
                      h = upsample(network, weight_map,
                                   '{}.{}'.format(pre, 1 if level == len(channel_mult) - 1 else 2), ch, h)
                  index = index + 1
          print(h.get_output(0).shape, len(hs), len(control), index)
          return h
      

      input_block_control是control_net的上半部分,在結(jié)構(gòu)參數(shù)上與Unet一樣,但是在每一層都添加了zero_convs層學習參數(shù)。

      def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
          hs = []
          h = input_first(network, weight_map, 'control_model', h)
          h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
      
          h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
          h.mode = trt.SliceMode.WRAP
          hs.append(zero_convs(network, weight_map, h, 0))
          # h [2, 320, 32, 48]
      
          channel_mult = [1, 2, 4, 4]
          num_res_blocks = [2] * 4
      
          model_channels = 320
          index = 1
          for level, mult in enumerate(channel_mult):
              ch = model_channels * mult
              for nr in range(num_res_blocks[level]):
                  pre = 'control_model.input_blocks.{}'.format(index)
                  h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
                  print('resblock: ', h.get_output(0).shape)
                  if level != len(channel_mult) -1:
                      h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
                  hs.append(zero_convs(network, weight_map, h, index))
      
                  # ch = mult * model_channels
                  index = index + 1
      
              if level != len(channel_mult) - 1:
                  pre = 'control_model.input_blocks.{}'.format(index)
                  out_ch = ch
                  h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
                  hs.append(zero_convs(network, weight_map, h, index))
                  index = index + 1
              
              # if index == 10:
          return hs, h
      

      網(wǎng)絡構(gòu)建模塊

      controlnet

      這里h, hint, emb經(jīng)過input_block_control得到control和h的特征,h經(jīng)過middle_block得到不同尺度特征的control特征。

      def control_net(network, weight_map, embed_weight, h, hint, emb, context):
          # #####################
          # # time_embed
          # #####################
      
          #####################
          # input_blocks
          #####################
          control, h = input_block_control(network, weight_map, embed_weight, h, emb, context, hint)
          print(h.get_output(0).shape)
      
          #####################
          # middle_blocks
          #####################   
          h = middle_block(network, weight_map, embed_weight, h, emb, context, 'control_model')
          h = conv(network, weight_map, h, 1280, 'control_model.middle_block_out.0', 1, 0, 1)
      
          control.append(h)
          return control
      

      Unet

      Unet的組成相對簡單,經(jīng)過input_block、middle_block和output_blocks得到最終結(jié)果,并返回最終狀態(tài)。

      def unet(network, weight_map, embed_weight, h, emb, context, control):
          # #####################
          # # time_embed
          # #####################
      
      
          #####################
          # input_blocks
          #####################
          hs, h = input_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
          print(h.get_output(0).shape)
      
          #####################
          # middle_blocks
          #####################   
          h = middle_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
          print(h.get_output(0).shape)
      
          h = network.add_elementwise(h.get_output(0), control.pop().get_output(0), trt.ElementWiseOperation.SUM)
      
          #####################
          # output_blocks
          #####################
          h = output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs)
      
          # out
          # group_norm
          # h = group_norm_sile(network, weight_map, h)
          h = group_norm(network, weight_map, h, 'model.diffusion_model.out.0', silu=True)
          # silu
          # h = silu(network, h)
          # conv_nd
          h = conv(network, weight_map, h, 4, 'model.diffusion_model.out.2', 3, 1, 1)
      
          return h
      

      參考

      1. nvidia python api: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
      2. xiatwhu: https://github.com/deeplearning/xiatwhu/trt2023
      posted @ 2023-10-07 17:30  wildkid1024  閱讀(464)  評論(0)    收藏  舉報
      主站蜘蛛池模板: 一个色综合色综合色综合| 麻豆国产va免费精品高清在线 | 国产永久免费高清在线| 国产日韩av一区二区在线| 色九月亚洲综合网| 人妻在线无码一区二区三区| 国厂精品114福利电影免费| 毛多水多高潮高清视频| 国产精品久久久久久久网| 性高湖久久久久久久久| 亚洲国产激情一区二区三区| 亚洲一国产一区二区三区| 97久久久亚洲综合久久| 精品素人AV无码不卡在线观看| 欧美日本一区二区视频在线观看| 艳妇乳肉豪妇荡乳xxx| 德格县| 国产亚洲一区二区三区av| 2021国产成人精品久久| 无遮挡粉嫩小泬久久久久久久| 久久a级片| 成熟熟女国产精品一区二区| 色噜噜一区二区三区| 欧美亚洲另类制服卡通动漫 | 亚洲av熟女国产一二三| 亚洲精品国产中文字幕| 国产在线中文字幕精品| 人人妻人人澡人人爽人人精品av| 国产伦码精品一区二区| 亚洲最大日韩精品一区| 久久91精品牛牛| 国产精品一区二区久久不卡| 无码福利写真片视频在线播放| 粉嫩av一区二区三区蜜臀| 91中文字幕在线一区| 久久精品一区二区三区中文字幕| 日本肉体xxxx裸交| 国产精品剧情亚洲二区| 大胸美女被吃奶爽死视频| 国产无遮挡免费视频免费| 久久天天躁狠狠躁夜夜2020老熟妇|