📜  Tensorflow.js tf.signal.stft()函数(1)

📅  最后修改于: 2023-12-03 15:35:17.961000             🧑  作者: Mango

TensorFlow.js tf.signal.stft()函数介绍

在 TensorFlow.js 中,tf.signal.stft() 函数用于对时间序列信号执行短时傅里叶变换(Short-time Fourier Transform,缩写为STFT),将时间域的信号转换到频域来进行处理和分析。

函数定义

以下是 tf.signal.stft() 函数的定义:

function stft(
    signal: Tensor1D, 
    frameLength: number, 
    frameStep: number, 
    fftLength?: number, 
    windowFn?: ((length: number) => Tensor1D), 
    padEnd?: boolean
): Tensor2D

参数介绍:

  • signal: Tensor1D,输入信号。
  • frameLength: number,每个帧的长度,以样本点数表示。
  • frameStep: number,每个帧之间的步长,以样本点数表示。
  • fftLength: number,FFT 的长度,如果不提供,则默认为 frameLength。
  • windowFn: (length: number) => Tensor1D,窗函数的类型,以离散 Fourier 变换(DFT)长度为参数,返回具有长度为 length 的一维 Tensor 的函数。如果不提供窗口函数,则使用矩形窗口。
  • padEnd: boolean,是否在信号末尾填充零以避免截断误差。

返回值:

  • 一个二维 Tensor2D,其 shape 为 [numFrames, fftLength / 2 + 1],其中 numFrames 表示信号被分成的帧数。可以通过 tf.realtf.imag 函数分别获取实部和虚部来获取 STFT 矩阵的实际值。
示例代码

下面是一个计算音频文件 STFT 的示例代码:

import * as tf from '@tensorflow/tfjs';
import * as fs from 'fs';
import * as wav from 'node-wav';

const signal = wav.decode(fs.readFileSync('example.wav')).channelData[0];
const frameLength = 256;
const frameStep = 128;
const stftMatrix = tf.signal.stft(signal, frameLength, frameStep);
console.log(stftMatrix.shape);
console.log(tf.real(stftMatrix).arraySync());
console.log(tf.imag(stftMatrix).arraySync());

在这个例子中,我们首先使用 node-wav 包读取了一个示例音频文件,然后将其解码成 signal 数组。然后,我们指定每个帧的长度为 256,每个帧之间的步长为 128,并调用 tf.signal.stft() 函数来计算 STFT,最后输出矩阵的 shape 和实部数组以及虚部数组。