📅  最后修改于: 2023-12-03 14:47:54.912000             🧑  作者: Mango
tf.eye()
是Tensorflow.js中的一个函数,用于创建一个指定大小的单位矩阵。
tf.eye(
numRows: number,
numColumns: number,
batchShape?: number[],
dtype?: 'float32'|'int32'|'bool'|'complex64'
): tf.Tensor
参数:
numRows
(必选):生成矩阵的行数。numColumns
(必选):生成矩阵的列数。batchShape
(可选):生成矩阵的批次形状(batch shape),默认为null
。dtype
(可选):生成矩阵的数据类型,可以是'float32'
、'int32'
、'bool'
或'complex64'
,默认为'float32'
。返回值:
[batchShape, numRows, numColumns]
。const eye = tf.eye(3);
// 输出:
// [
// [1, 0, 0],
// [0, 1, 0],
// [0, 0, 1]
// ]
eye.print();
const eye2 = tf.eye(2, 3);
// 输出:
// [
// [1, 0, 0],
// [0, 1, 0],
// ]
eye2.print();
const eye3 = tf.eye(2, 2, [2]);
// 输出:
// [
// [
// [1, 0],
// [0, 1]
// ],
// [
// [1, 0],
// [0, 1]
// ]
// ]
eye3.print();
[2]
的张量。输出其内容并打印。