普通视图

发现新文章,点击刷新页面。
今天 — 2025年4月7日LeetCode 每日一题题解

每日一题-分割等和子集🟡

2025年4月7日 00:00

给你一个 只包含正整数 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。

 

示例 1:

输入:nums = [1,5,11,5]
输出:true
解释:数组可以分割成 [1, 5, 5] 和 [11] 。

示例 2:

输入:nums = [1,2,3,5]
输出:false
解释:数组不能分割成两个元素和相等的子集。

 

提示:

  • 1 <= nums.length <= 200
  • 1 <= nums[i] <= 100

三种写法:记忆化搜索 / 递推 / bitset 优化(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2024年5月21日 15:21

一、分析

设 $\textit{nums}$ 的元素和为 $s$。

两个子集的元素和相等,意味着:

  1. 把 $\textit{nums}$ 分成两个子集,每个子集的元素和恰好等于 $\dfrac{s}{2}$。
  2. $s$ 必须是偶数。

如果 $s$ 是奇数,$\dfrac{s}{2}$ 不是整数,直接返回 $\texttt{false}$。

如果 $s$ 是偶数,问题相当于:

  • 能否从 $\textit{nums}$ 中选出一个子序列,其元素和恰好等于 $\dfrac{s}{2}$?

这可以用「恰好装满」型 0-1 背包解决,请看视频讲解:0-1 背包和完全背包【基础算法精讲 18】。制作不易,欢迎点赞关注~

二、记忆化搜索

定义 $\textit{dfs}(i,j)$ 表示能否从 $\textit{nums}[0]$ 到 $\textit{nums}[i]$ 中选出一个和恰好等于 $j$ 的子序列。

考虑 $\textit{nums}[i]$ 选或不选:

  • 选:问题变成能否从 $\textit{nums}[0]$ 到 $\textit{nums}[i-1]$ 中选出一个和恰好等于 $j-\textit{nums}[i]$ 的子序列,即 $\textit{dfs}(i-1,j-\textit{nums}[i])$。
  • 不选:问题变成能否从 $\textit{nums}[0]$ 到 $\textit{nums}[i-1]$ 中选出一个和恰好等于 $j$ 的子序列,即 $\textit{dfs}(i-1,j)$。

这两个只要有一个成立,$\textit{dfs}(i,j)$ 就是 $\texttt{true}$。所以有

$$
\textit{dfs}(i,j) = \textit{dfs}(i-1,j-\textit{nums}[i]) \vee \textit{dfs}(i-1,j)
$$

其中 $\vee$ 即编程语言中的 ||。代码实现时,可以只在 $j\ge \textit{nums}[i]$ 时才调用 $\textit{dfs}(i-1,j-\textit{nums}[i])$,因为任何子序列的和都不会是负的。

递归边界:$\textit{dfs}(-1,0) = \texttt{true},\ \textit{dfs}(-1,>0) = \texttt{false}$。如果所有数都考虑完了,且 $j=0$,表示找到了一个和为 $s/2$ 的子序列,返回 $\texttt{true}$,否则返回 $\texttt{false}$。

递归入口:$\textit{dfs}(n-1,s/2)$,即答案。

###py

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        @cache  # 缓存装饰器,避免重复计算 dfs 的结果(记忆化)
        def dfs(i: int, j: int) -> bool:
            if i < 0:
                return j == 0
            return j >= nums[i] and dfs(i - 1, j - nums[i]) or dfs(i - 1, j)

        s = sum(nums)
        return s % 2 == 0 and dfs(len(nums) - 1, s // 2)

###java

class Solution {
    public boolean canPartition(int[] nums) {
        int s = 0;
        for (int num : nums) {
            s += num;
        }
        if (s % 2 != 0) {
            return false;
        }
        int n = nums.length;
        int[][] memo = new int[n][s / 2 + 1];
        for (int[] row : memo) {
            Arrays.fill(row, -1); // -1 表示没有计算过
        }
        return dfs(n - 1, s / 2, nums, memo);
    }

    private boolean dfs(int i, int j, int[] nums, int[][] memo) {
        if (i < 0) {
            return j == 0;
        }
        if (memo[i][j] != -1) { // 之前计算过
            return memo[i][j] == 1;
        }
        boolean res = j >= nums[i] && dfs(i - 1, j - nums[i], nums, memo) || dfs(i - 1, j, nums, memo);
        memo[i][j] = res ? 1 : 0; // 记忆化
        return res;
    }
}

###cpp

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int s = reduce(nums.begin(), nums.end());
        if (s % 2) {
            return false;
        }
        int n = nums.size();
        vector memo(n, vector<int>(s / 2 + 1, -1)); // -1 表示没有计算过
        auto dfs = [&](this auto&& dfs, int i, int j) -> bool {
            if (i < 0) {
                return j == 0;
            }
            int& res = memo[i][j]; // 注意这里是引用
            if (res != -1) { // 之前计算过
                return res;
            }
            return res = j >= nums[i] && dfs(i - 1, j - nums[i]) || dfs(i - 1, j);
        };
        return dfs(n - 1, s / 2);
    }
};

###c

bool dfs(int i, int j, int* nums, int** memo) {
    if (i < 0) {
        return j == 0;
    }
    if (memo[i][j] != -1) { // 之前计算过
        return memo[i][j] == 1;
    }
    return memo[i][j] = j >= nums[i] && dfs(i - 1, j - nums[i], nums, memo) || dfs(i - 1, j, nums, memo);
}

bool canPartition(int* nums, int numsSize) {
    int s = 0;
    for (int i = 0; i < numsSize; i++) {
        s += nums[i];
    }
    if (s % 2) {
        return false;
    }
    int** memo = malloc(numsSize * sizeof(int*));
    for (int i = 0; i < numsSize; i++) {
        memo[i] = malloc((s / 2 + 1) * sizeof(int));
        memset(memo[i], -1, (s / 2 + 1) * sizeof(int)); // -1 表示没有计算过
    }
    int ans = dfs(numsSize - 1, s / 2, nums, memo);
    for (int i = 0; i < numsSize; i++) {
        free(memo[i]);
    }
    free(memo);
    return ans;
}

###go

func canPartition(nums []int) bool {
    s := 0
    for _, x := range nums {
        s += x
    }
    if s%2 != 0 {
        return false
    }
    n := len(nums)
    memo := make([][]int8, n)
    for i := range memo {
        memo[i] = make([]int8, s/2+1)
        for j := range memo[i] {
            memo[i][j] = -1 // -1 表示没有计算过
        }
    }
    var dfs func(int, int) bool
    dfs = func(i, j int) bool {
        if i < 0 {
            return j == 0
        }
        p := &memo[i][j]
        if *p != -1 { // 之前计算过
            return *p == 1
        }
        res := j >= nums[i] && dfs(i-1, j-nums[i]) || dfs(i-1, j)
        if res {
            *p = 1 // 记忆化
        } else {
            *p = 0
        }
        return res
    }
    return dfs(n-1, s/2)
}

###js

const canPartition = function(nums) {
    const s = _.sum(nums);
    if (s % 2) {
        return false;
    }
    const n = nums.length;
    const memo = Array.from({length: n}, () => Array(s / 2 + 1).fill(-1)); // -1 表示没有计算过
    function dfs(i, j) {
        if (i < 0) {
            return j === 0;
        }
        if (memo[i][j] !== -1) { // 之前计算过
            return memo[i][j] === 1;
        }
        const res = j >= nums[i] && dfs(i - 1, j - nums[i]) || dfs(i - 1, j);
        memo[i][j] = res ? 1 : 0; // 记忆化
        return res;
    }
    return dfs(n - 1, s / 2);
};

###rust

impl Solution {
    pub fn can_partition(nums: Vec<i32>) -> bool {
        let s = nums.iter().sum::<i32>() as usize;
        if s % 2 != 0 {
            return false;
        }
        fn dfs(i: usize, j: usize, nums: &[i32], memo: &mut [Vec<i32>]) -> bool {
            if i == nums.len() {
                return j == 0;
            }
            if memo[i][j] != -1 { // 之前计算过
                return memo[i][j] == 1;
            }
            let x = nums[i] as usize;
            let res = j >= x && dfs(i + 1, j - x, nums, memo) || dfs(i + 1, j, nums, memo);
            memo[i][j] = if res { 1 } else { 0 }; // 记忆化
            res
        }
        let n = nums.len();
        let mut memo = vec![vec![-1; s / 2 + 1]; n]; // -1 表示没有计算过
        // 为方便起见,改成 i 从 0 开始
        dfs(0, s / 2, &nums, &mut memo)
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(ns)$,其中 $n$ 是 $\textit{nums}$ 的长度,$s$ 是 $\textit{nums}$ 的元素和(的一半)。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(ns)$,单个状态的计算时间为 $\mathcal{O}(1)$,所以动态规划的时间复杂度为 $\mathcal{O}(ns)$。
  • 空间复杂度:$\mathcal{O}(ns)$。保存多少状态,就需要多少空间。

三、1:1 翻译成递推

我们可以去掉递归中的「递」,只保留「归」的部分,即自底向上计算。

具体来说,$f[i][j]$ 的定义和 $\textit{dfs}(i,j)$ 的定义是一样的,都表示能否从 $\textit{nums}[0]$ 到 $\textit{nums}[i]$ 中选出一个和恰好等于 $j$ 的子序列。

相应的递推式(状态转移方程)也和 $\textit{dfs}$ 一样:

$$
f[i][j] = f[i-1][j-\textit{nums}[i]] \vee f[i-1][j]
$$

但是,这种定义方式没有状态能表示递归边界,即 $i=-1$ 的情况。

解决办法:在二维数组 $f$ 的最上边插入一排状态,那么其余状态全部向下偏移一位,把 $f[i]$ 改为 $f[i+1]$,把 $f[i-1]$ 改为 $f[i]$。

修改后 $f[i+1][j]$ 表示能否从 $\textit{nums}[0]$ 到 $\textit{nums}[i]$ 中选出一个和为 $j$ 的子序列。$f[0]$ 对应递归边界。

修改后的递推式为

$$
f[i+1][j] = f[i][j-\textit{nums}[i]] \vee f[i][j]
$$

:为什么 $\textit{nums}$ 的下标不用变?

:既然是在 $f$ 的最上边插入一排状态,那么就只需要修改和 $f$ 有关的下标,其余任何逻辑都无需修改。或者说,如果把 $\textit{nums}[i]$ 也改成 $\textit{nums}[i+1]$,那么 $\textit{nums}[0]$ 就被我们给忽略掉了。

初始值 $f[0][0]=\texttt{true}$,翻译自递归边界 $\textit{dfs}(-1,0)=\texttt{true}$。其余值初始化成 $\texttt{false}$。

答案为 $f[n][s/2]$,翻译自递归入口 $\textit{dfs}(n-1,s/2)$。

###py

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        s = sum(nums)
        if s % 2:
            return False
        s //= 2  # 注意这里把 s 减半了
        n = len(nums)
        f = [[False] * (s + 1) for _ in range(n + 1)]
        f[0][0] = True
        for i, x in enumerate(nums):
            for j in range(s + 1):
                f[i + 1][j] = j >= x and f[i][j - x] or f[i][j]
        return f[n][s]

###java

class Solution {
    public boolean canPartition(int[] nums) {
        int s = 0;
        for (int num : nums) {
            s += num;
        }
        if (s % 2 != 0) {
            return false;
        }
        s /= 2; // 注意这里把 s 减半了
        int n = nums.length;
        boolean[][] f = new boolean[n + 1][s + 1];
        f[0][0] = true;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            for (int j = 0; j <= s; j++) {
                f[i + 1][j] = j >= x && f[i][j - x] || f[i][j];
            }
        }
        return f[n][s];
    }
}

###cpp

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int s = reduce(nums.begin(), nums.end());
        if (s % 2) {
            return false;
        }
        s /= 2; // 注意这里把 s 减半了
        int n = nums.size();
        vector f(n + 1, vector<int>(s + 1));
        f[0][0] = true;
        for (int i = 0; i < n; i++) {
            int x = nums[i];
            for (int j = 0; j <= s; j++) {
                f[i + 1][j] = j >= x && f[i][j - x] || f[i][j];
            }
        }
        return f[n][s];
    }
};

###c

bool canPartition(int* nums, int numsSize) {
    int s = 0;
    for (int i = 0; i < numsSize; i++) {
        s += nums[i];
    }
    if (s % 2) {
        return false;
    }
    s = s / 2 + 1;
    bool* f = calloc((numsSize + 1) * s, sizeof(bool));
    f[0] = true;
    for (int i = 0; i < numsSize; i++) {
        int x = nums[i];
        for (int j = 0; j < s; j++) {
            f[(i + 1) * s + j] = j >= x && f[i * s + j - x] || f[i * s + j];
        }
    }
    bool ans = f[(numsSize + 1) * s - 1];
    free(f);
    return ans;
}

###go

func canPartition(nums []int) bool {
    s := 0
    for _, num := range nums {
        s += num
    }
    if s%2 != 0 {
        return false
    }
    s /= 2 // 注意这里把 s 减半了
    n := len(nums)
    f := make([][]bool, n+1)
    for i := range f {
        f[i] = make([]bool, s+1)
    }
    f[0][0] = true
    for i, x := range nums {
        for j := 0; j <= s; j++ {
            f[i+1][j] = j >= x && f[i][j-x] || f[i][j]
        }
    }
    return f[n][s]
}

###js

var canPartition = function(nums) {
    let s = _.sum(nums);
    if (s % 2) {
        return false;
    }
    s /= 2; // 注意这里把 s 减半了
    const n = nums.length;
    const f = Array.from({length: n + 1}, () => Array(s + 1).fill(false));
    f[0][0] = true;
    for (let i = 0; i < n; i++) {
        const x = nums[i];
        for (let j = 0; j <= s; j++) {
            f[i + 1][j] = j >= x && f[i][j - x] || f[i][j];
        }
    }
    return f[n][s];
};

###rust

impl Solution {
    pub fn can_partition(nums: Vec<i32>) -> bool {
        let s = nums.iter().sum::<i32>();
        if s % 2 != 0 {
            return false;
        }
        let s = s as usize / 2; // 注意这里把 s 减半了
        let n = nums.len();
        let mut f = vec![vec![false; s + 1]; n + 1];
        f[0][0] = true;
        for (i, &x) in nums.iter().enumerate() {
            let x = x as usize;
            for j in 0..=s {
                f[i + 1][j] = j >= x && f[i][j - x] || f[i][j];
            }
        }
        f[n][s]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(ns)$,其中 $n$ 是 $\textit{nums}$ 的长度,$s$ 是 $\textit{nums}$ 的元素和(的一半)。
  • 空间复杂度:$\mathcal{O}(ns)$。

四、空间优化

观察上面的状态转移方程,在计算 $f[i+1]$ 时,只会用到 $f[i]$,不会用到比 $i$ 更早的状态。

因此可以去掉第一个维度,反复利用同一个一维数组。

状态转移方程改为

$$
f[j] = f[j] \vee f[j-\textit{nums}[i]]
$$

初始值 $f[0]= \texttt{true}$。

答案为 $f[s/2]$。

具体例子,以及为什么要倒序遍历 $j$,请看 0-1 背包视频讲解

此外,设前 $i$ 个数的和为 $s'$,由于子序列的元素和不可能比 $s'$ 还大,$j$ 可以从 $\min(s',s/2)$ 开始倒着枚举。比如 $\textit{nums}$ 前两个数的和等于 $5$,那么我们无法在前两个数中,选出一个元素和大于 $5$ 的子序列,所以对于 $j>5$ 的 $f$ 值,一定是 $\texttt{false}$,无需计算。

此外,可以在循环中提前判断 $f[s/2]$ 是否为 $\texttt{true}$,是就直接返回 $\texttt{true}$。

###py

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        s = sum(nums)
        if s % 2:
            return False
        s //= 2  # 注意这里把 s 减半了
        f = [True] + [False] * s
        s2 = 0
        for i, x in enumerate(nums):
            s2 = min(s2 + x, s)
            for j in range(s2, x - 1, -1):
                f[j] = f[j] or f[j - x]
            if f[s]:
                return True
        return False

###java

class Solution {
    public boolean canPartition(int[] nums) {
        int s = 0;
        for (int x : nums) {
            s += x;
        }
        if (s % 2 != 0) {
            return false;
        }
        s /= 2; // 注意这里把 s 减半了
        boolean[] f = new boolean[s + 1];
        f[0] = true;
        int s2 = 0;
        for (int x : nums) {
            s2 = Math.min(s2 + x, s);
            for (int j = s2; j >= x; j--) {
                f[j] = f[j] || f[j - x];
            }
            if (f[s]) {
                return true;
            }
        }
        return false;
    }
}

###cpp

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int s = reduce(nums.begin(), nums.end());
        if (s % 2) {
            return false;
        }
        s /= 2; // 注意这里把 s 减半了
        vector<int> f(s + 1);
        f[0] = true;
        int s2 = 0;
        for (int x : nums) {
            s2 = min(s2 + x, s);
            for (int j = s2; j >= x; j--) {
                f[j] |= f[j - x];
            }
            if (f[s]) {
                return true;
            }
        }
        return false;
    }
};

###c

#define MIN(a, b) ((b) < (a) ? (b) : (a))

bool canPartition(int* nums, int numsSize) {
    int s = 0;
    for (int i = 0; i < numsSize; i++) {
        s += nums[i];
    }
    if (s % 2) {
        return false;
    }
    s /= 2; // 注意这里把 s 减半了
    bool* f = calloc(s + 1, sizeof(bool));
    f[0] = true;
    int s2 = 0;
    for (int i = 0; i < numsSize; i++) {
        int x = nums[i];
        s2 = MIN(s2 + x, s);
        for (int j = s2; j >= x; j--) {
            f[j] |= f[j - x];
        }
        if (f[s]) {
            free(f);
            return true;
        }
    }
    free(f);
    return false;
}

###go

func canPartition(nums []int) bool {
    s := 0
    for _, x := range nums {
        s += x
    }
    if s%2 != 0 {
        return false
    }
    s /= 2 // 注意这里把 s 减半了
    f := make([]bool, s+1)
    f[0] = true
    s2 := 0
    for _, x := range nums {
        s2 = min(s2+x, s)
        for j := s2; j >= x; j-- {
            f[j] = f[j] || f[j-x]
        }
        if f[s] {
            return true
        }
    }
    return false
}

###js

var canPartition = function(nums) {
    let s = _.sum(nums);
    if (s % 2) {
        return false;
    }
    s /= 2; // 注意这里把 s 减半了
    const f = Array(s + 1).fill(false);
    f[0] = true;
    let s2 = 0;
    for (const x of nums) {
        s2 = Math.min(s2 + x, s);
        for (let j = s2; j >= x; j--) {
            f[j] = f[j] || f[j - x];
        }
        if (f[s]) {
            return true;
        }
    }
    return false;
};

###rust

impl Solution {
    pub fn can_partition(nums: Vec<i32>) -> bool {
        let s = nums.iter().sum::<i32>();
        if s % 2 != 0 {
            return false;
        }
        let s = s as usize / 2; // 注意这里把 s 减半了
        let mut f = vec![false; s + 1];
        f[0] = true;
        let mut s2 = 0;
        for x in nums {
            let x = x as usize;
            s2 = (s2 + x).min(s);
            for j in (x..=s2).rev() {
                f[j] = f[j] || f[j - x];
            }
            if f[s] {
                return true;
            }
        }
        false
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(ns)$,其中 $n$ 是 $\textit{nums}$ 的长度,$s$ 是 $\textit{nums}$ 的元素和(的一半)。
  • 空间复杂度:$\mathcal{O}(s)$。

附:bitset 做法

把布尔数组压缩成一个二进制数,二进制数从低到高第 $i$ 位是 $0$,表示布尔数组的第 $i$ 个元素是 $\texttt{false}$;从低到高第 $i$ 位是 $1$,表示布尔数组的第 $i$ 个元素是 $\texttt{true}$。

转移方程等价于,把 $f$ 中的每个比特位增加 $x=\textit{nums}[i]$,即左移 $x$ 位,然后跟原来 $f$ 计算 OR。前者对应选 $x$,后者对应不选 $x$。

判断 $f[s]$ 是否为 $\texttt{true}$,等价于判断 $f$ 的第 $s$ 位是否为 $1$,即 (f >> s & 1) == 1

###py

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        s = sum(nums)
        if s % 2:
            return False
        s //= 2
        f = 1
        for x in nums:
            f |= f << x
        return (f >> s & 1) == 1

###java

import java.math.BigInteger;

class Solution {
    public boolean canPartition(int[] nums) {
        int s = 0;
        for (int x : nums) {
            s += x;
        }
        if (s % 2 != 0) {
            return false;
        }
        s /= 2;
        BigInteger f = BigInteger.ONE;
        for (int x : nums) {
            f = f.or(f.shiftLeft(x)); // f |= f << x;
        }
        return f.testBit(s); // 判断 f 中第 s 位是否为 1
    }
}

###cpp

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int s = reduce(nums.begin(), nums.end());
        if (s % 2) {
            return false;
        }
        s /= 2;
        bitset<10001> f; // sum(nums[i]) / 2 <= 10000
        f[0] = 1;
        for (int x : nums) {
            f |= f << x;
        }
        return f[s]; // 判断 f 中第 s 位是否为 1
    }
};

###go

func canPartition(nums []int) bool {
    s := 0
    for _, x := range nums {
        s += x
    }
    if s%2 != 0 {
        return false
    }
    s /= 2
    f := big.NewInt(1)
    p := new(big.Int)
    for _, x := range nums {
        f.Or(f, p.Lsh(f, uint(x)))
    }
    return f.Bit(s) == 1
}

###js

var canPartition = function(nums) {
    let s = _.sum(nums);
    if (s % 2) {
        return false;
    }
    s /= 2;
    let f = 1n;
    for (const x of nums) {
        f |= f << BigInt(x);
    }
    return (f >> BigInt(s) & 1n) === 1n;
};

复杂度分析

  • 时间复杂度:$\mathcal{O}(ns/w)$,其中 $n$ 是 $\textit{nums}$ 的长度,$s$ 是 $\textit{nums}$ 的元素和(的一半),$w=32$ 或者 $64$。
  • 空间复杂度:$\mathcal{O}(s/w)$。

思考题

改成计算分割的方案数,要怎么做?

欢迎在评论区分享你的思路/代码。

更多相似题目,见下面动态规划题单中的「§3.1 0-1 背包」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

分割等和子集

2020年10月10日 23:32

📺 视频题解

416.分割等和子集.mp4

📖 文字题解

前言

作者在这里希望读者认真阅读前言部分。

本题是经典的「NP 完全问题」,也就是说,如果你发现了该问题的一个多项式算法,那么恭喜你证明出了 P=NP,可以期待一下图灵奖了。

正因如此,我们不应期望该问题有多项式时间复杂度的解法。我们能想到的,例如基于贪心算法的「将数组降序排序后,依次将每个元素添加至当前元素和较小的子集中」之类的方法都是错误的,可以轻松地举出反例。因此,我们必须尝试非多项式时间复杂度的算法,例如时间复杂度与元素大小相关的动态规划

方法一:动态规划

思路与算法

这道题可以换一种表述:给定一个只包含正整数的非空数组 $\textit{nums}[0]$,判断是否可以从数组中选出一些数字,使得这些数字的和等于整个数组的元素和的一半。因此这个问题可以转换成「$0-1$ 背包问题」。这道题与传统的「$0-1$ 背包问题」的区别在于,传统的「$0-1$ 背包问题」要求选取的物品的重量之和不能超过背包的总容量,这道题则要求选取的数字的和恰好等于整个数组的元素和的一半。类似于传统的「$0-1$ 背包问题」,可以使用动态规划求解。

在使用动态规划求解之前,首先需要进行以下判断。

  • 根据数组的长度 $n$ 判断数组是否可以被划分。如果 $n<2$,则不可能将数组分割成元素和相等的两个子集,因此直接返回 $\text{false}$。

  • 计算整个数组的元素和 $\textit{sum}$ 以及最大元素 $\textit{maxNum}$。如果 $\textit{sum}$ 是奇数,则不可能将数组分割成元素和相等的两个子集,因此直接返回 $\text{false}$。如果 $\textit{sum}$ 是偶数,则令 $\textit{target}=\frac{\textit{sum}}{2}$,需要判断是否可以从数组中选出一些数字,使得这些数字的和等于 $\textit{target}$。如果 $\textit{maxNum}>\textit{target}$,则除了 $\textit{maxNum}$ 以外的所有元素之和一定小于 $\textit{target}$,因此不可能将数组分割成元素和相等的两个子集,直接返回 $\text{false}$。

创建二维数组 $\textit{dp}$,包含 $n$ 行 $\textit{target}+1$ 列,其中 $\textit{dp}[i][j]$ 表示从数组的 $[0,i]$ 下标范围内选取若干个正整数(可以是 $0$ 个),是否存在一种选取方案使得被选取的正整数的和等于 $j$。初始时,$\textit{dp}$ 中的全部元素都是 $\text{false}$。

在定义状态之后,需要考虑边界情况。以下两种情况都属于边界情况。

  • 如果不选取任何正整数,则被选取的正整数之和等于 $0$。因此对于所有 $0 \le i < n$,都有 $\textit{dp}[i][0]=\text{true}$。

  • 当 $i==0$ 时,只有一个正整数 $\textit{nums}[0]$ 可以被选取,因此 $\textit{dp}[0][\textit{nums}[0]]=\text{true}$。

对于 $i>0$ 且 $j>0$ 的情况,如何确定 $\textit{dp}[i][j]$ 的值?需要分别考虑以下两种情况。

  • 如果 $j \ge \textit{nums}[i]$,则对于当前的数字 $\textit{nums}[i]$,可以选取也可以不选取,两种情况只要有一个为 $\text{true}$,就有 $\textit{dp}[i][j]=\text{true}$。

    • 如果不选取 $\textit{nums}[i]$,则 $\textit{dp}[i][j]=\textit{dp}[i-1][j]$;
    • 如果选取 $\textit{nums}[i]$,则 $\textit{dp}[i][j]=\textit{dp}[i-1][j-\textit{nums}[i]]$。
  • 如果 $j < \textit{nums}[i]$,则在选取的数字的和等于 $j$ 的情况下无法选取当前的数字 $\textit{nums}[i]$,因此有 $\textit{dp}[i][j]=\textit{dp}[i-1][j]$。

状态转移方程如下:

$$
\textit{dp}[i][j]=\begin{cases}
\textit{dp}[i-1][j]~|~\textit{dp}[i-1][j-\textit{nums}[i]], & j \ge \textit{nums}[i] \
\textit{dp}[i-1][j], & j < \textit{nums}[i]
\end{cases}
$$

最终得到 $\textit{dp}[n-1][\textit{target}]$ 即为答案。

<ppt1,ppt2,ppt3,ppt4,ppt5,ppt6,ppt7,ppt8,ppt9,ppt10,ppt11,ppt12>

###Java

class Solution {
    public boolean canPartition(int[] nums) {
        int n = nums.length;
        if (n < 2) {
            return false;
        }
        int sum = 0, maxNum = 0;
        for (int num : nums) {
            sum += num;
            maxNum = Math.max(maxNum, num);
        }
        if (sum % 2 != 0) {
            return false;
        }
        int target = sum / 2;
        if (maxNum > target) {
            return false;
        }
        boolean[][] dp = new boolean[n][target + 1];
        for (int i = 0; i < n; i++) {
            dp[i][0] = true;
        }
        dp[0][nums[0]] = true;
        for (int i = 1; i < n; i++) {
            int num = nums[i];
            for (int j = 1; j <= target; j++) {
                if (j >= num) {
                    dp[i][j] = dp[i - 1][j] | dp[i - 1][j - num];
                } else {
                    dp[i][j] = dp[i - 1][j];
                }
            }
        }
        return dp[n - 1][target];
    }
}

###C++

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int n = nums.size();
        if (n < 2) {
            return false;
        }
        int sum = accumulate(nums.begin(), nums.end(), 0);
        int maxNum = *max_element(nums.begin(), nums.end());
        if (sum & 1) {
            return false;
        }
        int target = sum / 2;
        if (maxNum > target) {
            return false;
        }
        vector<vector<int>> dp(n, vector<int>(target + 1, 0));
        for (int i = 0; i < n; i++) {
            dp[i][0] = true;
        }
        dp[0][nums[0]] = true;
        for (int i = 1; i < n; i++) {
            int num = nums[i];
            for (int j = 1; j <= target; j++) {
                if (j >= num) {
                    dp[i][j] = dp[i - 1][j] | dp[i - 1][j - num];
                } else {
                    dp[i][j] = dp[i - 1][j];
                }
            }
        }
        return dp[n - 1][target];
    }
};

###JavaScript

var canPartition = function(nums) {
    const n = nums.length;
    if (n < 2) {
        return false;
    }
    let sum = 0, maxNum = 0;
    for (const num of nums) {
        sum += num;
        maxNum = maxNum > num ? maxNum : num;
    }
    if (sum & 1) {
        return false;
    }
    const target = Math.floor(sum / 2);
    if (maxNum > target) {
        return false;
    }
    const dp = new Array(n).fill(0).map(() => new Array(target + 1, false));
    for (let i = 0; i < n; i++) {
        dp[i][0] = true;
    }
    dp[0][nums[0]] = true;
    for (let i = 1; i < n; i++) {
        const num = nums[i];
        for (let j = 1; j <= target; j++) {
            if (j >= num) {
                dp[i][j] = dp[i - 1][j] | dp[i - 1][j - num];
            } else {
                dp[i][j] = dp[i - 1][j];
            }
        }
    }
    return dp[n - 1][target];
};

###Golang

func canPartition(nums []int) bool {
    n := len(nums)
    if n < 2 {
        return false
    }

    sum, max := 0, 0
    for _, v := range nums {
        sum += v
        if v > max {
            max = v
        }
    }
    if sum%2 != 0 {
        return false
    }

    target := sum / 2
    if max > target {
        return false
    }

    dp := make([][]bool, n)
    for i := range dp {
        dp[i] = make([]bool, target+1)
    }
    for i := 0; i < n; i++ {
        dp[i][0] = true
    }
    dp[0][nums[0]] = true
    for i := 1; i < n; i++ {
        v := nums[i]
        for j := 1; j <= target; j++ {
            if j >= v {
                dp[i][j] = dp[i-1][j] || dp[i-1][j-v]
            } else {
                dp[i][j] = dp[i-1][j]
            }
        }
    }
    return dp[n-1][target]
}

###C

bool canPartition(int* nums, int numsSize) {
    if (numsSize < 2) {
        return false;
    }
    int sum = 0, maxNum = 0;
    for (int i = 0; i < numsSize; ++i) {
        sum += nums[i];
        maxNum = fmax(maxNum, nums[i]);
    }
    if (sum & 1) {
        return false;
    }
    int target = sum / 2;
    if (maxNum > target) {
        return false;
    }
    int dp[numsSize][target + 1];
    memset(dp, 0, sizeof(dp));
    for (int i = 0; i < numsSize; i++) {
        dp[i][0] = true;
    }
    dp[0][nums[0]] = true;
    for (int i = 1; i < numsSize; i++) {
        int num = nums[i];
        for (int j = 1; j <= target; j++) {
            if (j >= num) {
                dp[i][j] = dp[i - 1][j] | dp[i - 1][j - num];
            } else {
                dp[i][j] = dp[i - 1][j];
            }
        }
    }
    return dp[numsSize - 1][target];
}

###Python

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        n = len(nums)
        if n < 2:
            return False
        
        total = sum(nums)
        maxNum = max(nums)
        if total & 1:
            return False
        
        target = total // 2
        if maxNum > target:
            return False
        
        dp = [[False] * (target + 1) for _ in range(n)]
        for i in range(n):
            dp[i][0] = True
        
        dp[0][nums[0]] = True
        for i in range(1, n):
            num = nums[i]
            for j in range(1, target + 1):
                if j >= num:
                    dp[i][j] = dp[i - 1][j] | dp[i - 1][j - num]
                else:
                    dp[i][j] = dp[i - 1][j]
        
        return dp[n - 1][target]

上述代码的空间复杂度是 $O(n \times \textit{target})$。但是可以发现在计算 $\textit{dp}$ 的过程中,每一行的 $dp$ 值都只与上一行的 $dp$ 值有关,因此只需要一个一维数组即可将空间复杂度降到 $O(\textit{target})$。此时的转移方程为:
$$
\textit{dp}[j]=\textit{dp}[j]\ |\ dp[j-\textit{nums}[i]]
$$
且需要注意的是第二层的循环我们需要从大到小计算,因为如果我们从小到大更新 $\textit{dp}$ 值,那么在计算 $\textit{dp}[j]$ 值的时候,$\textit{dp}[j-\textit{nums}[i]]$ 已经是被更新过的状态,不再是上一行的 $\textit{dp}$ 值。

代码

###Java

class Solution {
    public boolean canPartition(int[] nums) {
        int n = nums.length;
        if (n < 2) {
            return false;
        }
        int sum = 0, maxNum = 0;
        for (int num : nums) {
            sum += num;
            maxNum = Math.max(maxNum, num);
        }
        if (sum % 2 != 0) {
            return false;
        }
        int target = sum / 2;
        if (maxNum > target) {
            return false;
        }
        boolean[] dp = new boolean[target + 1];
        dp[0] = true;
        for (int i = 0; i < n; i++) {
            int num = nums[i];
            for (int j = target; j >= num; --j) {
                dp[j] |= dp[j - num];
            }
        }
        return dp[target];
    }
}

###C++

class Solution {
public:
    bool canPartition(vector<int>& nums) {
        int n = nums.size();
        if (n < 2) {
            return false;
        }
        int sum = 0, maxNum = 0;
        for (auto& num : nums) {
            sum += num;
            maxNum = max(maxNum, num);
        }
        if (sum & 1) {
            return false;
        }
        int target = sum / 2;
        if (maxNum > target) {
            return false;
        }
        vector<int> dp(target + 1, 0);
        dp[0] = true;
        for (int i = 0; i < n; i++) {
            int num = nums[i];
            for (int j = target; j >= num; --j) {
                dp[j] |= dp[j - num];
            }
        }
        return dp[target];
    }
};

###JavaScript

var canPartition = function(nums) {
    const n = nums.length;
    if (n < 2) {
        return false;
    }
    let sum = 0, maxNum = 0;
    for (const num of nums) {
        sum += num;
        maxNum = maxNum > num ? maxNum : num;
    }
    if (sum & 1) {
        return false;
    }
    const target = Math.floor(sum / 2);
    if (maxNum > target) {
        return false;
    }
    const dp = new Array(target + 1).fill(false);
    dp[0] = true;
    for (const num of nums) {
        for (let j = target; j >= num; --j) {
            dp[j] |= dp[j - num];
        }
    }
    return dp[target];
};

###Golang

func canPartition(nums []int) bool {
    n := len(nums)
    if n < 2 {
        return false
    }

    sum, max := 0, 0
    for _, v := range nums {
        sum += v
        if v > max {
            max = v
        }
    }
    if sum%2 != 0 {
        return false
    }

    target := sum / 2
    if max > target {
        return false
    }

    dp := make([]bool, target+1)
    dp[0] = true
    for i := 0; i < n; i++ {
        v := nums[i]
        for j := target; j >= v; j-- {
            dp[j] = dp[j] || dp[j-v]
        }
    }
    return dp[target]
}

###C

bool canPartition(int* nums, int numsSize) {
    if (numsSize < 2) {
        return false;
    }
    int sum = 0, maxNum = 0;
    for (int i = 0; i < numsSize; ++i) {
        sum += nums[i];
        maxNum = fmax(maxNum, nums[i]);
    }
    if (sum & 1) {
        return false;
    }
    int target = sum / 2;
    if (maxNum > target) {
        return false;
    }
    int dp[target + 1];
    memset(dp, 0, sizeof(dp));
    dp[0] = true;
    for (int i = 0; i < numsSize; i++) {
        int num = nums[i];
        for (int j = target; j >= num; --j) {
            dp[j] |= dp[j - num];
        }
    }
    return dp[target];
}

###Python

class Solution:
    def canPartition(self, nums: List[int]) -> bool:
        n = len(nums)
        if n < 2:
            return False
        
        total = sum(nums)
        if total % 2 != 0:
            return False
        
        target = total // 2
        dp = [True] + [False] * target
        for i, num in enumerate(nums):
            for j in range(target, num - 1, -1):
                dp[j] |= dp[j - num]
        
        return dp[target]

复杂度分析

  • 时间复杂度:$O(n \times \textit{target})$,其中 $n$ 是数组的长度,$\textit{target}$ 是整个数组的元素和的一半。需要计算出所有的状态,每个状态在进行转移时的时间复杂度为 $O(1)$。

  • 空间复杂度:$O(\textit{target})$,其中 $\textit{target}$ 是整个数组的元素和的一半。空间复杂度取决于 $\textit{dp}$ 数组,在不进行空间优化的情况下,空间复杂度是 $O(n \times \textit{target})$,在进行空间优化的情况下,空间复杂度可以降到 $O(\textit{target})$。

动态规划(转换为 0-1 背包问题)

作者 liweiwei1419
2019年7月4日 20:27

关于背包问题的介绍,大家可以在互联网上搜索「背包九讲」进行学习,其中「0-1」背包问题是这些问题的基础。「力扣」上涉及的背包问题有「0-1」背包问题、完全背包问题、多重背包问题。

本题解有些地方使用了「0-1」背包问题的描述,因此会不加解释的使用「背包」、「容量」这样的名词。

说明:这里感谢很多朋友在这篇题解下提出的建议,对我的启发很大。本题解的阅读建议是:先浏览代码,然后再看代码之前的分析,能更有效理解知识点和整个问题的思考路径。题解后也增加了「总结」,供大家参考。


转换为 「0 - 1」 背包问题

这道问题是我学习「背包」问题的入门问题,做这道题需要做一个等价转换:是否可以从输入数组中挑选出一些正整数,使得这些数的和 等于 整个数组元素的和的一半。很坦白地说,如果不是我的老师告诉我可以这样想,我很难想出来。容易知道:数组的和一定得是偶数。

本题与 0-1 背包问题有一个很大的不同,即:

  • 0-1 背包问题选取的物品的容积总量 不能超过 规定的总量;
  • 本题选取的数字之和需要 恰好等于 规定的和的一半。

这一点区别,决定了在初始化的时候,所有的值应该初始化为 false。 (《背包九讲》的作者在介绍 「0-1 背包」问题的时候,有强调过这点区别。)

「0 - 1」 背包问题的思路

作为「0-1 背包问题」,它的特点是:「每个数只能用一次」。解决的基本思路是:物品一个一个选,容量也一点一点增加去考虑,这一点是「动态规划」的思想,特别重要。
在实际生活中,我们也是这样做的,一个一个地尝试把候选物品放入「背包」,通过比较得出一个物品要不要拿走。

具体做法是:画一个 len 行,target + 1 列的表格。这里 len 是物品的个数,target 是背包的容量。len 行表示一个一个物品考虑,target + 1多出来的那 1 列,表示背包容量从 0 开始考虑。很多时候,我们需要考虑这个容量为 0 的数值。

状态与状态转移方程

  • 状态定义:dp[i][j]表示从数组的 [0, i] 这个子区间内挑选一些正整数,每个数只能用一次,使得这些数的和恰好等于 j
  • 状态转移方程:很多时候,状态转移方程思考的角度是「分类讨论」,对于「0-1 背包问题」而言就是「当前考虑到的数字选与不选」。
    • 不选择 nums[i],如果在 [0, i - 1] 这个子区间内已经有一部分元素,使得它们的和为 j ,那么 dp[i][j] = true
    • 选择 nums[i],如果在 [0, i - 1] 这个子区间内就得找到一部分元素,使得它们的和为 j - nums[i]

状态转移方程:

###java

dp[i][j] = dp[i - 1][j] or dp[i - 1][j - nums[i]]

一般写出状态转移方程以后,就需要考虑初始化条件。

  • j - nums[i] 作为数组的下标,一定得保证大于等于 0 ,因此 nums[i] <= j
  • 注意到一种非常特殊的情况:j 恰好等于 nums[i],即单独 nums[i] 这个数恰好等于此时「背包的容积」 j,这也是符合题意的。

因此完整的状态转移方程是:

$$
\text{dp}[i][j]=
\begin{cases}
\text{dp}[i - 1][j], & 至少是这个答案,如果 \ \text{dp}[i - 1][j] \ 为真,直接计算下一个状态 \
\text{true}, & \text{nums[i] = j} \
\text{dp}[i - 1][j - nums[i]]. & \text{nums[i] < j}
\end{cases}
$$

说明:虽然写成花括号,但是它们的关系是 或者

  • 初始化:dp[0][0] = false,因为候选数 nums[0] 是正整数,凑不出和为 $0$;
  • 输出:dp[len - 1][target],这里 len 表示数组的长度,target 是数组的元素之和(必须是偶数)的一半。

说明

  • 事实上 dp[0][0] = true 也是可以的,相应地状态转移方程有所变化,请见下文;
  • 如果觉得这个初始化非常难理解,解释性差的朋友,我个人觉得可以不用具体解释它的意义,初始化的值保证状态转移能够正确完成即可。

参考代码 1

###Java

public class Solution {

    public boolean canPartition(int[] nums) {
        int len = nums.length;
        // 题目已经说非空数组,可以不做非空判断
        int sum = 0;
        for (int num : nums) {
            sum += num;
        }
        // 特判:如果是奇数,就不符合要求
        if ((sum & 1) == 1) {
            return false;
        }

        int target = sum / 2;
        // 创建二维状态数组,行:物品索引,列:容量(包括 0)
        boolean[][] dp = new boolean[len][target + 1];

        // 先填表格第 0 行,第 1 个数只能让容积为它自己的背包恰好装满
        if (nums[0] <= target) {
            dp[0][nums[0]] = true;
        }
        // 再填表格后面几行
        for (int i = 1; i < len; i++) {
            for (int j = 0; j <= target; j++) {
                // 直接从上一行先把结果抄下来,然后再修正
                dp[i][j] = dp[i - 1][j];

                if (nums[i] == j) {
                    dp[i][j] = true;
                    continue;
                }
                if (nums[i] < j) {
                    dp[i][j] = dp[i - 1][j] || dp[i - 1][j - nums[i]];
                }
            }
        }
        return dp[len - 1][target];
    }
}

复杂度分析

  • 时间复杂度:$O(NC)$:这里 $N$ 是数组元素的个数,$C$ 是数组元素的和的一半。
  • 空间复杂度:$O(NC)$。

解释设置 dp[0][0] = true 的合理性(重点)

修改状态数组初始化的定义:dp[0][0] = true。考虑容量为 $0$ 的时候,即 dp[i][0]。按照本意来说,应该设置为 false ,但是注意到状态转移方程(代码中):

###java

dp[i][j] = dp[i - 1][j] || dp[i - 1][j - nums[i]];

j - nums[i] == 0 成立的时候,根据上面分析,就说明单独的 nums[i] 这个数就恰好能够在被分割为单独的一组,其余的数分割成为另外一组。因此,我们把初始化的 dp[i][0] 设置成为 true 是没有问题的。

注意:观察状态转移方程,or 的结果只要为真,表格 这一列 下面所有的值都为真。因此在填表的时候,只要表格的最后一列是 true,代码就可以结束,直接返回 true

参考代码 2

###Java

public class Solution {

    public boolean canPartition(int[] nums) {
        int len = nums.length;
        int sum = 0;
        for (int num : nums) {
            sum += num;
        }
        if ((sum & 1) == 1) {
            return false;
        }

        int target = sum / 2;
        boolean[][] dp = new boolean[len][target + 1];
        
        // 初始化成为 true 虽然不符合状态定义,但是从状态转移来说是完全可以的
        dp[0][0] = true;

        if (nums[0] <= target) {
            dp[0][nums[0]] = true;
        }
        for (int i = 1; i < len; i++) {
            for (int j = 0; j <= target; j++) {
                dp[i][j] = dp[i - 1][j];
                if (nums[i] <= j) {
                    dp[i][j] = dp[i - 1][j] || dp[i - 1][j - nums[i]];
                }
            }

            // 由于状态转移方程的特殊性,提前结束,可以认为是剪枝操作
            if (dp[i][target]) {
                return true;
            }
        }
        return dp[len - 1][target];
    }
}

复杂度分析:(同上)


考虑空间优化(重要)

说明:这个技巧很常见、很基础,请一定要掌握。

「0-1 背包问题」常规优化:「状态数组」从二维降到一维,减少空间复杂度。

  • 在「填表格」的时候,当前行只参考了上一行的值,因此状态数组可以只设置 $2$ 行,使用「滚动数组」的技巧「填表格」即可;

  • 实际上,在「滚动数组」的基础上还可以优化,在「填表格」的时候,当前行总是参考了它上面一行 「头顶上」 那个位置和「左上角」某个位置的值。因此,我们可以只开一个一维数组,从后向前依次填表即可。

友情提示:这一点在刚开始学习的时候,可能会觉得很奇怪。理解的办法是:拿题目中的示例,画一个表格,自己模拟一遍程序是如何「填表」的行为,就很清楚为什么状态数组降到 1 行的时候,需要「从后前向」填表。

  • 「从后向前」 写的过程中,一旦 nums[i] <= j 不满足,可以马上退出当前循环,因为后面的 j 的值肯定越来越小,没有必要继续做判断,直接进入外层循环的下一层。相当于也是一个剪枝,这一点是「从前向后」填表所不具备的。

说明:如果对空间优化技巧还有疑惑的朋友,本题解下的精选评论也解释了如何理解这个空间优化的技巧,请大家前往观看。

参考代码 3:只展示了使用一维表格,并且「从后向前」填表格的代码。

###Java

public class Solution {

    public boolean canPartition(int[] nums) {
        int len = nums.length;
        int sum = 0;
        for (int num : nums) {
            sum += num;
        }
        if ((sum & 1) == 1) {
            return false;
        }

        int target = sum / 2;
        boolean[] dp = new boolean[target + 1];
        dp[0] = true;

        if (nums[0] <= target) {
            dp[nums[0]] = true;
        }
        for (int i = 1; i < len; i++) {
            for (int j = target; nums[i] <= j; j--) {
                if (dp[target]) {
                    return true;
                }
                dp[j] = dp[j] || dp[j - nums[i]];
            }
        }
        return dp[target];
    }
}

复杂度分析:

  • 时间复杂度:$O(NC)$:这里 $N$ 是数组元素的个数,$C$ 是数组元素的和的一半;
  • 空间复杂度:$O(C)$:减少了物品那个维度,无论来多少个数,用一行表示状态就够了。

总结

image.png

「0-1 背包」问题是一类非常重要的动态规划问题,一开始学习的时候,可能会觉得比较陌生。建议动笔计算,手动模拟填表的过程,其实就是画表格。这个过程非常重要,自己动手填过表,更能加深体会程序是如何执行的,也能更好地理解「空间优化」技巧的思路和好处。

image.png

在编写代码完成以后,把数组 dp 打印出来,看看是不是与自己手算的一样。以加深体会动态规划的设计思想:「不是直接面对问题求解,而是从一个最小规模的问题开始,新问的最优解均是由比它规模还小的子问题的最优解转换得到,在求解的过程中记录每一步的结果,直至所要求的问题得到解」。


最后思考为什么题目说是正整数,有 $0$ 是否可以,有实数可以吗,有负数可以吗?

  • $0$ 的存在意义不大,放在哪个子集都是可以的;
  • 实数有可能是无理数,也可能是无限不循环小数,在计算整个数组元素的和的一半,要除法,然后在比较两个子集元素的和是否相等的时候,就会遇到精度的问题;
  • 再说负数,负数其实也是可以存在的,但要用到「回溯搜算法」解决。

相关问题

「力扣」上的 0-1 背包问题

  • 「力扣」第 416 题:分割等和子集(中等);
  • 「力扣」第 474 题:一和零(中等);
  • 「力扣」第 494 题:目标和(中等);
  • 「力扣」第 879 题:盈利计划(困难);

「力扣」上的 完全背包问题

  • 「力扣」第 322 题:零钱兑换(中等);
  • 「力扣」第 518 题:零钱兑换 II(中等);
  • 「力扣」第 1449 题:数位成本和为目标值的最大数字(困难)。

这里要注意鉴别:「力扣」第 377 题,不是「完全背包」问题。

参考资料

昨天 — 2025年4月6日LeetCode 每日一题题解

三种方法:记忆化搜索/递推/最优性优化(Python/Java/C++/Go)

作者 endlesscheng
2025年4月6日 08:36

一、分析

设子集为 $A$,题目要求对于任意 $(A[i],A[j])$,都满足 $A[i]\bmod A[j] = 0$ 或者 $A[j]\bmod A[i] = 0$,也就是一个数是另一个数的倍数

这里有两个条件,不好处理。我们可以把 $A$ 排序,或者说把 $\textit{nums}$ 排序(从小到大)。由于 $\textit{nums}$ 所有元素互不相同(没有相等的情况),题目要求变成:

  • 从(排序后的)$\textit{nums}$ 中选一个子序列,在子序列中,右边的数一定是左边的数的倍数。

由于 $x$ 的倍数的倍数仍然是 $x$ 的倍数,只要相邻元素满足倍数关系,那么任意两数一定满足倍数关系。于是题目要求变成:

  • 从(排序后的)$\textit{nums}$ 中选一个子序列,在子序列中,任意相邻的两个数,右边的数一定是左边的数的倍数。

这类似 300. 最长递增子序列,都是相邻元素有约束,且要计算的都是子序列的最长长度。

下文把满足题目要求的子序列叫做合法子序列。

二、记忆化搜索

先把 $\textit{nums}$ 从小到大排序。

注意:排序不影响答案,因为题目说「如果存在多个有效解子集,返回其中任何一个均可」。如果 $[1,2]$ 合法,那么 $[2,1]$ 也合法。

仿照 300 题,定义 $\textit{dfs}(i)$ 表示以 $\textit{nums}[i]$ 结尾的合法子序列的最长长度。

枚举子序列倒数第二个数是 $\textit{nums}[j]$,如果 $\textit{nums}[i]\bmod \textit{nums}[j] = 0$,那么问题变成以 $\textit{nums}[j]$ 结尾的合法子序列的最长长度,即 $\textit{dfs}(i) = \textit{dfs}(j) + 1$。

取最大值,有

$$
\textit{dfs}(i) = \max_{j=0}^{i-1} \textit{dfs}(j) + 1
$$

其中 $\textit{nums}[i]\bmod \textit{nums}[j] = 0$。如果没有满足要求的 $j$,那么 $\textit{dfs}(i) = 1$,即 $\textit{nums}[i]$ 单独一个数作为子序列。

递归边界:$\textit{dfs}(0) = 1$。其实不需要判断递归边界,因为当 $i=0$ 的时候,不会进入循环,直接返回了。

递归入口:$\textit{dfs}(i)$。

注:从右往左递归,主要是为了方便把递归翻译成递推。从左往右递归也是可以的。

如何生成具体方案?

本题可以直接把合法子序列作为 $\textit{dfs}$ 的返回值,但这样做空间复杂度是 $\mathcal{O}(n^2)$ 的,且不是通用做法。

通用做法是,用一个数组 $\textit{from}$ 记录转移来源,初始值 $\textit{from}[i]=-1$

如果 $\textit{dfs}(j)$ 是 $[0,i-1]$ 中的最大值,那么记录 $\textit{from}[i] = j$。

此外,记录最大的 $\textit{dfs}(i)$ 的下标 $\textit{maxI}$,也就是最长合法子序列的最后一个数的下标。

这样我们可以从 $i=\textit{maxI}$ 开始,顺着 $\textit{from}[i]$,找到合法子序列的倒数第二个数,倒数第三个数,……,第一个数。把找到的数记录到一个列表中,最后返回这个列表。

class Solution:
    def largestDivisibleSubset(self, nums: List[int]) -> List[int]:
        nums.sort()
        n = len(nums)
        from_ = [-1] * n

        @cache  # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
        def dfs(i: int) -> int:
            res = 0
            for j in range(i):
                if nums[i] % nums[j]:
                    continue
                f = dfs(j)
                if f > res:
                    res = f
                    from_[i] = j  # 记录最佳转移来源
            return res + 1  # 加上 nums[i] 自己

        max_f = max_i = 0
        for i in range(n):
            f = dfs(i)
            if f > max_f:
                max_f = f
                max_i = i  # 最长合法子序列的最后一个数的下标

        path = []
        i = max_i
        while i >= 0:
            path.append(nums[i])
            i = from_[i]
        return path  # 不需要 reverse,任意顺序返回均可
class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        Arrays.sort(nums);

        int n = nums.length;
        int[] memo = new int[n];
        int[] from = new int[n];
        Arrays.fill(from, -1);
        int maxF = 0;
        int maxI = 0;

        for (int i = 0; i < n; i++) {
            int f = dfs(i, nums, memo, from);
            if (f > maxF) {
                maxF = f;
                maxI = i; // 最长合法子序列的最后一个数的下标
            }
        }

        List<Integer> path = new ArrayList<>(maxF); // 预分配空间
        for (int i = maxI; i >= 0; i = from[i]) {
            path.add(nums[i]);
        }
        return path; // 不需要 reverse,任意顺序返回均可
    }

    private int dfs(int i, int[] nums, int[] memo, int[] from) {
        if (memo[i] > 0) { // 之前计算过
            return memo[i];
        }
        int res = 0;
        for (int j = 0; j < i; j++) {
            if (nums[i] % nums[j] != 0) {
                continue;
            }
            int f = dfs(j, nums, memo, from);
            if (f > res) {
                res = f;
                from[i] = j; // 记录最佳转移来源
            }
        }
        return memo[i] = res + 1; // 记忆化
    }
}
class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        ranges::sort(nums);
        int n = nums.size();
        vector<int> memo(n), from_(n, -1);

        auto dfs = [&](this auto&& dfs, int i) -> int {
            int& res = memo[i]; // 注意这里是引用
            if (res) { // 之前计算过
                return res;
            }
            for (int j = 0; j < i; j++) {
                if (nums[i] % nums[j]) {
                    continue;
                }
                int f = dfs(j);
                if (f > res) {
                    res = f;
                    from_[i] = j; // 记录最佳转移来源
                }
            }
            res++; // 加上 nums[i] 自己
            return res;
        };

        int max_f = 0, max_i = 0;
        for (int i = 0; i < n; i++) {
            int f = dfs(i);
            if (f > max_f) {
                max_f = f;
                max_i = i; // 最长合法子序列的最后一个数的下标
            }
        }

        vector<int> path;
        for (int i = max_i; i >= 0; i = from_[i]) {
            path.push_back(nums[i]);
        }
        return path; // 不需要 reverse,任意顺序返回均可
    }
};
func largestDivisibleSubset(nums []int) []int {
    slices.Sort(nums)
    n := len(nums)
    memo := make([]int, n)
    from := make([]int, n)
    for i := range from {
        from[i] = -1
    }

    var dfs func(i int) int
    dfs = func(i int) (res int) {
        p := &memo[i]
        if *p > 0 { // 之前计算过
            return *p
        }
        defer func() { *p = res }() // 记忆化
        x := nums[i]
        for j, y := range nums[:i] {
            if x%y != 0 {
                continue
            }
            f := dfs(j)
            if f > res {
                res = f
                from[i] = j // 记录最佳转移来源
            }
        }
        return res + 1 // 加上 nums[i] 自己
    }

    maxF, maxI := 0, 0
    for i := range n {
        f := dfs(i)
        if f > maxF {
            maxF = f
            maxI = i // 最长合法子序列的最后一个数的下标
        }
    }

    path := make([]int, 0, maxF) // 预分配空间
    for i := maxI; i >= 0; i = from[i] {
        path = append(path, nums[i])
    }
    return path // 不需要 reverse,任意顺序返回均可
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{nums}$ 的长度。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(n)$,单个状态的计算时间为 $\mathcal{O}(n)$,所以总的时间复杂度为 $\mathcal{O}(n^2)$。
  • 空间复杂度:$\mathcal{O}(n)$。保存多少状态,就需要多少空间。

三、1:1 翻译成递推

我们可以去掉递归中的「递」,只保留「归」的部分,即自底向上计算。

具体来说,$f[i]$ 的定义和 $\textit{dfs}(i)$ 的定义是完全一样的,都表示以 $\textit{nums}[i]$ 结尾的合法子序列的最长长度。

相应的递推式(状态转移方程)也和 $\textit{dfs}$ 一样:

$$
f[i] = \max_{j=0}^{i-1} f[j] + 1
$$

其中 $\textit{nums}[i]\bmod \textit{nums}[j] = 0$。如果没有满足要求的 $j$,那么 $f[i] = 1$,即 $\textit{nums}[i]$ 单独一个数作为子序列。

class Solution:
    def largestDivisibleSubset(self, nums: List[int]) -> List[int]:
        nums.sort()

        n = len(nums)
        f = [0] * n
        from_ = [-1] * n
        max_i = 0

        for i, x in enumerate(nums):
            for j in range(i):
                if x % nums[j] == 0 and f[j] > f[i]:
                    f[i] = f[j]
                    from_[i] = j  # 记录最佳转移来源
            f[i] += 1
            if f[i] > f[max_i]:
                max_i = i  # 最长合法子序列的最后一个数的下标

        path = []
        i = max_i
        while i >= 0:
            path.append(nums[i])
            i = from_[i]
        return path  # 不需要 reverse,任意顺序返回均可
class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        Arrays.sort(nums);

        int n = nums.length;
        int[] f = new int[n];
        int[] from = new int[n];
        Arrays.fill(from, -1);
        int maxI = 0;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < i; j++) {
                if (nums[i] % nums[j] == 0 && f[j] > f[i]) {
                    f[i] = f[j];
                    from[i] = j; // 记录最佳转移来源
                }
            }
            f[i]++;
            if (f[i] > f[maxI]) {
                maxI = i; // 最长合法子序列的最后一个数的下标
            }
        }

        List<Integer> path = new ArrayList<>(f[maxI]); // 预分配空间
        for (int i = maxI; i >= 0; i = from[i]) {
            path.add(nums[i]);
        }
        return path; // 不需要 reverse,任意顺序返回均可
    }
}
class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        ranges::sort(nums);

        int n = nums.size();
        vector<int> f(n), from_(n, -1);
        int max_i = 0;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < i; j++) {
                if (nums[i] % nums[j] == 0 && f[j] > f[i]) {
                    f[i] = f[j];
                    from_[i] = j; // 记录最佳转移来源
                }
            }
            f[i]++;
            if (f[i] > f[max_i]) {
                max_i = i; // 最长合法子序列的最后一个数的下标
            }
        }

        vector<int> path;
        for (int i = max_i; i >= 0; i = from_[i]) {
            path.push_back(nums[i]);
        }
        return path; // 不需要 reverse,任意顺序返回均可
    }
};
func largestDivisibleSubset(nums []int) []int {
    slices.Sort(nums)

    n := len(nums)
    f := make([]int, n)
    from := make([]int, n)
    for i := range from {
        from[i] = -1
    }
    maxI := 0

    for i, x := range nums {
        for j, y := range nums[:i] {
            if x%y == 0 && f[j] > f[i] {
                f[i] = f[j]
                from[i] = j // 记录最佳转移来源
            }
        }
        f[i]++
        if f[i] > f[maxI] {
            maxI = i // 最长合法子序列的最后一个数的下标
        }
    }

    path := make([]int, 0, f[maxI]) // 预分配空间
    for i := maxI; i >= 0; i = from[i] {
        path = append(path, nums[i])
    }
    return path // 不需要 reverse,任意顺序返回均可
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

四、最优性优化

设 $\textit{nums}$ 的最大值为 $\textit{maxNum}$,目前计算出的最大的 $f[i]$ 为 $\textit{maxF}$。

假设 $\textit{maxNum}=16$,目前 $\textit{maxF}=3$,$\textit{nums}[i]=5$,算出的 $f[i]=2$。我们可以断定:这个 $f[i]$ 是无用数据,不需要在后续的内层循环中遍历它!

为什么?$\textit{nums}[i]=5$ 的不超过 $\textit{maxNum}$ 的倍数有 $10,15$,这些倍数的倍数,一定会超过 $\textit{maxNum}$,所以 $f[i]$ 的「增长潜力」只有 $1$。由于 $f[i]+1\le \textit{maxF}$,在后续的计算中,不可能更新 $\textit{maxF}$,所以 $f[i]$ 是无用数据。

一般地,把一个数 $x$ 不断地乘 $2$,要求结果 $\le \textit{maxNum}$,最多能乘多少次?

这相当于 $x\cdot 2^k\le \textit{maxNum}$,解得 $k\le \left\lfloor \log_2 \dfrac{\textit{maxNum}}{x}\right\rfloor$。

一般地,如果

$$
f[i] + \left\lfloor \log_2 \dfrac{\textit{maxNum}}{\textit{nums}[i]}\right\rfloor \le \textit{maxF}
$$

那么这个 $f[i]$ 是无用数据,在后续的循环中,不需要遍历 $i$。

我们可以维护一个需要遍历的下标列表 $\textit{validIdx}$,把不满足上式的 $i$ 加进去。

此外,当 $\textit{maxF}$ 变大时,扫描 $\textit{validIdx}$,去掉无用数据。

注:这个做法类似搜索中的「最优性剪枝」:如果发现继续递归下去不可能更新答案,那么不继续往下递归。

class Solution:
    def largestDivisibleSubset(self, nums: List[int]) -> List[int]:
        nums.sort()
        max_num = nums[-1]

        n = len(nums)
        f = [0] * n
        from_ = [-1] * n
        max_f = max_i = 0
        valid_idx = []

        for i, x in enumerate(nums):
            for j in valid_idx:  # 只需要遍历在 valid_idx 中的下标
                if x % nums[j] == 0 and f[j] > f[i]:
                    f[i] = f[j]
                    from_[i] = j
            f[i] += 1
            if f[i] > max_f:
                max_f = f[i]
                max_i = i
                # max_f 变大,去掉 valid_idx 中的无用数据(这里直接生成一个新的)
                new_valid_idx = []
                for j in valid_idx:
                    if f[j] + (max_num // nums[j]).bit_length() - 1 > max_f:
                        new_valid_idx.append(j)
                valid_idx = new_valid_idx
            if f[i] + (max_num // x).bit_length() - 1 > max_f:
                valid_idx.append(i)

        path = []
        i = max_i
        while i >= 0:
            path.append(nums[i])
            i = from_[i]
        return path
// 更快的写法见【Java 数组】
class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        Arrays.sort(nums);
        int n = nums.length;
        int maxNum = nums[nums.length - 1];

        int[] f = new int[n];
        int[] from = new int[n];
        Arrays.fill(from, -1);
        int maxF = 0;
        int maxI = 0;
        List<Integer> validIdx = new ArrayList<>();

        for (int i = 0; i < n; i++) {
            int x = nums[i];
            for (int j : validIdx) { // 只需要遍历在 validIdx 中的下标
                if (x % nums[j] == 0 && f[j] > f[i]) {
                    f[i] = f[j];
                    from[i] = j;
                }
            }
            f[i]++;
            if (f[i] > maxF) {
                maxF = f[i];
                maxI = i;
                // maxF 变大,去掉 validIdx 中的无用数据(这里直接生成一个新的)
                List<Integer> newValidIdx = new ArrayList<>();
                for (int j : validIdx) {
                    if (f[j] + 31 - Integer.numberOfLeadingZeros(maxNum / nums[j]) > maxF) {
                        newValidIdx.add(j);
                    }
                }
                validIdx = newValidIdx;
            }
            if (f[i] + 31 - Integer.numberOfLeadingZeros(maxNum / x) > maxF) {
                validIdx.add(i);
            }
        }

        List<Integer> path = new ArrayList<>(maxF);
        for (int i = maxI; i >= 0; i = from[i]) {
            path.add(nums[i]);
        }
        return path;
    }
}
class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        Arrays.sort(nums);
        int n = nums.length;
        int maxNum = nums[nums.length - 1];

        int[] f = new int[n];
        int[] from = new int[n];
        Arrays.fill(from, -1);
        int maxF = 0;
        int maxI = 0;
        int[] validIdx = new int[n]; // 改成数组
        int m = 0; // validIdx 的大小

        for (int i = 0; i < n; i++) {
            int x = nums[i];
            for (int k = 0; k < m; k++) {
                int j = validIdx[k];
                if (x % nums[j] == 0 && f[j] > f[i]) {
                    f[i] = f[j];
                    from[i] = j;
                }
            }
            f[i]++;
            if (f[i] > maxF) {
                maxF = f[i];
                maxI = i;
                int m2 = 0;
                for (int k = 0; k < m; k++) {
                    int j = validIdx[k];
                    if (f[j] + 31 - Integer.numberOfLeadingZeros(maxNum / nums[j]) > maxF) {
                        validIdx[m2++] = j; // 原地修改
                    }
                }
                m = m2;
            }
            if (f[i] + 31 - Integer.numberOfLeadingZeros(maxNum / x) > maxF) {
                validIdx[m++] = i;
            }
        }

        List<Integer> path = new ArrayList<>(maxF);
        for (int i = maxI; i >= 0; i = from[i]) {
            path.add(nums[i]);
        }
        return path;
    }
}
class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        ranges::sort(nums);
        unsigned max_num = nums.back();

        int n = nums.size();
        vector<int> f(n), from_(n, -1);
        int max_f = 0, max_i = 0;
        vector<int> valid_idx;

        for (int i = 0; i < n; i++) {
            for (int j : valid_idx) { // 只需要遍历在 valid_idx 中的下标
                if (nums[i] % nums[j] == 0 && f[j] > f[i]) {
                    f[i] = f[j];
                    from_[i] = j;
                }
            }
            f[i]++;
            if (f[i] > max_f) {
                max_f = f[i];
                max_i = i;
                // max_f 变大,去掉 valid_idx 中的无用数据(这里直接生成一个新的)
                vector<int> new_valid_idx;
                for (int j : valid_idx) {
                    if (f[j] + bit_width(max_num / nums[j]) - 1 > max_f) {
                        new_valid_idx.push_back(j);
                    }
                }
                valid_idx = move(new_valid_idx);
            }
            if (f[i] + bit_width(max_num / nums[i]) - 1 > max_f) {
                valid_idx.push_back(i);
            }
        }

        vector<int> path;
        for (int i = max_i; i >= 0; i = from_[i]) {
            path.push_back(nums[i]);
        }
        return path;
    }
};
func largestDivisibleSubset(nums []int) []int {
    slices.Sort(nums)
    n := len(nums)
    maxNum := nums[n-1]

    f := make([]int, n)
    from := make([]int, n)
    for i := range from {
        from[i] = -1
    }
    maxF, maxI := 0, 0
    validIdx := []int{}

    for i, x := range nums {
        for _, j := range validIdx { // 只需要遍历在 validIdx 中的下标
            if x%nums[j] == 0 && f[j] > f[i] {
                f[i] = f[j]
                from[i] = j
            }
        }
        f[i]++
        if f[i] > maxF {
            maxF = f[i]
            maxI = i
            // maxF 变大,去掉 validIdx 中的无用数据
            newValidIdx := validIdx[:0]
            for _, j := range validIdx {
                if f[j]+bits.Len(uint(maxNum/nums[j]))-1 > maxF {
                    newValidIdx = append(newValidIdx, j)
                }
            }
            validIdx = newValidIdx
        }
        if f[i]+bits.Len(uint(maxNum/x))-1 > maxF {
            validIdx = append(validIdx, i)
        }
    }

    path := make([]int, 0, maxF)
    for i := maxI; i >= 0; i = from[i] {
        path = append(path, nums[i])
    }
    return path
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-最大整除子集🟡

2025年4月6日 00:00
给你一个由 无重复 正整数组成的集合 nums ,请你找出并返回其中最大的整除子集 answer ,子集中每一元素对 (answer[i], answer[j]) 都应当满足:
  • answer[i] % answer[j] == 0 ,或
  • answer[j] % answer[i] == 0

如果存在多个有效解子集,返回其中任何一个均可。

 

示例 1:

输入:nums = [1,2,3]
输出:[1,2]
解释:[1,3] 也会被视为正确答案。

示例 2:

输入:nums = [1,2,4,8]
输出:[1,2,4,8]

 

提示:

  • 1 <= nums.length <= 1000
  • 1 <= nums[i] <= 2 * 109
  • nums 中的所有整数 互不相同

【宫水三叶の相信科学系列】详解为何能转换为序列 DP 问题

作者 AC_OIer
2021年4月23日 08:52

基本分析

根据题意:对于符合要求的「整除子集」中的任意两个值,必然满足「较大数」是「较小数」的倍数。

数据范围是 $10^3$,我们不可能采取获取所有子集,再检查子集是否合法的爆搜解法。

通常「递归」做不了,我们就往「递推」方向去考虑。

由于存在「整除子集」中任意两个值必然存在倍数/约数关系的性质,我们自然会想到对 nums 进行排序,然后从集合 nums 中从大到小进行取数,每次取数只考虑当前决策的数是否与「整除子集」中的最后一个数成倍数关系即可。

这时候你可能会想枚举每个数作为「整除子集」的起点,然后从前往后遍历一遍,每次都将符合「与当前子集最后一个元素成倍数」关系的数加入答案。

举个🌰,假设有原数组 [1,2,4,8],“或许”我们期望的决策过程是:

  1. 遍历到数字 1,此时「整除子集」为空,加到「整除子集」中;
  2. 遍历到数字 2,与「整除子集」的最后一个元素(1)成倍数关系,加到「整除子集」中;
  3. 遍历到数字 4,与「整除子集」的最后一个元素(2)成倍数关系,自然也与 2 之前的元素成倍数关系,加到「整除子集」中;
  4. 遍历到数字 8,与「整除子集」的最后一个元素(4)成倍数关系,自然也与 4 之前的元素成倍数关系,加到「整除子集」中。

但这样的做法只能够确保得到「合法解」,无法确保得到的是「最长整除子集」

当时担心本题数据太弱,上述错误的解法也能够通过,所以还特意实现了一下,还好被卡住了(🤣

同时也得到这个反例:[9,18,54,90,108,180,360,540,720],如果按照我们上述逻辑,我们得到的是 [9,18,54,108,540] 答案(长度为 5),但事实上存在更长的「整除子集」: [9,18,90,180,360,720](长度为 6)。

其本质是因为同一个数的不同倍数之间不存在必然的「倍数/约数关系」,而只存在「具有公约数」的性质,这会导致我们「模拟解法」错过最优解。

比如上述 🌰,54 & 9018 存在倍数关系,但两者本身不存在倍数关系。

因此当我们决策到某一个数 nums[i] 时(nums 已排好序),我们无法直接将 nums[i] 直接接在符合「约数关系」的、最靠近位置 i 的数后面,而是要检查位置 i 前面的所有符合「约数关系」的位置,找一个已经形成「整除子集」长度最大的数

换句话说,当我们对 nums 排好序并从前往后处理时,在处理到 nums[i] 时,我们希望知道位置 i 之前的下标已经形成的「整除子集」长度是多少,然后从中选一个最长的「整除子集」,将 nums[i] 接在后面(前提是符合「倍数关系」)。


动态规划

基于上述分析,我们不难发现这其实是一个序列 DP 问题:某个状态的转移依赖于与前一个状态的关系。即 nums[i] 能否接在 nums[j] 后面,取决于是否满足 nums[i] % nums[j] == 0 条件。

可看做是「最长上升子序列」问题的变形题。

定义 $f[i]$ 为考虑前 i 个数字,且以第 i 个数为结尾的最长「整除子集」长度。

我们不失一般性的考虑任意位置 i,存在两种情况:

  • 如果在 i 之前找不到符合条件 nums[i] % nums[j] == 0 的位置 j,那么 nums[i] 不能接在位置 i 之前的任何数的后面,只能自己独立作为「整除子集」的第一个数,此时状态转移方程为 $f[i] = 1$;
  • 如果在 i 之前能够找到符合条件的位置 j,则取所有符合条件的 f[j] 的最大值,代表如果希望找到以 nums[i] 为结尾的最长「整除子集」,需要将 nums[i] 接到符合条件的最长的 nums[j] 后面,此时状态转移方程为 $f[i] = f[j] + 1$。

同时由于我们需要输出具体方案,需要额外使用 g[] 数组来记录每个状态是由哪个状态转移而来。

定义 $g[i]$ 为记录 $f[i]$ 是由哪个下标的状态转移而来,如果 $f[i] = f[j] + 1$, 则有 $g[i] = j$。

对于求方案数的题目,多开一个数组来记录状态从何转移而来是最常见的手段。

当我们求得所有的状态值之后,可以对 f[] 数组进行遍历,取得具体的最长「整除子集」长度和对应下标,然后使用 g[] 数组进行回溯,取得答案。

代码(感谢 @Benhao@007 两位同学提供的其他语言版本):

###Java

class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        Arrays.sort(nums);
        int n = nums.length;
        int[] f = new int[n];
        int[] g = new int[n];
        for (int i = 0; i < n; i++) {
            // 至少包含自身一个数,因此起始长度为 1,由自身转移而来
            int len = 1, prev = i;
            for (int j = 0; j < i; j++) {
                if (nums[i] % nums[j] == 0) {
                    // 如果能接在更长的序列后面,则更新「最大长度」&「从何转移而来」
                    if (f[j] + 1 > len) {
                        len = f[j] + 1;
                        prev = j;
                    }
                }
            }
            // 记录「最终长度」&「从何转移而来」
            f[i] = len;
            g[i] = prev;
        }
        
        // 遍历所有的 f[i],取得「最大长度」和「对应下标」
        int max = -1, idx = -1;
        for (int i = 0; i < n; i++) {
            if (f[i] > max) {
                idx = i;
                max = f[i];
            }
        }

        // 使用 g[] 数组回溯出具体方案
        List<Integer> ans = new ArrayList<>();
        while (ans.size() != max) {
            ans.add(nums[idx]);
            idx = g[idx];
        }
        return ans;
    }
}

###C++

class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        sort(nums.begin(), nums.end());
        int n = nums.size();
        vector<int> f(n, 0);
        vector<int> g(n ,0);
        
        for(int i = 0; i < n; i++) {
            // 至少包含自身一个数,因此起始长度为 1,由自身转移而来
            int len = 1, prev = i;
            for(int j = 0; j < i; j++) {
                if(nums[i] % nums[j] == 0) {
                    // 如果能接在更长的序列后面,则更新「最大长度」&「从何转移而来」
                    if(f[j] + 1 > len) {
                        len = f[j] + 1;
                        prev = j;
                    }
                }
            }
            f[i] = len;
            g[i] = prev;
        }

        // 遍历所有的 f[i],取得「最大长度」和「对应下标」
        int idx = max_element(f.begin(), f.end()) - f.begin();
        int max = f[idx];

        // 使用 g[] 数组回溯出具体方案
        vector<int> ans;
        while(ans.size() != max) {
            ans.push_back(nums[idx]);
            idx = g[idx];
        }
        return ans;
    }
};

###Python3

class Solution:
    def largestDivisibleSubset(self, nums: List[int]) -> List[int]:
        nums.sort()
        n = len(nums)
        f, g = [0] * n, [0] * n
        for i in range(n):
            # 至少包含自身一个数,因此起始长度为 1,由自身转移而来
            length, prev = 1, i
            for j in range(i):
                if nums[i] % nums[j] == 0:
                    # 如果能接在更长的序列后面,则更新「最大长度」&「从何转移而来」
                    if f[j] + 1 > length:
                        length = f[j] + 1
                        prev = j
            # 记录「最终长度」&「从何转移而来」
            f[i] = length
            g[i] = prev

        # 遍历所有的 f[i],取得「最大长度」和「对应下标」
        max_len = idx = -1
        for i in range(n):
            if f[i] > max_len:
                idx = i
                max_len = f[i]
        
        # 使用 g[] 数组回溯出具体方案
        ans = []
        while len(ans) < max_len:
            ans.append(nums[idx])
            idx = g[idx]
        ans.reverse()
        return ans

###Go

func largestDivisibleSubset(nums []int) []int {
sort.Ints(nums)
n := len(nums)
// 定义 f[i] 为考虑前 i 个数字,且以第 i 个数为结尾的最长「整除子集」长度。
f := make([]int, n)
// 定义 g[i] 为记录 f[i] 是由哪个下标的状态转移而来,如果 f[i] = f[j] + 1, 则有 g[i] = j。
g := make([]int, n)

for i := 0; i < n; i++ {
// 至少包含自身一个数,因此起始长度为 1,由自身转移而来
l := 1
prev := i
for j := 0; j < i; j++ {
if nums[i]%nums[j] == 0 {
// 如果能接在更长的序列后面,则更新「最大长度」&「从何转移而来」
if f[j]+1 > l {
l = f[j] + 1
prev = j
}
}
}

// 记录「最终长度」&「从何转移而来」
f[i] = l
g[i] = prev
}

// 遍历所有的 f[i],取得「最大长度」和「对应下标」
max := -1
idx := -1
for i := 0; i < n; i++ {
if f[i] > max {
idx = i
max = f[i]
}
}

// 使用 g[] 数组回溯出具体方案
var ans []int
for len(ans) != max {
ans = append(ans, nums[idx])

idx = g[idx]
}
return ans
}
  • 时间复杂度:$O(n^2)$
  • 空间复杂度:$O(n)$

证明

之所以上述解法能够成立,问题能够转化为「最长上升子序列(LIS)」问题进行求解,本质是利用了「全序关系」中的「可传递性」。

在 LIS 问题中,我们是利用了「关系运算符 $\geqslant$ 」的传递性,因此当我们某个数 a 能够接在 b 后面,只需要确保 $a \geqslant b$ 成立,即可确保 a 大于等于 b 之前的所有值。

那么同理,如果我们想要上述解法成立,我们还需要证明如下内容:

  • 「倍数/约数关系」具有传递性

由于我们将 nums[i] 往某个数字后面接时(假设为 nums[j]),只检查了其与 nums[j] 的关系,并没有去检查 nums[i]nums[j] 之前的数值是否具有「倍数/约数关系」。

换句话说,我们只确保了最终答案 [a1, a2, a3, ..., an] 相邻两数值之间具有「倍数/约数关系」,并不明确任意两值之间具有「倍数/约数关系」。

因此需要证得由 $a | b$ 和 $b | c$,可推导出 $a | c$ 的传递性:

由 $a | b$ 可得 $b = x * a$
由 $b | c$ 可得 $c = y * b$

最终有 $c = y * b = y * x * a$,由于 $x$ 和 $y$ 都是整数,因此可得 $a | c$。

得证「倍数/约数关系」具有传递性。


最后

如果有帮助到你,请给题解点个赞和收藏,让更多的人看到 ~ ("▔□▔)/

也欢迎你 关注我 和 加入我们的「组队打卡」小群 ,提供写「证明」&「思路」的高质量题解

最大整除子集

2021年4月22日 22:54

前言

首先需要理解什么叫「整除子集」。根据题目的描述,如果一个所有元素互不相同的集合中的任意元素存在整除关系,就称为整除子集。为了得到「最大整除子集」,我们需要考虑如何从一个小的整除子集扩充成为更大的整除子集

根据整除关系具有传递性,即如果 $a\big|b$,并且 $b\big|c$,那么 $a\big|c$,可知:

  • 如果整数 $a$ 是整除子集 $S_1$ 的最小整数 $b$ 的约数(即 $a\big|b$),那么可以将 $a$ 添加到 $S_1$ 中得到一个更大的整除子集;

  • 如果整数 $c$ 是整除子集 $S_2$ 的最大整数 $d$ 的倍数(即 $d\big|c$),那么可以将 $c$ 添加到 $S_2$ 中得到一个更大的整除子集。

这两点揭示了当前问题状态转移的特点,因此可以使用动态规划的方法求解。题目只要求我们得到多个目标子集的其中一个,根据求解动态规划问题的经验,我们需要将子集的大小定义为状态,然后根据结果倒推得到一个目标子集。事实上,当前问题和使用动态规划解决的经典问题「300. 最长递增子序列」有相似之处。

方法一:动态规划

根据前言的分析,我们需要将输入数组 $\textit{nums}$ 按照升序排序,以便获得一个子集的最小整数或者最大整数。又根据动态规划的「无后效性」状态设计准则,我们需要将状态定义成「某个元素必须选择」。

状态定义:$\textit{dp}[i]$ 表示在输入数组 $\textit{nums}$ 升序排列的前提下,以 $\textit{nums}[i]$ 为最大整数的「整除子集」的大小(在这种定义下 $\textit{nums}[i]$ 必须被选择)。

状态转移方程:枚举 $j = 0 \ldots i-1$ 的所有整数 $\textit{nums}[j]$,如果 $\textit{nums}[j]$ 能整除 $\textit{nums}[i]$,说明 $\textit{nums}[i]$ 可以扩充在以 $\textit{nums}[j]$ 为最大整数的整除子集里成为一个更大的整除子集。

初始化:由于 $\textit{nums}[i]$ 必须被选择,因此对于任意 $i = 0 \ldots n-1$,初始的时候 $\textit{dp}[i] = 1$,这里 $n$ 是输入数组的长度。

输出:由于最大整除子集不一定包含 $\textit{nums}$ 中最大的整数,所以我们需要枚举所有的 $\textit{dp}[i]$,选出最大整除子集的大小 $\textit{maxSize}$,以及该最大子集中的最大整数 $\textit{maxVal}$。按照如下方式倒推获得一个目标子集:

  1. 倒序遍历数组 $\textit{dp}$,直到找到 $\textit{dp}[i] = \textit{maxSize}$ 为止,把此时对应的 $\textit{nums}[i]$ 加入结果集,此时 $\textit{maxVal} = \textit{nums}[i]$;

  2. 然后将 $\textit{maxSize}$ 的值减 $1$,继续倒序遍历找到 $\textit{dp}[i] = \textit{maxSize}$,且 $\textit{nums}[i]$ 能整除 $\textit{maxVal}$ 的 $i$ 为止,将此时的 $\textit{nums}[i]$ 加入结果集,$\textit{maxVal}$ 更新为此时的 $num[i]$;

  3. 重复上述操作,直到 $\textit{maxSize}$ 的值变成 $0$,此时的结果集即为一个目标子集。

下面用一个例子说明如何得到最大整除子集。假设输入数组为 $[2,4,7,8,9,12,16,18]$(已经有序),得到的动态规划表格如下:

$\textit{nums}$ $2$ $4$ $7$ $8$ $9$ $12$ $16$ $20$
$\textit{dp}$ $1$ $2$ $1$ $3$ $1$ $3$ $4$ $3$

得到最大整除子集的做法如下:

  1. 根据 $\textit{dp}$ 的计算结果,$\textit{maxSize}=4$,$\textit{maxVal}=16$,因此大小为 $4$ 的最大整除子集包含的最大整数为 $16$;

  2. 然后查找大小为 $3$ 的最大整除子集,我们看到 $8$ 和 $12$ 对应的状态值都是 $3$,最大整除子集一定包含 $8$,这是因为 $8 \big| 16$;

  3. 然后查找大小为 $2$ 的最大整除子集,我们看到 $4$ 对应的状态值是 $2$,最大整除子集一定包含 $4$;

  4. 然后查找大小为 $1$ 的最大整除子集,我们看到 $2$ 对应的状态值是 $1$,最大整除子集一定包含 $2$。

通过这样的方式,我们就找到了满足条件的某个最大整除子集 $[16,8,4,2]$。

代码

###Java

class Solution {
    public List<Integer> largestDivisibleSubset(int[] nums) {
        int len = nums.length;
        Arrays.sort(nums);

        // 第 1 步:动态规划找出最大子集的个数、最大子集中的最大整数
        int[] dp = new int[len];
        Arrays.fill(dp, 1);
        int maxSize = 1;
        int maxVal = dp[0];
        for (int i = 1; i < len; i++) {
            for (int j = 0; j < i; j++) {
                // 题目中说「没有重复元素」很重要
                if (nums[i] % nums[j] == 0) {
                    dp[i] = Math.max(dp[i], dp[j] + 1);
                }
            }

            if (dp[i] > maxSize) {
                maxSize = dp[i];
                maxVal = nums[i];
            }
        }

        // 第 2 步:倒推获得最大子集
        List<Integer> res = new ArrayList<Integer>();
        if (maxSize == 1) {
            res.add(nums[0]);
            return res;
        }
        
        for (int i = len - 1; i >= 0 && maxSize > 0; i--) {
            if (dp[i] == maxSize && maxVal % nums[i] == 0) {
                res.add(nums[i]);
                maxVal = nums[i];
                maxSize--;
            }
        }
        return res;
    }
}

###JavaScript

var largestDivisibleSubset = function(nums) {
    const len = nums.length;
    nums.sort((a, b) => a - b);

    // 第 1 步:动态规划找出最大子集的个数、最大子集中的最大整数
    const dp = new Array(len).fill(1);
    let maxSize = 1;
    let maxVal = dp[0];
    for (let i = 1; i < len; i++) {
        for (let j = 0; j < i; j++) {
            // 题目中说「没有重复元素」很重要
            if (nums[i] % nums[j] === 0) {
                dp[i] = Math.max(dp[i], dp[j] + 1);
            }
        }

        if (dp[i] > maxSize) {
            maxSize = dp[i];
            maxVal = nums[i];
        }
    }

    // 第 2 步:倒推获得最大子集
    const res = [];
    if (maxSize === 1) {
        res.push(nums[0]);
        return res;
    }
    
    for (let i = len - 1; i >= 0 && maxSize > 0; i--) {
        if (dp[i] === maxSize && maxVal % nums[i] === 0) {
            res.push(nums[i]);
            maxVal = nums[i];
            maxSize--;
        }
    }
    return res;
};

###go

func largestDivisibleSubset(nums []int) (res []int) {
    sort.Ints(nums)

    // 第 1 步:动态规划找出最大子集的个数、最大子集中的最大整数
    n := len(nums)
    dp := make([]int, n)
    for i := range dp {
        dp[i] = 1
    }
    maxSize, maxVal := 1, 1
    for i := 1; i < n; i++ {
        for j, v := range nums[:i] {
            if nums[i]%v == 0 && dp[j]+1 > dp[i] {
                dp[i] = dp[j] + 1
            }
        }
        if dp[i] > maxSize {
            maxSize, maxVal = dp[i], nums[i]
        }
    }

    if maxSize == 1 {
        return []int{nums[0]}
    }

    // 第 2 步:倒推获得最大子集
    for i := n - 1; i >= 0 && maxSize > 0; i-- {
        if dp[i] == maxSize && maxVal%nums[i] == 0 {
            res = append(res, nums[i])
            maxVal = nums[i]
            maxSize--
        }
    }
    return
}

###C++

class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        int len = nums.size();
        sort(nums.begin(), nums.end());

        // 第 1 步:动态规划找出最大子集的个数、最大子集中的最大整数
        vector<int> dp(len, 1);
        int maxSize = 1;
        int maxVal = dp[0];
        for (int i = 1; i < len; i++) {
            for (int j = 0; j < i; j++) {
                // 题目中说「没有重复元素」很重要
                if (nums[i] % nums[j] == 0) {
                    dp[i] = max(dp[i], dp[j] + 1);
                }
            }

            if (dp[i] > maxSize) {
                maxSize = dp[i];
                maxVal = nums[i];
            }
        }

        // 第 2 步:倒推获得最大子集
        vector<int> res;
        if (maxSize == 1) {
            res.push_back(nums[0]);
            return res;
        }

        for (int i = len - 1; i >= 0 && maxSize > 0; i--) {
            if (dp[i] == maxSize && maxVal % nums[i] == 0) {
                res.push_back(nums[i]);
                maxVal = nums[i];
                maxSize--;
            }
        }
        return res;
    }
};

###C

int cmp(int* a, int* b) {
    return *a - *b;
}

int* largestDivisibleSubset(int* nums, int numsSize, int* returnSize) {
    int len = numsSize;
    qsort(nums, numsSize, sizeof(int), cmp);

    // 第 1 步:动态规划找出最大子集的个数、最大子集中的最大整数
    int dp[len];
    for (int i = 0; i < len; i++) {
        dp[i] = 1;
    }
    int maxSize = 1;
    int maxVal = dp[0];
    for (int i = 1; i < len; i++) {
        for (int j = 0; j < i; j++) {
            // 题目中说「没有重复元素」很重要
            if (nums[i] % nums[j] == 0) {
                dp[i] = fmax(dp[i], dp[j] + 1);
            }
        }

        if (dp[i] > maxSize) {
            maxSize = dp[i];
            maxVal = nums[i];
        }
    }

    // 第 2 步:倒推获得最大子集
    int* res = malloc(sizeof(int) * len);
    *returnSize = 0;
    if (maxSize == 1) {
        res[(*returnSize)++] = nums[0];
        return res;
    }

    for (int i = len - 1; i >= 0 && maxSize > 0; i--) {
        if (dp[i] == maxSize && maxVal % nums[i] == 0) {
            res[(*returnSize)++] = nums[i];
            maxVal = nums[i];
            maxSize--;
        }
    }
    return res;
}

复杂度分析

  • 时间复杂度:$O(n^2)$,其中 $n$ 为输入数组的长度。对数组 $\textit{nums}$ 排序的时间复杂度为 $O(n \log n)$,计算数组 $\textit{dp}$ 元素的时间复杂度为 $O(n^2)$,倒序遍历得到一个目标子集,时间复杂度为 $O(n)$。

  • 空间复杂度:$O(n)$,其中 $n$ 为输入数组的长度。需要创建长度为 $n$ 的数组 $\textit{dp}$。


📚 好读书?读好书!让时间更有价值| 世界读书日

4 月 22 日至 4 月 28 日,进入「学习」,完成页面右上角的「让时间更有价值」限时阅读任务,可获得「2021 读书日纪念勋章」。更多活动详情戳上方标题了解更多👆

今日学习任务:

个人题解[dp][c++][O(n^2)]

作者 leolee
2019年10月27日 17:44

做题前的思考:假如已经有一个两两倍数关系的序列2 4 8,若此时来一个8的倍数x,那么x一定是8 4 2的倍数,2 4 8 x依然保持了两两倍数关系的序列,到了此处就发现可以大问题转换成子问题,于是联想到动态规划。并且在扩展序列的时候是有序的,所以我们需要给nums排序.

对于已经排序后的nums
设定状态: dp[i]: 以nums[i]结尾的序列最大长度
last[i]: 在最大序列中 nums[i]的上一个元素在nums出现的下标
状态转移方程:
使用二重循环,对于每一个nums[i],看他可以接在之前的哪个序列dp[j]上,使得dp[i]最长
nums[i]%nums[j] == 0是可以接的条件,dp[i]<=dp[j]是使得dp[i]变长的条件
初始状态:dp[i] = 1 (i:1 - n) 每一个只有自己的序列长度为1

###cpp

        for(int i = 0;i<sz;i++){
            for(int j = 0;j<i;j++)
                if(nums[i]%nums[j] == 0 && dp[i]<=dp[j]){
                    dp[i] = dp[j]+1;
                    last[i] = j;
                }

###cpp

class Solution {
public:
    vector<int> largestDivisibleSubset(vector<int>& nums) {
        int sz = nums.size(),mx = 0,end = -1;
        vector<int> dp(sz,1),last(sz,-1),res;
        sort(nums.begin(),nums.end());
        for(int i = 0;i<sz;i++){
            for(int j = 0;j<i;j++){
                if(nums[i]%nums[j] == 0 && dp[i]<=dp[j]){
                    dp[i] = dp[j]+1;
                    last[i] = j;
                }
            }
            if(dp[i]>mx){
                mx = dp[i];
                end = i;
            }
        }
        for(int i = end;i!=-1;i = last[i]){//倒序输出
            res.push_back(nums[i]);
        }
        return res;
    }
};
昨天以前LeetCode 每日一题题解

每日一题-找出所有子集的异或总和再求和🟢

2025年4月5日 00:00

一个数组的 异或总和 定义为数组中所有元素按位 XOR 的结果;如果数组为 ,则异或总和为 0

  • 例如,数组 [2,5,6]异或总和2 XOR 5 XOR 6 = 1

给你一个数组 nums ,请你求出 nums 中每个 子集异或总和 ,计算并返回这些值相加之

注意:在本题中,元素 相同 的不同子集应 多次 计数。

数组 a 是数组 b 的一个 子集 的前提条件是:从 b 删除几个(也可能不删除)元素能够得到 a

 

示例 1:

输入:nums = [1,3]
输出:6
解释:[1,3] 共有 4 个子集:
- 空子集的异或总和是 0 。
- [1] 的异或总和为 1 。
- [3] 的异或总和为 3 。
- [1,3] 的异或总和为 1 XOR 3 = 2 。
0 + 1 + 3 + 2 = 6

示例 2:

输入:nums = [5,1,6]
输出:28
解释:[5,1,6] 共有 8 个子集:
- 空子集的异或总和是 0 。
- [5] 的异或总和为 5 。
- [1] 的异或总和为 1 。
- [6] 的异或总和为 6 。
- [5,1] 的异或总和为 5 XOR 1 = 4 。
- [5,6] 的异或总和为 5 XOR 6 = 3 。
- [1,6] 的异或总和为 1 XOR 6 = 7 。
- [5,1,6] 的异或总和为 5 XOR 1 XOR 6 = 2 。
0 + 5 + 1 + 6 + 4 + 3 + 7 + 2 = 28

示例 3:

输入:nums = [3,4,5,6,7,8]
输出:480
解释:每个子集的全部异或总和值之和为 480 。

 

提示:

  • 1 <= nums.length <= 12
  • 1 <= nums[i] <= 20

O(n) 数学做法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年3月16日 22:31

提示 1

对于异或运算,每个比特位是互相独立的,我们可以先思考只有一个比特位的情况,也就是 $\textit{nums}$ 中只有 $0$ 和 $1$ 的情况。(从特殊到一般)

在这种情况下,如果子集中有偶数个 $1$,那么异或和为 $0$;如果子集中有奇数个 $1$,那么异或和为 $1$。所以关键是求出异或和为 $1$ 的子集个数。

设 $\textit{nums}$ 的长度为 $n$,且包含 $1$。我们可以先把其中一个 $1$ 拿出来,剩下 $n-1$ 个数随便选或不选,有 $2^{n-1}$ 种选法。

  • 如果这 $n-1$ 个数中选了偶数个 $1$,那么放入我们拿出来的 $1$(选这个 $1$),得到奇数个 $1$,异或和为 $1$。
  • 如果这 $n-1$ 个数中选了奇数个 $1$,那么不放入我们拿出来的 $1$(不选这个 $1$),得到奇数个 $1$,异或和为 $1$。

所以,恰好有 $2^{n-1}$ 个子集的异或和为 $1$。

注意这个结论与 $\textit{nums}$ 中有多少个 $1$ 是无关的,只要有 $1$,异或和为 $1$ 的子集个数就是 $2^{n-1}$。如果 $\textit{nums}$ 中没有 $1$,那么有 $0$ 个子集的异或和为 $1$。

所以,在有至少一个 $1$ 的情况下,$\textit{nums}$ 的所有子集的异或和的总和为

$$
2^{n-1}
$$

其他证明方法见文末。

提示 2

推广到多个比特位的情况。

例如 $\textit{nums}=[3,2,8]$,第 $0,1,3$ 个比特位上有 $1$,每个比特位对应的「所有子集的异或和的总和」分别为

$$
2^0 \cdot 2^{n-1},\ 2^1 \cdot 2^{n-1},\ 2^3\cdot 2^{n-1}
$$

相加得

$$
(2^0 + 2^1 + 2^3) \cdot 2^{n-1}
$$

怎么知道哪些比特位上有 $1$?计算 $\textit{nums}$ 的所有元素的 OR,即 $1011_{(2)}$。

注意到,所有元素的 OR,就是上例中的 $2^0 + 2^1 + 2^3$。

一般地,设 $\textit{nums}$ 所有元素的 OR 为 $\textit{or}$,$\textit{nums}$ 的所有子集的异或和的总和为

$$
\textit{or} \cdot 2^{n-1}
$$

###py

class Solution:
    def subsetXORSum(self, nums: List[int]) -> int:
        return reduce(or_, nums) << (len(nums) - 1)

###java

class Solution {
    public int subsetXORSum(int[] nums) {
        int or = 0;
        for (int x : nums) {
            or |= x;
        }
        return or << (nums.length - 1);
    }
}

###cpp

class Solution {
public:
    int subsetXORSum(vector<int>& nums) {
        int or_ = 0;
        for (int x : nums) {
            or_ |= x;
        }
        return or_ << (nums.size() - 1);
    }
};

###cpp

class Solution {
public:
    int subsetXORSum(vector<int>& nums) {
        return reduce(nums.begin(), nums.end(), 0, bit_or()) << (nums.size() - 1);
    }
};

###c

int subsetXORSum(int* nums, int numsSize) {
    int or = 0;
    for (int i = 0; i < numsSize; i++) {
        or |= nums[i];
    }
    return or << (numsSize - 1);
}

###go

func subsetXORSum(nums []int) int {
    or := 0
    for _, x := range nums {
        or |= x
    }
    return or << (len(nums) - 1)
}

###js

var subsetXORSum = function(nums) {
    return nums.reduce((or, x) => or | x, 0) << (nums.length - 1);
};

###rust

impl Solution {
    pub fn subset_xor_sum(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        nums.into_iter().reduce(|or, x| or | x).unwrap() << (n - 1)
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

附:其他证明方法

定理 1:大小为 $m$ 的集合中,有 $2^{m-1}$ 个大小为奇数的子集。其中 $m$ 是正整数。

第一种证法

分类讨论:

  • 如果 $m$ 是奇数,那么对于一个大小为奇数的子集,其补集的大小为偶数,因为奇数减奇数等于偶数。比如 $m=5$,选 $3$ 个数,剩余数字个数 $5-3=2$ 是偶数。同样地,偶数大小的子集,其补集为奇数。不同的奇数大小子集,所对应的偶数大小子集也是不同的(反过来也是)。所以可以把 $2^m$ 个子集均分成两部分:恰好有 $2^{m-1}$ 个大小为奇数的子集,恰好有 $2^{m-1}$ 个大小为偶数的子集。这两部分是一一对应的(双射)。
  • 如果 $m$ 是偶数,我们可以先拿一个数出来,剩下 $m-1$ 个数,且 $m-1$ 是奇数。根据上面的结论,恰好有 $2^{m-2}$ 个大小为奇数的子集,恰好有 $2^{m-2}$ 个大小为偶数的子集。然后我们把拿出来的 $1$,加到每个大小为偶数的子集中,得到 $2^{m-2}$ 个大小为奇数的子集。所以一共有 $2^{m-2} + 2^{m-2} = 2^{m-1}$ 个大小为奇数的子集。

综上所述,无论 $m$ 是奇是偶,都有 $2^{m-1}$ 个大小为奇数的子集。

第二种证法

根据二项式定理,我们有

$$
2^m = (1+1)^m = \binom m 0 + \binom m 1 + \binom m 2 + \cdots + \binom m m
$$

以及

$$
0^m = (1-1)^m = \binom m 0 - \binom m 1 + \binom m 2 - \cdots + (-1)^m \binom m m
$$

两个式子相减,得

$$
2^m = 2\cdot\left[\binom m 1 + \binom m 3 + \binom m 5 + \cdots\right]
$$

$$
\binom m 1 + \binom m 3 + \binom m 5 + \cdots = 2^{m-1}
$$

定理 2:如果 $\textit{nums}$ 包含 $1$,那么恰有 $2^{n-1}$ 个子集有奇数个 $1$。

证明

设 $\textit{nums}$ 的长度为 $n$,其中有 $m$ 个 $1$ 和 $n-m$ 个 $0$。

先从 $m$ 个 $1$ 中选奇数个 $1$,根据定理 1,这有 $2^{m-1}$ 种选法。

再选 $0$,这 $n-m$ 个 $0$,每个 $0$ 选或不选都可以,有 $2^{n-m}$ 种选法。

根据乘法原理,一共有

$$
2^{m-1}\cdot 2^{n-m} = 2^{n-1}
$$

个子集有奇数个 $1$。

变形题

本题有很多变形题,例如:

  • 把子集(子序列)改成连续子数组,要怎么做?
  • 所有子序列的异或和的异或和是多少?
  • 所有子序列的元素和的异或和是多少?

按照「子数组/子序列」的「元素和/异或和」的「总和/异或和」组合题目,一共可以得到八道题。解答见 灵茶八题

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

找出所有子集的异或总和再求和

2021年5月19日 00:24

方法一:递归法枚举子集

思路与算法

我们用函数 $\textit{dfs}(\textit{val}, \textit{idx})$ 来递归枚举数组 $\textit{nums}$ 的子集。其中 $\textit{val}$ 代表当前选取部分的异或值,$\textit{idx}$ 代表递归的当前位置。

我们用 $n$ 来表示 $\textit{nums}$ 的长度。在进入 $\textit{dfs}(\textit{val}, \textit{idx})$ 时,数组中 $[0,\textit{idx} - 1]$ 部分的选取情况是已经确定的,而 $[\textit{idx}, n)$ 部分的选取情况还未确定。我们需要确定 $\textit{idx}$ 位置的选取情况,然后求解子问题 $\textit{dfs}(\textit{val'}, \textit{idx} + 1)$。

此时选取情况有两种:

  • 选取,此时 $\textit{val'} = \textit{val} \oplus \textit{nums}[\textit{idx}]$,其中 $\oplus$ 代表异或运算;

  • 不选取,此时 $\textit{val'} = \textit{val}$。

当 $\textit{idx} = n$ 时,递归结束。与此同时,我们维护这些子集异或总和 $\textit{val}$ 的和。

代码

###C++

class Solution {
public:
    int res;
    int n;
    
    void dfs(int val, int idx, vector<int>& nums){
        if (idx == n){
            // 终止递归
            res += val;
            return;
        }
        // 考虑选择当前数字
        dfs(val ^ nums[idx], idx + 1, nums);
        // 考虑不选择当前数字
        dfs(val, idx + 1, nums);
    }
    
    int subsetXORSum(vector<int>& nums) {
        res = 0;
        n = nums.size();
        dfs(0, 0, nums);
        return res;
    }
};

###Python

class Solution:
    def subsetXORSum(self, nums: List[int]) -> int:
        res = 0
        n = len(nums)
        def dfs(val, idx):
            nonlocal res
            if idx == n:
                # 终止递归
                res += val
                return
            # 考虑选择当前数字
            dfs(val ^ nums[idx], idx + 1)
            # 考虑不选择当前数字
            dfs(val, idx + 1)
        
        dfs(0, 0)
        return res

###Java

class Solution {
    int res;
    int n;

    // 深度优先搜索
    void dfs(int val, int idx, int[] nums) {
        if (idx == n) {
            // 终止递归
            res += val;
            return;
        }
        // 考虑选择当前数字
        dfs(val ^ nums[idx], idx + 1, nums);
        // 考虑不选择当前数字
        dfs(val, idx + 1, nums);
    }

    public int subsetXORSum(int[] nums) {
        res = 0;
        n = nums.length;
        dfs(0, 0, nums);
        return res;
    }
}

###C#

public class Solution {
    int res;
    int n;

    // 深度优先搜索
    void Dfs(int val, int idx, int[] nums) {
        if (idx == n) {
            // 终止递归
            res += val;
            return;
        }
        // 考虑选择当前数字
        Dfs(val ^ nums[idx], idx + 1, nums);
        // 考虑不选择当前数字
        Dfs(val, idx + 1, nums);
    }

    public int SubsetXORSum(int[] nums) {
        res = 0;
        n = nums.Length;
        Dfs(0, 0, nums);
        return res;
    }
}

###Go

func subsetXORSum(nums []int) int {
    return dfs(0, 0, nums)
}

func dfs(val, idx int, nums []int) int {
    if idx == len(nums) {
        // 终止递归
        return val
    }
    // 考虑选择当前数字, 考虑不选择当前数字
    return dfs(val ^ nums[idx], idx + 1, nums) + dfs(val, idx + 1, nums)
}

###C

int dfs(int val, int idx, int* nums, int numsSize) {
    if (idx == numsSize) {
        // 终止递归
        return val;
    }
    // 考虑选择当前数字, 考虑不选择当前数字
    return dfs(val ^ nums[idx], idx + 1, nums, numsSize) + dfs(val, idx + 1, nums, numsSize);
}

int subsetXORSum(int* nums, int numsSize) {
    return dfs(0, 0, nums, numsSize);
}

###JavaScript

var subsetXORSum = function(nums) {
    return dfs(0, 0, nums);
};

// 深度优先搜索
function dfs(val, idx, nums) {
    if (idx === nums.length) {
        // 终止递归
        return val;
    }
    // 考虑选择当前数字, 考虑不选择当前数字
    return dfs(val ^ nums[idx], idx + 1, nums) + dfs(val, idx + 1, nums);
}

###JavaScript

function subsetXORSum(nums: number[]): number {
    return dfs(0, 0, nums);
};

// 深度优先搜索
function dfs(val: number, idx: number, nums: number[]): number {
    if (idx === nums.length) {
        // 终止递归
        return val;
    }
    // 考虑选择当前数字, 考虑不选择当前数字
    return dfs(val ^ nums[idx], idx + 1, nums) + dfs(val, idx + 1, nums);
}

###Rust

impl Solution {
    pub fn subset_xor_sum(nums: Vec<i32>) -> i32 {
        fn dfs(val: i32, idx: usize, nums: &[i32]) -> i32 {
            if idx == nums.len() {
                // 终止递归
                return val;
            }
            // 考虑选择当前数字, 考虑不选择当前数字
            dfs(val ^ nums[idx], idx + 1, nums) + dfs(val, idx + 1, nums)
        }
        
        dfs(0, 0, &nums)
    }
}

复杂度分析

  • 时间复杂度:$O(2^n)$,其中 $n$ 为 $\textit{nums}$ 的长度。

    第 $\textit{idx}$ 层的递归函数共有 $2^\textit{idx}$ 个,总计共会调用 $\sum_{i = 0}^n 2^i = 2^{n+1} - 1$ 次递归函数。而每个递归函数的时间复杂度均为 $O(1)$。

  • 空间复杂度:$O(n)$,即为递归时的栈空间开销。

方法二:迭代法枚举子集

提示 $1$

一个长度为 $n$ 的数组 $\textit{nums}$ 有 $2^n$ 个子集(包括空集与自身)。我们可以将这些子集一一映射到 $[0, 2^n-1]$ 中的整数。

提示 $2$

数组中的每个元素都有「选取」与「未选取」两个状态,可以对应一个二进制位的 $1$ 与 $0$。那么对于一个长度为 $n$ 的数组 $\textit{nums}$,我们也可以用 $n$ 个二进制位的整数来唯一表示每个元素的选取情况。此时该整数第 $j$ 位的取值表示数组第 $j$ 个元素是否包含在对应的子集中。

思路与算法

我们也可以用迭代来实现子集枚举。

根据 提示 $1$提示 $2$,我们枚举 $[0, 2^n-1]$ 中的整数 $i$,其第 $j$ 位的取值表示 $\textit{nums}$ 的第 $j$ 个元素是否包含在对应的子集中。

对于每个整数 $i$,我们遍历它的每一位计算对应子集的异或总和,并维护这些值之和。

代码

###C++

class Solution {
public:
    int subsetXORSum(vector<int>& nums) {
        int res = 0;
        int n = nums.size();
        for (int i = 0; i < (1 << n); ++i){   // 遍历所有子集
            int tmp = 0;
            for (int j = 0; j < n; ++j){   // 遍历每个元素
                if (i & (1 << j)){
                    tmp ^= nums[j];
                }
            }
            res += tmp;
        }
        return res;
    }
};

###Python

class Solution:
    def subsetXORSum(self, nums: List[int]) -> int:
        res = 0
        n = len(nums)
        for i in range(1 << n):   # 遍历所有子集
            tmp = 0
            for j in range(n):   # 遍历每个元素
                if i & (1 << j):
                    tmp ^= nums[j]
            res += tmp
        return res

###Java

class Solution {
    public long mostPoints(int[][] questions) {
        int n = questions.length;
        long[] dp = new long[n + 1]; // 解决每道题及以后题目的最高分数
        for (int i = n - 1; i >= 0; i--) {
            dp[i] = Math.max(dp[i + 1], questions[i][0] + dp[Math.min(n, i + questions[i][1] + 1)]);
        }
        return dp[0];
    }
}

###C#

public class Solution {
    public long MostPoints(int[][] questions) {
        int n = questions.Length;
        long[] dp = new long[n + 1]; // 解决每道题及以后题目的最高分数
        for (int i = n - 1; i >= 0; i--) {
            dp[i] = Math.Max(dp[i + 1], questions[i][0] + dp[Math.Min(n, i + questions[i][1] + 1)]);
        }
        return dp[0];
    }
}

###Go

func mostPoints(questions [][]int) int64 {
    n := len(questions)
    dp := make([]int64, n + 1) // 解决每道题及以后题目的最高分数
    for i := n - 1; i >= 0; i-- {
        dp[i] = max(dp[i + 1], int64(questions[i][0]) + dp[min(n, i + questions[i][1] + 1)])
    }
    return dp[0]
}

###C

long long max(long long a, long long b) {
    return a > b ? a : b;
}

long long min(long long a, long long b) {
    return a < b ? a : b;
}

long long mostPoints(int** questions, int questionsSize, int* questionsColSize) {
    long long dp[questionsSize + 1]; // 解决每道题及以后题目的最高分数
    memset(dp, 0, sizeof(dp));
    for (int i = questionsSize - 1; i >= 0; --i) {
        dp[i] = max(dp[i + 1], questions[i][0] + dp[min(questionsSize, i + questions[i][1] + 1)]);
    }
    long long result = dp[0];
    return result;
}

###JavaScript

var mostPoints = function(questions) {
    const n = questions.length;
    const dp = new Array(n + 1).fill(0); // 解决每道题及以后题目的最高分数
    for (let i = n - 1; i >= 0; i--) {
        dp[i] = Math.max(dp[i + 1], questions[i][0] + dp[Math.min(n, i + questions[i][1] + 1)]);
    }
    return dp[0];
};

###JavaScript

function mostPoints(questions: number[][]): number {
    const n = questions.length;
    const dp: number[] = new Array(n + 1).fill(0); // 解决每道题及以后题目的最高分数
    for (let i = n - 1; i >= 0; i--) {
        dp[i] = Math.max(dp[i + 1], questions[i][0] + dp[Math.min(n, i + questions[i][1] + 1)]);
    }
    return dp[0];
};

###Rust

impl Solution {
    pub fn most_points(questions: Vec<Vec<i32>>) -> i64 {
        let n = questions.len();
        let mut dp = vec![0i64; n + 1]; // 解决每道题及以后题目的最高分数
        for i in (0..n).rev() {
            dp[i] = dp[i + 1].max(questions[i][0] as i64 + dp[(n).min(i + questions[i][1] as usize + 1)]);
        }
        dp[0]
    }
}

复杂度分析

  • 时间复杂度:$O(n2^n)$,其中 $n$ 为 $\textit{nums}$ 的长度。我们遍历了 $\textit{nums}$ 的 $2^n$ 个子集,每个子集需要 $O(n)$ 的时间计算异或总和。

  • 空间复杂度:$O(1)$。

方法三:按位考虑 + 二项式展开

提示 $1$

由于异或运算本质上是按位操作,因此我们可以按位考虑取值情况。

提示 $2$

对于数组中所有元素的某一位,存在两种可能:

  • 第一种,所有元素该位都为 $0$;

  • 第二种,至少有一个元素该位为 $1$。

假设数组元素个数为 $n$,那么第一种情况下,所有子集异或总和中该位均为 $0$;第二种情况下,所有子集异或总和中该位为 $0$ 的个数与为 $1$ 的个数相等,均为 $2^{n-1}$。

提示 $2$ 解释

首先,一个子集的异或总和中某位为 $0$ 当且仅当子集内该位为 $1$ 的元素数量为偶数(包括 $0$),某位为 $1$ 当且仅当子集内该位为 $1$ 的元素数量为奇数。那么第一种情况时显然所有子集的异或总和中该位都为 $0$。

其次,假设数组内某一位为 $1$ 的元素个数为 $m$,那么它的子集里面包含 $k$ 个 $1$ 的数量为($k \le m \le n$):

$$
2^{n-m}\binom{k}{m},
$$

那么包含奇数个 $1$ 的子集数量为:

$$
\sum_{k\ \text{is odd}, 0\le k\le m}2^{n-m}\binom{k}{m} = 2^{n-m}\sum_{k\ \text{is odd}, 0\le k\le m}\binom{k}{m},
$$

同理,包含偶数个 $1$ 的子集数量为:

$$
\sum_{k\ \text{is even}, 0\le k\le m}2^{n-m}\binom{k}{m} = 2^{n-m}\sum_{k\ \text{is even}, 0\le k\le m}\binom{k}{m}.
$$

事实上,我们通过对于 $(x + 1)^m$ 二项式展开并取 $x = -1$ 时,有:

$$
(-1+1)^m = \sum_{k = 0}^{m} \binom{k}{m} (-1)^k 1^{m-k} = \sum_{k\ \text{is even}, 0\le k\le m}\binom{k}{m} - \sum_{k\ \text{is odd}, 0\le k\le m}\binom{k}{m} = 0.
$$

这也就说明,包含奇数个 $1$ 的子集数量与包含偶数个 $1$ 的子集数量相等,均为全体子集数量的一半,即 $2^{n-1}$。

思路与算法

根据 提示 $2$,我们用 $\textit{res}$ 来维护数组全体元素的按位或,使得 $\textit{res}$ 的某一位为 $1$ 当且仅当数组中存在该位为 $1$ 的元素。

那么,对于 $\textit{res}$ 中为 $1$ 的任何一位,其对于结果的贡献均为该位对应的值乘上异或总和为 $1$ 的子集数量 $2^{n-1}$;对于为 $0$ 的任何一位,乘上 $2^{n-1}$ 也不会对结果产生影响。因此我们可以直接将 $\textit{res}$ 算术左移 $n - 1$ 位作为结果返回。

代码

###C++

class Solution {
public:
    int subsetXORSum(vector<int>& nums) {
        int res = 0;
        int n = nums.size();
        for (auto num: nums){
            res |= num;
        }
        return res << (n - 1);
    }
};

###Python

class Solution:
    def subsetXORSum(self, nums: List[int]) -> int:
        res = 0
        n = len(nums)
        for num in nums:
            res |= num
        return res << (n - 1)

###Java

class Solution {
    public int subsetXORSum(int[] nums) {
        int res = 0;
        int n = nums.length;
        for (int num : nums) {
            res |= num;
        }
        return res << (n - 1);
    }
}

###C#

public class Solution {
    public int SubsetXORSum(int[] nums) {
        int res = 0;
        int n = nums.Length;
        foreach (int num in nums) {
            res |= num;
        }
        return res << (n - 1);
    }
}

###Go

func subsetXORSum(nums []int) int {
    res := 0
    n := len(nums)
    for _, num := range nums {
        res |= num
    }
    return res << (n - 1)
}

###C

int subsetXORSum(int* nums, int numsSize) {
    int res = 0;
    for (int i = 0; i < numsSize; ++i) {
        res |= nums[i];
    }
    return res << (numsSize - 1);
}

###JavaScript

var subsetXORSum = function(nums) {
    let res = 0;
    const n = nums.length;
    for (let num of nums) {
        res |= num;
    }
    return res << (n - 1);
};

###JavaScript

function subsetXORSum(nums: number[]): number {
    let res = 0;
    const n = nums.length;
    for (let num of nums) {
        res |= num;
    }
    return res << (n - 1);
};

###Rust

impl Solution {
    pub fn subset_xor_sum(nums: Vec<i32>) -> i32 {
        let mut res = 0;
        let n = nums.len();
        for &num in &nums {
            res |= num;
        }
        res << (n - 1)
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为 $\textit{nums}$ 的长度,即为一遍遍历数组的时间复杂度。

  • 空间复杂度:$O(1)$。

每日一题-最深叶节点的最近公共祖先🟡

2025年4月4日 00:00

给你一个有根节点 root 的二叉树,返回它 最深的叶节点的最近公共祖先 。

回想一下:

  • 叶节点 是二叉树中没有子节点的节点
  • 树的根节点的 深度 为 0,如果某一节点的深度为 d,那它的子节点的深度就是 d+1
  • 如果我们假定 A 是一组节点 S 的 最近公共祖先S 中的每个节点都在以 A 为根节点的子树中,且 A 的深度达到此条件下可能的最大值。

 

示例 1:

输入:root = [3,5,1,6,2,0,8,null,null,7,4]
输出:[2,7,4]
解释:我们返回值为 2 的节点,在图中用黄色标记。
在图中用蓝色标记的是树的最深的节点。
注意,节点 6、0 和 8 也是叶节点,但是它们的深度是 2 ,而节点 7 和 4 的深度是 3 。

示例 2:

输入:root = [1]
输出:[1]
解释:根节点是树中最深的节点,它是它本身的最近公共祖先。

示例 3:

输入:root = [0,1,3,null,2]
输出:[2]
解释:树中最深的叶节点是 2 ,最近公共祖先是它自己。

 

提示:

  • 树中的节点数将在 [1, 1000] 的范围内。
  • 0 <= Node.val <= 1000
  • 每个节点的值都是 独一无二 的。

 

注意:本题与力扣 865 重复:https://leetcode-cn.com/problems/smallest-subtree-with-all-the-deepest-nodes/

两种递归思路(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2023年9月6日 07:52

前言

推荐先把 236. 二叉树的最近公共祖先 做了,对理解本题做法有帮助。

本题最深的叶子可能只有一个,此时这个叶子就是答案。如果最深的叶子不止一个,那么答案为所有最深叶子的最近公共祖先。

方法一:递归递归,有递有归

回顾 236 题的做法:

  • 如果要找的节点只在左子树中,那么最近公共祖先也只在左子树中。
  • 如果要找的节点只在右子树中,那么最近公共祖先也只在右子树中。
  • 如果要找的节点左右子树都有,那么最近公共祖先就是当前节点。

对于本题,要找的节点是最深的叶子。

如果左子树的最大深度比右子树的大,那么(子树中的)最深叶子就只在左子树中,所以(子树中的)最深叶子的最近公共祖先也只在左子树中。

如果左右子树的最大深度一样呢?当前节点一定是最近公共祖先吗?

不一定。比如上图节点 $1$ 的左右子树最深叶子 $0,8$ 的深度都是 $2$,但该深度并不是全局最大深度,所以节点 $1$ 并不是答案。

根据以上讨论,正确做法如下:

  1. 从根节点开始递归,同时维护全局最大深度 $\textit{maxDepth}$。
  2. 在「递」的时候往下传 $\textit{depth}$,用来表示当前节点的深度。
  3. 在「归」的时候往上传当前子树最深的空节点的深度。这里为了方便,用空节点代替叶子,因为最深的空节点的上面一定是最深的叶子。
  4. 设左子树最深空节点的深度为 $\textit{leftMaxDepth}$,右子树最深空节点的深度为 $\textit{rightMaxDepth}$。如果最深的空节点左右子树都有,即 $\textit{leftMaxDepth}=\textit{rightMaxDepth}=\textit{maxDepth}$,那么更新答案为当前节点。注意这并不代表我们找到了答案,如果后面发现了更深的空节点,答案还会更新。另外注意,这个判断方式在只有一个最深叶子的情况下,也是正确的。
class Solution:
    def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        ans = None
        max_depth = -1  # 全局最大深度
        def dfs(node: Optional[TreeNode], depth: int) -> int:
            nonlocal ans, max_depth
            if node is None:
                max_depth = max(max_depth, depth)  # 维护全局最大深度
                return depth
            left_max_depth = dfs(node.left, depth + 1)  # 左子树最深空节点的深度
            right_max_depth = dfs(node.right, depth + 1)  # 右子树最深空节点的深度
            if left_max_depth == right_max_depth == max_depth:  # 最深的空节点左右子树都有
                ans = node
            return max(left_max_depth, right_max_depth)  # 当前子树最深空节点的深度
        dfs(root, 0)
        return ans
class Solution {
    private TreeNode ans;
    private int maxDepth = -1; // 全局最大深度

    public TreeNode lcaDeepestLeaves(TreeNode root) {
        dfs(root, 0);
        return ans;
    }

    private int dfs(TreeNode node, int depth) {
        if (node == null) {
            maxDepth = Math.max(maxDepth, depth); // 维护全局最大深度
            return depth;
        }
        int leftMaxDepth = dfs(node.left, depth + 1); // 左子树最深空节点的深度
        int rightMaxDepth = dfs(node.right, depth + 1); // 右子树最深空节点的深度
        if (leftMaxDepth == rightMaxDepth && leftMaxDepth == maxDepth) { // 最深的空节点左右子树都有
            ans = node;
        }
        return Math.max(leftMaxDepth, rightMaxDepth); // 当前子树最深空节点的深度
    }
}
class Solution {
public:
    TreeNode* lcaDeepestLeaves(TreeNode* root) {
        TreeNode* ans = nullptr;
        int max_depth = -1; // 全局最大深度
        auto dfs = [&](this auto&& dfs, TreeNode* node, int depth) {
            if (node == nullptr) {
                max_depth = max(max_depth, depth); // 维护全局最大深度
                return depth;
            }
            int left_max_depth = dfs(node->left, depth + 1); // 左子树最深空节点的深度
            int right_max_depth = dfs(node->right, depth + 1); // 右子树最深空节点的深度
            if (left_max_depth == right_max_depth && left_max_depth == max_depth) { // 最深的空节点左右子树都有
                ans = node;
            }
            return max(left_max_depth, right_max_depth); // 当前子树最深空节点的深度
        };
        dfs(root, 0);
        return ans;
    }
};
#define MAX(a, b) ((b) > (a) ? (b) : (a))

struct TreeNode* lcaDeepestLeaves(struct TreeNode* root) {
    struct TreeNode* ans = NULL;
    int max_depth = -1; // 全局最大深度

    int dfs(struct TreeNode* node, int depth) {
        if (node == NULL) {
            max_depth = MAX(max_depth, depth); // 维护全局最大深度
            return depth;
        }
        int left_max_depth = dfs(node->left, depth + 1); // 左子树最深空节点的深度
        int right_max_depth = dfs(node->right, depth + 1); // 右子树最深空节点的深度
        if (left_max_depth == right_max_depth && left_max_depth == max_depth) { // 最深的空节点左右子树都有
            ans = node;
        }
        return MAX(left_max_depth, right_max_depth); // 当前子树最深空节点的深度
    }

    dfs(root, 0);
    return ans;
}
func lcaDeepestLeaves(root *TreeNode) (ans *TreeNode) {
    maxDepth := -1 // 全局最大深度
    var dfs func(*TreeNode, int) int
    dfs = func(node *TreeNode, depth int) int {
        if node == nil {
            maxDepth = max(maxDepth, depth) // 维护全局最大深度
            return depth
        }
        leftMaxDepth := dfs(node.Left, depth+1) // 左子树最深空节点的深度
        rightMaxDepth := dfs(node.Right, depth+1) // 右子树最深空节点的深度
        if leftMaxDepth == rightMaxDepth && leftMaxDepth == maxDepth { // 最深的空节点左右子树都有
            ans = node
        }
        return max(leftMaxDepth, rightMaxDepth) // 当前子树最深空节点的深度
    }
    dfs(root, 0)
    return
}
var lcaDeepestLeaves = function(root) {
    let ans = null;
    let maxDepth = -1; // 全局最大深度
    function dfs(node, depth) {
        if (node === null) {
            maxDepth = Math.max(maxDepth, depth); // 维护全局最大深度
            return depth;
        }
        const leftMaxDepth = dfs(node.left, depth + 1); // 左子树最深空节点的深度
        const rightMaxDepth = dfs(node.right, depth + 1); // 右子树最深空节点的深度
        if (leftMaxDepth === rightMaxDepth && leftMaxDepth === maxDepth) { // 最深的空节点左右子树都有
            ans = node;
        }
        return Math.max(leftMaxDepth, rightMaxDepth);// 当前子树最深空节点的深度
    }
    dfs(root, 0);
    return ans;
};
use std::rc::Rc;
use std::cell::RefCell;

impl Solution {
    pub fn lca_deepest_leaves(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
        let mut ans = None;
        let mut max_depth = -1; // 全局最大深度

        fn dfs(node: &Option<Rc<RefCell<TreeNode>>>, depth: i32, max_depth: &mut i32, ans: &mut Option<Rc<RefCell<TreeNode>>>) -> i32 {
            if let Some(node) = node {
                let x = node.borrow();
                let left_max_depth = dfs(&x.left, depth + 1, max_depth, ans); // 左子树最深空节点的深度
                let right_max_depth = dfs(&x.right, depth + 1, max_depth, ans); // 右子树最深空节点的深度
                if left_max_depth == right_max_depth && left_max_depth == *max_depth { // 最深的空节点左右子树都有
                    *ans = Some(node.clone());
                }
                left_max_depth.max(right_max_depth) // 当前子树最深空节点的深度
            } else {
                *max_depth = (*max_depth).max(depth); // 维护全局最大深度
                depth
            }
        }

        dfs(&root, 0, &mut max_depth, &mut ans);
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$。每个节点都会恰好访问一次。
  • 空间复杂度:$\mathcal{O}(n)$。最坏情况下,二叉树是一条链,递归需要 $\mathcal{O}(n)$ 的栈空间。

方法二:自底向上

能否不用外部变量 $\textit{ans}$ 和 $\textit{maxDepth}$ 呢?

把每棵子树都看成是一个「子问题」,即对于每棵子树,我们需要知道:

  • 这棵子树最深叶子的深度。这里是指叶子在这棵子树内的深度,而不是在整棵二叉树的视角下的深度。相当于这棵子树的高度
  • 这棵子树的最深叶子的最近公共祖先 $\textit{lca}$。

设子树的根节点为 $\textit{node}$,$\textit{node}$ 的左子树的高度为 $\textit{leftHeight}$,$\textit{node}$ 的右子树的高度为 $\textit{rightHeight}$。分类讨论:

  • 如果 $\textit{leftHeight} > \textit{rightHeight}$,那么 $\textit{node}$ 子树的高度为 $\textit{leftHeight} + 1$,$\textit{lca}$ 是左子树的 $\textit{lca}$。
  • 如果 $\textit{leftHeight} < \textit{rightHeight}$,那么 $\textit{node}$ 子树的高度为 $\textit{rightHeight} + 1$,$\textit{lca}$ 是右子树的 $\textit{lca}$。
  • 如果 $\textit{leftHeight} = \textit{rightHeight}$,那么 $\textit{node}$ 子树的高度为 $\textit{leftHeight} + 1$,$\textit{lca}$ 就是 $\textit{node}$。
class Solution:
    def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        def dfs(node: Optional[TreeNode]) -> (int, Optional[TreeNode]):
            if node is None:
                return 0, None
            left_height, left_lca = dfs(node.left)
            right_height, right_lca = dfs(node.right)
            if left_height > right_height:  # 左子树更高
                return left_height + 1, left_lca
            if left_height < right_height:  # 右子树更高
                return right_height + 1, right_lca
            return left_height + 1, node  # 一样高
        return dfs(root)[1]
class Solution {
    public TreeNode lcaDeepestLeaves(TreeNode root) {
        return dfs(root).getValue();
    }

    private Pair<Integer, TreeNode> dfs(TreeNode node) {
        if (node == null) {
            return new Pair<>(0, null);
        }
        Pair<Integer, TreeNode> left = dfs(node.left);
        Pair<Integer, TreeNode> right = dfs(node.right);
        if (left.getKey() > right.getKey()) { // 左子树更高
            return new Pair<>(left.getKey() + 1, left.getValue());
        }
        if (left.getKey() < right.getKey()) { // 右子树更高
            return new Pair<>(right.getKey() + 1, right.getValue());
        }
        return new Pair<>(left.getKey() + 1, node); // 一样高
    }
}
class Solution {
    pair<int, TreeNode*> dfs(TreeNode* node) {
        if (node == nullptr) {
            return {0, nullptr};
        }
        auto [left_height, left_lca] = dfs(node->left);
        auto [right_height, right_lca] = dfs(node->right);
        if (left_height > right_height) { // 左子树更高
            return {left_height + 1, left_lca};
        }
        if (left_height < right_height) { // 右子树更高
            return {right_height + 1, right_lca};
        }
        return {left_height + 1, node}; // 一样高
    }

public:
    TreeNode* lcaDeepestLeaves(TreeNode* root) {
        return dfs(root).second;
    }
};
typedef struct {
    int height;
    struct TreeNode* lca;
} Pair;

Pair dfs(struct TreeNode* node) {
    if (node == NULL) {
        return (Pair) {0, NULL};
    }
    Pair left = dfs(node->left);
    Pair right = dfs(node->right);
    if (left.height > right.height) { // 左子树更高
        return (Pair) {left.height + 1, left.lca};
    }
    if (left.height < right.height) { // 右子树更高
        return (Pair) {right.height + 1, right.lca};
    }
    return (Pair) {left.height + 1, node}; // 一样高
}

struct TreeNode* lcaDeepestLeaves(struct TreeNode* root) {
    return dfs(root).lca;
}
func dfs(node *TreeNode) (int, *TreeNode) {
    if node == nil {
        return 0, nil
    }
    leftHeight, leftLCA := dfs(node.Left)
    rightHeight, rightLCA := dfs(node.Right)
    if leftHeight > rightHeight { // 左子树更高
        return leftHeight + 1, leftLCA
    }
    if leftHeight < rightHeight { // 右子树更高
        return rightHeight + 1, rightLCA
    }
    return leftHeight + 1, node // 一样高
}

func lcaDeepestLeaves(root *TreeNode) *TreeNode {
    _, lca := dfs(root)
    return lca
}
var dfs = function(node) {
    if (node === null) {
        return [0, null];
    }
    const [leftHeight, leftLca] = dfs(node.left);
    const [rightHeight, rightLca] = dfs(node.right);
    if (leftHeight > rightHeight) { // 左子树更高
        return [leftHeight + 1, leftLca];
    }
    if (leftHeight < rightHeight) { // 右子树更高
        return [rightHeight + 1, rightLca];
    }
    return [leftHeight + 1, node]; // 一样高
};

var lcaDeepestLeaves = function(root) {
    return dfs(root, 0)[1];
};
use std::rc::Rc;
use std::cell::RefCell;

impl Solution {
    pub fn lca_deepest_leaves(root: Option<Rc<RefCell<TreeNode>>>) -> Option<Rc<RefCell<TreeNode>>> {
        fn dfs(node: &Option<Rc<RefCell<TreeNode>>>) -> (i32, Option<Rc<RefCell<TreeNode>>>) {
            if let Some(node) = node {
                let x = node.borrow();
                let (left_height, left_lca) = dfs(&x.left);
                let (right_height, right_lca) = dfs(&x.right);
                if left_height > right_height {
                    return (left_height + 1, left_lca); // 左子树更高
                }
                if left_height < right_height {
                    return (right_height + 1, right_lca); // 右子树更高
                }
                (left_height + 1, Some(node.clone())) // 一样高
            } else {
                (0, None)
            }
        }
        dfs(&root).1
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$。每个节点都会恰好访问一次。
  • 空间复杂度:$\mathcal{O}(n)$。最坏情况下,二叉树是一条链,递归需要 $\mathcal{O}(n)$ 的栈空间。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

最深叶节点的最近公共祖先

2023年9月1日 10:23

方法一:递归

思路与算法

题目给出一个二叉树,要求返回它最深的叶节点的最近公共祖先。其中树的根节点的深度为 $0$,我们注意到所有深度最大的节点,都是树的叶节点。为方便说明,我们把最深的叶节点的最近公共祖先,称之为 $\textit{lca}$ 节点。

我们用递归的方式,进行深度优先搜索,对树中的每个节点进行递归,返回当前子树的最大深度 $d$ 和 $\textit{lca}$ 节点。如果当前节点为空,我们返回深度 $0$ 和空节点。在每次搜索中,我们递归地搜索左子树和右子树,然后比较左右子树的深度:

  • 如果左子树更深,最深叶节点在左子树中,我们返回 {左子树深度 + $1$,左子树的 $\textit{lca}$ 节点}
  • 如果右子树更深,最深叶节点在右子树中,我们返回 {右子树深度 + $1$,右子树的 $\textit{lca}$ 节点}
  • 如果左右子树一样深,左右子树都有最深叶节点,我们返回 {左子树深度 + $1$,当前节点}

最后我们返回根节点的 $\textit{lca}$ 节点即可。

代码

###C++

class Solution {
public:
    pair<TreeNode*, int> f(TreeNode* root) {
        if (!root) {
            return {root, 0};
        }

        auto left = f(root->left);
        auto right = f(root->right);

        if (left.second > right.second) {
            return {left.first, left.second + 1};
        }
        if (left.second < right.second) {
            return {right.first, right.second + 1};
        }
        return {root, left.second + 1};

    }

    TreeNode* lcaDeepestLeaves(TreeNode* root) {
        return f(root).first;
    }
};

###Java

class Solution {
    public TreeNode lcaDeepestLeaves(TreeNode root) {
        return f(root).getKey();
    }

    private Pair<TreeNode, Integer> f(TreeNode root) {
        if (root == null) {
            return new Pair<>(root, 0);
        }

        Pair<TreeNode, Integer> left = f(root.left);
        Pair<TreeNode, Integer> right = f(root.right);

        if (left.getValue() > right.getValue()) {
            return new Pair<>(left.getKey(), left.getValue() + 1);
        }
        if (left.getValue() < right.getValue()) {
            return new Pair<>(right.getKey(), right.getValue() + 1);
        }
        return new Pair<>(root, left.getValue() + 1);
    }
}

###C#

public class Solution {
    public TreeNode LcaDeepestLeaves(TreeNode root) {
        return f(root).Item1;
    }

    private Tuple<TreeNode, int> f(TreeNode root) {
        if (root == null) {
            return new Tuple<TreeNode, int>(root, 0);
        }

        Tuple<TreeNode, int> left = f(root.left);
        Tuple<TreeNode, int> right = f(root.right);

        if (left.Item2 > right.Item2) {
            return new Tuple<TreeNode, int>(left.Item1, left.Item2 + 1);
        }
        if (left.Item2 < right.Item2) {
            return new Tuple<TreeNode, int>(right.Item1, right.Item2 + 1);
        }
        return new Tuple<TreeNode, int>(root, left.Item2 + 1);
    }
}

###Python

class Solution:
    def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        def f(root):
            if not root:
                return 0, None

            d1, lca1 = f(root.left)
            d2, lca2 = f(root.right)

            if d1 > d2:
                return d1 + 1, lca1
            if d1 < d2:
                return d2 + 1, lca2
            return d1 + 1, root

        return f(root)[1]

###JavaScript

var lcaDeepestLeaves = function(root) {
    return f(root)[1];
};

function f(root) {
    if (!root) {
      return [0, root];
    }

    let [d1, lca1] = f(root.left);
    let [d2, lca2] = f(root.right);

    if (d1 > d2) {
      return [d1 + 1, lca1];
    }
    if (d1 < d2) {
      return [d2 + 1, lca2];
    }
    return [d1 + 1, root];
}

###Go

func lcaDeepestLeaves(root *TreeNode) *TreeNode {
    _, lca := f(root)
    return lca
}

func f(root *TreeNode) (int, *TreeNode) {
    if root == nil {
        return 0, nil
    }

    d1, lca1 := f(root.Left)
    h2, lca2 := f(root.Right)

    if d1 > h2 {
        return d1 + 1, lca1
    }
    if d1 < h2 {
        return h2 + 1, lca2
    }
    return d1 + 1, root
}

###C

struct Pair {
    struct TreeNode *node;
    int depth;
};

struct Pair f(struct TreeNode *root) {
    if (root == NULL) {
        return (struct Pair) {NULL, 0};
    }

    struct Pair left = f(root->left);
    struct Pair right = f(root->right);

    if (left.depth > right.depth) {
        return (struct Pair) {left.node, left.depth + 1};
    }
    if (left.depth < right.depth) {
        return (struct Pair) {right.node, right.depth + 1};
    }
    return (struct Pair) {root, left.depth + 1};
}

struct TreeNode *lcaDeepestLeaves(struct TreeNode *root) {
    return f(root).node;
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是树的节点数量。

  • 空间复杂度:$O(d)$,其中 $d$ 是树的深度。空间复杂度主要是递归的空间,最差情况为 $O(n)$,其中 $n$ 是树的节点数量。

两种思路,一种前序遍历,一种后序遍历,速度100%

作者 qiujunlin
2020年11月17日 19:21

解题思路

第一种容想到的常规解法

类似于前序遍历,从根节点开始,分别求左右子树的高度left,和right。

  • 情况1:left=right 那么两边子树的最深高度相同,返回本节点
  • 情况2:left<right 说明最深节点在右子树,直接返回右子树的递归结果
  • 情况2:left>right 说明最深节点在左子树,直接返回右子树的递归结果

其中求子树的高度需要定义一个方法,就是104. 二叉树的最大深度,很简单。
image.png

代码

###java

class Solution {
    public TreeNode lcaDeepestLeaves(TreeNode root) {
       if(root==null) return null;
       int left=dfs(root.left);
       int right=dfs(root.right);
       if(left==right) return root;
       else if(left<right) return lcaDeepestLeaves(root.right);
       return lcaDeepestLeaves(root.left);
    }
    int dfs(TreeNode  node){
      if(node==null) return 0;
      return 1+Math.max(dfs(node.right),dfs(node.left));
    }
}

第二种方法,

第二种方法其实就是求后序遍历,代码结构有点类似于求最大深度,只不过要想办法保存最近的节点,和返回深度

首先定义一个点来保存最近公共祖先,定义一个pre来保存上一次得到的最近公共祖先的深度。
在递归过程中,带一个参数level表示当前遍历到的节点的深度

如果node为空,返回当前深度。
如果不为空,则当前节点的逻辑为:
分别求左子树和右子树的最大深度,left和right

  • 1.left=right 如果相同,并且当前深度大于上一次的最大深度,说明当前节点为最新的最近公共祖先,上一次的没有当前这个深,将当前节点保存在结果中,并将深度pre更新。
  • 2.left不等于right 则直接返左右子树的最大深度
    image.png
class Solution {
    TreeNode res = null;
    int pre=0;
    public TreeNode lcaDeepestLeaves(TreeNode root) {
        dfs(root,1);
        return res;

    }
    int dfs(TreeNode  node,int depth){
      if(node==null) return depth;
      int left=dfs(node.left,depth+1);
      int right =dfs(node.right,depth+1);
      if(left==right&&left>=pre){
           res=node;
           pre=left;
      } 
      return Math.max(left,right);
    }
}

有序三元组中的最大值 II

2025年3月14日 09:49

方法一:贪心 + 前后缀数组

令数组 $\textit{nums}$ 的长度为 $n$。根据值公式 $(\textit{nums}[i] - \textit{nums}[j]) \times \textit{nums}[k]$ 可知,当固定 $j$ 时,$\textit{nums}[i]$ 和 $\textit{nums}[k]$ 分别取 $[0, j)$ 和 $[j + 1, n)$ 的最大值时,三元组的值最大。我们使用 $\textit{leftMax}[j]$ 和 $\textit{rightMax}[j]$ 维护前缀 $[0, j)$ 最大值和后缀 $[j + 1, n)$ 最大值,依次枚举 $j$,计算值 $(\textit{leftMax}[j] - \textit{nums}[j]) \times \textit{rightMax}[j]$,返回最大值(若所有值都为负数,则返回 $0$)。

###C++

class Solution {
public:
    long long maximumTripletValue(vector<int>& nums) {
        int n = nums.size();
        vector<int> leftMax(n), rightMax(n);
        for (int i = 1; i < n; i++) {
            leftMax[i] = max(leftMax[i - 1], nums[i - 1]);
            rightMax[n - 1 - i] = max(rightMax[n - i], nums[n - i]);
        }
        long long res = 0;
        for (int j = 1; j < n - 1; j++) {
            res = max(res, (long long)(leftMax[j] - nums[j]) * rightMax[j]);
        }
        return res;
    }
};

###Go

func maximumTripletValue(nums []int) int64 {
    n := len(nums)
    leftMax := make([]int, n)
    rightMax := make([]int, n)
    for i := 1; i < n; i++ {
        leftMax[i] = max(leftMax[i - 1], nums[i - 1])
    }
    for i := 1; i < n; i++ {
        rightMax[n - 1 - i] = max(rightMax[n - i], nums[n - i])
    }
    var res int64 = 0
    for j := 1; j < n - 1; j++ {
        res = max(res, int64((leftMax[j] - nums[j]) * rightMax[j]))
    }
    return res
}

###Python

class Solution:
    def maximumTripletValue(self, nums: List[int]) -> int:
        n = len(nums)
        leftMax = [0] * n
        rightMax = [0] * n
        for i in range(1, n):
            leftMax[i] = max(leftMax[i - 1], nums[i - 1])
            rightMax[n - 1 - i] = max(rightMax[n - i], nums[n - i])
        res = 0
        for j in range(1, n - 1):
            res = max(res, (leftMax[j] - nums[j]) * rightMax[j])
        return res

###Java

public class Solution {
    public long maximumTripletValue(int[] nums) {
        int n = nums.length;
        int[] leftMax = new int[n];
        int[] rightMax = new int[n];
        for (int i = 1; i < n; i++) {
            leftMax[i] = Math.max(leftMax[i - 1], nums[i - 1]);
            rightMax[n - 1 - i] = Math.max(rightMax[n - i], nums[n - i]);
        }
        long res = 0;
        for (int j = 1; j < n - 1; j++) {
            res = Math.max(res, (long)(leftMax[j] - nums[j]) * rightMax[j]);
        }
        return res;
    }
}

###JavaScript

var maximumTripletValue = function(nums) {
    const n = nums.length;
    const leftMax = new Array(n).fill(0);
    const rightMax = new Array(n).fill(0);
    for (let i = 1; i < n; i++) {
        leftMax[i] = Math.max(leftMax[i - 1], nums[i - 1]);
        rightMax[n - 1 - i] = Math.max(rightMax[n - i], nums[n - i]);
    }
    let res = 0;
    for (let j = 1; j < n - 1; j++) {
        res = Math.max(res, (leftMax[j] - nums[j]) * rightMax[j]);
    }
    return res;
};

###TypeScript

function maximumTripletValue(nums: number[]): number {
    const n = nums.length;
    const leftMax: number[] = new Array(n).fill(0);
    const rightMax: number[] = new Array(n).fill(0);
    for (let i = 1; i < n; i++) {
        leftMax[i] = Math.max(leftMax[i - 1], nums[i - 1]);
        rightMax[n - 1 - i] = Math.max(rightMax[n - i], nums[n - i]);
    }
    let res = 0;
    for (let j = 1; j < n - 1; j++) {
        res = Math.max(res, (leftMax[j] - nums[j]) * rightMax[j]);
    }
    return res;
}

###C#

public class Solution {
    public long MaximumTripletValue(int[] nums) {
        int n = nums.Length;
        int[] leftMax = new int[n];
        int[] rightMax = new int[n];
        for (int i = 1; i < n; i++) {
            leftMax[i] = Math.Max(leftMax[i - 1], nums[i - 1]);
            rightMax[n - 1 - i] = Math.Max(rightMax[n - i], nums[n - i]);
        }
        long res = 0;
        for (int j = 1; j < n - 1; j++) {
            res = Math.Max(res, (long)(leftMax[j] - nums[j]) * rightMax[j]);
        }
        return res;
    }
}

###C

long long maximumTripletValue(int *nums, int numsSize) {
    int leftMax[numsSize], rightMax[numsSize];
    leftMax[0] = 0;
    rightMax[numsSize - 1] = 0;
    for (int i = 1; i < numsSize; i++) {
        leftMax[i] = fmax(leftMax[i - 1], nums[i - 1]);
        rightMax[numsSize - 1 - i] = fmax(rightMax[numsSize - i], nums[numsSize - i]);
    }
    long long res = 0;
    for (int j = 1; j < numsSize - 1; j++) {
        long long temp = (long long)(leftMax[j] - nums[j]) * rightMax[j];
        res = fmax(res, temp);
    }
    return res;
}

###Rust

impl Solution {
    pub fn maximum_triplet_value(nums: Vec<i32>) -> i64 {
        let n = nums.len();
        let mut left_max = vec![0; n];
        let mut right_max = vec![0; n];
        for i in 1..n {
            left_max[i] = left_max[i - 1].max(nums[i - 1]);
            right_max[n - i - 1] = right_max[n - i].max(nums[n - i]);
        }
        let mut res = 0;
        for j in 1..n - 1 {
            res = res.max((left_max[j] - nums[j]) as i64 * right_max[j] as i64);
        }
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

  • 空间复杂度:$O(n)$。

方法二:贪心

类似于方法一,我们固定 $k$,那么当 $\textit{nums}[i] - \textit{nums}[j]$ 取最大值时,三元组的值最大。我们可以用 $\textit{imax}$ 维护 $\textit{nums}[i]$ 的最大值,$\textit{dmax}$ 维护 $\textit{nums}[i] - \textit{nums}[j]$ 的最大值,在枚举 $k$ 的过程中,更新 $\textit{dmax}$ 和 $\textit{imax}$。

###C++

class Solution {
public:
    long long maximumTripletValue(vector<int>& nums) {
        int n = nums.size();
        long long res = 0, imax = 0, dmax = 0;
        for (int k = 0; k < n; k++) {
            res = max(res, dmax * nums[k]);
            dmax = max(dmax, imax - nums[k]);
            imax = max(imax, static_cast<long long>(nums[k]));
        }
        return res;
    }
};

###Java

class Solution {
    public long maximumTripletValue(int[] nums) {
        int n = nums.length;
        long res = 0, imax = 0, dmax = 0;
        for (int k = 0; k < n; k++) {
            res = Math.max(res, dmax * nums[k]);
            dmax = Math.max(dmax, imax - nums[k]);
            imax = Math.max(imax, nums[k]);
        }
        return res;
    }
}

###C#

public class Solution {
    public long MaximumTripletValue(int[] nums) {
        int n = nums.Length;
        long res = 0, imax = 0, dmax = 0;
        for (int k = 0; k < n; k++) {
            res = Math.Max(res, dmax * nums[k]);
            dmax = Math.Max(dmax, imax - nums[k]);
            imax = Math.Max(imax, nums[k]);
        }
        return res;
    }
}

###Python

class Solution:
    def maximumTripletValue(self, nums: List[int]) -> int:
        n = len(nums)
        res, imax, dmax = 0, 0, 0
        for k in range(n):
            res = max(res, dmax * nums[k])
            dmax = max(dmax, imax - nums[k])
            imax = max(imax, nums[k])
        return res

###C

long long maximumTripletValue(int* nums, int numsSize) {
    long long res = 0, imax = 0, dmax = 0;
    for (int k = 0; k < numsSize; k++) {
        res = fmax(res, dmax * nums[k]);
        dmax = fmax(dmax, imax - nums[k]);
        imax = fmax(imax, nums[k]);
    }
    return res;
}

###Go

func maximumTripletValue(nums []int) int64 {
    n := len(nums)
    var res, imax, dmax int64 = 0, 0, 0
    for k := 0; k < n; k++ {
        res = max(res, dmax * int64(nums[k]))
        dmax = max(dmax, imax - int64(nums[k]))
        imax = max(imax, int64(nums[k]))
    }
    return res
}

###JavaScript

var maximumTripletValue = function(nums) {
    const n = nums.length;
    let res = 0, imax = 0, dmax = 0;
    for (let k = 0; k < n; k++) {
        res = Math.max(res, dmax * nums[k]);
        dmax = Math.max(dmax, imax - nums[k]);
        imax = Math.max(imax, nums[k]);
    }
    return res;
};

###TypeScript

function maximumTripletValue(nums: number[]): number {
    const n: number = nums.length;
    let res: number = 0, imax: number = 0, dmax: number = 0;
    for (let k = 0; k < n; k++) {
        res = Math.max(res, dmax * nums[k]);
        dmax = Math.max(dmax, imax - nums[k]);
        imax = Math.max(imax, nums[k]);
    }
    return res;
}

###Rust

impl Solution {
    pub fn maximum_triplet_value(nums: Vec<i32>) -> i64 {
        let mut res = 0;
        let mut imax = 0;
        let mut dmax = 0;
        for &num in nums.iter() {
            res = res.max(dmax * num as i64);
            dmax = dmax.max(imax - num as i64);
            imax = imax.max(num as i64);
        }
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

  • 空间复杂度:$O(1)$。

[Python3/Java/C++/Go/TypeScript] 一题一解:维护前缀最大值和最大差值(清晰题解)

作者 lcbin
2025年4月3日 06:18

方法一:维护前缀最大值和最大差值

我们用两个变量 $\textit{mx}$ 和 $\textit{mxDiff}$ 分别维护前缀最大值和最大差值,用一个变量 $\textit{ans}$ 维护答案。初始时,这些变量都为 $0$。

接下来,我们枚举数组的每个元素 $x$ 作为 $\textit{nums}[k]$,首先更新答案 $\textit{ans} = \max(\textit{ans}, \textit{mxDiff} \times x)$,然后我们更新最大差值 $\textit{mxDiff} = \max(\textit{mxDiff}, \textit{mx} - x)$,最后更新前缀最大值 $\textit{mx} = \max(\textit{mx}, x)$。

枚举完所有元素后,返回答案 $\textit{ans}$。

###python

class Solution:
    def maximumTripletValue(self, nums: List[int]) -> int:
        ans = mx = mx_diff = 0
        for x in nums:
            ans = max(ans, mx_diff * x)
            mx_diff = max(mx_diff, mx - x)
            mx = max(mx, x)
        return ans

###java

class Solution {
    public long maximumTripletValue(int[] nums) {
        long ans = 0, mxDiff = 0;
        int mx = 0;
        for (int x : nums) {
            ans = Math.max(ans, mxDiff * x);
            mxDiff = Math.max(mxDiff, mx - x);
            mx = Math.max(mx, x);
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    long long maximumTripletValue(vector<int>& nums) {
        long long ans = 0, mxDiff = 0;
        int mx = 0;
        for (int x : nums) {
            ans = max(ans, mxDiff * x);
            mxDiff = max(mxDiff, 1LL * mx - x);
            mx = max(mx, x);
        }
        return ans;
    }
};

###go

func maximumTripletValue(nums []int) int64 {
ans, mx, mxDiff := 0, 0, 0
for _, x := range nums {
ans = max(ans, mxDiff*x)
mxDiff = max(mxDiff, mx-x)
mx = max(mx, x)
}
return int64(ans)
}

###ts

function maximumTripletValue(nums: number[]): number {
    let [ans, mx, mxDiff] = [0, 0, 0];
    for (const x of nums) {
        ans = Math.max(ans, mxDiff * x);
        mxDiff = Math.max(mxDiff, mx - x);
        mx = Math.max(mx, x);
    }
    return ans;
}

###rust

impl Solution {
    pub fn maximum_triplet_value(nums: Vec<i32>) -> i64 {
        let mut ans: i64 = 0;
        let mut mx: i32 = 0;
        let mut mx_diff: i32 = 0;

        for &x in &nums {
            ans = ans.max(mx_diff as i64 * x as i64);
            mx_diff = mx_diff.max(mx - x);
            mx = mx.max(x);
        }

        ans
    }
}

时间复杂度 $O(n)$,其中 $n$ 是数组长度。空间复杂度 $O(1)$。


有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈😄~

❌
❌