根号分解算法(Python/Java/C++/Go)
算法一:暴力
暴力处理每个询问,把下标为 $l,l+k,l+2k,\dots$ 的数都乘以 $v$。
最坏情况每次需要 $\mathcal{O}\left(\dfrac{n}{k}\right)$ 的时间,整体 $\mathcal{O}\left(\dfrac{nq}{k}\right)$ 时间。其中 $n$ 是 $\textit{nums}$ 的长度,$q$ 是 $\textit{queries}$ 的长度。
特点:当 $k$ 比较大时,算法比较快。
算法二:差分数组(商分数组)
前置知识:差分数组
如果 $k=1$,我们可以用差分数组(准确来说叫商分数组)记录询问,然后计算商分数组的前缀积,即可得到最终的数组。
商分数组 $d$ 与差分数组的区别是,初始值每一项都是 $1$(乘法单位元);记录询问时,$d[l]$ 乘以 $v$,$d[r+1]$ 除以 $v$,即乘以 $v$ 的逆元。关于逆元,请看 模运算的世界:当加减乘除遇上取模。
对于其他 $k$ 呢?
比如 $k=3$。我们可以把所有询问分为 $k=3$ 组:
- 作用在下标 $0,3,6,\dots$ 上的询问。
- 作用在下标 $1,4,7,\dots$ 上的询问。
- 作用在下标 $2,5,8,\dots$ 上的询问。
比如 $l=1$,$r=9$,更新的下标是 $1,4,7$。在左端点 $1$ 处乘以 $v$,右端点 $7+k=10$ 处除以 $v$(乘以 $v$ 的逆元)。这样我们计算 $1,4,7,10,\dots$ 的前缀积,就可以正确地得到最终数组每一项要乘的数了。
这里的 $7$ 是怎么算的?我们要找 $\le r$ 的最大的 $3k+1$,或者说,要把 $r$ 减少多少。这个减少量等同于当 $l=0$,$r=8$ 时,$r$ 到 $\le r$ 的最近的 $k$ 的倍数的距离,即 $8\bmod k = 2$。一般地,更新的最大下标是 $r-(r-l)\bmod k$。再加上 $k$,得到要做商分标记的位置。
一般地,在左端点 $l$ 处乘以 $v$,右端点 $r-(r-l)\bmod k+k$ 处除以 $v$(乘以 $v$ 的逆元)。
处理每个询问只需要 $\mathcal{O}(\log M)$ 时间计算逆元,其中 $M=10^9+7$。然而,我们需要遍历 $\mathcal{O}(K)$ 个长为 $\mathcal{O}(n)$ 的商分数组,总体需要 $\mathcal{O}(nK + q\log M)$ 的时间。其中 $K$ 是 $k_i$ 的最大值。
特点:当 $K$ 比较小时,算法比较快。
「平衡」两个算法
根据这两个算法的特点,我们可以规定一个阈值 $B$:
- 对于 $k\ge B$ 的询问,使用算法一,即暴力计算。
- 对于 $k < B$ 的询问,使用算法二,即用商分数组记录询问。
总体时间复杂度为
$$
\mathcal{O}\left(\dfrac{nq}{B} + nB + q\log M\right)
$$
根据基本不等式,当 $B=\sqrt q$ 时,上式取到最小值
$$
\mathcal{O}(n\sqrt q + q\log M)
$$
足以通过本题。
优化:比如没有 $k=3$ 的询问,那么对于 $k=3$ 的商分数组,我们既不创建,也不遍历。
本题视频讲解,欢迎点赞关注~
写法一
###py
class Solution:
def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
MOD = 1_000_000_007
n = len(nums)
B = isqrt(len(queries))
diff = [None] * B
for l, r, k, v in queries:
if k < B:
# 懒初始化
if not diff[k]:
diff[k] = [1] * (n + k)
diff[k][l] = diff[k][l] * v % MOD
r = r - (r - l) % k + k
diff[k][r] = diff[k][r] * pow(v, -1, MOD) % MOD
else:
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
for k, d in enumerate(diff):
if not d:
continue
for start in range(k):
mul_d = 1
for i in range(start, n, k):
mul_d = mul_d * d[i] % MOD
nums[i] = nums[i] * mul_d % MOD
return reduce(xor, nums)
###java
class Solution {
private static final int MOD = 1_000_000_007;
public int xorAfterQueries(int[] nums, int[][] queries) {
int n = nums.length;
int B = (int) Math.sqrt(queries.length);
int[][] diff = new int[B][];
for (int[] q : queries) {
int l = q[0], r = q[1], k = q[2];
long v = q[3];
if (k < B) {
// 懒初始化
if (diff[k] == null) {
diff[k] = new int[n + k];
Arrays.fill(diff[k], 1);
}
diff[k][l] = (int) (diff[k][l] * v % MOD);
r = r - (r - l) % k + k;
diff[k][r] = (int) (diff[k][r] * pow(v, MOD - 2) % MOD);
} else {
for (int i = l; i <= r; i += k) {
nums[i] = (int) (nums[i] * v % MOD);
}
}
}
for (int k = 0; k < B; k++) {
int[] d = diff[k];
if (d == null) {
continue;
}
for (int start = 0; start < k; start++) {
long mulD = 1;
for (int i = start; i < n; i += k) {
mulD = mulD * d[i] % MOD;
nums[i] = (int) (nums[i] * mulD % MOD);
}
}
}
int ans = 0;
for (int x : nums) {
ans ^= x;
}
return ans;
}
private long pow(long x, int n) {
long res = 1;
for (; n > 0; n /= 2) {
if (n % 2 > 0) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
}
###cpp
class Solution {
const int MOD = 1'000'000'007;
long long pow(long long x, int n) {
long long res = 1;
for (; n; n /= 2) {
if (n % 2) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
public:
int xorAfterQueries(vector<int>& nums, vector<vector<int>>& queries) {
int n = nums.size();
int B = sqrt(queries.size());
vector<vector<int>> diff(B);
for (auto& q : queries) {
int l = q[0], r = q[1], k = q[2];
long long v = q[3];
if (k < B) {
// 懒初始化
if (diff[k].empty()) {
diff[k].resize(n + k, 1);
}
diff[k][l] = diff[k][l] * v % MOD;
r = r - (r - l) % k + k;
diff[k][r] = diff[k][r] * pow(v, MOD - 2) % MOD;
} else {
for (int i = l; i <= r; i += k) {
nums[i] = nums[i] * v % MOD;
}
}
}
for (int k = 1; k < B; k++) {
auto& d = diff[k];
if (d.empty()) {
continue;
}
for (int start = 0; start < k; start++) {
long long mul_d = 1;
for (int i = start; i < n; i += k) {
mul_d = mul_d * d[i] % MOD;
nums[i] = nums[i] * mul_d % MOD;
}
}
}
return reduce(nums.begin(), nums.end(), 0, bit_xor());
}
};
###go
const mod = 1_000_000_007
func xorAfterQueries(nums []int, queries [][]int) (ans int) {
n := len(nums)
B := int(math.Sqrt(float64(len(queries))))
diff := make([][]int, B)
for _, q := range queries {
l, r, k, v := q[0], q[1], q[2], q[3]
if k < B {
// 懒初始化
if diff[k] == nil {
diff[k] = make([]int, n+k)
for j := range diff[k] {
diff[k][j] = 1
}
}
diff[k][l] = diff[k][l] * v % mod
r = r - (r-l)%k + k
diff[k][r] = diff[k][r] * pow(v, mod-2) % mod
} else {
for i := l; i <= r; i += k {
nums[i] = nums[i] * v % mod
}
}
}
for k, d := range diff {
if d == nil {
continue
}
for start := range k {
mulD := 1
for i := start; i < n; i += k {
mulD = mulD * d[i] % mod
nums[i] = nums[i] * mulD % mod
}
}
}
for _, x := range nums {
ans ^= x
}
return
}
func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}
写法一的优化
把懒初始化的想法进一步扩展。比如 $k=3$ 时,没有遇到 $l\bmod k=2$ 的组,那么这一组的商分数组全为 $1$,无需遍历。
用二维布尔数组记录询问是否有 $(k,l\bmod k)$。
###py
class Solution:
def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
MOD = 1_000_000_007
n = len(nums)
B = isqrt(len(queries))
diff = [None] * B
has = [None] * B
for l, r, k, v in queries:
if k < B:
# 懒初始化
if not diff[k]:
diff[k] = [1] * (n + k)
has[k] = [False] * k
has[k][l % k] = True
diff[k][l] = diff[k][l] * v % MOD
r = r - (r - l) % k + k
diff[k][r] = diff[k][r] * pow(v, -1, MOD) % MOD
else:
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
for k, d in enumerate(diff):
if not d:
continue
for start, b in enumerate(has[k]):
if not b:
continue
mul_d = 1
for i in range(start, n, k):
mul_d = mul_d * d[i] % MOD
nums[i] = nums[i] * mul_d % MOD
return reduce(xor, nums)
###java
class Solution {
private static final int MOD = 1_000_000_007;
public int xorAfterQueries(int[] nums, int[][] queries) {
int n = nums.length;
int B = (int) Math.sqrt(queries.length);
int[][] diff = new int[B][];
boolean[][] has = new boolean[B][];
for (int[] q : queries) {
int l = q[0], r = q[1], k = q[2];
long v = q[3];
if (k < B) {
// 懒初始化
if (diff[k] == null) {
diff[k] = new int[n + k];
Arrays.fill(diff[k], 1);
has[k] = new boolean[k];
}
has[k][l % k] = true;
diff[k][l] = (int) (diff[k][l] * v % MOD);
r = r - (r - l) % k + k;
diff[k][r] = (int) (diff[k][r] * pow(v, MOD - 2) % MOD);
} else {
for (int i = l; i <= r; i += k) {
nums[i] = (int) (nums[i] * v % MOD);
}
}
}
for (int k = 0; k < B; k++) {
int[] d = diff[k];
if (d == null) {
continue;
}
for (int start = 0; start < k; start++) {
if (!has[k][start]) {
continue;
}
long mulD = 1;
for (int i = start; i < n; i += k) {
mulD = mulD * d[i] % MOD;
nums[i] = (int) (nums[i] * mulD % MOD);
}
}
}
int ans = 0;
for (int x : nums) {
ans ^= x;
}
return ans;
}
private long pow(long x, int n) {
long res = 1;
for (; n > 0; n /= 2) {
if (n % 2 > 0) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
}
###cpp
class Solution {
const int MOD = 1'000'000'007;
long long pow(long long x, int n) {
long long res = 1;
for (; n; n /= 2) {
if (n % 2) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
public:
int xorAfterQueries(vector<int>& nums, vector<vector<int>>& queries) {
int n = nums.size();
int B = sqrt(queries.size());
vector<vector<int>> diff(B);
vector<vector<int8_t>> has(B);
for (auto& q : queries) {
int l = q[0], r = q[1], k = q[2];
long long v = q[3];
if (k < B) {
// 懒初始化
if (diff[k].empty()) {
diff[k].resize(n + k, 1);
has[k].resize(k);
}
has[k][l % k] = true;
diff[k][l] = diff[k][l] * v % MOD;
r = r - (r - l) % k + k;
diff[k][r] = diff[k][r] * pow(v, MOD - 2) % MOD;
} else {
for (int i = l; i <= r; i += k) {
nums[i] = nums[i] * v % MOD;
}
}
}
for (int k = 1; k < B; k++) {
auto& d = diff[k];
if (d.empty()) {
continue;
}
for (int start = 0; start < k; start++) {
if (!has[k][start]) {
continue;
}
long long mul_d = 1;
for (int i = start; i < n; i += k) {
mul_d = mul_d * d[i] % MOD;
nums[i] = nums[i] * mul_d % MOD;
}
}
}
return reduce(nums.begin(), nums.end(), 0, bit_xor());
}
};
###go
const mod = 1_000_000_007
func xorAfterQueries(nums []int, queries [][]int) (ans int) {
n := len(nums)
B := int(math.Sqrt(float64(len(queries))))
diff := make([][]int, B)
has := make([][]bool, B)
for _, q := range queries {
l, r, k, v := q[0], q[1], q[2], q[3]
if k < B {
// 懒初始化
if diff[k] == nil {
diff[k] = make([]int, n+k)
for j := range diff[k] {
diff[k][j] = 1
}
has[k] = make([]bool, k)
}
has[k][l%k] = true
diff[k][l] = diff[k][l] * v % mod
r = r - (r-l)%k + k
diff[k][r] = diff[k][r] * pow(v, mod-2) % mod
} else {
for i := l; i <= r; i += k {
nums[i] = nums[i] * v % mod
}
}
}
for k, d := range diff {
if d == nil {
continue
}
for start, b := range has[k] {
if !b {
continue
}
mulD := 1
for i := start; i < n; i += k {
mulD = mulD * d[i] % mod
nums[i] = nums[i] * mulD % mod
}
}
}
for _, x := range nums {
ans ^= x
}
return
}
func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}
复杂度分析
- 时间复杂度:$\mathcal{O}(n\sqrt q + q\log M)$,其中 $n$ 是 $\textit{nums}$ 的长度,$q$ 是 $\textit{queries}$ 的长度,$M=10^9+7$。
- 空间复杂度:$\mathcal{O}(n\sqrt q)$。
写法二
把询问按照 $(k,l\bmod k)$ 分组,对于每一组计算商分。这样空间复杂度更小。
###py
class Solution:
def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
MOD = 1_000_000_007
n = len(nums)
B = isqrt(len(queries))
groups = [[] for _ in range(B)]
for l, r, k, v in queries:
if k < B:
groups[k].append((l, r, v))
else:
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
for k, g in enumerate(groups):
if not g:
continue
buckets = [[] for _ in range(k)]
for t in g:
buckets[t[0] % k].append(t)
for start, bucket in enumerate(buckets):
if not bucket:
continue
if len(bucket) == 1:
# 只有一个询问,直接暴力
l, r, v = bucket[0]
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
continue
m = (n - start - 1) // k + 1
diff = [1] * (m + 1)
for l, r, v in bucket:
diff[l // k] = diff[l // k] * v % MOD
r = (r - start) // k + 1
diff[r] = diff[r] * pow(v, -1, MOD) % MOD
mul_d = 1
for i in range(m):
mul_d = mul_d * diff[i] % MOD
j = start + i * k
nums[j] = nums[j] * mul_d % MOD
return reduce(xor, nums)
###java
class Solution {
private static final int MOD = 1_000_000_007;
public int xorAfterQueries(int[] nums, int[][] queries) {
int n = nums.length;
int B = (int) Math.sqrt(queries.length);
List<int[]>[] groups = new ArrayList[B];
Arrays.setAll(groups, _ -> new ArrayList<>());
for (int[] q : queries) {
int l = q[0], r = q[1], k = q[2], v = q[3];
if (k < B) {
groups[k].add(new int[]{l, r, v});
} else {
for (int i = l; i <= r; i += k) {
nums[i] = (int) ((long) nums[i] * v % MOD);
}
}
}
int[] diff = new int[n + 1];
for (int k = 1; k < B; k++) {
List<int[]> g = groups[k];
if (g.isEmpty()) {
continue;
}
List<int[]>[] buckets = new ArrayList[k];
Arrays.setAll(buckets, _ -> new ArrayList<>());
for (int[] t : g) {
buckets[t[0] % k].add(t);
}
for (int start = 0; start < k; start++) {
List<int[]> bucket = buckets[start];
if (bucket.isEmpty()) {
continue;
}
if (bucket.size() == 1) {
// 只有一个询问,直接暴力
int[] t = bucket.get(0);
int l = t[0], r = t[1];
long v = t[2];
for (int i = l; i <= r; i += k) {
nums[i] = (int) (nums[i] * v % MOD);
}
continue;
}
int m = (n - start - 1) / k + 1;
Arrays.fill(diff, 0, m, 1);
for (int[] t : bucket) {
int l = t[0];
long v = t[2];
diff[l / k] = (int) (diff[l / k] * v % MOD);
int r = (t[1] - start) / k + 1;
diff[r] = (int) (diff[r] * pow(v, MOD - 2) % MOD);
}
long mulD = 1;
for (int i = 0; i < m; i++) {
mulD = mulD * diff[i] % MOD;
int j = start + i * k;
nums[j] = (int) (nums[j] * mulD % MOD);
}
}
}
int ans = 0;
for (int x : nums) {
ans ^= x;
}
return ans;
}
private long pow(long x, int n) {
long res = 1;
for (; n > 0; n /= 2) {
if (n % 2 > 0) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
}
###cpp
class Solution {
const int MOD = 1'000'000'007;
long long pow(long long x, int n) {
long long res = 1;
for (; n; n /= 2) {
if (n % 2) {
res = res * x % MOD;
}
x = x * x % MOD;
}
return res;
}
public:
int xorAfterQueries(vector<int>& nums, vector<vector<int>>& queries) {
int n = nums.size();
int B = ceil(sqrt(queries.size()));
vector<vector<tuple<int, int, int>>> groups(B);
for (auto& q : queries) {
int l = q[0], r = q[1], k = q[2], v = q[3];
if (k < B) {
groups[k].emplace_back(l, r, v);
} else {
for (int i = l; i <= r; i += k) {
nums[i] = 1LL * nums[i] * v % MOD;
}
}
}
vector<int> diff(n + 1);
for (int k = 1; k < B; k++) {
auto& g = groups[k];
if (g.empty()) {
continue;
}
vector<vector<tuple<int, int, int>>> buckets(k);
for (auto& t : g) {
buckets[get<0>(t) % k].emplace_back(t);
}
for (int start = 0; start < k; start++) {
auto& bucket = buckets[start];
if (bucket.empty()) {
continue;
}
if (bucket.size() == 1) {
// 只有一个询问,直接暴力
auto& [l, r, v] = bucket[0];
for (int i = l; i <= r; i += k) {
nums[i] = 1LL * nums[i] * v % MOD;
}
continue;
}
int m = (n - start - 1) / k + 1;
fill(diff.begin(), diff.begin() + m, 1);
for (auto& [l, r, v] : bucket) {
diff[l / k] = 1LL * diff[l / k] * v % MOD;
r = (r - start) / k + 1;
diff[r] = diff[r] * pow(v, MOD - 2) % MOD;
}
long long mul_d = 1;
for (int i = 0; i < m; i++) {
mul_d = mul_d * diff[i] % MOD;
int j = start + i * k;
nums[j] = nums[j] * mul_d % MOD;
}
}
}
return reduce(nums.begin(), nums.end(), 0, bit_xor());
}
};
###go
const mod = 1_000_000_007
func xorAfterQueries(nums []int, queries [][]int) (ans int) {
n := len(nums)
B := int(math.Sqrt(float64(len(queries))))
type tuple struct{ l, r, v int }
groups := make([][]tuple, B)
for _, q := range queries {
l, r, k, v := q[0], q[1], q[2], q[3]
if k < B {
groups[k] = append(groups[k], tuple{l, r, v})
} else {
for i := l; i <= r; i += k {
nums[i] = nums[i] * v % mod
}
}
}
diff := make([]int, n+1)
for k, g := range groups {
if g == nil {
continue
}
buckets := make([][]tuple, k)
for _, t := range g {
buckets[t.l%k] = append(buckets[t.l%k], t)
}
for start, bucket := range buckets {
if bucket == nil {
continue
}
if len(bucket) == 1 {
// 只有一个询问,直接暴力
t := bucket[0]
for i := t.l; i <= t.r; i += k {
nums[i] = nums[i] * t.v % mod
}
continue
}
for i := range (n-start-1)/k + 1 {
diff[i] = 1
}
for _, t := range bucket {
diff[t.l/k] = diff[t.l/k] * t.v % mod
r := (t.r-start)/k + 1
diff[r] = diff[r] * pow(t.v, mod-2) % mod
}
mulD := 1
for i := range (n-start-1)/k + 1 {
mulD = mulD * diff[i] % mod
j := start + i*k
nums[j] = nums[j] * mulD % mod
}
}
}
for _, x := range nums {
ans ^= x
}
return
}
func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}
进一步优化
如果询问的前三项是一样的,就把这样的询问合并在一起。
###py
class Solution:
def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
MOD = 1_000_000_007
prod = defaultdict(lambda: 1)
for l, r, k, v in queries:
t = (l, r, k)
prod[t] = prod[t] * v % MOD
n = len(nums)
B = isqrt(len(prod))
groups = [[] for _ in range(B)]
for (l, r, k), v in prod.items():
if k < B:
groups[k].append((l, r, v))
else:
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
for k, g in enumerate(groups):
if not g:
continue
buckets = [[] for _ in range(k)]
for t in g:
buckets[t[0] % k].append(t)
for start, bucket in enumerate(buckets):
if not bucket:
continue
if len(bucket) == 1:
# 只有一个询问,直接暴力
l, r, v = bucket[0]
for i in range(l, r + 1, k):
nums[i] = nums[i] * v % MOD
continue
m = (n - start - 1) // k + 1
diff = [1] * (m + 1)
for l, r, v in bucket:
diff[l // k] = diff[l // k] * v % MOD
r = (r - start) // k + 1
diff[r] = diff[r] * pow(v, -1, MOD) % MOD
mul_d = 1
for i in range(m):
mul_d = mul_d * diff[i] % MOD
j = start + i * k
nums[j] = nums[j] * mul_d % MOD
return reduce(xor, nums)
复杂度分析
- 时间复杂度:$\mathcal{O}(n\sqrt q + q\log M)$,其中 $n$ 是 $\textit{nums}$ 的长度,$q$ 是 $\textit{queries}$ 的长度,$M=10^9+7$。
- 空间复杂度:$\mathcal{O}(n + q)$。
分类题单
- 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
- 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
- 单调栈(基础/矩形面积/贡献法/最小字典序)
- 网格图(DFS/BFS/综合应用)
- 位运算(基础/性质/拆位/试填/恒等式/思维)
- 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
- 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
- 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
- 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
- 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
- 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
- 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)
欢迎关注 B站@灵茶山艾府