阅读视图

发现新文章,点击刷新页面。

每日一题-元素和小于等于 k 的子矩阵的数目🟡

给你一个下标从 0 开始的整数矩阵 grid 和一个整数 k

返回包含 grid 左上角元素、元素和小于或等于 k子矩阵的数目。

 

示例 1:

输入:grid = [[7,6,3],[6,6,1]], k = 18
输出:4
解释:如上图所示,只有 4 个子矩阵满足:包含 grid 的左上角元素,并且元素和小于或等于 18 。

示例 2:

输入:grid = [[7,2,9],[1,5,0],[2,6,6]], k = 20
输出:6
解释:如上图所示,只有 6 个子矩阵满足:包含 grid 的左上角元素,并且元素和小于或等于 20 。

 

提示:

  • m == grid.length
  • n == grid[i].length
  • 1 <= n, m <= 1000
  • 0 <= grid[i][j] <= 1000
  • 1 <= k <= 109

元素和小于等于 k 的子矩阵的数目

方法一:二维前缀和

思路与算法

题目要求统计包含矩阵 $\textit{grid}$ 左上角元素的所有子矩阵中,元素和不超过 $k$ 的子矩阵个数。

我们从左上角出发,按行优先顺序遍历矩阵,将当前访问位置 $(i, j)$ 视为子矩阵的右下角。为了在一次遍历中高效计算子矩阵的和,我们维护一个数组 $\textit{cols}[j]$,用于记录当前行之前第 $j$ 列所有元素的和。在遍历第 $i$ 行时,按列从左到右遍历 $j$,将 $\textit{grid}[i][j]$ 累加到 $\textit{cols}[j]$后,并将 $\textit{cols}[j]$ 累加起来,若当前累加和 $\le k$,则答案加 $1$。

代码

###C++

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n = grid.size(), m = grid[0].size();
        vector<int> cols(m);
        int res = 0;
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        return res;
    }
};

###Python

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        n, m = len(grid), len(grid[0])
        cols = [0] * m
        res = 0
        
        for i in range(n):
            row_sum = 0
            for j in range(m):
                cols[j] += grid[i][j]
                row_sum += cols[j]
                if row_sum <= k:
                    res += 1
        
        return res

###Rust

impl Solution {
    pub fn count_submatrices(grid: Vec<Vec<i32>>, k: i32) -> i32 {
        let n = grid.len();
        let m = grid[0].len();
        let mut cols = vec![0; m];
        let mut res = 0;
        
        for i in 0..n {
            let mut row_sum = 0;
            for j in 0..m {
                cols[j] += grid[i][j];
                row_sum += cols[j];
                if row_sum <= k {
                    res += 1;
                }
            }
        }
        
        res
    }
}

###Java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int n = grid.length, m = grid[0].length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int CountSubmatrices(int[][] grid, int k) {
        int n = grid.Length, m = grid[0].Length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###Go

func countSubmatrices(grid [][]int, k int) int {
    n := len(grid)
    m := len(grid[0])
    cols := make([]int, m)
    res := 0
    
    for i := 0; i < n; i++ {
        rows := 0
        for j := 0; j < m; j++ {
            cols[j] += grid[i][j]
            rows += cols[j]
            if rows <= k {
                res++
            }
        }
    }
    
    return res
}

###C

int countSubmatrices(int** grid, int gridSize, int* gridColSize, int k) {
    int n = gridSize;
    int m = *gridColSize;
    int* cols = (int*)calloc(m, sizeof(int));
    int res = 0;
    
    for (int i = 0; i < n; i++) {
        int rows = 0;
        for (int j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    free(cols);
    return res;
}

###JavaScript

var countSubmatrices = function(grid, k) {
    const n = grid.length;
    const m = grid[0].length;
    const cols = new Array(m).fill(0);
    let res = 0;
    
    for (let i = 0; i < n; i++) {
        let rows = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
};

###TypeScript

function countSubmatrices(grid: number[][], k: number): number {
    const n: number = grid.length;
    const m: number = grid[0].length;
    const cols: number[] = new Array(m).fill(0);
    let res: number = 0;
    
    for (let i = 0; i < n; i++) {
        let rows: number = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 是 $\textit{grid}$ 的行数,$n$ 是 $\textit{grid}$ 的列数。

  • 空间复杂度:$O(n)$。

C++二位前缀和

Problem: 100237. 元素和小于等于 k 的子矩阵的数目

思路

  • 非常标准的二位前缀和,处理之后对每个位置都可以在$O(1)$的时间完成一次查询,暴力即可

Code

###C++

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n=grid.size(),m=grid[0].size();
        vector<vector<int>>arr(n+1,vector<int>(m+1));
        int res=0;
        for(int i=1;i<=n;i++){
            for(int j=1;j<=m;j++){
                arr[i][j]=arr[i-1][j]+arr[i][j-1]-arr[i-1][j-1]+grid[i-1][j-1];
                if(arr[i][j]<=k)
                    res++;
            }
        }
        return res;
    }
};

两种方法:二维前缀和模板 / 维护每列元素和(Python/Java/C++/Go)

方法一:二维前缀和

前置知识【图解】二维前缀和

本题相当于统计有多少个二维前缀和 $\le k$。

###py

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        m, n = len(grid), len(grid[0])
        s = [[0] * (n + 1) for _ in range(m + 1)]
        ans = 0
        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                s[i + 1][j + 1] = s[i + 1][j] + s[i][j + 1] - s[i][j] + x
                if s[i + 1][j + 1] <= k:
                    ans += 1
        return ans

###java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int m = grid.length;
        int n = grid[0].length;
        int[][] sum = new int[m + 1][n + 1];
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
                if (sum[i + 1][j + 1] <= k) {
                    ans++;
                }
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int m = grid.size(), n = grid[0].size();
        vector sum(m + 1, vector<int>(n + 1));
        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
                ans += sum[i + 1][j + 1] <= k;
            }
        }
        return ans;
    }
};

###go

func countSubmatrices(grid [][]int, k int) (ans int) {
m, n := len(grid), len(grid[0])
sum := make([][]int, m+1)
sum[0] = make([]int, n+1)
for i, row := range grid {
sum[i+1] = make([]int, n+1)
for j, x := range row {
sum[i+1][j+1] = sum[i+1][j] + sum[i][j+1] - sum[i][j] + x
if sum[i+1][j+1] <= k {
ans++
}
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

:如果原地计算二维前缀和,可以做到 $\mathcal{O}(1)$ 额外空间。

方法二:维护每列的元素和

遍历每一行,同时用一个长为 $n$ 的数组 $\textit{colSum}$ 维护每一列的元素和。

遍历当前行时,一边更新 $\textit{colSum}[j]$,一边累加 $\textit{colSum}[j]$ 到变量 $s$ 中。

如果 $s\le k$ 则把答案加一,否则可以跳出循环(因为矩阵元素都非负)。

###py

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        col_sum = [0] * len(grid[0])
        ans = 0
        for row in grid:
            s = 0
            for j, x in enumerate(row):
                col_sum[j] += x
                s += col_sum[j]
                if s > k:
                    break
                ans += 1
        return ans

###java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int n = grid[0].length;
        int[] colSum = new int[n];
        int ans = 0;
        for (int[] row : grid) {
            int s = 0;
            for (int j = 0; j < n; j++) {
                colSum[j] += row[j];
                s += colSum[j];
                if (s > k) {
                    break;
                }
                ans++;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n = grid[0].size();
        vector<int> col_sum(n);
        int ans = 0;
        for (auto& row : grid) {
            int s = 0;
            for (int j = 0; j < n; j++) {
                col_sum[j] += row[j];
                s += col_sum[j];
                if (s > k) {
                    break;
                }
                ans++;
            }
        }
        return ans;
    }
};

###go

func countSubmatrices(grid [][]int, k int) (ans int) {
colSum := make([]int, len(grid[0]))
for _, row := range grid {
s := 0
for j, x := range row {
colSum[j] += x
s += colSum[j]
if s > k {
break
}
ans++
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(n)$。

:如果把每列元素和保存到 $\textit{grid}$ 的第一行,可以做到 $\mathcal{O}(1)$ 额外空间。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-重新排列后的最大子矩阵🟡

给你一个二进制矩阵 matrix ,它的大小为 m x n ,你可以将 matrix 中的  按任意顺序重新排列。

请你返回最优方案下将 matrix 重新排列后,全是 1 的子矩阵面积。

 

示例 1:

输入:matrix = [[0,0,1],[1,1,1],[1,0,1]]
输出:4
解释:你可以按照上图方式重新排列矩阵的每一列。
最大的全 1 子矩阵是上图中加粗的部分,面积为 4 。

示例 2:

输入:matrix = [[1,0,1,0,1]]
输出:3
解释:你可以按照上图方式重新排列矩阵的每一列。
最大的全 1 子矩阵是上图中加粗的部分,面积为 3 。

示例 3:

输入:matrix = [[1,1,0],[1,0,1]]
输出:2
解释:由于你只能整列整列重新排布,所以没有比面积为 2 更大的全 1 子矩形。

示例 4:

输入:matrix = [[0,0],[0,0]]
输出:0
解释:由于矩阵中没有 1 ,没有任何全 1 的子矩阵,所以面积为 0 。

 

提示:

  • m == matrix.length
  • n == matrix[i].length
  • 1 <= m * n <= 105
  • matrix[i][j] 要么是 0 ,要么是 1

枚举子矩形的底边 + O(mn) 优化(Python/Java/C++/Go)

做法类似 85. 最大矩形,枚举子矩形的底边(最后一行),定义 $\textit{heights}[j]$ 表示从 $\textit{matrix}[i][j]$ 往上有多少个连续的 $1$(柱子的高度),问题变成:

  • 你可以重排 $\textit{heights}$。重排后,对于 $\textit{heights}$ 的连续子数组,子数组长度(矩形底边长)$\times$ 子数组最小值(矩形的高),即为全 $1$ 子矩形的面积。

对于示例 1,以第三行为底边算出来的 $\textit{heights} = [2,0,3]$,下图重排后是 $[2,3,0]$。其中子数组 $[2,3]$,长为 $2$,最小值为 $2$,所以对应的子矩形面积为 $2\times 2 = 4$。

lc1727.png{:width=430px}

如何找到面积最大的子矩形?还是枚举。

枚举子数组的长度 $k = 1,2,\ldots,n$。由于我们可以重排 $\textit{heights}$,那么贪心地,把 $\textit{heights}$ 最大的 $k$ 个数排在一起,就可以让子数组的最小值(矩形的高)尽量大,从而得到最大的矩形面积。

对于 $\textit{heights}$ 的计算,如果 $\textit{matrix}[i][j]=0$,那么 $\textit{heights}[j] = 0$。否则,把高度增加 $1$。形象地说,就是在柱子下面垫一块石头,把柱子抬高。

优化前

###py

class Solution:
    def largestSubmatrix(self, matrix: List[List[int]]) -> int:
        n = len(matrix[0])
        heights = [0] * n
        ans = 0

        for row in matrix:  # 枚举子矩形的底边
            for j, x in enumerate(row):
                if x == 0:
                    heights[j] = 0
                else:
                    heights[j] += 1

            hs = sorted(heights)  # 复制一份 heights 再排序
            for i, h in enumerate(hs):  # 把 hs[i:] 作为子数组
                # 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
                ans = max(ans, (n - i) * h)  

        return ans

###java

class Solution {
    public int largestSubmatrix(int[][] matrix) {
        int n = matrix[0].length;
        int[] heights = new int[n];
        int ans = 0;

        for (int[] row : matrix) { // 枚举子矩形的底边
            for (int j = 0; j < n; j++) {
                if (row[j] == 0) {
                    heights[j] = 0;
                } else {
                    heights[j]++;
                }
            }

            int[] hs = heights.clone();
            Arrays.sort(hs);
            for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
                // 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
                ans = Math.max(ans, (n - i) * hs[i]); 
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int largestSubmatrix(vector<vector<int>>& matrix) {
        int n = matrix[0].size();
        vector<int> heights(n);
        int ans = 0;

        for (auto& row : matrix) { // 枚举子矩形的底边
            for (int j = 0; j < n; j++) {
                int x = row[j];
                if (x == 0) {
                    heights[j] = 0;
                } else {
                    heights[j]++;
                }
            }

            auto hs = heights;
            ranges::sort(hs);
            for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
                // 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
                ans = max(ans, (n - i) * hs[i]); 
            }
        }
        return ans;
    }
};

###go

func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)

for _, row := range matrix { // 枚举子矩形的底边
for j, x := range row {
if x == 0 {
heights[j] = 0
} else {
heights[j]++
}
}

hs := slices.Clone(heights)
slices.Sort(hs)
for i, h := range hs { // 把 hs[i:] 作为子数组
ans = max(ans, (n-i)*h) // 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
}
}

return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn\log n)$,其中 $m$ 和 $n$ 分别是 $\textit{matrix}$ 的行数和列数。瓶颈在排序上,有 $m$ 行,每行都要跑一个 $\mathcal{O}(n\log n)$ 的排序。
  • 空间复杂度:$\mathcal{O}(n)$。

优化

考察从 $i-1$ 行到 $i$ 行,$\textit{heights}$ 会如何变化:

  • 如果 $\textit{matrix}[i][j] = 0$,那么 $\textit{heights}[j] = 0$。在排序后,$0$ 会排在大于 $0$ 的高度前面。
  • 如果 $\textit{matrix}[i][j] = 1$,那么 $\textit{heights}[j]$ 增加一。对于那些增加一的高度,相对大小是不变的,无需再次排序。比如把 $1,2,3$ 都增加一,得到 $2,3,4$,这三个数的相对大小不变。

举个例子。假设 $i-1$ 行的 $\textit{heights}$ 排序后是 $[0,{\color{red}0},{\color{red}0},1,{\color{red}2},{\color{red}3}]$,把红色数字加一,其余数字变成 $0$,得到 $[0,{\color{red}1},{\color{red}1},0,{\color{red}3},{\color{red}4}]$。把 $0$ 排在红色数字前面,得到 $[0,0,{\color{red}1},{\color{red}1},{\color{red}3},{\color{red}4}]$。注意红色数字的相对大小是不变的,无需再次排序。

一般地,如果已知 $i-1$ 行的 $\textit{heights}$ 排序后的结果,那么对于 $i$ 行,我们只需把高度变成 $0$ 的数据排在前面,其余(增加一的)高度的相对大小不变,无需再次排序。这样就可以把排序的时间从 $\mathcal{O}(n\log n)$ 优化成 $\mathcal{O}(n)$。

但是,如果直接对 $\textit{heights}$ 排序,我们就不知道每个高度对应矩阵的哪一列了。如何解决?创建一个 $0$ 到 $n-1$ 的下标数组(列号数组)$\textit{idx}$,对下标数组排序。

###py

class Solution:
    def largestSubmatrix(self, matrix: List[List[int]]) -> int:
        n = len(matrix[0])
        heights = [0] * n
        idx = list(range(n))  # 按照高度排序后的列号
        ans = 0

        for row in matrix:
            zeros = []
            non_zeros = []
            for j in idx:
                if row[j] == 0:
                    heights[j] = 0
                    zeros.append(j)
                else:
                    heights[j] += 1
                    non_zeros.append(j)
            idx = zeros + non_zeros  # 把高度为 0 的列号排在其他高度前面

            # heights[idx[i]] 是递增的
            for i in range(len(zeros), n):  # 高度 0 无需计算
                ans = max(ans, (n - i) * heights[idx[i]])

        return ans

###java

class Solution {
    public int largestSubmatrix(int[][] matrix) {
        int n = matrix[0].length;
        int[] heights = new int[n];
        int[] idx = new int[n]; // 按照高度排序后的列号
        for (int i = 0; i < n; i++) {
            idx[i] = i;
        }
        int[] nonZeros = new int[n]; // 避免在循环内反复申请内存
        int ans = 0;

        for (int[] row : matrix) {
            int p = 0;
            int q = 0;
            for (int j : idx) {
                if (row[j] == 0) {
                    heights[j] = 0;
                    idx[p++] = j; // 高度 0 排在前面
                } else {
                    heights[j]++;
                    nonZeros[q++] = j;
                }
            }

            // heights[idx[i]] 是递增的
            for (int i = p; i < n; i++) { // 高度 0 无需计算
                idx[i] = nonZeros[i - p]; // 把 nonZeros 复制到 idx 的 [p,n-1] 中
                ans = Math.max(ans, (n - i) * heights[idx[i]]);
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int largestSubmatrix(vector<vector<int>>& matrix) {
        int n = matrix[0].size();
        vector<int> heights(n);
        vector<int> idx(n); // 按照高度排序后的列号
        ranges::iota(idx, 0); // idx[i] = i
        vector<int> non_zeros(n); // 避免在循环内反复申请内存
        int ans = 0;

        for (auto& row : matrix) {
            int p = 0, q = 0;
            for (int j : idx) {
                if (row[j] == 0) {
                    heights[j] = 0;
                    idx[p++] = j; // 高度 0 排在前面
                } else {
                    heights[j]++;
                    non_zeros[q++] = j;
                }
            }

            // heights[idx[i]] 是递增的
            for (int i = p; i < n; i++) { // 高度 0 无需计算
                idx[i] = non_zeros[i - p]; // 把 non_zeros 复制到 idx 的 [p,n-1] 中
                ans = max(ans, (n - i) * heights[idx[i]]);
            }
        }

        return ans;
    }
};

###go

func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)
idx := make([]int, n) // 按照高度排序后的列号
for i := range idx {
idx[i] = i
}
_nonZeros := make([]int, n) // 避免在循环内反复申请内存

for _, row := range matrix {
zeros := idx[:0]
nonZeros := _nonZeros[:0]
for _, j := range idx {
if row[j] == 0 {
heights[j] = 0
zeros = append(zeros, j)
} else {
heights[j]++
nonZeros = append(nonZeros, j)
}
}
idx = append(zeros, nonZeros...) // 把高度为 0 的列号排在其他高度前面

// heights[idx[i]] 是递增的
for i := len(zeros); i < n; i++ { // 高度 0 无需计算
ans = max(ans, (n-i)*heights[idx[i]])
}
}

return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{matrix}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面贪心题单的「§1.6 先枚举,再贪心」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

【贪心】活学活用,做题不慌

提示:如果你没有做出来这道题,建议先去回顾一下 85.最大矩形 的思想。本质上,这类题目都是通过枚举其中一个维度,将问题划归为一维问题来进行求解。

设矩阵为 $H$ 行 $W$ 列。

首先,我们维护一个等大的数组 $\textit{up}$,其中 $\textit{up}[i][j]$ 表示 $\textit{matrix}[i][j]$ 上面有多少个 $1$(包括它自己)。

随后,我们枚举最大矩形的底部位置 $i$。由于列之间可以任意排序,所以可以按照 $\textit{up}[i][0], \textit{up}[i][1], ... , \textit{up}[i][W-1]$ 的大小进行递增排序。

在递增排序过后,设(排序后的)第 $i$ 行第 $1$ 列上面有 $a_1$ 个 $1$。由于已经递增排序,所以第 $i$ 行第 $1$ 列右面的所有位置的上面都至少有 $a_1$ 个 $1$。于是,底边为第 $i$ 行,高度为 $a_1$ 的矩阵的最大宽度为 $W$,对应面积为 $a_1W$。

同理,设第 $i$ 行第 $2$ 列上面有 $a_2$ 个 $1$,则第 $i$ 行第 $1$ 列右面的所有位置的上面都至少有 $a_2$ 个 $1$,因此对应面积 $a_2(W-1)$。以此类推,我们能够得到底边为第 $i$ 行的矩形的最大面积。

随后,再枚举所有的 $i$,就可以得到整体的最大面积。

class Solution {
public:
    int largestSubmatrix(vector<vector<int>>& matrix) {
        int h = matrix.size(), w = matrix[0].size();
        vector<vector<int>> up(h, vector<int>(w, 0));
        
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                if (matrix[i][j] == 1) {
                    up[i][j] = (i == 0 ? 0 : up[i-1][j]) + 1;
                }
            }
        }

        int ret = 0;
        for (int i = 0; i < h; i++) {
            vector<int> buf;
            for (int j = 0; j < w; j++) {
                buf.push_back(up[i][j]);
            }
            sort(buf.begin(), buf.end());
            for (int j = 0; j < w; j++) {
                ret = max(ret, buf[j] * (w - j));
            }
        }
        return ret;
    }
};

时间复杂度
共枚举 $H$ 行,每行需要 $O(W\log W)$ 的排序以及 $O(W)$ 的额外扫描,故总体复杂度为 $O(HW\log W)$。

Java 预处理数组,遍历每行排序

预处理数组,计算以这个点为结尾,上面有多少个连续的1,就是这一列以这个点为结尾的最大高度
这样就将二维问题转成一维

遍历每一行,对每一行进行排序,记录矩形的最长的高度,每次更新结果

class Solution {
    public int largestSubmatrix(int[][] matrix) {
        int n=matrix.length;
        int m=matrix[0].length;
        int res=0;
        for(int i=1;i<n;i++){
            for(int j=0;j<m;j++){
                if(matrix[i][j]==1){
                    //记录向上连续1的个数
                    matrix[i][j]+=matrix[i-1][j];
                }
            }
        }
        for(int i=0;i<n;i++){
            Arrays.sort(matrix[i]);
            for(int j=m-1;j>=0;j--){
                //更新矩形的最大高度
                int height=matrix[i][j];
                //更新最大面积
                res=Math.max(res,height*(m-j));
            }
        }
        return res;
    }
}

每日一题-奇妙序列🔴

请你实现三个 API appendaddAll 和 multAll 来实现奇妙序列。

请实现 Fancy 类 :

  • Fancy() 初始化一个空序列对象。
  • void append(val) 将整数 val 添加在序列末尾。
  • void addAll(inc) 将所有序列中的现有数值都增加 inc 。
  • void multAll(m) 将序列中的所有现有数值都乘以整数 m 。
  • int getIndex(idx) 得到下标为 idx 处的数值(下标从 0 开始),并将结果对 109 + 7 取余。如果下标大于等于序列的长度,请返回 -1 。

 

示例:

输入:
["Fancy", "append", "addAll", "append", "multAll", "getIndex", "addAll", "append", "multAll", "getIndex", "getIndex", "getIndex"]
[[], [2], [3], [7], [2], [0], [3], [10], [2], [0], [1], [2]]
输出:
[null, null, null, null, null, 10, null, null, null, 26, 34, 20]

解释:
Fancy fancy = new Fancy();
fancy.append(2);   // 奇妙序列:[2]
fancy.addAll(3);   // 奇妙序列:[2+3] -> [5]
fancy.append(7);   // 奇妙序列:[5, 7]
fancy.multAll(2);  // 奇妙序列:[5*2, 7*2] -> [10, 14]
fancy.getIndex(0); // 返回 10
fancy.addAll(3);   // 奇妙序列:[10+3, 14+3] -> [13, 17]
fancy.append(10);  // 奇妙序列:[13, 17, 10]
fancy.multAll(2);  // 奇妙序列:[13*2, 17*2, 10*2] -> [26, 34, 20]
fancy.getIndex(0); // 返回 26
fancy.getIndex(1); // 返回 34
fancy.getIndex(2); // 返回 20

 

提示:

  • 1 <= val, inc, m <= 100
  • 0 <= idx <= 105
  • 总共最多会有 105 次对 appendaddAllmultAll 和 getIndex 的调用。

懒更新 + 等价转化(Python/Java/C++/Go)

只有加法

从特殊到一般,先考虑一个简单的问题:只有加法,没有乘法,怎么做?

执行 $\texttt{addAll}(\textit{inc})$ 时,如果把每个数都增加 $\textit{inc}$,就太慢了。

我们可以采用一种「懒更新」的想法,等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才做计算。

比如序列 $\textit{vals} = [3,1,4]$,执行 $\texttt{addAll}(2)$ 时,我们不去把 $\textit{vals}$ 的每个数都增加 $2$,而是用一个变量 $\textit{add}$ 表示「每个数都要增加 $\textit{add}$」。执行 $\texttt{addAll}(2)$ 时,只把 $\textit{add}$ 增加 $2$。等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才计算加法:

$$
\textit{vals}[\textit{idx}] + \textit{add}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] + \textit{add}$ 这个式子,我们可以先把 $\textit{val}$ 减少 $\textit{add}$,再添加到 $\textit{vals}$ 的末尾,比如 $\textit{val} = 5$,$\textit{add} = 2$,那么往 $\textit{vals}$ 的末尾添加 $5-2=3$,就可以让式子 $\textit{vals}[\textit{idx}] + \textit{add}$ 对所有元素都保持一致

只有乘法

如果只有乘法,没有加法呢?

同理,用变量 $\textit{mul}$ 表示「每个数都要乘以 $\textit{mul}$」。执行 $\texttt{multAll}(2)$ 时,只把 $\textit{mul}$ 乘以 $2$。等到调用 $\texttt{getIndex}(\textit{idx})$ 时,才计算乘法:

$$
\textit{vals}[\textit{idx}] \cdot \textit{mul}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] \cdot \textit{mul}$ 这个式子,我们可以先把 $\textit{val}$ 除以 $\textit{mul}$,再添加到 $\textit{vals}$ 的末尾,比如 $\textit{val} = 6$,$\textit{mul} = 2$,那么往 $\textit{vals}$ 的末尾添加 $6/2=3$,就可以让式子 $\textit{vals}[\textit{idx}] \cdot \textit{mul}$ 对所有元素都保持一致

注意:在模运算中,除以 $\textit{mul}$ 等价于乘以 $\textit{mul}$ 关于 $M = 10^9+7$ 的逆元,即 $\textit{mul}^{M-2}$。原理见 模运算的世界:当加减乘除遇上取模

加法和乘法

把上述方法结合起来,用 $\textit{add}$ 记录操作 $\texttt{addAll}$,用 $\textit{mul}$ 记录操作 $\texttt{multAll}$。

  • 初始值:$\textit{add} = 0$,$\textit{mul}=1$。
  • 执行 $\texttt{addAll}(\textit{inc})$ 时,把 $\textit{add}$ 增加 $\textit{inc}$。
  • 执行 $\texttt{multAll}(m)$ 时,由于 $(v \cdot \textit{mul} + \textit{add})\cdot m = v \cdot (\textit{mul}\cdot m) + \textit{add}\cdot m$,所以把 $\textit{mul}$ 乘以 $m$,把 $\textit{add}$ 乘以 $m$。

调用 $\texttt{getIndex}(\textit{idx})$ 时,计算

$$
\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}
$$

即为 $\textit{vals}[\textit{idx}]$ 更新后的数值。

如何处理 $\texttt{append}(\textit{val})$ 呢?

为了让 $\textit{val}$ 兼容 $\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}$ 这个式子,我们可以先计算 $v = \dfrac{\textit{val} - \textit{add}}{\textit{mul}}$,再把 $v$ 添加到 $\textit{vals}$ 的末尾,就可以让式子 $\textit{vals}[\textit{idx}] \cdot \textit{mul} + \textit{add}$ 对所有元素都保持一致

代码实现时,注意取模。为什么可以在中途取模?原理见 模运算的世界:当加减乘除遇上取模

###py

MOD = 1_000_000_007

class Fancy:
    def __init__(self):
        self.vals = []
        self.add = 0
        self.mul = 1

    def append(self, val: int) -> None:
        self.vals.append((val - self.add) * pow(self.mul, -1, MOD) % MOD)

    def addAll(self, inc: int) -> None:
        self.add += inc

    def multAll(self, m: int) -> None:
        self.mul = self.mul * m % MOD
        self.add = self.add * m % MOD

    def getIndex(self, idx: int) -> int:
        if idx >= len(self.vals):
            return -1
        return (self.vals[idx] * self.mul + self.add) % MOD

###java

class Fancy {
    private static final int MOD = 1_000_000_007;

    private final List<Integer> vals = new ArrayList<>();
    private long add = 0;
    private long mul = 1;

    public void append(int val) {
        // 注意这里有减法,计算结果可能是负数,+MOD 可以保证计算结果非负
        vals.add((int) ((val - add + MOD) * pow(mul, MOD - 2) % MOD));
    }

    public void addAll(int inc) {
        add = (add + inc) % MOD;
    }

    public void multAll(int m) {
        mul = mul * m % MOD;
        add = add * m % MOD;
    }

    public int getIndex(int idx) {
        if (idx >= vals.size()) {
            return -1;
        }
        return (int) ((vals.get(idx) * mul + add) % MOD);
    }

    private long pow(long x, int n) {
        long res = 1;
        for (; n > 0; n /= 2) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }
}

###cpp

class Fancy {
    static constexpr int MOD = 1'000'000'007;

    vector<int> vals;
    long long add = 0;
    long long mul = 1;

    long long pow(long long x, int n) {
        long long res = 1;
        for (; n; n /= 2) {
            if (n % 2) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }

public:
    void append(int val) {
        // 注意这里有减法,计算结果可能是负数,+MOD 可以保证计算结果非负
        vals.push_back((val - add + MOD) * pow(mul, MOD - 2) % MOD);
    }

    void addAll(int inc) {
        add = (add + inc) % MOD;
    }

    void multAll(int m) {
        mul = mul * m % MOD;
        add = add * m % MOD;
    }

    int getIndex(int idx) {
        if (idx >= vals.size()) {
            return -1;
        }
        return (vals[idx] * mul + add) % MOD;
    }
};

###go

const mod = 1_000_000_007

func pow(x, n int) int {
res := 1
for ; n > 0; n /= 2 {
if n%2 > 0 {
res = res * x % mod
}
x = x * x % mod
}
return res
}

type Fancy struct {
vals []int
add  int
mul  int
}

func Constructor() Fancy {
return Fancy{mul: 1}
}

func (f *Fancy) Append(val int) {
// 注意这里有减法,计算结果可能是负数,+mod 可以保证计算结果非负
f.vals = append(f.vals, (val-f.add+mod)*pow(f.mul, mod-2)%mod)
}

func (f *Fancy) AddAll(inc int) {
f.add = (f.add + inc) % mod
}

func (f *Fancy) MultAll(m int) {
f.mul = f.mul * m % mod
f.add = f.add * m % mod
}

func (f *Fancy) GetIndex(idx int) int {
if idx >= len(f.vals) {
return -1
}
return (f.vals[idx]*f.mul + f.add) % mod
}

复杂度分析

  • 时间复杂度:$\texttt{append}$ 为 $\mathcal{O}(\log M)$,其余为 $\mathcal{O}(1)$,其中 $M=10^9+7$。
  • 空间复杂度:$\mathcal{O}(q)$,其中 $q$ 是 $\texttt{append}$ 的调用次数。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

如果不了解乘法逆元,就好好学习一下线段树的懒更新

首先,这个问题借助乘法逆元,可以非常优美地解决,具体请参考零神的题解:https://leetcode.cn/problems/fancy-sequence/solution/qi-miao-xu-lie-by-zerotrac2/

经常玩儿算法比赛的同学,还是需要学习乘法逆元的。乘法逆元的在算法竞赛中使用非常广泛,但在面试中近乎没有用。鉴于现在力扣的问题越来越竞赛化,所以学习乘法逆元不吃亏。


但是,对于这个问题,如果你不会乘法逆元,也是可以解决的。由于问题的操作都是在一个区间中进行的,所以很容易想到使用线段树求解。

其实树状数组也可以,不过因为力扣官方网站本身有非常好的线段树的解析,所以,这个题解的代码,我使用线段树。

另外值得一提的是,使用线段树解决,其实可以解决这个问题的“更强版本”。在这个问题中,每次操作都是对已经有的所有数据进行的。但是如果问题变成,在指定的某个区间,所有的元素做乘法或者加法,既支持的是 addRange(l, r, inc)multRange(l, r, m) 两种操作的话,线段树也可以轻松应对。(可以出一个“奇妙序列 II”了。)

因为这个问题需要区间更新,所以,需要使用线段树的懒更新功能。强烈建议大家首先仔细阅读学习这份力扣的文档,在这份文档的后续,详细介绍了如何为线段树添加懒更新的功能,并且附有代码:https://leetcode.com/articles/a-recursive-approach-to-segment-trees-range-sum-queries-lazy-propagation/

下面,我对这个问题的解题代码,将直接基于这篇文章的线段树代码进行修改。


上面的文章中,解决了如果为指定区间所有元素添加一个值,怎样进行区间更新。但是在这个问题中,还需要处理,如果为指定区间所有元素乘以一个值,要怎么办?

首先,非常关键的一点是,经过一系列题目中的操作,对于每一个元素 x,最终都可以化为 x * a + b 的形式。

比如,对于某个 x,我们有了 ((x + inc1) * m1 + inc2) * m2 + inc3,展开,就有了:

x * (m1 * m2) + (inc1 * m1 * m2 + inc2 * m2 + inc3)。

相当于是 a = m1 * m2;b = inc1 * m1 * m2 + inc2 * m2 + inc3。

因此,我们在线段树的懒标记中,只需要记录经过若干操作以后,对每个元素的 a 是多少,b 是多少,就好了。

所以,我们的线段树需要两个懒标记数组,我使用 lazymlazya,来记录乘数和加数。


下面的关键就是,如果来了一个运算,如何做懒更新?

假设本来的元素是 x,对应的乘数是 a,加数是 b。也就是其实应该是 ax + b。

来了一个乘法运算 m。则变成了 (ax + b) * m = amx + bm。可以看到,乘数和加数都要乘以 m;

来了一个加法运算 inc,则变成了 ax + b + inc = ax + (b + inc)。可以看到,只要加数加 m 就好了。

在下面的代码中,注释有“懒更新”字眼的代码,是实现的这部分逻辑。


除此之外,还有一个问题,懒更新除了更新下面的 lazymlazya 数组,还需要更新当前的区间 tree[treeID]

这其实很简单,因为当前区间的所有元素都要乘以 a 然后加上 b。所以整体就是对原来的区间和乘以 a,然后加上区间长度 len * b(每个元素都加上了 b)。

在下面的代码中,注释有“区间更新”字眼的代码,是实现的这部分逻辑。

取此之外,所有的逻辑都和上面文章中的线段树是一样的。

参考代码(C++):

class SegmentTree{

private:
    int n;
    vector<long long> tree, lazya, lazym;
    const long long MOD = 1e9 + 7;

public:
    SegmentTree(int n): n(n), tree(4 * n, 0ll), lazya(4 * n, 0ll), lazym(4 * n, 1ll){}

    void add(int index, int val){
        update(0, 0, n - 1, index, index, val, 1);
    }

    void update_add(int uL, int uR, int inc){
        update(0, 0, n-1, uL, uR, inc, 1);
    }

    void update_mul(int uL, int uR, int mul){
        update(0, 0, n-1, uL, uR, 0, mul);
    }

    int query(int index){
        return query(0, 0, n-1, index);
    }

private:
    void update(int treeID, int treeL, int treeR, int uL, int uR, int inc, int mul){

        if(lazya[treeID] != 0 || lazym[treeID] != 1){
            // 区间更新
            tree[treeID] = (tree[treeID] * lazym[treeID] + (treeR - treeL + 1)* lazya[treeID]) % MOD;

            // 懒更新
            if(treeL != treeR){
                lazym[2 * treeID + 1] = lazym[2 * treeID + 1] * lazym[treeID] % MOD;
                lazya[2 * treeID + 1] = lazya[2 * treeID + 1] * lazym[treeID] % MOD;
                lazya[2 * treeID + 1] = (lazya[2 * treeID + 1] + lazya[treeID]) % MOD;

                lazym[2 * treeID + 2] = lazym[2 * treeID + 2] * lazym[treeID] % MOD;
                lazya[2 * treeID + 2] = lazya[2 * treeID + 2] * lazym[treeID] % MOD;
                lazya[2 * treeID + 2] = (lazya[2 * treeID + 2] + lazya[treeID]) % MOD;
            }
            lazya[treeID] = 0;
            lazym[treeID] = 1;
        }

        if (treeL > uR || treeR < uL) return;

        if(uL <= treeL && uR >= treeR){
            // 区间更新
            tree[treeID] = (tree[treeID] + tree[treeID] * (mul - 1) + (treeR - treeL + 1) * inc) % MOD;

            // 懒更新
            if(treeL != treeR){
                lazym[2 * treeID + 1] = lazym[2 * treeID + 1] * mul % MOD;
                lazya[2 * treeID + 1] = lazya[2 * treeID + 1] * mul % MOD;
                lazya[2 * treeID + 1] = (lazya[2 * treeID + 1] + inc) % MOD;

                lazym[2 * treeID + 2] = lazym[2 * treeID + 2] * mul % MOD;
                lazya[2 * treeID + 2] = lazya[2 * treeID + 2] * mul % MOD;
                lazya[2 * treeID + 2] = (lazya[2 * treeID + 2] + inc) % MOD;
            }
            return;
        }

        int mid = (treeL + treeR) / 2;
        update(2 * treeID + 1, treeL, mid, uL, uR, inc, mul);
        update(2 * treeID + 2, mid + 1, treeR, uL, uR, inc, mul);
        tree[treeID] = (tree[treeID * 2 + 1] + tree[treeID * 2 + 2]) % MOD;
        return;
    }

    int query(int treeID, int treeL, int treeR, int index){

        if(lazya[treeID] != 0 || lazym[treeID] != 1){
            // 区间更新
            tree[treeID] = (tree[treeID] * lazym[treeID] + (treeR - treeL + 1)* lazya[treeID]) % MOD;

            // 懒更新
            if(treeL != treeR){
                lazym[2 * treeID + 1] = lazym[2 * treeID + 1] * lazym[treeID] % MOD;
                lazya[2 * treeID + 1] = lazya[2 * treeID + 1] * lazym[treeID] % MOD;
                lazya[2 * treeID + 1] = (lazya[2 * treeID + 1] + lazya[treeID]) % MOD;

                lazym[2 * treeID + 2] = lazym[2 * treeID + 2] * lazym[treeID] % MOD;
                lazya[2 * treeID + 2] = lazya[2 * treeID + 2] * lazym[treeID] % MOD;
                lazya[2 * treeID + 2] = (lazya[2 * treeID + 2] + lazya[treeID]) % MOD;
            }
            lazya[treeID] = 0;
            lazym[treeID] = 1;
        }

        if(treeL == treeR) return tree[treeID];

        int mid = (treeL + treeR) / 2;
        if(index <= mid) return query(2 * treeID + 1, treeL, mid, index);
        return query(2 * treeID + 2, mid + 1, treeR, index);
    }
};

有了这样一个线段树,题目中要求的 Fancy 类是非常简单的:

class Fancy {

private:
    SegmentTree tree;
    int len = 0;

public:
    Fancy() : tree(1e5 + 1){}

    void append(int val) {
        tree.add(len ++, val);
    }

    void addAll(int inc) {
        tree.update_add(0, len - 1, inc);
    }

    void multAll(int mul) {
        tree.update_mul(0, len - 1, mul);
    }

    int getIndex(int idx) {
        if(idx >= 0 && idx < len)
            return tree.query(idx);
        return -1;
    }
};

对于这个方法,所有的操作都是 O(logn) 的。

依然是,这个方法的优点是,如果问题要求在指定的某个区间做乘法或者加法,即支持的是 addRange(l, r, inc)multRange(l, r, m) 两种操作的话,这个解法也能轻松应对。


觉得有帮助请点赞哇!

奇妙序列

预备知识:乘法逆元

设模数为 $m$,整数 $a$ 在模 $m$ 的意义下存在乘法逆元整数 $a^{-1}~(0 < a^{-1} < m)$,当且仅当

$$
a a^{-1} \equiv 1 ~ (\bmod ~ m)
$$

成立。根据上式可得

$$
aa^{-1} = km + 1, \quad k \in \mathbb{N}
$$

整理得

$$
a^{-1} \cdot a - k \cdot m = 1
$$

当 $m$ 为质数时,根据「裴蜀定理」,$\text{gcd}(a, m) = 1$,因此必存在整数 $a^{-1}$ 和 $k$ 使得上式成立。

如果 $(a^{-1}_0, k_0)$ 是一组解,那么

$$
(a^{-1}_0 + cm, k_0 + ca), \quad c \in \mathbb{Z}
$$

都是上式的解。因此必然存在一组解中的整数 $a^{-1}$ 满足 $0 < a^{-1} < m$。

那么如何求出 $a^{-1}$ 呢?一种简单的方法是使用「费马小定理」,即

$$
a^{m-1} \equiv 1 ~ (\bmod ~ m)
$$

那么有

$$
a^{m-1} a^{-1} \equiv a^{-1} ~ (\bmod ~ m)
$$

$$
a^{m-2} a a^{-1} \equiv a^{-1} ~ (\bmod ~ m)
$$

$$
a^{-1} \equiv a^{m-2} ~ (\bmod ~ m)
$$

使用「乘法逆元」有什么好处呢?如果我们要求 $\dfrac{c}{a}$ 对 $m$ 取模的结果,那么我们可以化除法为乘法,即

$$
\frac{c}{a} \equiv c \cdot a^{-1} ~ (\bmod ~ m)
$$

而 $a^{-1}$ 就等于 $a^{m-2}$ 对 $m$ 取模的结果,后者使用快速幂即可,时间复杂度为 $O(\log m)$,可以参考 50. Pow(x, n) 的官方题解

方法一:将 getIndex 作为瓶颈

思路与算法

我们可以将所有的 addAll 操作以及 multAll 操作「浓缩」成一次操作 $(a, b)$,表示将任意整数 $x$ 变为 $ax+b$:

  • 初始时 $(a, b) = (1, 0)$;

  • 遇到 addAll(inc) 操作时,将 $b$ 增加 $\textit{inc}$;

  • 遇到 multAll(m) 操作时,将 $a$ 和 $b$ 都乘上 $m$。

我们记 $v$ 为原始序列(也就是保存了每个 append(val)val 的原始值的序列),$(a_i, b_i)$ 表示在 $v_i$ 被加入 $v$ 中之前,所有进行的操作「浓缩」后的结果。特别地,$(a_0, b_0) = (1, 0)$。当我们遇到 getIndex(idx) 时,考虑 $(a_\textit{idx}, b_\textit{idx})$ 以及 $v$ 中最后一个元素的 $(a, b)$:

  • 在 $v_\textit{idx}$ 被放入 $v$ 之前,操作为 $(a_\textit{idx}, b_\textit{idx})$;

  • 在 $v_\textit{idx}$ 被放入 $v$ 之后到目前为止,操作为 $(a, b)$。

因此,对 $v_\textit{idx}$ 进行的操作,就等价于可以将 $(a_\textit{idx}, b_\textit{idx})$ 变成 $(a, b)$ 的操作,记为 $(a_o, b_o)$。具体地:

$$
\begin{cases}
a_\textit{idx} \cdot a_o \equiv a ~ (\bmod ~ m) \
b_\textit{idx} \cdot a_o + b_o \equiv b ~ (\bmod ~ m)
\end{cases}
$$

这里 $m$ 为质数 $10^9+7$。其实就是对于任意的 $x$,有

$$
a_o ( a_\textit{idx} \cdot x + b_\textit{idx} ) + b_o = ax + b
$$

根据「预备知识」,可以通过乘法逆元得到 $a_o$

$$
a_o \equiv a_\textit{idx}^{-1} \cdot a ~ (\bmod ~ m)
$$

将其带入也可以得到 $b_o$

$$
b_o \equiv b - b_\textit{idx} \cdot a_o ~ (\bmod ~ m)
$$

这样返回 $a_o \cdot v_\textit{idx} + b_o$ 对 $m$ 取模的结果即可。

代码

###C++

class Fancy {
private:
    static constexpr int mod = 1000000007;
    vector<int> v, a, b;
    
public:
    Fancy() {
        a.push_back(1);
        b.push_back(0);
    }
    
    // 快速幂
    int quickmul(int x, int y) {
        int ret = 1;
        int cur = x;
        while (y) {
            if (y & 1) {
                ret = (long long)ret * cur % mod;
            }
            cur = (long long)cur * cur % mod;
            y >>= 1;
        }
        return ret;
    }
    
    // 乘法逆元
    int inv(int x) {
        return quickmul(x, mod - 2);
    }
    
    void append(int val) {
        v.push_back(val);
        a.push_back(a.back());
        b.push_back(b.back());
    }
    
    void addAll(int inc) {
        b.back() = (b.back() + inc) % mod;
    }
    
    void multAll(int m) {
        a.back() = (long long)a.back() * m % mod;
        b.back() = (long long)b.back() * m % mod;
    }
    
    int getIndex(int idx) {
        if (idx >= v.size()) {
            return -1;
        }
        int ao = (long long)inv(a[idx]) * a.back() % mod;
        int bo = (b.back() - (long long)b[idx] * ao % mod + mod) % mod;
        int ans = ((long long)ao * v[idx] % mod + bo) % mod;
        return ans;
    }
};

复杂度分析

  • 时间复杂度:getIndex(idx) 为 $O(\log m)$,其余均为 $O(1)$。

  • 空间复杂度:$O(n)$。

方法二:将 append 作为瓶颈

思路与算法

我们也可以将求解逆元的步骤放在 append(val) 操作时。

当我们将 $v_i$ 加入 $v$ 中时,如果目前为止的操作为 $(a, b)$,那么我们可以将 $\dfrac{v-b}{a}$ 代替 $v$ 加入 $v_i$。这样在任意时刻,所有元素都可以看作进行了相同的操作

根据「预备知识」,可以通过乘法逆元,计算 $(v-b) \cdot a^{-1}$ 来等价 $\dfrac{v-b}{a}$。

代码

###C++

class Fancy {
private:
    static constexpr int mod = 1000000007;
    vector<int> v;
    int a, b;
    
public:
    Fancy(): a(1), b(0) {}
    
    // 快速幂
    int quickmul(int x, int y) {
        int ret = 1;
        int cur = x;
        while (y) {
            if (y & 1) {
                ret = (long long)ret * cur % mod;
            }
            cur = (long long)cur * cur % mod;
            y >>= 1;
        }
        return ret;
    }
    
    // 乘法逆元
    int inv(int x) {
        return quickmul(x, mod - 2);
    }
    
    void append(int val) {
        v.push_back((long long)((val - b + mod) % mod) * inv(a) % mod);
    }
    
    void addAll(int inc) {
        b = (b + inc) % mod;
    }
    
    void multAll(int m) {
        a = (long long)a * m % mod;
        b = (long long)b * m % mod;
    }
    
    int getIndex(int idx) {
        if (idx >= v.size()) {
            return -1;
        }
        int ans = ((long long)a * v[idx] % mod + b) % mod;
        return ans;
    }
};

###Python

class Fancy:

    def __init__(self):
        self.mod = 10**9 + 7
        self.v = list()
        self.a = 1
        self.b = 0

    # 快速幂
    def quickmul(self, x: int, y: int) -> int:
        return pow(x, y, self.mod)
    
    # 乘法逆元
    def inv(self, x: int) -> int:
        return self.quickmul(x, self.mod - 2)

    def append(self, val: int) -> None:
        self.v.append((val - self.b) * self.inv(self.a) % self.mod)

    def addAll(self, inc: int) -> None:
        self.b = (self.b + inc) % self.mod

    def multAll(self, m: int) -> None:
        self.a = self.a * m % self.mod
        self.b = self.b * m % self.mod

    def getIndex(self, idx: int) -> int:
        if idx >= len(self.v):
            return -1
        return (self.a * self.v[idx] + self.b) % self.mod

复杂度分析

  • 时间复杂度:append(val) 为 $O(\log m)$,其余均为 $O(1)$。

  • 空间复杂度:$O(n)$。

每日一题-长度为 n 的开心字符串中字典序第 k 小的字符串🟡

一个 「开心字符串」定义为:

  • 仅包含小写字母 ['a', 'b', 'c'].
  • 对所有在 1 到 s.length - 1 之间的 i ,满足 s[i] != s[i + 1] (字符串的下标从 1 开始)。

比方说,字符串 "abc""ac","b" 和 "abcbabcbcb" 都是开心字符串,但是 "aa""baa" 和 "ababbc" 都不是开心字符串。

给你两个整数 n 和 k ,你需要将长度为 n 的所有开心字符串按字典序排序。

请你返回排序后的第 k 个开心字符串,如果长度为 n 的开心字符串少于 k 个,那么请你返回 空字符串 。

 

示例 1:

输入:n = 1, k = 3
输出:"c"
解释:列表 ["a", "b", "c"] 包含了所有长度为 1 的开心字符串。按照字典序排序后第三个字符串为 "c" 。

示例 2:

输入:n = 1, k = 4
输出:""
解释:长度为 1 的开心字符串只有 3 个。

示例 3:

输入:n = 3, k = 9
输出:"cab"
解释:长度为 3 的开心字符串总共有 12 个 ["aba", "abc", "aca", "acb", "bab", "bac", "bca", "bcb", "cab", "cac", "cba", "cbc"] 。第 9 个字符串为 "cab"

示例 4:

输入:n = 2, k = 7
输出:""

示例 5:

输入:n = 10, k = 100
输出:"abacbabacb"

 

提示:

  • 1 <= n <= 10
  • 1 <= k <= 100

 

O(n) 简洁写法(Python/Java/C++/C/Go/JS/Rust)

首先计算有多少个长为 $n$ 的开心字符串。

这是一个计数问题。有 $n$ 个空位,第一个位置可以填 $3$ 种字母,后面的每个位置,都不能和前一个位置的字母相同,所以都只有 $2$ 种填法,因此方案数为

$$
3\cdot 2^{n-1}
$$

如果 $k > 3\cdot 2^{n-1}$,返回空串。否则答案是存在的。


为方便计算,我们先把 $k$ 减一,改成从 $0$ 开始。

以 $n=4$,$k=12=1100_{(2)}$(减一后)为例说明。

答案的第一个字母是什么?

如果第一个字母是 $\texttt{a}$,那么后面的 $n-1=3$ 个位置有 $2^{n-1}=8$ 种填法,不够。

如果第一个字母是 $\texttt{b}$,同样地,后面的 $n-1=3$ 个位置有 $2^{n-1}=8$ 种填法,这就够了,所以答案的第一个字母填 $\texttt{b}$。

现在答案为 $\texttt{b___}$,剩余的三个位置填什么字母?

由于每个位置都有两种填法,这和 $k$ 的二进制可以完美对应,如下表(注意相邻字母不同的要求):

$k$ 的低三位 对应填法
$000$ $\texttt{aba}$
$001$ $\texttt{abc}$
$010$ $\texttt{aca}$
$011$ $\texttt{acb}$
$100$ $\texttt{cab}$
$101$ $\texttt{cac}$
$110$ $\texttt{cba}$
$111$ $\texttt{cbc}$

对于 $k=1100_{(2)}$(减一后)这个例子,答案是这样填的:

  • 答案的第二个字母不能和前一个字母 $\textit{b}$ 相同,只能填 $\texttt{a}$ 或者 $\texttt{c}$,由于 $k$ 这一位是 $1$,所以填 $\texttt{c}$。
  • 答案的第三个字母不能和前一个字母 $\textit{c}$ 相同,只能填 $\texttt{a}$ 或者 $\texttt{b}$,由于 $k$ 这一位是 $0$,所以填 $\texttt{a}$。
  • 答案的第四个字母不能和前一个字母 $\textit{a}$ 相同,只能填 $\texttt{b}$ 或者 $\texttt{c}$,由于 $k$ 这一位是 $0$,所以填 $\texttt{b}$。

所以答案为 $\texttt{bcab}$。

一般地,答案的第一个字母是第 $\left\lfloor\dfrac{k}{2^{n-1}}\right\rfloor$ 个小写字母,随后的字母可以根据 $k\bmod 2^{n-1}$ 二进制从高到低是 $0$ 还是 $1$,填入相应的字母:

  • 首先,如果二进制这一位是 $0$,那么填入 $\texttt{a}$,否则填入 $\texttt{b}$。
  • 然后修正:看填入的字母是否大于等于左侧相邻字母,如果大于等于,那么把填入的字母加一。比如左侧相邻字母是 $\texttt{a}$,那么当前这一位如果填的是 $\texttt{a}$,要变成 $\texttt{b}$;如果填的是 $\texttt{b}$,要变成 $\texttt{c}$。

:这个「加一」的技巧可用于生成两个不同的随机整数,见 961. 在长度 2N 的数组中找出重复 N 次的元素 我的题解的方法四。

###py

class Solution:
    def getHappyString(self, n: int, k: int) -> str:
        if k > 3 << (n - 1):
            return ""
        k -= 1  # 改成从 0 开始,方便计算
        ans = [ord('a')] * n
        ans[0] += k >> (n - 1)
        for i in range(1, n):
            ans[i] += k >> (n - 1 - i) & 1
            if ans[i] >= ans[i - 1]:
                ans[i] += 1
        return ''.join(map(chr, ans))

###java

class Solution {
    public String getHappyString(int n, int k) {
        if (k > 3 << (n - 1)) {
            return "";
        }
        k--; // 改成从 0 开始,方便计算
        char[] ans = new char[n];
        ans[0] = (char) ('a' + (k >> (n - 1)));
        for (int i = 1; i < n; i++) {
            ans[i] = (char) ('a' + (k >> (n - 1 - i) & 1));
            if (ans[i] >= ans[i - 1]) {
                ans[i]++;
            }
        }
        return new String(ans);
    }
}

###cpp

class Solution {
public:
    string getHappyString(int n, int k) {
        if (k > 3 << (n - 1)) {
            return "";
        }
        k--; // 改成从 0 开始,方便计算
        string ans(n, 'a');
        ans[0] += k >> (n - 1);
        for (int i = 1; i < n; i++) {
            ans[i] += k >> (n - 1 - i) & 1;
            if (ans[i] >= ans[i - 1]) {
                ans[i]++;
            }
        }
        return ans;
    }
};

###c

char* getHappyString(int n, int k) {
    if (k > 3 << (n - 1)) {
        return "";
    }
    k--; // 改成从 0 开始,方便计算
    char* ans = malloc((n + 1) * sizeof(char));
    ans[0] = 'a' + (k >> (n - 1));
    for (int i = 1; i < n; i++) {
        ans[i] = 'a' + (k >> (n - 1 - i) & 1);
        if (ans[i] >= ans[i - 1]) {
            ans[i]++;
        }
    }
    ans[n] = '\0';
    return ans;
}

###go

func getHappyString(n, k int) string {
if k > 3<<(n-1) {
return ""
}
k-- // 改成从 0 开始,方便计算
ans := make([]byte, n)
ans[0] = 'a' + byte(k>>(n-1))
for i := 1; i < n; i++ {
ans[i] = 'a' + byte(k>>(n-1-i)&1)
if ans[i] >= ans[i-1] {
ans[i]++
}
}
return string(ans)
}

###js

var getHappyString = function(n, k) {
    if (k > 3 << (n - 1)) {
        return "";
    }
    k--; // 改成从 0 开始,方便计算
    const ans = Array(n).fill('a'.charCodeAt(0));
    ans[0] += k >> (n - 1);
    for (let i = 1; i < n; i++) {
        ans[i] += k >> (n - 1 - i) & 1;
        if (ans[i] >= ans[i - 1]) {
            ans[i]++;
        }
    }
    return ans.map(c => String.fromCharCode(c)).join('');
};

###rust

impl Solution {
    pub fn get_happy_string(n: i32, mut k: i32) -> String {
        if k > 3 << (n - 1) {
            return String::new();
        }
        k -= 1; // 改成从 0 开始,方便计算
        let n = n as usize;
        let mut ans = vec![0; n];
        ans[0] = b'a' + (k >> (n - 1)) as u8;
        for i in 1..n {
            ans[i] = b'a' + (k >> (n - 1 - i) & 1) as u8;
            if ans[i] >= ans[i - 1] {
                ans[i] += 1;
            }
        }
        unsafe { String::from_utf8_unchecked(ans) }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(1)$。返回值不计入。

相似题目

60. 排列序列

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

长度为 n 的开心字符串中字典序第 k 小的字符串(看图写作)

由题:按字典序画出如下树
未命名文件 (11).png
由图知,以a开头长度为n的字符串共有$2^{n-1}$种,b,c开头的也是如此,则总计$3*2^{n-1}$种(3棵n层的满二叉树)
若对第n层曲所有树叶结点依次进行编号,从0开始,则第k大的字符串对应序号k-1的节点
由a,b,c开头的字符串均有$2^{n-1}$种,且依次排序
则序号$order/(2^{n-1})$,可以得到其位于哪棵树
则序号$order$%$(2^{n-1})$,可以得到其位于该棵树的第几号节点
得到其在当前树(满二叉树)的位置,则可通过该树分叉出的含义(二进制,0取较小的字符,1去较大的字符)

例如:n=3,k=7

序号6
每棵树叶子结点数:$2^{3-1}=4$
位于6/4=1,位于b开头的树
在b这棵树叶子节点中为6%4=2号节点(0开始)
2对应(n-1)位的二进制为0b10
则字符串为bca

b(起始结点)
->c(上一字符b,当前可选字符a,c。1表示往较大的一个字符转移)
->a(上一字符c,当前可选字符a,b。0表示往较小的一个字符转移)

###java

class Solution {
    public String getHappyString(int n, int k) {
        int count=3<<(n-1);//3*2^(n-1)
        if(k>count) return "";
        char[] result=new char[n];
        int[][] stateTab=new int[][]{{1,2},{0,2},{0,1}};
        //状态0,1,2分别表示a,b,c,
        //转移参数:0表示下一个取较小的字符,1表示下一个取较大的字符
        int order=k-1;//序号k-1,表示第k大
        int index=0 ,state= order>>(n-1);// <=>order/2^(n-1)
        result[index++]=(char)(state+'a');
        int tree=order&((1<<(n-1))-1);//获取其在树中的位置<=>order% 2^(n-1)
        for(int i=n-2;i>=0;i--){
            state=stateTab[state][(tree >> i) & 1];//(tree >> i) & 1取二进制第i位
            result[index++]=(char)(state+'a');
        }
        return String.valueOf(result);
    }
}

简单代码双100解题思想,可扩展K值到亿规模

  1. 该题与字典序第n个10进制数思想一样
  2. 整个题解空间就是根节点包含3个子节点,其他非叶节点包含两个子节点(因为相邻的字母不相同)
  3. 按照第2步,不断判断当前已确定前缀下,子树的节点数目同目标值大小比较,寻找所在的子树
  4. 该方法可以适配解决比如总字符长度到30,k值到1个亿的取值范围
    public String getHappyString(int n, int k) {
        if (total(n) < k) return "";
        char[] result = new char[n];
        int idx = 0;
        while(idx < n) {
            char pre = idx == 0? '0' : result[idx-1];
            result[idx] = pre == 'a' ? 'b' : 'a';
            int len = 1 << (n - idx - 1);
            while(k > len) {
                result[idx] = (char)(result[idx]+1);
                if (result[idx] == pre) {
                    result[idx] = (char)(result[idx]+1);
                }
                k-=len;
            }
            ++idx;
        }
        return new String(result);
    }


    int total(int n) {
        return 3 * (1 << (n-1));
    }

移山所需的最少秒数

方法一:二分答案

思路与算法

根据题目描述,如果 $t$ 秒可以使山的高度降低到 $0$,那么任何大于 $t$ 的秒数也可以。因此答案具有单调性,我们可以使用二分查找来解决本题。

对于二分查找的每一步,假设当前猜测的秒数为 $\textit{mid}$,我们需要判断所有工人在 $\textit{mid}$ 秒内能否将山的高度降低 $H = \textit{mountainHeight}$。对于第 $i$ 个工人,他将山的高度降低 $k$ 所需的时间为:

$$
\textit{workerTimes}[i] \cdot (1 + 2 + \cdots + k) = \textit{workerTimes}[i] \cdot \frac{k(k+1)}{2}
$$

因此在 $\textit{mid}$ 秒内,第 $i$ 个工人 $i$ 最多能将山降低的高度,是满足

$$\textit{workerTimes}[i] \cdot \frac{k(k+1)}{2} \leq \textit{mid}$$

的最大正整数 $k$。

令 $\textit{work} = \lfloor \dfrac{\textit{mid}}{\textit{workerTimes}[i]} \rfloor$,其中 $\lfloor \cdot \rfloor$ 表示下取整,则需要满足 $\dfrac{k(k+1)}{2} \leq \textit{work}$,利用求一元二次方程求根公式可得:

$$
k = \left\lfloor \frac{-1 + \sqrt{1 + 8 \cdot \textit{work}}}{2} \right\rfloor
$$

我们将所有工人计算得到的 $k$ 值相加,如果总和大于等于 $H$,则说明 $\textit{mid}$ 秒可以完成任务,应当尝试更少的时间,否则尝试更多的时间。

二分查找的下界为 $1$,上界为 $\max(\textit{workerTimes}) \cdot \dfrac{H(H + 1)}{2}$,即最慢的工人独自完成所有工作所需的时间。

代码

###C++

class Solution {
public:
    long long minNumberOfSeconds(int mountainHeight, vector<int>& workerTimes) {
        int maxWorkerTimes = *max_element(workerTimes.begin(), workerTimes.end());
        long long l = 1, r = static_cast<long long>(maxWorkerTimes) * mountainHeight * (mountainHeight + 1) / 2;
        long long ans = 0;

        while (l <= r) {
            long long mid = (l + r) / 2;
            long long cnt = 0;
            for (int t: workerTimes) {
                long long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long long k = (-1.0 + sqrt(1 + work * 8)) / 2 + eps;
                cnt += k;
            }
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            }
            else {
                l = mid + 1;
            }
        }

        return ans;
    }

private:
    static constexpr double eps = 1e-7;
};

###Python

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        maxWorkerTimes = max(workerTimes)
        l, r, ans = 1, maxWorkerTimes * mountainHeight * (mountainHeight + 1) // 2, 0
        eps = 1e-7
        
        while l <= r:
            mid = (l + r) // 2
            cnt = 0
            for t in workerTimes:
                work = mid // t
                # 求最大的 k 满足 1+2+...+k <= work
                k = int((-1 + ((1 + work * 8) ** 0.5)) / 2 + eps)
                cnt += k
            if cnt >= mountainHeight:
                ans = mid
                r = mid - 1
            else:
                l = mid + 1

        return ans

###Java

class Solution {
    private static final double EPS = 1e-7;
    
    public long minNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        for (int t : workerTimes) {
            maxWorkerTimes = Math.max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long) maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            for (int t : workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###C#

class Solution {
    private const double EPS = 1e-7;
    
    public long MinNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        foreach (int t in workerTimes) {
            maxWorkerTimes = Math.Max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            
            foreach (int t in workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.Sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###Go

const eps = 1e-7

func minNumberOfSeconds(mountainHeight int, workerTimes []int) int64 {
    maxWorkerTimes := 0
    for _, t := range workerTimes {
        if t > maxWorkerTimes {
            maxWorkerTimes = t
        }
    }
    
    l := int64(1)
    r := int64(maxWorkerTimes) * int64(mountainHeight) * int64(mountainHeight + 1) / 2
    var ans int64 = 0
    
    for l <= r {
        mid := (l + r) / 2
        var cnt int64 = 0
        
        for _, t := range workerTimes {
            work := mid / int64(t)
            // 求最大的 k 满足 1+2+...+k <= work
            k := int64((-1.0 + math.Sqrt(1 + float64(work) * 8)) / 2 + eps)
            cnt += k
        }
        if cnt >= int64(mountainHeight) {
            ans = mid
            r = mid - 1
        } else {
            l = mid + 1
        }
    }
    
    return ans
}

###C

#define EPS 1e-7

long long minNumberOfSeconds(int mountainHeight, int* workerTimes, int workerTimesSize) {
    int maxWorkerTimes = 0;
    for (int i = 0; i < workerTimesSize; i++) {
        if (workerTimes[i] > maxWorkerTimes) {
            maxWorkerTimes = workerTimes[i];
        }
    }
    
    long long l = 1;
    long long r = (long long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    long long ans = 0;
    
    while (l <= r) {
        long long mid = (l + r) / 2;
        long long cnt = 0;
        for (int i = 0; i < workerTimesSize; i++) {
            long long work = mid / workerTimes[i];
            // 求最大的 k 满足 1+2+...+k <= work
            long long k = (long long)((-1.0 + sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###JavaScript

const EPS = 1e-7;

var minNumberOfSeconds = function(mountainHeight, workerTimes) {
    const maxWorkerTimes = Math.max(...workerTimes);
    let l = 1;
    let r = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans = 0;
    
    while (l <= r) {
        const mid = Math.floor((l + r) / 2);
        let cnt = 0;
        for (const t of workerTimes) {
            const work = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###TypeScript

const EPS: number = 1e-7;

function minNumberOfSeconds(mountainHeight: number, workerTimes: number[]): number {
    const maxWorkerTimes: number = Math.max(...workerTimes);
    let l: number = 1;
    let r: number = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans: number = 0;
    
    while (l <= r) {
        const mid: number = Math.floor((l + r) / 2);
        let cnt: number = 0;
        for (const t of workerTimes) {
            const work: number = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k: number = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###Rust

const EPS: f64 = 1e-7;

impl Solution {
    pub fn min_number_of_seconds(mountain_height: i32, worker_times: Vec<i32>) -> i64 {
        let mountain_height = mountain_height as i64;
        let max_worker_times = *worker_times.iter().max().unwrap_or(&0) as i64;
        
        let mut l: i64 = 1;
        let mut r: i64 = max_worker_times * mountain_height * (mountain_height + 1) / 2;
        let mut ans: i64 = 0;
        
        while l <= r {
            let mid = (l + r) / 2;
            let mut cnt: i64 = 0;
            
            for &t in &worker_times {
                let work = mid / t as i64;
                // 求最大的 k 满足 1+2+...+k <= work
                let k = ((-1.0 + (1.0 + (work as f64) * 8.0).sqrt()) / 2.0 + EPS) as i64;
                cnt += k;
            }
            
            if cnt >= mountain_height {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(n \log(MH^2))$,其中 $n$ 是数组 $\textit{workerTimes}$ 的长度,$M$ 是数组 $\textit{workerTimes}$ 中的最大值,$H$ 是 $\textit{mountainHeight}$。二分查找需要 $O(\log(MH^2))$ 次迭代,每次迭代遍历所有工人,需要 $O(n)$ 的时间。

  • 空间复杂度:$O(1)$。

每日一题-移山所需的最少秒数🟡

给你一个整数 mountainHeight 表示山的高度。

同时给你一个整数数组 workerTimes,表示工人们的工作时间(单位:)。

工人们需要 同时 进行工作以 降低 山的高度。对于工人 i :

  • 山的高度降低 x,需要花费 workerTimes[i] + workerTimes[i] * 2 + ... + workerTimes[i] * x 秒。例如:
    • 山的高度降低 1,需要 workerTimes[i] 秒。
    • 山的高度降低 2,需要 workerTimes[i] + workerTimes[i] * 2 秒,依此类推。

返回一个整数,表示工人们使山的高度降低到 0 所需的 最少 秒数。

 

示例 1:

输入: mountainHeight = 4, workerTimes = [2,1,1]

输出: 3

解释:

将山的高度降低到 0 的一种方式是:

  • 工人 0 将高度降低 1,花费 workerTimes[0] = 2 秒。
  • 工人 1 将高度降低 2,花费 workerTimes[1] + workerTimes[1] * 2 = 3 秒。
  • 工人 2 将高度降低 1,花费 workerTimes[2] = 1 秒。

因为工人同时工作,所需的最少时间为 max(2, 3, 1) = 3 秒。

示例 2:

输入: mountainHeight = 10, workerTimes = [3,2,2,4]

输出: 12

解释:

  • 工人 0 将高度降低 2,花费 workerTimes[0] + workerTimes[0] * 2 = 9 秒。
  • 工人 1 将高度降低 3,花费 workerTimes[1] + workerTimes[1] * 2 + workerTimes[1] * 3 = 12 秒。
  • 工人 2 将高度降低 3,花费 workerTimes[2] + workerTimes[2] * 2 + workerTimes[2] * 3 = 12 秒。
  • 工人 3 将高度降低 2,花费 workerTimes[3] + workerTimes[3] * 2 = 12 秒。

所需的最少时间为 max(9, 12, 12, 12) = 12 秒。

示例 3:

输入: mountainHeight = 5, workerTimes = [1]

输出: 15

解释:

这个示例中只有一个工人,所以答案是 workerTimes[0] + workerTimes[0] * 2 + workerTimes[0] * 3 + workerTimes[0] * 4 + workerTimes[0] * 5 = 15 秒。

 

提示:

  • 1 <= mountainHeight <= 105
  • 1 <= workerTimes.length <= 104
  • 1 <= workerTimes[i] <= 106
❌