普通视图

发现新文章,点击刷新页面。
今天 — 2025年11月10日首页

从分治到单调栈,简洁写法(Python/Java/C++/Go)

作者 endlesscheng
2025年5月11日 07:18

分治

回顾示例 3 $\textit{nums}=[1,2,1,2,1,2]$ 的操作过程:

  • 首先,只需要一次操作(选择整个数组),就可以把所有的最小值 $1$ 都变成 $0$。现在数组是 $[0,2,0,2,0,2]$。
  • 这些被 $0$ 分割开的 $2$,无法合在一起操作(因为子数组会包含 $0$,导致 $2$ 无法变成 $0$),只能一个一个操作。

一般地:

  1. 先通过一次操作,把 $\textit{nums}$ 的最小值都变成 $0$(如果最小值已经是 $0$ 则跳过这步)。
  2. 此时 $\textit{nums}$ 被这些 $0$ 划分成了若干段,后续操作只能在每段内部,不能跨段操作(否则子数组会包含 $0$)。每一段是规模更小的子问题,可以用第一步的方法解决。这样我们可以写一个递归去处理。递归边界:如果操作后全为 $0$,直接返回。

找最小值可以用 ST 表或者线段树,但这种做法很麻烦。有没有简单的做法呢?

单调栈

从左往右遍历数组,只在「必须要操作」的时候,才把答案加一。

什么时候必须要操作?

示例 3 $\textit{nums}=[1,2,1,2,1,2]$,因为 $2$ 左右两侧都有小于 $2$ 的数,需要单独操作。

又例如 $\textit{nums}=[1,2,3,2,1]$:

  • 遍历到第二个 $2$ 时,可以知道 $3$ 左右两侧都有小于 $3$ 的数,所以 $3$ 必须要操作一次,答案加一。注意这不表示第一次操作的是 $3$,而是某次操作会把 $3$ 变成 $0$。
  • 遍历到末尾 $1$ 时,可以知道中间的两个 $2$,左边有 $1$,右边也有 $1$,必须操作一次,答案加一。比如选择 $[2,3,2]$ 可以把这两个 $2$ 都变成 $0$。
  • 最后,数组中的 $1$ 需要操作一次都变成 $0$。

我们怎么知道「$3$ 左右两侧都有小于 $3$ 的数」?

遍历数组的同时,把遍历过的元素用栈记录:

  • 如果当前元素比栈顶大(或者栈为空),那么直接入栈。
  • 如果当前元素比栈顶小,那么对于栈顶来说,左边(栈顶倒数第二个数)比栈顶小(原因后面解释),右边(当前元素)也比栈顶小,所以栈顶必须操作一次。然后弹出栈顶。
  • 如果当前元素等于栈顶,可以在同一次操作中把当前元素与栈顶都变成 $0$,所以无需入栈。注意这保证了栈中没有重复元素。

如果当前元素比栈顶小,就弹出栈顶,我们会得到一个底小顶大的单调栈,这就保证了「对于栈顶来说,左边(栈顶倒数第二个数)比栈顶小」。

遍历结束后,因为栈是严格递增的,所以栈中每个非零数字都需要操作一次。

代码实现时,可以直接把 $\textit{nums}$ 当作栈。

具体请看 视频讲解,欢迎点赞关注~

###py

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        ans = 0
        st = []
        for x in nums:
            while st and x < st[-1]:
                st.pop()
                ans += 1
            # 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if not st or x != st[-1]:
                st.append(x)
        return ans + len(st) - (st[0] == 0)  # 0 不需要操作

###py

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        ans = 0
        top = -1  # 栈顶下标(把 nums 当作栈)
        for x in nums:
            while top >= 0 and x < nums[top]:
                top -= 1  # 出栈
                ans += 1
            # 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if top < 0 or x != nums[top]:
                top += 1
                nums[top] = x  # 入栈
        return ans + top + (nums[0] > 0)

###java

class Solution {
    public int minOperations(int[] nums) {
        int ans = 0;
        int top = -1; // 栈顶下标(把 nums 当作栈)
        for (int x : nums) {
            while (top >= 0 && x < nums[top]) {
                top--; // 出栈
                ans++;
            }
            // 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if (top < 0 || x != nums[top]) {
                nums[++top] = x; // 入栈
            }
        }
        return ans + top + (nums[0] > 0 ? 1 : 0);
    }
}

###cpp

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int ans = 0;
        int top = -1; // 栈顶下标(把 nums 当作栈)
        for (int x : nums) {
            while (top >= 0 && x < nums[top]) {
                top--; // 出栈
                ans++;
            }
            // 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if (top < 0 || x != nums[top]) {
                nums[++top] = x; // 入栈
            }
        }
        return ans + top + (nums[0] > 0);
    }
};

###go

func minOperations(nums []int) (ans int) {
st := nums[:0] // 原地
for _, x := range nums {
for len(st) > 0 && x < st[len(st)-1] {
st = st[:len(st)-1]
ans++
}
// 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
if len(st) == 0 || x != st[len(st)-1] {
st = append(st, x)
}
}
if st[0] == 0 { // 0 不需要操作
ans--
}
return ans + len(st)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。每个元素至多入栈出栈各一次,所以二重循环的循环次数是 $\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(n)$ 或 $\mathcal{O}(1)$。原地做法可以做到 $\mathcal{O}(1)$ 空间。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天 — 2025年11月9日首页

O(log) 辗转相除(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2022年2月13日 12:13

下文把 $\textit{num}_1$ 和 $\textit{num}_2$ 分别记作 $x$ 和 $y$。

根据题意,如果 $x\ge y$,那么 $x$ 要不断减去 $y$,直到小于 $y$。这和商的定义是一样的,所以把 $x$ 减少到小于 $y$ 的操作数为 $\left\lfloor\dfrac{x}{y}\right\rfloor$。

$x<y$ 后,$x$ 变成 $x\bmod y$。

我们可以交换 $x$ 和 $y$,重复上述过程,这样无需实现 $x<y$ 时把 $y$ 变小的逻辑。

循环直到 $y=0$ 为止。

累加所有操作数,即为答案。

class Solution:
    def countOperations(self, x: int, y: int) -> int:
        ans = 0
        while y > 0:
            ans += x // y  # x 变成 x%y
            x, y = y, x % y
        return ans
class Solution {
    public int countOperations(int x, int y) {
        int ans = 0;
        while (y > 0) {
            ans += x / y;
            int r = x % y; // x 变成 r
            x = y; // 交换 x 和 y
            y = r;
        }
        return ans;
    }
}
class Solution {
public:
    int countOperations(int x, int y) {
        int ans = 0;
        while (y > 0) {
            ans += x / y;
            x %= y;
            swap(x, y);
        }
        return ans;
    }
};
int countOperations(int x, int y) {
    int ans = 0;
    while (y > 0) {
        ans += x / y;
        int r = x % y; // x 变成 r
        x = y; // 交换 x 和 y
        y = r;
    }
    return ans;
}
func countOperations(x, y int) (ans int) {
for y > 0 {
ans += x / y // x 变成 x%y
x, y = y, x%y
}
return
}
var countOperations = function(x, y) {
    let ans = 0;
    while (y > 0) {
        ans += Math.floor(x / y); // x 变成 x%y
        [x, y] = [y, x % y];
    }
    return ans;
};
impl Solution {
    pub fn count_operations(mut x: i32, mut y: i32) -> i32 {
        let mut ans = 0;
        while y > 0 {
            ans += x / y; // x 变成 x%y
            (x, y) = (y, x % y);
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\log\max(x,y))$。当 $x$ 和 $y$ 为斐波那契数列中的相邻两项时,达到最坏情况。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天以前首页

推式子(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年11月5日 13:32

九连环,玩具也,以铜制之。欲使九环同贯于柱上,则先上第一环,再上第二环,而下其第一环,更上第三环,而下其第一二环,再上第四环,如是更迭上下,凡八十一次,而九环毕上矣。解之之法,先下其第一环,次下其第三环,更上第一环,而并下其第一二环,又下其第三环,如是更迭上下,凡八十一次,而九环毕下矣。

——《清稗类钞》

下文的 $n$ 均为二进制。

题意解读:第二种操作,翻转的是 $n$ 最低的 $1$ 左侧相邻的比特位。例如 $10010$ 操作后是 $10110$。注意操作是可逆的,$10110$ 执行第二种操作得到 $10010$。

逆向思维,从 $0$ 开始操作,我们依次得到的数字是什么?

由于连续两次相同操作后数字不变,如果要最小化操作次数,不能连续执行相同的操作,所以只能第一种操作和第二种操作交替执行。这意味着最优操作方案是唯一的,如下:

$$
0\to 1\to 11\to 10\to 110\to 111\to 101 \to 100 \to 1100\to \cdots
$$

从特殊到一般,考察从 $0$ 到 $2^k$ 的操作次数:

  • 从 $0$ 到 $10$ 需要操作 $3$ 次。
  • 从 $0$ 到 $100$ 需要操作 $7$ 次。
  • 从 $0$ 到 $1000$ 需要操作多少次?

仔细考察从 $0$ 到 $100$ 的过程,其中后半段从 $110$ 到 $100$ 的过程是 $110\to 111\to 101 \to 100$,只看低 $2$ 位是 $10\to 11\to 01 \to 00$,倒过来看是 $00\to 01\to 11\to 10$,这和 $0$ 到 $10$ 的过程是完全一样的!

定义 $f(n)$ 表示把 $0$ 变成 $n$ 的最小操作次数,这也等于把 $n$ 变成 $0$ 最小操作次数。那么有

$$
0 \xrightarrow{操作\ f(10)\ 次} 10 \xrightarrow{操作\ 1\ 次} 110 \xrightarrow{操作\ f(10)\ 次} 100
$$

同理可得

$$
0 \xrightarrow{操作\ f(100)\ 次} 100 \xrightarrow{操作\ 1\ 次} 1100 \xrightarrow{操作\ f(100)\ 次} 1000
$$

所以从 $0$ 到 $1000$ 需要操作 $f(100) + 1 + f(100) = 7+1+7=15$ 次。

一般地,我们有

$$
f(2^k) = 2f(2^{k-1}) + 1
$$

两边同时加一,得

$$
f(2^k) + 1 = 2(f(2^{k-1}) + 1)
$$

所以 $f(2^k) + 1$ 是个等比数列,公比为 $2$,初始值为 $f(1)+1 = 2$,得

$$
f(2^k) + 1 = 2^{k+1}
$$

$$
f(2^k) = 2^{k+1} - 1
$$

注:另一种理解角度是,从 $0$ 到 $2^k$ 的过程中,恰好访问了 $[0,2^{k+1}-1]$ 中的每个整数各一次,所以需要操作 $2^{k+1}-1$ 次。

我们解决了 $n$ 是 $2$ 的幂的情况。下面考虑一般情况。

再来看这个过程

$$
0\to 1\to 11\to 10\to 110\to 111\to 101 \to 100
$$

其中从 $0$ 到 $111$ 需要操作多少次?

  • 先计算从 $0$ 到 $100$ 的操作次数 $f(100)$。
  • 然后减去从 $111$ 到 $100$ 的操作次数。这等于从 $11$ 到 $00$ 的操作次数,即 $f(11)$。

所以 $f(111) = f(100) - f(11)$。

一般地,设 $n$ 的二进制长度为 $k$,我们有

$$
\begin{aligned}
f(n) &= f(2^{k-1}) - f(n - 2^{k-1}) \
&= 2^k - 1 - f(n - 2^{k-1}) \
\end{aligned}
$$

其中 $n - 2^{k-1}$ 表示 $n$ 去掉最高的 $1$ 后的值。

递归边界:$f(0) = 0$。

注:九连环需要 $f(2^9-1) = 341$ 次操作。开头那段文言由于把多次操作算作一次,给出的操作次数比实际的少。

写法一:自顶向下

###py

class Solution:
    def minimumOneBitOperations(self, n: int) -> int:
        if n == 0:
            return 0
        k = n.bit_length()
        return (1 << k) - 1 - self.minimumOneBitOperations(n - (1 << (k - 1)))

###java

class Solution {
    public int minimumOneBitOperations(int n) {
        if (n == 0) {
            return 0;
        }
        int k = 32 - Integer.numberOfLeadingZeros(n);
        return (1 << k) - 1 - minimumOneBitOperations(n - (1 << (k - 1)));
    }
}

###java

class Solution {
    public int minimumOneBitOperations(int n) {
        if (n == 0) {
            return 0;
        }
        int hb = Integer.highestOneBit(n);
        return (hb << 1) - 1 - minimumOneBitOperations(n - hb);
    }
}

###cpp

class Solution {
public:
    int minimumOneBitOperations(int n) {
        if (n == 0) {
            return 0;
        }
        int k = bit_width((uint32_t) n);
        return (1 << k) - 1 - minimumOneBitOperations(n - (1 << (k - 1)));
    }
};

###c

int minimumOneBitOperations(int n) {
    if (n == 0) {
        return 0;
    }
    int k = 32 - __builtin_clz(n);
    return (1 << k) - 1 - minimumOneBitOperations(n - (1 << (k - 1)));
}

###go

func minimumOneBitOperations(n int) int {
if n == 0 {
return 0
}
k := bits.Len(uint(n))
return 1<<k - 1 - minimumOneBitOperations(n-1<<(k-1))
}

###js

var minimumOneBitOperations = function(n) {
    if (n === 0) {
        return 0;
    }
    let k = 32 - Math.clz32(n);
    return (1 << k) - 1 - minimumOneBitOperations(n - (1 << (k - 1)));
};

###rust

impl Solution {
    pub fn minimum_one_bit_operations(n: i32) -> i32 {
        if n == 0 {
            return 0;
        }
        let k = 32 - n.leading_zeros();
        (1 << k) - 1 - Self::minimum_one_bit_operations(n - (1 << (k - 1)))
    }
}

写法二:自底向上

递归是从高到低遍历 $n$ 的值为 $1$ 的比特位。

也可以从低到高遍历这些 $1$。

初始化答案 $\textit{ans}= 0$,即递归边界。

计算 $n$ 的最低位的 $1$,即 $\text{lowbit}$,原理见 从集合论到位运算,常见位运算技巧分类总结!

这里 $\text{lowbit}$ 相当于上面的 $2^{k-1}$。

然后更新 $\textit{ans}$ 为 $\text{lowbit}\cdot 2 - 1 - \textit{ans}$,相当于去掉递归的「递」,只在「归」的过程中计算答案。

###py

class Solution:
    def minimumOneBitOperations(self, n: int) -> int:
        ans = 0
        while n > 0:
            lb = n & -n  # n 的最低 1
            ans = (lb << 1) - 1 - ans
            n ^= lb  # 去掉 n 的最低 1
        return ans

###java

class Solution {
    public int minimumOneBitOperations(int n) {
        int ans = 0;
        while (n > 0) {
            int lb = n & -n; // n 的最低 1
            ans = (lb << 1) - 1 - ans;
            n ^= lb; // 去掉 n 的最低 1
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int minimumOneBitOperations(int n) {
        int ans = 0;
        while (n > 0) {
            int lb = n & -n; // n 的最低 1
            ans = (lb << 1) - 1 - ans;
            n ^= lb; // 去掉 n 的最低 1
        }
        return ans;
    }
};

###c

int minimumOneBitOperations(int n) {
    int ans = 0;
    while (n > 0) {
        int lb = n & -n; // n 的最低 1
        ans = (lb << 1) - 1 - ans;
        n ^= lb; // 去掉 n 的最低 1
    }
    return ans;
}

###go

func minimumOneBitOperations(n int) (ans int) {
for n > 0 {
lb := n & -n // n 的最低 1
ans = lb<<1 - 1 - ans
n ^= lb // 去掉 n 的最低 1
}
return
}

###js

var minimumOneBitOperations = function(n) {
    let ans = 0;
    while (n > 0) {
        const lb = n & -n; // n 的最低 1
        ans = (lb << 1) - 1 - ans;
        n ^= lb; // 去掉 n 的最低 1
    }
    return ans;
};

###rust

impl Solution {
    pub fn minimum_one_bit_operations(mut n: i32) -> i32 {
        let mut ans = 0;
        while n > 0 {
            let lb = n & -n; // n 的最低 1
            ans = (lb << 1) - 1 - ans;
            n ^= lb; // 去掉 n 的最低 1
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\log n)$。循环次数为 $n$ 的二进制中的 $1$ 的个数。特别地,如果 $n$ 是 $2$ 的幂,这个做法只需循环一次。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

二分答案+前缀和+差分数组+贪心(Python/Java/C++/Go)

作者 endlesscheng
2023年1月8日 00:28

转化

如果存在一种方案,可以让所有城市的电量都 $\ge \textit{low}$,那么也可以 $\ge \textit{low}-1$ 或者更小的值(要求更宽松)。

如果不存在让所有城市的电量都 $\ge \textit{low}$ 的方案,那么也不存在 $\ge \textit{low}+1$ 或者更大的方案(要求更苛刻)。

据此,可以二分猜答案。关于二分算法的原理,请看 二分查找 红蓝染色法【基础算法精讲 04】

现在问题转化成一个判定性问题:

  • 给定 $\textit{low}$,是否存在建造 $k$ 座额外电站的方案,使得所有城市的电量都 $\ge \textit{low}$?

如果存在,说明答案 $\ge \textit{low}$,否则答案 $<\textit{low}$。

思路

由于已经建造的电站是可以发电的,我们需要在二分之前,用 $\textit{stations}$ 计算每个城市的初始电量 $\textit{power}$。这可以用前缀和或者滑动窗口做,具体后面解释。

然后从左到右遍历 $\textit{power}$,挨个处理。如果 $\textit{power}[i] < \textit{low}$,就需要建造电站,额外提供 $\textit{low} - \textit{power}[i]$ 的电力。

在哪建造电站最好呢?

由于我们是从左到右遍历的,在 $i$ 左侧的城市已经处理好了,所以建造的电站越靠右越好,尽可能多地覆盖没遍历到的城市。具体地,$i$ 应当恰好在电站供电范围的边缘上,也就是把电站建在 $i+r$ 的位置,使得电站覆盖范围为 $[i,i+2r]$。

我们要建 $m = \textit{low} - \textit{power}[i]$ 个供电站,也就是把下标在 $[i,i+2r]$ 中的城市的电量都增加 $m$。

这里有一个「区间增加 $m$」的需求,用差分数组实现,原理见 差分数组原理讲解

我们要一边做差分更新,一边计算差分数组的前缀和,以得到当前城市的实际电量。

遍历的同时,累计额外建造的电站数量,如果超过 $k$,不满足要求,可以提前跳出循环。

细节

下面代码采用开区间二分,这仅仅是二分的一种写法,使用闭区间或者半闭半开区间都是可以的,喜欢哪种写法就用哪种。

  • 开区间左端点初始值:$\min(\textit{power}) + \left\lfloor\dfrac{k}{n}\right\rfloor$。把 $k$ 均摊,即使 $r=0$,每个城市都能至少分到 $\left\lfloor\dfrac{k}{n}\right\rfloor$ 的额外电量,所以 $\textit{low} = \min(\textit{power}) + \left\lfloor\dfrac{k}{n}\right\rfloor$ 时一定满足要求。
  • 开区间右端点初始值:$\min(\textit{power}) + k + 1$。即使把所有额外电站都建给电量最小的城市,也无法满足要求。

对于开区间写法,简单来说 check(mid) == true 时更新的是谁,最后就返回谁。相比其他二分写法,开区间写法不需要思考加一减一等细节,更简单。推荐使用开区间写二分。

写法一:前缀和

能覆盖城市 $i$ 的电站下标范围是 $[i-r,i+r]$。注意下标不能越界,所以实际范围是 $[\max(i-r,0),\min(i+r,n-1)]$。

我们需要计算 $\textit{stations}$ 的这个范围(子数组)的和。

计算 $\textit{stations}$ 的 前缀和 数组后,可以 $\mathcal{O}(1)$ 计算 $\textit{stations}$ 的任意子数组的和。

class Solution:
    def maxPower(self, stations: List[int], r: int, k: int) -> int:
        n = len(stations)
        # 前缀和
        s = list(accumulate(stations, initial=0))
        # 初始电量
        power = [s[min(i + r + 1, n)] - s[max(i - r, 0)] for i in range(n)]

        def check(low: int) -> bool:
            diff = [0] * n  # 差分数组
            sum_d = built = 0
            for i, p in enumerate(power):
                sum_d += diff[i]  # 累加差分值
                m = low - (p + sum_d)
                if m <= 0:
                    continue
                # 需要在 i+r 额外建造 m 个供电站
                built += m
                if built > k:  # 不满足要求
                    return False
                # 把区间 [i, i+2r] 增加 m
                sum_d += m  # 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                if (right := i + r * 2 + 1) < n:
                    diff[right] -= m
            return True

        # 开区间二分
        mn = min(power)
        left, right = mn + k // n, mn + k + 1
        while left + 1 < right:
            mid = (left + right) // 2
            if check(mid):
                left = mid
            else:
                right = mid
        return left
class Solution:
    def maxPower(self, stations: List[int], r: int, k: int) -> int:
        n = len(stations)
        # 前缀和
        s = list(accumulate(stations, initial=0))
        # 初始电量
        power = [s[min(i + r + 1, n)] - s[max(i - r, 0)] for i in range(n)]

        def check(low: int) -> bool:
            low += 1  # 二分最小的不满足要求的 low(符合库函数),这样最终返回的就是最大的满足要求的 low
            diff = [0] * n  # 差分数组
            sum_d = built = 0
            for i, p in enumerate(power):
                sum_d += diff[i]  # 累加差分值
                m = low - (p + sum_d)
                if m <= 0:
                    continue
                # 需要在 i+r 额外建造 m 个供电站
                built += m
                if built > k:  # 不满足要求
                    return True
                # 把区间 [i, i+2r] 增加 m
                sum_d += m  # 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                if (right := i + r * 2 + 1) < n:
                    diff[right] -= m
            return False

        mn = min(power)
        left, right = mn + k // n, mn + k
        return bisect_left(range(right), True, lo=left, key=check)
class Solution {
    public long maxPower(int[] stations, int r, int k) {
        int n = stations.length;
        // 前缀和
        long[] sum = new long[n + 1];
        for (int i = 0; i < n; i++) {
            sum[i + 1] = sum[i] + stations[i];
        }

        // 初始电量
        long[] power = new long[n];
        long mn = Long.MAX_VALUE;
        for (int i = 0; i < n; i++) {
            power[i] = sum[Math.min(i + r + 1, n)] - sum[Math.max(i - r, 0)];
            mn = Math.min(mn, power[i]);
        }

        // 开区间二分
        long left = mn + k / n;
        long right = mn + k + 1;
        while (left + 1 < right) {
            long mid = left + (right - left) / 2;
            if (check(mid, power, r, k)) {
                left = mid;
            } else {
                right = mid;
            }
        }
        return left;
    }

    private boolean check(long low, long[] power, int r, int k) {
        int n = power.length;
        long[] diff = new long[n + 1];
        long sumD = 0;
        long built = 0;
        for (int i = 0; i < n; i++) {
            sumD += diff[i]; // 累加差分值
            long m = low - (power[i] + sumD);
            if (m <= 0) {
                continue;
            }
            // 需要在 i+r 额外建造 m 个供电站
            built += m;
            if (built > k) { // 不满足要求
                return false;
            }
            // 把区间 [i, i+2r] 增加 m
            sumD += m; // 由于 diff[i] 后面不会再访问,我们直接加到 sumD 中
            diff[Math.min(i + r * 2 + 1, n)] -= m;
        }
        return true;
    }
}
class Solution {
public:
    long long maxPower(vector<int>& stations, int r, int k) {
        int n = stations.size();
        // 前缀和
        vector<long long> sum(n + 1);
        for (int i = 0; i < n; i++) {
            sum[i + 1] = sum[i] + stations[i];
        }

        // 初始电量
        vector<long long> power(n);
        long long mn = LLONG_MAX;
        for (int i = 0; i < n; i++) {
            power[i] = sum[min(i + r + 1, n)] - sum[max(i - r, 0)];
            mn = min(mn, power[i]);
        }

        auto check = [&](long long low) -> bool {
            vector<long long> diff(n + 1);
            long long sum_d = 0, built = 0;
            for (int i = 0; i < n; i++) {
                sum_d += diff[i]; // 累加差分值
                long long m = low - (power[i] + sum_d);
                if (m <= 0) {
                    continue;
                }
                // 需要在 i+r 额外建造 m 个供电站
                built += m;
                if (built > k) { // 不满足要求
                    return false;
                }
                // 把区间 [i, i+2r] 增加 m
                sum_d += m; // 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                diff[min(i + r * 2 + 1, n)] -= m;
            }
            return true;
        };

        // 开区间二分
        long long left = mn + k / n, right = mn + k + 1;
        while (left + 1 < right) {
            long long mid = left + (right - left) / 2;
            (check(mid) ? left : right) = mid;
        }
        return left;
    }
};
func maxPower(stations []int, r int, k int) int64 {
n := len(stations)
// 前缀和
sum := make([]int, n+1)
for i, x := range stations {
sum[i+1] = sum[i] + x
}

// 初始电量
power := make([]int, n)
mn := math.MaxInt
for i := range power {
power[i] = sum[min(i+r+1, n)] - sum[max(i-r, 0)]
mn = min(mn, power[i])
}

// 二分答案
left := mn + k/n
right := mn + k
ans := left + sort.Search(right-left, func(low int) bool {
// 这里 +1 是为了二分最小的不满足要求的 low(符合库函数),这样最终返回的就是最大的满足要求的 low
low += left + 1
diff := make([]int, n+1) // 差分数组
sumD, built := 0, 0
for i, p := range power {
sumD += diff[i] // 累加差分值
m := low - (p + sumD)
if m <= 0 {
continue
}
// 需要在 i+r 额外建造 m 个供电站
built += m
if built > k { // 不满足要求
return true
}
// 把区间 [i, i+2r] 增加 m
sumD += m // 由于 diff[i] 后面不会再访问,我们直接加到 sumD 中
diff[min(i+r*2+1, n)] -= m
}
return false
})
return int64(ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log k)$,其中 $n$ 是 $\textit{stations}$ 的长度。二分 $\mathcal{O}(\log k)$ 次,每次 $\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(n)$。

写法二:滑动窗口

滑动窗口 计算 $\textit{power}$。

class Solution:
    def maxPower(self, stations: List[int], r: int, k: int) -> int:
        n = len(stations)
        # 滑动窗口
        s = sum(stations[:r])  # 先计算 [0, r-1] 的发电量,为第一个窗口做准备
        power = [0] * n
        for i in range(n):
            # 右边进
            if (right := i + r) < n:
                s += stations[right]
            # 左边出
            if (left := i - r - 1) >= 0:
                s -= stations[left]
            power[i] = s

        def check(low: int) -> bool:
            diff = [0] * n  # 差分数组
            sum_d = built = 0
            for i, p in enumerate(power):
                sum_d += diff[i]  # 累加差分值
                m = low - (p + sum_d)
                if m <= 0:
                    continue
                # 需要在 i+r 额外建造 m 个供电站
                built += m
                if built > k:  # 不满足要求
                    return False
                # 把区间 [i, i+2r] 增加 m
                sum_d += m  # 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                if (right := i + r * 2 + 1) < n:
                    diff[right] -= m
            return True

        # 开区间二分
        mn = min(power)
        left, right = mn + k // n, mn + k + 1
        while left + 1 < right:
            mid = (left + right) // 2
            if check(mid):
                left = mid
            else:
                right = mid
        return left
class Solution:
    def maxPower(self, stations: List[int], r: int, k: int) -> int:
        n = len(stations)
        # 滑动窗口
        s = sum(stations[:r])  # 先计算 [0, r-1] 的发电量,为第一个窗口做准备
        power = [0] * n
        for i in range(n):
            # 右边进
            if (right := i + r) < n:
                s += stations[right]
            # 左边出
            if (left := i - r - 1) >= 0:
                s -= stations[left]
            power[i] = s

        def check(low: int) -> bool:
            low += 1  # 二分最小的不满足要求的 low(符合库函数),这样最终返回的就是最大的满足要求的 low
            diff = [0] * n  # 差分数组
            sum_d = built = 0
            for i, p in enumerate(power):
                sum_d += diff[i]  # 累加差分值
                m = low - (p + sum_d)
                if m <= 0:
                    continue
                # 需要在 i+r 额外建造 m 个供电站
                built += m
                if built > k:  # 不满足要求
                    return True
                # 把区间 [i, i+2r] 增加 m
                sum_d += m  # 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                if (right := i + r * 2 + 1) < n:
                    diff[right] -= m
            return False

        mn = min(power)
        left, right = mn + k // n, mn + k
        return bisect_left(range(right), True, lo=left, key=check)
class Solution {
    public long maxPower(int[] stations, int r, int k) {
        int n = stations.length;
        // 滑动窗口
        // 先计算 [0, r-1] 的发电量,为第一个窗口做准备
        long sum = 0;
        for (int i = 0; i < r; i++) {
            sum += stations[i];
        }
        long[] power = new long[n];
        long mn = Long.MAX_VALUE;
        for (int i = 0; i < n; i++) {
            // 右边进
            if (i + r < n) {
                sum += stations[i + r];
            }
            // 左边出
            if (i - r - 1 >= 0) {
                sum -= stations[i - r - 1];
            }
            power[i] = sum;
            mn = Math.min(mn, sum);
        }

        // 开区间二分
        long left = mn + k / n;
        long right = mn + k + 1;
        while (left + 1 < right) {
            long mid = left + (right - left) / 2;
            if (check(mid, power, r, k)) {
                left = mid;
            } else {
                right = mid;
            }
        }
        return left;
    }

    private boolean check(long low, long[] power, int r, int k) {
        int n = power.length;
        long[] diff = new long[n + 1];
        long sumD = 0;
        long built = 0;
        for (int i = 0; i < n; i++) {
            sumD += diff[i]; // 累加差分值
            long m = low - (power[i] + sumD);
            if (m <= 0) {
                continue;
            }
            // 需要在 i+r 额外建造 m 个供电站
            built += m;
            if (built > k) { // 不满足要求
                return false;
            }
            // 把区间 [i, i+2r] 增加 m
            sumD += m; // 由于 diff[i] 后面不会再访问,我们直接加到 sumD 中
            diff[Math.min(i + r * 2 + 1, n)] -= m;
        }
        return true;
    }
}
class Solution {
public:
    long long maxPower(vector<int>& stations, int r, int k) {
        int n = stations.size();
        // 滑动窗口
        // 先计算 [0, r-1] 的发电量,为第一个窗口做准备
        long long sum = reduce(stations.begin(), stations.begin() + r, 0LL);
        vector<long long> power(n);
        long long mn = LLONG_MAX;
        for (int i = 0; i < n; i++) {
            // 右边进
            if (i + r < n) {
                sum += stations[i + r];
            }
            // 左边出
            if (i - r - 1 >= 0) {
                sum -= stations[i - r - 1];
            }
            power[i] = sum;
            mn = min(mn, sum);
        }

        auto check = [&](long long low) -> bool {
            vector<long long> diff(n + 1);
            long long sum_d = 0, built = 0;
            for (int i = 0; i < n; i++) {
                sum_d += diff[i]; // 累加差分值
                long long m = low - (power[i] + sum_d);
                if (m <= 0) {
                    continue;
                }
                // 需要在 i+r 额外建造 m 个供电站
                built += m;
                if (built > k) { // 不满足要求
                    return false;
                }
                // 把区间 [i, i+2r] 增加 m
                sum_d += m; // 由于 diff[i] 后面不会再访问,我们直接加到 sum_d 中
                diff[min(i + r * 2 + 1, n)] -= m;
            }
            return true;
        };

        // 开区间二分
        long long left = mn + k / n, right = mn + k + 1;
        while (left + 1 < right) {
            long long mid = left + (right - left) / 2;
            (check(mid) ? left : right) = mid;
        }
        return left;
    }
};
func maxPower(stations []int, r int, k int) int64 {
n := len(stations)
// 滑动窗口
// 先计算 [0, r-1] 的发电量,为第一个窗口做准备
sum := 0
for _, x := range stations[:r] {
sum += x
}
power := make([]int, n)
mn := math.MaxInt
for i := range power {
// 右边进
if i+r < n {
sum += stations[i+r]
}
// 左边出
if i-r-1 >= 0 {
sum -= stations[i-r-1]
}
power[i] = sum
mn = min(mn, sum)
}

// 二分答案
left := mn + k/n
right := mn + k
ans := left + sort.Search(right-left, func(low int) bool {
// 这里 +1 是为了二分最小的不满足要求的 low(符合库函数),这样最终返回的就是最大的满足要求的 low
low += left + 1
diff := make([]int, n+1) // 差分数组
sumD, built := 0, 0
for i, p := range power {
sumD += diff[i] // 累加差分值
m := low - (p + sumD)
if m <= 0 {
continue
}
// 需要在 i+r 额外建造 m 个供电站
built += m
if built > k { // 不满足要求
return true
}
// 把区间 [i, i+2r] 增加 m
sumD += m // 由于 diff[i] 后面不会再访问,我们直接加到 sumD 中
diff[min(i+r*2+1, n)] -= m
}
return false
})
return int64(ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log k)$,其中 $n$ 是 $\textit{stations}$ 的长度。二分 $\mathcal{O}(\log k)$ 次,每次 $\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(n)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:懒删除堆 / 倒序处理(Python/Java/C++/Go)

作者 endlesscheng
2025年7月6日 12:08

方法一:懒删除堆

首先,建图 + DFS,把每个连通块中的节点加到各自的最小堆中。每个最小堆维护对应连通块的节点编号。

然后处理询问。

对于类型二,用一个 $\textit{offline}$ 布尔数组表示离线的电站。这一步不修改堆。

对于类型一:

  • 如果电站 $x$ 在线,那么答案为 $x$。
  • 否则检查 $x$ 所处堆的堆顶是否在线。若离线,则弹出堆顶,重复该过程。如果堆为不空,那么答案为堆顶,否则为 $-1$。

为了找到 $x$ 所属的堆,还需要一个数组 $\textit{belong}$ 记录每个节点在哪个堆中。

具体请看 视频讲解,欢迎点赞关注~

###py

class Solution:
    def processQueries(self, c: int, connections: List[List[int]], queries: List[List[int]]) -> List[int]:
        g = [[] for _ in range(c + 1)]
        for x, y in connections:
            g[x].append(y)
            g[y].append(x)

        belong = [-1] * (c + 1)
        heaps = []

        def dfs(x: int) -> None:
            belong[x] = len(heaps)  # 记录节点 x 在哪个堆
            h.append(x)
            for y in g[x]:
                if belong[y] < 0:
                    dfs(y)

        for i in range(1, c + 1):
            if belong[i] >= 0:
                continue
            h = []
            dfs(i)
            heapify(h)
            heaps.append(h)

        ans = []
        offline = [False] * (c + 1)
        for op, x in queries:
            if op == 2:
                offline[x] = True
                continue
            if not offline[x]:
                ans.append(x)
                continue
            h = heaps[belong[x]]
            # 懒删除:取堆顶的时候,如果离线,才删除
            while h and offline[h[0]]:
                heappop(h)
            ans.append(h[0] if h else -1)
        return ans

###java

class Solution {
    public int[] processQueries(int c, int[][] connections, int[][] queries) {
        List<Integer>[] g = new ArrayList[c + 1];
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] e : connections) {
            int x = e[0], y = e[1];
            g[x].add(y);
            g[y].add(x);
        }

        int[] belong = new int[c + 1];
        Arrays.fill(belong, -1);
        List<PriorityQueue<Integer>> heaps = new ArrayList<>();
        PriorityQueue<Integer> pq;
        for (int i = 1; i <= c; i++) {
            if (belong[i] >= 0) {
                continue;
            }
            pq = new PriorityQueue<>();
            dfs(i, g, belong, heaps.size(), pq);
            heaps.add(pq);
        }

        int ansSize = 0;
        for (int[] q : queries) {
            if (q[0] == 1) {
                ansSize++;
            }
        }

        int[] ans = new int[ansSize];
        int idx = 0;
        boolean[] offline = new boolean[c + 1];
        for (int[] q : queries) {
            int x = q[1];
            if (q[0] == 2) {
                offline[x] = true;
                continue;
            }
            if (!offline[x]) {
                ans[idx++] = x;
                continue;
            }
            pq = heaps.get(belong[x]);
            // 懒删除:取堆顶的时候,如果离线,才删除
            while (!pq.isEmpty() && offline[pq.peek()]) {
                pq.poll();
            }
            ans[idx++] = pq.isEmpty() ? -1 : pq.peek();
        }
        return ans;
    }

    private void dfs(int x, List<Integer>[] g, int[] belong, int compId, PriorityQueue<Integer> pq) {
        belong[x] = compId; // 记录节点 x 在哪个堆
        pq.offer(x);
        for (int y : g[x]) {
            if (belong[y] < 0) {
                dfs(y, g, belong, compId, pq);
            }
        }
    }
}

###cpp

class Solution {
public:
    vector<int> processQueries(int c, vector<vector<int>>& connections, vector<vector<int>>& queries) {
        vector<vector<int>> g(c + 1);
        for (auto& e : connections) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        vector<int> belong(c + 1, -1);
        vector<priority_queue<int, vector<int>, greater<>>> heaps;
        priority_queue<int, vector<int>, greater<>> pq;

        auto dfs = [&](this auto&& dfs, int x) -> void {
            belong[x] = heaps.size(); // 记录节点 x 在哪个堆
            pq.push(x);
            for (int y : g[x]) {
                if (belong[y] < 0) {
                    dfs(y);
                }
            }
        };

        for (int i = 1; i <= c; i++) {
            if (belong[i] < 0) {
                dfs(i);
                heaps.emplace_back(move(pq));
            }
        }

        vector<int> ans;
        vector<int8_t> offline(c + 1);
        for (auto& q : queries) {
            int x = q[1];
            if (q[0] == 2) {
                offline[x] = true;
                continue;
            }
            if (!offline[x]) {
                ans.push_back(x);
                continue;
            }
            auto& h = heaps[belong[x]];
            // 懒删除:取堆顶的时候,如果离线,才删除
            while (!h.empty() && offline[h.top()]) {
                h.pop();
            }
            ans.push_back(h.empty() ? -1 : h.top());
        }
        return ans;
    }
};

###go

func processQueries(c int, connections [][]int, queries [][]int) (ans []int) {
g := make([][]int, c+1)
for _, e := range connections {
x, y := e[0], e[1]
g[x] = append(g[x], y)
g[y] = append(g[y], x)
}

belong := make([]int, c+1)
for i := range belong {
belong[i] = -1
}
heaps := []hp{}
var h hp

var dfs func(int)
dfs = func(x int) {
belong[x] = len(heaps) // 记录节点 x 在哪个堆
h.IntSlice = append(h.IntSlice, x)
for _, y := range g[x] {
if belong[y] < 0 {
dfs(y)
}
}
}
for i := 1; i <= c; i++ {
if belong[i] >= 0 {
continue
}
h = hp{}
dfs(i)
heap.Init(&h)
heaps = append(heaps, h)
}

offline := make([]bool, c+1)
for _, q := range queries {
x := q[1]
if q[0] == 2 {
offline[x] = true
continue
}
if !offline[x] {
ans = append(ans, x)
continue
}
// 懒删除:取堆顶的时候,如果离线,才删除
h := &heaps[belong[x]]
for h.Len() > 0 && offline[h.IntSlice[0]] {
heap.Pop(h)
}
if h.Len() > 0 {
ans = append(ans, h.IntSlice[0])
} else {
ans = append(ans, -1)
}
}
return
}

type hp struct{ sort.IntSlice }
func (h *hp) Push(v any) { h.IntSlice = append(h.IntSlice, v.(int)) }
func (h *hp) Pop() any   { a := h.IntSlice; v := a[len(a)-1]; h.IntSlice = a[:len(a)-1]; return v }

复杂度分析

  • 时间复杂度:$\mathcal{O}(c\log c+n + q\log c)$ 或者 $\mathcal{O}(c+n + q\log c)$,取决于实现,其中 $n$ 是 $\textit{connections}$ 的长度,$q$ 是 $\textit{queries}$ 的长度。
  • 空间复杂度:$\mathcal{O}(c+n)$。返回值不计入。

方法二:倒序处理 + 维护最小值

倒序处理询问,离线变成在线,删除变成添加,每个连通块只需要一个 $\texttt{int}$ 变量就可以维护最小值。

注意可能存在同一个节点多次离线的情况,我们需要记录节点离线的最早时间(询问的下标)。对于倒序处理来说,离线的最早时间才是真正的在线时间。

###py

class Solution:
    def processQueries(self, c: int, connections: List[List[int]], queries: List[List[int]]) -> List[int]:
        g = [[] for _ in range(c + 1)]
        for x, y in connections:
            g[x].append(y)
            g[y].append(x)

        belong = [-1] * (c + 1)
        cc = 0  # 连通块编号

        def dfs(x: int) -> None:
            belong[x] = cc  # 记录节点 x 在哪个连通块
            for y in g[x]:
                if belong[y] < 0:
                    dfs(y)

        for i in range(1, c + 1):
            if belong[i] < 0:
                dfs(i)
                cc += 1

        # 记录每个节点的离线时间,初始为无穷大(始终在线)
        offline_time = [inf] * (c + 1)
        for i in range(len(queries) - 1, -1, -1):
            t, x = queries[i]
            if t == 2:
                offline_time[x] = i  # 记录离线时间

        # 每个连通块中仍在线的电站的最小编号
        mn = [inf] * cc
        for i in range(1, c + 1):
            if offline_time[i] == inf:  # 最终仍在线
                j = belong[i]
                mn[j] = min(mn[j], i)

        ans = []
        for i in range(len(queries) - 1, -1, -1):
            t, x = queries[i]
            j = belong[x]
            if t == 2:
                if offline_time[x] == i:
                    mn[j] = min(mn[j], x)  # 变回在线
            elif i < offline_time[x]:  # 已经在线(写 < 或者 <= 都可以)
                ans.append(x)
            elif mn[j] != inf:
                ans.append(mn[j])
            else:
                ans.append(-1)
        ans.reverse()
        return ans

###java

class Solution {
    public int[] processQueries(int c, int[][] connections, int[][] queries) {
        List<Integer>[] g = new ArrayList[c + 1];
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] e : connections) {
            int x = e[0], y = e[1];
            g[x].add(y);
            g[y].add(x);
        }

        int[] belong = new int[c + 1];
        Arrays.fill(belong, -1);
        int cc = 0; // 连通块编号
        for (int i = 1; i <= c; i++) {
            if (belong[i] < 0) {
                dfs(i, g, belong, cc);
                cc++;
            }
        }

        int[] offlineTime = new int[c + 1];
        Arrays.fill(offlineTime, Integer.MAX_VALUE);
        int q1 = 0;
        for (int i = queries.length - 1; i >= 0; i--) {
            int[] q = queries[i];
            if (q[0] == 2) {
                offlineTime[q[1]] = i; // 记录最早离线时间
            } else {
                q1++;
            }
        }

        // 维护每个连通块的在线电站的最小编号
        int[] mn = new int[cc];
        Arrays.fill(mn, Integer.MAX_VALUE);
        for (int i = 1; i <= c; i++) {
            if (offlineTime[i] == Integer.MAX_VALUE) { // 最终仍然在线
                int j = belong[i];
                mn[j] = Math.min(mn[j], i);
            }
        }

        int[] ans = new int[q1];
        for (int i = queries.length - 1; i >= 0; i--) {
            int[] q = queries[i];
            int x = q[1];
            int j = belong[x];
            if (q[0] == 2) {
                if (offlineTime[x] == i) { // 变回在线
                    mn[j] = Math.min(mn[j], x);
                }
            } else {
                q1--;
                if (i < offlineTime[x]) { // 已经在线(写 < 或者 <= 都可以)
                    ans[q1] = x;
                } else if (mn[j] != Integer.MAX_VALUE) {
                    ans[q1] = mn[j];
                } else {
                    ans[q1] = -1;
                }
            }
        }
        return ans;
    }

    private void dfs(int x, List<Integer>[] g, int[] belong, int compId) {
        belong[x] = compId;
        for (int y : g[x]) {
            if (belong[y] < 0) {
                dfs(y, g, belong, compId);
            }
        }
    }
}

###cpp

class Solution {
public:
    vector<int> processQueries(int c, vector<vector<int>>& connections, vector<vector<int>>& queries) {
        vector<vector<int>> g(c + 1);
        for (auto& e : connections) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        vector<int> belong(c + 1, -1);
        int cc = 0; // 连通块编号
        auto dfs = [&](this auto&& dfs, int x) -> void {
            belong[x] = cc; // 记录节点 x 在哪个连通块
            for (int y : g[x]) {
                if (belong[y] < 0) {
                    dfs(y);
                }
            }
        };

        for (int i = 1; i <= c; i++) {
            if (belong[i] < 0) {
                dfs(i);
                cc++;
            }
        }

        vector<int> offline_time(c + 1, INT_MAX);
        for (int i = queries.size() - 1; i >= 0; i--) {
            auto& q = queries[i];
            if (q[0] == 2) {
                offline_time[q[1]] = i; // 记录最早离线时间
            }
        }

        // 维护每个连通块的在线电站的最小编号
        vector<int> mn(cc, INT_MAX);
        for (int i = 1; i <= c; i++) {
            if (offline_time[i] == INT_MAX) { // 最终仍然在线
                int j = belong[i];
                mn[j] = min(mn[j], i);
            }
        }

        vector<int> ans;
        for (int i = queries.size() - 1; i >= 0; i--) {
            auto& q = queries[i];
            int x = q[1];
            int j = belong[x];
            if (q[0] == 2) {
                if (offline_time[x] == i) { // 变回在线
                    mn[j] = min(mn[j], x);
                }
            } else if (i < offline_time[x]) { // 已经在线(写 < 或者 <= 都可以)
                ans.push_back(x);
            } else if (mn[j] != INT_MAX) {
                ans.push_back(mn[j]);
            } else {
                ans.push_back(-1);
            }
        }
        ranges::reverse(ans);
        return ans;
    }
};

###go

func processQueries(c int, connections [][]int, queries [][]int) []int {
g := make([][]int, c+1)
for _, e := range connections {
x, y := e[0], e[1]
g[x] = append(g[x], y)
g[y] = append(g[y], x)
}

belong := make([]int, c+1)
for i := range belong {
belong[i] = -1
}
cc := 0 // 连通块编号

var dfs func(int)
dfs = func(x int) {
belong[x] = cc // 记录节点 x 在哪个连通块
for _, y := range g[x] {
if belong[y] < 0 {
dfs(y)
}
}
}
for i := 1; i <= c; i++ {
if belong[i] < 0 {
dfs(i)
cc++
}
}

offlineTime := make([]int, c+1)
for i := range offlineTime {
offlineTime[i] = math.MaxInt
}
q1 := 0
for i, q := range slices.Backward(queries) {
if q[0] == 2 {
offlineTime[q[1]] = i // 记录最早离线时间
} else {
q1++
}
}

// 维护每个连通块的在线电站的最小编号
mn := make([]int, cc)
for i := range mn {
mn[i] = math.MaxInt
}
for i := 1; i <= c; i++ {
if offlineTime[i] == math.MaxInt { // 最终仍然在线
j := belong[i]
mn[j] = min(mn[j], i)
}
}

ans := make([]int, q1)
for i, q := range slices.Backward(queries) {
x := q[1]
j := belong[x]
if q[0] == 2 {
if offlineTime[x] == i { // 变回在线
mn[j] = min(mn[j], x)
}
} else {
q1--
if i < offlineTime[x] { // 已经在线(写 < 或者 <= 都可以)
ans[q1] = x
} else if mn[j] != math.MaxInt {
ans[q1] = mn[j]
} else {
ans[q1] = -1
}
}
}
return ans
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(c+n + q)$,其中 $n$ 是 $\textit{connections}$ 的长度,$q$ 是 $\textit{queries}$ 的长度。
  • 空间复杂度:$\mathcal{O}(c+n)$。返回值不计入。

相似题目

3108. 带权图里旅途的最小代价

专题训练

  1. 图论题单的「§1.1 DFS 基础」。
  2. 数据结构题单的「§5.6 懒删除堆」。
  3. 数据结构题单的「专题:离线算法」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

两个有序集合维护前 x 大二元组(Python/Java/C++/Go)

作者 endlesscheng
2024年10月13日 12:36

前置题目

  1. 295. 数据流的中位数我的题解
  2. 480. 滑动窗口中位数我的题解
  3. 3013. 将数组分成最小总代价的子数组 II我的题解

在 3013 题中,我们用两个有序集合维护前 $k-1$ 小元素及其总和。

本题要维护前 $x$ 大的二元组 $(\textit{cnt}[x], x)$,以及 $\textit{cnt}[x]\cdot x$ 的总和。其中 $\textit{cnt}[x]$ 表示 $x$ 在子数组(滑动窗口)中的出现次数。

当元素进入窗口时:

  1. 把 $(\textit{cnt}[x], x)$ 从有序集合中移除。
  2. 把 $\textit{cnt}[x]$ 加一。
  3. 把 $(\textit{cnt}[x], x)$ 加入有序集合。

当元素离开窗口时:

  1. 把 $(\textit{cnt}[x], x)$ 从有序集合中移除。
  2. 把 $\textit{cnt}[x]$ 减一。
  3. 把 $(\textit{cnt}[x], x)$ 加入有序集合。

添加删除的同时维护 $\textit{cnt}[x]\cdot x$ 的总和。

其余逻辑同 3013 题

本题视频讲解,欢迎点赞关注~

###py

from sortedcontainers import SortedList

class Solution:
    def findXSum(self, nums: List[int], k: int, x: int) -> List[int]:
        cnt = defaultdict(int)
        L = SortedList()  # 保存 tuple (出现次数,元素值)
        R = SortedList()
        sum_l = 0  # L 的元素和

        def add(val: int) -> None:
            if cnt[val] == 0:
                return
            p = (cnt[val], val)
            if L and p > L[0]:  # p 比 L 中最小的还大
                nonlocal sum_l
                sum_l += p[0] * p[1]
                L.add(p)
            else:
                R.add(p)

        def remove(val: int) -> None:
            if cnt[val] == 0:
                return
            p = (cnt[val], val)
            if p in L:
                nonlocal sum_l
                sum_l -= p[0] * p[1]
                L.remove(p)
            else:
                R.remove(p)

        def l2r() -> None:
            nonlocal sum_l
            p = L[0]
            sum_l -= p[0] * p[1]
            L.remove(p)
            R.add(p)

        def r2l() -> None:
            nonlocal sum_l
            p = R[-1]
            sum_l += p[0] * p[1]
            R.remove(p)
            L.add(p)

        ans = [0] * (len(nums) - k + 1)
        for r, in_ in enumerate(nums):
            # 添加 in_
            remove(in_)
            cnt[in_] += 1
            add(in_)

            l = r + 1 - k
            if l < 0:
                continue

            # 维护大小
            while R and len(L) < x:
                r2l()
            while len(L) > x:
                l2r()
            ans[l] = sum_l

            # 移除 out
            out = nums[l]
            remove(out)
            cnt[out] -= 1
            add(out)
        return ans

###java

class Solution {
    private final TreeSet<int[]> L = new TreeSet<>((a, b) -> a[0] != b[0] ? a[0] - b[0] : a[1] - b[1]);
    private final TreeSet<int[]> R = new TreeSet<>(L.comparator());
    private final Map<Integer, Integer> cnt = new HashMap<>();
    private long sumL = 0;

    public long[] findXSum(int[] nums, int k, int x) {
        long[] ans = new long[nums.length - k + 1];
        for (int r = 0; r < nums.length; r++) {
            // 添加 in
            int in = nums[r];
            del(in);
            cnt.merge(in, 1, Integer::sum); // cnt[in]++
            add(in);

            int l = r + 1 - k;
            if (l < 0) {
                continue;
            }

            // 维护大小
            while (!R.isEmpty() && L.size() < x) {
                r2l();
            }
            while (L.size() > x) {
                l2r();
            }
            ans[l] = sumL;

            // 移除 out
            int out = nums[l];
            del(out);
            cnt.merge(out, -1, Integer::sum); // cnt[out]--
            add(out);
        }
        return ans;
    }

    // 添加元素
    private void add(int val) {
        int c = cnt.get(val);
        if (c == 0) {
            return;
        }
        int[] p = new int[]{c, val};
        if (!L.isEmpty() && L.comparator().compare(p, L.first()) > 0) { // p 比 L 中最小的还大
            sumL += (long) p[0] * p[1];
            L.add(p);
        } else {
            R.add(p);
        }
    }

    // 删除元素
    private void del(int val) {
        int c = cnt.getOrDefault(val, 0);
        if (c == 0) {
            return;
        }
        int[] p = new int[]{c, val};
        if (L.contains(p)) {
            sumL -= (long) p[0] * p[1];
            L.remove(p);
        } else {
            R.remove(p);
        }
    }

    // 从 L 移动一个元素到 R
    private void l2r() {
        int[] p = L.pollFirst();
        sumL -= (long) p[0] * p[1];
        R.add(p);
    }

    // 从 R 移动一个元素到 L
    private void r2l() {
        int[] p = R.pollLast();
        sumL += (long) p[0] * p[1];
        L.add(p);
    }
}

###cpp

class Solution {
public:
    vector<long long> findXSum(vector<int>& nums, int k, int x) {
        using pii = pair<int, int>; // 出现次数,元素值
        set<pii> L, R;
        long long sum_l = 0; // L 的元素和
        unordered_map<int, int> cnt;
        auto add = [&](int x) {
            pii p = {cnt[x], x};
            if (p.first == 0) {
                return;
            }
            if (!L.empty() && p > *L.begin()) { // p 比 L 中最小的还大
                sum_l += (long long) p.first * p.second;
                L.insert(p);
            } else {
                R.insert(p);
            }
        };
        auto del = [&](int x) {
            pii p = {cnt[x], x};
            if (p.first == 0) {
                return;
            }
            auto it = L.find(p);
            if (it != L.end()) {
                sum_l -= (long long) p.first * p.second;
                L.erase(it);
            } else {
                R.erase(p);
            }
        };
        auto l2r = [&]() {
            pii p = *L.begin();
            sum_l -= (long long) p.first * p.second;
            L.erase(p);
            R.insert(p);
        };
        auto r2l = [&]() {
            pii p = *R.rbegin();
            sum_l += (long long) p.first * p.second;
            R.erase(p);
            L.insert(p);
        };

        vector<long long> ans(nums.size() - k + 1);
        for (int r = 0; r < nums.size(); r++) {
            // 添加 in
            int in = nums[r];
            del(in);
            cnt[in]++;
            add(in);

            int l = r + 1 - k;
            if (l < 0) {
                continue;
            }

            // 维护大小
            while (!R.empty() && L.size() < x) {
                r2l();
            }
            while (L.size() > x) {
                l2r();
            }
            ans[l] = sum_l;

            // 移除 out
            int out = nums[l];
            del(out);
            cnt[out]--;
            add(out);
        }
        return ans;
    }
};

###go

import "github.com/emirpasic/gods/v2/trees/redblacktree"

type pair struct{ c, x int } // 出现次数,元素值

func less(p, q pair) int {
return cmp.Or(p.c-q.c, p.x-q.x)
}

func findXSum(nums []int, k, x int) []int64 {
L := redblacktree.NewWith[pair, struct{}](less)
R := redblacktree.NewWith[pair, struct{}](less)

sumL := 0 // L 的元素和
cnt := map[int]int{}
add := func(x int) {
p := pair{cnt[x], x}
if p.c == 0 {
return
}
if !L.Empty() && less(p, L.Left().Key) > 0 { // p 比 L 中最小的还大
sumL += p.c * p.x
L.Put(p, struct{}{})
} else {
R.Put(p, struct{}{})
}
}
del := func(x int) {
p := pair{cnt[x], x}
if p.c == 0 {
return
}
if _, ok := L.Get(p); ok {
sumL -= p.c * p.x
L.Remove(p)
} else {
R.Remove(p)
}
}
l2r := func() {
p := L.Left().Key
sumL -= p.c * p.x
L.Remove(p)
R.Put(p, struct{}{})
}
r2l := func() {
p := R.Right().Key
sumL += p.c * p.x
R.Remove(p)
L.Put(p, struct{}{})
}

ans := make([]int64, len(nums)-k+1)
for r, in := range nums {
// 添加 in
del(in)
cnt[in]++
add(in)

l := r + 1 - k
if l < 0 {
continue
}

// 维护大小
for !R.Empty() && L.Size() < x {
r2l()
}
for L.Size() > x {
l2r()
}
ans[l] = int64(sumL)

// 移除 out
out := nums[l]
del(out)
cnt[out]--
add(out)
}
return ans
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log k)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面数据结构题单的「§5.7 对顶堆」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

贪心,每段保留最大的(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年10月17日 17:06

为了不让相邻气球颜色相同,对于 $\textit{colors}$ 的每个连续同色段,只能保留一个气球。

贪心地,保留其中耗时最大的气球。

答案为 $\textit{neededTime}$ 的总和,减去每段的最大耗时。

###py

class Solution:
    def minCost(self, colors: str, neededTime: List[int]) -> int:
        ans = max_t = 0
        for i, t in enumerate(neededTime):
            ans += t
            if t > max_t:  # 手写 if 比调用 max 快
                max_t = t
            if i == len(colors) - 1 or colors[i] != colors[i + 1]:
                # 遍历到了连续同色段的末尾
                ans -= max_t  # 保留耗时最大的气球
                max_t = 0  # 准备计算下一段的最大耗时
        return ans

###java

class Solution {
    public int minCost(String colors, int[] neededTime) {
        int n = neededTime.length;
        int ans = 0;
        int maxT = 0;
        for (int i = 0; i < n; i++) {
            int t = neededTime[i];
            ans += t;
            maxT = Math.max(maxT, t);
            if (i == n - 1 || colors.charAt(i) != colors.charAt(i + 1)) {
                // 遍历到了连续同色段的末尾
                ans -= maxT; // 保留耗时最大的气球
                maxT = 0; // 准备计算下一段的最大耗时
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int minCost(string colors, vector<int>& neededTime) {
        int n = colors.size();
        int ans = 0, max_t = 0;
        for (int i = 0; i < n; i++) {
            int t = neededTime[i];
            ans += t;
            max_t = max(max_t, t);
            if (i == n - 1 || colors[i] != colors[i + 1]) {
                // 遍历到了连续同色段的末尾
                ans -= max_t; // 保留耗时最大的气球
                max_t = 0; // 准备计算下一段的最大耗时
            }
        }
        return ans;
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

int minCost(char* colors, int* neededTime, int neededTimeSize) {
    int ans = 0, max_t = 0;
    for (int i = 0; i < neededTimeSize; i++) {
        int t = neededTime[i];
        ans += t;
        max_t = MAX(max_t, t);
        if (i == neededTimeSize - 1 || colors[i] != colors[i + 1]) {
            // 遍历到了连续同色段的末尾
            ans -= max_t; // 保留耗时最大的气球
            max_t = 0; // 准备计算下一段的最大耗时
        }
    }
    return ans;
}

###go

func minCost(colors string, neededTime []int) (ans int) {
maxT := 0
for i, t := range neededTime {
ans += t
maxT = max(maxT, t)
if i == len(colors)-1 || colors[i] != colors[i+1] {
// 遍历到了连续同色段的末尾
ans -= maxT // 保留耗时最大的气球
maxT = 0    // 准备计算下一段的最大耗时
}
}
return
}

###js

var minCost = function(colors, neededTime) {
    const n = colors.length;
    let ans = 0, maxT = 0;
    for (let i = 0; i < n; i++) {
        const t = neededTime[i];
        ans += t;
        maxT = Math.max(maxT, t);
        if (i === n - 1 || colors[i] !== colors[i + 1]) {
            // 遍历到了连续同色段的末尾
            ans -= maxT; // 保留耗时最大的气球
            maxT = 0; // 准备计算下一段的最大耗时
        }
    }
    return ans;
};

###rust

impl Solution {
    pub fn min_cost(colors: String, needed_time: Vec<i32>) -> i32 {
        let s = colors.as_bytes();
        let mut ans = 0;
        let mut max_t = 0;
        for (i, t) in needed_time.into_iter().enumerate() {
            ans += t;
            max_t = max_t.max(t);
            if i + 1 == s.len() || s[i] != s[i + 1] {
                // 遍历到了连续同色段的末尾
                ans -= max_t; // 保留耗时最大的气球
                max_t = 0; // 准备计算下一段的最大耗时
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{colors}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面双指针题单的「六、分组循环」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

哨兵节点+一次遍历(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年7月14日 12:10

如何在遍历链表的同时,删除链表节点?请看【基础算法精讲 08】

对于本题,由于直接判断节点值是否在 $\textit{nums}$ 中,需要遍历 $\textit{nums}$,时间复杂度为 $\mathcal{O}(n)$。把 $\textit{nums}$ 中的元素保存一个哈希集合中,然后判断节点值是否在哈希集合中,这样可以做到 $\mathcal{O}(1)$。

具体做法:

  1. 把 $\textit{nums}$ 中的元素保存到一个哈希集合中。
  2. 由于头节点可能会被删除,在头节点前面插入一个哨兵节点 $\textit{dummy}$,以简化代码逻辑。
  3. 初始化 $\textit{cur} = \textit{dummy}$。
  4. 遍历链表,如果 $\textit{cur}$ 的下一个节点的值在哈希集合中,则需要删除,更新 $\textit{cur}.\textit{next}$ 为 $\textit{cur}.\textit{next}.\textit{next}$;否则不删除,更新 $\textit{cur}$ 为 $\textit{cur}.\textit{next}$。
  5. 循环结束后,返回 $\textit{dummy}.\textit{next}$。

注:$\textit{dummy}$ 和 $\textit{cur}$ 是同一个节点的引用,修改 $\textit{cur}.\textit{next}$ 也会修改 $\textit{dummy}.\textit{next}$。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def modifiedList(self, nums: List[int], head: Optional[ListNode]) -> Optional[ListNode]:
        st = set(nums)
        cur = dummy = ListNode(next=head)
        while cur.next:
            nxt = cur.next
            if nxt.val in st:
                cur.next = nxt.next  # 从链表中删除 nxt 节点
            else:
                cur = nxt  # 不删除 nxt,继续向后遍历链表
        return dummy.next

###java

class Solution {
    public ListNode modifiedList(int[] nums, ListNode head) {
        Set<Integer> set = new HashSet<>(nums.length, 1); // 预分配空间
        for (int x : nums) {
            set.add(x);
        }

        ListNode dummy = new ListNode(0, head);
        ListNode cur = dummy;
        while (cur.next != null) {
            ListNode nxt = cur.next;
            if (set.contains(nxt.val)) {
                cur.next = nxt.next; // 从链表中删除 nxt 节点
            } else {
                cur = nxt; // 不删除 nxt,继续向后遍历链表
            }
        }
        return dummy.next;
    }
}

###cpp

class Solution {
public:
    ListNode* modifiedList(vector<int>& nums, ListNode* head) {
        unordered_set<int> st(nums.begin(), nums.end());
        ListNode dummy(0, head);
        ListNode* cur = &dummy;
        while (cur->next) {
            ListNode* nxt = cur->next;
            if (st.contains(nxt->val)) {
                cur->next = nxt->next; // 从链表中删除 nxt 节点
            } else {
                cur = nxt; // 不删除 nxt,继续向后遍历链表
            }
        }
        return dummy.next;
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

struct ListNode* modifiedList(int* nums, int numsSize, struct ListNode* head) {
    int mx = 0;
    for (int i = 0; i < numsSize; i++) {
        mx = MAX(mx, nums[i]);
    }

    bool* has = calloc(mx + 1, sizeof(bool));
    for (int i = 0; i < numsSize; i++) {
        has[nums[i]] = true;
    }

    struct ListNode dummy = {0, head};
    struct ListNode* cur = &dummy;
    while (cur->next) {
        struct ListNode* nxt = cur->next;
        if (nxt->val <= mx && has[nxt->val]) {
            cur->next = nxt->next; // 从链表中删除 nxt 节点
            free(nxt);
        } else {
            cur = nxt; // 不删除 nxt,继续向后遍历链表
        }
    }

    free(has);
    return dummy.next;
}

###go

func modifiedList(nums []int, head *ListNode) *ListNode {
has := make(map[int]bool, len(nums)) // 预分配空间
for _, x := range nums {
has[x] = true
}

dummy := &ListNode{Next: head}
cur := dummy
for cur.Next != nil {
nxt := cur.Next
if has[nxt.Val] {
cur.Next = nxt.Next // 从链表中删除 nxt 节点
} else {
cur = nxt // 不删除 nxt,继续向后遍历链表
}
}
return dummy.Next
}

###js

var modifiedList = function(nums, head) {
    const set = new Set(nums);
    const dummy = new ListNode(0, head);
    let cur = dummy;
    while (cur.next) {
        const nxt = cur.next;
        if (set.has(nxt.val)) {
            cur.next = nxt.next; // 从链表中删除 nxt 节点
        } else {
            cur = nxt; // 不删除 nxt,继续向后遍历链表
        }
    }
    return dummy.next;
};

###rust

use std::collections::HashSet;

impl Solution {
    pub fn modified_list(nums: Vec<i32>, head: Option<Box<ListNode>>) -> Option<Box<ListNode>> {
        let set = nums.into_iter().collect::<HashSet<_>>();
        let mut dummy = Box::new(ListNode { val: 0, next: head });
        let mut cur = &mut dummy;
        while let Some(ref mut nxt) = cur.next {
            if set.contains(&nxt.val) {
                cur.next = nxt.next.take(); // 从链表中删除 nxt 节点
            } else {
                cur = cur.next.as_mut()?; // 不删除 nxt,继续向后遍历链表
            }
        }
        dummy.next
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n + m)$,其中 $n$ 是 $\textit{nums}$ 的长度,$m$ 是链表的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面链表题单的「§1.2 删除节点」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

O(1) 空间做法:位运算 / 数学(Python/Java/C++/Go)

作者 endlesscheng
2024年9月15日 12:09

本题和 2965. 找出缺失和重复的数字 本质是一样的,见 我的题解,有位运算和数学两种做法。

位运算

需要两次遍历。一次遍历见下面的数学做法。

###py

class Solution:
    def getSneakyNumbers(self, nums: List[int]) -> List[int]:
        n = len(nums) - 2
        xor_all = n ^ (n + 1)  # n 和 n+1 多异或了
        for i, x in enumerate(nums):
            xor_all ^= i ^ x
        shift = xor_all.bit_length() - 1

        ans = [0, 0]
        for i, x in enumerate(nums):
            if i < n:
                ans[i >> shift & 1] ^= i
            ans[x >> shift & 1] ^= x
        return ans

###java

class Solution {
    public int[] getSneakyNumbers(int[] nums) {
        int n = nums.length - 2;
        int xorAll = n ^ (n + 1); // n 和 n+1 多异或了
        for (int i = 0; i < nums.length; i++) {
            xorAll ^= i ^ nums[i];
        }
        int shift = Integer.numberOfTrailingZeros(xorAll);

        int[] ans = new int[2];
        for (int i = 0; i < nums.length; i++) {
            if (i < n) {
                ans[i >> shift & 1] ^= i;
            }
            ans[nums[i] >> shift & 1] ^= nums[i];
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    vector<int> getSneakyNumbers(vector<int>& nums) {
        int n = nums.size() - 2;
        int xor_all = n ^ (n + 1); // n 和 n+1 多异或了
        for (int i = 0; i < nums.size(); i++) {
            xor_all ^= i ^ nums[i];
        }
        int shift = __builtin_ctz(xor_all);

        vector<int> ans(2);
        for (int i = 0; i < nums.size(); i++) {
            if (i < n) {
                ans[i >> shift & 1] ^= i;
            }
            ans[nums[i] >> shift & 1] ^= nums[i];
        }
        return ans;
    }
};

###go

func getSneakyNumbers(nums []int) []int {
n := len(nums) - 2
xorAll := n ^ (n + 1) // n 和 n+1 多异或了
for i, x := range nums {
xorAll ^= i ^ x
}
shift := bits.TrailingZeros(uint(xorAll))

ans := make([]int, 2)
for i, x := range nums {
if i < n {
ans[i>>shift&1] ^= i
}
ans[x>>shift&1] ^= x
}
return ans
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

数学

设多出的两个数分别为 $x$ 和 $y$。

也就是说,$\textit{nums} = [0,1,2,\cdots,n-1,x,y]$。

设 $\textit{nums}$ 的元素和为 $s$,$\textit{nums}$ 的元素平方之和为 $s_2$,那么有

$$
\begin{aligned}
&x+y = s - (0 + 1 + 2 + \cdots + n-1) = a \
&x^2+y^2 = s_2 - (0^2 + 1^2 + 2^2 + \cdots + (n-1)^2) = b \
\end{aligned}
$$

解得

$$
\begin{cases}
x = \dfrac{a-\sqrt{2b-a^2}}{2} \
y = \dfrac{a+\sqrt{2b-a^2}}{2} \
\end{cases}
$$

也可以先算出 $x$,然后算出 $y=a-x$。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def getSneakyNumbers(self, nums: List[int]) -> List[int]:
        n = len(nums) - 2
        a = -n * (n - 1) // 2
        b = -n * (n - 1) * (n * 2 - 1) // 6
        for x in nums:
            a += x
            b += x * x
        x = (a - isqrt(b * 2 - a * a)) // 2
        return [x, a - x]

###java

class Solution {
    public int[] getSneakyNumbers(int[] nums) {
        int n = nums.length - 2;
        int a = -n * (n - 1) / 2;
        int b = -n * (n - 1) * (n * 2 - 1) / 6;
        for (int x : nums) {
            a += x;
            b += x * x;
        }
        int x = (a - (int) Math.sqrt(b * 2 - a * a)) / 2;
        return new int[]{x, a - x};
    }
}

###cpp

class Solution {
public:
    vector<int> getSneakyNumbers(vector<int>& nums) {
        int n = nums.size() - 2;
        int a = -n * (n - 1) / 2;
        int b = -n * (n - 1) * (n * 2 - 1) / 6;
        for (int x : nums) {
            a += x;
            b += x * x;
        }
        int x = (a - (int) sqrt(b * 2 - a * a)) / 2;
        return {x, a - x};
    }
};

###go

func getSneakyNumbers(nums []int) []int {
n := len(nums) - 2
a := -n * (n - 1) / 2
b := -n * (n - 1) * (n*2 - 1) / 6
for _, x := range nums {
a += x
b += x * x
}
x := (a - int(math.Sqrt(float64(b*2-a*a)))) / 2
return []int{x, a - x}
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

阅读理解(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2022年1月2日 12:13

题意:去掉没有安全设备的行,计算相邻行之间的激光束的数量之和。

lc2125.jpg{:width=250px}

示例 1 有 $4$ 行,安全设备的个数分别为 $3,0,2,1$。

去掉没有安全设备的行,剩下 $3,2,1$。

计算相邻行之间的激光束的数量之和,即 $3\times 2+ 2\times 1 = 8$。

class Solution:
    def numberOfBeams(self, bank: List[str]) -> int:
        ans = pre_cnt = 0
        for row in bank:
            cnt = row.count('1')
            if cnt > 0:
                ans += pre_cnt * cnt
                pre_cnt = cnt
        return ans
class Solution {
    public int numberOfBeams(String[] bank) {
        int ans = 0;
        int preCnt = 0;
        for (String row : bank) {
            int cnt = 0;
            for (char ch : row.toCharArray()) {
                cnt += ch - '0';
            }
            if (cnt > 0) {
                ans += preCnt * cnt;
                preCnt = cnt;
            }
        }
        return ans;
    }
}
class Solution {
public:
    int numberOfBeams(vector<string>& bank) {
        int ans = 0, pre_cnt = 0;
        for (auto& row : bank) {
            int cnt = ranges::count(row, '1');
            if (cnt > 0) {
                ans += pre_cnt * cnt;
                pre_cnt = cnt;
            }
        }
        return ans;
    }
};
int numberOfBeams(char **bank, int bankSize) {
    int ans = 0, pre_cnt = 0;
    for (int i = 0; i < bankSize; i++) {
        char* row = bank[i];
        int cnt = 0;
        for (int j = 0; row[j]; j++) {
            cnt += row[j] - '0';
        }
        if (cnt > 0) {
            ans += pre_cnt * cnt;
            pre_cnt = cnt;
        }
    }
    return ans;
}
func numberOfBeams(bank []string) (ans int) {
preCnt := 0
for _, row := range bank {
cnt := strings.Count(row, "1")
if cnt > 0 {
ans += preCnt * cnt
preCnt = cnt
}
}
return
}
var numberOfBeams = function(bank) {
    let ans = 0, preCnt = 0;
    for (const row of bank) {
        const cnt = row.split('1').length - 1;
        if (cnt > 0) {
            ans += preCnt * cnt;
            preCnt = cnt;
        }
    }
    return ans;
};
impl Solution {
    pub fn number_of_beams(bank: Vec<String>) -> i32 {
        let mut ans = 0;
        let mut pre_cnt = 0;
        for row in bank {
            let cnt = row.bytes().filter(|&c| c == b'1').count() as i32;
            if cnt > 0 {
                ans += pre_cnt * cnt;
                pre_cnt = cnt;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{bank}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

O(1) 等差数列求和(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年10月25日 08:09

设一周的天数为 $D=7$。

根据题意,第一周存的钱数为

$$
\textit{base} = 1+2+\cdots+D = \dfrac{D(D+1)}{2}
$$

第二周每天存的钱都比 $D$ 天前多 $1$,所以第二周存的钱数为

$$
\textit{base} + D
$$

第三周每天存的钱都比 $D$ 天前多 $1$,所以第三周存的钱数为

$$
\textit{base} + 2D
$$

依此类推。我们存了完整的 $w=\left\lfloor\dfrac{n}{D}\right\rfloor$ 周,在第 $w$ 周存的钱数为

$$
\textit{base} + (w-1)\cdot D
$$

$w$ 周一共存的钱数为

$$
\begin{aligned}
\sum_{i=0}^{w-1} \textit{base} + i\cdot D &= w\cdot \textit{base} + \dfrac{w(w-1)}{2}\cdot D \
&= w\cdot \dfrac{D(D+1)}{2} + \dfrac{w(w-1)}{2}\cdot D \
&= \dfrac{wD(w+D)}{2} \
\end{aligned}
$$

如果 $n$ 不是 $D$ 的倍数,还有 $r = n\bmod D$ 天,存的钱数为

$$
(w+1) + (w+2) + \cdots + (w+r) = rw + \dfrac{r(r+1)}{2} = \dfrac{r(2w+r+1)}{2}
$$

最终答案为

$$
\dfrac{wD(w+D)}{2} + \dfrac{r(2w+r+1)}{2} = \dfrac{wD(w+D) + r(2w+r+1)}{2}
$$

其中 $D=7$,$w=\left\lfloor\dfrac{n}{D}\right\rfloor$,$r = n\bmod D$。

class Solution:
    def totalMoney(self, n: int) -> int:
        D = 7
        w, r = divmod(n, D)
        return (w * D * (w + D) + r * (w * 2 + r + 1)) // 2
class Solution {
    public int totalMoney(int n) {
        final int D = 7;
        int w = n / D;
        int r = n % D;
        return (w * D * (w + D) + r * (w * 2 + r + 1)) / 2;
    }
}
class Solution {
public:
    int totalMoney(int n) {
        constexpr int D = 7;
        int w = n / D, r = n % D;
        return (w * D * (w + D) + r * (w * 2 + r + 1)) / 2;
    }
};
int totalMoney(int n) {
    const int D = 7;
    int w = n / D, r = n % D;
    return (w * D * (w + D) + r * (w * 2 + r + 1)) / 2;
}
func totalMoney(n int) int {
const d = 7
w, r := n/d, n%d
return (w*d*(w+d) + r*(w*2+r+1)) / 2
}
var totalMoney = function(n) {
    const D = 7;
    const w = Math.floor(n / D), r = n % D;
    return (w * D * (w + D) + r * (w * 2 + r + 1)) / 2;
};
impl Solution {
    pub fn total_money(n: i32) -> i32 {
        const D: i32 = 7;
        let w = n / D;
        let r = n % D;
        (w * D * (w + D) + r * (w * 2 + r + 1)) / 2
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:枚举 / 倒序贪心 + 0-1 背包(Python/Java/C++/Go)

作者 endlesscheng
2025年10月24日 10:36

方法一:枚举

设 $n$ 的十进制长度为 $m$。

对于本题的数据范围,一定存在十进制长为 $m+1$ 的数值平衡数。

例如 $n=999$,答案为 $1333$。

例如 $n=999999$,答案为 $1224444$。

对于本题的数据范围,$m+1$ 一定可以分解为 $[1,9]$ 中的不同元素之和。

所以枚举 $\mathcal{O}(n)$ 次就能找到答案。

###py

class Solution:
    def nextBeautifulNumber(self, n: int) -> int:
        while True:
            n += 1
            cnt = Counter(str(n))
            if all(int(d) == c for d, c in cnt.items()):
                return n

###java

class Solution {
    public int nextBeautifulNumber(int n) {
        next:
        while (true) {
            n++;

            int[] cnt = new int[10];
            for (int x = n; x > 0; x /= 10) {
                cnt[x % 10]++;
            }

            for (int x = n; x > 0; x /= 10) {
                if (cnt[x % 10] != x % 10) {
                    continue next;
                }
            }

            return n;
        }
    }
}

###cpp

class Solution {
public:
    int nextBeautifulNumber(int n) {
        while (true) {
            n++;

            int cnt[10]{};
            for (int x = n; x > 0; x /= 10) {
                cnt[x % 10]++;
            }

            bool ok = true;
            for (int x = n; x > 0; x /= 10) {
                if (cnt[x % 10] != x % 10) {
                    ok = false;
                    break;
                }
            }
            if (ok) {
                return n;
            }
        }
    }
};

###go

func nextBeautifulNumber(n int) int {
next:
for {
n++
cnt := [10]int{}
for x := n; x > 0; x /= 10 {
cnt[x%10]++
}
for x := n; x > 0; x /= 10 {
if cnt[x%10] != x%10 {
continue next
}
}
return n
}
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$ 或 $\mathcal{O}(n(D + \log n))$,其中 $D=10$。枚举 $\mathcal{O}(n)$ 个数,每个数需要 $\mathcal{O}(\log n)$ 的时间判断是否为数值平衡数。
  • 空间复杂度:$\mathcal{O}(\log n)$ 或 $\mathcal{O}(D)$。

方法二:倒序贪心

如果数据范围扩大到 $n\le 10^{18}$,上面的做法就超时了。

下面介绍一个更快的做法,对于更大的数据范围也能瞬间得出结果,且可以扩展到其他情况,例如值域改成小写字母 $\texttt{a}$ 到 $\texttt{z}$。

请先完成更简单的 3720. 大于目标字符串的最小字典序排列我的题解

本题用同样的方法解决。

先把 $n$ 转成十进制字符串 $s$。为简化代码,在 $s$ 前面加一个前导零。设 $s$(补前导零后)的长度为 $m$。

倒着遍历 $s$,设 $d = s[i]$,尝试把 $s[i]$ 增大为 $d+1,d+2,\ldots,9$。那么:

  1. 对于下标在 $[1,i]$ 中的数字,不能存在数字 $x$ 的出现次数超过 $x$ 的情况。换句话说,设 $\textit{cnt}[x]$ 是数字 $x$ 在 $s$ 的 $[1,i]$ 中的出现次数,不能出现 $\textit{cnt}[x] > x$ 的情况。
  2. 剩余位置 $[i+1,m-1]$ 可以随便填,我们需要补满 $0 < \textit{cnt}[x] < x$ 的数字 $x$,把 $x$ 的出现次数变成 $x$。
  3. 补满数字后,如果还有剩余位置没有填数字,那么从剩余的满足 $\textit{cnt}[x] = 0$ 的非零数字中,选择一个字典序最小的序列,使得序列的和恰好等于剩余位置的个数。这是恰好装满型 0-1 背包。算完 DP 后,如何得到具体方案?类似 1449. 数位成本和为目标值的最大数字,见 我的题解
  4. 把第二步和第三步的数字从小到大排序,填在 $[i+1,m-1]$ 中。

###py

class Solution:
    # 从 a 中选一个字典序最小的、元素和等于 target 的子序列
    # a 已经从小到大排序
    # 无解返回 None
    def zeroOneKnapsack(self, a: List[int], target: int) -> Optional[List[int]]:
        n = len(a)
        f = [[False] * (target + 1) for _ in range(n + 1)]
        f[n][0] = True

        # 倒着 DP,这样后面可以正着(从小到大)选
        for i in range(n - 1, -1, -1):
            v = a[i]
            for j in range(target + 1):
                if j < v:
                    f[i][j] = f[i + 1][j]
                else:
                    f[i][j] = f[i + 1][j] or f[i + 1][j - v]

        if not f[0][target]:
            return None

        ans = []
        j = target
        for i, v in enumerate(a):
            if j >= v and f[i + 1][j - v]:
                ans.append(v)
                j -= v
        return ans

    def nextBeautifulNumber(self, n: int) -> int:
        s = "0" + str(n)  # 补一个前导零,方便处理答案十进制串比 n 的十进制串长的情况
        s = list(map(int, s))  # 避免在后续循环中反复调用 int
        m = len(s)

        MX = 10
        cnt = [0] * MX
        for i in range(1, m):
            cnt[s[i]] += 1

        # 从右往左尝试
        for i in range(m - 1, -1, -1):
            if i > 0:
                cnt[s[i]] -= 1  # 撤销

            # 增大 s[i] 为 j
            for j in range(s[i] + 1, MX):
                cnt[j] += 1

                # 后面 [i+1, m-1] 需要补满 0 < cnt[k] < k 的数字 k,剩余数位可以随便填
                free = m - 1 - i  # 统计可以随便填的数位个数
                for k, c in enumerate(cnt):
                    if k < c:  # 不合法
                        free = -1
                        break
                    if c > 0:
                        free -= k - c
                if free < 0:  # 不合法,继续枚举
                    cnt[j] -= 1
                    continue

                # 对于可以随便填的数位,计算字典序最小的填法
                a = [k for k in range(1, MX) if cnt[k] == 0]
                missing = self.zeroOneKnapsack(a, free)
                if missing is None:  # 无解,继续枚举
                    cnt[j] -= 1
                    continue

                for v in missing:
                    cnt[v] = -v  # 用负数表示可以随便填的数

                s[i] = j
                del s[i + 1:]
                for k, c in enumerate(cnt):
                    s += [k] * (k - c if c > 0 else -c)
                return int(''.join(map(str, s)))

        return -1  # 无解(本题不会发生,但为了可扩展性保留)

###java

class Solution {
    public int nextBeautifulNumber(int n) {
        // 补一个前导零,方便处理答案十进制串比 n 的十进制串长的情况
        char[] s = ("0" + n).toCharArray();
        int m = s.length;

        final int MX = 10;
        int[] cnt = new int[MX];
        for (int i = 1; i < m; i++) {
            cnt[s[i] - '0']++;
        }

        // 从右往左尝试
        for (int i = m - 1; i >= 0; i--) {
            if (i > 0) {
                cnt[s[i] - '0']--; // 撤销
            }

            // 增大 s[i] 为 j
            for (int j = s[i] - '0' + 1; j < MX; j++) {
                cnt[j]++;

                // 后面 [i+1, m-1] 需要补满 0 < cnt[k] < k 的数字 k,剩余数位可以随便填
                int free = m - 1 - i; // 统计可以随便填的数位个数
                for (int k = 0; k < MX; k++) {
                    int c = cnt[k];
                    if (k < c) { // 不合法
                        free = -1;
                        break;
                    }
                    if (c > 0) {
                        free -= k - c;
                    }
                }
                if (free < 0) { // 不合法,继续枚举
                    cnt[j]--;
                    continue;
                }

                // 对于可以随便填的数位,计算字典序最小的填法
                List<Integer> a = new ArrayList<>();
                for (int k = 1; k < MX; k++) {
                    if (cnt[k] == 0) {
                        a.add(k);
                    }
                }

                List<Integer> missing = zeroOneKnapsack(a, free);
                if (missing == null) { // 无解,继续枚举
                    cnt[j]--;
                    continue;
                }

                for (int v : missing) {
                    cnt[v] = -v; // 用负数表示可以随便填的数
                }

                StringBuilder ans = new StringBuilder("0" + n);
                ans.setCharAt(i, (char) ('0' + j));
                ans.setLength(i + 1);
                for (int k = 1; k < MX; k++) {
                    int c = cnt[k];
                    ans.repeat('0' + k, c > 0 ? k - c : -c);
                }
                return Integer.parseInt(ans.toString());
            }
        }
        return -1; // 无解(本题不会发生,但为了可扩展性保留)
    }

    // 从 a 中选一个字典序最小的、元素和等于 target 的子序列
    // a 已经从小到大排序
    // 无解返回 null
    private List<Integer> zeroOneKnapsack(List<Integer> a, int target) {
        int n = a.size();
        boolean[][] f = new boolean[n + 1][target + 1];
        f[n][0] = true;

        // 倒着 DP,这样后面可以正着(从小到大)选
        for (int i = n - 1; i >= 0; i--) {
            int v = a.get(i);
            for (int j = 0; j <= target; j++) {
                if (j < v) {
                    f[i][j] = f[i + 1][j];
                } else {
                    f[i][j] = f[i + 1][j] || f[i + 1][j - v];
                }
            }
        }

        if (!f[0][target]) {
            return null;
        }

        List<Integer> ans = new ArrayList<>();
        int j = target;
        for (int i = 0; i < n; i++) {
            int v = a.get(i);
            if (j >= v && f[i + 1][j - v]) {
                ans.add(v);
                j -= v;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
    // 从 a 中选一个字典序最小的、元素和等于 target 的子序列
    // a 已经从小到大排序
    // 无解返回 {} 和 false
    pair<vector<int>, bool> zeroOneKnapsack(vector<int>& a, int target) {
        int n = a.size();
        vector f(n + 1, vector<int8_t>(target + 1)); 
        f[n][0] = true;

        // 倒着 DP,这样后面可以正着(从小到大)选
        for (int i = n - 1; i >= 0; i--) {
            int v = a[i];
            for (int j = 0; j <= target; j++) {
                if (j < v) {
                    f[i][j] = f[i + 1][j];
                } else {
                    f[i][j] = f[i + 1][j] || f[i + 1][j - v];
                }
            }
        }

        if (!f[0][target]) {
            return {};
        }

        vector<int> ans;
        int j = target;
        for (int i = 0; i < n; i++) {
            int v = a[i];
            if (j >= v && f[i + 1][j - v]) {
                ans.push_back(v);
                j -= v;
            }
        }
        return {ans, true};
    }

public:
    int nextBeautifulNumber(int n) {
        // 补一个前导零,方便处理答案十进制串比 n 的十进制串长的情况
        string s = "0" + to_string(n);
        int m = s.size();

        constexpr int MX = 10;
        int cnt[MX]{};
        for (int i = 1; i < m; i++) {
            cnt[s[i] - '0']++;
        }

        // 从右往左尝试
        for (int i = m - 1; i >= 0; i--) {
            if (i > 0) {
                cnt[s[i] - '0']--; // 撤销
            }

            // 增大 s[i] 为 j
            for (int j = s[i] - '0' + 1; j < MX; j++) {
                cnt[j]++;

                // 后面 [i+1, m-1] 需要补满 0 < cnt[k] < k 的数字 k,剩余数位可以随便填
                int free = m - 1 - i; // 统计可以随便填的数位个数
                for (int k = 0; k < MX; k++) {
                    int c = cnt[k];
                    if (k < c) { // 不合法
                        free = -1;
                        break;
                    }
                    if (c > 0) {
                        free -= k - c;
                    }
                }
                if (free < 0) { // 不合法,继续枚举
                    cnt[j]--;
                    continue;
                }

                // 对于可以随便填的数位,计算字典序最小的填法
                vector<int> a;
                for (int k = 1; k < MX; k++) {
                    if (cnt[k] == 0) {
                        a.push_back(k);
                    }
                }
                auto [missing, ok] = zeroOneKnapsack(a, free);
                if (!ok) { // 无解,继续枚举
                    cnt[j]--;
                    continue;
                }

                for (int v : missing) {
                    cnt[v] = -v; // 用负数表示可以随便填的数
                }

                s[i] = '0' + j;
                s.resize(i + 1);
                for (int k = 1; k < MX; k++) {
                    int c = cnt[k];
                    c = c > 0 ? k - c : -c;
                    s += string(c, '0' + k);
                }
                return stoi(s);
            }
        }
        return -1; // 无解(本题不会发生,但为了可扩展性保留)
    }
};

###go

// 从 a 中选一个字典序最小的、元素和等于 target 的子序列
// a 已经从小到大排序
// 无解返回 nil
func zeroOneKnapsack(a []int, target int) []int {
n := len(a)
f := make([][]bool, n+1)
for i := range f {
f[i] = make([]bool, target+1)
}
f[n][0] = true

// 倒着 DP,这样后面可以正着(从小到大)选
for i := n - 1; i >= 0; i-- {
v := a[i]
for j := range f[i] {
if j < v {
f[i][j] = f[i+1][j]
} else {
f[i][j] = f[i+1][j] || f[i+1][j-v]
}
}
}

if !f[0][target] {
return nil
}

ans := []int{}
j := target
for i, v := range a {
if j >= v && f[i+1][j-v] {
ans = append(ans, v)
j -= v
}
}
return ans
}

func nextBeautifulNumber(n int) int {
// 补一个前导零,方便处理答案十进制比 n 的十进制长的情况
s := "0" + strconv.Itoa(n)
m := len(s)

const mx = 10
cnt := make([]int, mx)
for i := 1; i < m; i++ {
cnt[s[i]-'0']++
}

// 从右往左尝试
for i := m - 1; i >= 0; i-- {
if i > 0 {
cnt[s[i]-'0']-- // 撤销
}

// 增大 s[i] 为 j
for j := s[i] - '0' + 1; j < mx; j++ {
cnt[j]++

// 后面 [i+1, m-1] 需要补满 0 < cnt[k] < k 的数字 k,剩余数位可以随便填
free := m - 1 - i // 统计可以随便填的数位个数
for k, c := range cnt {
if k < c { // 不合法
free = -1
break
}
if c > 0 {
free -= k - c
}
}
if free < 0 { // 不合法,继续枚举
cnt[j]--
continue
}

// 对于可以随便填的数位,计算字典序最小的填法
a := []int{}
for k := 1; k < mx; k++ {
if cnt[k] == 0 {
a = append(a, k)
}
}
missing := zeroOneKnapsack(a, free)
if missing == nil { // 无解,继续枚举
cnt[j]--
continue
}

for _, v := range missing {
cnt[v] = -v // 用负数表示可以随便填的数
}

t := []byte(s[:i+1])
t[i] = '0' + byte(j)
for k, c := range cnt {
if c > 0 {
c = k - c
} else {
c = -c
}
d := []byte{'0' + byte(k)}
t = append(t, bytes.Repeat(d, c)...)
}
ans, _ := strconv.Atoi(string(t))
return ans
}
}
return -1 // 无解(本题不会发生,但为了可扩展性保留)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(D^2\log^2 n)$,其中 $D=10$,$\log n$ 是 $n$ 的十进制长度。枚举 $\mathcal{O}(D\log n)$ 种把 $s[i]$ 增大的情况,每次需要 $\mathcal{O}(D\log n)$ 的时间计算 0-1 背包。
  • 空间复杂度:$\mathcal{O}(D\log n)$。

相似题目

专题训练

见下面贪心题单的「§3.1 字典序最小/最大」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:差分/三指针+双指针(Python/Java/C++/Go)

作者 endlesscheng
2024年11月10日 09:14

方法一:差分

前置知识差分原理讲解

设 $x = \textit{nums}[i]$。根据题意,$x$ 可以变成 $[x-k,x+k]$ 中的整数。

题目让我们计算操作后,最多有多少个数相同。

例如 $\textit{nums}=[2,4]$,$k=1$,$\textit{numOperations}=2$。$2$ 可以变成 $[1,3]$ 中的整数,$4$ 可以变成 $[3,5]$ 中的整数。$2$ 和 $4$ 都可以变成 $3$,所以答案是 $2$。

一般地,$x$ 可以变成 $[x-k,x+k]$ 中的整数,我们可以把 $[x-k,x+k]$ 中的每个整数的出现次数都加一,然后统计出现次数的最大值。这可以用差分实现。

计算差分的前缀和。设有 $\textit{sumD}$ 个数可以变成 $y$。

如果 $y$ 不在 $\textit{nums}$ 中,那么 $y$ 的最大出现次数不能超过 $\textit{numOperations}$,与 $\textit{sumD}$ 取最小值,得 $\min(\textit{sumD}, \textit{numOperations})$。

如果 $y$ 在 $\textit{nums}$ 中,且出现了 $\textit{cnt}$ 次,那么有 $\textit{sumD}-\textit{cnt}$ 个其他元素(不等于 $y$ 的数)可以变成 $y$,但这不能超过 $\textit{numOperations}$,所以有

$$
\min(\textit{sumD}-\textit{cnt}, \textit{numOperations})
$$

个其他元素可以变成 $y$,再加上 $y$ 自身出现的次数 $\textit{cnt}$,得到 $y$ 的最大频率为

$$
\textit{cnt} + \min(\textit{sumD}-\textit{cnt}, \textit{numOperations}) = \min(\textit{sumD}, \textit{cnt}+\textit{numOperations})
$$

注意上式兼容 $y$ 不在 $\textit{nums}$ 中的情况,此时 $\textit{cnt}=0$。

本题视频讲解,欢迎点赞关注~

答疑

:为什么代码只考虑在 $\textit{diff}$ 和 $\textit{nums}$ 中的数字?

:比如 $x$ 在 $\textit{diff}$ 中,$x+1,x+2,\ldots$ 不在 $\textit{diff}$ 中,那么 $x+1,x+2,\ldots$ 的 $\textit{sumD}$ 和 $\textit{x}$ 的是一样的,无需重复计算。此外,要想算出比 $\min(\textit{sumD}, \textit{cnt}+\textit{numOperations})$ 更大的数,要么 $\textit{sumD}$ 变大,要么 $\textit{cnt}$ 变大。「变大」时的 $x$ 必然在 $\textit{diff}$ 或 $\textit{nums}$ 中。

###py

class Solution:
    def maxFrequency(self, nums: List[int], k: int, numOperations: int) -> int:
        cnt = defaultdict(int)
        diff = defaultdict(int)
        for x in nums:
            cnt[x] += 1
            diff[x]  # 把 x 插入 diff,以保证下面能遍历到 x
            diff[x - k] += 1  # 把 [x-k,x+k] 中的每个整数的出现次数都加一
            diff[x + k + 1] -= 1

        ans = sum_d = 0
        for x, d in sorted(diff.items()):
            sum_d += d
            ans = max(ans, min(sum_d, cnt[x] + numOperations))
        return ans

###java

class Solution {
    int maxFrequency(int[] nums, int k, int numOperations) {
        Map<Integer, Integer> cnt = new HashMap<>();
        Map<Integer, Integer> diff = new TreeMap<>();
        for (int x : nums) {
            cnt.merge(x, 1, Integer::sum); // cnt[x]++
            diff.putIfAbsent(x, 0); // 把 x 插入 diff,以保证下面能遍历到 x
            // 把 [x-k, x+k] 中的每个整数的出现次数都加一
            diff.merge(x - k, 1, Integer::sum); // diff[x-k]++
            diff.merge(x + k + 1, -1, Integer::sum); // diff[x+k+1]--
        }

        int ans = 0;
        int sumD = 0;
        for (Map.Entry<Integer, Integer> e : diff.entrySet()) {
            sumD += e.getValue();
            ans = Math.max(ans, Math.min(sumD, cnt.getOrDefault(e.getKey(), 0) + numOperations));
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxFrequency(vector<int>& nums, int k, int numOperations) {
        unordered_map<int, int> cnt;
        map<int, int> diff;
        for (int x : nums) {
            cnt[x]++;
            diff[x]; // 把 x 插入 diff,以保证下面能遍历到 x
            diff[x - k]++; // 把 [x-k, x+k] 中的每个整数的出现次数都加一
            diff[x + k + 1]--;
        }

        int ans = 0, sum_d = 0;
        for (auto& [x, d] : diff) {
            sum_d += d;
            ans = max(ans, min(sum_d, cnt[x] + numOperations));
        }
        return ans;
    }
};

###go

func maxFrequency(nums []int, k, numOperations int) (ans int) {
cnt := map[int]int{}
diff := map[int]int{}
for _, x := range nums {
cnt[x]++
diff[x] += 0 // 把 x 插入 diff,以保证下面能遍历到 x
diff[x-k]++  // 把 [x-k, x+k] 中的每个整数的出现次数都加一
diff[x+k+1]--
}

sumD := 0
for _, x := range slices.Sorted(maps.Keys(diff)) {
sumD += diff[x]
ans = max(ans, min(sumD, cnt[x]+numOperations))
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

方法二:同向三指针 + 同向双指针

核心思路

  1. 计算有多少个数能变成 $x$,其中 $x = \textit{nums}[i]$。用同向三指针实现。
  2. 计算有多少个数能变成 $x$,其中 $x$ 不一定在 $\textit{nums}$ 中。用同向双指针实现。

同向三指针

把 $\textit{nums}$ 从小到大排序。

遍历 $\textit{nums}$。设 $x=\textit{nums}[i]$,计算元素值在 $[x-k,x+k]$ 中的元素个数,这些元素都可以变成 $x$。

遍历的同时,维护左指针 $\textit{left}$,它是最小的满足

$$
\textit{nums}[\textit{left}] \ge x - k
$$

的下标。

遍历的同时,维护右指针 $\textit{right}$,它是最小的满足

$$
\textit{nums}[\textit{right}] > x + k
$$

的下标。如果不存在,则 $\textit{right}=n$。

下标在左闭右开区间 $[\textit{left},\textit{right})$ 中的元素,都可以变成 $x$。这有

$$
\textit{sumD} = \textit{right} - \textit{left}
$$

个。

遍历的同时,求出 $x$ 有 $\textit{cnt}$ 个。然后用方法一的公式,更新答案的最大值。

同向双指针

同向三指针没有考虑「变成不在 $\textit{nums}$ 中的数」这种情况。

然而,不在 $\textit{nums}$ 中的数太多了!怎么减少计算量?

假设都变成 $y$,那么只有 $[y-k,y+k]$ 中的数能变成 $y$。

把 $\textit{nums}[i]$ 视作一维坐标轴上的点,想象一个窗口 $[y-k,y+k]$ 在不断地向右滑动,什么时候窗口内的元素个数才会变大?

  • 如果我们从 $[y-k-1,y+k-1]$ 移动到 $[y-k,y+k]$,且 $y+k$ 不在 $\textit{nums}$ 中,此时窗口内的元素个数不会变大,甚至因为左端点收缩了,元素个数可能会变小。
  • 所以,只有当 $y+k$ 恰好在 $\textit{nums}$ 中时,窗口内的元素个数才可能会变大。
  • 结论:我们只需考虑 $y+k$ 在 $\textit{nums}$ 中时的 $y$!

于是,枚举 $x=\textit{nums}[\textit{right}]$,计算元素值在 $[x-2k,x]$ 中的元素个数,这些元素都可以变成同一个数 $y=x-k$。

左指针 $\textit{left}$ 是最小的满足

$$
\textit{nums}[\textit{left}] \ge x-2k
$$

的下标。

计算好 $\textit{left}$ 后,下标在 $[\textit{left}, \textit{right}]$ 中的数可以变成一样的,这有

$$
\textit{right} - \textit{left} + 1
$$

个。注意上式不能超过 $\textit{numOperations}$。

小优化

由于同向双指针算出的结果不超过 $\textit{numOperations}$,所以当同向三指针计算完毕后,如果发现答案已经 $\ge \textit{numOperations}$,那么无需计算同向双指针。

###py

class Solution:
    def maxFrequency(self, nums: List[int], k: int, numOperations: int) -> int:
        nums.sort()

        n = len(nums)
        ans = cnt = left = right = 0
        for i, x in enumerate(nums):
            cnt += 1
            # 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if i < n - 1 and x == nums[i + 1]:
                continue
            while nums[left] < x - k:
                left += 1
            while right < n and nums[right] <= x + k:
                right += 1
            ans = max(ans, min(right - left, cnt + numOperations))
            cnt = 0

        if ans >= numOperations:
            return ans

        left = 0
        for right, x in enumerate(nums):
            while nums[left] < x - k * 2:
                left += 1
            ans = max(ans, right - left + 1)
        return min(ans, numOperations)  # 最后和 numOperations 取最小值

###java

class Solution {
    public int maxFrequency(int[] nums, int k, int numOperations) {
        Arrays.sort(nums);

        int n = nums.length;
        int ans = 0;
        int cnt = 0;
        int left = 0;
        int right = 0;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            cnt++;
            // 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if (i < n - 1 && x == nums[i + 1]) {
                continue;
            }
            while (nums[left] < x - k) {
                left++;
            }
            while (right < n && nums[right] <= x + k) {
                right++;
            }
            ans = Math.max(ans, Math.min(right - left, cnt + numOperations));
            cnt = 0;
        }

        if (ans >= numOperations) {
            return ans;
        }

        left = 0;
        for (right = 0; right < n; right++) {
            int x = nums[right];
            while (nums[left] < x - k * 2) {
                left++;
            }
            ans = Math.max(ans, right - left + 1);
        }
        return Math.min(ans, numOperations); // 最后和 numOperations 取最小值
    }
}

###cpp

class Solution {
public:
    int maxFrequency(vector<int>& nums, int k, int numOperations) {
        ranges::sort(nums);

        int n = nums.size();
        int ans = 0, cnt = 0, left = 0, right = 0;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            cnt++;
            // 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if (i < n - 1 && x == nums[i + 1]) {
                continue;
            }
            while (nums[left] < x - k) {
                left++;
            }
            while (right < n && nums[right] <= x + k) {
                right++;
            }
            ans = max(ans, min(right - left, cnt + numOperations));
            cnt = 0;
        }

        if (ans >= numOperations) {
            return ans;
        }

        left = 0;
        for (int right = 0; right < n; right++) {
            int x = nums[right];
            while (nums[left] < x - k * 2) {
                left++;
            }
            ans = max(ans, right - left + 1);
        }
        return min(ans, numOperations); // 最后和 numOperations 取最小值
    }
};

###go

func maxFrequency(nums []int, k, numOperations int) (ans int) {
slices.Sort(nums)

n := len(nums)
var cnt, left, right int
for i, x := range nums {
cnt++
// 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
if i < n-1 && x == nums[i+1] {
continue
}
for nums[left] < x-k {
left++
}
for right < n && nums[right] <= x+k {
right++
}
ans = max(ans, min(right-left, cnt+numOperations))
cnt = 0
}

if ans >= numOperations {
return ans
}

left = 0
for right, x := range nums {
for nums[left] < x-k*2 {
left++
}
ans = max(ans, right-left+1)
}
return min(ans, numOperations) // 最后和 numOperations 取最小值
}

也可以把两个 for 循环合起来。

###py

class Solution:
    def maxFrequency(self, nums: List[int], k: int, numOperations: int) -> int:
        nums.sort()

        n = len(nums)
        ans = cnt = left = right = left2 = 0
        for i, x in enumerate(nums):
            while nums[left2] < x - k * 2:
                left2 += 1
            ans = max(ans, min(i - left2 + 1, numOperations))

            cnt += 1
            # 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if i < n - 1 and x == nums[i + 1]:
                continue
            while nums[left] < x - k:
                left += 1
            while right < n and nums[right] <= x + k:
                right += 1
            ans = max(ans, min(right - left, cnt + numOperations))
            cnt = 0

        return ans

###java

class Solution {
    public int maxFrequency(int[] nums, int k, int numOperations) {
        Arrays.sort(nums);

        int n = nums.length;
        int ans = 0;
        int cnt = 0;
        int left = 0;
        int right = 0;
        int left2 = 0;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            while (nums[left2] < x - k * 2) {
                left2++;
            }
            ans = Math.max(ans, Math.min(i - left2 + 1, numOperations));

            cnt++;
            // 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if (i < n - 1 && x == nums[i + 1]) {
                continue;
            }
            while (nums[left] < x - k) {
                left++;
            }
            while (right < n && nums[right] <= x + k) {
                right++;
            }
            ans = Math.max(ans, Math.min(right - left, cnt + numOperations));
            cnt = 0;
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int maxFrequency(vector<int>& nums, int k, int numOperations) {
        ranges::sort(nums);

        int n = nums.size();
        int ans = 0, cnt = 0, left = 0, right = 0, left2 = 0;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            while (nums[left2] < x - k * 2) {
                left2++;
            }
            ans = max(ans, min(i - left2 + 1, numOperations));

            cnt++;
            // 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
            if (i < n - 1 && x == nums[i + 1]) {
                continue;
            }
            while (nums[left] < x - k) {
                left++;
            }
            while (right < n && nums[right] <= x + k) {
                right++;
            }
            ans = max(ans, min(right - left, cnt + numOperations));
            cnt = 0;
        }

        return ans;
    }
};

###go

func maxFrequency(nums []int, k, numOperations int) (ans int) {
slices.Sort(nums)

n := len(nums)
var cnt, left, right, left2 int
for i, x := range nums {
for nums[left2] < x-k*2 {
left2++
}
ans = max(ans, min(i-left2+1, numOperations))

cnt++
// 循环直到连续相同段的末尾,这样可以统计出 x 的出现次数
if i < n-1 && x == nums[i+1] {
continue
}
for nums[left] < x-k {
left++
}
for right < n && nums[right] <= x+k {
right++
}
ans = max(ans, min(right-left, cnt+numOperations))
cnt = 0
}

return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销。

专题训练

  1. 下面数据结构题单的「§2.1 一维差分」。
  2. 下面双指针题单的「§3.2 同向双指针」和「五、三指针」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

枚举轮转到最左边的下标 + 裴蜀定理(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年10月3日 10:35

:题干中的「数字一旦超过 $9$ 就会变成 $0$」的意思是,数字 $x$ 加上 $a$ 后,会变成 $(x+a)\bmod 10$。

不轮转

从特殊到一般,先考虑只累加、不轮转的情况。

为了让字典序尽量小,第一个奇数下标 $1$ 上的数字 $s_1$,越小越好。一旦我们确定了 $s_1$ 的最终值,就确定了一共累加的值。由于所有奇数下标都要累加同一个数,所以也就确定了其余奇数下标的值。

那么,$s_1$ 最小可以是多少?可以是 $0$ 吗?

比如 $s_1 = 5$,$a=2$。不断累加 $a$,$s_1$ 的变化情况如下:

$$
5\to 7\to 9\to 1\to 3\to 5\to \cdots
$$

这种情况 $s_1$ 只能是奇数,最小是 $1$。

而如果 $s_1 = 5$,$a=3$,$s_1$ 的变化情况如下:

$$
5\to 8\to 1\to 4\to 7\to 0\to 3\to 6\to 9\to 2\to 5 \cdots
$$

这种情况 $s_1$ 可以变成 $[0,9]$ 中的任意整数,最小是 $0$。

一般地,设累加操作执行了 $k\ (k\ge 0)$ 次,那么 $s_1$ 变成

$$
r = (s_1 + ak)\bmod 10
$$

也就是 $s_1 + ak$ 减去若干个 $10$ 等于 $r$,即

$$
s_1 + ak - 10q = r
$$

变形得

$$
ak - 10q = r-s_1
$$

裴蜀定理 指出,该方程有解,当且仅当 $r-s_1$ 是 $g = \gcd(a,10)$ 的倍数,即

$$
r \equiv s_1 \pmod g
$$

其中 $\equiv$ 是同余符号,详细解释见 模运算的世界:当加减乘除遇上取模

上式表明,$s_1$ 通过累加操作变成的数,必须与 $s_1$ 关于模 $g$ 同余,所以 $s_1$ 可以变成的最小值为

$$
s_1\bmod g
$$

从 $s_1$ 到 $s_1\bmod g$,一共要累加的值为

$$
s_1\bmod g - s_1 + 10
$$

其中 $+10$ 保证减法结果非负。

枚举轮转到最左边的下标

例如 $s=\texttt{012345}$,$b=4$,执行轮转操作,得到

$$
\texttt{012345}\to\texttt{234501}\to\texttt{450123}\to\texttt{012345}\to\cdots
$$

只有 $s_0,s_2,s_4$ 可以轮转到最左边。

类似上文的思路,根据裴蜀定理,可以轮转到最左边的下标,必须是 $\textit{step} = \gcd(b,n)$ 的倍数,其中 $n$ 是 $s$ 的长度。

枚举 $i = 0,\textit{step},2\cdot \textit{step}, 3\cdot \textit{step},\dots$ 作为轮转到最左边的下标。

分类讨论:

  • 如果 $\gcd(b,n)$ 是偶数,无论如何轮转,我们只能对奇数下标执行累加操作。
  • 如果 $\gcd(b,n)$ 是奇数,轮转一次后,原来的偶数下标变成奇数下标。可以先轮转一次,执行累加,再轮转到我们想要的位置。可以视作我们拥有了「对偶数下标执行累加操作」的能力。

###py

class Solution:
    def findLexSmallestString(self, s: str, a: int, b: int) -> str:
        s = list(map(int, s))
        n = len(s)
        step = gcd(b, n)
        g = gcd(a, 10)
        ans = [inf]

        def modify(start: int) -> None:
            ch = t[start]  # 最靠前的数字,越小越好
            # ch 可以变成的最小值为 ch%g
            # 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
            # 从 ch 到 ch%g,需要增加 inc(循环中会 %10 保证结果在 [0,9] 中)
            inc = ch % g - ch
            if inc:  # 优化:inc 为 0 时,t[j] 不变,无需执行 for 循环
                for j in range(start, n, 2):
                    t[j] = (t[j] + inc) % 10

        for i in range(0, n, step):      
            t = s[i:] + s[:i]  # 轮转
            modify(1)  # 累加操作(所有奇数下标)
            if step % 2:  # 能对偶数下标执行累加操作
                modify(0)  # 累加操作(所有偶数下标)
            ans = min(ans, t)

        return ''.join(map(str, ans))

###java

class Solution {
    public String findLexSmallestString(String S, int a, int b) {
        char[] s = S.toCharArray();
        int n = s.length;
        char[] t = new char[n];
        int step = gcd(b, n);
        int g = gcd(a, 10);
        String ans = null;

        for (int i = 0; i < n; i += step) {
            // t = s[i,n) + s[0,i)
            System.arraycopy(s, i, t, 0, n - i);
            System.arraycopy(s, 0, t, n - i, i);

            modify(t, 1, g); // 累加操作(所有奇数下标)
            if (step % 2 > 0) { // 能对偶数下标执行累加操作
                modify(t, 0, g); // 累加操作(所有偶数下标)
            }

            String str = new String(t);
            if (ans == null || str.compareTo(ans) < 0) {
                ans = str;
            }
        }

        return ans;
    }

    private void modify(char[] t, int start, int g) {
        int ch = t[start] - '0'; // 最靠前的数字,越小越好
        // ch 可以变成的最小值为 ch%g
        // 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
        // 从 ch 到 ch%g,需要增加 inc,其中 +10 保证 inc 非负(循环中会 %10 保证结果在 [0,9] 中)
        int inc = ch % g - ch + 10;
        for (int j = start; j < t.length; j += 2) {
            t[j] = (char) ('0' + (t[j] - '0' + inc) % 10);
        }
    }

    private int gcd(int a, int b) {
        while (a != 0) {
            int tmp = a;
            a = b % a;
            b = tmp;
        }
        return b;
    }
}

###cpp

class Solution {
public:
    string findLexSmallestString(string s, int a, int b) {
        int n = s.size();
        int step = gcd(b, n);
        int g = gcd(a, 10);
        string ans;

        for (int i = 0; i < n; i += step) {
            string t = s.substr(i) + s.substr(0, i); // 轮转

            auto modify = [&](int start) -> void {
                int ch = t[start] - '0'; // 最靠前的数字,越小越好
                // ch 可以变成的最小值为 ch%g
                // 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
                // 从 ch 到 ch%g,需要增加 inc,其中 +10 保证 inc 非负(循环中会 %10 保证结果在 [0,9] 中)
                int inc = ch % g - ch + 10;
                for (int j = start; j < n; j += 2) {
                    t[j] = '0' + (t[j] - '0' + inc) % 10;
                }
            };

            modify(1); // 累加操作(所有奇数下标)
            if (step % 2) { // 能对偶数下标执行累加操作
                modify(0); // 累加操作(所有偶数下标)
            }

            if (ans.empty() || t < ans) {
                ans = move(t);
            }
        }

        return ans;
    }
};

###c

int gcd(int a, int b) {
    while (a) {
        int tmp = a;
        a = b % a;
        b = tmp;
    }
    return b;
}

void modify(char* t, int n, int start, int g) {
    int ch = t[start] - '0'; // 最靠前的数字,越小越好
    // ch 可以变成的最小值为 ch%g
    // 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
    // 从 ch 到 ch%g,需要增加 inc,其中 +10 保证 inc 非负(循环中会 %10 保证结果在 [0,9] 中)
    int inc = ch % g - ch + 10;
    for (int j = start; j < n; j += 2) {
        t[j] = '0' + (t[j] - '0' + inc) % 10;
    }
}

char* findLexSmallestString(char* s, int a, int b) {
    int n = strlen(s);
    int step = gcd(b, n);
    int g = gcd(a, 10);

    char* ans = malloc((n + 1) * sizeof(char));
    ans[0] = CHAR_MAX;
    ans[1] = '\0';

    char* t = malloc((n + 1) * sizeof(char));
    t[n] = '\0';

    for (int i = 0; i < n; i += step) {
        // t = s[i,n) + s[0,i)
        strncpy(t, s + i, n - i);
        strncpy(t + n - i, s, i);

        modify(t, n, 1, g); // 累加操作(所有奇数下标)
        if (step % 2) { // 能对偶数下标执行累加操作
            modify(t, n, 0, g); // 累加操作(所有偶数下标)
        }

        if (strcmp(t, ans) < 0) {
            strcpy(ans, t);
        }
    }

    free(t);
    return ans;
}

###go

func findLexSmallestString(s string, a int, b int) string {
n := len(s)
step := gcd(b, n)
g := gcd(a, 10)
var ans []byte

for i := 0; i < n; i += step {
t := []byte(s[i:] + s[:i]) // 轮转
modify := func(start int) {
ch := t[start] - '0' // 最靠前的数字,越小越好
// ch 可以变成的最小值为 ch%g
// 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
// 从 ch 到 ch%g,需要增加 inc,其中 +10 保证 inc 非负(循环中会 %10 保证结果在 [0,9] 中)
inc := ch%byte(g) + 10 - ch
for j := start; j < n; j += 2 {
t[j] = '0' + (t[j]-'0'+inc)%10
}
}
modify(1) // 累加操作(所有奇数下标)
if step%2 > 0 { // 能对偶数下标执行累加操作
modify(0) // 累加操作(所有偶数下标)
}
if ans == nil || bytes.Compare(t, ans) < 0 {
ans = t
}
}

return string(ans)
}

func gcd(a, b int) int {
for a != 0 {
a, b = b%a, a
}
return b
}

###js

var findLexSmallestString = function(s, a, b) {
    const arr = s.split('').map(ch => parseInt(ch));
    const n = arr.length;
    const step = gcd(b, n);
    const g = gcd(a, 10);
    let ans = null;

    function modify(t, start) {
        const ch = t[start]; // 最靠前的数字,越小越好
        // ch 可以变成的最小值为 ch%g
        // 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
        // 从 ch 到 ch%g,需要增加 inc(循环中会 %10 保证结果在 [0,9] 中)
        let inc = ch % g - ch + 10;
        if (inc === 0) { // 优化:inc 为 0 时,t[j] 不变,无需执行 for 循环
            return;
        }
        for (let j = start; j < n; j += 2) {
            t[j] = (t[j] + inc) % 10;
        }
    }

    for (let i = 0; i < n; i += step) {
        const t = arr.slice(i).concat(arr.slice(0, i)); // 轮转
        modify(t, 1); // 累加操作(所有奇数下标)
        if (step % 2) { // 能对偶数下标执行累加操作
            modify(t, 0); // 累加操作(所有偶数下标)
        }
        if (ans === null || compareArray(t, ans) < 0) {
            ans = t;
        }
    }

    return ans.join('');
};

function gcd(a, b) {
    while (a) {
        [a, b] = [b % a, a];
}
return b;
}

function compareArray(a, b) {
    const n = a.length;
    for (let i = 0; i < n; i++) {
        if (a[i] !== b[i]) {
            return a[i] - b[i];
        }
    }
    return 0;
}

###rust

impl Solution {
    pub fn find_lex_smallest_string(s: String, a: i32, b: i32) -> String {
        let n = s.len();
        let step = gcd(b, n as i32) as usize;
        let g = gcd(a, 10) as u8;
        let mut ans = vec![u8::MAX];

        let modify = |t: &mut [u8], start: usize| {
            let ch = t[start] - b'0'; // 最靠前的数字,越小越好
            // ch 可以变成的最小值为 ch%g
            // 例如 ch=5,g=2,那么 ch+2+2+2(模 10)后变成 1,不可能变得更小
            // 从 ch 到 ch%g,需要增加 inc,其中 +10 保证 inc 非负(循环中会 %10 保证结果在 [0,9] 中)
            let inc = ch % g + 10 - ch;
            for j in (start..n).step_by(2) {
                t[j] = b'0' + (t[j] - b'0' + inc) % 10;
            }
        };

        for i in (0..n).step_by(step) {
            let mut t = format!("{}{}", &s[i..], &s[..i]).into_bytes(); // 轮转
            modify(&mut t, 1); // 累加操作(所有奇数下标)
            if step % 2 != 0 { // 能对偶数下标执行累加操作
                modify(&mut t, 0); // 累加操作(所有偶数下标)
            }
            ans = ans.min(t);
        }

        unsafe { String::from_utf8_unchecked(ans) }
    }
}

fn gcd(mut a: i32, mut b: i32) -> i32 {
    while a != 0 {
        (a, b) = (b % a, a);
    }
    b
}

复杂度分析

  • 时间复杂度:$\mathcal{O}\left(\dfrac{n^2}{\gcd(b,n)}\right)$,其中 $n$ 是 $s$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

:还可以枚举奇数下标累加值,枚举偶数下标累加值,然后用类似 最小表示法 的思想,计算在固定累加值的情况下,轮转后的最小字典序。时间复杂度 $\mathcal{O}(D^2n)$,$D=10$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

从小到大贪心(Python/Java/C++/Go)

作者 endlesscheng
2024年12月22日 12:05

一个现实中的例子:

  • 军训的某一天,同学们在操场上。现在教官吹响了口哨,同学们集合,排成一排。对于最靠左的同学 $A$,他需要尽量往左移动,给其他同学腾出位置。$A$ 旁边的同学,可以紧挨着 $A$。依此类推。

把 $\textit{nums}$ 视作 $n$ 个同学在一维数轴中的位置,从最左边的同学($\textit{nums}$ 的最小值)开始思考。

设最左边的同学的位置为 $a$,他尽量往左移,位置变成 $a-k$。

$\textit{nums}$ 的次小值 $b$ 呢?这位同学也尽量往左移:

  • 比如 $a=4,b=6,k=3$,那么 $a$ 变成 $a-k=1$,$b$ 变成 $b-k=3$。
  • 比如 $a=4,b=4,k=3$,那么 $a$ 变成 $a'=a-k=1$,$b$ 变成 $b-k=1$ 就和 $a'$ 一样了,可以稍微大一点(紧挨着 $a'$),把 $b$ 变成 $a'+1=2$。

一般地,$b$ 变成

$$
\max(b-k,a'+1)
$$

但这不能超过 $b+k$,所以最终要变成

$$
\min(\max(b-k,a'+1),b+k)
$$

相当于让 $a'+1$ 落在 $[b-k,b+k]$ 中,若超出范围则修正。

第三小的数也同理,通过前一个数可以算出当前元素能变成多少。

最后答案为 $\textit{nums}$ 中的不同元素个数。我们可以在修改的同时统计,如果发现当前元素修改后的值,比上一个元素修改后的值大,那么答案加一。

为方便计算,把 $\textit{nums}$ 从小到大排序。排序后,从左到右遍历数组,就相当于从最左边的人开始计算了。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def maxDistinctElements(self, nums: List[int], k: int) -> int:
        nums.sort()
        ans = 0
        pre = -inf  # 记录每个人左边的人的位置
        for x in nums:
            x = min(max(x - k, pre + 1), x + k)
            if x > pre:
                ans += 1
                pre = x
        return ans

###java

class Solution {
    public int maxDistinctElements(int[] nums, int k) {
        Arrays.sort(nums);
        int ans = 0;
        int pre = Integer.MIN_VALUE; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = Math.min(Math.max(x - k, pre + 1), x + k);
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxDistinctElements(vector<int>& nums, int k) {
        ranges::sort(nums);
        int ans = 0;
        int pre = INT_MIN; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = clamp(pre + 1, x - k, x + k); // min(max(x - k, pre + 1), x + k)
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
};

###go

func maxDistinctElements(nums []int, k int) (ans int) {
slices.Sort(nums)
pre := math.MinInt // 记录每个人左边的人的位置
for _, x := range nums {
x = min(max(x-k, pre+1), x+k)
if x > pre {
ans++
pre = x
}
}
return
}

优化

什么情况下,可以直接返回 $n$?

先考虑 $\textit{nums}$ 所有元素都相同的情况(同学们都挤在一起)。我们可以把元素 $x$ 变成 $[x-k,x+k]$ 中的整数,这一共有 $2k+1$ 个。如果 $2k+1 \ge n$,就可以让所有元素互不相同。

如果 $\textit{nums}$ 有不同元素,当 $2k+1 \ge n$ 时,更加可以让所有元素互不相同。

所以只要 $2k+1 \ge n$,就可以直接返回 $n$。

###py

class Solution:
    def maxDistinctElements(self, nums: List[int], k: int) -> int:
        if k * 2 + 1 >= len(nums):
            return len(nums)

        nums.sort()
        ans = 0
        pre = -inf  # 记录每个人左边的人的位置
        for x in nums:
            x = min(max(x - k, pre + 1), x + k)
            if x > pre:
                ans += 1
                pre = x
        return ans

###java

class Solution {
    public int maxDistinctElements(int[] nums, int k) {
        int n = nums.length;
        if (k * 2 + 1 >= n) {
            return n;
        }

        Arrays.sort(nums);
        int ans = 0;
        int pre = Integer.MIN_VALUE; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = Math.min(Math.max(x - k, pre + 1), x + k);
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxDistinctElements(vector<int>& nums, int k) {
        int n = nums.size();
        if (k * 2 + 1 >= n) {
            return n;
        }

        ranges::sort(nums);
        int ans = 0;
        int pre = INT_MIN; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = clamp(pre + 1, x - k, x + k); // min(max(x - k, pre + 1), x + k)
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
};

###go

func maxDistinctElements(nums []int, k int) (ans int) {
n := len(nums)
if k*2+1 >= n {
return n
}

slices.Sort(nums)
pre := math.MinInt // 记录每个人左边的人的位置
for _, x := range nums {
x = min(max(x-k, pre+1), x+k)
if x > pre {
ans++
pre = x
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销。

专题训练

见下面贪心题单的「§1.1 从最小/最大开始贪心」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

同余分组(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2023年3月19日 12:06

下文记 $m=\textit{value}$。

由于同一个数可以加减任意倍的 $m$,我们可以先把每个 $\textit{nums}[i]$ 变成与 $\textit{nums}[i]$ 关于模 $m$ 同余的最小非负整数,以备后用。关于同余的介绍,请看 模运算的世界:当加减乘除遇上取模

本题有负数,根据这篇文章中的公式,我们可以把每个 $\textit{nums}[i]$ 变成

$$
(\textit{nums}[i]\bmod m + m)\bmod m
$$

从而保证取模结果在 $[0,m)$ 中。

例如 $\textit{nums}=[1,-6,-4,3,5]$,$m=3$,取模后变成 $[1,0,2,0,2]$。

然后枚举答案:

  • 有没有与 $0$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $0$。
  • 有没有与 $1$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $1$。
  • 有没有与 $2$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $2$。
  • 有没有与 $3$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $0$。这个取模后等于 $0$ 的数,可以继续操作,变成 $3$。
  • 有没有与 $4$ 关于模 $m$ 同余的数?也就是看是否还有 $1$,没有,那么答案等于 $4$。

怎么知道还有没有剩余元素?用一个哈希表 $\textit{cnt}$ 统计 $(\textit{nums}[i]\bmod m + m) \bmod m$ 的个数。

本题视频讲解,欢迎点赞关注~

写法一

class Solution:
    def findSmallestInteger(self, nums: List[int], m: int) -> int:
        cnt = Counter(x % m for x in nums)
        mex = 0
        while cnt[mex % m]:
            cnt[mex % m] -= 1
            mex += 1
        return mex
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        int[] cnt = new int[m];
        for (int x : nums) {
            cnt[(x % m + m) % m]++; // 保证取模结果在 [0, m) 中
        }

        int mex = 0;
        while (cnt[mex % m]-- > 0) {
            mex++;
        }
        return mex;
    }
}
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        Map<Integer, Integer> cnt = new HashMap<>();
        for (int x : nums) {
            cnt.merge((x % m + m) % m, 1, Integer::sum);
        }

        int mex = 0;
        while (cnt.merge(mex % m, -1, Integer::sum) >= 0) {
            mex++;
        }
        return mex;
    }
}
class Solution {
public:
    int findSmallestInteger(vector<int>& nums, int m) {
        unordered_map<int, int> cnt;
        for (int x : nums) {
            cnt[(x % m + m) % m]++; // 保证取模结果在 [0, m) 中
        }

        int mex = 0;
        while (cnt[mex % m]-- > 0) {
            mex++;
        }
        return mex;
    }
};
int findSmallestInteger(int* nums, int numsSize, int m) {
    int* cnt = calloc(m, sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        cnt[(nums[i] % m + m) % m]++; // 保证取模结果在 [0, m) 中
    }

    int mex = 0;
    while (cnt[mex % m]-- > 0) {
        mex++;
    }

    free(cnt);
    return mex;
}
func findSmallestInteger(nums []int, m int) (mex int) {
cnt := map[int]int{}
for _, x := range nums {
cnt[(x%m+m)%m]++ // 保证取模结果在 [0, m) 中
}

for cnt[mex%m] > 0 {
cnt[mex%m]--
mex++
}
return
}
var findSmallestInteger = function(nums, m) {
    const cnt = new Map();
    for (const x of nums) {
        const v = (x % m + m) % m; // 保证取模结果在 [0, m) 中
        cnt.set(v, (cnt.get(v) ?? 0) + 1);
    }

    let mex = 0;
    while ((cnt.get(mex % m) ?? 0) > 0) {
        cnt.set(mex % m, cnt.get(mex % m) - 1);
        mex++;
    }
    return mex;
};
use std::collections::HashMap;

impl Solution {
    pub fn find_smallest_integer(nums: Vec<i32>, m: i32) -> i32 {
        let mut cnt = HashMap::new();
        for x in nums {
            // 保证取模结果在 [0, m) 中
            *cnt.entry((x % m + m) % m).or_insert(0) += 1;
        }

        for mex in 0.. {
            if let Some(c) = cnt.get_mut(&(mex % m)) {
                if *c > 0 {
                    *c -= 1;
                    continue;
                }
            }
            return mex;
        }
        unreachable!()
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。由于加多少个数,就只能减多少个数,所以第二个循环至多循环 $\mathcal{O}(n)$ 次。
  • 空间复杂度:$\mathcal{O}(\min(n,m))$。哈希表中至多有 $\mathcal{O}(\min(n,m))$ 个元素。

写法二

lc2598-c.png{:width=500px}

此外,把哈希表换成数组更快。

class Solution:
    def findSmallestInteger(self, nums: List[int], m: int) -> int:
        cnt = [0] * m
        for x in nums:
            cnt[x % m] += 1

        i = cnt.index(min(cnt))
        return m * cnt[i] + i
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        int[] cnt = new int[m];
        for (int x : nums) {
            cnt[(x % m + m) % m]++;
        }

        int i = 0;
        for (int j = 1; j < m; j++) {
            if (cnt[j] < cnt[i]) {
                i = j;
            }
        }

        return m * cnt[i] + i;
    }
}
class Solution {
public:
    int findSmallestInteger(vector<int>& nums, int m) {
        vector<int> cnt(m);
        for (int x : nums) {
            cnt[(x % m + m) % m]++;
        }

        int i = ranges::min_element(cnt) - cnt.begin();
        return m * cnt[i] + i;
    }
};
int findSmallestInteger(int* nums, int numsSize, int m) {
    int* cnt = calloc(m, sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        cnt[(nums[i] % m + m) % m]++;
    }

    int i = 0;
    for (int j = 1; j < m; j++) {
        if (cnt[j] < cnt[i]) {
            i = j;
        }
    }

    int ans = m * cnt[i] + i;
    free(cnt);
    return ans;
}
func findSmallestInteger(nums []int, m int) int {
cnt := make([]int, m)
for _, x := range nums {
cnt[(x%m+m)%m]++
}

i := 0
for j := 1; j < m; j++ {
if cnt[j] < cnt[i] {
i = j
}
}

return m*cnt[i] + i
}
var findSmallestInteger = function(nums, m) {
    const cnt = Array(m).fill(0);
    for (const x of nums) {
        cnt[(x % m + m) % m]++;
    }

    let i = 0;
    for (let j = 1; j < m; j++) {
        if (cnt[j] < cnt[i]) {
            i = j;
        }
    }

    return m * cnt[i] + i;
};
impl Solution {
    pub fn find_smallest_integer(nums: Vec<i32>, m: i32) -> i32 {
        let mut cnt = vec![0; m as usize];
        for x in nums {
            cnt[((x % m + m) % m) as usize] += 1;
        }

        let mut i = 0;
        for j in 1..m as usize {
            if cnt[j] < cnt[i] {
                i = j;
            }
        }

        m * cnt[i] + i as i32
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n+m)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(m)$。

相似题目

见下面数学题单的「§1.9 同余」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

O(n) 一次遍历,简洁写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年11月10日 12:03

遍历 $\textit{nums}$,寻找严格递增段(子数组)。

设当前严格递增段的长度为 $\textit{cnt}$,上一个严格递增段的长度为 $\textit{preCnt}$。

答案有两种情况:

  • 两个子数组属于同一个严格递增段,那么 $k$ 最大是 $\left\lfloor\dfrac{\textit{cnt}}{2}\right\rfloor$。
  • 两个子数组分别属于一对相邻的严格递增段,那么 $k$ 最大是 $\min(\textit{preCnt}, \textit{cnt})$。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def maxIncreasingSubarrays(self, nums: List[int]) -> int:
        ans = pre_cnt = cnt = 0
        for i, x in enumerate(nums):
            cnt += 1
            if i == len(nums) - 1 or x >= nums[i + 1]:  # i 是严格递增段的末尾
                ans = max(ans, cnt // 2, min(pre_cnt, cnt))
                pre_cnt = cnt
                cnt = 0
        return ans

###java

class Solution {
    public int maxIncreasingSubarrays(List<Integer> nums) {
        int ans = 0;
        int preCnt = 0;
        int cnt = 0;
        for (int i = 0; i < nums.size(); i++) {
            cnt++;
            // i 是严格递增段的末尾
            if (i == nums.size() - 1 || nums.get(i) >= nums.get(i + 1)) {
                ans = Math.max(ans, Math.max(cnt / 2, Math.min(preCnt, cnt)));
                preCnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
}

###java

class Solution {
    public int maxIncreasingSubarrays(List<Integer> nums) {
        Integer[] a = nums.toArray(Integer[]::new); // 转成数组处理,更快
        int ans = 0;
        int preCnt = 0;
        int cnt = 0;
        for (int i = 0; i < a.length; i++) {
            cnt++;
            // i 是严格递增段的末尾
            if (i == a.length - 1 || a[i] >= a[i + 1]) {
                ans = Math.max(ans, Math.max(cnt / 2, Math.min(preCnt, cnt)));
                preCnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxIncreasingSubarrays(vector<int>& nums) {
        int ans = 0, pre_cnt = 0, cnt = 0;
        for (int i = 0; i < nums.size(); i++) {
            cnt++;
            if (i == nums.size() - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
                ans = max({ans, cnt / 2, min(pre_cnt, cnt)});
                pre_cnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))
#define MAX(a, b) ((b) > (a) ? (b) : (a))

int maxIncreasingSubarrays(int* nums, int numsSize) {
    int ans = 0, pre_cnt = 0, cnt = 0;
    for (int i = 0; i < numsSize; i++) {
        cnt++;
        if (i == numsSize - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
            ans = MAX(ans, MAX(cnt / 2, MIN(pre_cnt, cnt)));
            pre_cnt = cnt;
            cnt = 0;
        }
    }
    return ans;
}

###go

func maxIncreasingSubarrays(nums []int) (ans int) {
preCnt, cnt := 0, 0
for i, x := range nums {
cnt++
if i == len(nums)-1 || x >= nums[i+1] { // i 是严格递增段的末尾
ans = max(ans, cnt/2, min(preCnt, cnt))
preCnt = cnt
cnt = 0
}
}
return
}

###js

var maxIncreasingSubarrays = function(nums) {
    let ans = 0, preCnt = 0, cnt = 0;
    for (let i = 0; i < nums.length; i++) {
        cnt++;
        if (i === nums.length - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
            ans = Math.max(ans, Math.floor(cnt / 2), Math.min(preCnt, cnt));
            preCnt = cnt;
            cnt = 0;
        }
    }
    return ans;
};

###rust

impl Solution {
    pub fn max_increasing_subarrays(nums: Vec<i32>) -> i32 {
        let mut ans = 0;
        let mut pre_cnt = 0;
        let mut cnt = 0;
        for i in 0..nums.len() {
            cnt += 1;
            if i == nums.len() - 1 || nums[i] >= nums[i + 1] { // i 是严格递增段的末尾
                ans = ans.max(cnt / 2).max(pre_cnt.min(cnt));
                pre_cnt = cnt;
                cnt = 0;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面双指针题单的「六、分组循环」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

❌
❌