小波树是一种数据结构,可将流递归地分为两部分,直到剩下同质数据为止。该名称源自信号的小波变换的类比,该变换将信号递归分解为低频和高频分量。小波树可用于有效地回答范围查询。
考虑该问题,以找到给定数组A的范围[L,R]中小于x的元素数。有效解决此问题的一种方法是使用持久段树数据结构。但是我们也可以使用小波树轻松解决此问题。让我们看看如何!
构造小波树
小波树中的每个节点都由一个数组表示,该数组是原始数组和范围[L,R]的子序列。这里[L,R]是数组元素落入的范围。即,“ R”表示阵列中的最大元素,而“ L”表示最小元素。因此,根节点将包含元素位于[L,R]范围内的原始数组。现在,我们将计算范围[L,R]的中间值,并将数组稳定地分为左右两个孩子两半。因此,左子元素将包含在[L,mid]范围内的元素,右子元素将包含在[mid + 1,R]范围内的元素。
假设给定一个整数数组。现在,我们计算中间值(最大值+最小值/ 2)并形成两个子级。
左子级:整数小于/等于Mid
合适的孩子:大于中整数
我们递归执行此操作,直到形成相似元素的所有节点。
给定数组:0 0 9 1 2 1 7 6 4 8 9 4 3 7 5 9 2 7 0 5 1 0
要构造一个小波树,让我们看看我们需要在每个节点上存储什么。因此,在树的每个节点上,我们将存储两个数组,分别为S []和freq []。数组S []将是原始数组A []的子序列,数组freq []将存储将进入该节点左右子元素的元素计数。也就是说,freq [i]表示从S []的前i个元素到左子元素的元素数。因此,可以很容易地计算出将要移到正确子元素上的元素的数量为(i – freq [i])。
下面的示例显示如何维护freq []数组:
Array : 1 5 2 6 4 4
Mid = (1 + 6) / 2 = 3
Left Child : 1 2
Right Child : 5 6 4 4
为了维持频率阵列,我们将检查元素是否小于Mid。如果是,那么我们将1添加到频率数组的最后一个元素,否则添加0并再次推回去。
对于,上面的数组:
频率数组:{ 1、1、2、2、2、2 }
这意味着1个元素将从索引1和2移到该节点的左子节点,而2个元素将从索引3到6移至该节点的左子节点。这可以很容易地从上述给定数组中进行描述。
为了计算移到右子树的元素数量,我们从i中减去freq [i]。
From index 1, 0 elements go to right subtree.
From index 2, 1 element go to right subtree.
From index 3, 1 element go to right subtree.
From index 4, 2 elements go to right subtree.
From index 5, 3 elements go to right subtree.
From index 6, 4 elements go to right subtree.
我们可以在C++ STL中使用stable_partition函数和lambda表达式来轻松稳定地围绕轴对数组进行分区,而不会扭曲原始序列中元素的顺序。强烈建议在继续执行之前,先阅读stable_partition和lambda表达式文章。
下面是构造小波树的实现:
// CPP code to implement wavelet trees
#include
using namespace std;
#define N 100000
// Given array
int arr[N];
// wavelet tree class
class wavelet_tree {
public:
// Range to elements
int low, high;
// Left and Right children
wavelet_tree* l, *r;
vector freq;
// Default constructor
// Array is in range [x, y]
// Indices are in range [from, to]
wavelet_tree(int* from, int* to, int x, int y)
{
// Initialising low and high
low = x, high = y;
// Array is of 0 length
if (from >= to)
return;
// Array is homogenous
// Example : 1 1 1 1 1
if (high == low) {
// Assigning storage to freq array
freq.reserve(to - from + 1);
// Initialising the Freq array
freq.push_back(0);
// Assigning values
for (auto it = from; it != to; it++)
// freq will be increasing as there'll
// be no further sub-tree
freq.push_back(freq.back() + 1);
return;
}
// Computing mid
int mid = (low + high) / 2;
// Lambda function to check if a number is
// less than or equal to mid
auto lessThanMid = [mid](int x) {
return x <= mid;
};
// Assigning storage to freq array
freq.reserve(to - from + 1);
// Initialising the freq array
freq.push_back(0);
// Assigning value to freq array
for (auto it = from; it != to; it++)
// If lessThanMid returns 1(true), we add
// 1 to previous entry. Otherwise, we add
// 0 (element goes to right sub-tree)
freq.push_back(freq.back() + lessThanMid(*it));
// std::stable_partition partitions the array w.r.t Mid
auto pivot = stable_partition(from, to, lessThanMid);
// Left sub-tree's object
l = new wavelet_tree(from, pivot, low, mid);
// Right sub-tree's object
r = new wavelet_tree(pivot, to, mid + 1, high);
}
};
// Driver code
int main()
{
int size = 5, high = INT_MIN;
int arr[] = {1 , 2, 3, 4, 5};
for (int i = 0; i < size; i++)
high = max(high, arr[i]);
// Object of class wavelet tree
wavelet_tree obj(arr, arr + size, 1, high);
return 0;
}
树的高度:O(log(max(A)),其中max(A)是数组A []中的最大元素。
小波树查询
我们已经为给定的数组构造了小波树。现在我们继续解决我们的问题,计算给定数组中[L,R]范围内小于或等于x的元素数。
因此,对于每个节点,我们都有一个原始数组的子序列,该数组中存在的最低和最高值以及左右子元素中的元素计数。
现在,
If high <= x,
we return R - L + 1.
i.e. all the elements in the current range is less than x.
否则,我们将使用变量LtCount = freq [L-1](即元素从L-1到左子树),RtCount = freq [R](即元素从R到右子树)
现在,我们递归地调用并添加的返回值:
left sub-tree with range[ LtCount + 1, RtCount ] and,
right sub-tree with range[ L - Ltcount,R - RtCount ]
下面是C++的实现:
// CPP program for querying in
// wavelet tree Data Structure
#include
using namespace std;
#define N 100000
// Given Array
int arr[N];
// wavelet tree class
class wavelet_tree {
public:
// Range to elements
int low, high;
// Left and Right child
wavelet_tree* l, *r;
vector freq;
// Default constructor
// Array is in range [x, y]
// Indices are in range [from, to]
wavelet_tree(int* from, int* to, int x, int y)
{
// Initialising low and high
low = x, high = y;
// Array is of 0 length
if (from >= to)
return;
// Array is homogenous
// Example : 1 1 1 1 1
if (high == low) {
// Assigning storage to freq array
freq.reserve(to - from + 1);
// Initialising the Freq array
freq.push_back(0);
// Assigning values
for (auto it = from; it != to; it++)
// freq will be increasing as there'll
// be no further sub-tree
freq.push_back(freq.back() + 1);
return;
}
// Computing mid
int mid = (low + high) / 2;
// Lambda function to check if a number
// is less than or equal to mid
auto lessThanMid = [mid](int x) {
return x <= mid;
};
// Assigning storage to freq array
freq.reserve(to - from + 1);
// Initialising the freq array
freq.push_back(0);
// Assigning value to freq array
for (auto it = from; it != to; it++)
// If lessThanMid returns 1(true), we add
// 1 to previous entry. Otherwise, we add 0
// (element goes to right sub-tree)
freq.push_back(freq.back() + lessThanMid(*it));
// std::stable_partition partitions the array w.r.t Mid
auto pivot = stable_partition(from, to, lessThanMid);
// Left sub-tree's object
l = new wavelet_tree(from, pivot, low, mid);
// Right sub-tree's object
r = new wavelet_tree(pivot, to, mid + 1, high);
}
// Count of numbers in range[L..R] less than
// or equal to k
int kOrLess(int l, int r, int k)
{
// No elements int range is less than k
if (l > r or k < low)
return 0;
// All elements in the range are less than k
if (high <= k)
return r - l + 1;
// Computing LtCount and RtCount
int LtCount = freq[l - 1];
int RtCount = freq[r];
// Answer is (no. of element <= k) in
// left + (those <= k) in right
return (this->l->kOrLess(LtCount + 1, RtCount, k) +
this->r->kOrLess(l - LtCount, r - RtCount, k));
}
};
// Driver code
int main()
{
int size = 5, high = INT_MIN;
int arr[] = {1, 2, 3, 4, 5};
// Array : 1 2 3 4 5
for (int i = 0; i < size; i++)
high = max(high, arr[i]);
// Object of class wavelet tree
wavelet_tree obj(arr, arr + size, 1, high);
// count of elements less than 2 in range [1,3]
cout << obj.kOrLess(0, 3, 2) << '\n';
return 0;
}
输出 :
2
时间复杂度:O(log(max(A)),其中max(A)是数组A []中的最大元素。
在这篇文章中,我们讨论了有关范围查询的单个问题,而没有更新。此外,我们还将讨论范围更新。
参考文献:
- https://users.dcc.uchile.cl/~jperez/papers/ioiconf16.pdf
- https://zh.wikipedia.org/wiki/小波树
- https://www.youtube.com/watch?v=K7tju9j7UWU