普通视图

发现新文章,点击刷新页面。
今天 — 2026年3月24日首页

前后缀分解(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2023年10月15日 12:12

请先完成本题的一维版本238. 除了自身以外数组的乘积

把矩阵平铺成一维数组,就是 238 题了。我们需要算出每个数左边所有数的乘积,以及右边所有数的乘积。

先算出从 $\textit{grid}[i][j]$ 的下一个元素开始,到最后一个元素 $\textit{grid}[n-1][m-1]$ 的乘积,记作 $\textit{suf}[i][j]$。这可以从最后一个数 $\textit{grid}[n-1][m-1]$ 开始,倒着遍历 $\textit{grid}$ 得到。

然后算出从第一个数 $\textit{grid}[0][0]$ 开始,到 $\textit{grid}[i][j]$ 的上一个元素的乘积,记作 $\textit{pre}[i][j]$。这可以从第一行第一列开始,正着遍历得到。

那么

$$
p[i][j] = \textit{pre}[i][j]\cdot \textit{suf}[i][j]
$$

代码实现时,可以先初始化 $p[i][j]=\textit{suf}[i][j]$,然后在计算 $\textit{pre}[i][j]$ 的过程中,把 $\textit{pre}[i][j]$ 乘到 $\textit{p}[i][j]$ 中,就得到了最终答案。这样写的话,$\textit{pre}$ 和 $\textit{suf}$ 可以直接用单个变量表示,无需创建数组。

代码实现时,注意取模。为什么可以在中途取模?原理见 模运算的世界:当加减乘除遇上取模

class Solution:
    def constructProductMatrix(self, grid: List[List[int]]) -> List[List[int]]:
        MOD = 12345
        n, m = len(grid), len(grid[0])
        p = [[0] * m for _ in range(n)]

        suf = 1  # 后缀乘积
        for i in range(n - 1, -1, -1):
            for j in range(m - 1, -1, -1):
                p[i][j] = suf  # p[i][j] 先初始化成后缀乘积
                suf = suf * grid[i][j] % MOD

        pre = 1  # 前缀乘积
        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                p[i][j] = p[i][j] * pre % MOD  # 乘上前缀乘积
                pre = pre * x % MOD

        return p
class Solution {
    public int[][] constructProductMatrix(int[][] grid) {
        final int MOD = 12345;
        int n = grid.length;
        int m = grid[0].length;
        int[][] p = new int[n][m];

        long suf = 1; // 后缀乘积
        for (int i = n - 1; i >= 0; i--) {
            for (int j = m - 1; j >= 0; j--) {
                p[i][j] = (int) suf; // p[i][j] 先初始化成后缀乘积
                suf = suf * grid[i][j] % MOD;
            }
        }

        long pre = 1; // 前缀乘积
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                p[i][j] = (int) (p[i][j] * pre % MOD); // 乘上前缀乘积
                pre = pre * grid[i][j] % MOD;
            }
        }

        return p;
    }
}
class Solution {
public:
    vector<vector<int>> constructProductMatrix(vector<vector<int>>& grid) {
        constexpr int MOD = 12345;
        int n = grid.size(), m = grid[0].size();
        vector p(n, vector<int>(m));

        long long suf = 1; // 后缀乘积
        for (int i = n - 1; i >= 0; i--) {
            for (int j = m - 1; j >= 0; j--) {
                p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
                suf = suf * grid[i][j] % MOD;
            }
        }

        long long pre = 1; // 前缀乘积
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
                pre = pre * grid[i][j] % MOD;
            }
        }

        return p;
    }
};
int** constructProductMatrix(int** grid, int gridSize, int* gridColSize, int* returnSize, int** returnColumnSizes) {
    const int MOD = 12345;
    int n = gridSize, m = gridColSize[0];
    int** p = malloc(n * sizeof(int*));
    *returnSize = n;
    *returnColumnSizes = malloc(n * sizeof(int));
    for (int i = 0; i < n; i++) {
        p[i] = malloc(m * sizeof(int));
        (*returnColumnSizes)[i] = m;
    }

    long long suf = 1; // 后缀乘积
    for (int i = n - 1; i >= 0; i--) {
        for (int j = m - 1; j >= 0; j--) {
            p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
            suf = suf * grid[i][j] % MOD;
        }
    }

    long long pre = 1; // 前缀乘积
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
            pre = pre * grid[i][j] % MOD;
        }
    }

    return p;
}
func constructProductMatrix(grid [][]int) [][]int {
const mod = 12345
n, m := len(grid), len(grid[0])
p := make([][]int, n)
suf := 1 // 后缀乘积
for i := n - 1; i >= 0; i-- {
p[i] = make([]int, m)
for j := m - 1; j >= 0; j-- {
p[i][j] = suf // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % mod
}
}

pre := 1 // 前缀乘积
for i, row := range grid {
for j, x := range row {
p[i][j] = p[i][j] * pre % mod // 乘上前缀乘积
pre = pre * x % mod
}
}
return p
}
var constructProductMatrix = function(grid) {
    const MOD = 12345;
    const n = grid.length, m = grid[0].length;
    const p = Array.from({ length: n }, () => Array(m).fill(0));

    let suf = 1; // 后缀乘积
    for (let i = n - 1; i >= 0; i--) {
        for (let j = m - 1; j >= 0; j--) {
            p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
            suf = suf * grid[i][j] % MOD;
        }
    }

    let pre = 1; // 前缀乘积
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
            pre = pre * grid[i][j] % MOD;
        }
    }

    return p;
};
impl Solution {
    pub fn construct_product_matrix(grid: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        const MOD: i64 = 12345;
        let n = grid.len();
        let m = grid[0].len();
        let mut p = vec![vec![0; m]; n];

        let mut suf = 1; // 后缀乘积
        for i in (0..n).rev() {
            for j in (0..m).rev() {
                p[i][j] = suf as i32; // p[i][j] 先初始化成后缀乘积
                suf = suf * grid[i][j] as i64 % MOD;
            }
        }

        let mut pre = 1; // 前缀乘积
        for i in 0..n {
            for j in 0..m {
                p[i][j] = (p[i][j] as i64 * pre % MOD) as i32; // 乘上前缀乘积
                pre = pre * grid[i][j] as i64 % MOD;
            }
        }

        p
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(nm)$,其中 $n$ 和 $m$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(1)$。返回值不计入。

专题训练

见下面动态规划题单的「专题:前后缀分解」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天 — 2026年3月23日首页

记忆化搜索 -> 递推 -> 空间优化(Python/Java/C++/Go)

作者 endlesscheng
2026年3月13日 09:24

前置题目

  1. 152. 乘积最大子数组我的题解
  2. 64. 最小路径和我的题解

状态定义与状态转移方程

思路和 152 题是一样的,除了计算最大路径乘积,还要计算最小路径乘积(因为负负得正)。

定义 $\textit{dfs}(i,j)$ 表示从左上角 $(0,0)$ 到 $(i,j)$ 的最小路径乘积以及最大路径乘积($\textit{dfs}$ 返回两个数)。

设 $x = \textit{grid}[i][j]$。分类讨论如何到达 $(i,j)$:

  • 如果是从上边过来,那么必须先到达 $(i-1,j)$,我们需要知道从 $(0,0)$ 到 $(i-1,j)$ 的最小路径乘积 $\textit{mn}$ 以及最大路径乘积 $\textit{mx}$,这可以从 $\textit{dfs}(i-1,j)$ 获取到。从左上角 $(0,0)$ 到 $(i,j)$ 的最小路径乘积为 $\min(\textit{mn}\cdot x, \textit{mx}\cdot x)$,从左上角 $(0,0)$ 到 $(i,j)$ 的最大路径乘积为 $\max(\textit{mn}\cdot x, \textit{mx}\cdot x)$。理由同 152 题。
  • 如果是从左边过来,那么必须先到达 $(i,j-1)$,我们需要知道从 $(0,0)$ 到 $(i,j-1)$ 的最小路径乘积以及最大路径乘积,这可以从 $\textit{dfs}(i,j-1)$ 获取到。计算方法同上。

两种情况取最小值(最大值),即为 $\textit{dfs}(i,j)$ 的返回值。

递归边界:$\textit{dfs}(0,0) = (x,x)$。

递归入口:$\textit{dfs}(m-1,n-1)$。取返回值中的最大路径乘积作为答案。如果答案是负数,返回 $-1$;否则返回答案模 $10^9+7$ 的结果。

注意:题目要求算完了再取模。如果在中途取模,可能会把一个很大的数模成很小的数,导致计算错误。比如两个数 $10^9+8$ 和 $10^9$,取模之前是 $10^9+8$ 更大,但取模后这两个数分别变成 $1$ 和 $10^9$,后者更大。

###py

class Solution:
    def maxProductPath(self, grid: List[List[int]]) -> int:
        @cache  # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
        def dfs(i: int, j: int) -> Tuple[int, int]:
            x = grid[i][j]
            if i == j == 0:
                return x, x

            res_min, res_max = inf, -inf
            if i > 0:
                mn, mx = dfs(i - 1, j)
                res_min = min(mn * x, mx * x)
                res_max = max(mn * x, mx * x)
            if j > 0:
                mn, mx = dfs(i, j - 1)
                res_min = min(res_min, mn * x, mx * x)
                res_max = max(res_max, mn * x, mx * x)

            return res_min, res_max

        ans = dfs(len(grid) - 1, len(grid[0]) - 1)[1]
        return -1 if ans < 0 else ans % 1_000_000_007

###java

class Solution {
    public int maxProductPath(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        long[][][] memo = new long[m][n][2];
        for (long[][] row : memo) {
            for (long[] p : row) {
                p[0] = p[1] = Long.MIN_VALUE;
            }
        }

        long ans = dfs(m - 1, n - 1, grid, memo)[1];
        return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
    }

    private long[] dfs(int i, int j, int[][] grid, long[][][] memo) {
        long x = grid[i][j];
        if (i == 0 && j == 0) {
            return new long[]{x, x};
        }

        long[] p = memo[i][j];
        if (p[0] != Long.MIN_VALUE) { // 之前计算过
            return p;
        }

        long resMin = Long.MAX_VALUE;
        long resMax = Long.MIN_VALUE;
        if (i > 0) {
            long[] res = dfs(i - 1, j, grid, memo);
            long mn = res[0], mx = res[1];
            resMin = Math.min(mn * x, mx * x);
            resMax = Math.max(mn * x, mx * x);
        }
        if (j > 0) {
            long[] res = dfs(i, j - 1, grid, memo);
            long mn = res[0], mx = res[1];
            resMin = Math.min(resMin, Math.min(mn * x, mx * x));
            resMax = Math.max(resMax, Math.max(mn * x, mx * x));
        }

        p[0] = resMin;
        p[1] = resMax; // 记忆化
        return p;
    }
}

###cpp

class Solution {
public:
    int maxProductPath(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector memo(m, vector<array<long long, 2>>(n, {LLONG_MIN, LLONG_MIN}));

        auto dfs = [&](this auto&& dfs, int i, int j) -> array<long long, 2> {
            long long x = grid[i][j];
            if (i == 0 && j == 0) {
                return {x, x};
            }

            auto& res = memo[i][j]; // 注意这里是引用
            if (res[0] != LLONG_MIN) { // 之前计算过
                return res;
            }

            long long res_min = LLONG_MAX;
            long long res_max = LLONG_MIN;
            if (i > 0) {
                auto [mn, mx] = dfs(i - 1, j);
                res_min = min(mn * x, mx * x);
                res_max = max(mn * x, mx * x);
            }
            if (j > 0) {
                auto [mn, mx] = dfs(i, j - 1);
                res_min = min(res_min, min(mn * x, mx * x));
                res_max = max(res_max, max(mn * x, mx * x));
            }

            res = {res_min, res_max}; // 记忆化
            return res;
        };

        long long ans = dfs(m - 1, n - 1)[1];
        return ans < 0 ? -1 : ans % 1'000'000'007;
    }
};

###go

func maxProductPath(grid [][]int) int {
m, n := len(grid), len(grid[0])
memo := make([][][2]int, m)
for i := range memo {
memo[i] = make([][2]int, n)
for j := range memo[i] {
memo[i][j] = [2]int{math.MinInt, math.MinInt}
}
}

var dfs func(int, int) (int, int)
dfs = func(i, j int) (int, int) {
x := grid[i][j]
if i == 0 && j == 0 {
return x, x
}

p := &memo[i][j]
if p[0] != math.MinInt { // 之前计算过
return p[0], p[1]
}

resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := dfs(i-1, j)
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := dfs(i, j-1)
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}

p[0], p[1] = resMin, resMax // 记忆化
return resMin, resMax
}

_, ans := dfs(m-1, n-1)
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(mn)$,单个状态的计算时间为 $\mathcal{O}(1)$,所以总的时间复杂度为 $\mathcal{O}(mn)$。
  • 空间复杂度:$\mathcal{O}(mn)$。保存多少状态,就需要多少空间。

1:1 翻译成递推

把 $\textit{dfs}(i,j)$ 改成 $f[i][j]$。

###py

class Solution:
    def maxProductPath(self, grid: List[List[int]]) -> int:
        m, n = len(grid), len(grid[0])
        f = [[None] * n for _ in range(m)]

        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                if i == j == 0:
                    f[0][0] = (x, x)
                    continue

                res_min, res_max = inf, -inf
                if i > 0:
                    mn, mx = f[i - 1][j]
                    res_min = min(mn * x, mx * x)
                    res_max = max(mn * x, mx * x)
                if j > 0:
                    mn, mx = f[i][j - 1]
                    res_min = min(res_min, mn * x, mx * x)
                    res_max = max(res_max, mn * x, mx * x)

                f[i][j] = (res_min, res_max)

        ans = f[-1][-1][1]
        return -1 if ans < 0 else ans % 1_000_000_007

###java

class Solution {
    public int maxProductPath(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        long[][][] f = new long[m][n][2];

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                long x = grid[i][j];
                if (i == 0 && j == 0) {
                    f[0][0][0] = x;
                    f[0][0][1] = x;
                    continue;
                }

                long resMin = Long.MAX_VALUE;
                long resMax = Long.MIN_VALUE;
                if (i > 0) {
                    long mn = f[i - 1][j][0], mx = f[i - 1][j][1];
                    resMin = Math.min(mn * x, mx * x);
                    resMax = Math.max(mn * x, mx * x);
                }
                if (j > 0) {
                    long mn = f[i][j - 1][0], mx = f[i][j - 1][1];
                    resMin = Math.min(resMin, Math.min(mn * x, mx * x));
                    resMax = Math.max(resMax, Math.max(mn * x, mx * x));
                }

                f[i][j][0] = resMin;
                f[i][j][1] = resMax;
            }
        }

        long ans = f[m - 1][n - 1][1];
        return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
    }
}

###cpp

class Solution {
public:
    int maxProductPath(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector f(m, vector<array<long long, 2>>(n));

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                long long x = grid[i][j];
                if (i == 0 && j == 0) {
                    f[0][0] = {x, x};
                    continue;
                }

                long long res_min = LLONG_MAX;
                long long res_max = LLONG_MIN;
                if (i > 0) {
                    auto [mn, mx] = f[i - 1][j];
                    res_min = min(mn * x, mx * x);
                    res_max = max(mn * x, mx * x);
                }
                if (j > 0) {
                    auto [mn, mx] = f[i][j - 1];
                    res_min = min(res_min, min(mn * x, mx * x));
                    res_max = max(res_max, max(mn * x, mx * x));
                }

                f[i][j] = {res_min, res_max};
            }
        }

        long long ans = f[m - 1][n - 1][1];
        return ans < 0 ? -1 : ans % 1'000'000'007;
    }
};

###go

func maxProductPath(grid [][]int) int {
m, n := len(grid), len(grid[0])
f := make([][][2]int, m)
for i := range f {
f[i] = make([][2]int, n)
}

for i, row := range grid {
for j, x := range row {
if i == 0 && j == 0 {
f[0][0] = [2]int{x, x}
continue
}

resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := f[i-1][j][0], f[i-1][j][1]
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := f[i][j-1][0], f[i][j-1][1]
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}

f[i][j] = [2]int{resMin, resMax}
}
}

ans := f[m-1][n-1][1]
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

空间优化

原理见 64 题 我的题解

###py

class Solution:
    def maxProductPath(self, grid: List[List[int]]) -> int:
        n = len(grid[0])
        f_min = [0] * n
        f_max = [0] * n

        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                if i == j == 0:
                    f_min[0] = f_max[0] = x
                    continue

                res_min, res_max = inf, -inf
                if i > 0:
                    mn, mx = f_min[j], f_max[j]
                    res_min = min(mn * x, mx * x)
                    res_max = max(mn * x, mx * x)
                if j > 0:
                    mn, mx = f_min[j - 1], f_max[j - 1]
                    res_min = min(res_min, mn * x, mx * x)
                    res_max = max(res_max, mn * x, mx * x)

                f_min[j] = res_min
                f_max[j] = res_max

        ans = f_max[-1]
        return -1 if ans < 0 else ans % 1_000_000_007

###java

class Solution {
    public int maxProductPath(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        long[] fMin = new long[n];
        long[] fMax = new long[n];

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                long x = grid[i][j];
                if (i == 0 && j == 0) {
                    fMin[0] = fMax[0] = x;
                    continue;
                }

                long resMin = Long.MAX_VALUE;
                long resMax = Long.MIN_VALUE;
                if (i > 0) {
                    long mn = fMin[j], mx = fMax[j];
                    resMin = Math.min(mn * x, mx * x);
                    resMax = Math.max(mn * x, mx * x);
                }
                if (j > 0) {
                    long mn = fMin[j - 1], mx = fMax[j - 1];
                    resMin = Math.min(resMin, Math.min(mn * x, mx * x));
                    resMax = Math.max(resMax, Math.max(mn * x, mx * x));
                }

                fMin[j] = resMin;
                fMax[j] = resMax;
            }
        }

        long ans = fMax[n - 1];
        return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
    }
}

###cpp

class Solution {
public:
    int maxProductPath(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector<long long> f_min(n), f_max(n);

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                long long x = grid[i][j];
                if (i == 0 && j == 0) {
                    f_min[0] = f_max[0] = x;
                    continue;
                }

                long long res_min = LLONG_MAX;
                long long res_max = LLONG_MIN;
                if (i > 0) {
                    long long mn = f_min[j], mx = f_max[j];
                    res_min = min(mn * x, mx * x);
                    res_max = max(mn * x, mx * x);
                }
                if (j > 0) {
                    long long mn = f_min[j - 1], mx = f_max[j - 1];
                    res_min = min(res_min, min(mn * x, mx * x));
                    res_max = max(res_max, max(mn * x, mx * x));
                }

                f_min[j] = res_min;
                f_max[j] = res_max;
            }
        }

        long long ans = f_max[n - 1];
        return ans < 0 ? -1 : ans % 1'000'000'007;
    }
};

###go

func maxProductPath(grid [][]int) int {
n := len(grid[0])
fMin := make([]int, n)
fMax := make([]int, n)

for i, row := range grid {
for j, x := range row {
if i == 0 && j == 0 {
fMin[0], fMax[0] = x, x
continue
}

resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := fMin[j], fMax[j]
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := fMin[j-1], fMax[j-1]
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}

fMin[j] = resMin
fMax[j] = resMax
}
}

ans := fMax[n-1]
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面动态规划题单的「二、网格图 DP」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天以前首页

两种方法:先旋转再比较 / 直接比较(Python/Java/C++/Go)

作者 endlesscheng
2021年6月6日 12:34

方法一:先旋转再比较

枚举 $\textit{mat}$ 旋转 $0,1,2,3$ 次,判断旋转后的 $\textit{mat}$ 是否等于 $\textit{target}$。

旋转方阵有原地算法,见 48. 旋转图像我的题解

class Solution:
    # 48. 旋转图像
    def rotate(self, matrix: List[List[int]]) -> None:
        n = len(matrix)
        for i, row in enumerate(matrix):
            for j in range(i + 1, n):  # 遍历对角线上方元素,做转置
                row[j], matrix[j][i] = matrix[j][i], row[j]
            row.reverse()  # 行翻转

    def findRotation(self, mat: List[List[int]], target: List[List[int]]) -> bool:
        for _ in range(4):
            if mat == target:
                return True
            self.rotate(mat)
        return False
class Solution {
    public boolean findRotation(int[][] mat, int[][] target) {
        for (int i = 0; i < 4; i++) {
            if (Arrays.deepEquals(mat, target)) {
                return true;
            }
            rotate(mat);
        }
        return false;
    }

    // 48. 旋转图像
    public void rotate(int[][] matrix) {
        int n = matrix.length;
        for (int i = 0; i < n; i++) {
            int[] row = matrix[i];
            for (int j = i + 1; j < n; j++) { // 遍历对角线上方元素,做转置
                int tmp = row[j];
                row[j] = matrix[j][i];
                matrix[j][i] = tmp;
            }
            for (int j = 0; j < n / 2; j++) { // 遍历左半元素,做行翻转
                int tmp = row[j];
                row[j] = row[n - 1 - j];
                row[n - 1 - j] = tmp;
            }
        }
    }
}
class Solution {
    // 48. 旋转图像
    void rotate(vector<vector<int>>& matrix) {
        int n = matrix.size();
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) { // 遍历对角线上方元素,做转置
                swap(matrix[i][j], matrix[j][i]);
            }
            ranges::reverse(matrix[i]); // 行翻转
        }
    }

public:
    bool findRotation(vector<vector<int>>& mat, vector<vector<int>>& target) {
        for (int i = 0; i < 4; i++) {
            if (mat == target) {
                return true;
            }
            rotate(mat);
        }
        return false;
    }
};
// 48. 旋转图像
func rotate(matrix [][]int) {
n := len(matrix)
for i, row := range matrix {
for j := i + 1; j < n; j++ { // 遍历对角线上方元素,做转置
row[j], matrix[j][i] = matrix[j][i], row[j]
}
slices.Reverse(row) // 行翻转
}
}

func findRotation(mat, target [][]int) bool {
for range 4 {
if slices.EqualFunc(mat, target, slices.Equal[[]int]) {
return true
}
rotate(mat)
}
return false
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{mat}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:直接比较

顺时针旋转 $90^\circ$ 后,位于 $(i,j)$ 的元素去哪了?

根据 48 题 我的题解,结论如下:

$$
(i,j)\xrightarrow{旋转\ 90^\circ} (j,n-1-i) \xrightarrow{旋转\ 90^\circ} (n-1-i,n-1-j) \xrightarrow{旋转\ 90^\circ} (n-1-j,i)
$$

所以对于 $\textit{mat}[i][j]$,它需要比较四个位置上的值:

  • 旋转 $0$ 次:比较 $\textit{target}[i][j]$。
  • 旋转 $1$ 次:比较 $\textit{target}[j][n-1-i]$。
  • 旋转 $2$ 次:比较 $\textit{target}[n-1-i][n-1-j]$。
  • 旋转 $3$ 次:比较 $\textit{target}[n-1-j][i]$。

如果对于某个旋转次数,所有的比较都为真,那么返回 $\texttt{true}$。否则返回 $\texttt{false}$。

class Solution:
    def findRotation(self, mat: List[List[int]], target: List[List[int]]) -> bool:
        ok = (1 << 4) - 1  # ok = [True] * 4
        for i, row in enumerate(mat):
            for j, x in enumerate(row):
                if x != target[i][j]:
                    ok &= ~(1 << 0)  # ok[0] = False
                if x != target[j][-1 - i]:
                    ok &= ~(1 << 1)  # ok[1] = False
                if x != target[-1 - i][-1 - j]:
                    ok &= ~(1 << 2)  # ok[2] = False
                if x != target[-1 - j][i]:
                    ok &= ~(1 << 3)  # ok[3] = False
                if ok == 0:  # 所有的 ok[i] 都是 False
                    return False
        return True  # 至少有一个 ok[i] 是 True
class Solution {
    public boolean findRotation(int[][] mat, int[][] target) {
        int n = mat.length;
        int ok = (1 << 4) - 1; // boolean[] ok = {true, true, true, true};
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int x = mat[i][j];
                if (x != target[i][j]) {
                    ok &= ~(1 << 0); // ok[0] = false
                }
                if (x != target[j][n - 1 - i]) {
                    ok &= ~(1 << 1); // ok[1] = false
                }
                if (x != target[n - 1 - i][n - 1 - j]) {
                    ok &= ~(1 << 2); // ok[2] = false
                }
                if (x != target[n - 1 - j][i]) {
                    ok &= ~(1 << 3); // ok[3] = false
                }
                if (ok == 0) { // 所有的 ok[i] 都是 false
                    return false;
                }
            }
        }
        return true; // 至少有一个 ok[i] 是 true
    }
}
class Solution {
public:
    bool findRotation(vector<vector<int>>& mat, vector<vector<int>>& target) {
        int n = mat.size();
        int ok = (1 << 4) - 1; // bool ok[4] = {true, true, true, true}
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int x = mat[i][j];
                if (x != target[i][j]) {
                    ok &= ~(1 << 0); // ok[0] = false
                }
                if (x != target[j][n - 1 - i]) {
                    ok &= ~(1 << 1); // ok[1] = false
                }
                if (x != target[n - 1 - i][n - 1 - j]) {
                    ok &= ~(1 << 2); // ok[2] = false
                }
                if (x != target[n - 1 - j][i]) {
                    ok &= ~(1 << 3); // ok[3] = false
                }
                if (ok == 0) { // 所有的 ok[i] 都是 false
                    return false;
                }
            }
        }
        return true; // 至少有一个 ok[i] 是 true
    }
};
func findRotation(mat, target [][]int) bool {
n := len(mat)
ok := 1<<4 - 1 // ok := [4]bool{true, true, true, true}
for i, row := range mat {
for j, x := range row {
if x != target[i][j] {
ok &^= 1 << 0 // ok[0] = false
}
if x != target[j][n-1-i] {
ok &^= 1 << 1 // ok[1] = false
}
if x != target[n-1-i][n-1-j] {
ok &^= 1 << 2 // ok[2] = false
}
if x != target[n-1-j][i] {
ok &^= 1 << 3 // ok[3] = false
}
if ok == 0 { // 所有的 ok[i] 都是 false
return false
}
}
}
return true // 至少有一个 ok[i] 是 true
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{mat}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

双指针(Python/Java/C++/Go)

作者 endlesscheng
2025年8月10日 13:44

根据题意,交换的范围是行号 $[x,x+k-1]$,列号 $[y,y+k-1]$。

类似 344. 反转字符串,用双指针实现:

  • 初始化 $l=x$,$r=x+k-1$。
  • 循环直到 $l\ge r$。
  • 每次循环,对于 $[y,y+k-1]$ 中的每个整数 $j$,交换 $\textit{grid}[l][j]$ 和 $\textit{grid}[r][j]$。

具体请看 视频讲解,欢迎点赞关注~

###py

class Solution:
    def reverseSubmatrix(self, grid: List[List[int]], x: int, y: int, k: int) -> List[List[int]]:
        l, r = x, x + k - 1
        while l < r:
            for j in range(y, y + k):
                grid[l][j], grid[r][j] = grid[r][j], grid[l][j]
            l += 1
            r -= 1
        return grid

###py

class Solution:
    def reverseSubmatrix(self, grid: List[List[int]], x: int, y: int, k: int) -> List[List[int]]:
        l, r = x, x + k - 1
        while l < r:
            grid[l][y: y + k], grid[r][y: y + k] = grid[r][y: y + k], grid[l][y: y + k]
            l += 1
            r -= 1
        return grid

###java

class Solution {
    public int[][] reverseSubmatrix(int[][] grid, int x, int y, int k) {
        int l = x;
        int r = x + k - 1;
        while (l < r) {
            for (int j = y; j < y + k; j++) {
                int tmp = grid[l][j];
                grid[l][j] = grid[r][j];
                grid[r][j] = tmp;
            }
            l++;
            r--;
        }
        return grid;
    }
}

###cpp

class Solution {
public:
    vector<vector<int>> reverseSubmatrix(vector<vector<int>>& grid, int x, int y, int k) {
        int l = x, r = x + k - 1;
        while (l < r) {
            for (int j = y; j < y + k; j++) {
                swap(grid[l][j], grid[r][j]);
            }
            l++;
            r--;
        }
        return grid;
    }
};

###go

func reverseSubmatrix(grid [][]int, x, y, k int) [][]int {
l, r := x, x+k-1
for l < r {
for j := y; j < y+k; j++ {
grid[l][j], grid[r][j] = grid[r][j], grid[l][j]
}
l++
r--
}
return grid
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(k^2)$。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

暴力枚举(Python/Java/C++/Go)

作者 endlesscheng
2025年6月1日 13:51

暴力枚举所有子矩形。把子矩形中的所有元素添加到一个数组 $a$ 中,然后把 $a$ 排序。排序后,不同元素之差的最小值一定来自 $a$ 的相邻元素,计算相邻不同元素之差的最小值。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def minAbsDiff(self, grid: List[List[int]], k: int) -> List[List[int]]:
        m, n = len(grid), len(grid[0])
        ans = [[0] * (n - k + 1) for _ in range(m - k + 1)]
        for i in range(m - k + 1):
            sub_grid = grid[i: i + k]
            for j in range(n - k + 1):
                a = []
                for row in sub_grid:
                    a += row[j: j + k]
                a.sort()

                res = inf
                for x, y in pairwise(a):
                    if x < y:  # 题目要求相减的两个数必须不同
                        res = min(res, y - x)
                if res < inf:
                    ans[i][j] = res
        return ans

###java

class Solution {
    public int[][] minAbsDiff(int[][] grid, int k) {
        int m = grid.length;
        int n = grid[0].length;
        int[][] ans = new int[m - k + 1][n - k + 1];
        int[] a = new int[k * k];
        for (int i = 0; i <= m - k; i++) {
            for (int j = 0; j <= n - k; j++) {
                int idx = 0;
                for (int x = 0; x < k; x++) {
                    for (int y = 0; y < k; y++) {
                        a[idx++] = grid[i + x][j + y];
                    }
                }
                Arrays.sort(a);

                int res = Integer.MAX_VALUE;
                for (int p = 1; p < a.length; p++) {
                    if (a[p] > a[p - 1]) { // 题目要求相减的两个数必须不同
                        res = Math.min(res, a[p] - a[p - 1]);
                    }
                }
                if (res < Integer.MAX_VALUE) {
                    ans[i][j] = res;
                }
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    vector<vector<int>> minAbsDiff(vector<vector<int>>& grid, int k) {
        int m = grid.size(), n = grid[0].size();
        vector ans(m - k + 1, vector<int>(n - k + 1));
        for (int i = 0; i <= m - k; i++) {
            for (int j = 0; j <= n - k; j++) {
                vector<int> a;
                for (int x = 0; x < k; x++) {
                    for (int y = 0; y < k; y++) {
                        a.push_back(grid[i + x][j + y]);
                    }
                }
                ranges::sort(a);

                int res = INT_MAX;
                for (int p = 1; p < a.size(); p++) {
                    if (a[p] > a[p - 1]) { // 题目要求相减的两个数必须不同
                        res = min(res, a[p] - a[p - 1]);
                    }
                }
                if (res < INT_MAX) {
                    ans[i][j] = res;
                }
            }
        }
        return ans;
    }
};

###go

func minAbsDiff(grid [][]int, k int) [][]int {
m, n := len(grid), len(grid[0])
ans := make([][]int, m-k+1)
arr := make([]int, k*k)
for i := range ans {
ans[i] = make([]int, n-k+1)
for j := range ans[i] {
a := arr[:0] // 避免反复 make
for _, row := range grid[i : i+k] {
a = append(a, row[j:j+k]...)
}
slices.Sort(a)

res := math.MaxInt
for p := 1; p < len(a); p++ {
if a[p] > a[p-1] { // 题目要求相减的两个数必须不同
res = min(res, a[p]-a[p-1])
}
}
if res < math.MaxInt {
ans[i][j] = res
}
}
}
return ans
}

复杂度分析

  • 时间复杂度:$\mathcal{O}((m-k)(n-k)k^2\log k)$,其中 $m$ 和 $n$ 分别为 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(k^2)$。返回值不计入。

:考虑用定长滑动窗口 + 有序集合 + 懒删除堆,用有序集合维护窗口(子矩阵)元素,用懒删除堆维护相邻不同元素之差。添加删除的时候更新相邻不同元素之差。

这样可以做到 $\mathcal{O}((m-k)nk\log k)$,但常数比较大。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

两种方法:二维前缀和模板 / 维护每列元素和(Python/Java/C++/Go)

作者 endlesscheng
2024年3月3日 12:11

方法一:二维前缀和

前置知识【图解】二维前缀和

本题相当于统计有多少个二维前缀和 $\le k$。

###py

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        m, n = len(grid), len(grid[0])
        s = [[0] * (n + 1) for _ in range(m + 1)]
        ans = 0
        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                s[i + 1][j + 1] = s[i + 1][j] + s[i][j + 1] - s[i][j] + x
                if s[i + 1][j + 1] <= k:
                    ans += 1
        return ans

###java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int m = grid.length;
        int n = grid[0].length;
        int[][] sum = new int[m + 1][n + 1];
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
                if (sum[i + 1][j + 1] <= k) {
                    ans++;
                }
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int m = grid.size(), n = grid[0].size();
        vector sum(m + 1, vector<int>(n + 1));
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
                ans += sum[i + 1][j + 1] <= k;
            }
        }
        return ans;
    }
};

###go

func countSubmatrices(grid [][]int, k int) (ans int) {
m, n := len(grid), len(grid[0])
sum := make([][]int, m+1)
sum[0] = make([]int, n+1)
for i, row := range grid {
sum[i+1] = make([]int, n+1)
for j, x := range row {
sum[i+1][j+1] = sum[i+1][j] + sum[i][j+1] - sum[i][j] + x
if sum[i+1][j+1] <= k {
ans++
}
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

:如果原地计算二维前缀和,可以做到 $\mathcal{O}(1)$ 额外空间。

方法二:维护每列的元素和

遍历每一行,同时用一个长为 $n$ 的数组 $\textit{colSum}$ 维护每一列的元素和。

遍历当前行时,一边更新 $\textit{colSum}[j]$,一边累加 $\textit{colSum}[j]$ 到变量 $s$ 中。

如果 $s\le k$ 则把答案加一,否则可以跳出循环(因为矩阵元素都非负)。

###py

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        col_sum = [0] * len(grid[0])
        ans = 0
        for row in grid:
            s = 0
            for j, x in enumerate(row):
                col_sum[j] += x
                s += col_sum[j]
                if s > k:
                    break
                ans += 1
        return ans

###java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int n = grid[0].length;
        int[] colSum = new int[n];
        int ans = 0;
        for (int[] row : grid) {
            int s = 0;
            for (int j = 0; j < n; j++) {
                colSum[j] += row[j];
                s += colSum[j];
                if (s > k) {
                    break;
                }
                ans++;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n = grid[0].size();
        vector<int> col_sum(n);
        int ans = 0;
        for (auto& row : grid) {
            int s = 0;
            for (int j = 0; j < n; j++) {
                col_sum[j] += row[j];
                s += col_sum[j];
                if (s > k) {
                    break;
                }
                ans++;
            }
        }
        return ans;
    }
};

###go

func countSubmatrices(grid [][]int, k int) (ans int) {
colSum := make([]int, len(grid[0]))
for _, row := range grid {
s := 0
for j, x := range row {
colSum[j] += x
s += colSum[j]
if s > k {
break
}
ans++
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(n)$。

:如果把每列元素和保存到 $\textit{grid}$ 的第一行,可以做到 $\mathcal{O}(1)$ 额外空间。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

枚举子矩形的底边 + O(mn) 优化(Python/Java/C++/Go)

作者 endlesscheng
2026年3月10日 22:22

做法类似 85. 最大矩形,枚举子矩形的底边(最后一行),定义 $\textit{heights}[j]$ 表示从 $\textit{matrix}[i][j]$ 往上有多少个连续的 $1$(柱子的高度),问题变成:

  • 你可以重排 $\textit{heights}$。重排后,对于 $\textit{heights}$ 的连续子数组,子数组长度(矩形底边长)$\times$ 子数组最小值(矩形的高),即为全 $1$ 子矩形的面积。

对于示例 1,以第三行为底边算出来的 $\textit{heights} = [2,0,3]$,下图重排后是 $[2,3,0]$。其中子数组 $[2,3]$,长为 $2$,最小值为 $2$,所以对应的子矩形面积为 $2\times 2 = 4$。

lc1727.png{:width=430px}

如何找到面积最大的子矩形?还是枚举。

枚举子数组的长度 $k = 1,2,\ldots,n$。由于我们可以重排 $\textit{heights}$,那么贪心地,把 $\textit{heights}$ 最大的 $k$ 个数排在一起,就可以让子数组的最小值(矩形的高)尽量大,从而得到最大的矩形面积。

对于 $\textit{heights}$ 的计算,如果 $\textit{matrix}[i][j]=0$,那么 $\textit{heights}[j] = 0$。否则,把高度增加 $1$。形象地说,就是在柱子下面垫一块石头,把柱子抬高。

优化前

###py

class Solution:
    def largestSubmatrix(self, matrix: List[List[int]]) -> int:
        n = len(matrix[0])
        heights = [0] * n
        ans = 0

        for row in matrix:  # 枚举子矩形的底边
            for j, x in enumerate(row):
                if x == 0:
                    heights[j] = 0
                else:
                    heights[j] += 1

            hs = sorted(heights)  # 复制一份 heights 再排序
            for i, h in enumerate(hs):  # 把 hs[i:] 作为子数组
                # 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
                ans = max(ans, (n - i) * h)  

        return ans

###java

class Solution {
    public int largestSubmatrix(int[][] matrix) {
        int n = matrix[0].length;
        int[] heights = new int[n];
        int ans = 0;

        for (int[] row : matrix) { // 枚举子矩形的底边
            for (int j = 0; j < n; j++) {
                if (row[j] == 0) {
                    heights[j] = 0;
                } else {
                    heights[j]++;
                }
            }

            int[] hs = heights.clone();
            Arrays.sort(hs);
            for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
                // 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
                ans = Math.max(ans, (n - i) * hs[i]); 
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int largestSubmatrix(vector<vector<int>>& matrix) {
        int n = matrix[0].size();
        vector<int> heights(n);
        int ans = 0;

        for (auto& row : matrix) { // 枚举子矩形的底边
            for (int j = 0; j < n; j++) {
                int x = row[j];
                if (x == 0) {
                    heights[j] = 0;
                } else {
                    heights[j]++;
                }
            }

            auto hs = heights;
            ranges::sort(hs);
            for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
                // 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
                ans = max(ans, (n - i) * hs[i]); 
            }
        }
        return ans;
    }
};

###go

func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)

for _, row := range matrix { // 枚举子矩形的底边
for j, x := range row {
if x == 0 {
heights[j] = 0
} else {
heights[j]++
}
}

hs := slices.Clone(heights)
slices.Sort(hs)
for i, h := range hs { // 把 hs[i:] 作为子数组
ans = max(ans, (n-i)*h) // 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
}
}

return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn\log n)$,其中 $m$ 和 $n$ 分别是 $\textit{matrix}$ 的行数和列数。瓶颈在排序上,有 $m$ 行,每行都要跑一个 $\mathcal{O}(n\log n)$ 的排序。
  • 空间复杂度:$\mathcal{O}(n)$。

优化

考察从 $i-1$ 行到 $i$ 行,$\textit{heights}$ 会如何变化:

  • 如果 $\textit{matrix}[i][j] = 0$,那么 $\textit{heights}[j] = 0$。在排序后,$0$ 会排在大于 $0$ 的高度前面。
  • 如果 $\textit{matrix}[i][j] = 1$,那么 $\textit{heights}[j]$ 增加一。对于那些增加一的高度,相对大小是不变的,无需再次排序。比如把 $1,2,3$ 都增加一,得到 $2,3,4$,这三个数的相对大小不变。

举个例子。假设 $i-1$ 行的 $\textit{heights}$ 排序后是 $[0,{\color{red}0},{\color{red}0},1,{\color{red}2},{\color{red}3}]$,把红色数字加一,其余数字变成 $0$,得到 $[0,{\color{red}1},{\color{red}1},0,{\color{red}3},{\color{red}4}]$。把 $0$ 排在红色数字前面,得到 $[0,0,{\color{red}1},{\color{red}1},{\color{red}3},{\color{red}4}]$。注意红色数字的相对大小是不变的,无需再次排序。

一般地,如果已知 $i-1$ 行的 $\textit{heights}$ 排序后的结果,那么对于 $i$ 行,我们只需把高度变成 $0$ 的数据排在前面,其余(增加一的)高度的相对大小不变,无需再次排序。这样就可以把排序的时间从 $\mathcal{O}(n\log n)$ 优化成 $\mathcal{O}(n)$。

但是,如果直接对 $\textit{heights}$ 排序,我们就不知道每个高度对应矩阵的哪一列了。如何解决?创建一个 $0$ 到 $n-1$ 的下标数组(列号数组)$\textit{idx}$,对下标数组排序。

###py

class Solution:
    def largestSubmatrix(self, matrix: List[List[int]]) -> int:
        n = len(matrix[0])
        heights = [0] * n
        idx = list(range(n))  # 按照高度排序后的列号
        ans = 0

        for row in matrix:
            zeros = []
            non_zeros = []
            for j in idx:
                if row[j] == 0:
                    heights[j] = 0
                    zeros.append(j)
                else:
                    heights[j] += 1
                    non_zeros.append(j)
            idx = zeros + non_zeros  # 把高度为 0 的列号排在其他高度前面

            # heights[idx[i]] 是递增的
            for i in range(len(zeros), n):  # 高度 0 无需计算
                ans = max(ans, (n - i) * heights[idx[i]])

        return ans

###java

class Solution {
    public int largestSubmatrix(int[][] matrix) {
        int n = matrix[0].length;
        int[] heights = new int[n];
        int[] idx = new int[n]; // 按照高度排序后的列号
        for (int i = 0; i < n; i++) {
            idx[i] = i;
        }
        int[] nonZeros = new int[n]; // 避免在循环内反复申请内存
        int ans = 0;

        for (int[] row : matrix) {
            int p = 0;
            int q = 0;
            for (int j : idx) {
                if (row[j] == 0) {
                    heights[j] = 0;
                    idx[p++] = j; // 高度 0 排在前面
                } else {
                    heights[j]++;
                    nonZeros[q++] = j;
                }
            }

            // heights[idx[i]] 是递增的
            for (int i = p; i < n; i++) { // 高度 0 无需计算
                idx[i] = nonZeros[i - p]; // 把 nonZeros 复制到 idx 的 [p,n-1] 中
                ans = Math.max(ans, (n - i) * heights[idx[i]]);
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int largestSubmatrix(vector<vector<int>>& matrix) {
        int n = matrix[0].size();
        vector<int> heights(n);
        vector<int> idx(n); // 按照高度排序后的列号
        ranges::iota(idx, 0); // idx[i] = i
        vector<int> non_zeros(n); // 避免在循环内反复申请内存
        int ans = 0;

        for (auto& row : matrix) {
            int p = 0, q = 0;
            for (int j : idx) {
                if (row[j] == 0) {
                    heights[j] = 0;
                    idx[p++] = j; // 高度 0 排在前面
                } else {
                    heights[j]++;
                    non_zeros[q++] = j;
                }
            }

            // heights[idx[i]] 是递增的
            for (int i = p; i < n; i++) { // 高度 0 无需计算
                idx[i] = non_zeros[i - p]; // 把 non_zeros 复制到 idx 的 [p,n-1] 中
                ans = max(ans, (n - i) * heights[idx[i]]);
            }
        }

        return ans;
    }
};

###go

func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)
idx := make([]int, n) // 按照高度排序后的列号
for i := range idx {
idx[i] = i
}
_nonZeros := make([]int, n) // 避免在循环内反复申请内存

for _, row := range matrix {
zeros := idx[:0]
nonZeros := _nonZeros[:0]
for _, j := range idx {
if row[j] == 0 {
heights[j] = 0
zeros = append(zeros, j)
} else {
heights[j]++
nonZeros = append(nonZeros, j)
}
}
idx = append(zeros, nonZeros...) // 把高度为 0 的列号排在其他高度前面

// heights[idx[i]] 是递增的
for i := len(zeros); i < n; i++ { // 高度 0 无需计算
ans = max(ans, (n-i)*heights[idx[i]])
}
}

return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{matrix}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面贪心题单的「§1.6 先枚举,再贪心」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

懒更新 + 等价转化(Python/Java/C++/Go)

作者 endlesscheng
2026年3月8日 10:18

只有加法

从特殊到一般,先考虑一个简单的问题:只有加法,没有乘法,怎么做?

执行 $\texttt{addAll}(\textit{inc})$ 时,如果把每个数都增加 $\textit{inc}$,就太慢了。

我们可以采用一种「懒更新」的想法,等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才做计算。

比如序列 $\textit{vals} = [3,1,4]$,执行 $\texttt{addAll}(2)$ 时,我们不去把 $\textit{vals}$ 的每个数都增加 $2$,而是用一个变量 $\textit{add}$ 表示「每个数都要增加 $\textit{add}$」。执行 $\texttt{addAll}(2)$ 时,只把 $\textit{add}$ 增加 $2$。等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才计算加法:

$$
\textit{vals}[\textit{idx}] + \textit{add}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] + \textit{add}$ 这个式子,我们可以先把 $\textit{val}$ 减少 $\textit{add}$,再添加到 $\textit{vals}$ 的末尾,比如 $\textit{val} = 5$,$\textit{add} = 2$,那么往 $\textit{vals}$ 的末尾添加 $5-2=3$,就可以让式子 $\textit{vals}[\textit{idx}] + \textit{add}$ 对所有元素都保持一致

只有乘法

如果只有乘法,没有加法呢?

同理,用变量 $\textit{mul}$ 表示「每个数都要乘以 $\textit{mul}$」。执行 $\texttt{multAll}(2)$ 时,只把 $\textit{mul}$ 乘以 $2$。等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才计算乘法:

$$
\textit{vals}[\textit{idx}] \cdot \textit{mul}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] \cdot \textit{mul}$ 这个式子,我们可以先把 $\textit{val}$ 除以 $\textit{mul}$,再添加到 $\textit{vals}$ 的末尾,比如 $\textit{val} = 6$,$\textit{mul} = 2$,那么往 $\textit{vals}$ 的末尾添加 $6/2=3$,就可以让式子 $\textit{vals}[\textit{idx}] \cdot \textit{mul}$ 对所有元素都保持一致

注意:在模运算中,除以 $\textit{mul}$ 等价于乘以 $\textit{mul}$ 关于 $M = 10^9+7$ 的逆元,即 $\textit{mul}^{M-2}$。原理见 模运算的世界:当加减乘除遇上取模

加法和乘法

把上述方法结合起来,用 $\textit{add}$ 记录操作 $\texttt{addAll}$,用 $\textit{mul}$ 记录操作 $\texttt{multAll}$。

  • 初始值:$\textit{add} = 0$,$\textit{mul}=1$。
  • 执行 $\texttt{addAll}(\textit{inc})$ 时,把 $\textit{add}$ 增加 $\textit{inc}$。
  • 执行 $\texttt{multAll}(m)$ 时,由于 $(v \cdot \textit{mul} + \textit{add})\cdot m = v \cdot (\textit{mul}\cdot m) + \textit{add}\cdot m$,所以把 $\textit{mul}$ 乘以 $m$,把 $\textit{add}$ 乘以 $m$。

调用 $\texttt{getIndex}(\textit{idx})$ 时,计算

$$
\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}$ 这个式子,我们可以先计算 $v = \dfrac{\textit{val} - \textit{add}}{\textit{mul}}$,再把 $v$ 添加到 $\textit{vals}$ 的末尾,就可以让式子 $\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}$ 对所有元素都保持一致

代码实现时,注意取模。为什么可以在中途取模?原理见 模运算的世界:当加减乘除遇上取模

###py

MOD = 1_000_000_007

class Fancy:
    def __init__(self):
        self.vals = []
        self.add = 0
        self.mul = 1

    def append(self, val: int) -> None:
        self.vals.append((val - self.add) * pow(self.mul, -1, MOD) % MOD)

    def addAll(self, inc: int) -> None:
        self.add += inc

    def multAll(self, m: int) -> None:
        self.mul = self.mul * m % MOD
        self.add = self.add * m % MOD

    def getIndex(self, idx: int) -> int:
        if idx >= len(self.vals):
            return -1
        return (self.vals[idx] * self.mul + self.add) % MOD

###java

class Fancy {
    private static final int MOD = 1_000_000_007;

    private final List<Integer> vals = new ArrayList<>();
    private long add = 0;
    private long mul = 1;

    public void append(int val) {
        // 注意这里有减法,计算结果可能是负数,+MOD 可以保证计算结果非负
        vals.add((int) ((val - add + MOD) * pow(mul, MOD - 2) % MOD));
    }

    public void addAll(int inc) {
        add = (add + inc) % MOD;
    }

    public void multAll(int m) {
        mul = mul * m % MOD;
        add = add * m % MOD;
    }

    public int getIndex(int idx) {
        if (idx >= vals.size()) {
            return -1;
        }
        return (int) ((vals.get(idx) * mul + add) % MOD);
    }

    private long pow(long x, int n) {
        long res = 1;
        for (; n > 0; n /= 2) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }
}

###cpp

class Fancy {
    static constexpr int MOD = 1'000'000'007;

    vector<int> vals;
    long long add = 0;
    long long mul = 1;

    long long pow(long long x, int n) {
        long long res = 1;
        for (; n; n /= 2) {
            if (n % 2) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }

public:
    void append(int val) {
        // 注意这里有减法,计算结果可能是负数,+MOD 可以保证计算结果非负
        vals.push_back((val - add + MOD) * pow(mul, MOD - 2) % MOD);
    }

    void addAll(int inc) {
        add = (add + inc) % MOD;
    }

    void multAll(int m) {
        mul = mul * m % MOD;
        add = add * m % MOD;
    }

    int getIndex(int idx) {
        if (idx >= vals.size()) {
            return -1;
        }
        return (vals[idx] * mul + add) % MOD;
    }
};

###go

const mod = 1_000_000_007

func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}

type Fancy struct {
vals []int
add  int
mul  int
}

func Constructor() Fancy {
return Fancy{mul: 1}
}

func (f *Fancy) Append(val int) {
// 注意这里有减法,计算结果可能是负数,+mod 可以保证计算结果非负
f.vals = append(f.vals, (val-f.add+mod)*pow(f.mul, mod-2)%mod)
}

func (f *Fancy) AddAll(inc int) {
f.add = (f.add + inc) % mod
}

func (f *Fancy) MultAll(m int) {
f.mul = f.mul * m % mod
f.add = f.add * m % mod
}

func (f *Fancy) GetIndex(idx int) int {
if idx >= len(f.vals) {
return -1
}
return (f.vals[idx]*f.mul + f.add) % mod
}

复杂度分析

  • 时间复杂度:$\texttt{append}$ 为 $\mathcal{O}(\log M)$,其余为 $\mathcal{O}(1)$,其中 $M=10^9+7$。
  • 空间复杂度:$\mathcal{O}(q)$,其中 $q$ 是 $\texttt{append}$ 的调用次数。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

O(n) 简洁写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2026年3月4日 16:09

首先计算有多少个长为 $n$ 的开心字符串。

这是一个计数问题。有 $n$ 个空位,第一个位置可以填 $3$ 种字母,后面的每个位置,都不能和前一个位置的字母相同,所以都只有 $2$ 种填法,因此方案数为

$$
3\cdot 2^{n-1}
$$

如果 $k > 3\cdot 2^{n-1}$,返回空串。否则答案是存在的。


为方便计算,我们先把 $k$ 减一,改成从 $0$ 开始。

以 $n=4$,$k=12=1100_{(2)}$(减一后)为例说明。

答案的第一个字母是什么?

如果第一个字母是 $\texttt{a}$,那么后面的 $n-1=3$ 个位置有 $2^{n-1}=8$ 种填法,不够。

如果第一个字母是 $\texttt{b}$,同样地,后面的 $n-1=3$ 个位置有 $2^{n-1}=8$ 种填法,这就够了,所以答案的第一个字母填 $\texttt{b}$。

现在答案为 $\texttt{b___}$,剩余的三个位置填什么字母?

由于每个位置都有两种填法,这和 $k$ 的二进制可以完美对应,如下表(注意相邻字母不同的要求):

$k$ 的低三位 对应填法
$000$ $\texttt{aba}$
$001$ $\texttt{abc}$
$010$ $\texttt{aca}$
$011$ $\texttt{acb}$
$100$ $\texttt{cab}$
$101$ $\texttt{cac}$
$110$ $\texttt{cba}$
$111$ $\texttt{cbc}$

对于 $k=1100_{(2)}$(减一后)这个例子,答案是这样填的:

  • 答案的第二个字母不能和前一个字母 $\textit{b}$ 相同,只能填 $\texttt{a}$ 或者 $\texttt{c}$,由于 $k$ 这一位是 $1$,所以填 $\texttt{c}$。
  • 答案的第三个字母不能和前一个字母 $\textit{c}$ 相同,只能填 $\texttt{a}$ 或者 $\texttt{b}$,由于 $k$ 这一位是 $0$,所以填 $\texttt{a}$。
  • 答案的第四个字母不能和前一个字母 $\textit{a}$ 相同,只能填 $\texttt{b}$ 或者 $\texttt{c}$,由于 $k$ 这一位是 $0$,所以填 $\texttt{b}$。

所以答案为 $\texttt{bcab}$。

一般地,答案的第一个字母是第 $\left\lfloor\dfrac{k}{2^{n-1}}\right\rfloor$ 个小写字母,随后的字母可以根据 $k\bmod 2^{n-1}$ 二进制从高到低是 $0$ 还是 $1$,填入相应的字母:

  • 首先,如果二进制这一位是 $0$,那么填入 $\texttt{a}$,否则填入 $\texttt{b}$。
  • 然后修正:看填入的字母是否大于等于左侧相邻字母,如果大于等于,那么把填入的字母加一。比如左侧相邻字母是 $\texttt{a}$,那么当前这一位如果填的是 $\texttt{a}$,要变成 $\texttt{b}$;如果填的是 $\texttt{b}$,要变成 $\texttt{c}$。

:这个「加一」的技巧可用于生成两个不同的随机整数,见 961. 在长度 2N 的数组中找出重复 N 次的元素 我的题解的方法四。

###py

class Solution:
    def getHappyString(self, n: int, k: int) -> str:
        if k > 3 << (n - 1):
            return ""
        k -= 1  # 改成从 0 开始,方便计算
        ans = [ord('a')] * n
        ans[0] += k >> (n - 1)
        for i in range(1, n):
            ans[i] += k >> (n - 1 - i) & 1
            if ans[i] >= ans[i - 1]:
                ans[i] += 1
        return ''.join(map(chr, ans))

###java

class Solution {
    public String getHappyString(int n, int k) {
        if (k > 3 << (n - 1)) {
            return "";
        }
        k--; // 改成从 0 开始,方便计算
        char[] ans = new char[n];
        ans[0] = (char) ('a' + (k >> (n - 1)));
        for (int i = 1; i < n; i++) {
            ans[i] = (char) ('a' + (k >> (n - 1 - i) & 1));
            if (ans[i] >= ans[i - 1]) {
                ans[i]++;
            }
        }
        return new String(ans);
    }
}

###cpp

class Solution {
public:
    string getHappyString(int n, int k) {
        if (k > 3 << (n - 1)) {
            return "";
        }
        k--; // 改成从 0 开始,方便计算
        string ans(n, 'a');
        ans[0] += k >> (n - 1);
        for (int i = 1; i < n; i++) {
            ans[i] += k >> (n - 1 - i) & 1;
            if (ans[i] >= ans[i - 1]) {
                ans[i]++;
            }
        }
        return ans;
    }
};

###c

char* getHappyString(int n, int k) {
    if (k > 3 << (n - 1)) {
        return "";
    }
    k--; // 改成从 0 开始,方便计算
    char* ans = malloc((n + 1) * sizeof(char));
    ans[0] = 'a' + (k >> (n - 1));
    for (int i = 1; i < n; i++) {
        ans[i] = 'a' + (k >> (n - 1 - i) & 1);
        if (ans[i] >= ans[i - 1]) {
            ans[i]++;
        }
    }
    ans[n] = '\0';
    return ans;
}

###go

func getHappyString(n, k int) string {
if k > 3<<(n-1) {
return ""
}
k-- // 改成从 0 开始,方便计算
ans := make([]byte, n)
ans[0] = 'a' + byte(k>>(n-1))
for i := 1; i < n; i++ {
ans[i] = 'a' + byte(k>>(n-1-i)&1)
if ans[i] >= ans[i-1] {
ans[i]++
}
}
return string(ans)
}

###js

var getHappyString = function(n, k) {
    if (k > 3 << (n - 1)) {
        return "";
    }
    k--; // 改成从 0 开始,方便计算
    const ans = Array(n).fill('a'.charCodeAt(0));
    ans[0] += k >> (n - 1);
    for (let i = 1; i < n; i++) {
        ans[i] += k >> (n - 1 - i) & 1;
        if (ans[i] >= ans[i - 1]) {
            ans[i]++;
        }
    }
    return ans.map(c => String.fromCharCode(c)).join('');
};

###rust

impl Solution {
    pub fn get_happy_string(n: i32, mut k: i32) -> String {
        if k > 3 << (n - 1) {
            return String::new();
        }
        k -= 1; // 改成从 0 开始,方便计算
        let n = n as usize;
        let mut ans = vec![0; n];
        ans[0] = b'a' + (k >> (n - 1)) as u8;
        for i in 1..n {
            ans[i] = b'a' + (k >> (n - 1 - i) & 1) as u8;
            if ans[i] >= ans[i - 1] {
                ans[i] += 1;
            }
        }
        unsafe { String::from_utf8_unchecked(ans) }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(1)$。返回值不计入。

相似题目

60. 排列序列

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:最小堆模拟/二分答案(Python/Java/C++/Go)

作者 endlesscheng
2024年9月22日 12:08

方法一:最小堆模拟

循环 $\textit{mountainHeight}$ 次,每次选一个「工作后总用时」最短的工人,把山的高度降低 $1$。

注意工人们是同时工作的,这个算法不是贪心,是按照事件发生顺序的模拟。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        h = [(t, t, t) for t in workerTimes]
        heapify(h)

        for _ in range(mountainHeight):
            # 工作后总用时,当前工作(山高度降低 1)用时,workerTimes[i]
            total, cur, base = h[0]
            heapreplace(h, (total + cur + base, cur + base, base))
        return total  # 最后一个出堆的 total 即为答案

###java

class Solution {
    public long minNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        PriorityQueue<long[]> pq = new PriorityQueue<>((a, b) -> Long.compare(a[0], b[0]));
        for (int t : workerTimes) {
            pq.offer(new long[]{t, t, t});
        }

        long ans = 0;
        while (mountainHeight-- > 0) {
            // 工作后总用时,当前工作(山高度降低 1)用时,workerTimes[i]
            long[] top = pq.poll();
            long total = top[0], cur = top[1], base = top[2];
            ans = total; // 最后一个出堆的 total 即为答案
            pq.offer(new long[]{total + cur + base, cur + base, base});
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    long long minNumberOfSeconds(int mountainHeight, vector<int>& workerTimes) {
        priority_queue<tuple<long long, long long, int>, vector<tuple<long long, long long, int>>, greater<>> pq;
        for (int t : workerTimes) {
            pq.emplace(t, t, t);
        }

        long long ans = 0;
        while (mountainHeight--) {
            // 工作后总用时,当前工作(山高度降低 1)用时,workerTimes[i]
            auto [total, cur, base] = pq.top(); pq.pop();
            ans = total; // 最后一个出堆的 total 即为答案
            pq.emplace(total + cur + base, cur + base, base);
        }
        return ans;
    }
};

###go

func minNumberOfSeconds(mountainHeight int, workerTimes []int) int64 {
h := make(hp, len(workerTimes))
for i, t := range workerTimes {
h[i] = worker{t, t, t}
}
heap.Init(&h)

ans := 0
for range mountainHeight {
ans = h[0].total // 最后一个出堆的 total 即为答案
h[0].cur += h[0].base
h[0].total += h[0].cur
heap.Fix(&h, 0)
}
return int64(ans)
}

// 工作后总用时,当前工作(山高度降低 1)用时,workerTimes[i]
type worker struct{ total, cur, base int }
type hp []worker
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].total < h[j].total }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (hp) Push(any)             {}
func (hp) Pop() (_ any)         { return }

复杂度分析

  • 时间复杂度:$\mathcal{O}(\textit{mountainHeight}\log n)$,其中 $n$ 是 $\textit{workerTimes}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

方法二:二分答案

由于花的时间越多,能够降低的高度也越多,所以有单调性,可以二分答案。

问题变成:

  • 每个工人至多花费 $m$ 秒,总共降低的高度是多少?能否大于等于 $\textit{mountainHeight}$?

遍历 $\textit{workerTimes}$,设 $t=\textit{workerTimes}[i]$,那么有

$$
t + 2t+ \cdots + xt = t\cdot \dfrac{x(x+1)}{2} \le m
$$

$$
\dfrac{x(x+1)}{2} \le \left\lfloor\dfrac{m}{t}\right\rfloor = k
$$

解得

$$
x \le \dfrac{-1 + \sqrt{1 + 8k}}{2}
$$

所以第 $i$ 名工人可以把山的高度降低

$$
\left\lfloor \dfrac{-1 + \sqrt{1 + 8k}}{2} \right\rfloor = \left\lfloor \dfrac{-1 + \lfloor\sqrt{1 + 8k}\rfloor}{2} \right\rfloor
$$

上式是个关于下取整的恒等式,理由见 下取整恒等式及其应用

累加上式,如果和 $\ge \textit{mountainHeight}$,则说明答案 $\le m$,否则说明答案 $> m$。

最后,讨论二分的上下界。这里用开区间二分,其他二分写法也是可以的。

  • 开区间二分下界:$0$,无法把山的高度降低到 $0$。
  • 开区间二分上界:设 $\textit{maxT}$ 为 $\textit{workerTimes}$ 的最大值,假设每个工人都是最慢的 $\textit{maxT}$,那么单个工人要把山降低 $h=\left\lceil\dfrac{mountainHeight}{n}\right\rceil$,耗时 $\textit{maxT}\cdot(1+2+\cdots+h)=\textit{maxT}\cdot\dfrac{h(h+1)}{2}$,将其作为开区间的二分上界,一定可以把山的高度降低到 $\le 0$。

关于上取整的计算,当 $a$ 和 $b$ 均为正整数时,我们有

$$
\left\lceil\dfrac{a}{b}\right\rceil = \left\lfloor\dfrac{a-1}{b}\right\rfloor + 1
$$

证明见 上取整下取整转换公式的证明

###py

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        def check(m: int) -> bool:
            left_h = mountainHeight
            for t in workerTimes:
                left_h -= (isqrt(m // t * 8 + 1) - 1) // 2
                if left_h <= 0:
                    return True
            return False

        max_t = max(workerTimes)
        h = (mountainHeight - 1) // len(workerTimes) + 1
        return bisect_left(range(max_t * h * (h + 1) // 2), True, 1, key=check)

###py

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        f = lambda m: sum((isqrt(m // t * 8 + 1) - 1) // 2 for t in workerTimes)
        max_t = max(workerTimes)
        h = (mountainHeight - 1) // len(workerTimes) + 1
        return bisect_left(range(max_t * h * (h + 1) // 2), mountainHeight, 1, key=f)

###java

class Solution {
    public long minNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxT = 0;
        for (int t : workerTimes) {
            maxT = Math.max(maxT, t);
        }
        int h = (mountainHeight - 1) / workerTimes.length + 1;
        long left = 0;
        long right = (long) maxT * h * (h + 1) / 2;
        while (left + 1 < right) {
            long mid = (left + right) / 2;
            if (check(mid, mountainHeight, workerTimes)) {
                right = mid;
            } else {
                left = mid;
            }
        }
        return right;
    }

    private boolean check(long m, int leftH, int[] workerTimes) {
        for (int t : workerTimes) {
            leftH -= ((int) Math.sqrt(m / t * 8 + 1) - 1) / 2;
            if (leftH <= 0) {
                return true;
            }
        }
        return false;
    }
}

###cpp

class Solution {
public:
    long long minNumberOfSeconds(int mountainHeight, vector<int>& workerTimes) {
        auto check = [&](long long m) {
            int left_h = mountainHeight;
            for (int t : workerTimes) {
                left_h -= ((int) sqrt(m / t * 8 + 1) - 1) / 2;
                if (left_h <= 0) {
                    return true;
                }
            }
            return false;
        };

        int max_t = ranges::max(workerTimes);
        int h = (mountainHeight - 1) / workerTimes.size() + 1;
        long long left = 0, right = (long long) max_t * h * (h + 1) / 2;
        while (left + 1 < right) {
            long long mid = (left + right) / 2;
            (check(mid) ? right : left) = mid;
        }
        return right;
    }
};

###go

func minNumberOfSeconds(mountainHeight int, workerTimes []int) int64 {
maxT := slices.Max(workerTimes)
h := (mountainHeight-1)/len(workerTimes) + 1
ans := 1 + sort.Search(maxT*h*(h+1)/2-1, func(m int) bool {
m++
leftH := mountainHeight
for _, t := range workerTimes {
leftH -= (int(math.Sqrt(float64(m/t*8+1))) - 1) / 2
if leftH <= 0 {
return true
}
}
return false
})
return int64(ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log U)$,其中 $n$ 是 $\textit{workerTimes}$ 的长度,$U\le 5\cdot 10^{10}(10^5+1)$ 是二分上界。二分 $\mathcal{O}(\log U)$ 次,每次 $\mathcal{O}(n)$ 时间。开平方有专门的 CPU 指令,可以视作 $\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面二分题单的「二分答案:求最小」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

库函数写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2026年3月11日 07:59

题目让我们把 $n$ 取反。

例如二进制 $n=11001$,取反后是 $00110$,即十进制的 $6$。

看上去,计算 ~n 就好了?

但这样做会把更高位的 $0$ 也取反,对于 $32$ 位整数来说,$11001$ 实际是 $00000000000000000000000000011001$,取反后是 $11111111111111111111111111100110$。

所以对于这个例子,要只把 $n$ 的低 $5$ 位取反,也就是计算 $n$ 和 $11111$ 的异或。

$11111$ 怎么算?设 $w=5$ 是 $n$ 的二进制长度,计算 1 << w 可以得到 $100000$,再减去 $1$,得到 $11111$。

特殊情况:根据题意,$n=0$ 反转后是 $1$,如果用库函数算 $n=0$ 的二进制长度,会算出 $0$,这会导致 $n$ 取反后的值是 $0$。所以特判 $n=0$ 的情况,返回 $1$。

###py

class Solution:
    def bitwiseComplement(self, n: int) -> int:
        if n == 0:
            return 1
        w = n.bit_length()
        return ((1 << w) - 1) ^ n

###java

class Solution {
    public int bitwiseComplement(int n) {
        if (n == 0) {
            return 1;
        }
        int w = 32 - Integer.numberOfLeadingZeros(n);
        return ((1 << w) - 1) ^ n;
    }
}

###cpp

class Solution {
public:
    int bitwiseComplement(int n) {
        if (n == 0) {
            return 1;
        }
        int w = bit_width((uint32_t) n);
        return ((1 << w) - 1) ^ n;
    }
};

###c

int bitwiseComplement(int n) {
    if (n == 0) {
        return 1;
    }
    int w = 32 - __builtin_clz(n);
    return ((1 << w) - 1) ^ n;
}

###go

func bitwiseComplement(n int) int {
if n == 0 {
return 1
}
w := bits.Len(uint(n))
return 1<<w - 1 ^ n
}

###js

var bitwiseComplement = function(n) {
    if (n === 0) {
        return 1;
    }
    const w = 32 - Math.clz32(n);
    return ((1 << w) - 1) ^ n;
};

###rust

impl Solution {
    pub fn bitwise_complement(n: i32) -> i32 {
        if n == 0 {
            return 1;
        }
        let w = n.ilog2() + 1;
        ((1 << w) - 1) ^ n
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面位运算题单的「一、基础题」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

从 DP 到组合数学(Python/Java/C++/Go)

作者 endlesscheng
2024年4月28日 09:25

方法一:记忆化搜索

前置知识动态规划入门:从记忆化搜索到递推,其中包含如何把记忆化搜索 1:1 翻译成递推的技巧。

先解释 $\textit{limit}$,意思是数组中至多有连续 $\textit{limit}$ 个 $0$,且至多有连续 $\textit{limit}$ 个 $1$。

看示例 3,$\textit{zero}=3,\ \textit{one}=3,\ \textit{limit}=2$。

考虑稳定数组的最后一个位置填 $0$ 还是 $1$:

  • 填 $0$,问题变成剩下 $2$ 个 $0$ 和 $3$ 个 $1$ 怎么填。
  • 填 $1$,问题变成剩下 $3$ 个 $0$ 和 $2$ 个 $1$ 怎么填。
  • 这两个都是和原问题相似的子问题。

看上去,定义 $\textit{dfs}(i,j)$ 表示用 $i$ 个 $0$ 和 $j$ 个 $1$ 构造稳定数组的方案数?但这样定义不方便计算 $\textit{limit}$ 带来的影响。

比如 $\textit{limit}=3$,如果在以 $1000$ 结尾的稳定数组的后面,再添加一个 $0$,得到以 $10000$ 结尾的数组,这是不合法的。我们需要把这种不合法的情况减掉。

为了减去不合法的情况,我们需要知道稳定数组的最后一个数是 $0$ 还是 $1$。

改成定义 $\textit{dfs}(i,j,k)$ 表示用 $i$ 个 $0$ 和 $j$ 个 $1$ 构造长为 $i+j$ 的稳定数组的方案数,其中最后一个位置(第 $i+j$ 个位置)已经填入了 $k$,其中 $k$ 为 $0$ 或 $1$。

以 $k=0$ 为例,考虑 $\textit{dfs}(i,j,0)$ 怎么算。现在最后一个位置(第 $i+j$ 个位置)已经填入了 $0$,消耗了一个 $0$,还剩下 $i-1$ 个 $0$。问题变成用 $i-1$ 个 $0$ 和 $j$ 个 $1$ 构造长为 $i+j-1$ 的稳定数组的方案数。对于这个子问题,枚举最后一个位置(第 $i+j-1$ 个位置)填入 $k=0$ 还是 $k=1$,对应的方案数为 $\textit{dfs}(i-1,j,0)$ 和 $\textit{dfs}(i-1,j,1)$。

看上去,把这两种情况加起来,我们就得到了

$$
\textit{dfs}(i,j,0) = \textit{dfs}(i-1,j,0) + \textit{dfs}(i-1,j,1)
$$

但是,这会产生不合法的情况。

以 $\textit{limit}=3$ 为例说明。$\textit{dfs}(i-1,j,0)$ 是一些以 $0$ 结尾的稳定数组(合法数组),讨论末尾 $0$ 的个数:

  • 末尾恰好有连续 $1$ 个 $0$,即 $10$。
  • 末尾恰好有连续 $2$ 个 $0$,即 $100$。
  • 末尾恰好有连续 $3$ 个 $0$,即 $1000$。由于末尾不能超过连续 $3$ 个 $0$,末尾是 $000$ 的稳定数组,倒数第 $4$ 个数一定是 $1$,也就是在 $\textit{dfs}(i-1,j,0)$ 中有 $\textit{dfs}(i-4,j,1)$ 个末尾是 $1000$ 的稳定数组。

若要通过 $\textit{dfs}(i-1,j,0)$ 计算 $\textit{dfs}(i,j,0)$,相当于往这 $\textit{dfs}(i-1,j,0)$ 个稳定数组的末尾再加一个 $0$:

  • 末尾恰好有连续 $2$ 个 $0$,即 $100$,这是合法的。
  • 末尾恰好有连续 $3$ 个 $0$,即 $1000$,这是合法的。
  • 末尾恰好有连续 $4$ 个 $0$,即 $10000$,这是不合法的,要全部去掉!根据上面的分析,要减去 $\textit{dfs}(i-4,j,1)$。

一般地,在 $\textit{dfs}(i-1,j,0)$ 个稳定数组的末尾添加一个 $0$,会得到 $\textit{dfs}(i-\textit{limit}-1,j,1)$ 个不合法的数组。从上文的状态转移方程中减掉 $\textit{dfs}(i-\textit{limit}-1,j,1)$,得

$$
\textit{dfs}(i,j,0) = \textit{dfs}(i-1,j,0) + \textit{dfs}(i-1,j,1) - \textit{dfs}(i-\textit{limit}-1,j,1)
$$

同理得

$$
\textit{dfs}(i,j,1) = \textit{dfs}(i,j-1,0) + \textit{dfs}(i,j-1,1) - \textit{dfs}(i,j-\textit{limit}-1,0)
$$

递归边界

  • 如果 $i<0$ 或者 $j<0$,返回 $0$。也可以在递归 $\textit{dfs}(i-\textit{limit}-1,j,1)$ 前判断 $i>\textit{limit}$,在递归 $\textit{dfs}(i,j-\textit{limit}-1,0)$ 前判断 $j>\textit{limit}$。下面代码在递归前判断。
  • 如果 $i=0$,那么当 $k=1$ 且 $j\le \textit{limit}$ 的情况下返回 $1$,否则返回 $0$;如果 $j=0$,那么当 $k=0$ 且 $i\le \textit{limit}$ 的情况下返回 $1$,否则返回 $0$。

递归入口:$\textit{dfs}(\textit{zero},\textit{one},0)+\textit{dfs}(\textit{zero},\textit{one},1)$,即答案。

请看 视频讲解 第四题,欢迎点赞关注~

###py

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        MOD = 1_000_000_007
        @cache  # 缓存装饰器,避免重复计算 dfs 的结果(记忆化)
        def dfs(i: int, j: int, k: int) -> int:
            if i == 0:
                return 1 if k == 1 and j <= limit else 0
            if j == 0:
                return 1 if k == 0 and i <= limit else 0
            if k == 0:
                return (dfs(i - 1, j, 0) + dfs(i - 1, j, 1) - (dfs(i - limit - 1, j, 1) if i > limit else 0)) % MOD
            else:  # else 可以去掉,这里仅仅是为了代码对齐
                return (dfs(i, j - 1, 0) + dfs(i, j - 1, 1) - (dfs(i, j - limit - 1, 0) if j > limit else 0)) % MOD
        ans = (dfs(zero, one, 0) + dfs(zero, one, 1)) % MOD
        dfs.cache_clear()  # 防止爆内存
        return ans

###java

class Solution {
    private static final int MOD = 1_000_000_007;

    public int numberOfStableArrays(int zero, int one, int limit) {
        int[][][] memo = new int[zero + 1][one + 1][2];
        for (int[][] m : memo) {
            for (int[] m2 : m) {
                Arrays.fill(m2, -1); // -1 表示没有计算过
            }
        }
        return (dfs(zero, one, 0, limit, memo) + dfs(zero, one, 1, limit, memo)) % MOD;
    }

    private int dfs(int i, int j, int k, int limit, int[][][] memo) {
        if (i == 0) { // 递归边界
            return k == 1 && j <= limit ? 1 : 0;
        }
        if (j == 0) { // 递归边界
            return k == 0 && i <= limit ? 1 : 0;
        }
        if (memo[i][j][k] != -1) { // 之前计算过
            return memo[i][j][k];
        }
        if (k == 0) {
            // + MOD 保证答案非负
            memo[i][j][k] = (int) (((long) dfs(i - 1, j, 0, limit, memo) + dfs(i - 1, j, 1, limit, memo) +
                    (i > limit ? MOD - dfs(i - limit - 1, j, 1, limit, memo) : 0)) % MOD);
        } else {
            memo[i][j][k] = (int) (((long) dfs(i, j - 1, 0, limit, memo) + dfs(i, j - 1, 1, limit, memo) +
                    (j > limit ? MOD - dfs(i, j - limit - 1, 0, limit, memo) : 0)) % MOD);
        }
        return memo[i][j][k];
    }
}

###cpp

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        const int MOD = 1'000'000'007;
        vector memo(zero + 1, vector<array<int, 2>>(one + 1, {-1, -1})); // -1 表示没有计算过

        auto dfs = [&](this auto&& dfs, int i, int j, int k) -> int {
            if (i == 0) { // 递归边界
                return k == 1 && j <= limit;
            }
            if (j == 0) { // 递归边界
                return k == 0 && i <= limit;
            }
            int& res = memo[i][j][k]; // 注意这里是引用
            if (res != -1) { // 之前计算过
                return res;
            }
            if (k == 0) {
                // + MOD 保证答案非负
                res = ((long long) dfs(i - 1, j, 0) + dfs(i - 1, j, 1) +
                       (i > limit ? MOD - dfs(i - limit - 1, j, 1) : 0)) % MOD;
            } else {
                res = ((long long) dfs(i, j - 1, 0) + dfs(i, j - 1, 1) +
                       (j > limit ? MOD - dfs(i, j - limit - 1, 0) : 0)) % MOD;
            }
            return res;
        };

        return (dfs(zero, one, 0) + dfs(zero, one, 1)) % MOD;
    }
};

###go

func numberOfStableArrays(zero, one, limit int) int {
const mod = 1_000_000_007
memo := make([][][2]int, zero+1)
for i := range memo {
memo[i] = make([][2]int, one+1)
for j := range memo[i] {
memo[i][j] = [2]int{-1, -1}
}
}
var dfs func(int, int, int) int
dfs = func(i, j, k int) (res int) {
if i == 0 { // 递归边界
if k == 1 && j <= limit {
return 1
}
return
}
if j == 0 { // 递归边界
if k == 0 && i <= limit {
return 1
}
return
}
p := &memo[i][j][k]
if *p != -1 { // 之前计算过
return *p
}
if k == 0 {
// +mod 保证答案非负
res = (dfs(i-1, j, 0) + dfs(i-1, j, 1)) % mod
if i > limit {
res = (res - dfs(i-limit-1, j, 1) + mod) % mod
}
} else {
res = (dfs(i, j-1, 0) + dfs(i, j-1, 1)) % mod
if j > limit {
res = (res - dfs(i, j-limit-1, 0) + mod) % mod
}
}
*p = res // 记忆化
return
}
return (dfs(zero, one, 0) + dfs(zero, one, 1)) % mod
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\textit{zero}\cdot \textit{one})$。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(\textit{zero}\cdot \textit{one})$,单个状态的计算时间为 $\mathcal{O}(1)$,所以动态规划的时间复杂度为 $\mathcal{O}(\textit{zero}\cdot \textit{one})$。
  • 空间复杂度:$\mathcal{O}(\textit{zero}\cdot \textit{one})$。有多少个状态,$\textit{memo}$ 数组的大小就是多少。

方法二:递推

和 $\textit{dfs}(i,j,k)$ 一样,定义 $f[i][j][k]$ 表示用 $i$ 个 $0$ 和 $j$ 个 $1$ 构造稳定数组的方案数,其中第 $i+j$ 个位置要填 $k$,其中 $k$ 为 $0$ 或 $1$。

状态转移方程:

$$
\begin{aligned}
f[i][j][0] &= f[i-1][j][0] + f[i-1][j][1] - f[i-\textit{limit}-1][j][1] \
f[i][j][1] &= f[i][j-1][0] + f[i][j-1][1] - f[i][j-\textit{limit}-1][0] \
\end{aligned}
$$

如果 $i\le \textit{limit}$ 则 $f[i-\textit{limit}-1][j][1]$ 视作 $0$,如果 $j\le \textit{limit}$ 则 $f[i][j-\textit{limit}-1][0]$ 视作 $0$。

初始值:$f[i][0][0] = f[0][j][1] = 1$,其中 $1\le i \le \min(\textit{limit}, \textit{zero}),\ 1\le j \le \min(\textit{limit}, \textit{one})$。翻译自递归边界。

答案:$f[\textit{zero}][\textit{one}][0] + f[\textit{zero}][\textit{one}][1]$。翻译自递归入口。

###py

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        MOD = 1_000_000_007
        f = [[[0, 0] for _ in range(one + 1)] for _ in range(zero + 1)]
        for i in range(1, min(limit, zero) + 1):
            f[i][0][0] = 1
        for j in range(1, min(limit, one) + 1):
            f[0][j][1] = 1
        for i in range(1, zero + 1):
            for j in range(1, one + 1):
                f[i][j][0] = (f[i - 1][j][0] + f[i - 1][j][1] - (f[i - limit - 1][j][1] if i > limit else 0)) % MOD
                f[i][j][1] = (f[i][j - 1][0] + f[i][j - 1][1] - (f[i][j - limit - 1][0] if j > limit else 0)) % MOD
        return sum(f[-1][-1]) % MOD

###java

class Solution {
    public int numberOfStableArrays(int zero, int one, int limit) {
        final int MOD = 1_000_000_007;
        int[][][] f = new int[zero + 1][one + 1][2];
        for (int i = 1; i <= Math.min(limit, zero); i++) {
            f[i][0][0] = 1;
        }
        for (int j = 1; j <= Math.min(limit, one); j++) {
            f[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                // + MOD 保证答案非负
                f[i][j][0] = (int) (((long) f[i - 1][j][0] + f[i - 1][j][1] + (i > limit ? MOD - f[i - limit - 1][j][1] : 0)) % MOD);
                f[i][j][1] = (int) (((long) f[i][j - 1][0] + f[i][j - 1][1] + (j > limit ? MOD - f[i][j - limit - 1][0] : 0)) % MOD);
            }
        }
        return (f[zero][one][0] + f[zero][one][1]) % MOD;
    }
}

###cpp

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        const int MOD = 1'000'000'007;
        vector<vector<array<int, 2>>> f(zero + 1, vector<array<int, 2>>(one + 1));
        for (int i = 1; i <= min(limit, zero); i++) {
            f[i][0][0] = 1;
        }
        for (int j = 1; j <= min(limit, one); j++) {
            f[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                // + MOD 保证答案非负
                f[i][j][0] = ((long long) f[i - 1][j][0] + f[i - 1][j][1] + (i > limit ? MOD - f[i - limit - 1][j][1] : 0)) % MOD;
                f[i][j][1] = ((long long) f[i][j - 1][0] + f[i][j - 1][1] + (j > limit ? MOD - f[i][j - limit - 1][0] : 0)) % MOD;
            }
        }
        return (f[zero][one][0] + f[zero][one][1]) % MOD;
    }
};

###go

func numberOfStableArrays(zero, one, limit int) (ans int) {
const mod = 1_000_000_007
f := make([][][2]int, zero+1)
for i := range f {
f[i] = make([][2]int, one+1)
}
for i := 1; i <= min(limit, zero); i++ {
f[i][0][0] = 1
}
for j := 1; j <= min(limit, one); j++ {
f[0][j][1] = 1
}
for i := 1; i <= zero; i++ {
for j := 1; j <= one; j++ {
f[i][j][0] = (f[i-1][j][0] + f[i-1][j][1]) % mod
if i > limit {
// + mod 保证答案非负
f[i][j][0] = (f[i][j][0] - f[i-limit-1][j][1] + mod) % mod
}
f[i][j][1] = (f[i][j-1][0] + f[i][j-1][1]) % mod
if j > limit {
f[i][j][1] = (f[i][j][1] - f[i][j-limit-1][0] + mod) % mod
}
}
}
return (f[zero][one][0] + f[zero][one][1]) % mod
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\textit{zero}\cdot \textit{one})$。
  • 空间复杂度:$\mathcal{O}(\textit{zero}\cdot \textit{one})$。

方法三:容斥原理+乘法原理

回顾一个经典的组合数学问题:

  • 把 $n$ 个无区别的小球,放入 $m$ 个有区别的盒子,不允许空盒,有多少种方案?

这可以用隔板法解决,$n$ 个小球之间有 $n-1$ 个空隙,从中选择 $m-1$ 个空隙,插入 $m-1$ 个隔板,这样就把小球分成了 $m$ 组,并且每一组都是非空的,方案数就是 $n-1$ 选 $m-1$ 的组合数 $\dbinom {n-1} {m-1}$。

  • 只考虑 $0$,把 $0$ 分成 $i$ 组,方案数就是 $f_0[i] = \dbinom {\textit{zero}-1} {i-1}$;
  • 只考虑 $1$,把 $1$ 分成 $i$ 组,方案数就是 $f_1[i] = \dbinom {\textit{one}-1} {i-1}$。

如何综合考虑 $0$ 和 $1$?要如何计算方案数?

例如 $10110001$,相当于把 $0$ 分成了 $2$ 组,把 $1$ 分成了 $3$ 组。

一般地,设 $1$ 分成了 $i$ 组,那么 $0$ 会分成多少组?有哪些情况?

有如下四种情况:

  • $0$ 分成 $i-1$ 组,例如 $10110001$。注意第一个数和最后一个数一定是 $1$。
  • $0$ 分成 $i$ 组,且第一个数是 $0$,例如 $01010011$。注意最后一个数一定是 $1$。
  • $0$ 分成 $i$ 组,且第一个数是 $1$,例如 $10100110$。注意最后一个数一定是 $0$。
  • $0$ 分成 $i+1$ 组,例如 $01010110$。注意第一个数和最后一个数一定是 $0$。

注意 $0$ 和 $1$ 内部的分组方案是互相独立的,例如

$$
\begin{aligned}
&11010001\
&10110001\
&10100011
\end{aligned}
$$

这些例子的 $0$ 的组数不变,$1$ 的组数不变,$0$ 的分组方式也不变(都是一个 $0$ 和三个 $0$),只有 $1$ 的分组方式在变。

根据乘法原理,综合考虑 $0$ 和 $1$,把 $1$ 分成 $i$ 组总的方案数,等于上面说的四种情况(只考虑 $0$,把 $0$ 分成 $i-1,i,i+1$ 组)的方案数之和,乘以只考虑 $1$,把 $1$ 分成 $i$ 组的方案数,即

$$
(f_0[i-1] + 2\cdot f_0[i] + f_0[i+1])\cdot f_1[i]
$$

接下来,考虑 $\textit{limit}$ 带来的影响。推荐先看 2929. 给小朋友们分糖果 II 以及 我的题解

根据容斥原理,对于 $f_0[i] = \dbinom {\textit{zero}-1} {i-1}$,我们需要减去「至少 $1$ 组有超过 $\textit{limit}$ 个 $0$」的方案数,再加上「至少 $2$ 组有超过 $\textit{limit}$ 个 $0$」的方案数,再减去「至少 $3$ 组有超过 $\textit{limit}$ 个 $0$」的方案数,……,直到「至少 $j$ 组有超过 $\textit{limit}$ 个 $0$」的方案数,$j$ 的值见下文。

  • 至少 $j$ 组有超过 $\textit{limit}$ 个 $0$,相当于先从 $i$ 组中选 $j$ 组,每组先放入 $\textit{limit}$ 个 $0$,然后再把剩下的 $\textit{zero} - j\cdot \textit{limit}$ 分成 $i$ 组(需要满足 $\textit{zero} - j\cdot \textit{limit} \ge i$),方案数为

$$
\dbinom {i} {j} \dbinom {\textit{zero} - j\cdot \textit{limit}-1} {i-1}
$$

所以

$$
f_0[i] = \dbinom {\textit{zero}-1} {i-1} + \sum_{j} (-1)^j \dbinom {i} {j} \dbinom {\textit{zero} - j\cdot \textit{limit}-1} {i-1}
$$

其中 $j\ge 1$ 且需要满足 $\textit{zero} - j\cdot \textit{limit} \ge i$,即

$$
1\le j\le \left\lfloor\dfrac{zero - i}{\textit{limit}}\right\rfloor
$$

同理有

$$
f_1[i] = \dbinom {\textit{one}-1} {i-1} + \sum_{j} (-1)^j \dbinom {i} {j} \dbinom {\textit{one} - j\cdot \textit{limit}-1} {i-1}
$$

其中

$$
1\le j\le \left\lfloor\dfrac{one - i}{\textit{limit}}\right\rfloor
$$

最终答案为

$$
\sum_{i} (f_0[i-1] + 2\cdot f_0[i] + f_0[i+1])\cdot f_1[i]
$$

其中:

  1. $i\le \textit{one}$。因为 $1$ 最多分成 $\textit{one}$ 组。
  2. $i-1\le \textit{zero}$。因为 $0$ 最多分成 $\textit{zero}$ 组,当 $i-1 > \textit{zero}$ 时,上式中的 $f_0[i-1] = f_0[i] = f_0[i+1] = 0$,无需累加。
  3. $i\cdot \textit{limit}\ge \textit{one}$,即 $i\ge\left\lceil\dfrac{\textit{one}}{\textit{limit}}\right\rceil$。因为每组至多 $\textit{limit}$ 个 $1$,分成 $i$ 组,至多 $i\cdot \textit{limit}$ 个 $1$,这个数必须 $\ge \textit{one}$,不然剩下的 $1$ 放到哪一组都会导致组内 $1$ 的个数超过 $\textit{limit}$。

整理得

$$
\left\lceil\dfrac{\textit{one}}{\textit{limit}}\right\rceil \le i\le \min(\textit{one},\textit{zero}+1)
$$

代码实现时:

  1. 上取整 $\left\lceil\dfrac{a}{b}\right\rceil$ 转换成下取整 $\left\lfloor\dfrac{a-1}{b}\right\rfloor+1$。
  2. $(-1)^j$ 可以用 $1-j\bmod 2\cdot 2$ 表示,因为当 $j$ 是偶数时,该式为 $1$;当 $j$ 是奇数时,该式为 $-1$,符合 $(-1)^j$。
  3. 预处理阶乘及其逆元,利用公式 $\dbinom {n} {m} = \dfrac{n!}{m!(n-m)!}$ 计算组合数。

关于取模的知识点,见 模运算的世界:当加减乘除遇上取模

###py

MOD = 1_000_000_007
MX = 1001

fac = [0] * MX  # f[i] = i!
fac[0] = 1
for i in range(1, MX):
    fac[i] = fac[i - 1] * i % MOD

inv_f = [0] * MX  # inv_f[i] = i!^-1
inv_f[-1] = pow(fac[-1], -1, MOD)
for i in range(MX - 1, 0, -1):
    inv_f[i - 1] = inv_f[i] * i % MOD

def comb(n: int, m: int) -> int:
    return fac[n] * inv_f[m] * inv_f[n - m] % MOD

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        if zero > one:
            zero, one = one, zero  # 保证空间复杂度为 O(min(zero, one))
        f0 = [0] * (zero + 3)
        for i in range((zero - 1) // limit + 1, zero + 1):
            f0[i] = comb(zero - 1, i - 1)
            for j in range(1, (zero - i) // limit + 1):
                f0[i] = (f0[i] + (-1 if j % 2 else 1) * comb(i, j) * comb(zero - j * limit - 1, i - 1)) % MOD

        ans = 0
        for i in range((one - 1) // limit + 1, min(one, zero + 1) + 1):
            f1 = comb(one - 1, i - 1)
            for j in range(1, (one - i) // limit + 1):
                f1 = (f1 + (-1 if j % 2 else 1) * comb(i, j) * comb(one - j * limit - 1, i - 1)) % MOD
            ans = (ans + (f0[i - 1] + f0[i] * 2 + f0[i + 1]) * f1) % MOD
        return ans

###java

class Solution {
    private static final int MOD = 1_000_000_007;
    private static final int MX = 1001;

    private static final long[] F = new long[MX]; // f[i] = i!
    private static final long[] INV_F = new long[MX]; // inv_f[i] = i!^-1

    static {
        F[0] = 1;
        for (int i = 1; i < MX; i++) {
            F[i] = F[i - 1] * i % MOD;
        }

        INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
        for (int i = MX - 1; i > 0; i--) {
            INV_F[i - 1] = INV_F[i] * i % MOD;
        }
    }

    public int numberOfStableArrays(int zero, int one, int limit) {
        if (zero > one) {
            // swap,保证空间复杂度为 O(min(zero, one))
            int t = zero;
            zero = one;
            one = t;
        }
        long[] f0 = new long[zero + 3];
        for (int i = (zero - 1) / limit + 1; i <= zero; i++) {
            f0[i] = comb(zero - 1, i - 1);
            for (int j = 1; j <= (zero - i) / limit; j++) {
                f0[i] = (f0[i] + (1 - j % 2 * 2) * comb(i, j) * comb(zero - j * limit - 1, i - 1)) % MOD;
            }
        }

        long ans = 0;
        for (int i = (one - 1) / limit + 1; i <= Math.min(one, zero + 1); i++) {
            long f1 = comb(one - 1, i - 1);
            for (int j = 1; j <= (one - i) / limit; j++) {
                f1 = (f1 + (1 - j % 2 * 2) * comb(i, j) * comb(one - j * limit - 1, i - 1)) % MOD;
            }
            ans = (ans + (f0[i - 1] + f0[i] * 2 + f0[i + 1]) * f1) % MOD;
        }
        return (int) ((ans + MOD) % MOD); // 保证结果非负
    }

    private long comb(int n, int m) {
        return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
    }

    private static long pow(long x, int n) {
        long res = 1;
        for (; n > 0; n /= 2) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }
}

###cpp

const int MOD = 1'000'000'007;
const int MX = 1001;

long long F[MX]; // F[i] = i!
long long INV_F[MX]; // INV_F[i] = i!^-1

long long pow(long long x, int n) {
    long long res = 1;
    for (; n; n /= 2) {
        if (n % 2) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
    }
    return res;
}

auto init = [] {
    F[0] = 1;
    for (int i = 1; i < MX; i++) {
        F[i] = F[i - 1] * i % MOD;
    }

    INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
    for (int i = MX - 1; i; i--) {
        INV_F[i - 1] = INV_F[i] * i % MOD;
    }
    return 0;
}();

long long comb(int n, int m) {
    return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
}

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        if (zero > one) {
            swap(zero, one); // 保证空间复杂度为 O(min(zero, one))
        }
        vector<long long> f0(zero + 3);
        for (int i = (zero - 1) / limit + 1; i <= zero; i++) {
            f0[i] = comb(zero - 1, i - 1);
            for (int j = 1; j <= (zero - i) / limit; j++) {
                f0[i] = (f0[i] + (1 - j % 2 * 2) * comb(i, j) * comb(zero - j * limit - 1, i - 1)) % MOD;
            }
        }

        long long ans = 0;
        for (int i = (one - 1) / limit + 1; i <= min(one, zero + 1); i++) {
            long long f1 = comb(one - 1, i - 1);
            for (int j = 1; j <= (one - i) / limit; j++) {
                f1 = (f1 + (1 - j % 2 * 2) * comb(i, j) * comb(one - j * limit - 1, i - 1)) % MOD;
            }
            ans = (ans + (f0[i - 1] + f0[i] * 2 + f0[i + 1]) * f1) % MOD;
        }
        return (ans + MOD) % MOD; // 保证结果非负
    }
};

###go

const mod = 1_000_000_007
const mx = 1001

var f [mx]int    // f[i] = i!
var invF [mx]int // invF[i] = i!^-1

func init() {
f[0] = 1
for i := 1; i < mx; i++ {
f[i] = f[i-1] * i % mod
}

invF[mx-1] = pow(f[mx-1], mod-2)
for i := mx - 1; i > 0; i-- {
invF[i-1] = invF[i] * i % mod
}
}

func comb(n, m int) int {
return f[n] * invF[m] % mod * invF[n-m] % mod
}

func numberOfStableArrays(zero, one, limit int) (ans int) {
if zero > one {
zero, one = one, zero // 保证空间复杂度为 O(min(zero, one))
}
f0 := make([]int, zero+3)
for i := (zero-1)/limit + 1; i <= zero; i++ {
f0[i] = comb(zero-1, i-1)
for j := 1; j <= (zero-i)/limit; j++ {
f0[i] = (f0[i] + (1-j%2*2)*comb(i, j)*comb(zero-j*limit-1, i-1)) % mod
}
}

for i := (one-1)/limit + 1; i <= min(one, zero+1); i++ {
f1 := comb(one-1, i-1)
for j := 1; j <= (one-i)/limit; j++ {
f1 = (f1 + (1-j%2*2)*comb(i, j)*comb(one-j*limit-1, i-1)) % mod
}
ans = (ans + (f0[i-1]+f0[i]*2+f0[i+1])*f1) % mod
}
return (ans + mod) % mod // 保证结果非负
}

func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}

复杂度分析

  • 时间复杂度:$\mathcal{O}\left(\dfrac{\textit{zero}\cdot\textit{one}}{\textit{limit}}\right)$。忽略预处理的时间和空间。
  • 空间复杂度:$\mathcal{O}(\min(\textit{zero},\textit{one}))$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:暴力枚举 / 康托对角线(Python/Java/C++/Go)

作者 endlesscheng
2021年8月22日 12:13

方法一:暴力枚举

把 $\textit{nums}$ 中的字符串转成二进制整数,保存到一个哈希集合中。

枚举 $\textit{ans} = 0,1,2,\ldots$ 直到 $\textit{ans}$ 不在哈希集合中,即为答案。

方法二告诉我们,满足要求的答案是一定存在的。

class Solution:
    def findDifferentBinaryString(self, nums: List[str]) -> str:
        st = {int(s, 2) for s in nums}

        ans = 0
        while ans in st:
            ans += 1

        n = len(nums)
        return f"{ans:0{n}b}"
class Solution {
    public String findDifferentBinaryString(String[] nums) {
        Set<Integer> set = new HashSet<>();
        for (String s : nums) {
            set.add(Integer.parseInt(s, 2));
        }

        int ans = 0;
        while (set.contains(ans)) {
            ans++;
        }

        String bin = Integer.toBinaryString(ans);
        return "0".repeat(nums.length - bin.length()) + bin;
    }
}
class Solution {
public:
    string findDifferentBinaryString(vector<string>& nums) {
        unordered_set<int> st;
        for (auto& s : nums) {
            st.insert(stoi(s, nullptr, 2));
        }

        int ans = 0;
        while (st.contains(ans)) {
            ans++;
        }

        int n = nums.size();
        return bitset<32>(ans).to_string().substr(32 - n);
    }
};
func findDifferentBinaryString(nums []string) string {
n := len(nums)
has := make(map[int]bool, n)
for _, s := range nums {
x, _ := strconv.ParseInt(s, 2, 64)
has[int(x)] = true
}

ans := 0
for has[ans] {
ans++
}

return fmt.Sprintf("%0*b", n, ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{nums}$ 的长度。把长为 $n$ 的字符串转成整数需要 $\mathcal{O}(n)$ 的时间。
  • 空间复杂度:$\mathcal{O}(n)$。

方法二:康托对角线

这个方法灵感来自数学家康托关于「实数是不可数无限」的证明。

例如 $\textit{nums} = [\texttt{111}, \texttt{011}, \texttt{000}]$。我们可以构造一个字符串 $\textit{ans}$,满足:

  • $\textit{ans}[0] = \texttt{0} \ne \textit{nums}[0][0]$。
  • $\textit{ans}[1] = \texttt{0} \ne \textit{nums}[1][1]$。
  • $\textit{ans}[2] = \texttt{1} \ne \textit{nums}[2][2]$。

$\textit{ans} = \texttt{001}$ 和每个 $\textit{nums}[i]$ 都至少有一个字符不同,满足题目要求。

一般地,令 $\textit{ans}[i] = \textit{nums}[i][i]\oplus 1$,即可满足要求。其中 $\oplus$ 是异或运算。

class Solution:
    def findDifferentBinaryString(self, nums: List[str]) -> str:
        ans = [''] * len(nums)
        for i, s in enumerate(nums):
            ans[i] = '1' if s[i] == '0' else '0'
        return ''.join(ans)
class Solution {
    public String findDifferentBinaryString(String[] nums) {
        int n = nums.length;
        char[] ans = new char[n];
        for (int i = 0; i < n; i++) {
            ans[i] = (char) (nums[i].charAt(i) ^ 1);
        }
        return new String(ans);
    }
}
class Solution {
public:
    string findDifferentBinaryString(vector<string>& nums) {
        int n = nums.size();
        string ans(n, 0);
        for (int i = 0; i < n; i++) {
            ans[i] = nums[i][i] ^ 1;
        }
        return ans;
    }
};
func findDifferentBinaryString(nums []string) string {
ans := make([]byte, len(nums))
for i, s := range nums {
ans[i] = s[i] ^ 1
}
return string(ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。注意这个方法没有遍历整个字符串,只访问了每个字符串的其中一个字符。
  • 空间复杂度:$\mathcal{O}(1)$,返回值不计入。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

简单题,简单做(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2026年3月6日 08:12

注意 $s$ 不含前导零,那么只含一段连续 $\texttt{1}$ 的 $s$ 只有两种情况:

  • $s$ 全是 $\texttt{1}$。
  • $s$ 是 $\texttt{11}\cdots\texttt{100}\cdots\texttt{0}$,一段 $\texttt{1}$ 和一段 $\texttt{0}$。

如果 $s$ 包含多段连续的 $\texttt{1}$,比如示例 1 的 $s = \texttt{1001}$, $\texttt{0}$ 的后面还有 $\texttt{1}$。所以检查 $s$ 是否包含 $\texttt{01}$ 即可。

:只有一个 $\texttt{1}$ 也算一段连续的 $\texttt{1}$。

###py

class Solution:
    def checkOnesSegment(self, s: str) -> bool:
        return "01" not in s

###java

class Solution {
    public boolean checkOnesSegment(String s) {
        return !s.contains("01");
    }
}

###cpp

class Solution {
public:
    bool checkOnesSegment(string s) {
        return s.find("01") == string::npos;
    }
};

###c

bool checkOnesSegment(char* s) {
    return strstr(s, "01") == NULL;
}

###go

func checkOnesSegment(s string) bool {
return !strings.Contains(s, "01")
}

###js

var checkOnesSegment = function(s) {
    return !s.includes("01");
};

###rust

impl Solution {
    pub fn check_ones_segment(s: String) -> bool {
        !s.contains("01")
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $s$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

从递归到 O(1) 数学公式(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2026年2月26日 08:55

方法一:递归 / 迭代

我们需要确定第 $k$ 个字符位于 $S_n$ 的左半、正中间还是右半。为此,首先要知道 $S_n$ 的长度。

用 $|s|$ 表示字符串 $s$ 的长度。根据题意,$|S_1| = 1$,$|S_n| = 2|S_{n-1}| + 1$,所以有

$$
|S_n| + 1 = 2(|S_{n-1}| + 1)
$$

所以 ${|S_n| + 1}$ 是个首项为 $2$,公比为 $2$ 的等比数列,得

$$
|S_n| = 2^n - 1
$$

所以 $|S_{n-1}| = 2^{n-1} - 1$,这说明 $S_n$ 的左半是第 $1$ 个字符到第 $2^{n-1}-1$ 个字符,正中间是第 $2^{n-1}$ 个字符,右半是第 $2^{n-1} + 1$ 个字符到第 $2^n-1$ 个字符。

分类讨论:

  • 如果 $k < 2^{n-1}$,那么第 $k$ 个字符位于 $S_n$ 的左半,问题变成 $S_{n-1}$ 的第 $k$ 个字符。这可以递归解决。
  • 如果 $k > 2^{n-1}$,那么第 $k$ 个字符位于 $S_n$ 的右半,问题变成 $S_{n-1}$ 反转后的第 $k-2^{n-1}$ 个字符,即反转前的第 $2^{n-1}-(k-2^{n-1}) = 2^n-k$ 个字符(比如 $k=2^n-1$ 对应反转前的第 $1$ 个字符)。这个字符再翻转,即为 $S_n$ 的第 $k$ 个字符。这也可以递归解决。

递归边界:

  • 如果 $n=1$,那么返回 $S_1$ 唯一的字符 $\texttt{0}$。
  • 如果 $k = 2^{n-1}$,那么返回 $S_n$ 正中间的字符 $\texttt{1}$。

递归写法

class Solution:
    def findKthBit(self, n: int, k: int) -> str:
        if n == 1:
            return '0'
        if k == 1 << (n - 1):
            return '1'
        if k < 1 << (n - 1):
            return self.findKthBit(n - 1, k)
        res = self.findKthBit(n - 1, (1 << n) - k)
        return '0' if res == '1' else '1'
class Solution {
    public char findKthBit(int n, int k) {
        if (n == 1) {
            return '0';
        }
        if (k == 1 << (n - 1)) {
            return '1';
        }
        if (k < 1 << (n - 1)) {
            return findKthBit(n - 1, k);
        }
        char res = findKthBit(n - 1, (1 << n) - k);
        return (char) (res ^ 1);
    }
}
class Solution {
public:
    char findKthBit(int n, int k) {
        if (n == 1) {
            return '0';
        }
        if (k == 1 << (n - 1)) {
            return '1';
        }
        if (k < 1 << (n - 1)) {
            return findKthBit(n - 1, k);
        }
        return findKthBit(n - 1, (1 << n) - k) ^ 1;
    }
};
char findKthBit(int n, int k) {
    if (n == 1) {
        return '0';
    }
    if (k == 1 << (n - 1)) {
        return '1';
    }
    if (k < 1 << (n - 1)) {
        return findKthBit(n - 1, k);
    }
    return findKthBit(n - 1, (1 << n) - k) ^ 1;
}
func findKthBit(n, k int) byte {
if n == 1 {
return '0'
}
if k == 1<<(n-1) {
return '1'
}
if k < 1<<(n-1) {
return findKthBit(n-1, k)
}
return findKthBit(n-1, 1<<n-k) ^ 1
}
var findKthBit = function(n, k) {
    if (n === 1) {
        return '0';
    }
    if (k === 1 << (n - 1)) {
        return '1';
    }
    if (k < 1 << (n - 1)) {
        return findKthBit(n - 1, k);
    }
    return findKthBit(n - 1, (1 << n) - k) === '1' ? '0' : '1';
};
impl Solution {
    pub fn find_kth_bit(n: i32, k: i32) -> char {
        if n == 1 {
            return '0';
        }
        if k == 1 << (n - 1) {
            return '1';
        }
        if k < 1 << (n - 1) {
            return Self::find_kth_bit(n - 1, k);
        }
        (Self::find_kth_bit(n - 1, (1 << n) - k) as u8 ^ 1) as _
    }
}

迭代写法

class Solution:
    def findKthBit(self, n: int, k: int) -> str:
        rev = 0  # 翻转次数的奇偶性
        while True:
            if n == 1:
                return '1' if rev else '0'
            if k == 1 << (n - 1):
                return '0' if rev else '1'
            if k > 1 << (n - 1):
                k = (1 << n) - k
                rev ^= 1
            n -= 1
class Solution {
    public char findKthBit(int n, int k) {
        int rev = 0; // 翻转次数的奇偶性
        while (true) {
            if (n == 1) {
                return (char) ('0' ^ rev);
            }
            if (k == 1 << (n - 1)) {
                return (char) ('1' ^ rev);
            }
            if (k > 1 << (n - 1)) {
                k = (1 << n) - k;
                rev ^= 1;
            }
            n--;
        }
    }
}
class Solution {
public:
    char findKthBit(int n, int k) {
        int rev = 0; // 翻转次数的奇偶性
        while (true) {
            if (n == 1) {
                return '0' ^ rev;
            }
            if (k == 1 << (n - 1)) {
                return '1' ^ rev;
            }
            if (k > 1 << (n - 1)) {
                k = (1 << n) - k;
                rev ^= 1;
            }
            n--;
        }
    }
};
char findKthBit(int n, int k) {
    int rev = 0; // 翻转次数的奇偶性
    while (true) {
        if (n == 1) {
            return '0' ^ rev;
        }
        if (k == 1 << (n - 1)) {
            return '1' ^ rev;
        }
        if (k > 1 << (n - 1)) {
            k = (1 << n) - k;
            rev ^= 1;
        }
        n--;
    }
}
func findKthBit(n, k int) byte {
rev := byte(0) // 翻转次数的奇偶性
for {
if n == 1 {
return '0' ^ rev
}
if k == 1<<(n-1) {
return '1' ^ rev
}
if k > 1<<(n-1) {
k = 1<<n - k
rev ^= 1
}
n--
}
}
var findKthBit = function(n, k) {
    let rev = 0; // 翻转次数的奇偶性
    while (true) {
        if (n === 1) {
            return rev ? '1' : '0';
        }
        if (k === 1 << (n - 1)) {
            return rev ? '0' : '1';
        }
        if (k > 1 << (n - 1)) {
            k = (1 << n) - k;
            rev ^= 1;
        }
        n--;
    }
};
impl Solution {
    pub fn find_kth_bit(mut n: i32, mut k: i32) -> char {
        let mut rev = 0; // 翻转次数的奇偶性
        loop {
            if n == 1 {
                return (b'0' ^ rev) as _;
            }
            if k == 1 << (n - 1) {
                return (b'1' ^ rev) as _;
            }
            if k > 1 << (n - 1) {
                k = (1 << n) - k;
                rev ^= 1;
            }
            n -= 1;
        }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:数学公式

奇数位

$S_4 = \texttt{011100110110001}$,只看奇数位(下标从 $1$ 开始)的字符,是 $\texttt{01010101}$,这是一个 $\texttt{01}$ 交替序列,为什么?

只看奇数位:

  • $S_1 = \texttt{0}$。
  • $S_2$ 把 $\texttt{0}$ 反转再翻转,得到 $\texttt{1}$,拼起来是 $\texttt{01}$。
  • $S_3$ 把 $\texttt{01}$ 反转再翻转,得到 $\texttt{01}$,拼起来是 $\texttt{0101}$。
  • $S_4$ 把 $\texttt{0101}$ 反转再翻转,得到 $\texttt{0101}$,拼起来是 $\texttt{01010101}$。

一般地,由于 $\texttt{01}$ 交替序列反转再翻转,结果不变,所以从 $S_{i-1}$ 到 $S_i\ (i\ge 3)$,其中奇数位相当于复制了一份自身,拼在了自身后面,得到的仍然是 $\texttt{01}$ 交替序列。

所以,当 $k$ 是奇数时,可以立刻得出答案:

  • 设 $k' = \dfrac{k-1}{2}$。这会把 $k=1,3,5,7,\ldots$ 变成 $k'=0,1,2,3,\ldots$
  • 如果 $k'$ 是偶数,那么答案是 $\texttt{0}$。
  • 如果 $k'$ 是奇数,那么答案是 $\texttt{1}$。
  • 一般地,答案为 $k'\bmod 2$ 对应的字符。

偶数位

奇数位的字符,都发源于 $S_1 = \texttt{0}$。

偶数位的字符呢?都发源于 $S_i\ (i\ge 2)$ 正中间的那个 $\texttt{1}$,即位置为 $2,4,8,16,\ldots$ 的字符 $\texttt{1}$。

根据方法一的结论,$S_{n-1}$ 的第 $k$ 个字符,反转后,是 $S_n$ 的第 $2^n-k$ 个字符。

$2^n-k$ 有什么性质?

比如二进制 $10000 - 100 = 1100$,去掉末尾的两个 $0$,相当于 $100 - 1 = 11$,结果最低位一定是 $1$,所以 $100$ 和 $1100$ 的尾零个数相同。一般地,$k$ 和 $2^n-k$ 的尾零个数是相同的,这是个不变量!我们可以根据 $k$ 的尾零个数,找到 $k$ 发源于哪个 $S_i$ 正中间的 $\texttt{1}$。

以 $S_2$ 的中间字符(第 $2$ 个字符)为例:

  • 我们把 $S_2$ 的第 $2$ 个字符反转到了 $S_3$ 的第 $8-2=6$ 个字符。把 $\texttt{1}$ 反转再翻转,得到 $\texttt{0}$,拼起来是 $\texttt{10}$。
  • 我们把 $S_3$ 的第 $2,6$ 个字符反转到了 $S_4$ 的第 $14,10$ 个字符。把 $\texttt{10}$ 反转再翻转,得到 $\texttt{10}$,拼起来是 $\texttt{1010}$。注意 $2,6,10,14$ 的二进制尾零个数都是 $1$,且这些位置上的字符拼起来是一个 $\texttt{10}$ 交替序列。

一般地,设 $t$ 为 $k$ 去掉尾零后的值,即 $k = t\cdot 2^x$ 且 $t$ 是奇数。比如 $k=2,6,10,14,\ldots$ 对应着 $t=1,3,5,7,\ldots$

  • 设 $t' = \dfrac{t-1}{2}$。这会把 $t=1,3,5,7,\ldots$ 变成 $t'=0,1,2,3,\ldots$
  • 如果 $t'$ 是偶数,那么答案是 $\texttt{1}$。
  • 如果 $t'$ 是奇数,那么答案是 $\texttt{0}$。
  • 一般地,答案为 $1 - t'\bmod 2$ 对应的字符。

如何去掉 $k$ 的尾零?把 $k$ 除以其 $\text{lowbit}$ 即可。关于 $\text{lowbit}$ 的原理,请看 从集合论到位运算,常见位运算技巧分类总结

class Solution:
    def findKthBit(self, _, k: int) -> str:
        if k % 2:
            return str(k // 2 % 2)
        k //= k & -k  # 去掉 k 的尾零
        return str(1 - k // 2 % 2)
class Solution {
    public char findKthBit(int n, int k) {
        if (k % 2 > 0) {
            return (char) ('0' + k / 2 % 2);
        }
        k /= k & -k; // 去掉 k 的尾零
        return (char) ('1' - k / 2 % 2);
    }
}
class Solution {
public:
    char findKthBit(int, int k) {
        if (k % 2) {
            return '0' + k / 2 % 2;
        }
        k /= k & -k; // 去掉 k 的尾零
        return '1' - k / 2 % 2;
    }
};
char findKthBit(int, int k) {
    if (k % 2) {
        return '0' + k / 2 % 2;
    }
    k /= k & -k; // 去掉 k 的尾零
    return '1' - k / 2 % 2;
}
func findKthBit(_, k int) byte {
if k%2 > 0 {
return '0' + byte(k/2%2)
}
k /= k & -k // 去掉 k 的尾零
return '1' - byte(k/2%2)
}
var findKthBit = function(_, k) {
    if (k % 2) {
        return (k - 1) / 2 % 2 ? '1' : '0';
    }
    k /= k & -k; // 去掉 k 的尾零
    return (k - 1) / 2 % 2 ? '0' : '1';
};
impl Solution {
    pub fn find_kth_bit(_: i32, mut k: i32) -> char {
        if k % 2 > 0 {
            return (b'0' + k as u8 / 2 % 2) as _;
        }
        k /= k & -k; // 去掉 k 的尾零
        (b'1' - k as u8 / 2 % 2) as _
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

相似题目

见下面回溯题单的「五、其他递归/分治」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

脑筋急转弯(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2026年2月24日 15:21

例如 $n=321$,其中最大的数字是 $3$。这个 $3$ 至少要拆分成 $3$ 个 $1$,即 $321=1__ + 1__ + 1__$。对于 $n$ 中的其余数字 $d$,可以拆分成 $d$ 个 $1$ 和 $3-d$ 个 $0$,即 $2=1+1+0$ 和 $1=1+0+0$,填到对应的位置上,得到 $321 = 111 + 110 + 100$。

一般地,设 $m$ 为 $n$ 中的最大数字,那么答案为 $m$。构造方案为:设 $n$ 的第 $i$ 个数字为 $n_i$,那么拆分出的这 $m$ 个数的第 $i$ 位上,有 $n_i$ 个 $1$ 和 $m-n_i$ 个 $0$(填入顺序随意)。

###py

class Solution:
    def minPartitions(self, n: str) -> int:
        return int(max(n))

###java

class Solution {
    public int minPartitions(String n) {
        int mx = 0;
        for (char ch : n.toCharArray()) {
            mx = Math.max(mx, ch);
        }
        return mx - '0';
    }
}

###cpp

class Solution {
public:
    int minPartitions(string n) {
        return ranges::max(n) - '0';
    }
};

###c

#define MAX(a, b) ((b) > (a) ? (b) : (a))

int minPartitions(char* n) {
    char mx = 0;
    for (int i = 0; n[i]; i++) {
        mx = MAX(mx, n[i]);
    }
    return mx - '0';
}

###go

func minPartitions(n string) int {
ans := rune(0)
for _, ch := range n {
ans = max(ans, ch)
}
return int(ans - '0')
}

###js

var minPartitions = function(n) {
    return Number(_.max(n));
};

###rust

impl Solution {
    pub fn min_partitions(n: String) -> i32 {
        (n.as_bytes().iter().max().unwrap() - b'0') as _
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(|n|)$,其中 $|n|$ 表示 $n$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面贪心与思维题单的「§5.2 脑筋急转弯」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

Golang 简洁写法

作者 endlesscheng
2020年12月6日 12:32

用位运算模拟这个过程:每拼接一个数 $i$,就把之前拼接过的数左移 $i$ 的二进制长度,然后加上 $i$。

由于左移后空出的位置全为 $0$,加法运算也可以写成或运算。

###go

func concatenatedBinary(n int) (ans int) {
    for i := 1; i <= n; i++ {
        ans = (ans<<bits.Len(uint(i)) | i) % (1e9 + 7)
    }
    return
}

两种方法:BFS / 数学(Python/Java/C++/Go)

作者 endlesscheng
2025年8月31日 10:29

方法一:BFS

做法和 2612. 最少翻转操作数 是类似的,请先阅读 我的题解

设 $s$ 的长度为 $n$,其中有 $z$ 个 $0$。

翻转一次后,$s$ 有多少个 $0$?$z$ 可以变成什么数?

设翻转了 $x$ 个 $0$,那么也同时翻转了 $k-x$ 个 $1$,这些 $1$ 变成了 $0$。

所以 $z$ 减少了 $x$,然后又增加了 $k-x$。

所以新的 $z'$ 为

$$
z' = z - x + (k-x) = z+k-2x
$$

$x$ 最大可以是 $k$,但这不能超过 $s$ 中的 $0$ 的个数 $z$,所以 $x$ 最大为 $\min(k,z)$。

$k-x$ 最大可以是 $k$,但这不能超过 $s$ 中的 $1$ 的个数 $n-z$,所以 $k-x$ 最大为 $\min(k,n-z)$,所以 $x$ 最小为 $\max(0,k-n+z)$。

所以 $x$ 的范围为

$$
[\max(0,k-n+z),\min(k,z)]
$$

其余逻辑同 2612 题。

###py

class Solution:
    def minOperations(self, s: str, k: int) -> int:
        n = len(s)
        not_vis = [SortedList(range(0, n + 1, 2)), SortedList(range(1, n + 1, 2))]
        not_vis[0].add(n + 1)  # 哨兵,下面 sl[idx] <= mx 无需判断越界
        not_vis[1].add(n + 1)

        start = s.count('0')  # 起点
        not_vis[start % 2].discard(start)
        q = [start]
        ans = 0
        while q:
            tmp = q
            q = []
            for z in tmp:
                if z == 0:  # 没有 0,翻转完毕
                    return ans
                # not_vis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
                mn = z + k - 2 * min(k, z)
                mx = z + k - 2 * max(0, k - n + z)
                sl = not_vis[mn % 2]
                idx = sl.bisect_left(mn)
                while sl[idx] <= mx:
                    j = sl.pop(idx)  # 注意 pop(idx) 会使后续元素向左移,不需要写 idx += 1
                    q.append(j)
            ans += 1
        return -1

###java

class Solution {
    public int minOperations(String s, int k) {
        int n = s.length();
        TreeSet<Integer>[] notVis = new TreeSet[2];
        for (int m = 0; m < 2; m++) {
            notVis[m] = new TreeSet<>();
            for (int i = m; i <= n; i += 2) {
                notVis[m].add(i);
            }
        }

        // 计算起点
        int start = 0;
        for (int i = 0; i < n; i++) {
            if (s.charAt(i) == '0') {
                start++;
            }
        }

        notVis[start % 2].remove(start);
        List<Integer> q = List.of(start);
        for (int ans = 0; !q.isEmpty(); ans++) {
            List<Integer> tmp = q;
            q = new ArrayList<>();
            for (int z : tmp) {
                if (z == 0) { // 没有 0,翻转完毕
                    return ans;
                }
                // notVis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
                int mn = z + k - 2 * Math.min(k, z);
                int mx = z + k - 2 * Math.max(0, k - n + z);
                TreeSet<Integer> set = notVis[mn % 2];
                for (Iterator<Integer> it = set.tailSet(mn).iterator(); it.hasNext(); it.remove()) {
                    int j = it.next();
                    if (j > mx) {
                        break;
                    }
                    q.add(j);
                }
            }
        }
        return -1;
    }
}

###cpp

class Solution {
public:
    int minOperations(string s, int k) {
        int n = s.size();
        set<int> not_vis[2];
        for (int m = 0; m < 2; m++) {
            for (int i = m; i <= n; i += 2) {
                not_vis[m].insert(i);
            }
            not_vis[m].insert(n + 1); // 哨兵,下面无需判断 it != st.end()
        }

        int start = ranges::count(s, '0'); // 起点
        not_vis[start % 2].erase(start);
        vector<int> q = {start};
        for (int ans = 0; !q.empty(); ans++) {
            vector<int> nxt;
            for (int z : q) {
                if (z == 0) { // 没有 0,翻转完毕
                    return ans;
                }
                // not_vis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
                int mn = z + k - 2 * min(k, z);
                int mx = z + k - 2 * max(0, k - n + z);
                auto& st = not_vis[mn % 2];
                for (auto it = st.lower_bound(mn); *it <= mx; it = st.erase(it)) {
                    nxt.push_back(*it);
                }
            }
            q = move(nxt);
        }
        return -1;
    }
};

###go

// import "github.com/emirpasic/gods/v2/trees/redblacktree"
func minOperations(s string, k int) (ans int) {
n := len(s)
notVis := [2]*redblacktree.Tree[int, struct{}]{}
for m := range notVis {
notVis[m] = redblacktree.New[int, struct{}]()
for i := m; i <= n; i += 2 {
notVis[m].Put(i, struct{}{})
}
notVis[m].Put(n+1, struct{}{}) // 哨兵,下面无需判断 node != nil
}

start := strings.Count(s, "0")
notVis[start%2].Remove(start)
q := []int{start}
for q != nil {
tmp := q
q = nil
for _, z := range tmp {
if z == 0 { // 没有 0,翻转完毕
return ans
}
// notVis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
mn := z + k - 2*min(k, z)
mx := z + k - 2*max(0, k-n+z)
t := notVis[mn%2]
for node, _ := t.Ceiling(mn); node.Key <= mx; node, _ = t.Ceiling(mn) {
q = append(q, node.Key)
t.Remove(node.Key)
}
}
ans++
}
return -1
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $s$ 的长度。$[0,n]$ 中的每个数至多入队出队各一次,每次 $\mathcal{O}(\log n)$ 时间。
  • 空间复杂度:$\mathcal{O}(n)$。

方法二:数学

分析

设 $s$ 中有 $z$ 个 $0$,设一共操作了 $m$ 次。那么总翻转次数为 $mk$。

这 $z$ 个 $0$ 必须翻转奇数次,其余 $n-z$ 个 $1$ 必须翻转偶数次。

总翻转次数减去 $z$,剩下每个位置都必须翻转偶数次,所以

$$
mk-z\ 是偶数
$$

下面计算 $m$ 的下界。只要能证明 $m$ 可以等于下界,问题就解决了。

要想把 $z$ 个 $0$ 变成 $1$,总翻转次数至少要是 $z$,即

$$
mk\ge z
$$

$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$

除此以外,还需要满足什么要求?

情况一:m 是偶数

由于 $mk-z$ 是偶数,如果 $m$ 是偶数,那么 $z$ 也必须是偶数。

$s$ 中的每个位置至多翻转 $m$ 次。但是,对于 $s$ 中的 $0$,由于要翻转奇数次,所以至多翻转 $m-1$ 次。

所以 $s$ 中的所有位置的翻转次数的上界是 $z(m-1)+(n-z)m$,其可能等于 $mk$,也可能比 $mk$ 大(因为是上界),所以有

$$
z(m-1)+(n-z)m\ge mk
$$

解得

$$
m\ge \left\lceil\dfrac{z}{n-k}\right\rceil
$$

$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$

联立得

$$
m\ge \max\left(\left\lceil\dfrac{z}{k}\right\rceil,\left\lceil\dfrac{z}{n-k}\right\rceil\right)
$$

情况二:m 是奇数

由于 $mk-z$ 是偶数,如果 $m$ 是奇数,那么 $z$ 和 $k$ 必须同为奇数,或者同为偶数(奇偶性相同)。

$s$ 中的每个位置至多翻转 $m$ 次。但是,对于 $s$ 中的 $1$,由于要翻转偶数次,所以至多翻转 $m-1$ 次。

所以 $s$ 中的所有位置的翻转次数的上界是 $zm+(n-z)(m-1)$,其可能等于 $mk$,也可能比 $mk$ 大(因为是上界),所以有

$$
zm+(n-z)(m-1)\ge mk
$$

解得

$$
m\ge \left\lceil\dfrac{n-z}{n-k}\right\rceil
$$

$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$

联立得

$$
m\ge \max\left(\left\lceil\dfrac{z}{k}\right\rceil,\left\lceil\dfrac{n-z}{n-k}\right\rceil\right)
$$

情况一和情况二取最小值。

如果两个情况都不满足要求,返回 $-1$。

下界可以取到

这可以用 Gale-Ryser 定理证明。

具体来说,我们需要判断是否存在一个 $m$ 行 $n$ 列的 $0\text{-}1$ 矩阵 $M$,第 $i$ 行对应着第 $i$ 次操作,其中 $M_{i,j} = 0$ 表示没有翻转 $s_j$,$M_{i,j} = 1$ 表示翻转 $s_j$。每一行的元素和都是 $k$,第 $j$ 列的元素和是 $s_j$ 的翻转次数 $a_j$。由于 $a_j\le m$ 且 $\sum\limits_{j} a_j\le mk$,由 Gale-Ryser 定理可得,这样的矩阵是存在的。

特殊情况

如果 $z=0$,那么无需操作,答案是 $0$。

由于下界公式中的分母 $n-k$ 不能为 $0$,我们需要特判 $n=k$ 的情况,此时每次操作只能翻转整个 $s$。

  • 如果 $z=n$,即 $s$ 全为 $0$,那么只需操作 $1$ 次。
  • 否则无论怎么操作,$s$ 中始终有 $0$,返回 $-1$。

上取整转成下取整

关于上取整的计算,当 $a$ 为非负整数,$b$ 为正整数时,有恒等式

$$
\left\lceil\dfrac{a}{b}\right\rceil = \left\lfloor\dfrac{a+b-1}{b}\right\rfloor
$$

证明见 上取整下取整转换公式的证明

###py

class Solution:
    def minOperations(self, s: str, k: int) -> int:
        n = len(s)
        z = s.count('0')
        if z == 0:
            return 0
        if k == n:
            return 1 if z == n else -1

        ans = inf
        # 情况一:操作次数 m 是偶数
        if z % 2 == 0:  # z 必须是偶数
            m = max((z + k - 1) // k, (z + n - k - 1) // (n - k))  # 下界
            ans = m + m % 2  # 把 m 往上调整为偶数

        # 情况二:操作次数 m 是奇数
        if z % 2 == k % 2:  # z 和 k 的奇偶性必须相同
            m = max((z + k - 1) // k, (n - z + n - k - 1) // (n - k))  # 下界
            ans = min(ans, m | 1)  # 把 m 往上调整为奇数

        return ans if ans < inf else -1

###java

class Solution {
    public int minOperations(String s, int k) {
        int n = s.length();
        int z = 0;
        for (int i = 0; i < n; i++) {
            if (s.charAt(i) == '0') {
                z++;
            }
        }

        if (z == 0) {
            return 0;
        }
        if (k == n) {
            return z == n ? 1 : -1;
        }

        int ans = Integer.MAX_VALUE;
        // 情况一:操作次数 m 是偶数
        if (z % 2 == 0) { // z 必须是偶数
            int m = Math.max((z + k - 1) / k, (z + n - k - 1) / (n - k)); // 下界
            ans = m + m % 2; // 把 m 往上调整为偶数
        }

        // 情况二:操作次数 m 是奇数
        if (z % 2 == k % 2) { // z 和 k 的奇偶性必须相同
            int m = Math.max((z + k - 1) / k, (n - z + n - k - 1) / (n - k)); // 下界
            ans = Math.min(ans, m | 1); // 把 m 往上调整为奇数
        }

        return ans < Integer.MAX_VALUE ? ans : -1;
    }
}

###cpp

class Solution {
public:
    int minOperations(string s, int k) {
        int n = s.size();
        int z = ranges::count(s, '0');
        if (z == 0) {
            return 0;
        }
        if (k == n) {
            return z == n ? 1 : -1;
        }

        int ans = INT_MAX;
        // 情况一:操作次数 m 是偶数
        if (z % 2 == 0) { // z 必须是偶数
            int m = max((z + k - 1) / k, (z + n - k - 1) / (n - k)); // 下界
            ans = m + m % 2; // 把 m 往上调整为偶数
        }

        // 情况二:操作次数 m 是奇数
        if (z % 2 == k % 2) { // z 和 k 的奇偶性必须相同
            int m = max((z + k - 1) / k, (n - z + n - k - 1) / (n - k)); // 下界
            ans = min(ans, m | 1); // 把 m 往上调整为奇数
        }

        return ans < INT_MAX ? ans : -1;
    }
};

###go

func minOperations(s string, k int) int {
n := len(s)
z := strings.Count(s, "0")
if z == 0 {
return 0
}
if k == n {
if z == n {
return 1
}
return -1
}

ans := math.MaxInt
// 情况一:操作次数 m 是偶数
if z%2 == 0 { // z 必须是偶数
m := max((z+k-1)/k, (z+n-k-1)/(n-k)) // 下界
ans = m + m%2 // 把 m 往上调整为偶数
}

// 情况二:操作次数 m 是奇数
if z%2 == k%2 { // z 和 k 的奇偶性必须相同
m := max((z+k-1)/k, (n-z+n-k-1)/(n-k)) // 下界
ans = min(ans, m|1) // 把 m 往上调整为奇数
}

if ans < math.MaxInt {
return ans
}
return -1
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $s$ 的长度。瓶颈在遍历 $s$ 上。如果已知 $s$ 中的 $0$ 的个数,则时间复杂度是 $\mathcal{O}(1)$。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

❌
❌