📜  Python – PyTorch 的钳位()方法(1)

📅  最后修改于: 2023-12-03 14:46:07.238000             🧑  作者: Mango

Python – PyTorch 的 clamp() 方法

介绍

PyTorch 是一个基于 Python 的科学计算包,具有以下两个特点:

  1. 类似于 NumPy 的张量,但能在 GPU 上利用其速度进行计算。
  2. 深度学习研究平台,提供了最大的灵活性和速度。

其中,clamp() 方法是 PyTorch 在张量上提供的基本方法之一。本文将介绍 clamp() 方法的基本用法和几个示例。

clamp() 方法的基本用法

clamp() 方法是用于将张量(tensor)中的元素限制在一定范围内的方法。它接受三个参数:min_value、max_value 和 out。

该方法限制张量中的所有元素在 min_value 和 max_value 之间,并返回一个新的张量。

clamp() 方法的一般形式为:

torch.clamp(input, min_value, max_value, out=None) → Tensor

其中,

  • input:必需参数,指定输入的张量。
  • min_value:必需参数,指定限制元素的最小值。如果指定为 None,则不会使用最小值限制。
  • max_value:必需参数,指定限制元素的最大值。如果指定为 None,则不会使用最大值限制。
  • out:可选参数,指定返回结果的张量。
示例

下面几个示例未使用 out 参数,因此返回结果均为新的张量。

示例 1

限制张量中的元素在 0 到 10 之间。

import torch

x = torch.randn(3, 3)
y = torch.clamp(x, 0, 10)
print(x)
print(y)

输出:

tensor([[ 0.6563, -0.6803, -0.5293],
        [-0.2628,  0.2662,  0.8974],
        [ 1.4794,  0.7744,  1.5192]])
tensor([[ 0.6563,  0.0000,  0.0000],
        [ 0.0000,  0.2662,  0.8974],
        [10.0000,  0.7744, 10.0000]])
示例 2

将张量中小于 0 的元素设为 0,将大于 1 的元素设为 1。

import torch

x = torch.tensor([[-1.2, 0.5], [0.7, 2.3]])
y = torch.clamp(x, 0, 1)
print(x)
print(y)

输出:

tensor([[-1.2000,  0.5000],
        [ 0.7000,  2.3000]])
tensor([[0.0000, 0.5000],
        [0.7000, 1.0000]])
示例 3

限制张量中的元素在 -1 到 1 之间。

import torch

x = torch.randn(2, 2)
y = torch.clamp(x, -1, 1)
print(x)
print(y)

输出:

tensor([[-0.1715,  0.0367],
        [ 1.1771,  0.2416]])
tensor([[-0.1715,  0.0367],
        [ 1.0000,  0.2416]])
总结

clamp() 方法用于将张量中的元素限制在一定范围内。通过指定 min_value 和 max_value,可以限制张量中的元素的最小值和最大值。clamp() 方法返回一个新的张量,该张量限制了输入张量中的元素。