📜  矩阵链乘法(1)

📅  最后修改于: 2023-12-03 15:27:17.830000             🧑  作者: Mango

矩阵链乘法

矩阵链乘法(Matrix Chain Multiplication)是指将多个矩阵相乘的运算顺序确定下来,使得计算次数最少。这个问题也被称为最优区间连乘问题(Optimal Interval Matrix Multiplication)。矩阵链乘法是一个经典的动态规划问题。

假设有 $n$ 个矩阵,分别用 $A_1, A_2, ..., A_n$ 来表示,其中第 $i$ 个矩阵的规模为 $p_{i-1}\times p_i$(注意这里的下标是从 $1$ 开始的,因为第 $0$ 个矩阵的规模不存在),则这 $n$ 个矩阵相乘的顺序共有 $(n-1)!$ 种,因为第一个矩阵只能乘以第二个矩阵,第二个矩阵只能乘以第三个矩阵,以此类推。但是,由于矩阵相乘具有结合律,因此运算的顺序不会影响最终结果,即运算次数一定是相同的。

动态规划

设 $m_{i,j}$ 表示从第 $i$ 个矩阵乘到第 $j$ 个矩阵所需的最少乘法次数,可以得到以下递推公式:

$$ m_{i,j}=\begin{cases} 0 & i=j \ \min\limits_{i\leq k<j}{m_{i,k}+m_{k+1,j}+p_{i-1}p_kp_j} & i<j \end{cases} $$

当 $i=j$ 时,显然 $m_{i,j}=0$,因为一个矩阵相乘不需要任何操作次数。

当 $i<j$ 时,设 $k\in [i,j-1]$,则 $A_i,A_{i+1},...,A_k$ 和 $A_{k+1},A_{k+2},...,A_j$ 这两部分矩阵先相乘,需要的乘法次数为 $m_{i,k}+m_{k+1,j}$,而将这两部分矩阵相乘的乘法次数为 $p_{i-1}p_kp_j$,因此这种情况下的乘法次数为 $m_{i,j}=\min\limits_{i\leq k<j}{m_{i,k}+m_{k+1,j}+p_{i-1}p_kp_j}$。

最终结果即为 $m_{1,n}$,也就是从第 $1$ 个矩阵乘到第 $n$ 个矩阵所需的最少乘法次数。

代码实现
def matrix_chain_order(p):
    n = len(p) - 1
    m = [[0] * (n+1) for _ in range(n+1)]
    s = [[0] * n for _ in range(n)]

    for gap in range(1, n):
        for i in range(1, n-gap+1):
            j = i + gap
            m[i][j] = float('inf')
            for k in range(i, j):
                q = m[i][k] + m[k+1][j] + p[i-1] * p[k] * p[j]
                if q < m[i][j]:
                    m[i][j] = q
                    s[i-1][j-2] = k

    return m, s

def print_optimal_parens(s, i, j):
    if i == j:
        print(f"A_{i}", end="")
    else:
        print("(", end="")
        print_optimal_parens(s, i, s[i-1][j-2])
        print_optimal_parens(s, s[i-1][j-2]+1, j)
        print(")", end="")
示例

假设有 $5$ 个矩阵 $A_1$、$A_2$、$A_3$、$A_4$、$A_5$,其规模分别为 $10\times 20$、$20\times 30$、$30\times 40$、$40\times 30$、$30\times 15$,则其乘法顺序为 $(A_1A_2)(A_3A_4)A_5$,此时的最少乘法次数为 $10\times 20\times 30 + 10\times 30\times 40 + 10\times 40\times 30 + 10\times 30\times 15 = 18000$。

# 示例
p = [10, 20, 30, 40, 30, 15]
m, s = matrix_chain_order(p)
print_optimal_parens(s, 1, 5)
# 输出结果为: ((A_1(A_2A_3))((A_4A_5)A_6))