动态与静态计算图——PyTorch 和 TensorFlow
TensorFlow 和 Pytorch 是最近最流行的两个深度学习库。这两个库都在主流深度学习中发展了各自的利基,拥有出色的文档、教程,最重要的是,它们背后有一个充满活力和支持的社区。
TensorFlow 中的静态计算图和 Pytorch 中的动态计算图的区别
尽管这两个库都使用有向无环图(或 DAG)来表示它们的机器学习和深度学习模型,但它们让数据和计算流过图形的方式仍然存在很大差异。这两个库之间的细微差别在于 Tensorflow(v < 2.0) 允许静态图计算,而 Pytorch 允许动态图计算。本文将通过代码示例以直观的方式介绍这些差异。本文假设您具备计算图的工作知识以及对 TensorFlow 和 Pytorch 模块的基本了解。为了快速复习这些概念,建议读者阅读以下文章:
- 深度学习中的计算图
- PyTorch 入门
- TensorFlow 简介
Tensorflow 中的静态计算图
节点和边的属性:节点表示直接应用于通过边流入和流出的数据的操作。对于上述方程组,我们在 TensorFlow 中实现时可以牢记以下几点:
- 由于输入充当图的边缘,我们可以使用tf.Placeholder()对象,它可以接受所需数据类型的任何输入。
- 为了计算输出“c”,我们定义了一个简单的乘法运算并启动一个 tensorflow 会话,我们通过session.run()方法中的feed_dict属性传入所需的输入值,以计算输出和梯度。
现在让我们在 TensorFlow 中实现上述计算,并观察操作是如何发生的:
Python3
# Importing tensorflow version 1
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
# Initializing placeholder variables of
# the graph
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
# Defining the operation
c = tf.multiply(a, b)
# Instantiating a tensorflow session
with tf.Session() as sess:
# Computing the output of the graph by giving
# respective input values
out = sess.run(, feed_dict={a: [15.0], b: [20.0]})[0][0]
# Computing the output gradient of the output with
# respect to the input 'a'
derivative_out_a = sess.run(tf.gradients(c, a), feed_dict={
a: [15.0], b: [20.0]})[0][0]
# Computing the output gradient of the output with
# respect to the input 'b'
derivative_out_b = sess.run(tf.gradients(c, b), feed_dict={
a: [15.0], b: [20.0]})[0][0]
# Displaying the outputs
print(f'c = {out}')
print(f'Derivative of c with respect to a = {derivative_out_a}')
print(f'Derivative of c with respect to b = {derivative_out_b}')
Python3
# Importing torch
import torch
# Initializing input tensors
a = torch.tensor(15.0, requires_grad=True)
b = torch.tensor(20.0, requires_grad=True)
# Computing the output
c = a * b
# Computing the gradients
c.backward()
# Collecting the output gradient of the
# output with respect to the input 'a'
derivative_out_a = a.grad
# Collecting the output gradient of the
# output with respect to the input 'b'
derivative_out_b = b.grad
# Displaying the outputs
print(f'c = {c}')
print(f'Derivative of c with respect to a = {derivative_out_a}')
print(f'Derivative of c with respect to b = {derivative_out_b}')
输出:
c = 300.0
Derivative of c with respect to a = 20.0
Derivative of c with respect to b = 15.0
正如我们所看到的,输出与我们在简介部分中的计算正确匹配,从而表明成功完成。从代码中可以看出静态结构,我们可以看到,有一次,在会话中,我们不能定义新的操作(或节点),但我们肯定可以使用sess.run()中的feed_dict属性更改输入变量方法。
好处:
- 由于图是静态的,它提供了许多优化结构和资源分配的可能性。
- 由于固定结构,计算比动态图略快。
缺点:
- 对可变维度输入的扩展性很差。例如,在没有大量预处理样板代码的情况下,具有在 28×28 图像上训练的静态计算图的 CNN(卷积神经网络)架构在 100×100 等不同尺寸的图像上表现不佳。
- 调试不好。这些很难调试,主要是因为用户无法访问信息流是如何发生的。 erg:假设用户创建了一个格式错误的静态图,用户无法直接跟踪错误,直到 TensorFlow 会话在计算反向传播和前向传播时发现错误。当模型很大时,这会成为一个主要问题,因为它会浪费用户的时间和计算资源。
Pytorch 中的动态计算图
节点和边的属性:节点代表数据(以张量的形式),边代表应用于输入数据的操作。
对于引言中给出的方程,我们在 Pytorch 中实现时可以牢记以下几点:
- 由于 Pytorch 中的所有内容都是动态创建的,因此我们不需要任何占位符,并且可以动态定义我们的输入和操作。
- 在定义输入并计算输出'c'之后,我们调用backward()方法,该方法计算相对于通过.grad说明符可访问的两个输入的相应偏导数。
现在让我们查看一个代码示例来验证我们的发现:
Python3
# Importing torch
import torch
# Initializing input tensors
a = torch.tensor(15.0, requires_grad=True)
b = torch.tensor(20.0, requires_grad=True)
# Computing the output
c = a * b
# Computing the gradients
c.backward()
# Collecting the output gradient of the
# output with respect to the input 'a'
derivative_out_a = a.grad
# Collecting the output gradient of the
# output with respect to the input 'b'
derivative_out_b = b.grad
# Displaying the outputs
print(f'c = {c}')
print(f'Derivative of c with respect to a = {derivative_out_a}')
print(f'Derivative of c with respect to b = {derivative_out_b}')
输出:
c = 300.0
Derivative of c with respect to a = 20.0
Derivative of c with respect to b = 15.0
正如我们所看到的,输出与我们在简介部分中的计算正确匹配,从而表明成功完成。从代码中可以看出动态结构。我们可以看到,所有输入和输出只能在运行时访问和更改,这与 Tensorflow 使用的方法完全不同。
好处:
- 对不同维度输入的可扩展性:对于不同维度的输入,可以很好地扩展,因为可以将新的预处理层动态添加到网络本身。
- 易于调试:这些非常容易调试,也是许多人从 TensorFlow 转向 Pytorch 的原因之一。由于节点是在任何信息流过它们之前动态创建的,因此由于用户完全控制了训练过程中使用的变量,因此很容易发现错误。
缺点:
- 由于需要为每个训练实例/批次创建一个新的图,因此图优化的空间非常小。
结论
本文阐明了 Tensorflow 和 Pytorch 的建模结构之间的区别。本文还通过代码示例列出了这两种方法的一些优点和缺点。这些库开发背后的各个组织在后续迭代中不断改进,但读者现在可以在为下一个项目选择最佳框架之前做出更明智的决定。