计数 0 到 N-1 的排列,其中至少 K 个元素与位置相同
给定两个整数N和K,任务是找到从0到N – 1 的数字的排列数,使得数组中至少有K个位置使得arr[i] = i ( 0 <= i < N)。由于答案可能非常大,请以 10^9+7 为模计算结果。
例子:
Input: N = 4, K = 3
Output: 1
Explanation: There is only one permutation [0, 1, 2, 3] such that number of elements with arr[i] = i is K = 3.
Input: N = 4, K = 2
Output: 7
Explanation: There are 7 permutations satisfying the condition which are as follow:
- [0, 1, 2, 3]
- [0, 1, 3, 2]
- [0, 3, 2, 1]
- [0, 2, 1, 3]
- [3, 1, 2, 0]
- [2, 1, 0, 3]
- [1, 0, 2, 3]
朴素的方法:解决问题的基本思想是首先找到数组的排列。
请按照以下步骤解决问题:
- 递归地找到所有可能的排列,然后,
- 检查它们中的每一个是否符合条件。
- 在此基础上维护一个计数器,并在当前排列满足条件时将其递增。
以下是上述方法的实现:
C++
// C++ code for the above approach:
#include
using namespace std;
// Recursive function to get the
// all permutations of current array
void getPermutations(vector& arr,
int index, int k,
int& ans)
{
// Base condition if current index is
// greater than or equal to array size
if (index >= arr.size()) {
// Initialising the variable count
int count = 0;
// Counting the number of positions
// with arr[i] = i in the array
for (int i = 0; i < arr.size(); i++) {
if (arr[i] == i) {
count++;
}
}
// If count is greater than
// or equal to k then
// increment the ans
if (count >= k) {
ans++;
}
return;
}
// Iterating over the array arr
for (int i = index; i < arr.size(); i++) {
// Swapping current index with I
swap(arr[index], arr[i]);
// Calling recursion for current
// condition
getPermutations(arr, index + 1, k, ans);
// Resetting the swapped position.
swap(arr[index], arr[i]);
}
}
int numberOfPermutations(long long n,
long long k)
{
// Initializing the variables
//'mod' and 'ans'.
int mod = 1e9 + 7;
int ans = 0;
// Initializing the array 'arr'.
vector arr;
// Pushing numbers in the array.
for (int i = 0; i < n; i++) {
arr.push_back(i);
}
// Calling recursive function 'getPermutations'
getPermutations(arr, 0, k, ans);
// Returning 'ans'.
return ans % mod;
}
// Driver Code
int main()
{
long long N = 4;
long long K = 2;
cout << numberOfPermutations(N, K);
return 0;
}
Java
// Java code for the above approach:
import java.util.*;
class GFG {
static int ans = 0;
// Recursive function to get the
// all permutations of current array
static void getPermutations(Vector arr,
int index, long k)
{
// Base condition if current index is
// greater than or equal to array size
if (index >= arr.size()) {
// Initialising the variable count
int count = 0;
// Counting the number of positions
// with arr[i] = i in the array
for (int i = 0; i < arr.size(); i++) {
if (arr.get(i) == i) {
count++;
}
}
// If count is greater than
// or equal to k then
// increment the ans
if (count >= k) {
ans++;
}
return;
}
// Iterating over the array arr
for (int i = index; i < arr.size(); i++) {
// Swapping current index with I
int temp = arr.get(index);
arr.set(index, arr.get(i));
arr.set(i, temp);
// Calling recursion for current
// condition
getPermutations(arr, index + 1, k);
// Resetting the swapped position.
temp = arr.get(index);
arr.set(index, arr.get(i));
arr.set(i, temp);
}
}
static int numberOfPermutations(long n,
long k)
{
// Initializing the variables
//'mod' and 'ans'.
int mod = 1000000000 + 7;
// Initializing the array 'arr'.
Vector arr = new Vector();
// Pushing numbers in the array.
for (int i = 0; i < n; i++) {
arr.add(i);
}
// Calling recursive function 'getPermutations'
getPermutations(arr, 0, k);
// Returning 'ans'.
return ans % mod;
}
// Driver Code
public static void main (String[] args) {
long N = 4;
long K = 2;
System.out.println(numberOfPermutations(N, K));
}
}
// This code is contributed by hrithikgarg03188.
Python3
# Python code for the approach
# Recursive function to get the
# all permutations of current array
ans = 0
def getPermutations(arr, index, k):
global ans
# Base condition if current index is
# greater than or equal to array size
if (index >= len(arr)):
# Initialising the variable count
count = 0
# Counting the number of positions
# with arr[i] = i in the array
for i in range(len(arr)):
if (arr[i] == i):
count += 1
# If count is greater than
# or equal to k then
# increment the ans
if (count >= k):
ans += 1
return
# Iterating over the array arr
for i in range(index, len(arr)):
# Swapping current index with I
temp = arr[index]
arr[index] = arr[i]
arr[i] = temp
# Calling recursion for current
# condition
getPermutations(arr, index + 1, k)
# Resetting the swapped position.
temp = arr[index]
arr[index] = arr[i]
arr[i] = temp
def numberOfPermutations(n, k):
# Initializing the variables
#'mod' and 'ans'.
mod = 1e9 + 7
# Initializing the array 'arr'.
arr = []
# Pushing numbers in the array.
for i in range(n):
arr.append(i)
# Calling recursive function 'getPermutations'
getPermutations(arr, 0, k)
# Returning 'ans'.
return int(ans % mod)
# Driver Code
N = 4
K = 2
print(numberOfPermutations(N, K))
# This code is contributed by shinjanpatra
Javascript
C++
// C++ code for the above approach:
#include
using namespace std;
// Driver function to get the
// modular addition.
int add(long long a, long long b)
{
int mod = 1e9 + 7;
return ((a % mod) + (b % mod)) % mod;
}
// Driver function to get the
// modular multiplication.
int mul(long long a, long long b)
{
int mod = 1e9 + 7;
return ((a % mod) * 1LL * (b % mod)) % mod;
}
// Driver function to get the
// modular binary exponentiation.
int bin_pow(long long a, long long b)
{
int mod = 1e9 + 7;
a %= mod;
long long res = 1;
while (b > 0) {
if (b & 1) {
res = res * 1LL * a % mod;
}
a = a * 1LL * a % mod;
b >>= 1;
}
return res;
}
// Driver function to get the
// modular division.
int reverse(long long x)
{
int mod = 1e9 + 7;
return bin_pow(x, mod - 2);
}
int numberOfPermutations(long long n, long long k)
{
// Updating 'k' with 'n - k'.
k = n - k;
// Initializing the 'ans' by 1.
int ans = 1;
// Condition when 'k' is 1.
if (k == 0 or k == 1) {
return ans;
}
// Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2));
// Condition when 'k' is 2.
if (k == 2) {
return ans;
}
// Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3));
// Condition when 'k' is 3.
if (k == 3) {
return ans;
}
// Adding derangement for 'k' = 4.
int u = mul(n, mul(n - 1, mul(n - 2,
n - 3)));
ans = add(ans, mul(u, reverse(8)));
ans = add(ans, mul(u, reverse(4)));
return ans;
}
// Driver Code
int main()
{
long long N = 4;
long long K = 2;
cout << numberOfPermutations(N, K);
return 0;
}
Java
// Java program for the above approach
import java.util.ArrayList;
class GFG {
// Driver function to get the
// modular addition.
static long add(long a, long b)
{
long mod = (int)1e9 + 7;
return ((a % mod) + (b % mod)) % mod;
}
// Driver function to get the
// modular multiplication.
static long mul(long a, long b)
{
long mod = (int)1e9 + 7;
return ((a % mod) * 1 * (b % mod)) % mod;
}
// Driver function to get the
// modular binary exponentiation.
static long bin_pow(long a, long b)
{
long mod = (int)1e9 + 7;
a %= mod;
long res = 1;
while (b > 0) {
if ((b & 1) != 0) {
res = res * 1 * a % mod;
}
a = a * 1 * a % mod;
b >>= 1;
}
return res;
}
// Driver function to get the
// modular division.
static long reverse(long x)
{
long mod = (int)1e9 + 7;
return bin_pow(x, mod - 2);
}
static long numberOfPermutations(long n, long k)
{
// Updating 'k' with 'n - k'.
k = n - k;
// Initializing the 'ans' by 1.
long ans = 1;
// Condition when 'k' is 1.
if (k == 0 || k == 1) {
return ans;
}
// Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2));
// Condition when 'k' is 2.
if (k == 2) {
return ans;
}
// Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3));
// Condition when 'k' is 3.
if (k == 3) {
return ans;
}
// Adding derangement for 'k' = 4.
long u = mul(n, mul(n - 1, mul(n - 2,
n - 3)));
ans = add(ans, mul(u, reverse(8)));
ans = add(ans, mul(u, reverse(4)));
return ans;
}
// Driver Code
public static void main(String args[]) {
long N = 4;
long K = 2;
System.out.print( numberOfPermutations(N, K));
}
}
// This code is contributed by sanjoy_62.
Python3
# Python3 code for the above approach:
# Driver function to get the
# modular addition.
def add(a, b):
mod = int(1e9 + 7)
return ((a % mod) + (b % mod)) % mod
# Driver function to get the
# modular multiplication.
def mul(a, b):
mod = int(1e9 + 7)
return ((a % mod) * (b % mod)) % mod
# Driver function to get the
# modular binary exponentiation.
def bin_pow(a, b):
mod = int(1e9 + 7)
a %= mod
res = 1
while (b > 0):
if (b & 1):
res = res * a % mod
a = a * a % mod
b >>= 1
return res
# Driver function to get the
# modular division.
def reverse(x):
mod = int(1e9 + 7)
return bin_pow(x, mod - 2)
def numberOfPermutations(n, k):
# Updating 'k' with 'n - k'.
k = n - k
# Initializing the 'ans' by 1.
ans = 1
# Condition when 'k' is 1.
if (k == 0 or k == 1):
return ans
# Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2))
# Condition when 'k' is 2.
if (k == 2):
return ans
# Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3))
# Condition when 'k' is 3.
if (k == 3):
return ans
# Adding derangement for 'k' = 4.
u = mul(n, mul(n - 1, mul(n - 2,
n - 3)))
ans = add(ans, mul(u, reverse(8)))
ans = add(ans, mul(u, reverse(4)))
return ans
# Driver Code
if __name__ == "__main__":
N = 4
K = 2
print(numberOfPermutations(N, K))
# This code is contributed by rakeshsahni
输出
7
时间复杂度: O(N * N!)
- 为了递归地找到所有排列,将有 N!递归调用,
- 对于每次调用,都会有一个运行循环,其中包含 N 次迭代。
- 因此,整体时间复杂度将为 O(N * N!)。
辅助空间: O(N)
有效方法:有效解决问题的想法是通过计算数组的紊乱。
请按照以下步骤解决问题:
- 首先固定数组中的位置,使得arr[i] != i ,假设有“ M ”这样的位置。 (0 <= M <= N – K)
- 为此计算具有固定M的排列数,然后简单地选择具有arr[i] !=i的索引
- 使用简单的组合公式NCM找到它。
- 然后,为选择的索引构造一个排列,使得对于每个选择的索引arr[i] !=i ,这只不过是derangements ,并使用穷举搜索找到它。
- 然后做casework,根据K值找出错位。
以下是上述方法的实现:
C++
// C++ code for the above approach:
#include
using namespace std;
// Driver function to get the
// modular addition.
int add(long long a, long long b)
{
int mod = 1e9 + 7;
return ((a % mod) + (b % mod)) % mod;
}
// Driver function to get the
// modular multiplication.
int mul(long long a, long long b)
{
int mod = 1e9 + 7;
return ((a % mod) * 1LL * (b % mod)) % mod;
}
// Driver function to get the
// modular binary exponentiation.
int bin_pow(long long a, long long b)
{
int mod = 1e9 + 7;
a %= mod;
long long res = 1;
while (b > 0) {
if (b & 1) {
res = res * 1LL * a % mod;
}
a = a * 1LL * a % mod;
b >>= 1;
}
return res;
}
// Driver function to get the
// modular division.
int reverse(long long x)
{
int mod = 1e9 + 7;
return bin_pow(x, mod - 2);
}
int numberOfPermutations(long long n, long long k)
{
// Updating 'k' with 'n - k'.
k = n - k;
// Initializing the 'ans' by 1.
int ans = 1;
// Condition when 'k' is 1.
if (k == 0 or k == 1) {
return ans;
}
// Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2));
// Condition when 'k' is 2.
if (k == 2) {
return ans;
}
// Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3));
// Condition when 'k' is 3.
if (k == 3) {
return ans;
}
// Adding derangement for 'k' = 4.
int u = mul(n, mul(n - 1, mul(n - 2,
n - 3)));
ans = add(ans, mul(u, reverse(8)));
ans = add(ans, mul(u, reverse(4)));
return ans;
}
// Driver Code
int main()
{
long long N = 4;
long long K = 2;
cout << numberOfPermutations(N, K);
return 0;
}
Java
// Java program for the above approach
import java.util.ArrayList;
class GFG {
// Driver function to get the
// modular addition.
static long add(long a, long b)
{
long mod = (int)1e9 + 7;
return ((a % mod) + (b % mod)) % mod;
}
// Driver function to get the
// modular multiplication.
static long mul(long a, long b)
{
long mod = (int)1e9 + 7;
return ((a % mod) * 1 * (b % mod)) % mod;
}
// Driver function to get the
// modular binary exponentiation.
static long bin_pow(long a, long b)
{
long mod = (int)1e9 + 7;
a %= mod;
long res = 1;
while (b > 0) {
if ((b & 1) != 0) {
res = res * 1 * a % mod;
}
a = a * 1 * a % mod;
b >>= 1;
}
return res;
}
// Driver function to get the
// modular division.
static long reverse(long x)
{
long mod = (int)1e9 + 7;
return bin_pow(x, mod - 2);
}
static long numberOfPermutations(long n, long k)
{
// Updating 'k' with 'n - k'.
k = n - k;
// Initializing the 'ans' by 1.
long ans = 1;
// Condition when 'k' is 1.
if (k == 0 || k == 1) {
return ans;
}
// Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2));
// Condition when 'k' is 2.
if (k == 2) {
return ans;
}
// Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3));
// Condition when 'k' is 3.
if (k == 3) {
return ans;
}
// Adding derangement for 'k' = 4.
long u = mul(n, mul(n - 1, mul(n - 2,
n - 3)));
ans = add(ans, mul(u, reverse(8)));
ans = add(ans, mul(u, reverse(4)));
return ans;
}
// Driver Code
public static void main(String args[]) {
long N = 4;
long K = 2;
System.out.print( numberOfPermutations(N, K));
}
}
// This code is contributed by sanjoy_62.
Python3
# Python3 code for the above approach:
# Driver function to get the
# modular addition.
def add(a, b):
mod = int(1e9 + 7)
return ((a % mod) + (b % mod)) % mod
# Driver function to get the
# modular multiplication.
def mul(a, b):
mod = int(1e9 + 7)
return ((a % mod) * (b % mod)) % mod
# Driver function to get the
# modular binary exponentiation.
def bin_pow(a, b):
mod = int(1e9 + 7)
a %= mod
res = 1
while (b > 0):
if (b & 1):
res = res * a % mod
a = a * a % mod
b >>= 1
return res
# Driver function to get the
# modular division.
def reverse(x):
mod = int(1e9 + 7)
return bin_pow(x, mod - 2)
def numberOfPermutations(n, k):
# Updating 'k' with 'n - k'.
k = n - k
# Initializing the 'ans' by 1.
ans = 1
# Condition when 'k' is 1.
if (k == 0 or k == 1):
return ans
# Adding derangement for 'k' = 2.
ans += mul(mul(n, n - 1), reverse(2))
# Condition when 'k' is 2.
if (k == 2):
return ans
# Adding derangement for 'k' = 3.
ans += mul(mul(n, mul(n - 1, n - 2)),
reverse(3))
# Condition when 'k' is 3.
if (k == 3):
return ans
# Adding derangement for 'k' = 4.
u = mul(n, mul(n - 1, mul(n - 2,
n - 3)))
ans = add(ans, mul(u, reverse(8)))
ans = add(ans, mul(u, reverse(4)))
return ans
# Driver Code
if __name__ == "__main__":
N = 4
K = 2
print(numberOfPermutations(N, K))
# This code is contributed by rakeshsahni
输出
7
时间复杂度:O(Log N)
- 因为,使用二进制指数得到一个数的反模
- 总时间复杂度为 O(Log N)
辅助空间: O(1)