<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      Ctorch開發(fā)日志——矩陣乘法優(yōu)化及數(shù)學(xué)原理

      隨著項(xiàng)目的推進(jìn),本作者遇到了目前最棘手的問題,即矩陣乘法的優(yōu)化

      但是有句話說得好

      “你越棘手,我越興奮”

      那么,如下是本作者如何把\(O(MNK)\)\(O(n^3)\))的樸素矩陣乘法一步一步優(yōu)化到\(O(n^{2.81})\) 的全過程

      測(cè)試環(huán)境

      macOS Tahoe 26 Beta 2
      M3 Pro 11核
      18GB
      CLion & Cmake
      計(jì)時(shí)器:ctime
      矩陣:1024 * 1024 @ 1024 * 1024
      為保證準(zhǔn)確,時(shí)間均為5次測(cè)量取平均值

      樸素算法實(shí)現(xiàn) & 測(cè)速

      樸素實(shí)現(xiàn)的數(shù)學(xué)原理

      其實(shí)就是把矩陣乘的數(shù)學(xué)公式重寫一遍:

      \[A={\left[ a_{ij}\right]_{m \times n}} \]

      \[B={\left[ b_{ij}\right]_{n \times s}} \]

      \[C= {A \times B}=\left[ c_{ij}\right]_{m \times s} \]

      \[= \left[ \sum \limits_{k=1}^{n}a_{ik}b_{kj}\right]_{m \times s} \]

      注意:矩陣乘法不滿足交換律
      只有左矩陣的列數(shù)與右矩陣的行數(shù)相同的兩個(gè)矩陣才能相乘
      乘積矩陣的行數(shù)等于左矩陣的行數(shù),列數(shù)等于右矩陣的列數(shù)

      概括一下就是

      乘積矩陣第i行第j列處的元素等于左矩陣的第i行與右矩陣的第j列對(duì)應(yīng)元素乘積之和

      那么,這個(gè)很簡(jiǎn)單,上代碼吧
      為了測(cè)試,所有的矩陣保證滿足乘法條件且為2維

      時(shí)間復(fù)雜度 \(O(n^3)\)

      // 原始版本(未優(yōu)化)
      void matrix_mult(float* A, float* B, float* C, int N) {
          for (int i = 0; i < N; i++)
              for (int j = 0; j < N; j++)
                  for (int k = 0; k < N; k++)
                      C[i*N + j] += A[i*N + k] * B[k*N + j];
      }
      

      在實(shí)際測(cè)試中,此算法跑出了1898.46ms優(yōu)秀成績(jī)

      優(yōu)化一:循環(huán)優(yōu)化

      你可能會(huì)疑惑,循環(huán)優(yōu)化是什么
      故名思義,就是對(duì)原有的ijk的循環(huán)重新更換順序?yàn)閕kj

      一個(gè)更大的問題來了,憑什么僅僅改變順序就快了許多

      那么不妨看看訪問順序

      算法1(樸素實(shí)現(xiàn)):

      在這個(gè)順序中,最內(nèi)層循環(huán)是k,它遍歷A的一行和B的一列。
      對(duì)于A的訪問是連續(xù)的(因?yàn)锳[i][k]在內(nèi)存中是按行存儲(chǔ)的,所以k增加時(shí)是連續(xù)訪問),
      但是B的訪問是不連續(xù)的(因?yàn)锽[k][j]在內(nèi)存中是按行存儲(chǔ),k增加時(shí)訪問的是不同行的同一列,所以是跳躍訪問)。這樣對(duì)B的訪問會(huì)導(dǎo)致緩存失效。

      so,真正影響到速度的,就是緩存,在順序讀取中,緩存可以加載一整行,不必跳躍元素訪問

      那么,有沒有一種循環(huán)順序,使得對(duì)三個(gè)數(shù)組均為順序訪問呢
      有的兄弟,有的,讓我們歡迎仍為\(O(n^{3})\)的優(yōu)化算法出場(chǎng)
      ——“ikj”循環(huán)優(yōu)化

      在這個(gè)順序中,最內(nèi)層循環(huán)是j。對(duì)于A的訪問,固定i和k,所以每次內(nèi)層循環(huán)A[i][k]是常數(shù)。對(duì)于B的訪問,是B[k][j],由于j是連續(xù)的,所以B的訪問是連續(xù)的(因?yàn)橥恍羞B續(xù)列)。同時(shí),C的訪問也是連續(xù)的(C[i][j])。這樣,所有的內(nèi)存訪問都是連續(xù)的,因此性能更好。

      可以自行驗(yàn)證,對(duì)于三個(gè)數(shù)組,均為順序訪問
      給出如下代碼:
      時(shí)間復(fù)雜度 \(O(n^3)\)

      // 優(yōu)化后(行優(yōu)先訪問)
      void matrix_mult_opt1(float* A, float* B, float* C, int N) {
          for (int i = 0; i < N; i++)
              for (int k = 0; k < N; k++)  // k循環(huán)提到中間
                  for (int j = 0; j < N; j++)
                      C[i*N + j] += A[i*N + k] * B[k*N + j];
      }
      

      在實(shí)際測(cè)試中,此算法跑出了1462.89ms的優(yōu)秀成績(jī)
      提升:1898.46ms-1462.89ms = 435.57ms 提高22.9%

      進(jìn)階提升 優(yōu)化二:矩陣分塊算法

      顧名思義,分塊算法即是把矩陣分為多個(gè)小矩陣,對(duì)每個(gè)矩陣操作后再組合出結(jié)果,類似分塊算法

      那么,它為什么快呢

      在計(jì)算機(jī)中,共有三種CPU緩存以及普通內(nèi)存,即L1、L2、L3 Cache和內(nèi)存
      前三種的速度要比內(nèi)存快很多很多,大概只有一個(gè)CPU周期的延遲,而普通內(nèi)存可以達(dá)到上百周期延遲
      而比他們更快的就是寄存器,直接接觸CPU,0延遲
      唯一的問題是,寄存器的大小只夠存儲(chǔ)單個(gè)值

      新的問題來了,怎樣把一個(gè)巨大的矩陣放到只有128k-1m的L1緩存中呢

      聰明的你一定想到了把原矩陣劃分為多個(gè)小矩陣,每一個(gè)都能放到L1內(nèi)進(jìn)行運(yùn)算
      那么恭喜你,你已經(jīng)知道了矩陣分塊算法的原理

      更確切的說,先定義一個(gè)常數(shù) \(blocks \in N^{+}\) 作為劃分的單位矩陣的行列,
      對(duì)于原矩陣A,我們把其中的 \({blocks \times blocks}\) 個(gè)元素劃分為一個(gè)新矩陣,記為\(a^{'}_{\cdots}\),我們定義:

      \[ a^{'}_{11} = \begin{pmatrix} a_{11} & \cdots & a_{1blocks} \\ \vdots & \ddots &\vdots \\ a_{blocks1} & \cdots & a_{blocks_{ }blocks} \end{pmatrix} \]

      其余以此類推,新矩陣\(A^{'}\)即變?yōu)?/p>

      \[A^{'}=\begin{pmatrix} a^{'}_{11} & \cdots & a^{'}_{1n} \\ \vdots & \ddots & \vdots \\ a^{'}_{m1} & \cdots & a^{'}_{mn} \end{pmatrix} \]

      正如同我們可以把 \(f(x)\) 中的 \(x\) 替換為任意多項(xiàng)式(函數(shù)),我們同樣也可以把矩陣中的每個(gè)元素?fù)Q為一個(gè)矩陣,其運(yùn)算規(guī)則仍然成立


      so,顯而易見的,分塊矩陣算法有如下公式:

      對(duì)于整體而言,

      \[\begin{array}{c} C={A^{'}\times B^{'}}=\left[ c_{ij}\right]_{m \times s} = \left[ \sum \limits_{k=1}^{n}a^{'}_{ik}b^{'}_{kj}\right]_{m \times s} \end{array}\]

      其中:

      \[ a^{'}_{ij} = \begin{pmatrix} a_{[(i-1) \times blocks]{[(j-1)\times blocks]}} & \cdots & a_{[(i-1) \times blocks]{[j\times blocks]}} \\ \vdots & \ddots &\vdots \\ a_{[i\times blocks]{[(j-1)\times blocks]}} & \cdots & a_{[i\times blocks]{[j\times blocks]}} \end{pmatrix} \]

      其中的每個(gè)乘法 \(a'{ik} \times b'{kj}\) 是子矩陣乘法
      這里為了方便看,默認(rèn)原矩陣的行列均為 blocks 的倍數(shù)
      那么下一個(gè)很自然的問題就是

      若行列不為blocks的倍數(shù),怎么辦

      分兩種情況:

      1. $ min(m,n) < blocks $
      2. $ \exists m,n \nmid blocks $

      對(duì)于1,直接執(zhí)行普通矩陣乘法即可,因?yàn)檎麄€(gè)矩陣均可放于L1、L2 Cache中
      對(duì)于2,我們定義分塊矩陣的大小為$ p,q $

      \[p = min(blocks,M- i \times blocks) \]

      \[q = min(blocks,N - j \times blocks) \]

      其中,

      \[M,N為被分塊矩陣的行列 \]

      \[i,j為分塊矩陣a^{'}_{ij}的下標(biāo) \]

      至此,分塊矩陣的全部問題已經(jīng)解決

      給出如下代碼:
      時(shí)間復(fù)雜度\(O(n^3)\)

      void block_mult(float* A, float* B, float* C, int N, int BLOCK) {
          // 清除結(jié)果矩陣
          memset(C, 0, N*N*sizeof(float));
          // 三層分塊循環(huán)
          for (int i0 = 0; i0 < N; i0 += BLOCK) {
              int i_end = min(i0 + BLOCK, N);  // 計(jì)算行邊界
              for (int k0 = 0; k0 < N; k0 += BLOCK) {
                  int k_end = min(k0 + BLOCK, N);  // 計(jì)算中間維度邊界
                  for (int j0 = 0; j0 < N; j0 += BLOCK) {
                      int j_end = min(j0 + BLOCK, N);  // 計(jì)算列邊界
                      // 核心計(jì)算:只處理完整塊內(nèi)的元素
                      for (int i = i0; i < i_end; i++) {
                          for (int k = k0; k < k_end; k++) {
                              float a_val = A[i*N + k];  // 一次加載A元素
                              // 內(nèi)層循環(huán):連續(xù)訪問B和C
                              for (int j = j0; j < j_end; j++) {
                                  C[i*N + j] += a_val * B[k*N + j];
                              }
                          }
                      }
                  }
              }
          }
      }
      

      由于矩陣過小時(shí),分塊算法優(yōu)勢(shì)不大,且會(huì)增加調(diào)用開銷,因此,這里的測(cè)試,\(m,n\)為2048

      實(shí)測(cè)結(jié)果:\(blocks = 512\) 時(shí),用時(shí) 11515.2ms

      而不使用分塊僅循環(huán)優(yōu)化的算法 用時(shí) 16028.9ms
      樸素實(shí)現(xiàn) 用時(shí) 18053.45ms
      提升:16028.9ms-11515.2ms = \(4513.2ms\) 提高:\(28.1\%\)

      高手過招 優(yōu)化三 :并行與SIMD

      何為并行與SIMD
      并行:多線程同時(shí)處理多個(gè)分塊
      SIMD:乘加一體,即一條CPU指令同時(shí)處理乘與加
      我們這里使用Apple的AMX(Apple Matrix協(xié)處理器)(也屬于CPU的一部分,并非GPU優(yōu)化)
      對(duì)于x86架構(gòu)和其余ARM架構(gòu)的處理器,可以使用AVX、AVX_512、SSE等SIMP指令集

      它的特性有:

      • Apple Silicon芯片(M1/M2/M3等)內(nèi)置的專用矩陣運(yùn)算單元
      • 可并行處理大量16位浮點(diǎn)(FP16)或整數(shù)(INT8)運(yùn)算

      每個(gè)AMX單元包含:

      • 8個(gè)32KB的寄存器文件
      • 可同時(shí)執(zhí)行2048次乘加運(yùn)算(16x16x8矩陣)
      • 專用數(shù)據(jù)通路減少內(nèi)存訪問延遲

      以及自動(dòng)分塊
      如下是使用AMX的SIMP的代碼:
      時(shí)間復(fù)雜度:\(O(n^3)\)

      // 使用Apple的AMX加速BLAS庫
                  cblas_sgemm(CblasRowMajor,   // 行主序存儲(chǔ)
                              CblasNoTrans,   // 不轉(zhuǎn)置A
                              CblasNoTrans,   // 不轉(zhuǎn)置B
                              M,              // A的行數(shù)
                              N,              // B的列數(shù)
                              K,              // 公共維度
                              1.0f,           // alpha系數(shù)
                              a_data,         // A數(shù)據(jù)指針
                              K,              // A的列步幅(lda)
                              b_data,         // B數(shù)據(jù)指針
                              N,              // B的列步幅(ldb)
                              0.0f,           // beta系數(shù)
                              r_data,         // 結(jié)果數(shù)據(jù)指針
                              N);             // 結(jié)果的列步幅(ldc)
      

      那么,本次優(yōu)化最嚇人、最恐怖的一次數(shù)據(jù)來了:
      實(shí)測(cè)數(shù)據(jù):202.831ms(4096*4096)
      而標(biāo)準(zhǔn)分塊+循環(huán)優(yōu)化算法,在2048*2048時(shí),就已經(jīng)11515.2ms
      提升:11312.37 ms 提高:98.2%

      數(shù)學(xué)手段 優(yōu)化4:Strassen算法 & 變種

      溫馨提示:到這里已經(jīng)是高等數(shù)學(xué)內(nèi)容了(實(shí)不相瞞,前面其實(shí)也是),有點(diǎn)小燒腦,不過歡迎各位繼續(xù)跟作者一起嘗試,本作者大約花了3小時(shí)搞完這一部分的證明

      介紹:

      Strassen算法是一種通過數(shù)學(xué)變換減少乘法次數(shù)的高效矩陣乘/卷積算法

      簡(jiǎn)要推導(dǎo):

      我們?cè)O(shè)有如下兩個(gè)矩陣相乘:

      \[\begin{pmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{pmatrix} \times \begin{pmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{pmatrix} = \begin{pmatrix} c_{11} & c_{12} \\ c_{21} &c_{22} \end{pmatrix} \]

      傳統(tǒng)計(jì)算需要8次乘法:

      \[c_{11} = a_{11}\times b_{11} + a_{12}\times b_{21} \\ \]

      \[c_{12} = a_{11}\times b_{12} + a_{12}\times b_{22} \\ \]

      \[c_{21} = a_{21}\times b_{11} + a_{22}\times b_{21} \\ \]

      \[c_{22} = a_{21}\times b_{12} + a_{22}\times b_{22} \\ \]

      而Strassen算法只需7次乘
      接下來,是Strassen算法最精妙絕倫的一步:
      作者定義了7個(gè)矩陣:

      \[\begin{align*} M_1 &= (a_{11} + a_{22})(b_{11} + b_{22}) \\ M_2 &= (a_{21} + a_{22})b_{11} \\ M_3 &= a_{11}(b_{12} - b_{22}) \\ M_4 &= a_{22}(b_{21} - b_{11}) \\ M_5 &= (a_{11} + a_{12})b_{22} \\ M_6 &= (a_{21} - a_{11})(b_{11} + b_{12}) \\ M_7 &= (a_{12} - a_{22})(b_{21} + b_{22}) \end{align*} \]

      真正讓人驚訝的是下一步:
      作者構(gòu)建的7個(gè)矩陣,可以通過有限次的組合成為結(jié)果矩陣的一個(gè)元素
      什么意思呢,讓我們嘗試展開其中一項(xiàng):

      \[c_{11} = M_1 + M_4 - M_5 + M_7 \\ \]

      \[\begin{align*} &=(a_{11} + a_{22})(b_{11} + b_{22}) + a_{22}(b_{21} - b_{11}) - (a_{11} + a_{12})b_{22} + (a_{12} - a_{22})(b_{21} + b_{22}) \\ \end{align*}\]

      \[\begin{align*} = & \ \ a_{11}b_{11} + \cancel{a_{11}b_{22}} + \cancel{a_{22}b_{11}} + \cancel{a_{22}b_{22}} \\ &+ \ \cancel{a_{22}b_{21}} - \cancel{a_{22}b_{11}} \\ &- \ \cancel{a_{11}b_{22}} - a_{12}b_{22} \\ &+ \ a_{12}b_{21} + \cancel{a_{12}b_{22}} - \cancel{a_{22}b_{21}} - \cancel{a_{22}b_{22}} \\ = & \ \ a_{11}b_{11} + a_{12}b_{21} \end{align*}\]

      由此,可以類似的推出\(c_{12}、c_{21}、c_{22}\)均與正常計(jì)算一致
      最后的結(jié)果矩陣為

      \[C= \begin{pmatrix} M_{1}+M_{4}-M_{5}+M_{7} & M_{3}+M_{5} \\ M_{2}+M_{4} & M_{1}-M_{2}+M_{3}+M_{6} \end{pmatrix} \]

      接下來,我們證明其對(duì)于n>2時(shí)仍然成立:
      由于( n=2 )時(shí),我們已經(jīng)證明其正確
      所以,在n>2時(shí),我們采取分塊
      將原矩陣分為4塊,此時(shí)我們將其中的子矩陣看為一個(gè)元素
      那么此時(shí)又回歸到了標(biāo)準(zhǔn)的2x2的Strassen算法
      由此在$n,m \mid 2 $時(shí)Strassen算法正確
      那么,下一個(gè)很自然的問題就是

      \(n,m \nmid {2} ,結(jié)論是否成立\)

      我們的做法是,將矩陣分塊,分為幾個(gè)\(2^k \times 2^k\)的子矩陣以及幾個(gè)符合矩陣乘規(guī)則的小矩陣
      顯然由于前文的分塊算法的正確性,此時(shí)的分塊仍然正確,對(duì)于幾個(gè)\(2^k \times 2^k\)的矩陣,我們使用Strassen算法進(jìn)行計(jì)算
      現(xiàn)在來計(jì)算一下Strassen算法的時(shí)間復(fù)雜度:
      設(shè)$ T(n) $ 為計(jì)算 $ n \times n $ 矩陣乘法的時(shí)間:
      \(T(n) = 7T\left(\frac{n}{2}\right) + O(n^2)\)

      • $ 7T(n/2) $:7 個(gè)子問題遞歸計(jì)算
      • $O(n^2) $:矩陣加減法開銷(共 18 次加減法)

      根據(jù)主定理1,
      $ a = 7,b = 2,f(n)=\Theta(n^2) $
      \(log _b a = log_27 \approx 2.807\)
      由于\(log _b a > 2\),所以\(f(n) = O({n^{log_b a-\epsilon})} = O(n^{log_27}) \approx O(n^{2.807})\)
      更精確的,復(fù)雜度為\(\Theta(n^{2.807})\)

      具體的算法為:
      1.先將矩陣AB分塊,分成大小為 \({blocks \times blocks}\) 的若干塊以及幾個(gè)任意大小的子塊
      2.對(duì)于 \({blocks \times blocks}\) 的子塊,我們使用Strassen算法計(jì)算
      3.在遞歸過程中,若方陣大小(因?yàn)镾trassen算法開始時(shí)為方陣)n = 128,則使用循環(huán)優(yōu)化的矩陣乘法
      4.否則,繼續(xù)按照Strassen算法遞歸計(jì)算直至n = 128
      5.對(duì)于不是 \({blocks \times blocks}\) 的子塊,使用循環(huán)優(yōu)化的矩陣乘法計(jì)算
      給出如下代碼:
      近似時(shí)間復(fù)雜度\(O(n^{2.81})\)

      const int BLOCK_SIZE = 2048;
      const int STRASSEN_THRESHOLD = 128;
      
      // 標(biāo)準(zhǔn)矩陣乘法 (用于小矩陣和邊界處理)
      void standard_matmul(const float* A, const float* B, float* C, int n, int m, int p, int lda, int ldb, int ldc) {
          for (int i = 0; i < n; ++i) {
              for (int k = 0; k < m; ++k) {
                  float a = A[i * lda + k];
                  for (int j = 0; j < p; ++j) {
                      C[i * ldc + j] += a * B[k * ldb + j];
                  }
              }
          }
      }
      
      // 循環(huán)優(yōu)化矩陣乘法 (n=128時(shí)使用)
      void optimized_matmul(const float* A, const float* B, float* C, int n, int lda, int ldb, int ldc) {
          for (int i = 0; i < n; ++i) {
              for (int k = 0; k < n; ++k) {
                  float a = A[i * lda + k];
                  for (int j = 0; j < n; ++j) {
                      C[i * ldc + j] += a * B[k * ldb + j];
                  }
              }
          }
      }
      
      // 矩陣加法
      void matrix_add(const float* A, const float* B, float* C, int n, int lda, int ldb, int ldc) {
          for (int i = 0; i < n; ++i) {
              for (int j = 0; j < n; ++j) {
                  C[i * ldc + j] = A[i * lda + j] + B[i * ldb + j];
              }
          }
      }
      
      // 矩陣減法
      void matrix_subtract(const float* A, const float* B, float* C, int n, int lda, int ldb, int ldc) {
          for (int i = 0; i < n; ++i) {
              for (int j = 0; j < n; ++j) {
                  C[i * ldc + j] = A[i * lda + j] - B[i * ldb + j];
              }
          }
      }
      
      // 結(jié)果累加到目標(biāo)矩陣
      void matrix_add_to_target(float* T, const float* S, int n, int ldt, int lds) {
          for (int i = 0; i < n; ++i) {
              for (int j = 0; j < n; ++j) {
                  T[i * ldt + j] += S[i * lds + j];
              }
          }
      }
      
      // Strassen 矩陣乘法 (遞歸實(shí)現(xiàn))
      void strassen_matmul(const float* A, const float* B, float* C, int n, int lda, int ldb, int ldc) {
          // 遞歸基: n <= 128 使用優(yōu)化乘法
          if (n <= STRASSEN_THRESHOLD) {
              optimized_matmul(A, B, C, n, lda, ldb, ldc);
              return;
          }
      
          int half = n / 2;
          // 定義子矩陣指針
          const float* A11 = A;
          const float* A12 = A + half;
          const float* A21 = A + half * lda;
          const float* A22 = A + half * lda + half;
          
          const float* B11 = B;
          const float* B12 = B + half;
          const float* B21 = B + half * ldb;
          const float* B22 = B + half * ldb + half;
          
          float* C11 = C;
          float* C12 = C + half;
          float* C21 = C + half * ldc;
          float* C22 = C + half * ldc + half;
      
          // 分配臨時(shí)矩陣
          std::vector<float> S1(half * half);
          std::vector<float> S2(half * half);
          std::vector<float> S3(half * half);
          std::vector<float> S4(half * half);
          std::vector<float> S5(half * half);
          std::vector<float> S6(half * half);
          std::vector<float> S7(half * half);
          std::vector<float> S8(half * half);
          std::vector<float> S9(half * half);
          std::vector<float> S10(half * half);
          
          std::vector<float> P1(half * half);
          std::vector<float> P2(half * half);
          std::vector<float> P3(half * half);
          std::vector<float> P4(half * half);
          std::vector<float> P5(half * half);
          std::vector<float> P6(half * half);
          std::vector<float> P7(half * half);
      
          // 計(jì)算S矩陣
          matrix_subtract(B12, B22, S1.data(), half, ldb, ldb, half);    // S1 = B12 - B22
          matrix_add(A11, A12, S2.data(), half, lda, lda, half);         // S2 = A11 + A12
          matrix_add(A21, A22, S3.data(), half, lda, lda, half);         // S3 = A21 + A22
          matrix_subtract(B21, B11, S4.data(), half, ldb, ldb, half);    // S4 = B21 - B11
          matrix_add(A11, A22, S5.data(), half, lda, lda, half);         // S5 = A11 + A22
          matrix_add(B11, B22, S6.data(), half, ldb, ldb, half);         // S6 = B11 + B22
          matrix_subtract(A12, A22, S7.data(), half, lda, lda, half);    // S7 = A12 - A22
          matrix_add(B21, B22, S8.data(), half, ldb, ldb, half);         // S8 = B21 + B22
          matrix_subtract(A11, A21, S9.data(), half, lda, lda, half);    // S9 = A11 - A21
          matrix_add(B11, B12, S10.data(), half, ldb, ldb, half);        // S10 = B11 + B12
      
          // 遞歸計(jì)算P矩陣
          strassen_matmul(A11, S1.data(), P1.data(), half, lda, half, half);      // P1 = A11 * S1
          strassen_matmul(S2.data(), B22, P2.data(), half, half, ldb, half);      // P2 = S2 * B22
          strassen_matmul(S3.data(), B11, P3.data(), half, half, ldb, half);      // P3 = S3 * B11
          strassen_matmul(A22, S4.data(), P4.data(), half, lda, half, half);      // P4 = A22 * S4
          strassen_matmul(S5.data(), S6.data(), P5.data(), half, half, half, half); // P5 = S5 * S6
          strassen_matmul(S7.data(), S8.data(), P6.data(), half, half, half, half); // P6 = S7 * S8
          strassen_matmul(S9.data(), S10.data(), P7.data(), half, half, half, half);// P7 = S9 * S10
      
          // 組合結(jié)果矩陣 (累加到C)
          // C11 = P5 + P4 - P2 + P6
          matrix_add_to_target(C11, P5.data(), half, ldc, half);
          matrix_add_to_target(C11, P4.data(), half, ldc, half);
          matrix_add_to_target(C11, P6.data(), half, ldc, half);
          for (int i = 0; i < half; ++i) {
              for (int j = 0; j < half; ++j) {
                  C11[i * ldc + j] -= P2[i * half + j];
              }
          }
          
          // C12 = P1 + P2
          matrix_add_to_target(C12, P1.data(), half, ldc, half);
          matrix_add_to_target(C12, P2.data(), half, ldc, half);
          
          // C21 = P3 + P4
          matrix_add_to_target(C21, P3.data(), half, ldc, half);
          matrix_add_to_target(C21, P4.data(), half, ldc, half);
          
          // C22 = P5 + P1 - P3 - P7
          matrix_add_to_target(C22, P5.data(), half, ldc, half);
          matrix_add_to_target(C22, P1.data(), half, ldc, half);
          for (int i = 0; i < half; ++i) {
              for (int j = 0; j < half; ++j) {
                  C22[i * ldc + j] -= (P3[i * half + j] + P7[i * half + j]);
              }
          }
      }
      
      // 分塊矩陣乘法
      void matrix_multiply(const float* A, const float* B, float* C, int n, int m, int p, int lda, int ldb, int ldc) {
          // 初始化輸出矩陣為0
          std::memset(C, 0, n * ldc * sizeof(float));
          
          // 分塊處理
          for (int i = 0; i < n; i += BLOCK_SIZE) {
              int i_end = std::min(i + BLOCK_SIZE, n);
              int i_size = i_end - i;
              
              for (int k = 0; k < m; k += BLOCK_SIZE) {
                  int k_end = std::min(k + BLOCK_SIZE, m);
                  int k_size = k_end - k;
                  
                  for (int j = 0; j < p; j += BLOCK_SIZE) {
                      int j_end = std::min(j + BLOCK_SIZE, p);
                      int j_size = j_end - j;
                      
                      // 當(dāng)前塊指針
                      const float* A_block = A + i * lda + k;
                      const float* B_block = B + k * ldb + j;
                      float* C_block = C + i * ldc + j;
                      
                      // 完整塊使用Strassen算法
                      if (i_size == BLOCK_SIZE && k_size == BLOCK_SIZE && j_size == BLOCK_SIZE) {
                          strassen_matmul(A_block, B_block, C_block, BLOCK_SIZE, lda, ldb, ldc);
                      } 
                      // 非完整塊使用標(biāo)準(zhǔn)乘法
                      else {
                          standard_matmul(A_block, B_block, C_block, i_size, k_size, j_size, lda, ldb, ldc);
                      }
                  }
              }
          }
      }
      }
      

      當(dāng)然,實(shí)際測(cè)試中,我們使用Ctorch框架的Tensor類Op::Add,與此代碼會(huì)略有差異
      最終測(cè)試結(jié)果:1005.65ms
      P.S.提升不明顯的原因是矩陣過小,如果使用Tranformer架構(gòu)的巨型矩陣測(cè)試,\(O(n^{2.81})\)的優(yōu)勢(shì)會(huì)非常明顯

      最終的方案:

      我們使用多函數(shù)策略:
      1.若dim<128 此時(shí)的拷貝開銷已經(jīng)大于AMX的優(yōu)化,因此使用循環(huán)優(yōu)化
      2.128<dim<4096 AMX的最優(yōu)區(qū)間,使用純AMX
      3.dim>4096 分塊,對(duì)于能夠分為\(2^k\)的塊,使用Strassen算法,遞歸到2048使用AMX
      對(duì)于不是\(2^k\)的塊,直接使用AMX計(jì)算,同時(shí),每個(gè)分塊使用單一線程

      最后的結(jié)果:

      測(cè)試矩陣:16384163842(\(2^{14}\)
      標(biāo)準(zhǔn):30251ms
      循環(huán)優(yōu)化:24580ms
      分塊:20498ms
      最終方案(SIMP+多線程+分塊+Strassen):8267.97ms

      最后

      如果你希望既追求高性能又追求簡(jiǎn)潔的框架,那么Ctorch將是你的最優(yōu)選擇
      盡管這個(gè)項(xiàng)目還在開發(fā)中,但是可以先小小的期待一下
      歡迎貢獻(xiàn),如有錯(cuò)誤,請(qǐng)各位不吝賜教,謝謝
      2025.8.2

      posted @ 2025-08-02 21:48  Ghost-Face  閱讀(159)  評(píng)論(0)    收藏  舉報(bào)
      主站蜘蛛池模板: 亚洲VA成无码人在线观看天堂| 一本大道av人久久综合| 亚洲午夜精品国产电影在线观看| 国产睡熟迷奷系列网站| 国产精品成人一区二区不卡| 又黄又爽又色的少妇毛片| 东京热高清无码精品| 亚洲精品无码人妻无码| 国产不卡在线一区二区| 野花香视频在线观看免费高清版| 一级女性全黄久久生活片| 人成午夜免费大片| 亚洲精品麻豆一二三区| 日韩av一中美av一中文字慕| 日本久久香蕉一本一道| 日韩精品久久久肉伦网站| 天堂av色综合久久天堂| 97精品人妻系列无码人妻| 精品国产中文字幕av| 亚洲综合色婷婷中文字幕| 亚洲午夜成人精品电影在线观看 | 大陆精大陆国产国语精品| 亚洲性av网站| 亚洲日韩国产二区无码| 九九热在线免费精品视频| 97久久精品人人做人人爽| 亚洲成A人片在线观看的电影| 丰满人妻一区二区三区无码AV| 精精国产XXX在线观看| 封开县| 日韩一区二区三区av在线| 亚洲精品一区二区18禁| 精品久久精品午夜精品久久| 欧美一区二区三区欧美日韩亚洲| 老熟妇国产一区二区三区| 国产精品18久久久久久麻辣| 一区二区三区不卡国产| аⅴ天堂国产最新版在线中文| 无码欧亚熟妇人妻AV在线外遇 | 人妻少妇无码精品专区| 日本伊人色综合网|