📌  相关文章
📜  计数 0 到 N-1 的排列,其中至少 K 个元素与位置相同

📅  最后修改于: 2022-05-13 01:56:10.277000             🧑  作者: Mango

计数 0 到 N-1 的排列,其中至少 K 个元素与位置相同

给定两个整数NK,任务是找到从0N – 1 的数字的排列数,使得数组中至少有K个位置使得arr[i] = i ( 0 <= i < N)。由于答案可能非常大,请以 10^9+7 为模计算结果。

例子:

朴素的方法:解决问题的基本思想是首先找到数组的排列。

请按照以下步骤解决问题:

  • 递归地找到所有可能的排列,然后,
  • 检查它们中的每一个是否符合条件。
  • 在此基础上维护一个计数器,并在当前排列满足条件时将其递增。

以下是上述方法的实现:

C++
// C++ code for the above approach:
 
#include 
using namespace std;
 
// Recursive function to get the
// all permutations of current array
void getPermutations(vector& arr,
                     int index, int k,
                     int& ans)
{
 
    // Base condition if current index is
    // greater than or equal to array size
 
    if (index >= arr.size()) {
 
        // Initialising the variable count
        int count = 0;
 
        // Counting the number of positions
        // with arr[i] = i in the array
        for (int i = 0; i < arr.size(); i++) {
            if (arr[i] == i) {
                count++;
            }
        }
 
        // If count is greater than
        // or equal to k then
        // increment the ans
        if (count >= k) {
            ans++;
        }
 
        return;
    }
 
    // Iterating over the array arr
    for (int i = index; i < arr.size(); i++) {
 
        // Swapping current index with I
        swap(arr[index], arr[i]);
 
        // Calling recursion for current
        // condition
        getPermutations(arr, index + 1, k, ans);
 
        // Resetting the swapped position.
        swap(arr[index], arr[i]);
    }
}
 
int numberOfPermutations(long long n,
                         long long k)
{
 
    // Initializing the variables
    //'mod' and 'ans'.
    int mod = 1e9 + 7;
    int ans = 0;
 
    // Initializing the array 'arr'.
    vector arr;
 
    // Pushing numbers in the array.
    for (int i = 0; i < n; i++) {
        arr.push_back(i);
    }
 
    // Calling recursive function 'getPermutations'
    getPermutations(arr, 0, k, ans);
 
    // Returning 'ans'.
    return ans % mod;
}
 
// Driver Code
int main()
{
    long long N = 4;
    long long K = 2;
    cout << numberOfPermutations(N, K);
    return 0;
}


Java
// Java code for the above approach:
import java.util.*;
 
class GFG {
 
  static int ans = 0;
   
  // Recursive function to get the
  // all permutations of current array
  static void getPermutations(Vector arr,
                              int index, long k)
  {
 
    // Base condition if current index is
    // greater than or equal to array size
 
    if (index >= arr.size()) {
 
      // Initialising the variable count
      int count = 0;
 
      // Counting the number of positions
      // with arr[i] = i in the array
      for (int i = 0; i < arr.size(); i++) {
        if (arr.get(i) == i) {
          count++;
        }
      }
 
      // If count is greater than
      // or equal to k then
      // increment the ans
      if (count >= k) {
        ans++;
      }
 
      return;
    }
 
    // Iterating over the array arr
    for (int i = index; i < arr.size(); i++) {
 
      // Swapping current index with I
      int temp = arr.get(index);
      arr.set(index, arr.get(i));
      arr.set(i, temp);
      // Calling recursion for current
      // condition
      getPermutations(arr, index + 1, k);
 
      // Resetting the swapped position.
      temp = arr.get(index);
      arr.set(index, arr.get(i));
      arr.set(i, temp);
    }
  }
 
  static int numberOfPermutations(long n,
                                  long k)
  {
 
    // Initializing the variables
    //'mod' and 'ans'.
    int mod = 1000000000 + 7;
 
 
    // Initializing the array 'arr'.
    Vector arr = new Vector();
 
    // Pushing numbers in the array.
    for (int i = 0; i < n; i++) {
      arr.add(i);
    }
 
    // Calling recursive function 'getPermutations'
    getPermutations(arr, 0, k);
 
    // Returning 'ans'.
    return ans % mod;
  }
 
  // Driver Code
  public static void main (String[] args) {
    long N = 4;
    long K = 2;
    System.out.println(numberOfPermutations(N, K));
  }
}
 
// This code is contributed by hrithikgarg03188.


Python3
# Python code for the approach
 
# Recursive function to get the
# all permutations of current array
ans = 0
 
def getPermutations(arr, index, k):
    global ans
     
    # Base condition if current index is
    # greater than or equal to array size
    if (index >= len(arr)):
 
        # Initialising the variable count
        count = 0
 
        # Counting the number of positions
        # with arr[i] = i in the array
        for i in range(len(arr)):
            if (arr[i] == i):
                count += 1
 
        # If count is greater than
        # or equal to k then
        # increment the ans
        if (count >= k):
            ans += 1
 
        return
 
    # Iterating over the array arr
    for i in range(index, len(arr)):
 
        # Swapping current index with I
        temp = arr[index]
        arr[index] = arr[i]
        arr[i] = temp
         
        # Calling recursion for current
        # condition
        getPermutations(arr, index + 1, k)
 
        # Resetting the swapped position.
        temp = arr[index]
        arr[index] = arr[i]
        arr[i] = temp
     
def numberOfPermutations(n, k):
 
    # Initializing the variables
    #'mod' and 'ans'.
    mod = 1e9 + 7
 
    # Initializing the array 'arr'.
    arr = []
 
    # Pushing numbers in the array.
    for i in range(n):
        arr.append(i)
 
    # Calling recursive function 'getPermutations'
    getPermutations(arr, 0, k)
 
    # Returning 'ans'.
    return int(ans % mod)
 
# Driver Code
 
N = 4
K = 2
print(numberOfPermutations(N, K))
 
# This code is contributed by shinjanpatra


Javascript


C++
// C++ code for the above approach:
 
#include 
using namespace std;
 
// Driver function to get the
// modular addition.
int add(long long a, long long b)
{
    int mod = 1e9 + 7;
    return ((a % mod) + (b % mod)) % mod;
}
 
// Driver function to get the
// modular multiplication.
int mul(long long a, long long b)
{
    int mod = 1e9 + 7;
    return ((a % mod) * 1LL * (b % mod)) % mod;
}
 
// Driver function to get the
// modular binary exponentiation.
int bin_pow(long long a, long long b)
{
    int mod = 1e9 + 7;
    a %= mod;
    long long res = 1;
    while (b > 0) {
        if (b & 1) {
            res = res * 1LL * a % mod;
        }
        a = a * 1LL * a % mod;
        b >>= 1;
    }
    return res;
}
 
// Driver function to get the
// modular division.
int reverse(long long x)
{
    int mod = 1e9 + 7;
    return bin_pow(x, mod - 2);
}
 
int numberOfPermutations(long long n, long long k)
{
 
    // Updating 'k' with 'n - k'.
    k = n - k;
 
    // Initializing the 'ans' by 1.
    int ans = 1;
 
    // Condition when 'k' is 1.
    if (k == 0 or k == 1) {
        return ans;
    }
 
    // Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2));
 
    // Condition when 'k' is 2.
    if (k == 2) {
        return ans;
    }
 
    // Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3));
 
    // Condition when 'k' is 3.
    if (k == 3) {
        return ans;
    }
 
    // Adding derangement for 'k' = 4.
    int u = mul(n, mul(n - 1, mul(n - 2,
                                  n - 3)));
    ans = add(ans, mul(u, reverse(8)));
    ans = add(ans, mul(u, reverse(4)));
 
    return ans;
}
 
// Driver Code
int main()
{
    long long N = 4;
    long long K = 2;
 
    cout << numberOfPermutations(N, K);
    return 0;
}


Java
// Java program for the above approach
import java.util.ArrayList;
 
class GFG {
 
  // Driver function to get the
  // modular addition.
  static long add(long a, long b)
  {
    long mod = (int)1e9 + 7;
    return ((a % mod) + (b % mod)) % mod;
  }
 
  // Driver function to get the
  // modular multiplication.
  static long mul(long a, long b)
  {
    long mod = (int)1e9 + 7;
    return ((a % mod) * 1 * (b % mod)) % mod;
  }
 
  // Driver function to get the
  // modular binary exponentiation.
  static long bin_pow(long a, long b)
  {
    long mod = (int)1e9 + 7;
    a %= mod;
    long res = 1;
    while (b > 0) {
      if ((b & 1) != 0) {
        res = res * 1 * a % mod;
      }
      a = a * 1 * a % mod;
      b >>= 1;
    }
    return res;
  }
 
  // Driver function to get the
  // modular division.
  static long reverse(long x)
  {
    long mod = (int)1e9 + 7;
    return bin_pow(x, mod - 2);
  }
 
  static long numberOfPermutations(long n, long k)
  {
 
    // Updating 'k' with 'n - k'.
    k = n - k;
 
    // Initializing the 'ans' by 1.
    long ans = 1;
 
    // Condition when 'k' is 1.
    if (k == 0 || k == 1) {
      return ans;
    }
 
    // Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2));
 
    // Condition when 'k' is 2.
    if (k == 2) {
      return ans;
    }
 
    // Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3));
 
    // Condition when 'k' is 3.
    if (k == 3) {
      return ans;
    }
 
    // Adding derangement for 'k' = 4.
    long u = mul(n, mul(n - 1, mul(n - 2,
                                   n - 3)));
    ans = add(ans, mul(u, reverse(8)));
    ans = add(ans, mul(u, reverse(4)));
 
    return ans;
  }
 
  // Driver Code
  public static void main(String args[]) {
    long N = 4;
    long K = 2;
 
    System.out.print( numberOfPermutations(N, K));
  }
}
 
// This code is contributed by sanjoy_62.


Python3
# Python3 code for the above approach:
 
# Driver function to get the
# modular addition.
def add(a, b):
    mod = int(1e9 + 7)
    return ((a % mod) + (b % mod)) % mod
 
# Driver function to get the
# modular multiplication.
def mul(a, b):
 
    mod = int(1e9 + 7)
    return ((a % mod) * (b % mod)) % mod
 
# Driver function to get the
# modular binary exponentiation.
def bin_pow(a, b):
 
    mod = int(1e9 + 7)
    a %= mod
    res = 1
    while (b > 0):
        if (b & 1):
            res = res * a % mod
 
        a = a * a % mod
        b >>= 1
 
    return res
 
# Driver function to get the
# modular division.
def reverse(x):
 
    mod = int(1e9 + 7)
    return bin_pow(x, mod - 2)
 
def numberOfPermutations(n, k):
 
    # Updating 'k' with 'n - k'.
    k = n - k
 
    # Initializing the 'ans' by 1.
    ans = 1
 
    # Condition when 'k' is 1.
    if (k == 0 or k == 1):
        return ans
 
    # Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2))
 
    # Condition when 'k' is 2.
    if (k == 2):
        return ans
 
    # Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3))
 
    # Condition when 'k' is 3.
    if (k == 3):
        return ans
 
    # Adding derangement for 'k' = 4.
    u = mul(n, mul(n - 1, mul(n - 2,
                              n - 3)))
    ans = add(ans, mul(u, reverse(8)))
    ans = add(ans, mul(u, reverse(4)))
 
    return ans
 
# Driver Code
if __name__ == "__main__":
 
    N = 4
    K = 2
 
    print(numberOfPermutations(N, K))
 
    # This code is contributed by rakeshsahni


输出
7

时间复杂度: O(N * N!)

  • 为了递归地找到所有排列,将有 N!递归调用,
  • 对于每次调用,都会有一个运行循环,其中包含 N 次迭代。
  • 因此,整体时间复杂度将为 O(N * N!)。

辅助空间: O(N)

有效方法:有效解决问题的想法是通过计算数组的紊乱。

请按照以下步骤解决问题:

  • 首先固定数组中的位置,使得arr[i] != i ,假设有“ M ”这样的位置。 (0 <= M <= N – K)
  • 为此计算具有固定M的排列数,然后简单地选择具有arr[i] !=i的索引
  • 使用简单的组合公式NCM找到它。
  • 然后,为选择的索引构造一个排列,使得对于每个选择的索引arr[i] !=i ,这只不过是derangements ,并使用穷举搜索找到它。
  • 然后做casework,根据K值找出错位

以下是上述方法的实现:

C++

// C++ code for the above approach:
 
#include 
using namespace std;
 
// Driver function to get the
// modular addition.
int add(long long a, long long b)
{
    int mod = 1e9 + 7;
    return ((a % mod) + (b % mod)) % mod;
}
 
// Driver function to get the
// modular multiplication.
int mul(long long a, long long b)
{
    int mod = 1e9 + 7;
    return ((a % mod) * 1LL * (b % mod)) % mod;
}
 
// Driver function to get the
// modular binary exponentiation.
int bin_pow(long long a, long long b)
{
    int mod = 1e9 + 7;
    a %= mod;
    long long res = 1;
    while (b > 0) {
        if (b & 1) {
            res = res * 1LL * a % mod;
        }
        a = a * 1LL * a % mod;
        b >>= 1;
    }
    return res;
}
 
// Driver function to get the
// modular division.
int reverse(long long x)
{
    int mod = 1e9 + 7;
    return bin_pow(x, mod - 2);
}
 
int numberOfPermutations(long long n, long long k)
{
 
    // Updating 'k' with 'n - k'.
    k = n - k;
 
    // Initializing the 'ans' by 1.
    int ans = 1;
 
    // Condition when 'k' is 1.
    if (k == 0 or k == 1) {
        return ans;
    }
 
    // Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2));
 
    // Condition when 'k' is 2.
    if (k == 2) {
        return ans;
    }
 
    // Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3));
 
    // Condition when 'k' is 3.
    if (k == 3) {
        return ans;
    }
 
    // Adding derangement for 'k' = 4.
    int u = mul(n, mul(n - 1, mul(n - 2,
                                  n - 3)));
    ans = add(ans, mul(u, reverse(8)));
    ans = add(ans, mul(u, reverse(4)));
 
    return ans;
}
 
// Driver Code
int main()
{
    long long N = 4;
    long long K = 2;
 
    cout << numberOfPermutations(N, K);
    return 0;
}

Java

// Java program for the above approach
import java.util.ArrayList;
 
class GFG {
 
  // Driver function to get the
  // modular addition.
  static long add(long a, long b)
  {
    long mod = (int)1e9 + 7;
    return ((a % mod) + (b % mod)) % mod;
  }
 
  // Driver function to get the
  // modular multiplication.
  static long mul(long a, long b)
  {
    long mod = (int)1e9 + 7;
    return ((a % mod) * 1 * (b % mod)) % mod;
  }
 
  // Driver function to get the
  // modular binary exponentiation.
  static long bin_pow(long a, long b)
  {
    long mod = (int)1e9 + 7;
    a %= mod;
    long res = 1;
    while (b > 0) {
      if ((b & 1) != 0) {
        res = res * 1 * a % mod;
      }
      a = a * 1 * a % mod;
      b >>= 1;
    }
    return res;
  }
 
  // Driver function to get the
  // modular division.
  static long reverse(long x)
  {
    long mod = (int)1e9 + 7;
    return bin_pow(x, mod - 2);
  }
 
  static long numberOfPermutations(long n, long k)
  {
 
    // Updating 'k' with 'n - k'.
    k = n - k;
 
    // Initializing the 'ans' by 1.
    long ans = 1;
 
    // Condition when 'k' is 1.
    if (k == 0 || k == 1) {
      return ans;
    }
 
    // Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2));
 
    // Condition when 'k' is 2.
    if (k == 2) {
      return ans;
    }
 
    // Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3));
 
    // Condition when 'k' is 3.
    if (k == 3) {
      return ans;
    }
 
    // Adding derangement for 'k' = 4.
    long u = mul(n, mul(n - 1, mul(n - 2,
                                   n - 3)));
    ans = add(ans, mul(u, reverse(8)));
    ans = add(ans, mul(u, reverse(4)));
 
    return ans;
  }
 
  // Driver Code
  public static void main(String args[]) {
    long N = 4;
    long K = 2;
 
    System.out.print( numberOfPermutations(N, K));
  }
}
 
// This code is contributed by sanjoy_62.

Python3

# Python3 code for the above approach:
 
# Driver function to get the
# modular addition.
def add(a, b):
    mod = int(1e9 + 7)
    return ((a % mod) + (b % mod)) % mod
 
# Driver function to get the
# modular multiplication.
def mul(a, b):
 
    mod = int(1e9 + 7)
    return ((a % mod) * (b % mod)) % mod
 
# Driver function to get the
# modular binary exponentiation.
def bin_pow(a, b):
 
    mod = int(1e9 + 7)
    a %= mod
    res = 1
    while (b > 0):
        if (b & 1):
            res = res * a % mod
 
        a = a * a % mod
        b >>= 1
 
    return res
 
# Driver function to get the
# modular division.
def reverse(x):
 
    mod = int(1e9 + 7)
    return bin_pow(x, mod - 2)
 
def numberOfPermutations(n, k):
 
    # Updating 'k' with 'n - k'.
    k = n - k
 
    # Initializing the 'ans' by 1.
    ans = 1
 
    # Condition when 'k' is 1.
    if (k == 0 or k == 1):
        return ans
 
    # Adding derangement for 'k' = 2.
    ans += mul(mul(n, n - 1), reverse(2))
 
    # Condition when 'k' is 2.
    if (k == 2):
        return ans
 
    # Adding derangement for 'k' = 3.
    ans += mul(mul(n, mul(n - 1, n - 2)),
               reverse(3))
 
    # Condition when 'k' is 3.
    if (k == 3):
        return ans
 
    # Adding derangement for 'k' = 4.
    u = mul(n, mul(n - 1, mul(n - 2,
                              n - 3)))
    ans = add(ans, mul(u, reverse(8)))
    ans = add(ans, mul(u, reverse(4)))
 
    return ans
 
# Driver Code
if __name__ == "__main__":
 
    N = 4
    K = 2
 
    print(numberOfPermutations(N, K))
 
    # This code is contributed by rakeshsahni
输出
7

时间复杂度:O(Log N)

  • 因为,使用二进制指数得到一个数的反模
  • 总时间复杂度为 O(Log N)

辅助空间: O(1)