📜  分而治之|第 5 组(施特拉森矩阵乘法)

📅  最后修改于: 2021-09-16 11:12:21             🧑  作者: Mango

给定两个大小分别为 nxn 的方阵 A 和 B,找出它们的乘法矩阵。
朴素方法
以下是两个矩阵相乘的简单方法。

C
void multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}


Javascript


Python
# Version 3.6
 
import numpy as np
 
def split(matrix):
    """
    Splits a given matrix into quarters.
    Input: nxn matrix
    Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
    """
    row, col = matrix.shape
    row2, col2 = row//2, col//2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]
 
def strassen(x, y):
    """
    Computes matrix product by divide and conquer approach, recursively.
    Input: nxn matrices x and y
    Output: nxn matrix, product of x and y
    """
 
    # Base case when size of matrices is 1x1
    if len(x) == 1:
        return x * y
 
    # Splitting the matrices into quadrants. This will be done recursively
    # until the base case is reached.
    a, b, c, d = split(x)
    e, f, g, h = split(y)
 
    # Computing the 7 products, recursively (p1, p2...p7)
    p1 = strassen(a, f - h) 
    p2 = strassen(a + b, h)       
    p3 = strassen(c + d, e)       
    p4 = strassen(d, g - e)       
    p5 = strassen(a + d, e + h)       
    p6 = strassen(b - d, g + h) 
    p7 = strassen(a - c, e + f) 
 
    # Computing the values of the 4 quadrants of the final matrix c
    c11 = p5 + p4 - p2 + p6 
    c12 = p1 + p2          
    c21 = p3 + p4           
    c22 = p1 + p5 - p3 - p7 
 
    # Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
 
    return c


上述方法的时间复杂度为 O(N 3 )。

分而治之
以下是两个方阵相乘的简单分治法。
1) 将矩阵 A 和 B 分成大小为 N/2 x N/2 的 4 个子矩阵,如下图所示。
2) 递归计算以下值。 ae + bg、af + bh、ce + dg 和 cf + dh。

strassen_new

在上述方法中,我们对大小为 N/2 x N/2 的矩阵进行了 8 次乘法和 4 次加法。两个矩阵相加需要 O(N 2 ) 时间。所以时间复杂度可以写成

T(N) = 8T(N/2) + O(N2)  

From Master's Theorem, time complexity of above method is O(N3)
which is unfortunately same as the above naive method.

简单的分而治之也导致 O(N 3 ),有没有更好的方法?
在上述分治法中,高时间复杂度的主要组成部分是8次递归调用。 Strassen 方法的思想是将递归调用的次数减少到 7。 Strassen 方法类似于上述简单的分治法,因为该方法也将矩阵划分为大小为 N/2 x N/2 的子矩阵,如如上图所示,但在 Strassen 方法中,结果的四个子矩阵是使用以下公式计算的。

stressen_formula_new_new

施特拉森方法的时间复杂度
两个矩阵的加法和减法需要 O(N 2 ) 时间。所以时间复杂度可以写成

T(N) = 7T(N/2) +  O(N2)

From Master's Theorem, time complexity of above method is 
O(NLog7) which is approximately O(N2.8074)

由于以下原因,通常 Strassen 方法不是实际应用的首选。
1) Strassen 方法中使用的常数很高,对于典型应用,朴素方法效果更好。
2)对于稀疏矩阵,有专门为它们设计的更好的方法。
3) 递归中的子矩阵需要额外的空间。
4) 由于计算机对非整数值的算术精度有限,Strassen 算法中累积的误差比 Naive Method 大(来源:CLRS Book)

Python

# Version 3.6
 
import numpy as np
 
def split(matrix):
    """
    Splits a given matrix into quarters.
    Input: nxn matrix
    Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
    """
    row, col = matrix.shape
    row2, col2 = row//2, col//2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]
 
def strassen(x, y):
    """
    Computes matrix product by divide and conquer approach, recursively.
    Input: nxn matrices x and y
    Output: nxn matrix, product of x and y
    """
 
    # Base case when size of matrices is 1x1
    if len(x) == 1:
        return x * y
 
    # Splitting the matrices into quadrants. This will be done recursively
    # until the base case is reached.
    a, b, c, d = split(x)
    e, f, g, h = split(y)
 
    # Computing the 7 products, recursively (p1, p2...p7)
    p1 = strassen(a, f - h) 
    p2 = strassen(a + b, h)       
    p3 = strassen(c + d, e)       
    p4 = strassen(d, g - e)       
    p5 = strassen(a + d, e + h)       
    p6 = strassen(b - d, g + h) 
    p7 = strassen(a - c, e + f) 
 
    # Computing the values of the 4 quadrants of the final matrix c
    c11 = p5 + p4 - p2 + p6 
    c12 = p1 + p2          
    c21 = p3 + p4           
    c22 = p1 + p5 - p3 - p7 
 
    # Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
 
    return c

记住施特拉森矩阵方程的简单方法

如果您希望与专家一起参加现场课程,请参阅DSA 现场工作专业课程学生竞争性编程现场课程