📜  分而治之|集合5(Strassen的矩阵乘法)(1)

📅  最后修改于: 2023-12-03 15:36:50.703000             🧑  作者: Mango

分而治之 | Strassen的矩阵乘法

介绍

Strassen的矩阵乘法是一种基于分而治之(divide and conquer)策略的快速矩阵乘法算法,其时间复杂度比通常的矩阵乘法快。

该算法的核心在于通过将两个大的矩阵分成小的矩阵,进而通过小矩阵的乘法和加法减少计算量。

算法步骤
  1. 首先将两个矩阵 $A$ 和 $B$ 划分成4个相等的子矩阵,即:

    $$ A = \begin{bmatrix} A_{11} & A_{12} \ A_{21} & A_{22} \end{bmatrix}, B = \begin{bmatrix} B_{11} & B_{12} \ B_{21} & B_{22} \end{bmatrix} $$

  2. 然后计算出7个矩阵 $M_1$ 到 $M_7$:

    $$ \begin{aligned} M_1 &= (A_{11}+A_{22})(B_{11}+B_{22}) \ M_2 &= (A_{21}+A_{22})B_{11} \ M_3 &= A_{11}(B_{12}-B_{22}) \ M_4 &= A_{22}(B_{21}-B_{11}) \ M_5 &= (A_{11}+A_{12})B_{22} \ M_6 &= (A_{21}-A_{11})(B_{11}+B_{12}) \ M_7 &= (A_{12}-A_{22})(B_{21}+B_{22}) \end{aligned} $$

  3. 接着计算出中间矩阵 $C$ 的4个子矩阵 $C_{11}$, $C_{12}$, $C_{21}$, $C_{22}$:

    $$ \begin{aligned} C_{11} &= M_1 + M_4 - M_5 + M_7 \ C_{12} &= M_3 + M_5 \ C_{21} &= M_2 + M_4 \ C_{22} &= M_1 - M_2 + M_3 + M_6 \end{aligned} $$

  4. 最后,将子矩阵 $C_{11}$, $C_{12}$, $C_{21}$, $C_{22}$ 合并为最终矩阵 $C$。

代码实现

下面是Python实现Strassen矩阵乘法的代码片段:

def strassen(A, B):
    # 计算矩阵C的大小
    n = len(A)
    C = [[0] * n for _ in range(n)]

    # 边界情况
    if n == 1:
        C[0][0] = A[0][0] * B[0][0]
        return C

    # 将矩阵A和B分成4个子矩阵
    mid = n // 2
    A11, A12, A21, A22 = A[:mid], A[mid:], B[:mid], B[mid:]
    B11, B12, B21, B22 = [a[:mid] for a in A], [a[mid:] for a in A], [b[:mid] for b in B], [b[mid:] for b in B]

    # 计算7个矩阵M1-M7
    M1 = strassen(add(A11, A22), add(B11, B22))
    M2 = strassen(add(A21, A22), B11)
    M3 = strassen(A11, sub(B12, B22))
    M4 = strassen(A22, sub(B21, B11))
    M5 = strassen(add(A11, A12), B22)
    M6 = strassen(sub(A21, A11), add(B11, B12))
    M7 = strassen(sub(A12, A22), add(B21, B22))

    # 计算C的4个子矩阵
    C11 = add(sub(add(M1, M4), M5), M7)
    C12 = add(M3, M5)
    C21 = add(M2, M4)
    C22 = add(sub(add(M1, M3), M2), M6)

    # 合并4个子矩阵为矩阵C
    for i in range(mid):
        for j in range(mid):
            C[i][j] = C11[i][j]
            C[i][j+mid] = C12[i][j]
            C[i+mid][j] = C21[i][j]
            C[i+mid][j+mid] = C22[i][j]

    return C

# 辅助函数:矩阵加法
def add(A, B):
    return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]

# 辅助函数:矩阵减法
def sub(A, B):
    return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
总结

Strassen的矩阵乘法算法通过分而治之的策略,减少了计算量,提高了矩阵乘法的计算效率。但是它也存在一些局限性,比如只适用于大矩阵乘法、需要使用浮点数运算等。