普通视图

发现新文章,点击刷新页面。
昨天 — 2025年10月18日首页

从小到大贪心(Python/Java/C++/Go)

作者 endlesscheng
2024年12月22日 12:05

一个现实中的例子:

  • 军训的某一天,同学们在操场上。现在教官吹响了口哨,同学们集合,排成一排。对于最靠左的同学 $A$,他需要尽量往左移动,给其他同学腾出位置。$A$ 旁边的同学,可以紧挨着 $A$。依此类推。

把 $\textit{nums}$ 视作 $n$ 个同学在一维数轴中的位置,从最左边的同学($\textit{nums}$ 的最小值)开始思考。

设最左边的同学的位置为 $a$,他尽量往左移,位置变成 $a-k$。

$\textit{nums}$ 的次小值 $b$ 呢?这位同学也尽量往左移:

  • 比如 $a=4,b=6,k=3$,那么 $a$ 变成 $a-k=1$,$b$ 变成 $b-k=3$。
  • 比如 $a=4,b=4,k=3$,那么 $a$ 变成 $a'=a-k=1$,$b$ 变成 $b-k=1$ 就和 $a'$ 一样了,可以稍微大一点(紧挨着 $a'$),把 $b$ 变成 $a'+1=2$。

一般地,$b$ 变成

$$
\max(b-k,a'+1)
$$

但这不能超过 $b+k$,所以最终要变成

$$
\min(\max(b-k,a'+1),b+k)
$$

相当于让 $a'+1$ 落在 $[b-k,b+k]$ 中,若超出范围则修正。

第三小的数也同理,通过前一个数可以算出当前元素能变成多少。

最后答案为 $\textit{nums}$ 中的不同元素个数。我们可以在修改的同时统计,如果发现当前元素修改后的值,比上一个元素修改后的值大,那么答案加一。

为方便计算,把 $\textit{nums}$ 从小到大排序。排序后,从左到右遍历数组,就相当于从最左边的人开始计算了。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def maxDistinctElements(self, nums: List[int], k: int) -> int:
        nums.sort()
        ans = 0
        pre = -inf  # 记录每个人左边的人的位置
        for x in nums:
            x = min(max(x - k, pre + 1), x + k)
            if x > pre:
                ans += 1
                pre = x
        return ans

###java

class Solution {
    public int maxDistinctElements(int[] nums, int k) {
        Arrays.sort(nums);
        int ans = 0;
        int pre = Integer.MIN_VALUE; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = Math.min(Math.max(x - k, pre + 1), x + k);
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxDistinctElements(vector<int>& nums, int k) {
        ranges::sort(nums);
        int ans = 0;
        int pre = INT_MIN; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = clamp(pre + 1, x - k, x + k); // min(max(x - k, pre + 1), x + k)
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
};

###go

func maxDistinctElements(nums []int, k int) (ans int) {
slices.Sort(nums)
pre := math.MinInt // 记录每个人左边的人的位置
for _, x := range nums {
x = min(max(x-k, pre+1), x+k)
if x > pre {
ans++
pre = x
}
}
return
}

优化

什么情况下,可以直接返回 $n$?

先考虑 $\textit{nums}$ 所有元素都相同的情况(同学们都挤在一起)。我们可以把元素 $x$ 变成 $[x-k,x+k]$ 中的整数,这一共有 $2k+1$ 个。如果 $2k+1 \ge n$,就可以让所有元素互不相同。

如果 $\textit{nums}$ 有不同元素,当 $2k+1 \ge n$ 时,更加可以让所有元素互不相同。

所以只要 $2k+1 \ge n$,就可以直接返回 $n$。

###py

class Solution:
    def maxDistinctElements(self, nums: List[int], k: int) -> int:
        if k * 2 + 1 >= len(nums):
            return len(nums)

        nums.sort()
        ans = 0
        pre = -inf  # 记录每个人左边的人的位置
        for x in nums:
            x = min(max(x - k, pre + 1), x + k)
            if x > pre:
                ans += 1
                pre = x
        return ans

###java

class Solution {
    public int maxDistinctElements(int[] nums, int k) {
        int n = nums.length;
        if (k * 2 + 1 >= n) {
            return n;
        }

        Arrays.sort(nums);
        int ans = 0;
        int pre = Integer.MIN_VALUE; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = Math.min(Math.max(x - k, pre + 1), x + k);
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxDistinctElements(vector<int>& nums, int k) {
        int n = nums.size();
        if (k * 2 + 1 >= n) {
            return n;
        }

        ranges::sort(nums);
        int ans = 0;
        int pre = INT_MIN; // 记录每个人左边的人的位置
        for (int x : nums) {
            x = clamp(pre + 1, x - k, x + k); // min(max(x - k, pre + 1), x + k)
            if (x > pre) {
                ans++;
                pre = x;
            }
        }
        return ans;
    }
};

###go

func maxDistinctElements(nums []int, k int) (ans int) {
n := len(nums)
if k*2+1 >= n {
return n
}

slices.Sort(nums)
pre := math.MinInt // 记录每个人左边的人的位置
for _, x := range nums {
x = min(max(x-k, pre+1), x+k)
if x > pre {
ans++
pre = x
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销。

专题训练

见下面贪心题单的「§1.1 从最小/最大开始贪心」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天以前首页

同余分组(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2023年3月19日 12:06

下文记 $m=\textit{value}$。

由于同一个数可以加减任意倍的 $m$,我们可以先把每个 $\textit{nums}[i]$ 变成与 $\textit{nums}[i]$ 关于模 $m$ 同余的最小非负整数,以备后用。关于同余的介绍,请看 模运算的世界:当加减乘除遇上取模

本题有负数,根据这篇文章中的公式,我们可以把每个 $\textit{nums}[i]$ 变成

$$
(\textit{nums}[i]\bmod m + m)\bmod m
$$

从而保证取模结果在 $[0,m)$ 中。

例如 $\textit{nums}=[1,-6,-4,3,5]$,$m=3$,取模后变成 $[1,0,2,0,2]$。

然后枚举答案:

  • 有没有与 $0$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $0$。
  • 有没有与 $1$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $1$。
  • 有没有与 $2$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $2$。
  • 有没有与 $3$ 关于模 $m$ 同余的数?有,我们消耗掉一个 $0$。这个取模后等于 $0$ 的数,可以继续操作,变成 $3$。
  • 有没有与 $4$ 关于模 $m$ 同余的数?也就是看是否还有 $1$,没有,那么答案等于 $4$。

怎么知道还有没有剩余元素?用一个哈希表 $\textit{cnt}$ 统计 $(\textit{nums}[i]\bmod m + m) \bmod m$ 的个数。

本题视频讲解,欢迎点赞关注~

写法一

class Solution:
    def findSmallestInteger(self, nums: List[int], m: int) -> int:
        cnt = Counter(x % m for x in nums)
        mex = 0
        while cnt[mex % m]:
            cnt[mex % m] -= 1
            mex += 1
        return mex
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        int[] cnt = new int[m];
        for (int x : nums) {
            cnt[(x % m + m) % m]++; // 保证取模结果在 [0, m) 中
        }

        int mex = 0;
        while (cnt[mex % m]-- > 0) {
            mex++;
        }
        return mex;
    }
}
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        Map<Integer, Integer> cnt = new HashMap<>();
        for (int x : nums) {
            cnt.merge((x % m + m) % m, 1, Integer::sum);
        }

        int mex = 0;
        while (cnt.merge(mex % m, -1, Integer::sum) >= 0) {
            mex++;
        }
        return mex;
    }
}
class Solution {
public:
    int findSmallestInteger(vector<int>& nums, int m) {
        unordered_map<int, int> cnt;
        for (int x : nums) {
            cnt[(x % m + m) % m]++; // 保证取模结果在 [0, m) 中
        }

        int mex = 0;
        while (cnt[mex % m]-- > 0) {
            mex++;
        }
        return mex;
    }
};
int findSmallestInteger(int* nums, int numsSize, int m) {
    int* cnt = calloc(m, sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        cnt[(nums[i] % m + m) % m]++; // 保证取模结果在 [0, m) 中
    }

    int mex = 0;
    while (cnt[mex % m]-- > 0) {
        mex++;
    }

    free(cnt);
    return mex;
}
func findSmallestInteger(nums []int, m int) (mex int) {
cnt := map[int]int{}
for _, x := range nums {
cnt[(x%m+m)%m]++ // 保证取模结果在 [0, m) 中
}

for cnt[mex%m] > 0 {
cnt[mex%m]--
mex++
}
return
}
var findSmallestInteger = function(nums, m) {
    const cnt = new Map();
    for (const x of nums) {
        const v = (x % m + m) % m; // 保证取模结果在 [0, m) 中
        cnt.set(v, (cnt.get(v) ?? 0) + 1);
    }

    let mex = 0;
    while ((cnt.get(mex % m) ?? 0) > 0) {
        cnt.set(mex % m, cnt.get(mex % m) - 1);
        mex++;
    }
    return mex;
};
use std::collections::HashMap;

impl Solution {
    pub fn find_smallest_integer(nums: Vec<i32>, m: i32) -> i32 {
        let mut cnt = HashMap::new();
        for x in nums {
            // 保证取模结果在 [0, m) 中
            *cnt.entry((x % m + m) % m).or_insert(0) += 1;
        }

        for mex in 0.. {
            if let Some(c) = cnt.get_mut(&(mex % m)) {
                if *c > 0 {
                    *c -= 1;
                    continue;
                }
            }
            return mex;
        }
        unreachable!()
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。由于加多少个数,就只能减多少个数,所以第二个循环至多循环 $\mathcal{O}(n)$ 次。
  • 空间复杂度:$\mathcal{O}(\min(n,m))$。哈希表中至多有 $\mathcal{O}(\min(n,m))$ 个元素。

写法二

lc2598-c.png{:width=500px}

此外,把哈希表换成数组更快。

class Solution:
    def findSmallestInteger(self, nums: List[int], m: int) -> int:
        cnt = [0] * m
        for x in nums:
            cnt[x % m] += 1

        i = cnt.index(min(cnt))
        return m * cnt[i] + i
class Solution {
    public int findSmallestInteger(int[] nums, int m) {
        int[] cnt = new int[m];
        for (int x : nums) {
            cnt[(x % m + m) % m]++;
        }

        int i = 0;
        for (int j = 1; j < m; j++) {
            if (cnt[j] < cnt[i]) {
                i = j;
            }
        }

        return m * cnt[i] + i;
    }
}
class Solution {
public:
    int findSmallestInteger(vector<int>& nums, int m) {
        vector<int> cnt(m);
        for (int x : nums) {
            cnt[(x % m + m) % m]++;
        }

        int i = ranges::min_element(cnt) - cnt.begin();
        return m * cnt[i] + i;
    }
};
int findSmallestInteger(int* nums, int numsSize, int m) {
    int* cnt = calloc(m, sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        cnt[(nums[i] % m + m) % m]++;
    }

    int i = 0;
    for (int j = 1; j < m; j++) {
        if (cnt[j] < cnt[i]) {
            i = j;
        }
    }

    int ans = m * cnt[i] + i;
    free(cnt);
    return ans;
}
func findSmallestInteger(nums []int, m int) int {
cnt := make([]int, m)
for _, x := range nums {
cnt[(x%m+m)%m]++
}

i := 0
for j := 1; j < m; j++ {
if cnt[j] < cnt[i] {
i = j
}
}

return m*cnt[i] + i
}
var findSmallestInteger = function(nums, m) {
    const cnt = Array(m).fill(0);
    for (const x of nums) {
        cnt[(x % m + m) % m]++;
    }

    let i = 0;
    for (let j = 1; j < m; j++) {
        if (cnt[j] < cnt[i]) {
            i = j;
        }
    }

    return m * cnt[i] + i;
};
impl Solution {
    pub fn find_smallest_integer(nums: Vec<i32>, m: i32) -> i32 {
        let mut cnt = vec![0; m as usize];
        for x in nums {
            cnt[((x % m + m) % m) as usize] += 1;
        }

        let mut i = 0;
        for j in 1..m as usize {
            if cnt[j] < cnt[i] {
                i = j;
            }
        }

        m * cnt[i] + i as i32
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n+m)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(m)$。

相似题目

见下面数学题单的「§1.9 同余」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

O(n) 一次遍历,简洁写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年11月10日 12:03

遍历 $\textit{nums}$,寻找严格递增段(子数组)。

设当前严格递增段的长度为 $\textit{cnt}$,上一个严格递增段的长度为 $\textit{preCnt}$。

答案有两种情况:

  • 两个子数组属于同一个严格递增段,那么 $k$ 最大是 $\left\lfloor\dfrac{\textit{cnt}}{2}\right\rfloor$。
  • 两个子数组分别属于一对相邻的严格递增段,那么 $k$ 最大是 $\min(\textit{preCnt}, \textit{cnt})$。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def maxIncreasingSubarrays(self, nums: List[int]) -> int:
        ans = pre_cnt = cnt = 0
        for i, x in enumerate(nums):
            cnt += 1
            if i == len(nums) - 1 or x >= nums[i + 1]:  # i 是严格递增段的末尾
                ans = max(ans, cnt // 2, min(pre_cnt, cnt))
                pre_cnt = cnt
                cnt = 0
        return ans

###java

class Solution {
    public int maxIncreasingSubarrays(List<Integer> nums) {
        int ans = 0;
        int preCnt = 0;
        int cnt = 0;
        for (int i = 0; i < nums.size(); i++) {
            cnt++;
            // i 是严格递增段的末尾
            if (i == nums.size() - 1 || nums.get(i) >= nums.get(i + 1)) {
                ans = Math.max(ans, Math.max(cnt / 2, Math.min(preCnt, cnt)));
                preCnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
}

###java

class Solution {
    public int maxIncreasingSubarrays(List<Integer> nums) {
        Integer[] a = nums.toArray(Integer[]::new); // 转成数组处理,更快
        int ans = 0;
        int preCnt = 0;
        int cnt = 0;
        for (int i = 0; i < a.length; i++) {
            cnt++;
            // i 是严格递增段的末尾
            if (i == a.length - 1 || a[i] >= a[i + 1]) {
                ans = Math.max(ans, Math.max(cnt / 2, Math.min(preCnt, cnt)));
                preCnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxIncreasingSubarrays(vector<int>& nums) {
        int ans = 0, pre_cnt = 0, cnt = 0;
        for (int i = 0; i < nums.size(); i++) {
            cnt++;
            if (i == nums.size() - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
                ans = max({ans, cnt / 2, min(pre_cnt, cnt)});
                pre_cnt = cnt;
                cnt = 0;
            }
        }
        return ans;
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))
#define MAX(a, b) ((b) > (a) ? (b) : (a))

int maxIncreasingSubarrays(int* nums, int numsSize) {
    int ans = 0, pre_cnt = 0, cnt = 0;
    for (int i = 0; i < numsSize; i++) {
        cnt++;
        if (i == numsSize - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
            ans = MAX(ans, MAX(cnt / 2, MIN(pre_cnt, cnt)));
            pre_cnt = cnt;
            cnt = 0;
        }
    }
    return ans;
}

###go

func maxIncreasingSubarrays(nums []int) (ans int) {
preCnt, cnt := 0, 0
for i, x := range nums {
cnt++
if i == len(nums)-1 || x >= nums[i+1] { // i 是严格递增段的末尾
ans = max(ans, cnt/2, min(preCnt, cnt))
preCnt = cnt
cnt = 0
}
}
return
}

###js

var maxIncreasingSubarrays = function(nums) {
    let ans = 0, preCnt = 0, cnt = 0;
    for (let i = 0; i < nums.length; i++) {
        cnt++;
        if (i === nums.length - 1 || nums[i] >= nums[i + 1]) { // i 是严格递增段的末尾
            ans = Math.max(ans, Math.floor(cnt / 2), Math.min(preCnt, cnt));
            preCnt = cnt;
            cnt = 0;
        }
    }
    return ans;
};

###rust

impl Solution {
    pub fn max_increasing_subarrays(nums: Vec<i32>) -> i32 {
        let mut ans = 0;
        let mut pre_cnt = 0;
        let mut cnt = 0;
        for i in 0..nums.len() {
            cnt += 1;
            if i == nums.len() - 1 || nums[i] >= nums[i + 1] { // i 是严格递增段的末尾
                ans = ans.max(cnt / 2).max(pre_cnt.min(cnt));
                pre_cnt = cnt;
                cnt = 0;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面双指针题单的「六、分组循环」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

从特殊到一般,教你如何思考本题,两种写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年5月12日 12:44

k = 1

从特殊到一般,先想一想,$k=1$ 怎么做?

此时只能一步一步地向右走。无论起点在哪,终点都是 $n-1$。

如果选择 $i$ 为起点,我们计算的是子数组 $[i,n-1]$ 的元素和,即后缀和

后缀和怎么算?我们可以倒着遍历 $\textit{energy}$,同时累加元素和,即为后缀和。

答案等于所有后缀和的最大值

k = 2

再想一想,$k=2$ 怎么做?

此时我们有两个终点:$n-2$ 和 $n-1$。

对于终点 $n-1$:

  • 如果选择 $n-3$ 为起点,那么我们累加的是下标为 $n-3,n-1$ 的元素和。
  • 如果选择 $n-5$ 为起点,那么我们累加的是下标为 $n-5,n-3,n-1$ 的元素和。
  • 如果选择 $n-7$ 为起点,那么我们累加的是下标为 $n-7,n-5,n-3,n-1$ 的元素和。
  • 一般地,从 $n-1$ 开始倒着遍历,步长为 $-k=-2$,累加元素和,计算元素和的最大值。

对于终点 $n-2$:

  • 如果选择 $n-4$ 为起点,那么我们累加的是下标为 $n-4,n-2$ 的元素和。
  • 如果选择 $n-6$ 为起点,那么我们累加的是下标为 $n-6,n-4,n-2$ 的元素和。
  • 如果选择 $n-8$ 为起点,那么我们累加的是下标为 $n-8,n-6,n-4,n-2$ 的元素和。
  • 一般地,从 $n-2$ 开始倒着遍历,步长为 $-k=-2$,累加元素和,计算元素和的最大值。

是否可以从 $n-3$ 开始倒着遍历?

不行,因为 $n-3$ 还可以向右跳到 $n-1$,所以 $n-3$ 不是终点,不能作为倒着遍历的起点。

一般情况

枚举终点 $n-k,n-k+1,\dots,n-1$,倒着遍历,步长为 $-k$。

遍历的同时累加元素和 $\textit{sufSum}$,计算 $\textit{sufSum}$ 的最大值,即为答案。

写法一

###py

class Solution:
    def maximumEnergy(self, energy: List[int], k: int) -> int:
        n = len(energy)
        ans = -inf
        for i in range(n - k, n):  # 枚举终点 i
            suf_sum = accumulate(energy[j] for j in range(i, -1, -k))  # 计算后缀和
            ans = max(ans, max(suf_sum))
        return ans

###java

class Solution {
    public int maximumEnergy(int[] energy, int k) {
        int n = energy.length;
        int ans = Integer.MIN_VALUE;
        for (int i = n - k; i < n; i++) { // 枚举终点 i
            int sufSum = 0;
            for (int j = i; j >= 0; j -= k) {
                sufSum += energy[j]; // 计算后缀和
                ans = Math.max(ans, sufSum);
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maximumEnergy(vector<int>& energy, int k) {
        int n = energy.size();
        int ans = INT_MIN;
        for (int i = n - k; i < n; i++) { // 枚举终点 i
            int suf_sum = 0;
            for (int j = i; j >= 0; j -= k) {
                suf_sum += energy[j]; // 计算后缀和
                ans = max(ans, suf_sum);
            }
        }
        return ans;
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

int maximumEnergy(int* energy, int energySize, int k){
    int ans = INT_MIN;
    for (int i = energySize - k; i < energySize; i++) { // 枚举终点 i
        int suf_sum = 0;
        for (int j = i; j >= 0; j -= k) {
            suf_sum += energy[j]; // 计算后缀和
            ans = MAX(ans, suf_sum);
        }
    }
    return ans;
}

###go

func maximumEnergy(energy []int, k int) int {
n := len(energy)
ans := math.MinInt
for i := n - k; i < n; i++ { // 枚举终点 i
sufSum := 0
for j := i; j >= 0; j -= k {
sufSum += energy[j] // 计算后缀和
ans = max(ans, sufSum)
}
}
return ans
}

###js

var maximumEnergy = function(energy, k) {
    const n = energy.length;
    let ans = -Infinity;
    for (let i = n - k; i < n; i++) { // 枚举终点 i
        let sufSum = 0;
        for (let j = i; j >= 0; j -= k) {
            sufSum += energy[j]; // 计算后缀和
            ans = Math.max(ans, sufSum);
        }
    }
    return ans;
};

###rust

impl Solution {
    pub fn maximum_energy(energy: Vec<i32>, k: i32) -> i32 {
        let n = energy.len();
        let k = k as usize;
        let mut ans = i32::MIN;
        for i in n - k..n { // 枚举终点 i
            let mut suf_sum = 0;
            for j in (0..=i).rev().step_by(k) {
                suf_sum += energy[j]; // 计算后缀和
                ans = ans.max(suf_sum);
            }
        }
        ans
    }
}

写法二

原地计算后缀和,把后缀和保存到 $\textit{energy}$ 中。

最后返回 $\textit{energy}$ 的最大值,即为所有后缀和的最大值。

###py

class Solution:
    def maximumEnergy(self, energy: List[int], k: int) -> int:
        for i in range(len(energy) - k - 1, -1, -1):
            energy[i] += energy[i + k]
        return max(energy)

###java

class Solution {
    public int maximumEnergy(int[] energy, int k) {
        for (int i = energy.length - k - 1; i >= 0; i--) {
            energy[i] += energy[i + k];
        }
        return Arrays.stream(energy).max().getAsInt();
    }
}

###cpp

class Solution {
public:
    int maximumEnergy(vector<int>& energy, int k) {
        int n = energy.size();
        for (int i = n - k - 1; i >= 0; i--) {
            energy[i] += energy[i + k];
        }
        return ranges::max(energy);
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

int maximumEnergy(int* energy, int energySize, int k) {
    int ans = INT_MIN;
    for (int i = energySize - 1; i >= 0; i--) {
        if (i + k < energySize) {
            energy[i] += energy[i + k];
        }
        ans = MAX(ans, energy[i]);
    }
    return ans;
}

###go

func maximumEnergy(energy []int, k int) int {
for i := len(energy) - k - 1; i >= 0; i-- {
energy[i] += energy[i+k]
}
return slices.Max(energy)
}

###js

var maximumEnergy = function(energy, k) {
    for (let i = energy.length - k - 1; i >= 0; i--) {
        energy[i] += energy[i + k];
    }
    return Math.max(...energy);
};

###rust

impl Solution {
    pub fn maximum_energy(mut energy: Vec<i32>, k: i32) -> i32 {
        let k = k as usize;
        for i in (0..energy.len() - k).rev() {
            energy[i] += energy[i + k];
        }
        *energy.iter().max().unwrap()
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{energy}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面贪心与思维题单的「§5.3 逆向思维」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

四种方法:正反两次扫描 / 递推 / record / 凸包+二分(Python/Java/C++/Go)

作者 endlesscheng
2025年3月23日 12:04

方法一:正反两次扫描

为了计算酿造药水的时间,定义 $\textit{lastFinish}[i]$ 表示巫师 $i$ 完成上一瓶药水的时间。

示例 1 在处理完 $\textit{mana}[0]$ 后,有

$$
\textit{lastFinish} = [5,30,40,60]
$$

如果接着 $\textit{lastFinish}$ 继续酿造下一瓶药水 $\textit{mana}[1]=1$,完成时间是多少?注意开始酿造的时间不能早于 $\textit{lastFinish}[i]$。

$i$ $\textit{skill}[i]$ $\textit{lastFinish}[i]$ 完成时间
$0$ $1$ $5$ $5+1=6$
$1$ $5$ $30$ $\max(6,30)+5=35$
$2$ $2$ $40$ $\max(35,40)+2=42$
$3$ $4$ $60$ $\max(42,60)+4=64$

题目要求「药水在当前巫师完成工作后必须立即传递给下一个巫师并开始处理」,也就是说,酿造药水的过程中是不能有停顿的。

从 $64$ 开始倒推,可以得到每名巫师的实际完成时间。比如倒数第二位巫师的完成时间,就是 $64$ 减去最后一名巫师花费的时间 $4\cdot 1$,得到 $60$。

$i$ $\textit{skill}[i+1]$ 实际完成时间
$3$ - $64$
$2$ $4$ $64-4\cdot 1=60$
$1$ $2$ $60-2\cdot 1=58$
$0$ $5$ $58-5\cdot 1=53$

按照上述过程处理每瓶药水,最终答案为 $\textit{lastFinish}[n-1]$。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def minTime(self, skill: List[int], mana: List[int]) -> int:
        n = len(skill)
        last_finish = [0] * n  # 第 i 名巫师完成上一瓶药水的时间
        for m in mana:
            # 按题意模拟
            sum_t = 0
            for x, last in zip(skill, last_finish):
                if last > sum_t: sum_t = last  # 手写 max
                sum_t += x * m
            # 倒推:如果酿造药水的过程中没有停顿,那么 last_finish[i] 应该是多少
            last_finish[-1] = sum_t
            for i in range(n - 2, -1, -1):
                last_finish[i] = last_finish[i + 1] - skill[i + 1] * m
        return last_finish[-1]

###java

class Solution {
    public long minTime(int[] skill, int[] mana) {
        int n = skill.length;
        long[] lastFinish = new long[n]; // 第 i 名巫师完成上一瓶药水的时间
        for (int m : mana) {
            // 按题意模拟
            long sumT = 0;
            for (int i = 0; i < n; i++) {
                sumT = Math.max(sumT, lastFinish[i]) + skill[i] * m;
            }
            // 倒推:如果酿造药水的过程中没有停顿,那么 lastFinish[i] 应该是多少
            lastFinish[n - 1] = sumT;
            for (int i = n - 2; i >= 0; i--) {
                lastFinish[i] = lastFinish[i + 1] - skill[i + 1] * m;
            }
        }
        return lastFinish[n - 1];
    }
}

###cpp

class Solution {
public:
    long long minTime(vector<int>& skill, vector<int>& mana) {
        int n = skill.size();
        vector<long long> last_finish(n); // 第 i 名巫师完成上一瓶药水的时间
        for (int m : mana) {
            // 按题意模拟
            long long sum_t = 0;
            for (int i = 0; i < n; i++) {
                sum_t = max(sum_t, last_finish[i]) + skill[i] * m;
            }
            // 倒推:如果酿造药水的过程中没有停顿,那么 lastFinish[i] 应该是多少
            last_finish[n - 1] = sum_t;
            for (int i = n - 2; i >= 0; i--) {
                last_finish[i] = last_finish[i + 1] - skill[i + 1] * m;
            }
        }
        return last_finish[n - 1];
    }
};

###go

func minTime(skill, mana []int) int64 {
n := len(skill)
lastFinish := make([]int, n) // 第 i 名巫师完成上一瓶药水的时间
for _, m := range mana {
// 按题意模拟
sumT := 0
for i, x := range skill {
sumT = max(sumT, lastFinish[i]) + x*m
}
// 倒推:如果酿造药水的过程中没有停顿,那么 lastFinish[i] 应该是多少
lastFinish[n-1] = sumT
for i := n - 2; i >= 0; i-- {
lastFinish[i] = lastFinish[i+1] - skill[i+1]*m
}
}
return int64(lastFinish[n-1])
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(nm)$,其中 $n$ 是 $\textit{skill}$ 的长度,$m$ 是 $\textit{mana}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

方法二:递推开始时间

由于酿造药水的过程是连续的,所以知道了开始时间(或者完成时间)就能知道每个 $\textit{lastFinish}[i]$。所以 $\textit{lastFinish}$ 数组是多余的。

设开始酿造 $\textit{mana}[j]$ 的时间为 $\textit{start}_j$,那么有

$$
\textit{lastFinish}_j[i] = \textit{start}j + \textit{mana}[j]\cdot \sum{k=0}^{i} \textit{skill}[k]
$$

在已知 $\textit{start}_{j-1}$ 的前提下,能否递推算出 $\textit{start}_j$?

哪位巫师决定了开始时间?假设第 $i$ 位巫师决定了开始时间,那么这位巫师完成 $\textit{mana}[j-1]$ 的时间,同时也是他开始 $\textit{mana}[j]$ 的时间。

所以有

$$
\textit{lastFinish}_{j-1}[i] + \textit{mana}[j]\cdot \textit{skill}[i] = \textit{lastFinish}_j[i]
$$

两边代入 $\textit{lastFinish}_j[i]$ 的式子,得

$$
\textit{start}{j-1} + \textit{mana}[j-1]\cdot \sum{k=0}^{i} \textit{skill}[k] + \textit{mana}[j]\cdot \textit{skill}[i] = \textit{start}j + \textit{mana}[j]\cdot \sum{k=0}^{i} \textit{skill}[k]
$$

移项得

$$
\textit{start}j = \textit{start}{j-1} + \textit{mana}[j-1]\cdot \sum_{k=0}^{i} \textit{skill}[k] - \textit{mana}[j]\cdot \sum_{k=0}^{i-1} \textit{skill}[k]
$$

计算 $\textit{skill}$ 的 前缀和 数组 $s$,上式为

$$
\textit{start}j = \textit{start}{j-1} + \textit{mana}[j-1]\cdot s[i+1] - \textit{mana}[j]\cdot s[i]
$$

枚举 $i$,取最大值,得

$$
\textit{start}j = \textit{start}{j-1} + \max_{i=0}^{n-1} \left{\textit{mana}[j-1]\cdot s[i+1] - \textit{mana}[j]\cdot s[i]\right}
$$

初始值 $\textit{start}_0 = 0$。

答案为 $\textit{lastFinish}{m-1}[n-1] = \textit{start}{m-1} + \textit{mana}[m-1]\cdot s[n]$。

###py

class Solution:
    def minTime(self, skill: List[int], mana: List[int]) -> int:
        n = len(skill)
        s = list(accumulate(skill, initial=0))  # skill 的前缀和
        start = 0
        for pre, cur in pairwise(mana):
            start += max(pre * s[i + 1] - cur * s[i] for i in range(n))
        return start + mana[-1] * s[-1]

###java

class Solution {
    public long minTime(int[] skill, int[] mana) {
        int n = skill.length;
        int[] s = new int[n + 1]; // skill 的前缀和
        for (int i = 0; i < n; i++) {
            s[i + 1] = s[i] + skill[i];
        }

        int m = mana.length;
        long start = 0;
        for (int j = 1; j < m; j++) {
            long mx = 0;
            for (int i = 0; i < n; i++) {
                mx = Math.max(mx, (long) mana[j - 1] * s[i + 1] - (long) mana[j] * s[i]);
            }
            start += mx;
        }
        return start + (long) mana[m - 1] * s[n];
    }
}

###cpp

class Solution {
public:
    long long minTime(vector<int>& skill, vector<int>& mana) {
        int n = skill.size(), m = mana.size();
        vector<int> s(n + 1); // skill 的前缀和
        partial_sum(skill.begin(), skill.end(), s.begin() + 1);

        long long start = 0;
        for (int j = 1; j < m; j++) {
            long long mx = 0;
            for (int i = 0; i < n; i++) {
                mx = max(mx, 1LL * mana[j - 1] * s[i + 1] - 1LL * mana[j] * s[i]);
            }
            start += mx;
        }
        return start + 1LL * mana[m - 1] * s[n];
    }
};

###go

func minTime(skill, mana []int) int64 {
n, m := len(skill), len(mana)
s := make([]int, n+1) // skill 的前缀和
for i, x := range skill {
s[i+1] = s[i] + x
}

start := 0
for j := 1; j < m; j++ {
mx := 0
for i := range n {
mx = max(mx, mana[j-1]*s[i+1]-mana[j]*s[i])
}
start += mx
}
return int64(start + mana[m-1]*s[n])
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(nm)$,其中 $n$ 是 $\textit{skill}$ 的长度,$m$ 是 $\textit{mana}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。如果在遍历的同时计算前缀和,则可以做到 $\mathcal{O}(1)$ 空间。

方法三:record 优化

将递推式

$$
\textit{start}j = \textit{start}{j-1} + \max_{i=0}^{n-1} \left{\textit{mana}[j-1]\cdot s[i+1] - \textit{mana}[j]\cdot s[i]\right}
$$

变形为

$$
\textit{start}j = \textit{start}{j-1} + \max_{i=0}^{n-1} \left{(\textit{mana}[j-1]-\textit{mana}[j])\cdot s[i] + \textit{mana}[j-1]\cdot \textit{skill}[i] \right}
$$

设 $d = \textit{mana}[j-1]-\textit{mana}[j]$。分类讨论:

  • 如果 $d > 0$。由于 $s$ 是单调递增数组,如果 $\textit{skill}[3] < \textit{skill}[5]$,那么 $i=3$ 绝对不会算出最大值;但如果 $\textit{skill}[3] > \textit{skill}[5]$,谁会算出最大值就不一定了。所以我们只需要考虑 $\textit{skill}$ 的逆序 record,这才是可能成为最大值的数据。其中逆序 record 的意思是,倒序遍历 $\textit{skill}$,每次遍历到更大的数,就记录下标。
  • 如果 $d < 0$。由于 $s$ 是单调递增数组,如果 $\textit{skill}[5] < \textit{skill}[3]$,那么 $i=5$ 绝对不会算出最大值;但如果 $\textit{skill}[5] > \textit{skill}[3]$,谁会算出最大值就不一定了。所以我们只需要考虑 $\textit{skill}$ 的正序 record,这才是可能成为最大值的数据。其中正序 record 的意思是,正序遍历 $\textit{skill}$,每次遍历到更大的数,就记录下标。
  • $d = 0$ 的情况可以并入 $d>0$ 的情况。

###py

class Solution:
    def minTime(self, skill: List[int], mana: List[int]) -> int:
        n = len(skill)
        s = list(accumulate(skill, initial=0))

        suf_record = [n - 1]
        for i in range(n - 2, -1, -1):
            if skill[i] > skill[suf_record[-1]]:
                suf_record.append(i)

        pre_record = [0]
        for i in range(1, n):
            if skill[i] > skill[pre_record[-1]]:
                pre_record.append(i)

        start = 0
        for pre, cur in pairwise(mana):
            record = pre_record if pre < cur else suf_record
            start += max(pre * s[i + 1] - cur * s[i] for i in record)
        return start + mana[-1] * s[-1]

###java

class Solution {
    public long minTime(int[] skill, int[] mana) {
        int n = skill.length;
        int[] s = new int[n + 1];
        for (int i = 0; i < n; i++) {
            s[i + 1] = s[i] + skill[i];
        }

        List<Integer> suf = new ArrayList<>();
        suf.add(n - 1);
        for (int i = n - 2; i >= 0; i--) {
            if (skill[i] > skill[suf.getLast()]) {
                suf.add(i);
            }
        }

        List<Integer> pre = new ArrayList<>();
        pre.add(0);
        for (int i = 1; i < n; i++) {
            if (skill[i] > skill[pre.getLast()]) {
                pre.add(i);
            }
        }

        int m = mana.length;
        long start = 0;
        for (int j = 1; j < m; j++) {
            List<Integer> record = mana[j - 1] < mana[j] ? pre : suf;
            long mx = 0;
            for (int i : record) {
                mx = Math.max(mx, (long) mana[j - 1] * s[i + 1] - (long) mana[j] * s[i]);
            }
            start += mx;
        }
        return start + (long) mana[m - 1] * s[n];
    }
}

###java

class Solution {
    public long minTime(int[] skill, int[] mana) {
        int n = skill.length;
        int[] s = new int[n + 1];
        for (int i = 0; i < n; i++) {
            s[i + 1] = s[i] + skill[i];
        }

        int[] suf = new int[n];
        int sufLen = 0;
        suf[sufLen++] = n - 1;
        for (int i = n - 2; i >= 0; i--) {
            if (skill[i] > skill[suf[sufLen - 1]]) {
                suf[sufLen++] = i;
            }
        }

        int[] pre = new int[n];
        int preLen = 0;
        pre[preLen++] = 0;
        for (int i = 1; i < n; i++) {
            if (skill[i] > skill[pre[preLen - 1]]) {
                pre[preLen++] = i;
            }
        }

        int m = mana.length;
        long start = 0;
        for (int j = 1; j < m; j++) {
            int[] record = mana[j - 1] < mana[j] ? pre : suf;
            int recordLen = mana[j - 1] < mana[j] ? preLen : sufLen;
            long mx = 0;
            for (int k = 0; k < recordLen; k++) {
                int i = record[k];
                mx = Math.max(mx, (long) mana[j - 1] * s[i + 1] - (long) mana[j] * s[i]);
            }
            start += mx;
        }
        return start + (long) mana[m - 1] * s[n];
    }
}

###cpp

class Solution {
public:
    long long minTime(vector<int>& skill, vector<int>& mana) {
        int n = skill.size(), m = mana.size();
        vector<int> s(n + 1);
        partial_sum(skill.begin(), skill.end(), s.begin() + 1);

        vector<int> suf = {n - 1};
        for (int i = n - 2; i >= 0; i--) {
            if (skill[i] > skill[suf.back()]) {
                suf.push_back(i);
            }
        }

        vector<int> pre = {0};
        for (int i = 1; i < n; i++) {
            if (skill[i] > skill[pre.back()]) {
                pre.push_back(i);
            }
        }

        long long start = 0;
        for (int j = 1; j < m; j++) {
            auto& record = mana[j - 1] < mana[j] ? pre : suf;
            long long mx = 0;
            for (int i : record) {
                mx = max(mx, 1LL * mana[j - 1] * s[i + 1] - 1LL * mana[j] * s[i]);
            }
            start += mx;
        }
        return start + 1LL * mana[m - 1] * s[n];
    }
};

###go

func minTime(skill, mana []int) int64 {
n, m := len(skill), len(mana)
s := make([]int, n+1)
for i, x := range skill {
s[i+1] = s[i] + x
}

suf := []int{n - 1}
for i := n - 2; i >= 0; i-- {
if skill[i] > skill[suf[len(suf)-1]] {
suf = append(suf, i)
}
}

pre := []int{0}
for i := 1; i < n; i++ {
if skill[i] > skill[pre[len(pre)-1]] {
pre = append(pre, i)
}
}

start := 0
for j := 1; j < m; j++ {
record := suf
if mana[j-1] < mana[j] {
record = pre
}
mx := 0
for _, i := range record {
mx = max(mx, mana[j-1]*s[i+1]-mana[j]*s[i])
}
start += mx
}
return int64(start + mana[m-1]*s[n])
}

复杂度分析(最坏情况)

  • 时间复杂度:$\mathcal{O}(nm)$,其中 $n$ 是 $\textit{skill}$ 的长度,$m$ 是 $\textit{mana}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

复杂度分析(平均情况)

力扣喜欢出随机数据,上述算法在随机数据下的性能如何?

换句话说,在随机数据下,record 的期望长度是多少?

为方便分析,假设 $\textit{skill}$ 是一个随机的 $[1,n]$ 的排列。

$\textit{skill}[i]$ 如果是一个新的最大值,那么它是 $[0,i]$ 中的最大值。在随机排列的情况下,$[0,i]$ 中的排列也是随机的,所以这等价于该排列的最后一个数是最大值的概率,即

$$
\dfrac{1}{i+1}
$$

record 的期望长度,等于「每个位置能否成为新的最大值」之和,能就贡献 $1$,不能就贡献 $0$。

所以 $\textit{skill}[i]$ 给期望的贡献是 $\dfrac{1}{i+1}$。所以 record 的期望长度为

$$
\sum_{i=0}^{n-1} \dfrac{1}{i+1}
$$

由调和级数可知,record 的期望长度为 $\Theta(\log n)$。

  • 时间复杂度:$\Theta(n + m\log n)$,其中 $n$ 是 $\textit{skill}$ 的长度,$m$ 是 $\textit{mana}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

方法四:凸包 + 二分

前置知识:二维计算几何,凸包,Andrew 算法。

把递推式

$$
\textit{start}j = \textit{start}{j-1} + \max_{i=0}^{n-1} \left{(\textit{mana}[j-1]-\textit{mana}[j])\cdot s[i] + \textit{mana}[j-1]\cdot \textit{skill}[i] \right}
$$

中的

$$
(\textit{mana}[j-1]-\textit{mana}[j])\cdot s[i] + \textit{mana}[j-1]\cdot \textit{skill}[i]
$$

改成点积的形式,这样我们能得到来自几何意义上的观察。

设向量 $\mathbf{v}_i = (s[i],\textit{skill}[i])$。

设向量 $\mathbf{p} = (\textit{mana}[j-1]-\textit{mana}[j], \textit{mana}[j-1])$。

那么我们求的是

$$
\max_{i=0}^{n-1} \mathbf{p}\cdot \mathbf{v}_i
$$

根据点积的几何意义,我们求的是 $\mathbf{v}_i$ 在 $\mathbf{p}$ 方向上的投影长度,再乘以 $\mathbf{p}$ 的模长 $||\mathbf{p}||$。由于 $||\mathbf{p}||$ 是个定值,所以要最大化投影长度。

考虑 $\mathbf{v}_i$ 的上凸包(用 Andrew 算法计算),在凸包内的点,就像是山坳,比凸包顶点的投影长度短。所以只需考虑凸包顶点。

这样有一个很好的性质:顺时针(或者逆时针)遍历凸包顶点,$\mathbf{p}\cdot \mathbf{v}_i$ 会先变大再变小(单峰函数)。那么要计算最大值,就类似 852. 山脉数组的峰顶索引二分首个「下坡」的位置,具体见 我的题解

###py

class Vec:
    __slots__ = 'x', 'y'

    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y

    def __sub__(self, b: "Vec") -> "Vec":
        return Vec(self.x - b.x, self.y - b.y)

    def det(self, b: "Vec") -> int:
        return self.x * b.y - self.y * b.x

    def dot(self, b: "Vec") -> int:
        return self.x * b.x + self.y * b.y

class Solution:
    # Andrew 算法,计算 points 的上凸包
    # 由于横坐标(前缀和)是严格递增的,所以无需排序
    def convex_hull(self, points: List[Vec]) -> List[Vec]:
        q = []
        for p in points:
            while len(q) > 1 and (q[-1] - q[-2]).det(p - q[-1]) >= 0:
                q.pop()
            q.append(p)
        return q

    def minTime(self, skill: List[int], mana: List[int]) -> int:
        s = list(accumulate(skill, initial=0))
        vs = [Vec(pre_sum, x) for pre_sum, x in zip(s, skill)]
        vs = self.convex_hull(vs)  # 去掉无用数据

        start = 0
        for pre, cur in pairwise(mana):
            p = Vec(pre - cur, pre)
            # p.dot(vs[i]) 是个单峰函数,二分找最大值
            check = lambda i: p.dot(vs[i]) > p.dot(vs[i + 1])
            i = bisect_left(range(len(vs) - 1), True, key=check)
            start += p.dot(vs[i])
        return start + mana[-1] * s[-1]

###java

class Solution {
    private record Vec(int x, int y) {
        Vec sub(Vec b) {
            return new Vec(x - b.x, y - b.y);
        }

        long det(Vec b) {
            return (long) x * b.y - (long) y * b.x;
        }

        long dot(Vec b) {
            return (long) x * b.x + (long) y * b.y;
        }
    }

    // Andrew 算法,计算 points 的上凸包
    // 由于横坐标(前缀和)是严格递增的,所以无需排序
    private List<Vec> convexHull(Vec[] points) {
        List<Vec> q = new ArrayList<>();
        for (Vec p : points) {
            while (q.size() > 1 && q.getLast().sub(q.get(q.size() - 2)).det(p.sub(q.getLast())) >= 0) {
                q.removeLast();
            }
            q.add(p);
        }
        return q;
    }

    public long minTime(int[] skill, int[] mana) {
        int n = skill.length;
        int[] s = new int[n + 1];
        Vec[] points = new Vec[n];
        for (int i = 0; i < n; i++) {
            s[i + 1] = s[i] + skill[i];
            points[i] = new Vec(s[i], skill[i]);
        }
        List<Vec> vs = convexHull(points); // 去掉无用数据

        int m = mana.length;
        long start = 0;
        for (int j = 1; j < m; j++) {
            Vec p = new Vec(mana[j - 1] - mana[j], mana[j - 1]);
            // p.dot(vs[i]) 是个单峰函数,二分找最大值
            int l = -1, r = vs.size() - 1;
            while (l + 1 < r) {
                int mid = (l + r) >>> 1;
                if (p.dot(vs.get(mid)) > p.dot(vs.get(mid + 1))) {
                    r = mid;
                } else {
                    l = mid;
                }
            }
            start += p.dot(vs.get(r));
        }
        return start + (long) mana[m - 1] * s[n];
    }
}

###cpp

struct Vec {
    int x, y;
    Vec operator-(const Vec& b) { return {x - b.x, y - b.y}; }
    long long det(const Vec& b) { return 1LL * x * b.y - 1LL * y * b.x; }
    long long dot(const Vec& b) { return 1LL * x * b.x + 1LL * y * b.y; }
};

class Solution {
    // Andrew 算法,计算 points 的上凸包
    // 由于横坐标(前缀和)是严格递增的,所以无需排序
    vector<Vec> convex_hull(vector<Vec>& points) {
        vector<Vec> q;
        for (auto& p : points) {
            while (q.size() > 1 && (q.back() - q[q.size() - 2]).det(p - q.back()) >= 0) {
                q.pop_back();
            }
            q.push_back(p);
        }
        return q;
    }

public:
    long long minTime(vector<int>& skill, vector<int>& mana) {
        int n = skill.size(), m = mana.size();
        vector<int> s(n + 1);
        vector<Vec> vs(n);
        for (int i = 0; i < n; i++) {
            s[i + 1] = s[i] + skill[i];
            vs[i] = {s[i], skill[i]};
        }
        vs = convex_hull(vs); // 去掉无用数据

        long long start = 0;
        for (int j = 1; j < m; j++) {
            Vec p = {mana[j - 1] - mana[j], mana[j - 1]};
            // p.dot(vs[i]) 是个单峰函数,二分找最大值
            int l = -1, r = vs.size() - 1;
            while (l + 1 < r) {
                int mid = l + (r - l) / 2;
                (p.dot(vs[mid]) > p.dot(vs[mid + 1]) ? r : l) = mid;
            }
            start += p.dot(vs[r]);
        }
        return start + 1LL * mana[m - 1] * s[n];
    }
};

###go

type vec struct{ x, y int }

func (a vec) sub(b vec) vec { return vec{a.x - b.x, a.y - b.y} }
func (a vec) det(b vec) int { return a.x*b.y - a.y*b.x }
func (a vec) dot(b vec) int { return a.x*b.x + a.y*b.y }

// Andrew 算法,计算 points 的上凸包
// 由于横坐标(前缀和)是严格递增的,所以无需排序
func convexHull(points []vec) (q []vec) {
for _, p := range points {
for len(q) > 1 && q[len(q)-1].sub(q[len(q)-2]).det(p.sub(q[len(q)-1])) >= 0 {
q = q[:len(q)-1]
}
q = append(q, p)
}
return
}

func minTime(skill, mana []int) int64 {
n, m := len(skill), len(mana)
s := make([]int, n+1)
vs := make([]vec, n)
for i, x := range skill {
s[i+1] = s[i] + x
vs[i] = vec{s[i], x}
}
vs = convexHull(vs) // 去掉无用数据

start := 0
for j := 1; j < m; j++ {
p := vec{mana[j-1] - mana[j], mana[j-1]}
// p.dot(vs[i]) 是个单峰函数,二分找最大值
i := sort.Search(len(vs)-1, func(i int) bool { return p.dot(vs[i]) > p.dot(vs[i+1]) })
start += p.dot(vs[i])
}
return int64(start + mana[m-1]*s[n])
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n + m\log n)$,其中 $n$ 是 $\textit{skill}$ 的长度,$m$ 是 $\textit{mana}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:排序+二分 / 计数+值域后缀和(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2022年6月12日 07:00

方法一:排序 + 二分查找

写法一:使用浮点数

问题相当于给你 $n$ 个询问,每次问 $\textit{spells}[i]$ 与 $\textit{potions}$ 中的多少个数相乘,结果 $\ge \textit{success}$。

对 $\textit{potions}$ 排序后,就可以二分查找了:

  • 设 $j$ 是最小的满足 $\textit{potions}[j]\ge\dfrac{\textit{success}}{\textit{spells}[i]}$ 的下标。
  • 由于数组已经排序,那么下标大于 $j$ 的数也同样满足不等式。
  • 从 $j$ 到 $m-1$,一共有 $m-j$ 个满足不等式的数,其中 $m$ 是 $\textit{potions}$ 的长度。

有关二分查找的原理,请看【基础算法精讲 04】

class Solution:
    def successfulPairs(self, spells: List[int], potions: List[int], success: int) -> List[int]:
        potions.sort()
        m = len(potions)
        return [m - bisect_left(potions, success / x) for x in spells]
class Solution {
    public int[] successfulPairs(int[] spells, int[] potions, long success) {
        Arrays.sort(potions);
        for (int i = 0; i < spells.length; i++) {
            spells[i] = potions.length - lowerBound(potions, (double) success / spells[i]);
        }
        return spells;
    }

    // 返回 nums 中的第一个大于等于 target 的元素下标
    private int lowerBound(int[] nums, double target) {
        int left = -1, right = nums.length; // 开区间 (left, right)
        while (left + 1 < right) { // 区间不为空
            // 循环不变量:
            // nums[left] <= target
            // nums[right] > target
            int mid = (left + right) >>> 1; // 比 left+(right-left)/2 更快的写法
            if (nums[mid] >= target) {
                right = mid; // 二分范围缩小到 (left, mid)
            } else {
                left = mid; // 二分范围缩小到 (mid, right)
            }
        }
        return right;
    }
}
class Solution {
public:
    vector<int> successfulPairs(vector<int>& spells, vector<int>& potions, long long success) {
        ranges::sort(potions);
        for (int& x : spells) { // 原地修改
            x = potions.end() - ranges::lower_bound(potions, 1.0 * success / x);
        }
        return spells;
    }
};
int cmp(const void* a, const void* b) {
    return *(int*)a - *(int*)b;
}

// 返回 nums 中的第一个大于等于 target 的元素下标
int lowerBound(int* nums, int numsSize, double target) {
    int left = -1, right = numsSize; // 开区间 (left, right)
    while (left + 1 < right) { // 区间不为空
        int mid = left + (right - left) / 2;
        if (nums[mid] >= target) {
            right = mid; // 二分范围缩小到 (left, mid)
        } else {
            left = mid; // 二分范围缩小到 (mid, right)
        }
    }
    return right;
}

int* successfulPairs(int* spells, int spellsSize, int* potions, int potionsSize, long long success, int* returnSize) {
    qsort(potions, potionsSize, sizeof(int), cmp);
    for (int i = 0; i < spellsSize; i++) {
        spells[i] = potionsSize - lowerBound(potions, potionsSize, 1.0 * success / spells[i]);
    }
    *returnSize = spellsSize;
    return spells;
}
func successfulPairs(spells, potions []int, success int64) []int {
    slices.Sort(potions)
    for i, x := range spells {
        target := float64(success) / float64(x)
        j := sort.Search(len(potions), func(j int) bool { return float64(potions[j]) >= target })
        spells[i] = len(potions) - j
    }
    return spells
}
var successfulPairs = function(spells, potions, success) {
    potions.sort((a, b) => a - b);
    for (let i = 0; i < spells.length; i++) {
        const target = success / spells[i];
        spells[i] = potions.length - lowerBound(potions, target);
    }
    return spells;
};

var lowerBound = function(nums, target) {
    let left = -1, right = nums.length; // 开区间 (left, right)
    while (left + 1 < right) { // 区间不为空
        // 循环不变量:
        // nums[left] < target
        // nums[right] >= target
        const mid = Math.floor((left + right) / 2);
        if (nums[mid] >= target) {
            right = mid; // 范围缩小到 (left, mid)
        } else {
            left = mid; // 范围缩小到 (mid, right)
        }
    }
    return right;
}
var successfulPairs = function(spells, potions, success) {
    potions.sort((a, b) => a - b);
    for (let i = 0; i < spells.length; i++) {
        const target = success / spells[i];
        spells[i] = potions.length - _.sortedIndex(potions, target);
    }
    return spells;
};
impl Solution {
    pub fn successful_pairs(mut spells: Vec<i32>, mut potions: Vec<i32>, success: i64) -> Vec<i32> {
        potions.sort_unstable();
        let last = potions[potions.len() - 1] as i64;
        for x in spells.iter_mut() {
            let target = success as f64 / *x as f64;
            let j = potions.partition_point(|&x| (x as f64) < target);
            *x = (potions.len() - j) as i32;
        }
        spells
    }
}

写法二:不使用浮点数

浮点数有舍入误差,如果数据范围更大,上面的做法就不一定正确了。

更好的做法是,避免使用浮点数,只使用整数计算。一方面可以保证正确性,另一方面整数运算比浮点运算快。

对于正整数,$xy\ge\textit{success}$ 等价于 $y\ge\left\lceil\dfrac{\textit{success}}{x}\right\rceil$。

为方便二分,可以利用如下恒等式:

$$
\left\lceil\dfrac{a}{b}\right\rceil = \left\lfloor\dfrac{a+b-1}{b}\right\rfloor = \left\lfloor\dfrac{a-1}{b}\right\rfloor + 1
$$

证明见 上取整下取整转换公式的证明

根据上式,我们有

$$
y\ge\left\lceil\dfrac{\textit{success}}{x}\right\rceil = \left\lfloor\dfrac{\textit{success}-1}{x}\right\rfloor + 1
$$

也可以写成

$$
y>\left\lfloor\dfrac{\textit{success}-1}{x}\right\rfloor
$$

为什么不等式一定要这样变形?好处是只需要在二分之前做一次除法,避免在二分循环内计算乘法,效率更高。另外的好处是部分语言可以直接调用库函数二分。

class Solution:
    def successfulPairs(self, spells: List[int], potions: List[int], success: int) -> List[int]:
        potions.sort()
        m = len(potions)
        success -= 1  # 提前减一,避免在循环中反复减一
        return [m - bisect_right(potions, success // x) for x in spells]
class Solution {
    public int[] successfulPairs(int[] spells, int[] potions, long success) {
        Arrays.sort(potions);
        for (int i = 0; i < spells.length; i++) {
            long target = (success - 1) / spells[i];
            if (target < potions[potions.length - 1]) {
                // 这样写每次二分就只用比两个 int 的大小,避免把 potions 中的元素转成 long 比较
                spells[i] = potions.length - upperBound(potions, (int) target);
            } else {
                spells[i] = 0;
            }
        }
        return spells;
    }

    // 返回 nums 中的第一个大于 target 的元素下标
    private int upperBound(int[] nums, int target) {
        int left = -1, right = nums.length; // 开区间 (left, right)
        while (left + 1 < right) { // 区间不为空
            // 循环不变量:
            // nums[left] <= target
            // nums[right] > target
            int mid = (left + right) >>> 1; // 比 left+(right-left)/2 更快的写法
            if (nums[mid] > target) {
                right = mid; // 二分范围缩小到 (left, mid)
            } else {
                left = mid; // 二分范围缩小到 (mid, right)
            }
        }
        return right;
    }
}
class Solution {
public:
    vector<int> successfulPairs(vector<int>& spells, vector<int>& potions, long long success) {
        ranges::sort(potions);
        for (int& x : spells) { // 原地修改
            long long target = (success - 1) / x;
            if (target < potions.back()) {
                // 这样写每次二分就只用比两个 int 的大小,避免把 potions 中的元素转成 long long 比较
                x = potions.end() - ranges::upper_bound(potions, (int) target);
            } else {
                x = 0;
            }
        }
        return spells;
    }
};
int cmp(const void* a, const void* b) {
    return *(int*)a - *(int*)b;
}

// 返回 nums 中的第一个大于 target 的元素下标
int upperBound(int* nums, int numsSize, int target) {
    int left = -1, right = numsSize; // 开区间 (left, right)
    while (left + 1 < right) { // 区间不为空
        int mid = left + (right - left) / 2;
        if (nums[mid] > target) {
            right = mid; // 二分范围缩小到 (left, mid)
        } else {
            left = mid; // 二分范围缩小到 (mid, right)
        }
    }
    return right;
}

int* successfulPairs(int* spells, int spellsSize, int* potions, int potionsSize, long long success, int* returnSize) {
    qsort(potions, potionsSize, sizeof(int), cmp);
    for (int i = 0; i < spellsSize; i++) {
        long long target = (success - 1) / spells[i];
        if (target < potions[potionsSize - 1]) {
            // 这样写每次二分就只用比两个 int 的大小,避免把 potions 中的元素转成 long long 比较
            spells[i] = potionsSize - upperBound(potions, potionsSize, target);
        } else {
            spells[i] = 0;
        }
    }
    *returnSize = spellsSize;
    return spells;
}
func successfulPairs(spells, potions []int, success int64) []int {
slices.Sort(potions)
for i, x := range spells {
spells[i] = len(potions) - sort.SearchInts(potions, (int(success)-1)/x+1)
}
return spells
}
var successfulPairs = function(spells, potions, success) {
    potions.sort((a, b) => a - b);
    for (let i = 0; i < spells.length; i++) {
        const target = Math.ceil(success / spells[i]);
        spells[i] = potions.length - lowerBound(potions, target);
    }
    return spells;
};

var lowerBound = function(nums, target) {
    let left = -1, right = nums.length; // 开区间 (left, right)
    while (left + 1 < right) { // 区间不为空
        // 循环不变量:
        // nums[left] < target
        // nums[right] >= target
        const mid = Math.floor((left + right) / 2);
        if (nums[mid] >= target) {
            right = mid; // 范围缩小到 (left, mid)
        } else {
            left = mid; // 范围缩小到 (mid, right)
        }
    }
    return right;
}
var successfulPairs = function(spells, potions, success) {
    potions.sort((a, b) => a - b);
    for (let i = 0; i < spells.length; i++) {
        const target = Math.ceil(success / spells[i]);
        spells[i] = potions.length - _.sortedIndex(potions, target);
    }
    return spells;
};
impl Solution {
    pub fn successful_pairs(mut spells: Vec<i32>, mut potions: Vec<i32>, success: i64) -> Vec<i32> {
        potions.sort_unstable();
        let last = potions[potions.len() - 1] as i64;
        for x in spells.iter_mut() {
            let target = (success - 1) / *x as i64;
            if target < last { // 防止 i64 转成 i32 截断(这样不需要把 potions 中的数转成 i64 比较)
                let j = potions.partition_point(|&x| x <= target as i32);
                *x = (potions.len() - j) as i32;
            } else {
                *x = 0;
            }
        }
        spells
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}((n+m)\log m)$,其中 $n$ 为 $\textit{spells}$ 的长度,$m$ 为 $\textit{potions}$ 的长度。排序 $\mathcal{O}(m\log m)$。二分 $n$ 次,每次 $\mathcal{O}(\log m)$。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销,仅用到若干额外变量。

方法二:计数 + 值域后缀和

方法一得出的结论是,统计满足 $\textit{potions}[j]\ge \left\lfloor\dfrac{\textit{success}-1}{\textit{spell}[i]}\right\rfloor + 1$ 的药水的个数。

比如 $\textit{potions}=[1,2,2,3,5,5,5]$,要计算 $\ge 2$ 的药水的个数,我们可以统计每个数出现了多少次,记在一个 $\textit{cnt}$ 数组中。在这个例子中,$\textit{cnt}=[0,1,2,1,0,3]$,比如 $\textit{cnt}[5]=3$ 表示 $5$ 出现了 $3$ 次。

那么计算 $\textit{cnt}[2] + \textit{cnt}[3] + \textit{cnt}[4] + \textit{cnt}[5]=2+1+0+3=6$,就是 $\ge 2$ 的药水的个数。

但这样太慢了。如何加速?

借鉴 前缀和 的思想,我们可以倒着遍历 $\textit{cnt}$,原地计算 $\textit{cnt}$ 的后缀和,把 $\textit{cnt}[i]$ 更新为 $\ge i$ 的药水的个数。上面的 $\textit{cnt}$ 可以更新为 $[7,7,6,4,3,3]$。比如 $\textit{cnt}[2] = 6$ 表示 $\ge 2$ 的药水的个数。

class Solution:
    def successfulPairs(self, spells: List[int], potions: List[int], success: int) -> List[int]:
        mx = max(potions)
        cnt = [0] * (mx + 1)
        for y in potions:
            cnt[y] += 1  # 统计每种药水的出现次数

        # 计算 cnt 的后缀和
        for i in range(mx - 1, -1, -1):
            cnt[i] += cnt[i + 1]
        # 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

        for i, x in enumerate(spells):
            low = (success - 1) // x + 1
            spells[i] = cnt[low] if low <= mx else 0
        return spells
class Solution {
    public int[] successfulPairs(int[] spells, int[] potions, long success) {
        int mx = 0;
        for (int y : potions) {
            mx = Math.max(mx, y);
        }

        int[] cnt = new int[mx + 1];
        for (int y : potions) {
            cnt[y]++; // 统计每种药水的出现次数
        }

        // 计算 cnt 的后缀和
        for (int i = mx - 1; i >= 0; i--) {
            cnt[i] += cnt[i + 1];
        }
        // 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

        for (int i = 0; i < spells.length; i++) {
            long low = (success - 1) / spells[i] + 1;
            spells[i] = low <= mx ? cnt[(int) low] : 0;
        }
        return spells;
    }
}
class Solution {
public:
    vector<int> successfulPairs(vector<int>& spells, vector<int>& potions, long long success) {
        int mx = ranges::max(potions);
        vector<int> cnt(mx + 1);
        for (int y : potions) {
            cnt[y]++; // 统计每种药水的出现次数
        }

        // 计算 cnt 的后缀和
        for (int i = mx - 1; i >= 0; i--) {
            cnt[i] += cnt[i + 1];
        }
        // 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

        for (int& x : spells) {
            long long low = (success - 1) / x + 1;
            x = low <= mx ? cnt[low] : 0;
        }
        return spells;
    }
};
#define MAX(a, b) ((b) > (a) ? (b) : (a))

int* successfulPairs(int* spells, int spellsSize, int* potions, int potionsSize, long long success, int* returnSize) {
    int mx = 0;
    for (int i = 0; i < potionsSize; i++) {
        mx = MAX(mx, potions[i]);
    }

    int* cnt = calloc(mx + 1, sizeof(int));
    for (int i = 0; i < potionsSize; i++) {
        cnt[potions[i]]++; // 统计每种药水的出现次数
    }

    // 计算 cnt 的后缀和
    for (int i = mx - 1; i >= 0; i--) {
        cnt[i] += cnt[i + 1];
    }
    // 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

    for (int i = 0; i < spellsSize; i++) {
        long long low = (success - 1) / spells[i] + 1;
        spells[i] = low <= mx ? cnt[low] : 0;
    }

    *returnSize = spellsSize;
    free(cnt);
    return spells;
}
func successfulPairs(spells, potions []int, success int64) []int {
mx := slices.Max(potions)
cnt := make([]int, mx+1)
for _, y := range potions {
cnt[y]++ // 统计每种药水的出现次数
}

// 计算 cnt 的后缀和
for i := mx - 1; i >= 0; i-- {
cnt[i] += cnt[i+1]
}
// 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

for i, x := range spells {
low := (int(success)-1)/x + 1
if low <= mx {
spells[i] = cnt[low]
} else {
spells[i] = 0
}
}
return spells
}
var successfulPairs = function(spells, potions, success) {
    const mx = Math.max(...potions);
    const cnt = Array(mx + 1).fill(0);
    for (const y of potions) {
        cnt[y]++; // 统计每种药水的出现次数
    }

    // 计算 cnt 的后缀和
    for (let i = mx - 1; i >= 0; i--) {
        cnt[i] += cnt[i + 1];
    }
    // 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

    for (let i = 0; i < spells.length; i++) {
        const low = Math.ceil(success / spells[i]);
        spells[i] = low <= mx ? cnt[low] : 0;
    }
    return spells;
};
impl Solution {
    pub fn successful_pairs(mut spells: Vec<i32>, potions: Vec<i32>, success: i64) -> Vec<i32> {
        let mx = *potions.iter().max().unwrap() as usize;
        let mut cnt = vec![0; mx + 1];
        for y in potions {
            cnt[y as usize] += 1; // 统计每种药水的出现次数
        }

        // 计算 cnt 的后缀和
        for i in (0..mx).rev() {
            cnt[i] += cnt[i + 1];
        }
        // 计算完毕后,cnt[i] 就是 potions 值 >= i 的药水个数

        let success = success as usize;
        for x in spells.iter_mut() {
            let low = (success - 1) / *x as usize + 1;
            *x = if low <= mx { cnt[low] } else { 0 };
        }
        spells
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n+m+U)$,其中 $n$ 为 $\textit{spells}$ 的长度,$m$ 为 $\textit{potions}$ 的长度,$U=\max(\textit{potions})$。
  • 空间复杂度:$\mathcal{O}(U)$。

思考题

把乘法改成异或要怎么做?

这题是 1803. 统计异或值在范围内的数对有多少,做法见 我的题解

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

贪心 + 有序集合 / 并查集(Python/Java/C++/Go/JS/Rust)

作者 endlesscheng
2025年10月2日 13:27

第一个错误的思路

从左到右遍历 $\textit{rains}$:

  • 如果 $\textit{rains}[i]=0$,把 $i$ 加到一个列表中,后面要抽水的时候再使用。
  • 如果 $\textit{rains}[i]>0$ 且 $\textit{rains}[i]$ 是满的,随便取出列表中的一天用来抽水。如果列表是空的,那么必定会发生洪水。

错误原因:比如 $\textit{rains}=[1,0,0,1,1]$,抽水之后连着下了两天的雨,最后一天必定发洪水。注意抽水必须在两个 $1$ 之间,第二个 $0$ 无法用来抽干湖水。但按照上面的算法,我们会错误地认为第二个 $0$ 可以把湖 $1$ 抽干。

第二个错误的思路

从左到右遍历 $\textit{rains}$:

  • 如果 $\textit{rains}[i]=0$,把 $i$ 入栈。
  • 如果 $\textit{rains}[i]>0$ 且 $\textit{rains}[i]$ 是满的,那么用栈顶的那一天抽水。如果栈顶那一天比 $\textit{rains}[i]$ 装满的那一天还要早,那么必定会发生洪水。

错误原因:例如 $\textit{rains}=[1,0,2,0,1,2]$,如果第二个 $0$ 用来抽湖 $1$,那么湖 $2$ 就没法抽干。

正确思路

从左到右遍历 $\textit{rains}$:

  • 如果 $\textit{rains}[i]=0$,把 $i$ 加到一个有序集合中。
  • 如果 $\textit{rains}[i]>0$ 且 $\textit{rains}[i]$ 是满的。设 $\textit{rains}[i]$ 装满的那一天为 $j$,从有序集合中选出大于 $j$ 的最早的抽水日。

解释:越晚的抽水日,灵活性越大,可以用于更晚装满的湖。所以越晚的抽水日越应该留到后面再使用。

写法一:有序集合

###py

class Solution:
    def avoidFlood(self, rains: List[int]) -> List[int]:
        n = len(rains)
        ans = [-1] * n
        full_day = {}  # lake -> 装满日
        dry_day = SortedList()  # 未被使用的抽水日
        for i, lake in enumerate(rains):
            if lake == 0:
                ans[i] = 1  # 先随便选一个湖抽干
                dry_day.add(i)  # 保存抽水日
                continue
            if lake in full_day:
                j = full_day[lake]
                # 必须在 j 之后,i 之前把 lake 抽干
                # 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                k = dry_day.bisect_right(j)
                if k == len(dry_day):
                    return []  # 无法阻止洪水
                d = dry_day[k]
                ans[d] = lake
                dry_day.discard(d)  # 移除已使用的抽水日
            full_day[lake] = i  # 插入或更新装满日
        return ans

###java

class Solution {
    public int[] avoidFlood(int[] rains) {
        int n = rains.length;
        int[] ans = new int[n];
        Map<Integer, Integer> fullDay = new HashMap<>(); // lake -> 装满日
        TreeSet<Integer> dryDay = new TreeSet<>(); // 未被使用的抽水日
        for (int i = 0; i < n; i++) {
            int lake = rains[i];
            if (lake == 0) {
                ans[i] = 1; // 先随便选一个湖抽干
                dryDay.add(i); // 保存抽水日
                continue;
            }
            Integer j = fullDay.get(lake);
            if (j != null) {
                // 必须在 j 之后,i 之前把 lake 抽干
                // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                Integer d = dryDay.higher(j);
                if (d == null) {
                    return new int[]{}; // 无法阻止洪水
                }
                ans[d] = lake;
                dryDay.remove(d); // 移除已使用的抽水日
            }
            ans[i] = -1;
            fullDay.put(lake, i); // 插入或更新装满日
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    vector<int> avoidFlood(vector<int>& rains) {
        int n = rains.size();
        vector<int> ans(n, -1);
        unordered_map<int, int> full_day; // lake -> 装满日
        set<int> dry_day; // 未被使用的抽水日
        for (int i = 0; i < n; i++) {
            int lake = rains[i];
            if (lake == 0) {
                ans[i] = 1; // 先随便选一个湖抽干
                dry_day.insert(i); // 保存抽水日
                continue;
            }
            if (full_day.contains(lake)) {
                int j = full_day[lake];
                // 必须在 j 之后,i 之前把 lake 抽干
                // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                auto it = dry_day.upper_bound(j);
                if (it == dry_day.end()) {
                    return {}; // 无法阻止洪水
                }
                ans[*it] = lake;
                dry_day.erase(it); // 移除已使用的抽水日
            }
            full_day[lake] = i; // 插入或更新装满日
        }
        return ans;
    }
};

###go

// import "github.com/emirpasic/gods/v2/trees/redblacktree"
func avoidFlood(rains []int) []int {
ans := make([]int, len(rains))
fullDay := map[int]int{} // lake -> 装满日
dryDay := redblacktree.New[int, struct{}]() // 未被使用的抽水日
for i, lake := range rains {
if lake == 0 {
ans[i] = 1 // 先随便选一个湖抽干
dryDay.Put(i, struct{}{}) // 保存抽水日
continue
}
if j, ok := fullDay[lake]; ok {
// 必须在 j 之后,i 之前把 lake 抽干
// 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
node, _ := dryDay.Ceiling(j)
if node == nil {
return nil // 无法阻止洪水
}
ans[node.Key] = lake
dryDay.Remove(node.Key) // 移除已使用的抽水日
}
ans[i] = -1
fullDay[lake] = i // 插入或更新装满日
}
return ans
}

###js

const { AvlTree } = require('datastructures-js');

var avoidFlood = function(rains) {
    const n = rains.length;
    const ans = Array(n).fill(-1);
    const fullDay = new Map(); // lake -> 装满日
    const dryDay = new AvlTree((a, b) => a - b); // 未被使用的抽水日
    for (let i = 0; i < n; i++) {
        const lake = rains[i];
        if (lake === 0) {
            ans[i] = 1; // 先随便选一个湖抽干
            dryDay.insert(i); // 保存抽水日
            continue;
        }
        const j = fullDay.get(lake);
        if (j !== undefined) {
            // 必须在 j 之后,i 之前把 lake 抽干
            // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
            const node = dryDay.ceil(j);
            if (node === null) {
                return []; // 无法阻止洪水
            }
            ans[node.getValue()] = lake;
            dryDay.removeNode(node); // 移除已使用的抽水日
        }
        fullDay.set(lake, i); // 插入或更新装满日
    }
    return ans;
};

###rust

use std::collections::{BTreeSet, HashMap};

impl Solution {
    pub fn avoid_flood(rains: Vec<i32>) -> Vec<i32> {
        let n = rains.len();
        let mut ans = vec![-1; n];
        let mut full_day = HashMap::new(); // lake -> 装满日
        let mut dry_day = BTreeSet::new(); // 未被使用的抽水日
        for (i, lake) in rains.into_iter().enumerate() {
            if lake == 0 {
                ans[i] = 1; // 先随便选一个湖抽干
                dry_day.insert(i); // 保存抽水日
                continue;
            }
            if let Some(&j) = full_day.get(&lake) {
                // 必须在 j 之后,i 之前把 lake 抽干
                // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                if let Some(&d) = dry_day.range(j..).next() {
                    ans[d] = lake;
                    dry_day.remove(&d); // 移除已使用的抽水日
                } else {
                    return vec![]; // 无法阻止洪水
                }
            }
            full_day.insert(lake, i); // 插入或更新装满日
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{rains}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

写法二:并查集

删除 $i$ 时,用并查集把 $i$ 指向 $i+1$(或者 $\text{find}(i+1)$)。

例如删除 $2$,那么并查集中 $2\to 3$。

然后删除 $3$,那么并查集中 $2\to 3\to 4$。

查找 $\ge 2$ 的最小的未被删除的天,顺着并查集中的 $2$,可以找到 $4$,即为 $\ge 2$ 的最小的未被删除的天。

对于 $\textit{rains}[i]>0$ 的天,直接把 $i$ 删除。

注:关于并查集的完整模板,见 数据结构题单

###py

class Solution:
    def avoidFlood(self, rains: List[int]) -> List[int]:
        n = len(rains)
        # 非递归并查集
        fa = list(range(n + 1))
        def find(x: int) -> int:
            rt = x
            while fa[rt] != rt:
                rt = fa[rt]
            while fa[x] != rt:
                fa[x], x = rt, fa[x]
            return rt

        ans = [-1] * n
        full_day = {}  # lake -> 装满日
        for i, lake in enumerate(rains):
            if lake == 0:
                ans[i] = 1  # 先随便选一个湖抽干
                continue
            if lake in full_day:
                j = full_day[lake]
                # 必须在 j 之后,i 之前把 lake 抽干
                # 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                dry_day = find(j + 1)
                if dry_day >= i:
                    return []  # 无法阻止洪水
                ans[dry_day] = lake
                fa[dry_day] = find(dry_day + 1)  # 删除 dry_day
            fa[i] = i + 1  # 删除 i
            full_day[lake] = i  # 插入或更新装满日
        return ans

###java

class Solution {
    public int[] avoidFlood(int[] rains) {
        int n = rains.length;
        int[] fa = new int[n + 1];
        for (int i = 0; i <= n; i++) {
            fa[i] = i;
        }

        int[] ans = new int[n];
        Map<Integer, Integer> fullDay = new HashMap<>(); // lake -> 装满日
        for (int i = 0; i < n; i++) {
            int lake = rains[i];
            if (lake == 0) {
                ans[i] = 1; // 先随便选一个湖抽干
                continue;
            }
            Integer j = fullDay.get(lake);
            if (j != null) {
                // 必须在 j 之后,i 之前把 lake 抽干
                // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                int dryDay = find(j + 1, fa);
                if (dryDay >= i) {
                    return new int[]{}; // 无法阻止洪水
                }
                ans[dryDay] = lake;
                fa[dryDay] = find(dryDay + 1, fa); // 删除 dryDay
            }
            ans[i] = -1;
            fa[i] = i + 1; // 删除 i
            fullDay.put(lake, i); // 插入或更新装满日
        }
        return ans;
    }

    private int find(int x, int[] fa) {
        if (fa[x] != x) {
            fa[x] = find(fa[x], fa);
        }
        return fa[x];
    }
}

###cpp

class Solution {
    vector<int> fa;

    int find(int x) {
        if (fa[x] != x) {
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }

public:
    vector<int> avoidFlood(vector<int>& rains) {
        int n = rains.size();
        fa.resize(n + 1);
        ranges::iota(fa, 0); // 并查集初始化

        vector<int> ans(n, -1);
        unordered_map<int, int> full_day; // lake -> 装满日
        for (int i = 0; i < n; i++) {
            int lake = rains[i];
            if (lake == 0) {
                ans[i] = 1; // 先随便选一个湖抽干
                continue;
            }
            if (full_day.count(lake)) {
                int j = full_day[lake];
                // 必须在 j 之后,i 之前把 lake 抽干
                // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
                int dry_day = find(j + 1);
                if (dry_day >= i) {
                    return {}; // 无法阻止洪水
                }
                ans[dry_day] = lake;
                fa[dry_day] = find(dry_day + 1); // 删除 dry_day
            }
            fa[i] = i + 1; // 删除 i
            full_day[lake] = i; // 插入或更新装满日
        }
        return ans;
    }
};

###go

func avoidFlood(rains []int) []int {
n := len(rains)
// 非递归并查集
fa := make([]int, n+1)
for i := range fa {
fa[i] = i
}
find := func(x int) int {
rt := x
for fa[rt] != rt {
rt = fa[rt]
}
for fa[x] != rt {
fa[x], x = rt, fa[x]
}
return rt
}

ans := make([]int, n)
fullDay := map[int]int{} // lake -> 装满日
for i, lake := range rains {
if lake == 0 {
ans[i] = 1 // 先随便选一个湖抽干
continue
}
if j, ok := fullDay[lake]; ok {
// 必须在 j 之后,i 之前把 lake 抽干
// 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
dryDay := find(j + 1)
if dryDay >= i {
return nil // 无法阻止洪水
}
ans[dryDay] = lake
fa[dryDay] = find(dryDay + 1) // 删除 dryDay
}
ans[i] = -1
fa[i] = i + 1 // 删除 i
fullDay[lake] = i // 插入或更新装满日
}
return ans
}

###js

var avoidFlood = function(rains) {
    const n = rains.length;
    const fa = Array(n + 1).fill(0).map((_, i) => i);

    function find(x) {
        if (fa[x] !== x) {
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }

    const ans = Array(n).fill(-1);
    const fullDay = new Map(); // lake -> 装满日
    for (let i = 0; i < n; i++) {
        const lake = rains[i];
        if (lake === 0) {
            ans[i] = 1; // 先随便选一个湖抽干
            continue;
        }
        const j = fullDay.get(lake);
        if (j !== undefined) {
            // 必须在 j 之后,i 之前把 lake 抽干
            // 选一个最早的未被使用的抽水日,如果选晚的,可能会导致其他湖没有可用的抽水日
            const dryDay = find(j + 1);
            if (dryDay >= i) {
                return []; // 无法阻止洪水
            }
            ans[dryDay] = lake;
            fa[dryDay] = find(dryDay + 1); // 删除 dryDay
        }
        fa[i] = i + 1; // 删除 i
        fullDay.set(lake, i); // 插入或更新装满日
    }
    return ans;
};

###rust

use std::collections::HashMap;

impl Solution {
    pub fn avoid_flood(rains: Vec<i32>) -> Vec<i32> {
        let n = rains.len();
        let mut fa = (0..=n).collect::<Vec<_>>();

        fn find(x: usize, fa: &mut [usize]) -> usize {
            if fa[x] != x {
                fa[x] = find(fa[x], fa);
            }
            fa[x]
        }

        let mut ans = vec![-1; n];
        let mut full_day = HashMap::new(); // lake -> 装满日
        for (i, lake) in rains.into_iter().enumerate() {
            if lake == 0 {
                ans[i] = 1; // 先随便选一个湖抽干
                continue;
            }
            if let Some(&j) = full_day.get(&lake) {
                // 必须在 j 之后,i 之前把 lake 抽干
                let dry_day = find(j + 1, &mut fa);
                if dry_day >= i {
                    return vec![]; // 无法阻止洪水
                }
                ans[dry_day] = lake;
                fa[dry_day] = find(dry_day + 1, &mut fa); // 删除 dry_day
            }
            fa[i] = i + 1; // 删除 i
            full_day.insert(lake, i); // 插入或更新装满日
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{rains}$ 的长度。本题并查集的操作是均摊 $\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面数据结构题单的「§7.4 数组上的并查集」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

逆向思维(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年10月5日 07:32

核心想法:本题起点(答案)不明确,但终点(边界)明确,所以从边界出发能方便地找到答案。

lc417.jpg{:width=400px}

什么是边界?$\textit{heights}$ 中的 $i=0$ 或者 $i=m-1$ 或者 $j=0$ 或者 $j=n-1$ 的格子。

什么是答案?既可流向太平洋也可流向大西洋的格子。

  1. 对于可流向太平洋的格子,我们从上边界($i=0$)和左边界($j=0$)倒着往高处走,所有能访问到的格子,就是可以流向太平洋的格子。
  2. 对于可流向大西洋的格子,我们从下边界($i=m-1$)和右边界($j=n-1$)倒着往高处走,所有能访问到的格子,就是可以流向大西洋的格子。
  3. 计算这两类格子的交集,即为既可流向太平洋也可流向大西洋的格子。

关于网格图 DFS 的原理,见 200. 岛屿数量我的题解

class Solution:
    def pacificAtlantic(self, heights: List[List[int]]) -> List[List[int]]:
        m, n = len(heights), len(heights[0])

        def search(cells: List[Tuple[int, int]]) -> Set[Tuple[int, int]]:
            def dfs(i: int, j: int) -> None:
                if (i, j) in vis:  # 避免重复访问,避免反复横跳无限递归
                    return
                vis.add((i, j))  # 标记 (i,j) 已访问
                for x, y in (i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j):  # 枚举相邻格子
                    if 0 <= x < m and 0 <= y < n and heights[x][y] >= heights[i][j]:  # 往高处走
                        dfs(x, y)

            vis = set()
            for i, j in cells:
                dfs(i, j)
            return vis

        pacific = [(0, j) for j in range(n)] + [(i, 0) for i in range(1, m)]
        atlantic = [(m - 1, j) for j in range(n)] + [(i, n - 1) for i in range(m - 1)]
        return list(search(pacific) & search(atlantic))  # 交集即为答案
class Solution {
    // 左右上下
    private static final int[][] DIRS = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};

    public List<List<Integer>> pacificAtlantic(int[][] heights) {
        int m = heights.length, n = heights[0].length;

        // 从太平洋边界出发
        boolean[][] pacificVis = new boolean[m][n];
        for (int j = 0; j < n; j++) {
            dfs(0, j, pacificVis, heights); // 上边界
        }
        for (int i = 1; i < m; i++) {
            dfs(i, 0, pacificVis, heights); // 左边界
        }

        // 从大西洋边界出发
        boolean[][] atlanticVis = new boolean[m][n];
        for (int j = 0; j < n; j++) {
            dfs(m - 1, j, atlanticVis, heights); // 下边界
        }
        for (int i = 0; i < m - 1; i++) {
            dfs(i, n - 1, atlanticVis, heights); // 右边界
        }

        // 交集即为答案
        List<List<Integer>> ans = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (pacificVis[i][j] && atlanticVis[i][j]) {
                    ans.add(List.of(i, j));
                }
            }
        }
        return ans;
    }

    private void dfs(int i, int j, boolean[][] vis, int[][] heights) {
        if (vis[i][j]) { // 避免重复访问,避免反复横跳无限递归
            return;
        }
        vis[i][j] = true; // 标记 (i,j) 已访问
        for (int[] d : DIRS) { // 枚举相邻格子
            int x = i + d[0], y = j + d[1];
            if (0 <= x && x < heights.length && 0 <= y && y < heights[x].length && heights[x][y] >= heights[i][j]) { // 往高处走
                dfs(x, y, vis, heights);
            }
        }
    }
}
class Solution {
    // 左右上下
    static constexpr int DIRS[4][2] = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};

public:
    vector<vector<int>> pacificAtlantic(vector<vector<int>>& heights) {
        int m = heights.size(), n = heights[0].size();

        // lambda 递归
        auto dfs = [&](this auto&& dfs, int i, int j, vector<vector<int8_t>>& vis) -> void {
            if (vis[i][j]) { // 避免重复访问,避免反复横跳无限递归
                return;
            }
            vis[i][j] = true; // 标记 (i,j) 已访问
            for (auto& [dx, dy] : DIRS) { // 枚举相邻格子
                int x = i + dx, y = j + dy;
                if (0 <= x && x < m && 0 <= y && y < n && heights[x][y] >= heights[i][j]) { // 往高处走
                    dfs(x, y, vis);
                }
            }
        };

        // 从太平洋边界出发
        vector pacific_vis(m, vector<int8_t>(n));
        for (int j = 0; j < n; j++) {
            dfs(0, j, pacific_vis); // 上边界
        }
        for (int i = 1; i < m; i++) {
            dfs(i, 0, pacific_vis); // 左边界
        }

        // 从大西洋边界出发
        vector atlantic_vis(m, vector<int8_t>(n));
        for (int j = 0; j < n; j++) {
            dfs(m - 1, j, atlantic_vis); // 下边界
        }
        for (int i = 0; i < m - 1; i++) {
            dfs(i, n - 1, atlantic_vis); // 右边界
        }

        // 交集即为答案
        vector<vector<int>> ans;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (pacific_vis[i][j] && atlantic_vis[i][j]) {
                    ans.push_back({i, j});
                }
            }
        }
        return ans;
    }
};
// 左右上下
static const int DIRS[4][2] = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};

void dfs(int i, int j, bool** vis, int** heights, int m, int n) {
    if (vis[i][j]) { // 避免重复访问,避免反复横跳无限递归
        return;
    }
    vis[i][j] = true; // 标记 (i,j) 已访问
    for (int k = 0; k < 4; k++) { // 枚举相邻格子
        int x = i + DIRS[k][0], y = j + DIRS[k][1];
        if (0 <= x && x < m && 0 <= y && y < n && heights[x][y] >= heights[i][j]) { // 往高处走
            dfs(x, y, vis, heights, m, n);
        }
    }
}

int** pacificAtlantic(int** heights, int heightsSize, int* heightsColSize, int* returnSize, int** returnColumnSizes) {
    int m = heightsSize, n = heightsColSize[0];

    // 从太平洋边界出发
    bool** pacific_vis = malloc(m * sizeof(bool*));
    for (int i = 0; i < m; i++) {
        pacific_vis[i] = calloc(n, sizeof(bool));
    }
    for (int j = 0; j < n; j++) {
        dfs(0, j, pacific_vis, heights, m, n); // 上边界
    }
    for (int i = 1; i < m; i++) {
        dfs(i, 0, pacific_vis, heights, m, n); // 左边界
    }

    // 从大西洋边界出发
    bool** atlantic_vis = malloc(m * sizeof(bool*));
    for (int i = 0; i < m; i++) {
        atlantic_vis[i] = calloc(n, sizeof(bool));
    }
    for (int j = 0; j < n; j++) {
        dfs(m - 1, j, atlantic_vis, heights, m, n); // 下边界
    }
    for (int i = 0; i < m - 1; i++) {
        dfs(i, n - 1, atlantic_vis, heights, m, n); // 右边界
    }

    // 交集即为答案
    int** ans = malloc(m * n * sizeof(int*));
    *returnColumnSizes = malloc(m * n * sizeof(int));
    *returnSize = 0;
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            if (pacific_vis[i][j] && atlantic_vis[i][j]) {
                (*returnColumnSizes)[*returnSize] = 2;
                ans[*returnSize] = malloc(2 * sizeof(int));
                ans[*returnSize][0] = i;
                ans[*returnSize][1] = j;
                (*returnSize)++;
            }
        }
    }
    
    free(pacific_vis);
    free(atlantic_vis);
    return ans;
}
var dirs = [][2]int{{0, -1}, {0, 1}, {-1, 0}, {1, 0}} // 左右上下

func pacificAtlantic(heights [][]int) (ans [][]int) {
    m, n := len(heights), len(heights[0])

    var dfs func(int, int, [][]bool)
    dfs = func(i, j int, vis [][]bool) {
        if vis[i][j] { // 避免重复访问,避免反复横跳无限递归
            return
        }
        vis[i][j] = true // 标记 (i,j) 已访问
        for _, d := range dirs { // 枚举相邻格子
            x, y := i+d[0], j+d[1]
            if 0 <= x && x < m && 0 <= y && y < n && heights[x][y] >= heights[i][j] { // 往高处走
                dfs(x, y, vis)
            }
        }
    }

    // 从太平洋边界出发
    pacificVis := make([][]bool, m)
    for i := range pacificVis {
        pacificVis[i] = make([]bool, n)
    }
    for j := range n {
        dfs(0, j, pacificVis) // 上边界
    }
    for i := 1; i < m; i++ {
        dfs(i, 0, pacificVis) // 左边界
    }

    // 从大西洋边界出发
    atlanticVis := make([][]bool, m)
    for i := range atlanticVis {
        atlanticVis[i] = make([]bool, n)
    }
    for j := range n {
        dfs(m-1, j, atlanticVis) // 下边界
    }
    for i := range m - 1 {
        dfs(i, n-1, atlanticVis) // 右边界
    }

    // 交集即为答案
    for i, row := range pacificVis {
        for j, ok := range row {
            if ok && atlanticVis[i][j] {
                ans = append(ans, []int{i, j})
            }
        }
    }
    return
}
const DIRS = [[0, -1], [0, 1], [-1, 0], [1, 0]]; // 左右上下

var pacificAtlantic = function(heights) {
    const m = heights.length, n = heights[0].length;

    function dfs(i, j, vis) {
        if (vis[i][j]) { // 避免重复访问,避免反复横跳无限递归
            return;
        }
        vis[i][j] = true; // 标记 (i,j) 已访问
        for (const [dx, dy] of DIRS) { // 枚举相邻格子
            const x = i + dx, y = j + dy;
            if (x >= 0 && x < m && y >= 0 && y < n && heights[x][y] >= heights[i][j]) { // 往高处走
                dfs(x, y, vis);
            }
        }
    }

    // 从太平洋边界出发
    const pacificVis = Array.from({ length: m }, () => Array(n).fill(false));
    for (let j = 0; j < n; j++) {
        dfs(0, j, pacificVis); // 上边界
    }
    for (let i = 1; i < m; i++) {
        dfs(i, 0, pacificVis); // 左边界
    }

    // 从大西洋边界出发
    const atlanticVis = Array.from({ length: m }, () => Array(n).fill(false));
    for (let j = 0; j < n; j++) {
        dfs(m - 1, j, atlanticVis); // 下边界
    }
    for (let i = 0; i < m - 1; i++) {
        dfs(i, n - 1, atlanticVis); // 右边界
    }

    // 交集即为答案
    const ans = [];
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            if (pacificVis[i][j] && atlanticVis[i][j]) {
                ans.push([i, j]);
            }
        }
    }
    return ans;
};
const DIRS: [(i8, i8); 4] = [(0, -1), (0, 1), (-1, 0), (1, 0)]; // 左右上下

impl Solution {
    pub fn pacific_atlantic(heights: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        fn dfs(i: usize, j: usize, vis: &mut Vec<Vec<bool>>, heights: &Vec<Vec<i32>>) {
            if vis[i][j] { // 避免重复访问,避免反复横跳无限递归
                return;
            }
            vis[i][j] = true; // 标记 (i,j) 已访问
            for &(dx, dy) in &DIRS { // 枚举相邻格子
                let x = i + dx as usize;
                let y = j + dy as usize;
                if x < heights.len() && y < heights[i].len() && heights[x][y] >= heights[i][j] { // 往高处走
                    dfs(x, y, vis, heights);
                }
            }
        }

        let m = heights.len();
        let n = heights[0].len();

        // 从太平洋边界出发
        let mut pacific_vis = vec![vec![false; n]; m];
        for j in 0..n {
            dfs(0, j, &mut pacific_vis, &heights); // 上边界
        }
        for i in 1..m {
            dfs(i, 0, &mut pacific_vis, &heights); // 左边界
        }

        // 从大西洋边界出发
        let mut atlantic_vis = vec![vec![false; n]; m];
        for j in 0..n {
            dfs(m - 1, j, &mut atlantic_vis, &heights); // 下边界
        }
        for i in 0..m - 1 {
            dfs(i, n - 1, &mut atlantic_vis, &heights); // 右边界
        }

        // 交集即为答案
        let mut ans = vec![];
        for i in 0..m {
            for j in 0..n {
                if pacific_vis[i][j] && atlantic_vis[i][j] {
                    ans.push(vec![i as i32, j as i32]);
                }
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别为 $\textit{heights}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

专题训练

见下面网格图题单的「一、网格图 DFS」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

短板效应(Python/Java/C++/Go/JS/Rust)

作者 endlesscheng
2024年11月23日 18:46

前置题目42. 接雨水。建议先掌握这题的双指针写法,再来做本题。

Buckets-effect-1024x937.webp{:width=400}

哪个格子的接水量,在一开始就能确定?

  • 最外面一圈的格子是无法接水的。
  • 假设 $(0,1)$ 的高度是最外面一圈的格子中最小的,且高度等于 $5$,那么和它相邻的 $(1,1)$,我们能知道:
    • $(1,1)$ 的水位不能超过 $5$,否则水会从 $(0,1)$ 流出去。
    • $(1,1)$ 的水位一定可以等于 $5$,这是因为 $(0,1)$ 的高度是最外面一圈的格子中最小的,$(1,1)$ 的水不可能从其他地方流出去。

我们从最外面一圈的格子开始。想象成一个木桶,最外面一圈格子的高度视作木板的高度。

接着上面的讨论:

  • 如果 $(1,1)$ 的高度 $\ge 5$,那么 $(0,1)$ 这块木板就没用了,我们去掉 $(0,1)$ 这块木板,改用 $(1,1)$ 这块木板。
  • 如果 $(1,1)$ 的高度 $< 5$,假设我们接的不是水,是水泥。那么把 $(1,1)$ 的高度填充为 $5$,仍然可以去掉 $(0,1)$ 这块木板,改用 $(1,1)$ 这块(填充水泥后)高为 $5$ 的木板水泥板。

继续,从当前木板中,找到一根最短的木板。假设 $(1,1)$ 是当前所有木板中最短的,那么其邻居 $(1,2)$ 和 $(2,1)$ 的水位就是 $(1,1)$ 的高度,因为超过 $(1,1)$ 高度的水会流出去。然后,去掉 $(1,1)$ 这块木板,改用 $(1,2)$ 和 $(2,1)$ 这两块木板。依此类推。

由于每次都要找最短的木板,所以用一个最小堆维护木板的高度。按照上述做法,不断循环,直到堆为空。

为方便实现,代码在初始化堆的时候,直接遍历了整个矩阵。只遍历最外面一圈的写法可以参考 Python3 的写法二。

答疑

:这种思路和 42. 接雨水 的双指针做法的联系是什么?

:42 那题只需要维护左右两个指针,本题相当于维护了“一圈”指针。42 那题每次取左右最小的指针,然后移动到相邻位置上;本题也是取最小的指针(出堆),往周围的邻居移动(入堆)。

class Solution:
    def trapRainWater(self, heightMap: List[List[int]]) -> int:
        m, n = len(heightMap), len(heightMap[0])
        h = []
        for i, row in enumerate(heightMap):
            for j, height in enumerate(row):
                if i == 0 or i == m - 1 or j == 0 or j == n - 1:
                    h.append((height, i, j))
                    row[j] = -1  # 标记 (i,j) 访问过
        heapify(h)

        ans = 0
        while h:
            min_height, i, j = heappop(h)  # min_height 是木桶的短板
            for x, y in (i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j):
                if 0 <= x < m and 0 <= y < n and heightMap[x][y] >= 0:  # (x,y) 没有访问过
                    # 如果 (x,y) 的高度小于 min_height,那么接水量为 min_height - heightMap[x][y]
                    ans += max(min_height - heightMap[x][y], 0)
                    # 给木桶新增一块高为 max(min_height, heightMap[x][y]) 的木板
                    heappush(h, (max(min_height, heightMap[x][y]), x, y))
                    heightMap[x][y] = -1  # 标记 (x,y) 访问过
        return ans
class Solution:
    def trapRainWater(self, heightMap: List[List[int]]) -> int:
        m, n = len(heightMap), len(heightMap[0])
        h = []
        for j in range(n):
            h.append((heightMap[0][j], 0, j))  # 上边
            h.append((heightMap[-1][j], m - 1, j))  # 下边
            heightMap[0][j] = heightMap[-1][j] = -1
        for i in range(1, m - 1):
            h.append((heightMap[i][0], i, 0))  # 左边
            h.append((heightMap[i][-1], i, n - 1))  # 右边
            heightMap[i][0] = heightMap[i][-1] = -1
        heapify(h)

        ans = 0
        while h:
            min_height, i, j = heappop(h)  # min_height 是木桶的短板
            for x, y in (i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j):
                if 0 <= x < m and 0 <= y < n and heightMap[x][y] >= 0:  # (x,y) 没有访问过
                    # 如果 (x,y) 的高度小于 min_height,那么接水量为 min_height - heightMap[x][y]
                    ans += max(min_height - heightMap[x][y], 0)
                    # 给木桶新增一块高为 max(min_height, heightMap[x][y]) 的木板
                    heappush(h, (max(min_height, heightMap[x][y]), x, y))
                    heightMap[x][y] = -1  # 标记 (x,y) 访问过
        return ans
class Solution {
    private static final int[][] DIRS = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};

    public int trapRainWater(int[][] heightMap) {
        int m = heightMap.length, n = heightMap[0].length;
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (i == 0 || i == m - 1 || j == 0 || j == n - 1) {
                    pq.add(new int[]{heightMap[i][j], i, j});
                    heightMap[i][j] = -1; // 标记 (i,j) 访问过
                }
            }
        }

        int ans = 0;
        while (!pq.isEmpty()) {
            int[] t = pq.poll(); // 去掉短板
            int minHeight = t[0], i = t[1], j = t[2]; // minHeight 是木桶的短板
            for (int[] d : DIRS) {
                int x = i + d[0], y = j + d[1]; // (i,j) 的邻居
                if (0 <= x && x < m && 0 <= y && y < n && heightMap[x][y] >= 0) { // (x,y) 没有访问过
                    // 如果 (x,y) 的高度小于 minHeight,那么接水量为 minHeight - heightMap[x][y]
                    ans += Math.max(minHeight - heightMap[x][y], 0);
                    // 给木桶新增一块高为 max(minHeight, heightMap[x][y]) 的木板
                    pq.add(new int[]{Math.max(minHeight, heightMap[x][y]), x, y});
                    heightMap[x][y] = -1; // 标记 (x,y) 访问过
                }
            }
        }
        return ans;
    }
}
class Solution {
    static constexpr int DIRS[4][2] = {{0, -1}, {0, 1}, {-1, 0}, {1, 0}};
public:
    int trapRainWater(vector<vector<int>>& heightMap) {
        int m = heightMap.size(), n = heightMap[0].size();
        priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<>> pq;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (i == 0 || i == m - 1 || j == 0 || j == n - 1) {
                    pq.emplace(heightMap[i][j], i, j);
                    heightMap[i][j] = -1; // 标记 (i,j) 访问过
                }
            }
        }

        int ans = 0;
        while (!pq.empty()) {
            auto [min_height, i, j] = pq.top(); // min_height 是木桶的短板
            pq.pop(); // 去掉短板
            for (auto& [dx, dy] : DIRS) {
                int x = i + dx, y = j + dy; // (i,j) 的邻居
                if (0 <= x && x < m && 0 <= y && y < n && heightMap[x][y] >= 0) { // (x,y) 没有访问过
                    // 如果 (x,y) 的高度小于 min_height,那么接水量为 min_height - heightMap[x][y]
                    ans += max(min_height - heightMap[x][y], 0);
                    // 给木桶新增一块高为 max(min_height, heightMap[x][y]) 的木板
                    pq.emplace(max(min_height, heightMap[x][y]), x, y);
                    heightMap[x][y] = -1; // 标记 (x,y) 访问过
                }
            }
        }
        return ans;
    }
};
var dir4 = []struct{ x, y int }{{0, -1}, {0, 1}, {-1, 0}, {1, 0}}

func trapRainWater(heightMap [][]int) (ans int) {
    m, n := len(heightMap), len(heightMap[0])
    h := hp{}
    for i, row := range heightMap {
        for j, height := range row {
            if i == 0 || i == m-1 || j == 0 || j == n-1 {
                h = append(h, cell{height, i, j})
                row[j] = -1 // 标记 (i,j) 访问过
            }
        }
    }
    heap.Init(&h)

    for len(h) > 0 {
        c := heap.Pop(&h).(cell) // 去掉短板
        minHeight, i, j := c.height, c.x, c.y // minHeight 是木桶的短板
        for _, d := range dir4 {
            x, y := i+d.x, j+d.y // (i,j) 的邻居
            if 0 <= x && x < m && 0 <= y && y < n && heightMap[x][y] >= 0 { // (x,y) 没有访问过
                // 如果 (x,y) 的高度小于 minHeight,那么接水量为 minHeight - heightMap[x][y]
                ans += max(minHeight-heightMap[x][y], 0)
                // 给木桶新增一块高为 max(minHeight, heightMap[x][y]) 的木板
                heap.Push(&h, cell{max(minHeight, heightMap[x][y]), x, y})
                heightMap[x][y] = -1 // 标记 (x,y) 访问过
            }
        }
    }
    return
}

type cell struct{ height, x, y int }
type hp []cell
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].height < h[j].height }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(cell)) }
func (h *hp) Pop() any          { a := *h; v := a[len(a)-1]; *h = a[:len(a)-1]; return v }
var trapRainWater = function(heightMap) {
    const m = heightMap.length, n = heightMap[0].length;
    const pq = new MinPriorityQueue(e => e[0]);
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            if (i === 0 || i === m - 1 || j === 0 || j === n - 1) {
                pq.enqueue([heightMap[i][j], i, j]);
                heightMap[i][j] = -1; // 标记 (i,j) 访问过
            }
        }
    }

    let ans = 0;
    while (!pq.isEmpty()) {
        const [minHeight, i, j] = pq.dequeue(); // 去掉短板
        for (const [x, y] of [[i, j - 1], [i, j + 1], [i - 1, j], [i + 1, j]]) {
            if (0 <= x && x < m && 0 <= y && y < n && heightMap[x][y] >= 0) { // (x,y) 没有访问过
                // 如果 (x,y) 的高度小于 minHeight,那么接水量为 minHeight - heightMap[x][y]
                ans += Math.max(minHeight - heightMap[x][y], 0);
                // 给木桶新增一块高为 max(minHeight, heightMap[x][y]) 的木板
                pq.enqueue([Math.max(minHeight, heightMap[x][y]), x, y]);
                heightMap[x][y] = -1; // 标记 (x,y) 访问过
            }
        }
    }
    return ans;
};
use std::collections::BinaryHeap;

impl Solution {
    pub fn trap_rain_water(mut height_map: Vec<Vec<i32>>) -> i32 {
        let m = height_map.len();
        let n = height_map[0].len();
        let mut h = BinaryHeap::new();
        for (i, row) in height_map.iter_mut().enumerate() {
            for (j, height) in row.iter_mut().enumerate() {
                if i == 0 || i == m - 1 || j == 0 || j == n - 1 {
                    h.push((-*height, i, j)); // 取相反数变成最小堆
                    *height = -1; // 标记 (i,j) 访问过
                }
            }
        }

        let mut ans = 0;
        while let Some((min_height, i, j)) = h.pop() { // 去掉短板
            let min_height = -min_height; // min_height 是木桶的短板
            for (x, y) in [(i, j - 1), (i, j + 1), (i - 1, j), (i + 1, j)] {
                if x < m && y < n && height_map[x][y] >= 0 { // (x,y) 没有访问过
                    // 如果 (x,y) 的高度小于 min_height,那么接水量为 min_height - heightMap[x][y]
                    ans += 0.max(min_height - height_map[x][y]);
                    // 给木桶新增一块高为 max(min_height, heightMap[x][y]) 的木板
                    h.push((-min_height.max(height_map[x][y]), x, y));
                    height_map[x][y] = -1; // 标记 (x,y) 访问过
                }
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn\log (mn))$,其中 $m$ 和 $n$ 分别为 $\textit{heightMap}$ 的行数和列数。每次出堆入堆需要 $\mathcal{O}(\log (mn))$ 的时间。
  • 空间复杂度:$\mathcal{O}(mn)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:模拟 / O(1) 数学公式(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年3月31日 12:10

方法一:模拟

本题和 1518. 换水问题 几乎一样,唯一区别是每次循环要把 $\textit{numExchange}$ 加一。

###py

class Solution:
    def maxBottlesDrunk(self, numBottles: int, numExchange: int) -> int:
        ans = 0
        while numBottles >= numExchange:
            ans += numExchange  # 吨吨吨~
            numBottles -= numExchange - 1
            numExchange += 1
        return ans + numBottles

###java

class Solution {
    public int maxBottlesDrunk(int numBottles, int numExchange) {
        int ans = 0;
        while (numBottles >= numExchange) {
            ans += numExchange; // 吨吨吨~
            numBottles -= numExchange - 1;
            numExchange++;
        }
        return ans + numBottles;
    }
}

###cpp

class Solution {
public:
    int maxBottlesDrunk(int numBottles, int numExchange) {
        int ans = 0;
        while (numBottles >= numExchange) {
            ans += numExchange; // 吨吨吨~
            numBottles -= numExchange - 1;
            numExchange++;
        }
        return ans + numBottles;
    }
};

###c

int maxBottlesDrunk(int numBottles, int numExchange) {
    int ans = 0;
    while (numBottles >= numExchange) {
        ans += numExchange; // 吨吨吨~
        numBottles -= numExchange - 1;
        numExchange++;
    }
    return ans + numBottles;
}

###go

func maxBottlesDrunk(numBottles, numExchange int) (ans int) {
for numBottles >= numExchange {
ans += numExchange // 吨吨吨~
numBottles -= numExchange - 1
numExchange++;
}
return ans + numBottles
}

###js

var maxBottlesDrunk = function(numBottles, numExchange) {
    let ans = 0;
    while (numBottles >= numExchange) {
        ans += numExchange; // 吨吨吨~
        numBottles -= numExchange - 1;
        numExchange++;
    }
    return ans + numBottles;
};

###rust

impl Solution {
    pub fn max_bottles_drunk(mut num_bottles: i32, mut num_exchange: i32) -> i32 {
        let mut ans = 0;
        while num_bottles >= num_exchange {
            ans += num_exchange; // 吨吨吨~
            num_bottles -= num_exchange - 1;
            num_exchange += 1;
        }
        ans + num_bottles
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\sqrt \textit{numBottles})$。$\textit{numExchange}=1$ 为最坏情况,根据方法二,需要循环 $\mathcal{O}(\sqrt \textit{numBottles})$ 次。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:数学公式

设 $n = \textit{numBottles}$,$e = \textit{numExchange}$。

设 $k$ 为循环次数(额外得到的水瓶数),那么答案就是 $n+k$。

循环 $k$ 次后,剩余瓶子数小于 $e+k$,即

$$
n - ((e-1) + e + (e+1) + \cdots + (e+k-2)) < e + k
$$

利用等差数列求和公式,上式化简为

$$
k^2 + (2e-1) k - 2(n-e) > 0
$$

解得

$$
k > \dfrac{-(2e-1) + \sqrt{(2e-1)^2+8(n-e)}}{2}
$$

设 $b = 2e-1$。由于 $k$ 是整数,所以 $k$ 为

$$
\left\lfloor\dfrac{\sqrt{b^2+8(n-e)} - b}{2}\right\rfloor + 1 = \left\lfloor\dfrac{\sqrt{b^2+8(n-e)} - b + 2}{2}\right\rfloor
$$

为了减少浮点运算次数,减少舍入误差,根据 下取整恒等式及其应用,上式等于

$$
\left\lfloor\dfrac{\lfloor\sqrt{b^2+8(n-e)}\rfloor - b + 2}{2}\right\rfloor
$$

注:部分编程语言是向零取整的,必须把 $+1$ 放到分数中,否则计算出负数,向零取整就是向上取整了。

###py

class Solution:
    def maxBottlesDrunk(self, n: int, e: int) -> int:
        b = e * 2 - 1
        k = (isqrt(b * b + (n - e) * 8) - b + 2) // 2
        return n + k

###java

class Solution {
    public int maxBottlesDrunk(int n, int e) {
        int b = e * 2 - 1;
        int k = ((int) Math.sqrt(b * b + (n - e) * 8) - b + 2) / 2;
        return n + k;
    }
}

###cpp

class Solution {
public:
    int maxBottlesDrunk(int n, int e) {
        int b = e * 2 - 1;
        int k = ((int) sqrt(b * b + (n - e) * 8) - b + 2) / 2;
        return n + k;
    }
};

###c

int maxBottlesDrunk(int n, int e) {
    int b = e * 2 - 1;
    int k = ((int) sqrt(b * b + (n - e) * 8) - b + 2) / 2;
    return n + k;
}

###go

func maxBottlesDrunk(n, e int) int {
b := e*2 - 1
k := (int(math.Sqrt(float64(b*b+(n-e)*8))) - b + 2) / 2
return n + k
}

###js

var maxBottlesDrunk = function(n, e) {
    const b = e * 2 - 1;
    const k = Math.floor((Math.sqrt(b * b + (n - e) * 8) - b) / 2) + 1;
    return n + k;
};

###rust

impl Solution {
    pub fn max_bottles_drunk(n: i32, e: i32) -> i32 {
        let b = e * 2 - 1;
        let delta = b * b + (n - e) * 8;
        let k = ((delta as f64).sqrt() as i32 - b + 2) / 2;
        n + k
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:模拟 / O(1) 数学公式(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年9月23日 14:24

方法一:模拟

根据题意,如果 $\textit{numBottles}\ge \textit{numExchange}$,那么我们可以用 $\textit{numExchange}$ 瓶水换 $1$ 瓶水,在这个过程中:

  • 我们喝了 $\textit{numExchange}$ 瓶水。
  • 我们手中持有的水瓶个数减少了 $\textit{numExchange}$,然后增加了 $1$。相当于把 $\textit{numBottles}$ 减少了 $\textit{numExchange}-1$。

按照上述过程模拟,循环直到 $\textit{numBottles} < \textit{numExchange}$。

最后,把剩下的 $\textit{numBottles}$ 瓶水全部喝完。

class Solution:
    def numWaterBottles(self, numBottles: int, numExchange: int) -> int:
        ans = 0
        while numBottles >= numExchange:
            ans += numExchange  # 吨吨吨~
            numBottles -= numExchange - 1
        return ans + numBottles
class Solution {
    public int numWaterBottles(int numBottles, int numExchange) {
        int ans = 0;
        while (numBottles >= numExchange) {
            ans += numExchange; // 吨吨吨~
            numBottles -= numExchange - 1;
        }
        return ans + numBottles;
    }
}
class Solution {
public:
    int numWaterBottles(int numBottles, int numExchange) {
        int ans = 0;
        while (numBottles >= numExchange) {
            ans += numExchange; // 吨吨吨~
            numBottles -= numExchange - 1;
        }
        return ans + numBottles;
    }
};
int numWaterBottles(int numBottles, int numExchange) {
    int ans = 0;
    while (numBottles >= numExchange) {
        ans += numExchange; // 吨吨吨~
        numBottles -= numExchange - 1;
    }
    return ans + numBottles;
}
func numWaterBottles(numBottles, numExchange int) (ans int) {
for numBottles >= numExchange {
ans += numExchange // 吨吨吨~
numBottles -= numExchange - 1
}
return ans + numBottles
}
var numWaterBottles = function(numBottles, numExchange) {
    let ans = 0;
    while (numBottles >= numExchange) {
        ans += numExchange; // 吨吨吨~
        numBottles -= numExchange - 1;
    }
    return ans + numBottles;
};
impl Solution {
    pub fn num_water_bottles(mut num_bottles: i32, num_exchange: i32) -> i32 {
        let mut ans = 0;
        while num_bottles >= num_exchange {
            ans += num_exchange; // 吨吨吨~
            num_bottles -= num_exchange - 1;
        }
        ans + num_bottles
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}\left(\dfrac{\textit{numBottles}}{\textit{numExchange}}\right)$。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:数学公式

上面代码循环了多少次,就额外得到了多少瓶水。

所以答案为 $\textit{numBottles}$ 加上循环次数。

设 $k$ 为上面代码的循环次数(额外得到的水瓶数)。循环 $k$ 次后剩余瓶子数小于 $\textit{numExchange}$,即

$$
\textit{numBottles} - k\cdot(\textit{numExchange} - 1) < \textit{numExchange}
$$

解得

$$
k > \dfrac{\textit{numBottles} - \textit{numExchange}}{\textit{numExchange} - 1}
$$

由于 $k$ 是整数,所以 $k$ 为

$$
\left\lfloor\dfrac{\textit{numBottles} - \textit{numExchange}}{\textit{numExchange} - 1}\right\rfloor + 1 = \left\lfloor\dfrac{\textit{numBottles} - 1}{\textit{numExchange} - 1}\right\rfloor
$$

所以我们可以通过兑换,多喝 $\left\lfloor\dfrac{\textit{numBottles} - 1}{\textit{numExchange} - 1}\right\rfloor$ 瓶水,所以答案为

$$
\textit{numBottles} + \left\lfloor\dfrac{\textit{numBottles} - 1}{\textit{numExchange} - 1}\right\rfloor
$$

class Solution:
    def numWaterBottles(self, numBottles: int, numExchange: int) -> int:
        return numBottles + (numBottles - 1) // (numExchange - 1)
class Solution {
    public int numWaterBottles(int numBottles, int numExchange) {
        return numBottles + (numBottles - 1) / (numExchange - 1);
    }
}
class Solution {
public:
    int numWaterBottles(int numBottles, int numExchange) {
        return numBottles + (numBottles - 1) / (numExchange - 1);
    }
};
int numWaterBottles(int numBottles, int numExchange) {
    return numBottles + (numBottles - 1) / (numExchange - 1);
}
func numWaterBottles(numBottles, numExchange int) int {
return numBottles + (numBottles-1)/(numExchange-1)
}
var numWaterBottles = function(numBottles, numExchange) {
    return numBottles + Math.floor((numBottles - 1) / (numExchange - 1));
};
impl Solution {
    pub fn num_water_bottles(num_bottles: i32, num_exchange: i32) -> i32 {
        num_bottles + (num_bottles - 1) / (num_exchange - 1)
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:O(n^2) 模拟 / O(n) 组合数学(Python/Java/C++/Go)

作者 endlesscheng
2022年4月3日 00:13

方法一:模拟

无需 $\textit{newNums}$,直接原地把 $\textit{nums}[i]$ 更新为 $(\textit{nums}[i] + \textit{nums}[i+1])\bmod 10$。

每循环一次,把 $\textit{nums}$ 的长度减一。

lc2221.png{:width=260}

示例 1 的 $\textit{nums}=[1,2,3,4,5]$。

一轮循环后,$\textit{nums}=[1+2,2+3,3+4,4+5] = [3,5,7,9]$。

再循环一轮,$\textit{nums}=[3+5,5+7,7+9] = [8,2,6]$(模 $10$)。

依此类推,循环直到 $\textit{nums}$ 只剩一个数。

最终答案为 $\textit{nums}[0]$。

class Solution:
    def triangularSum(self, nums: List[int]) -> int:
        # 每循环一轮,数组长度就减一
        for n in range(len(nums) - 1, 0, -1):
            for i in range(n):
                nums[i] = (nums[i] + nums[i + 1]) % 10
        return nums[0]
class Solution {
    public int triangularSum(int[] nums) {
        // 每循环一轮,数组长度就减一
        for (int n = nums.length - 1; n > 0; n--) {
            for (int i = 0; i < n; i++) {
                nums[i] = (nums[i] + nums[i + 1]) % 10;
            }
        }
        return nums[0];
    }
}
class Solution {
public:
    int triangularSum(vector<int>& nums) {
        // 每循环一轮,数组长度就减一
        for (int n = nums.size() - 1; n > 0; n--) {
            for (int i = 0; i < n; i++) {
                nums[i] = (nums[i] + nums[i + 1]) % 10;
            }
        }
        return nums[0];
    }
};
func triangularSum(nums []int) int {
// 每循环一轮,数组长度就减一
for n := len(nums) - 1; n > 0; n-- {
for i := range n {
nums[i] = (nums[i] + nums[i+1]) % 10
}
}
return nums[0]
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:组合数学

公式推导

能否把最终答案与原数组的元素联系起来?

例如 $\textit{nums}=[a,b,c]$,操作一轮后变成 $[a+b, b+c]$,再操作一轮后变成 $[a+2b+c]$。

又例如 $\textit{nums}=[a,b,c,d]$,操作一轮后变成 $[a+b,b+c,c+d]$,再操作一轮后变成 $[a+2b+c,b+2c+d]$,再操作一轮后变成 $[a+3b+3c+d]$。

又例如 $\textit{nums}=[a,b,c,d,e]$,倒数第二轮是 $[a+3b+3c+d, b+3c+3d+e]$,相加得到 $[a+4b+6c+4d+e]$。

这和 118. 杨辉三角 的计算过程是一样的,比如 $1,4,6,4,1$ 可以理解成 $1,3,3,1,0$ 加上 $0,1,3,3,1$。所以(根据数学归纳法)系数就是组合数。

也可以从 62. 不同路径 的角度理解。比如第一排的 $\textit{nums}[1]=2$,在最终答案中的系数是 $4$,这等于从第一排的 $2$ 移动到底部的路径数

lc2221.png{:width=260}

根据 62 我的题解 的方法二,我们要走 $n-1$ 步,其中恰好有 $i$ 步要往左下走,恰好有 $n-1-i$ 步要往右下走,问题变成从 $n-1$ 步中选择 $i$ 步往左下走的方案数。所以在最终答案中,$\textit{nums}[i]$ 的系数是组合数

$$
\dbinom {n-1} {i}
$$

最终答案为

$$
\left(\sum_{i=0}^{n-1} \dbinom {n-1} {i} \cdot \textit{nums}[i]\right)\bmod 10
$$

分离因子 2 和 5,欧拉定理求逆元

关于组合数取模的原理,请看 模运算的世界:当加减乘除遇上取模

本题由于模数 $10$ 不是质数,计算逆元无法用费马小定理,怎么办?

把每个数中的质因子 $2$ 和 $5$ 分离出来,并统计质因子 $2$ 和 $5$ 的个数。一个数去掉其中所有质因子 $2$ 和 $5$ 之后,得到的整数 $a$ 与 $10$ 互质,这样就可以用欧拉定理计算整数 $a$ 在模 $10$ 下的逆元,即 $a^{\varphi(10)-1} = a^3$。

细节

在计算组合数时,需要把分离出来的质因子 $2$ 和 $5$ 再乘回去。

预处理 $2$ 的幂模 $10$:由于 $2^i\bmod 10\ (i>0)$ 按照 $2,4,8,6$ 的周期循环,所以只需处理一个长为 $4$ 的数组 $[2,4,8,6]$。

预处理 $5$ 的幂模 $10$:由于 $i>0$ 时,$5^i\bmod 10 = 5$ 恒成立,所以无需预处理。

具体请看 视频讲解,欢迎点赞关注~

class Solution:
    def triangularSum(self, nums: List[int]) -> int:
        n = len(nums)
        # 直接调用 math.comb 算出来的组合数很大,更快的写法见【Python3 预处理】
        return sum(comb(n - 1, i) * x for i, x in enumerate(nums)) % 10
MOD = 10
MX = 1000
POW2 = (2, 4, 8, 6)

# 计算组合数,需要计算阶乘及其逆元
f = [0] * (MX + 1)  # f[n] = n!
inv_f = [0] * (MX + 1)  # inv_f[n] = n!^-1
p2 = [0] * (MX + 1)  # n! 中的 2 的幂次
p5 = [0] * (MX + 1)  # n! 中的 5 的幂次

f[0] = inv_f[0] = 1
for i in range(1, MX + 1):
    x = i

    # 分离质因子 2,计算 2 的幂次
    e2 = (x & -x).bit_length() - 1
    x >>= e2

    # 分离质因子 5,计算 5 的幂次
    e5 = 0
    while x % 5 == 0:
        e5 += 1
        x //= 5

    f[i] = f[i - 1] * x % MOD
    inv_f[i] = pow(f[i], 3, MOD)  # 欧拉定理求逆元
    p2[i] = p2[i - 1] + e2
    p5[i] = p5[i - 1] + e5

def comb(n: int, k: int) -> int:
    e2 = p2[n] - p2[k] - p2[n - k]
    return f[n] * inv_f[k] * inv_f[n - k] * \
        (POW2[(e2 - 1) % 4] if e2 else 1) * \
        (5 if p5[n] - p5[k] - p5[n - k] else 1)

class Solution:
    def triangularSum(self, nums: List[int]) -> int:
        n = len(nums)
        return sum(comb(n - 1, i) * x for i, x in enumerate(nums)) % MOD
class Solution {
    private static final int MOD = 10;
    private static final int MX = 1000;
    private static final int[] POW2 = new int[]{2, 4, 8, 6};

    // 计算组合数,需要计算阶乘及其逆元
    private static final int[] f = new int[MX + 1]; // f[n] = n!
    private static final int[] invF = new int[MX + 1]; // invF[n] = n!^-1
    private static final int[] p2 = new int[MX + 1]; // n! 中的 2 的幂次
    private static final int[] p5 = new int[MX + 1]; // n! 中的 5 的幂次

    private static boolean initialized = false;

    // 这样写比 static block 快
    private void init() {
        if (initialized) {
            return;
        }
        initialized = true;

        f[0] = invF[0] = 1;
        for (int i = 1; i <= MX; i++) {
            int x = i;

            // 分离质因子 2,计算 2 的幂次
            int e2 = Integer.numberOfTrailingZeros(x);
            x >>= e2;
            
            // 分离质因子 5,计算 5 的幂次
            int e5 = 0;
            while (x % 5 == 0) {
                e5++;
                x /= 5;
            }

            f[i] = f[i - 1] * x % MOD;
            invF[i] = pow(f[i], 3); // 欧拉定理求逆元
            p2[i] = p2[i - 1] + e2;
            p5[i] = p5[i - 1] + e5;
        }
    }

    private int pow(int x, int n) {
        int res = 1;
        while (n > 0) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
            n /= 2;
        }
        return res;
    }

    private int comb(int n, int k) {
        int e2 = p2[n] - p2[k] - p2[n - k];
        return f[n] * invF[k] * invF[n - k] *
                (e2 > 0 ? POW2[(e2 - 1) % 4] : 1) *
                (p5[n] - p5[k] - p5[n - k] > 0 ? 5 : 1) % MOD;
    }

    public int triangularSum(int[] nums) {
        init();
        int n = nums.length;
        int ans = 0;
        for (int i = 0; i < n; i++) {
            ans += comb(n - 1, i) * nums[i];
        }
        return ans % MOD;
    }
}
const int MOD = 10;
const int MX = 1000;
const int POW2[4] = {2, 4, 8, 6};

// 计算组合数,需要计算阶乘及其逆元
int f[MX + 1]; // f[n] = n!
int inv_f[MX + 1]; // invF[n] = n!^-1
int p2[MX + 1]; // n! 中的 2 的幂次
int p5[MX + 1]; // n! 中的 5 的幂次

int qpow(int x, int n) {
    int res = 1;
    while (n > 0) {
        if (n % 2 > 0) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
        n /= 2;
    }
    return res;
}

auto init = []() {
    f[0] = inv_f[0] = 1;
    for (int i = 1; i <= MX; i++) {
        int x = i;

        // 分离质因子 2,计算 2 的幂次
        int e2 = countr_zero((uint32_t) x);
        x >>= e2;

        // 分离质因子 5,计算 5 的幂次
        int e5 = 0;
        while (x % 5 == 0) {
            e5++;
            x /= 5;
        }

        f[i] = f[i - 1] * x % MOD;
        inv_f[i] = qpow(f[i], 3); // 欧拉定理求逆元
        p2[i] = p2[i - 1] + e2;
        p5[i] = p5[i - 1] + e5;
    }
    return 0;
}();

int comb(int n, int k) {
    int e2 = p2[n] - p2[k] - p2[n - k];
    return f[n] * inv_f[k] * inv_f[n - k] *
           (e2 ? POW2[(e2 - 1) % 4] : 1) *
           (p5[n] - p5[k] - p5[n - k] ? 5 : 1) % MOD;
}

class Solution {
public:
    int triangularSum(vector<int>& nums) {
        int n = nums.size();
        int ans = 0;
        for (int i = 0; i < n; i++) {
            ans += comb(n - 1, i) * nums[i];
        }
        return ans % MOD;
    }
};
const mod = 10

func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}

const mx = 1000

// 计算组合数,需要计算阶乘及其逆元
var (
f    [mx + 1]int // f[n] = n!
invF [mx + 1]int // invF[n] = n!^-1
p2   [mx + 1]int // n! 中的 2 的幂次
p5   [mx + 1]int // n! 中的 5 的幂次
)

func init() {
f[0] = 1
invF[0] = 1
for i := 1; i <= mx; i++ {
x := i

// 分离质因子 2,计算 2 的幂次
e2 := bits.TrailingZeros(uint(x))
x >>= e2

// 分离质因子 5,计算 5 的幂次
e5 := 0
for x%5 == 0 {
e5++
x /= 5
}

f[i] = f[i-1] * x % mod
invF[i] = pow(f[i], 3) // 欧拉定理求逆元
p2[i] = p2[i-1] + e2
p5[i] = p5[i-1] + e5
}
}

var pow2 = [4]int{2, 4, 8, 6}

func comb(n, k int) int {
res := f[n] * invF[k] * invF[n-k]
e2 := p2[n] - p2[k] - p2[n-k]
if e2 > 0 {
res *= pow2[(e2-1)%4]
}
if p5[n]-p5[k]-p5[n-k] > 0 {
res *= 5
}
return res
}

func triangularSum(nums []int) (ans int) {
for i, x := range nums {
ans += comb(len(nums)-1, i) * x
}
return ans % mod
}

复杂度分析

不计入预处理的时间和空间。

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

:本题还可以用 Lucas 定理做,见 我的题解

相似题目

3463. 判断操作后字符串中的数字是否相等 II 2286

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

【视频】教你一步步思考动态规划,从记忆化搜索到递推(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2023年4月1日 19:12

本题视频讲解

我单独制作了一期视频来讲区间 DP,其中就包括这道题目。

区间 DP【基础算法精讲 22】,制作不易,欢迎点赞关注~

一、记忆化搜索

1039-cut.png{:width=600px}

考虑到整个递归过程中有大量重复递归调用(递归入参相同)。由于递归函数没有副作用,同样的入参无论计算多少次,算出来的结果都是一样的,因此可以用记忆化搜索来优化:

  • 如果一个状态(递归入参)是第一次遇到,那么可以在返回前,把状态及其结果记到一个 $\textit{memo}$ 数组中。
  • 如果一个状态不是第一次遇到($\textit{memo}$ 中保存的结果不等于 $\textit{memo}$ 的初始值),那么可以直接返回 $\textit{memo}$ 中保存的结果。

注意:$\textit{memo}$ 数组的初始值一定不能等于要记忆化的值!例如初始值设置为 $0$,并且要记忆化的 $\textit{dfs}(i,j)$ 也等于 $0$,那就没法判断 $0$ 到底表示第一次遇到这个状态,还是表示之前遇到过了,从而导致记忆化失效。一般把初始值设置为 $-1$。本题由于 $\textit{values}[i]>0$,所以 $\textit{memo}[i][j]$ 可以初始化成 $0$。

Python 用户可以无视上面这段,直接用 @cache 装饰器。

答疑

:区间 DP 有一个「复制一倍,断环成链」的技巧,本题为什么不用这样计算?

:无论如何旋转多边形,无论从哪条边开始计算,得到的结果都是一样的,那么不妨就从 $0$ - $(n-1)$ 这条边开始计算。

class Solution:
    def minScoreTriangulation(self, v: List[int]) -> int:
        @cache  # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
        def dfs(i: int, j: int) -> int:
            if i + 1 == j:
                return 0  # 只有两个点,无法组成三角形
            return min(dfs(i, k) + dfs(k, j) + v[i] * v[j] * v[k]
                       for k in range(i + 1, j))  # 枚举顶点 k

        return dfs(0, len(v) - 1)
class Solution {
    public int minScoreTriangulation(int[] values) {
        int n = values.length;
        int[][] memo = new int[n][n];
        for (int[] row : memo) {
            Arrays.fill(row, -1); // -1 表示没有计算过
        }
        return dfs(0, n - 1, values, memo);
    }

    private int dfs(int i, int j, int[] v, int[][] memo) {
        if (i + 1 == j) {
            return 0; // 只有两个点,无法组成三角形
        }

        if (memo[i][j] != -1) { // 之前计算过
            return memo[i][j];
        }

        int res = Integer.MAX_VALUE;
        for (int k = i + 1; k < j; k++) { // 枚举顶点 k
            int subRes = dfs(i, k, v, memo) + dfs(k, j, v, memo) + v[i] * v[j] * v[k];
            res = Math.min(res, subRes);
        }

        return memo[i][j] = res; // 记忆化
    }
}
class Solution {
public:
    int minScoreTriangulation(vector<int>& v) {
        int n = v.size();
        vector memo(n, vector<int>(n, -1)); // -1 表示没有计算过

        // lambda 递归函数
        auto dfs = [&](this auto&& dfs, int i, int j) -> int {
            if (i + 1 == j) {
                return 0; // 只有两个点,无法组成三角形
            }
            int& res = memo[i][j]; // 注意这里是引用,修改 res 相当于修改 memo[i][j]
            if (res != -1) { // 之前计算过
                return res;
            }
            res = INT_MAX;
            for (int k = i + 1; k < j; k++) { // 枚举顶点 k
                res = min(res, dfs(i, k) + dfs(k, j) + v[i] * v[j] * v[k]);
            }
            return res;
        };

        return dfs(0, n - 1);
    }
};
#define MIN(a, b) ((b) < (a) ? (b) : (a))

int minScoreTriangulation(int* v, int n) {
    int** memo = malloc(n * sizeof(int*));
    for (int i = 0; i < n; i++) {
        memo[i] = malloc(n * sizeof(int));
        for (int j = 0; j < n; j++) {
            memo[i][j] = -1; // -1 表示没有计算过
        }
    }

    int dfs(int i, int j) {
        if (i + 1 == j) {
            return 0; // 只有两个点,无法组成三角形
        }
        if (memo[i][j] != -1) { // 之前计算过
            return memo[i][j];
        }
        int res = INT_MAX;
        for (int k = i + 1; k < j; k++) { // 枚举顶点 k
            int sub_res = dfs(i, k) + dfs(k, j) + v[i] * v[j] * v[k];
            res = MIN(res, sub_res);
        }
        return memo[i][j] = res; // 记忆化
    }
    int ans = dfs(0, n - 1);

    for (int i = 0; i < n; i++) {
        free(memo[i]);
    }
    free(memo);
    return ans;
}
func minScoreTriangulation(v []int) int {
    n := len(v)
    memo := make([][]int, n)
    for i := range memo {
        memo[i] = make([]int, n)
        for j := range memo[i] {
            memo[i][j] = -1 // -1 表示没有计算过
        }
    }

    var dfs func(int, int) int
    dfs = func(i, j int) int {
        if i+1 == j { // 只有两个点,无法组成三角形
            return 0
        }
        p := &memo[i][j]
        if *p != -1 { // 之前计算过
            return *p
        }
        res := math.MaxInt
        for k := i + 1; k < j; k++ { // 枚举顶点 k
            res = min(res, dfs(i, k)+dfs(k, j)+v[i]*v[j]*v[k])
        }
        *p = res // 记忆化
        return res
    }

    return dfs(0, n-1)
}
var minScoreTriangulation = function(v) {
    const n = v.length;
    const memo = Array.from({ length: n }, () => Array(n).fill(-1)); // -1 表示没有计算过

    function dfs(i, j) {
        if (i + 1 === j) {
            return 0; // 只有两个点,无法组成三角形
        }
        if (memo[i][j] !== -1) { // 之前计算过
            return memo[i][j];
        }
        let res = Infinity;
        for (let k = i + 1; k < j; k++) { // 枚举顶点 k
            res = Math.min(res, dfs(i, k) + dfs(k, j) + v[i] * v[j] * v[k]);
        }
        return memo[i][j] = res; // 记忆化
    }

    return dfs(0, n - 1);
};
impl Solution {
    pub fn min_score_triangulation(v: Vec<i32>) -> i32 {
        fn dfs(i: usize, j: usize, v: &[i32], memo: &mut [Vec<i32>]) -> i32 {
            if i + 1 == j {
                return 0; // 只有两个点,无法组成三角形
            }
            if memo[i][j] != -1 { // 之前计算过
                return memo[i][j];
            }
            let mut res = i32::MAX;
            for k in i + 1..j { // 枚举顶点 k
                let val = dfs(i, k, v, memo) + dfs(k, j, v, memo) + v[i] * v[j] * v[k];
                res = res.min(val);
            }
            memo[i][j] = res; // 记忆化
            res
        }

        let n = v.len();
        let mut memo = vec![vec![-1; n]; n]; // -1 表示没有计算过
        dfs(0, n - 1, &v, &mut memo)
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^3)$,其中 $n$ 为 $\textit{values}$ 的长度。动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题中状态个数等于 $\mathcal{O}(n^2)$,单个状态的计算时间为 $\mathcal{O}(n)$,因此时间复杂度为 $\mathcal{O}(n^3)$。
  • 空间复杂度:$\mathcal{O}(n^2)$。保存多少状态,就需要多少空间。

二、1:1 翻译成递推

根据视频中讲的,把 $\textit{dfs}$ 改成 $f$ 数组,把递归改成循环就好了。相当于原来是用递归计算每个状态 $(i,j)$,现在改用循环去计算每个状态 $(i,j)$。

状态转移方程和递归完全一致

$$
f[i][j]=\min_{k=i+1}^{j-1}{f[i][k]+f[k][j]+v[i]\cdot v[j]\cdot v[k]}
$$

需要注意循环的顺序:

  • 由于 $i<k$,$f[i]$ 要能从 $f[k]$ 转移过来,必须先计算出 $f[k]$,所以 $i$ 要倒序枚举;
  • 由于 $j>k$,$f[i][j]$ 要能从 $f[i][k]$ 转移过来,必须先计算出 $f[i][k]$,所以 $j$ 要正序枚举。

此外,递推式中的 $j\ge i+2$($j=i+1$ 的情况是初始值,无需计算),由于 $j\le n-1$,所以 $i+2\le n-1$,即 $i\le n-3$,所以 $i$ 从 $n-3$ 开始枚举。

初始值 $f[i][i+1]=0$,翻译自递归边界 $\textit{dfs}(i,i+1) = 0$。

答案为 $f[0][n-1]$,翻译自递归入口 $\textit{dfs}(0,n-1)$。

class Solution:
    def minScoreTriangulation(self, v: List[int]) -> int:
        n = len(v)
        f = [[0] * n for _ in range(n)]
        for i in range(n - 3, -1, -1):
            for j in range(i + 2, n):
                f[i][j] = min(f[i][k] + f[k][j] + v[i] * v[j] * v[k]
                              for k in range(i + 1, j))
        return f[0][-1]
class Solution {
    public int minScoreTriangulation(int[] v) {
        int n = v.length;
        int[][] f = new int[n][n];
        for (int i = n - 3; i >= 0; i--) {
            for (int j = i + 2; j < n; j++) {
                f[i][j] = Integer.MAX_VALUE;
                for (int k = i + 1; k < j; k++) {
                    f[i][j] = Math.min(f[i][j], f[i][k] + f[k][j] + v[i] * v[j] * v[k]);
                }
            }
        }
        return f[0][n - 1];
    }
}
class Solution {
public:
    int minScoreTriangulation(vector<int>& v) {
        int n = v.size();
        vector f(n, vector<int>(n));
        for (int i = n - 3; i >= 0; i--) {
            for (int j = i + 2; j < n; j++) {
                f[i][j] = INT_MAX;
                for (int k = i + 1; k < j; k++) {
                    f[i][j] = min(f[i][j], f[i][k] + f[k][j] + v[i] * v[j] * v[k]);
                }
            }
        }
        return f[0][n - 1];
    }
};
#define MIN(a, b) ((b) < (a) ? (b) : (a))

int minScoreTriangulation(int* v, int n) {
    int** f = malloc(n * sizeof(int*));
    for (int i = 0; i < n; i++) {
        f[i] = calloc(n, sizeof(int));
    }
    for (int i = n - 3; i >= 0; i--) {
        for (int j = i + 2; j < n; j++) {
            f[i][j] = INT_MAX;
            for (int k = i + 1; k < j; k++) {
                f[i][j] = MIN(f[i][j], f[i][k] + f[k][j] + v[i] * v[j] * v[k]);
            }
        }
    }
    int ans = f[0][n - 1];

    for (int i = 0; i < n; i++) {
        free(f[i]);
    }
    free(f);
    return ans;
}
func minScoreTriangulation(v []int) int {
    n := len(v)
    f := make([][]int, n)
    for i := range f {
        f[i] = make([]int, n)
    }
    for i := n - 3; i >= 0; i-- {
        for j := i + 2; j < n; j++ {
            f[i][j] = math.MaxInt
            for k := i + 1; k < j; k++ {
                f[i][j] = min(f[i][j], f[i][k]+f[k][j]+v[i]*v[j]*v[k])
            }
        }
    }
    return f[0][n-1]
}
var minScoreTriangulation = function(v) {
    const n = v.length;
    const f = Array.from({ length: n }, () => Array(n).fill(0));
    for (let i = n - 3; i >= 0; i--) {
        for (let j = i + 2; j < n; j++) {
            f[i][j] = Infinity;
            for (let k = i + 1; k < j; k++) {
                f[i][j] = Math.min(f[i][j], f[i][k] + f[k][j] + v[i] * v[j] * v[k]);
            }
        }
    }
    return f[0][n - 1];
};
impl Solution {
    pub fn min_score_triangulation(v: Vec<i32>) -> i32 {
        let n = v.len();
        let mut f = vec![vec![0; n]; n];
        for i in (0..n - 2).rev() {
            for j in i + 2..n {
                f[i][j] = i32::MAX;
                for k in i + 1..j {
                    f[i][j] = f[i][j].min(f[i][k] + f[k][j] + v[i] * v[j] * v[k]);
                }
            }
        }
        f[0][n - 1]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^3)$,其中 $n$ 为 $\textit{values}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n^2)$。

思考题

计算把 $n$ 边形三角剖分的方案数

欢迎在评论区分享你的思路/代码。

专题训练

见下面动态规划题单的「八、区间 DP

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:枚举 / 凸包+旋转卡壳(Python/Java/C++/Go)

作者 endlesscheng
2025年9月27日 09:02

三角形面积公式

对于平面上的三个点 $P_1,P_2,P_3$,定义 $\mathbf{a} = \overrightarrow{P_1P_2}$,$\mathbf{b} = \overrightarrow{P_1P_3}$。

根据向量叉积的定义,$|| \mathbf{a} \times \mathbf{b} ||$ 是由 $\mathbf{a}$ 和 $\mathbf{b}$ 张成的平行四边形的面积。除以 $2$ 就得到了 $\triangle P_1P_2P_3$ 的面积。

严格来说,叉积是三维概念。这里把向量 $\mathbf{a}$ 和 $\mathbf{b}$ 视作 $z$ 方向为 $0$ 的三维向量。

设 $\mathbf{a} = (x_1,y_1)$,$\mathbf{b} = (x_2,y_2)$,根据叉积的计算公式,三角形面积为

$$
\dfrac{1}{2}|x_1y_2 - y_1x_2|
$$

上式中的 $(x_1,y_1)$ 来自 $P_1,P_2$ 的横坐标之差,纵坐标之差。$(x_2,y_2)$ 来自 $P_1,P_3$ 的横坐标之差,纵坐标之差。

方法一:暴力枚举

class Solution:
    def largestTriangleArea(self, points: List[List[int]]) -> float:
        ans = 0
        for p1, p2, p3 in combinations(points, 3):
            x1, y1 = p2[0] - p1[0], p2[1] - p1[1]
            x2, y2 = p3[0] - p1[0], p3[1] - p1[1]
            ans = max(ans, abs(x1 * y2 - y1 * x2))  # 注意这里没有除以 2
        return ans / 2
class Solution {
    public double largestTriangleArea(int[][] points) {
        int n = points.length;
        int ans = 0;
        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                for (int k = j + 1; k < n; k++) {
                    int[] p1 = points[i], p2 = points[j], p3 = points[k];
                    int x1 = p2[0] - p1[0], y1 = p2[1] - p1[1];
                    int x2 = p3[0] - p1[0], y2 = p3[1] - p1[1];
                    ans = Math.max(ans, Math.abs(x1 * y2 - y1 * x2)); // 注意这里没有除以 2
                }
            }
        }
        return ans / 2.0;
    }
}
class Solution {
public:
    double largestTriangleArea(vector<vector<int>>& points) {
        int n = points.size();
        int ans = 0;
        for (int i = 0; i < n - 2; i++) {
            auto& p1 = points[i];
            for (int j = i + 1; j < n - 1; j++) {
                auto& p2 = points[j];
                for (int k = j + 1; k < n; k++) {
                    auto& p3 = points[k];
                    int x1 = p2[0] - p1[0], y1 = p2[1] - p1[1];
                    int x2 = p3[0] - p1[0], y2 = p3[1] - p1[1];
                    ans = max(ans, abs(x1 * y2 - y1 * x2)); // 注意这里没有除以 2
                }
            }
        }
        return ans / 2.0;
    }
};
func largestTriangleArea(points [][]int) float64 {
n := len(points)
ans := 0
for i := range n - 2 {
for j := i + 1; j < n-1; j++ {
for k := j + 1; k < n; k++ {
p1, p2, p3 := points[i], points[j], points[k]
x1, y1 := p2[0]-p1[0], p2[1]-p1[1]
x2, y2 := p3[0]-p1[0], p3[1]-p1[1]
ans = max(ans, abs(x1*y2-y1*x2)) // 注意这里没有除以 2
}
}
}
return float64(ans) / 2
}

func abs(x int) int { if x < 0 { return -x }; return x }

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^3)$,其中 $n$ 是 $\textit{points}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:凸包 + 旋转卡壳

前置题目587. 安装栅栏

若固定三角形的两个顶点,那么第三个顶点在哪?

三角形的高越长越好,所以第三个顶点相比在凸包内部,在凸包上更好(更远)。所以面积最大的三角形,三个顶点都在凸包上。

求出凸包后:

  1. 枚举凸包的顶点 $i$ 作为三角形的其中一个顶点。对于另外两个顶点,我们用旋转卡壳(同向双指针)计算。
  2. 初始化 $j=i+1$,$k=i+2$。
  3. 对于 $\triangle P_iP_jP_k$ 和 $\triangle P_iP_jP_{k+1}$ 的面积,如果后者更大,那么把 $k$ 加一。重复该过程,直到 $k+1$ 越界或者面积没有变大,跳出循环。
  4. 跳出循环后,$P_k$ 就移动到了一个在 $\overrightarrow{P_iP_j}$ 左侧且距离 $P_iP_j$ 最远的位置。计算 $\triangle P_iP_jP_k$ 的面积,更新答案的最大值。然后把 $j$ 加一,执行第三步。读者可以在纸上画画,随着 $j$ 的增大,在 $\overrightarrow{P_iP_j}$ 左侧且距离 $P_iP_j$ 最远的 $P_k$ 的下标 $k$ 也在增大,所以可以用同向双指针。

下面代码保证计算 $\mathbf{a} \times \mathbf{b}$ 时,$\mathbf{b}$ 在 $\mathbf{a}$ 的左侧,此时算出来的面积一定大于 $0$,无需取绝对值。

def det(x1: int, y1: int, x2: int, y2: int) -> int:
    return x1 * y2 - y1 * x2

def convex_hull(points: List[List[int]]) -> List[List[int]]:
    points.sort()

    # 计算下凸包(从左到右)
    q = []
    for p in points:
        while len(q) > 1 and det(q[-1][0] - q[-2][0], q[-1][1] - q[-2][1], p[0] - q[-1][0], p[1] - q[-1][1]) <= 0:
            q.pop()
        q.append(p)

    # 计算上凸包(从右到左)
    down_size = len(q)
    # 注意下凸包的最后一个点,是上凸包的右边第一个点,所以从 n-2 开始遍历
    for i in range(len(points) - 2, -1, -1):
        p = points[i]
        while len(q) > down_size and det(q[-1][0] - q[-2][0], q[-1][1] - q[-2][1], p[0] - q[-1][0], p[1] - q[-1][1]) <= 0:
            q.pop()
        q.append(p)

    # 此时首尾是同一个点 points[0],需要去掉
    q.pop()
    return q

class Solution:
    def largestTriangleArea(self, points: List[List[int]]) -> float:
        ch = convex_hull(points)

        def area(i: int, j: int, k: int) -> int:
            return det(ch[j][0] - ch[i][0], ch[j][1] - ch[i][1], ch[k][0] - ch[i][0], ch[k][1] - ch[i][1])

        m = len(ch)
        ans = 0
        # 固定三角形的其中一个顶点 ch[i]
        for i in range(m):
            # 同向双指针
            k = i + 2
            for j in range(i + 1, m - 1):
                while k + 1 < m and area(i, j, k) < area(i, j, k + 1):
                    k += 1
                # 循环结束后,ch[k] 距离 ch[i]ch[j] 最远
                ans = max(ans, area(i, j, k))  # 注意这里没有除以 2
        return ans / 2
class Solution {
    public double largestTriangleArea(int[][] points) {
        List<int[]> ch = convexHull(points);
        int m = ch.size();
        int ans = 0;
        // 固定三角形的其中一个顶点 ch[i]
        for (int i = 0; i < m; i++) {
            // 同向双指针
            int k = i + 2;
            for (int j = i + 1; j < m - 1; j++) {
                while (k + 1 < m && area(ch, i, j, k) < area(ch, i, j, k + 1)) {
                    k++;
                }
                // 循环结束后,ch[k] 距离 ch[i]ch[j] 最远
                ans = Math.max(ans, area(ch, i, j, k)); // 注意这里没有除以 2
            }
        }
        return ans / 2.0;
    }

    private List<int[]> convexHull(int[][] points) {
        Arrays.sort(points, (a, b) -> a[0] != b[0] ? a[0] - b[0] : a[1] - b[1]);

        // 计算下凸包(从左到右)
        List<int[]> q = new ArrayList<>();
        for (int[] p : points) {
            while (q.size() > 1) {
                int[] p1 = q.get(q.size() - 2);
                int[] p2 = q.getLast();
                if (det(p2[0] - p1[0], p2[1] - p1[1], p[0] - p2[0], p[1] - p2[1]) > 0) {
                    break;
                }
                q.removeLast();
            }
            q.add(p);
        }

        // 计算上凸包(从右到左)
        int downSize = q.size();
        // 注意下凸包的最后一个点,是上凸包的右边第一个点,所以从 n-2 开始遍历
        for (int i = points.length - 2; i >= 0; i--) {
            int[] p = points[i];
            while (q.size() > downSize) {
                int[] p1 = q.get(q.size() - 2);
                int[] p2 = q.getLast();
                if (det(p2[0] - p1[0], p2[1] - p1[1], p[0] - p2[0], p[1] - p2[1]) > 0) {
                    break;
                }
                q.removeLast();
            }
            q.add(p);
        }

        // 此时首尾是同一个点 points[0],需要去掉
        q.removeLast();
        return q;
    }

    private int det(int x1, int y1, int x2, int y2) {
        return x1 * y2 - y1 * x2;
    }

    private int area(List<int[]> ch, int i, int j, int k) {
        return det(ch.get(j)[0] - ch.get(i)[0], ch.get(j)[1] - ch.get(i)[1],
                ch.get(k)[0] - ch.get(i)[0], ch.get(k)[1] - ch.get(i)[1]);
    }
}
struct Vec {
    int x, y;

    Vec sub(const Vec& b) const {
        return {x - b.x, y - b.y};
    }

    int det(const Vec& b) const {
        return x * b.y - y * b.x;
    }
};

class Solution {
    vector<Vec> convexHull(vector<Vec>& points) {
        ranges::sort(points, {}, [](auto& p) { return pair(p.x, p.y); });

        vector<Vec> q;
        // 计算下凸包(从左到右)
        for (auto& p : points) {
            while (q.size() > 1 && q[q.size() - 1].sub(q[q.size() - 2]).det(p.sub(q[q.size() - 1])) <= 0) {
                q.pop_back();
            }
            q.push_back(p);
        }

        // 计算上凸包(从右到左)
        int down_size = q.size();
        // 注意下凸包的最后一个点,是上凸包的右边第一个点,所以从 n-2 开始遍历
        for (int i = (int) points.size() - 2; i >= 0; i--) {
            auto& p = points[i];
            while (q.size() > down_size && q[q.size() - 1].sub(q[q.size() - 2]).det(p.sub(q[q.size() - 1])) <= 0) {
                q.pop_back();
            }
            q.push_back(p);
        }

        // 此时首尾是同一个点 points[0],需要去掉
        q.pop_back();
        return q;
    }

public:
    double largestTriangleArea(vector<vector<int>>& points) {
        vector<Vec> a(points.size());
        for (int i = 0; i < points.size(); i++) {
            a[i] = {points[i][0], points[i][1]};
        }

        vector<Vec> ch = convexHull(a);

        auto area = [&](int i, int j, int k) -> int {
            return ch[j].sub(ch[i]).det(ch[k].sub(ch[i]));
        };

        int m = ch.size();
        int ans = 0;
        // 固定三角形的其中一个顶点 ch[i]
        for (int i = 0; i < m; i++) {
            // 同向双指针
            int k = i + 2;
            for (int j = i + 1; j < m - 1; j++) {
                while (k + 1 < m && area(i, j, k) < area(i, j, k + 1)) {
                    k++;
                }
                // 循环结束后,ch[k] 距离 ch[i]ch[j] 最远
                ans = max(ans, area(i, j, k)); // 注意这里没有除以 2
            }
        }
        return ans / 2.0;
    }
};
type vec struct{ x, y int }

func (a vec) sub(b vec) vec { return vec{a.x - b.x, a.y - b.y} }
func (a vec) det(b vec) int { return a.x*b.y - a.y*b.x }

func convexHull(points []vec) (q []vec) {
slices.SortFunc(points, func(a, b vec) int { return cmp.Or(a.x-b.x, a.y-b.y) })

// 计算下凸包(从左到右)
for _, p := range points {
for len(q) > 1 && q[len(q)-1].sub(q[len(q)-2]).det(p.sub(q[len(q)-1])) <= 0 {
q = q[:len(q)-1]
}
q = append(q, p)
}

// 计算上凸包(从右到左)
downSize := len(q)
// 注意下凸包的最后一个点,是上凸包的右边第一个点,所以从 n-2 开始遍历
for i := len(points) - 2; i >= 0; i-- {
p := points[i]
for len(q) > downSize && q[len(q)-1].sub(q[len(q)-2]).det(p.sub(q[len(q)-1])) <= 0 {
q = q[:len(q)-1]
}
q = append(q, p)
}

// 此时首尾是同一个点 points[0],需要去掉
q = q[:len(q)-1]
return
}

func largestTriangleArea(points [][]int) float64 {
a := make([]vec, len(points))
for i, p := range points {
a[i] = vec{p[0], p[1]}
}

ch := convexHull(a)
area := func(i, j, k int) int {
return ch[j].sub(ch[i]).det(ch[k].sub(ch[i]))
}

m := len(ch)
ans := 0
// 固定三角形的其中一个顶点 ch[i]
for i := range ch {
// 同向双指针
k := i + 2
for j := i + 1; j < m-1; j++ {
for k+1 < m && area(i, j, k) < area(i, j, k+1) {
k++
}
// 循环结束后,ch[k] 距离 ch[i]ch[j] 最远
ans = max(ans, area(i, j, k)) // 注意这里没有除以 2
}
}
return float64(ans) / 2
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{points}$ 的长度。枚举 $i$ 是 $\mathcal{O}(n)$,枚举 $j$ 和 $k$ 的同向双指针也是 $\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(n)$。

:本题有 $\mathcal{O}(n)$ 做法,见论文 Maximal Area Triangles in a Convex Polygon

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

教你一步步思考 DP:从记忆化搜索到递推到空间优化!(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年11月23日 10:04

一、寻找子问题

由于 $\textit{triangle}$ 每排的下标都是从 $0$ 开始的,示例 1 的正确图示应该是左对齐,即

$$
\begin{aligned}
& 2 \
& 3\ 4 \
& 6\ 5\ 7 \
& 4\ 1\ 8\ 3 \
\end{aligned}
$$

我们要解决的问题(原问题)是:

  • 从最上面的 $(0,0)$ 出发,移动到 $\textit{triangle}$ 的最后一排,路径上的元素之和的最小值。

考虑下一步往哪走:

  • 走到 $(1,0)$,那么需要解决的问题为:从 $(1,0)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值。
  • 走到 $(1,1)$,那么需要解决的问题为:从 $(1,1)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值。

这些问题都是和原问题相似的、规模更小的子问题,可以用递归解决。

二、状态定义与状态转移方程

根据上面的讨论,定义状态为 $\textit{dfs}(i,j)$,表示从 $(i,j)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值。

考虑下一步往哪走:

  • 走到 $(i+1,j)$,那么需要解决的问题为:从 $(i+1,j)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值,即 $\textit{dfs}(i+1,j)$。
  • 走到 $(i+1,j+1)$,那么需要解决的问题为:从 $(i+1,j+1)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值,即 $\textit{dfs}(i+1,j+1)$。

这两种情况取最小值,再加上当前位置的元素值 $\textit{triangle}[i][j]$ 就得到了 $\textit{dfs}(i,j)$,即

$$
\textit{dfs}(i,j) = \min(\textit{dfs}(i+1,j),\textit{dfs}(i+1,j+1)) + \textit{triangle}[i][j]
$$

递归边界:$\textit{dfs}(n-1,j)=\textit{triangle}[n-1][j]$。走到最后一排就无法再走了,路径上只有一个元素 $\textit{triangle}[n-1][j]$。

递归入口:$\textit{dfs}(0,0)$,这是原问题,也是答案。

三、递归搜索 + 保存递归返回值 = 记忆化搜索

考虑到整个递归过程中有大量重复递归调用(递归入参相同)。由于递归函数没有副作用,同样的入参无论计算多少次,算出来的结果都是一样的,因此可以用记忆化搜索来优化:

  • 如果一个状态(递归入参)是第一次遇到,那么可以在返回前,把状态及其结果记到一个 $\textit{memo}$ 数组中。
  • 如果一个状态不是第一次遇到($\textit{memo}$ 中保存的结果不等于 $\textit{memo}$ 的初始值),那么可以直接返回 $\textit{memo}$ 中保存的结果。

注意:$\textit{memo}$ 数组的初始值一定不能等于要记忆化的值!例如初始值设置为 $0$,并且要记忆化的 $\textit{dfs}(i,j)$ 也等于 $0$,那就没法判断 $0$ 到底表示第一次遇到这个状态,还是表示之前遇到过了,从而导致记忆化失效。一般把初始值设置为 $-1$。本题由于 $\textit{triangle}[i][j]$ 可以是负数,所以改用 $-\infty$ 作为初始值。

Python 用户可以无视上面这段,直接用 @cache 装饰器。

具体请看视频讲解 动态规划入门:从记忆化搜索到递推,其中包含把记忆化搜索 1:1 翻译成递推的技巧。

###py

class Solution:
    def minimumTotal(self, triangle: List[List[int]]) -> int:
        n = len(triangle)
        @cache  # 缓存装饰器,避免重复计算 dfs 的结果(记忆化)
        def dfs(i: int, j: int) -> int:
            if i == n - 1:
                return triangle[i][j]
            return min(dfs(i + 1, j), dfs(i + 1, j + 1)) + triangle[i][j]
        return dfs(0, 0)

###java

class Solution {
    public int minimumTotal(List<List<Integer>> triangle) {
        int n = triangle.size();
        int[][] memo = new int[n][n];
        for (int[] row : memo) {
            Arrays.fill(row, Integer.MIN_VALUE); // Integer.MIN_VALUE 表示没有计算过
        }
        return dfs(triangle, 0, 0, memo);
    }

    private int dfs(List<List<Integer>> triangle, int i, int j, int[][] memo) {
        if (i == triangle.size() - 1) {
            return triangle.get(i).get(j);
        }
        if (memo[i][j] != Integer.MIN_VALUE) { // 之前计算过
            return memo[i][j];
        }
        return memo[i][j] = Math.min(dfs(triangle, i + 1, j, memo),
                dfs(triangle, i + 1, j + 1, memo)) + triangle.get(i).get(j);
    }
}

###cpp

class Solution {
public:
    int minimumTotal(vector<vector<int>>& triangle) {
        int n = triangle.size();
        vector memo(n, vector<int>(n, INT_MIN)); // INT_MIN 表示没有计算过
        // lambda 递归
        auto dfs = [&](this auto&& dfs, int i, int j) -> int {
            if (i == n - 1) {
                return triangle[i][j];
            }
            int& res = memo[i][j]; // 注意这里是引用
            if (res != INT_MIN) { // 之前计算过
                return res;
            }
            return res = min(dfs(i + 1, j), dfs(i + 1, j + 1)) + triangle[i][j];
        };
        return dfs(0, 0);
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))

int dfs(int** triangle, int i, int j, int n, int** memo) {
    if (i == n - 1) {
        return triangle[i][j];
    }
    if (memo[i][j] != 0x3f3f3f3f) { // 之前计算过
        return memo[i][j];
    }
    int res1 = dfs(triangle, i + 1, j, n, memo);
    int res2 = dfs(triangle, i + 1, j + 1, n, memo);
    return memo[i][j] = MIN(res1, res2) + triangle[i][j];
}

int minimumTotal(int** triangle, int triangleSize, int* triangleColSize) {
    int** memo = malloc(triangleSize * sizeof(int*));
    for (int i = 0; i < triangleSize; i++) {
        memo[i] = malloc((i + 1) * sizeof(int));
        memset(memo[i], 0x3f, triangleColSize[i] * sizeof(int));
    }

    int ans = dfs(triangle, 0, 0, triangleSize, memo);

    for (int i = 0; i < triangleSize; i++) {
        free(memo[i]);
    }
    free(memo);
    return ans;
}

###go

func minimumTotal(triangle [][]int) int {
    n := len(triangle)
    memo := make([][]int, n)
    for i := range memo {
        memo[i] = make([]int, n)
        for j := range memo[i] {
            memo[i][j] = math.MinInt // math.MinInt 表示没有计算过
        }
    }
    var dfs func(int, int) int
    dfs = func(i, j int) int {
        if i == n-1 {
            return triangle[i][j]
        }
        p := &memo[i][j]
        if *p != math.MinInt { // 之前计算过
            return *p
        }
        *p = min(dfs(i+1, j), dfs(i+1, j+1)) + triangle[i][j]
        return *p
    }
    return dfs(0, 0)
}

###js

var minimumTotal = function(triangle) {
    const n = triangle.length;
    const memo = Array.from({ length: n }, () => Array(n));
    function dfs(i, j) {
        if (i === n - 1) {
            return triangle[i][j];
        }
        if (memo[i][j] !== undefined) { // 之前计算过
            return memo[i][j];
        }
        return memo[i][j] = Math.min(dfs(i + 1, j), dfs(i + 1, j + 1)) + triangle[i][j];
    }
    return dfs(0, 0);
};

###rust

impl Solution {
    pub fn minimum_total(triangle: Vec<Vec<i32>>) -> i32 {
        fn dfs(i: usize, j: usize, triangle: &[Vec<i32>], memo: &mut [Vec<i32>]) -> i32 {
            if i == triangle.len() - 1 {
                return triangle[i][j];
            }
            if memo[i][j] != i32::MIN { // 之前计算过
                return memo[i][j];
            }
            memo[i][j] = dfs(i + 1, j, triangle, memo).min(dfs(i + 1, j + 1, triangle, memo)) + triangle[i][j];
            memo[i][j]
        }
        let n = triangle.len();
        let mut memo = vec![vec![i32::MIN; n]; n]; // i32::MIN 表示没有计算过
        dfs(0, 0, &triangle, &mut memo)
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 为 $\textit{triangle}$ 的长度。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(n^2)$,单个状态的计算时间为 $\mathcal{O}(1)$,所以总的时间复杂度为 $\mathcal{O}(n^2)$。
  • 空间复杂度:$\mathcal{O}(n^2)$。保存多少状态,就需要多少空间。

四、1:1 翻译成递推

我们可以去掉递归中的「递」,只保留「归」的部分,即自底向上计算。

具体来说,$f[i][j]$ 的定义和 $\textit{dfs}(i,j)$ 的定义是一样的,都表示从 $(i,j)$ 出发,移动到 $\textit{triangle}$ 最后一排,路径上的元素之和的最小值。

相应的递推式(状态转移方程)也和 $\textit{dfs}$ 一样:

$$
f[i][j] = \min(f[i+1][j],f[i+1][j+1]) + \textit{triangle}[i][j]
$$

初始值 $f[n-1][j]=\textit{triangle}[n-1][j]$,翻译自递归边界 $\textit{dfs}(n-1,j)=\textit{triangle}[n-1][j]$。

答案为 $f[0][0]$,翻译自递归入口 $\textit{dfs}(0,0)$。

答疑

:如何思考循环顺序?什么时候要正序枚举,什么时候要倒序枚举?

:这里有一个通用的做法:盯着状态转移方程,想一想,要计算 $f[i][j]$,必须先把 $f[i+1][j]$ 和 $f[i+1][j+1]$ 算出来,那么只有 $i$ 从大到小枚举才能做到。对于 $j$ 来说,正序倒序都可以。

###py

class Solution:
    def minimumTotal(self, triangle: List[List[int]]) -> int:
        n = len(triangle)
        f = [[0] * (i + 1) for i in range(n)]
        f[-1] = triangle[-1]
        for i in range(n - 2, -1, -1):
            for j, x in enumerate(triangle[i]):
                f[i][j] = min(f[i + 1][j], f[i + 1][j + 1]) + x
        return f[0][0]

###java

class Solution {
    public int minimumTotal(List<List<Integer>> triangle) {
        int n = triangle.size();
        int[][] f = new int[n][n];
        for (int j = 0; j < n; j++) {
            f[n - 1][j] = triangle.get(n - 1).get(j);
        }
        for (int i = n - 2; i >= 0; i--) {
            for (int j = 0; j <= i; j++) {
                f[i][j] = Math.min(f[i + 1][j], f[i + 1][j + 1]) + triangle.get(i).get(j);
            }
        }
        return f[0][0];
    }
}

###cpp

class Solution {
public:
    int minimumTotal(vector<vector<int>>& triangle) {
        int n = triangle.size();
        vector f(n, vector<int>(n));
        f[n - 1] = triangle[n - 1];
        for (int i = n - 2; i >= 0; i--) {
            for (int j = 0; j <= i; j++) {
                f[i][j] = min(f[i + 1][j], f[i + 1][j + 1]) + triangle[i][j];
            }
        }
        return f[0][0];
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))

int minimumTotal(int** triangle, int triangleSize, int* triangleColSize) {
    int** f = malloc(triangleSize * sizeof(int*));
    for (int i = 0; i < triangleSize; i++) {
        f[i] = malloc((i + 1) * sizeof(int));
    }
    f[triangleSize - 1] = triangle[triangleSize - 1];
    for (int i = triangleSize - 2; i >= 0; i--) {
        for (int j = 0; j <= i; j++) {
            f[i][j] = MIN(f[i + 1][j], f[i + 1][j + 1]) + triangle[i][j];
        }
    }
    int ans = f[0][0];
    for (int i = 0; i < triangleSize; i++) {
        free(f[i]);
    }
    free(f);
    return ans;
}

###go

func minimumTotal(triangle [][]int) int {
    n := len(triangle)
    f := make([][]int, n)
    for i := range f {
        f[i] = make([]int, i+1)
    }
    f[n-1] = triangle[n-1]
    for i := n - 2; i >= 0; i-- {
        for j, x := range triangle[i] {
            f[i][j] = min(f[i+1][j], f[i+1][j+1]) + x
        }
    }
    return f[0][0]
}

###js

var minimumTotal = function(triangle) {
    const n = triangle.length;
    const f = Array.from({ length: n }, () => Array(n));
    f[n - 1] = triangle[n - 1];
    for (let i = n - 2; i >= 0; i--) {
        for (let j = 0; j <= i; j++) {
            f[i][j] = Math.min(f[i + 1][j], f[i + 1][j + 1]) + triangle[i][j];
        }
    }
    return f[0][0];
};

###rust

impl Solution {
    pub fn minimum_total(triangle: Vec<Vec<i32>>) -> i32 {
        let n = triangle.len();
        let mut f = vec![vec![0; n]; n];
        f[n - 1] = triangle[n - 1].clone();
        for i in (0..n - 1).rev() {
            for (j, &x) in triangle[i].iter().enumerate() {
                f[i][j] = f[i + 1][j].min(f[i + 1][j + 1]) + x;
            }
        }
        f[0][0]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 为 $\textit{triangle}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n^2)$。

五、空间优化

也可以直接把 $\textit{triangle}$ 当作 $f$ 数组。

###py

class Solution:
    def minimumTotal(self, f: List[List[int]]) -> int:
        for i in range(len(f) - 2, -1, -1):
            for j in range(i + 1):
                f[i][j] += min(f[i + 1][j], f[i + 1][j + 1])
        return f[0][0]

###java

class Solution {
    public int minimumTotal(List<List<Integer>> f) {
        for (int i = f.size() - 2; i >= 0; i--) {
            for (int j = 0; j <= i; j++) {
                f.get(i).set(j, f.get(i).get(j) + Math.min(f.get(i + 1).get(j), f.get(i + 1).get(j + 1)));
            }
        }
        return f.get(0).get(0);
    }
}

###cpp

class Solution {
public:
    int minimumTotal(vector<vector<int>>& f) {
        for (int i = f.size() - 2; i >= 0; i--) {
            for (int j = 0; j <= i; j++) {
                f[i][j] += min(f[i + 1][j], f[i + 1][j + 1]);
            }
        }
        return f[0][0];
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))

int minimumTotal(int** f, int triangleSize, int* triangleColSize) {
    for (int i = triangleSize - 2; i >= 0; i--) {
        for (int j = 0; j <= i; j++) {
            f[i][j] += MIN(f[i + 1][j], f[i + 1][j + 1]);
        }
    }
    return f[0][0];
}

###go

func minimumTotal(f [][]int) int {
    for i := len(f) - 2; i >= 0; i-- {
        for j := range f[i] {
            f[i][j] += min(f[i+1][j], f[i+1][j+1])
        }
    }
    return f[0][0]
}

###js

var minimumTotal = function(f) {
    for (let i = f.length - 2; i >= 0; i--) {
        for (let j = 0; j <= i; j++) {
            f[i][j] += Math.min(f[i + 1][j], f[i + 1][j + 1]);
        }
    }
    return f[0][0];
};

###rust

impl Solution {
    pub fn minimum_total(mut f: Vec<Vec<i32>>) -> i32 {
        for i in (0..f.len() - 1).rev() {
            for j in 0..=i {
                f[i][j] += f[i + 1][j].min(f[i + 1][j + 1]);
            }
        }
        f[0][0]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 为 $\textit{triangle}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面动态规划题单的「二、网格图 DP」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

模拟长除法(Python/Java/C++/Go/JS/Rust)

作者 endlesscheng
2025年9月24日 08:19

我们可以用长除法计算小数。来看两个例子。

例一:9/8 = 1.125

读者可以先在纸上算算 $9/8$,方便理解下述流程。

整数部分为 $\left\lfloor\dfrac{9}{8}\right\rfloor = 1$,初始余数为 $r = 9\bmod 8 = 1$。

  1. $r=1$。计算商 $\left\lfloor\dfrac{r\cdot 10}{8}\right\rfloor = 1$,得到小数点后第一位,更新 $r$ 为 $(r\cdot 10)\bmod 8 = 2$。
  2. $r=2$。计算商 $\left\lfloor\dfrac{r\cdot 10}{8}\right\rfloor = 2$,得到小数点后第二位,更新 $r$ 为 $(r\cdot 10)\bmod 8 = 4$。
  3. $r=4$。计算商 $\left\lfloor\dfrac{r\cdot 10}{8}\right\rfloor = 5$,得到小数点后第三位,更新 $r$ 为 $(r\cdot 10)\bmod 8 = 0$。
  4. $r=0$,说明 $9/8$ 是有限小数。

例二:3/14 = 0.2(142857)

整数部分为 $\left\lfloor\dfrac{3}{14}\right\rfloor = 0$,初始余数为 $r = 3\bmod 14 = 3$。

  1. $r=3$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 2$,得到小数点后第一位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 2$。
  2. $r=2$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 1$,得到小数点后第二位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 6$。
  3. $r=6$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 4$,得到小数点后第三位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 4$。
  4. $r=4$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 2$,得到小数点后第四位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 12$。
  5. $r=12$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 8$,得到小数点后第五位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 8$。
  6. $r=8$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 5$,得到小数点后第六位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 10$。
  7. $r=10$。计算商 $\left\lfloor\dfrac{r\cdot 10}{14}\right\rfloor = 7$,得到小数点后第七位,更新 $r$ 为 $(r\cdot 10)\bmod 14 = 2$。
  8. $r=2$,等于第 2 步开始时的余数。如果继续计算,我们会重复上面的第 2~7 步。这意味着我们找到了循环节。

根据 $r=2$ 首次出现的位置,可以知道循环节之前的小数为 $2$,循环节为 $142857$。

怎么知道进入循环了?

用一个哈希表记录,哈希表的 key 是余数 $r$,value 是这个余数对应着第几位小数。

计算商(添加到答案),更新 $r$ 后:

  • 如果 $r$ 在哈希表中,说明有循环节,根据哈希表中记录的小数位置,可以得到循环节之前的小数,以及循环节的内容。
  • 如果 $r$ 不在哈希表中,往哈希表中插入 $r$ 以及此时我们在算第几位小数。
  • 特别地,如果 $r=0$,说明没有循环节,退出循环。
class Solution:
    def fractionToDecimal(self, numerator: int, denominator: int) -> str:
        sign = '-' if numerator * denominator < 0 else ''
        numerator = abs(numerator)  # 保证下面的计算过程不产生负数
        denominator = abs(denominator)

        # 计算整数部分 q 和初始余数 r
        q, r = divmod(numerator, denominator)
        if r == 0:  # 没有小数部分
            return sign + str(q)

        ans = [sign + str(q) + '.']
        r_to_pos = {r: 1}  # 初始余数对应小数点后第一位
        while r:
            # 计算小数点后的数字 q,更新 r
            q, r = divmod(r * 10, denominator)
            ans.append(str(q))
            if r in r_to_pos:  # 有循环节
                pos = r_to_pos[r]  # 循环节的开始位置
                return f"{''.join(ans[:pos])}({''.join(ans[pos:])})"
            r_to_pos[r] = len(ans)  # 记录余数对应位置
        return ''.join(ans)  # 有限小数
class Solution {
    public String fractionToDecimal(int numerator, int denominator) {
        long a = numerator;
        long b = denominator;
        String sign = a * b < 0 ? "-" : "";
        a = Math.abs(a); // 保证下面的计算过程不产生负数
        b = Math.abs(b);

        // 计算整数部分 q 和初始余数 r
        long q = a / b;
        long r = a % b;
        if (r == 0) { // 没有小数部分
            return sign + q;
        }

        StringBuilder ans = new StringBuilder(sign).append(q).append('.');
        Map<Long, Integer> rToPos = new HashMap<>();
        rToPos.put(r, ans.length()); // 记录初始余数对应位置
        while (r > 0) {
            // 计算小数点后的数字 q,更新 r
            r *= 10;
            q = r / b;
            r %= b;
            ans.append(q);
            if (rToPos.containsKey(r)) { // 有循环节
                int pos = rToPos.get(r); // 循环节的开始位置
                return ans.substring(0, pos) + "(" + ans.substring(pos) + ")";
            }
            rToPos.put(r, ans.length()); // 记录余数对应位置
        }
        return ans.toString(); // 有限小数
    }
}
class Solution {
public:
    string fractionToDecimal(int numerator, int denominator) {
        long long a = numerator, b = denominator;
        string sign = a * b < 0 ? "-" : "";
        a = abs(a); // 保证下面的计算过程不产生负数
        b = abs(b);

        // 计算整数部分 q 和初始余数 r
        long long q = a / b, r = a % b;
        if (r == 0) { // 没有小数部分
            return sign + to_string(q);
        }

        string ans = sign + to_string(q) + ".";
        unordered_map<long long, int> r_to_pos = {{r, ans.size()}}; // 记录初始余数对应位置
        while (r) {
            // 计算小数点后的数字 q,更新 r
            r *= 10;
            q = r / b;
            r %= b;
            ans += '0' + q;
            if (r_to_pos.contains(r)) { // 有循环节
                int pos = r_to_pos[r]; // 循环节的开始位置
                return ans.substr(0, pos) + "(" + ans.substr(pos) + ")";
            }
            r_to_pos[r] = ans.size(); // 记录余数对应位置
        }
        return ans; // 有限小数
    }
};
func fractionToDecimal(numerator, denominator int) string {
    sign := ""
    if numerator*denominator < 0 {
        sign = "-"
    }
    numerator = abs(numerator) // 保证下面的计算过程不产生负数
    denominator = abs(denominator)

    // 计算整数部分 q 和初始余数 r
    q, r := numerator/denominator, numerator%denominator
    if r == 0 { // 没有小数部分
        return sign + strconv.Itoa(q)
    }

    ans := []byte(sign + strconv.Itoa(q) + ".")
    rToPos := map[int]int{r: len(ans)} // 记录初始余数对应位置
    for r != 0 {
        // 计算小数点后的数字 q,更新 r
        r *= 10
        q = r / denominator
        r %= denominator
        ans = append(ans, '0'+byte(q))
        if pos, ok := rToPos[r]; ok { // 有循环节,pos 为循环节的开始位置
            return string(ans[:pos]) + "(" + string(ans[pos:]) + ")"
        }
        rToPos[r] = len(ans) // 记录余数对应位置
    }
    return string(ans) // 有限小数
}

func abs(x int) int { if x < 0 { return -x }; return x }
var fractionToDecimal = function(numerator, denominator) {
    const sign = numerator * denominator < 0 ? "-" : "";
    numerator = Math.abs(numerator); // 保证下面的计算过程不产生负数
    denominator = Math.abs(denominator);

    // 计算整数部分 q 和初始余数 r
    let q = Math.floor(numerator / denominator);
    let r = numerator % denominator;
    if (r === 0) { // 没有小数部分
        return sign + String(q);
    }

    const ans = [sign + String(q) + "."];
    const r_to_pos = new Map();
    r_to_pos.set(r, 1); // 初始余数对应小数点后第一位
    while (r) {
        // 计算小数点后的数字 q,更新 r
        r *= 10;
        q = Math.floor(r / denominator);
        r = r % denominator;
        ans.push(String(q));
        if (r_to_pos.has(r)) { // 有循环节
            const pos = r_to_pos.get(r); // 循环节的开始位置
            return ans.slice(0, pos).join("") + "(" + ans.slice(pos).join("") + ")";
        }
        r_to_pos.set(r, ans.length); // 记录余数对应位置
    }
    return ans.join(""); // 有限小数
};
use std::collections::HashMap;

impl Solution {
    pub fn fraction_to_decimal(numerator: i32, denominator: i32) -> String {
        let mut a = numerator as i64;
        let mut b = denominator as i64;
        let sign = if a * b < 0 { "-" } else { "" };
        a = a.abs(); // 保证下面的计算过程不产生负数
        b = b.abs();

        // 计算整数部分 q 和初始余数 r
        let mut q = a / b;
        let mut r = a % b;
        if r == 0 { // 没有小数部分
            return format!("{}{}", sign, q);
        }

        let mut ans = format!("{}{}.", sign, q);
        let mut r_to_pos = HashMap::new();
        r_to_pos.insert(r, ans.len()); // 记录初始余数对应位置
        while r != 0 {
            // 计算小数点后的数字 q,更新 r
            r *= 10;
            q = r / b;
            r %= b;
            ans.push((b'0' + q as u8) as char);
            if let Some(&pos) = r_to_pos.get(&r) { // 有循环节,pos 为循环节的开始位置
                return format!("{}({})", &ans[..pos], &ans[pos..]);
            }
            r_to_pos.insert(r, ans.len()); // 记录余数对应位置
        }
        ans // 有限小数
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(N)$。其中 $N = \min(|\textit{denominator}|, 10^4)$。至多有 $|\textit{denominator}|$ 个不同的余数,最多循环 $\mathcal{O}(|\textit{denominator}|)$ 次。不过,本题保证答案长度小于 $10^4$。
  • 空间复杂度:$\mathcal{O}(N)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

库函数写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年9月23日 07:50

把字符串按照 $\texttt{.}$ 分割,分割出的子串转成整数,我们可以得到两个整数列表。

示例 1 的列表为 $[1,2]$ 和 $[1,10]$。

示例 2 的列表均为 $[1,1]$。

示例 3 的列表(短的列表在末尾补 $0$)均为 $[1,0,0,0]$。

问题相当于比较这两个列表的字典序:

  • 从左到右遍历两个列表(分别记作 $a$ 和 $b$)。
  • 如果 $a[i] < b[i]$,返回 $-1$。
  • 如果 $a[i] > b[i]$,返回 $1$。
  • 否则继续向后遍历。
  • 如果遍历过程中没有返回,说明两个列表相同,返回 $0$。

###py

class Solution:
    def compareVersion(self, version1: str, version2: str) -> int:
        a = map(int, version1.split('.'))
        b = map(int, version2.split('.'))
        for ver1, ver2 in zip_longest(a, b, fillvalue=0):
            if ver1 != ver2:
                return -1 if ver1 < ver2 else 1
        return 0

###java

class Solution {
    public int compareVersion(String version1, String version2) {
        String[] a = version1.split("\\.");
        String[] b = version2.split("\\.");
        int n = a.length;
        int m = b.length;
        for (int i = 0; i < n || i < m; i++) {
            int ver1 = i < n ? Integer.parseInt(a[i]) : 0;
            int ver2 = i < m ? Integer.parseInt(b[i]) : 0;
            if (ver1 != ver2) {
                return ver1 < ver2 ? -1 : 1;
            }
        }
        return 0;
    }
}

###cpp

class Solution {
    vector<string> split(const string& s, char delim) {
        vector<string> res;
        stringstream ss(s);
        string token;
        while (getline(ss, token, delim)) {
            res.push_back(token);
        }
        return res;
    }

public:
    int compareVersion(string version1, string version2) {
        auto a = split(version1, '.');
        auto b = split(version2, '.');
        int n = a.size(), m = b.size();
        for (int i = 0; i < n || i < m; i++) {
            int ver1 = i < n ? stoi(a[i]) : 0;
            int ver2 = i < m ? stoi(b[i]) : 0;
            if (ver1 != ver2) {
                return ver1 < ver2 ? -1 : 1;
            }
        }
        return 0;
    }
};

###c

int compareVersion(char* version1, char* version2) {
    char* saveptr1;
    char* saveptr2;
    char* token1 = strtok_r(version1, ".", &saveptr1);
    char* token2 = strtok_r(version2, ".", &saveptr2);

    while (token1 != NULL || token2 != NULL) {
        int ver1 = token1 ? atoi(token1) : 0;
        int ver2 = token2 ? atoi(token2) : 0;
        if (ver1 != ver2) {
            return ver1 < ver2 ? -1 : 1;
        }
        token1 = strtok_r(NULL, ".", &saveptr1);
        token2 = strtok_r(NULL, ".", &saveptr2);
    }

    return 0;
}

###go

func compareVersion(version1, version2 string) int {
    a := strings.Split(version1, ".")
    b := strings.Split(version2, ".")
    n, m := len(a), len(b)
    for i := range max(n, m) {
        ver1 := 0
        if i < n {
            ver1, _ = strconv.Atoi(a[i])
        }
        ver2 := 0
        if i < m {
            ver2, _ = strconv.Atoi(b[i])
        }
        c := cmp.Compare(ver1, ver2)
        if c != 0 {
            return c
        }
    }
    return 0
}

###js

var compareVersion = function(version1, version2) {
    const a = version1.split(".");
    const b = version2.split(".");
    const n = a.length, m = b.length;
    for (let i = 0; i < n || i < m; i++) {
        const ver1 = i < n ? parseInt(a[i]) : 0;
        const ver2 = i < m ? parseInt(b[i]) : 0;
        if (ver1 !== ver2) {
            return ver1 < ver2 ? -1 : 1;
        }
    }
    return 0;
};

###rust

impl Solution {
    pub fn compare_version(version1: String, version2: String) -> i32 {
        let a = version1.split('.').collect::<Vec<_>>();
        let b = version2.split('.').collect::<Vec<_>>();
        let n = a.len();
        let m = b.len();
        for i in 0..n.max(m) {
            let ver1 = if i < n { a[i].parse::<i32>().unwrap() } else { 0 };
            let ver2 = if i < m { b[i].parse::<i32>().unwrap() } else { 0 };
            if ver1 != ver2 {
                return if ver1 < ver2 { -1 } else { 1 };
            }
        }
        0
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n+m)$,其中 $n$ 是 $\textit{version}_1$ 的长度,$m$ 是 $\textit{version}_2$ 的长度。
  • 空间复杂度:$\mathcal{O}(n+m)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

一次遍历(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年1月14日 12:38

遍历 $\textit{nums}$,同时用哈希表统计每个元素的出现次数,并维护出现次数的最大值 $\textit{maxCnt}$:

  • 如果出现次数 $c > \textit{maxCnt}$,那么更新 $\textit{maxCnt}=c$,答案 $\textit{ans} = c$。
  • 如果出现次数 $c = \textit{maxCnt}$,那么答案增加 $c$。

###py

class Solution:
    def maxFrequencyElements(self, nums: List[int]) -> int:
        cnt = defaultdict(int)
        ans = max_cnt = 0
        for x in nums:
            cnt[x] += 1
            c = cnt[x]
            if c > max_cnt:
                ans = max_cnt = c
            elif c == max_cnt:
                ans += c
        return ans

###java

class Solution {
    public int maxFrequencyElements(int[] nums) {
        Map<Integer, Integer> cnt = new HashMap<>(); // 更快的写法见【Java 数组】
        int maxCnt = 0;
        int ans = 0;
        for (int x : nums) {
            int c = cnt.merge(x, 1, Integer::sum); // c = ++cnt[x]
            if (c > maxCnt) {
                ans = maxCnt = c;
            } else if (c == maxCnt) {
                ans += c;
            }
        }
        return ans;
    }
}

###java

class Solution {
    public int maxFrequencyElements(int[] nums) {
        int mx = 0;
        for (int x : nums) {
            mx = Math.max(mx, x);
        }
        
        int[] cnt = new int[mx + 1];
        int maxCnt = 0;
        int ans = 0;
        for (int x : nums) {
            int c = ++cnt[x];
            if (c > maxCnt) {
                ans = maxCnt = c;
            } else if (c == maxCnt) {
                ans += c;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxFrequencyElements(vector<int>& nums) {
        unordered_map<int, int> cnt;
        int ans = 0, max_cnt = 0;
        for (int x : nums) {
            int c = ++cnt[x];
            if (c > max_cnt) {
                ans = max_cnt = c;
            } else if (c == max_cnt) {
                ans += c;
            }
        }
        return ans;
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

int maxFrequencyElements(int* nums, int numsSize) {
    int mx = 0; // 直接初始化 mx = 100 可以做到一次遍历
    for (int i = 0; i < numsSize; i++) {
        mx = MAX(mx, nums[i]);
    }

    int* cnt = calloc(mx + 1, sizeof(int));
    int max_cnt = 0;
    int ans = 0;

    for (int i = 0; i < numsSize; i++) {
        int c = ++cnt[nums[i]];
        if (c > max_cnt) {
            ans = c;
            max_cnt = c;
        } else if (c == max_cnt) {
            ans += c;
        }
    }

    free(cnt);
    return ans;
}

###go

func maxFrequencyElements(nums []int) (ans int) {
cnt := map[int]int{}
maxCnt := 0
for _, x := range nums {
cnt[x]++
c := cnt[x]
if c > maxCnt {
maxCnt = c
ans = c
} else if c == maxCnt {
ans += c
}
}
return
}

###js

var maxFrequencyElements = function(nums) {
    const cnt = new Map();
    let ans = 0, maxCnt = 0;
    for (const x of nums) {
        const c = (cnt.get(x) ?? 0) + 1;
        cnt.set(x, c);
        if (c > maxCnt) {
            ans = maxCnt = c;
        } else if (c === maxCnt) {
            ans += c;
        }
    }
    return ans;
};

###rust

use std::collections::HashMap;

impl Solution {
    pub fn max_frequency_elements(nums: Vec<i32>) -> i32 {
        let mut cnt = HashMap::new();
        let mut max_cnt = 0;
        let mut ans = 0;
        for x in nums {
            let e = cnt.entry(x).or_insert(0);
            *e += 1;
            let c = *e;
            if c > max_cnt {
                max_cnt = c;
                ans = c;
            } else if c == max_cnt {
                ans += c;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

❌
❌