阅读视图

发现新文章,点击刷新页面。

每日一题-子矩阵元素加 1🟡

给你一个正整数 n ,表示最初有一个 n x n 、下标从 0 开始的整数矩阵 mat ,矩阵中填满了 0 。

另给你一个二维整数数组 query 。针对每个查询 query[i] = [row1i, col1i, row2i, col2i] ,请你执行下述操作:

  • 找出 左上角(row1i, col1i)右下角(row2i, col2i) 的子矩阵,将子矩阵中的 每个元素1 。也就是给所有满足 row1i <= x <= row2icol1i <= y <= col2imat[x][y]1

返回执行完所有操作后得到的矩阵 mat

 

示例 1:

输入:n = 3, queries = [[1,1,2,2],[0,0,1,1]]
输出:[[1,1,0],[1,2,1],[0,1,1]]
解释:上图所展示的分别是:初始矩阵、执行完第一个操作后的矩阵、执行完第二个操作后的矩阵。
- 第一个操作:将左上角为 (1, 1) 且右下角为 (2, 2) 的子矩阵中的每个元素加 1 。 
- 第二个操作:将左上角为 (0, 0) 且右下角为 (1, 1) 的子矩阵中的每个元素加 1 。 

示例 2:

输入:n = 2, queries = [[0,0,1,1]]
输出:[[1,1],[1,1]]
解释:上图所展示的分别是:初始矩阵、执行完第一个操作后的矩阵。 
- 第一个操作:将矩阵中的每个元素加 1 。

 

提示:

  • 1 <= n <= 500
  • 1 <= queries.length <= 104
  • 0 <= row1i <= row2i < n
  • 0 <= col1i <= col2i < n

二维差分(+图解)

这个题可以用标准的二维差分来做:
对所有的查询,首先维护二维差分数组;然后对差分数组求前缀和即为答案。

如果不熟悉二维差分,可以参考我的这篇 题解。下面的说明摘自我之前的题解。

如果将矩阵的第 $(i,j)$ 个单元格中的值增加 $1$,那么,若对矩阵求二维前缀和,那么下图 $(a)$ 中的黄色区域的值都会增加 $1$。

如果要将矩阵中的 任意 矩形区域(如下图中 $(b)$ 的蓝色区域)的值增加 $1$ 呢?只需按照下图 $(c)$ 来修改矩阵即可。修改后,若对矩阵求前缀和,那么,只会有蓝色的区域的值 $+1$,其它区域的值都不变。

image.png

###c++

class Solution {
public:
    vector<vector<int>> rangeAddQueries(int n, vector<vector<int>>& queries) {
        vector<vector<int>> diff(n + 1, vector<int>(n + 1, 0));
        vector<vector<int>> ret(n, vector<int>(n, 0));
        for(const auto& q : queries) {
            diff[q[0]][q[1]]++;
            diff[q[0]][q[3]+1]--;
            diff[q[2]+1][q[1]]--;
            diff[q[2]+1][q[3]+1]++;
        }
        for(int i = 0; i < n; ++i)
            for(int j = 1; j < n; ++j) diff[i][j] += diff[i][j-1];
        for(int i = 1; i < n; ++i)
            for(int j = 0; j < n; ++j) diff[i][j] += diff[i-1][j];
        for(int i = 0; i < n; ++i)
            for(int j = 0; j < n; ++j) ret[i][j] = diff[i][j];
        return ret;
    }
};

二维树状数组Py/C++/C/Java/JS/Go/Rust/Swift/Kotlin/TS/C#

二维树状数组

根据题意,我们很自然的想到,需要一个支持在二维数组中进行区间更新和单点查询的数据结构。
因此可以使用二维树状数组求解。
注:本题使用二维差分的做法更优,请参考灵神的题解。

P.S. 抱歉之前的描述有误,我用的二维树状数组而不是线段树。

代码

###Python3

#二维树状数组,维护区域和
class SegmentTree2D:
    def __init__(self, n, m):
        self.n = n
        self.m = m
        self.tree = [[0] * (m + 1) for _ in range(n + 1)]

    def lowbit(self, x):
        return x & (-x)

    def update(self, x, y, val):
        while x <= self.n:
            y1 = y
            while y1 <= self.m:
                self.tree[x][y1] += val
                y1 += self.lowbit(y1)
            x += self.lowbit(x)

    def query(self, x, y):
        res = 0
        while x > 0:
            y1 = y
            while y1 > 0:
                res += self.tree[x][y1]
                y1 -= self.lowbit(y1)
            x -= self.lowbit(x)
        return res


class Solution:
    def rangeAddQueries(self, n: int, queries: List[List[int]]) -> List[List[int]]:
        seg = SegmentTree2D(n, n)
        for x1, y1, x2, y2 in queries:
            seg.update(x1 + 1, y1 + 1, 1)
            seg.update(x2 + 2, y1 + 1, -1)
            seg.update(x1 + 1, y2 + 2, -1)
            seg.update(x2 + 2, y2 + 2, 1)
        res = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                res[i][j] = seg.query(i + 1, j + 1)
        return res

###C++

#include<bits/stdc++.h>

using namespace std;

class SegmentTree2D{
public:
    int n, m;
    vector<vector<int>> tree;

    SegmentTree2D(int n, int m): n(n), m(m), tree(n + 1, vector<int>(m + 1, 0)){}

    int lowbit(int x){
        return x & (-x);
    }

    void update(int x, int y, int val){
        while(x <= n){
            int y1 = y;
            while(y1 <= m){
                tree[x][y1] += val;
                y1 += lowbit(y1);
            }
            x += lowbit(x);
        }
    }

    int query(int x, int y){
        int res = 0;
        while(x > 0){
            int y1 = y;
            while(y1 > 0){
                res += tree[x][y1];
                y1 -= lowbit(y1);
            }
            x -= lowbit(x);
        }
        return res;
    }
};

class Solution {
public:
    vector<vector<int>> rangeAddQueries(int n, vector<vector<int>>& queries) {
        SegmentTree2D seg(n, n);
        for(auto& q: queries){
            seg.update(q[0] + 1, q[1] + 1, 1);
            seg.update(q[2] + 2, q[1] + 1, -1);
            seg.update(q[0] + 1, q[3] + 2, -1);
            seg.update(q[2] + 2, q[3] + 2, 1);
        }
        vector<vector<int>> res(n, vector<int>(n, 0));
        for(int i = 0; i < n; i++){
            for(int j = 0; j < n; j++){
                res[i][j] = seg.query(i + 1, j + 1);
            }
        }
        return res;
    }
};

###C

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef struct SegmentTree2D{
    int n, m;
    int **tree;
} SegmentTree2D;

int lowbit(int x){
    return x & (-x);
}

void update(SegmentTree2D *seg, int x, int y, int val){
    while(x <= seg->n){
        int y1 = y;
        while(y1 <= seg->m){
            seg->tree[x][y1] += val;
            y1 += lowbit(y1);
        }
        x += lowbit(x);
    }
}

int query(SegmentTree2D *seg, int x, int y){
    int res = 0;
    while(x > 0){
        int y1 = y;
        while(y1 > 0){
            res += seg->tree[x][y1];
            y1 -= lowbit(y1);
        }
        x -= lowbit(x);
    }
    return res;
}

/**
 * Return an array of arrays of size *returnSize.
 * The sizes of the arrays are returned as *returnColumnSizes array.
 * Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().
 */
int** rangeAddQueries(int n, int** queries, int queriesSize, int* queriesColSize, int* returnSize, int** returnColumnSizes){
    SegmentTree2D *seg = (SegmentTree2D *)malloc(sizeof(SegmentTree2D));
    seg->n = n;
    seg->m = n;
    seg->tree = (int **)malloc(sizeof(int *) * (n + 1));
    for(int i = 0; i <= n; i++){
        seg->tree[i] = (int *)malloc(sizeof(int) * (n + 1));
        memset(seg->tree[i], 0, sizeof(int) * (n + 1));
    }
    for(int i = 0; i < queriesSize; i++){
        update(seg, queries[i][0] + 1, queries[i][1] + 1, 1);
        update(seg, queries[i][2] + 2, queries[i][1] + 1, -1);
        update(seg, queries[i][0] + 1, queries[i][3] + 2, -1);
        update(seg, queries[i][2] + 2, queries[i][3] + 2, 1);
    }
    int **res = (int **)malloc(sizeof(int *) * n);
    for(int i = 0; i < n; i++){
        res[i] = (int *)malloc(sizeof(int) * n);
        for(int j = 0; j < n; j++){
            res[i][j] = query(seg, i + 1, j + 1);
        }
    }
    *returnSize = n;
    *returnColumnSizes = (int *)malloc(sizeof(int) * n);
    for(int i = 0; i < n; i++){
        (*returnColumnSizes)[i] = n;
    }
    return res;
}

###Java

//二维树状数组
class SegmentTree2D{
    int n, m;
    int[][] tree;

    public SegmentTree2D(int n, int m){
        this.n = n;
        this.m = m;
        tree = new int[n + 1][m + 1];
    }

    int lowbit(int x){
        return x & (-x);
    }

    void update(int x, int y, int val){
        while(x <= n){
            int y1 = y;
            while(y1 <= m){
                tree[x][y1] += val;
                y1 += lowbit(y1);
            }
            x += lowbit(x);
        }
    }

    int query(int x, int y){
        int res = 0;
        while(x > 0){
            int y1 = y;
            while(y1 > 0){
                res += tree[x][y1];
                y1 -= lowbit(y1);
            }
            x -= lowbit(x);
        }
        return res;
    }
};

class Solution {
    public int[][] rangeAddQueries(int n, int[][] queries) {
        SegmentTree2D seg = new SegmentTree2D(n, n);
        for(int[] q: queries){
            seg.update(q[0] + 1, q[1] + 1, 1);
            seg.update(q[2] + 2, q[1] + 1, -1);
            seg.update(q[0] + 1, q[3] + 2, -1);
            seg.update(q[2] + 2, q[3] + 2, 1);
        }
        int[][] res = new int[n][n];
        for(int i = 0; i < n; i++){
            for(int j = 0; j < n; j++){
                res[i][j] = seg.query(i + 1, j + 1);
            }
        }
        return res;
    }
}

###Javascript

// #二维树状数组,维护区域和
class SegmentTree2D {
    constructor(n, m) {
        this.n = n;
        this.m = m;
        this.tree = new Array(n + 1).fill(0).map(() => new Array(m + 1).fill(0));
    }

    lowbit(x) {
        return x & (-x);
    }

    update(x, y, val) {
        while (x <= this.n) {
            let y1 = y;
            while (y1 <= this.m) {
                this.tree[x][y1] += val;
                y1 += this.lowbit(y1);
            }
            x += this.lowbit(x);
        }
    }

    query(x, y) {
        let res = 0;
        while (x > 0) {
            let y1 = y;
            while (y1 > 0) {
                res += this.tree[x][y1];
                y1 -= this.lowbit(y1);
            }
            x -= this.lowbit(x);
        }
        return res;
    }
}

/**
 * @param {number} n
 * @param {number[][]} queries
 * @return {number[][]}
 */
var rangeAddQueries = function(n, queries) {
    let seg = new SegmentTree2D(n, n);
    for (let [x1, y1, x2, y2] of queries) {
        seg.update(x1 + 1, y1 + 1, 1);
        seg.update(x2 + 2, y1 + 1, -1);
        seg.update(x1 + 1, y2 + 2, -1);
        seg.update(x2 + 2, y2 + 2, 1);
    }
    let res = new Array(n).fill(0).map(() => new Array(n).fill(0));
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            res[i][j] = seg.query(i + 1, j + 1);
        }
    }
    return res;
};

###Go

package main

type SegmentTree2D struct {
n, m int
tree [][]int
}

func lowbit(x int) int {
return x & (-x)
}

func update(seg *SegmentTree2D, x, y, val int) {
for x <= seg.n {
y1 := y
for y1 <= seg.m {
seg.tree[x][y1] += val
y1 += lowbit(y1)
}
x += lowbit(x)
}
}

func query(seg *SegmentTree2D, x, y int) int {
res := 0
for x > 0 {
y1 := y
for y1 > 0 {
res += seg.tree[x][y1]
y1 -= lowbit(y1)
}
x -= lowbit(x)
}
return res
}

func rangeAddQueries(n int, queries [][]int) [][]int {
seg := &SegmentTree2D{
n:    n,
m:    n,
tree: make([][]int, n+1),
}
for i := 0; i <= n; i++ {
seg.tree[i] = make([]int, n+1)
}
for _, query := range queries {
update(seg, query[0]+1, query[1]+1, 1)
update(seg, query[2]+2, query[1]+1, -1)
update(seg, query[0]+1, query[3]+2, -1)
update(seg, query[2]+2, query[3]+2, 1)
}
res := make([][]int, n)
for i := 0; i < n; i++ {
res[i] = make([]int, n)
for j := 0; j < n; j++ {
res[i][j] = query(seg, i+1, j+1)
}
}
return res
}

###Rust

use std::cmp::min;

struct SegmentTree2D {
    n: usize,
    m: usize,
    tree: Vec<Vec<i32>>,
}

impl SegmentTree2D {
    fn new(n: usize, m: usize) -> Self {
        let mut tree = vec![vec![0; m + 1]; n + 1];
        SegmentTree2D { n, m, tree }
    }

    fn lowbit(&self, x: usize) -> usize {
        x & (!x + 1)
    }

    fn update(&mut self, x: usize, y: usize, val: i32) {
        let mut x = x;
        while x <= self.n {
            let mut y1 = y;
            while y1 <= self.m {
                self.tree[x][y1] += val;
                y1 += self.lowbit(y1);
            }
            x += self.lowbit(x);
        }
    }

    fn query(&self, x: usize, y: usize) -> i32 {
        let mut res = 0;
        let mut x = x;
        while x > 0 {
            let mut y1 = y;
            while y1 > 0 {
                res += self.tree[x][y1];
                y1 -= self.lowbit(y1);
            }
            x -= self.lowbit(x);
        }
        res
    }
}

impl Solution {
    pub fn range_add_queries(n: i32, queries: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        let mut seg = SegmentTree2D::new(n as usize, n as usize);
        for query in queries {
            seg.update(query[0] as usize + 1, query[1] as usize + 1, 1);
            seg.update(query[2] as usize + 2, query[1] as usize + 1, -1);
            seg.update(query[0] as usize + 1, query[3] as usize + 2, -1);
            seg.update(query[2] as usize + 2, query[3] as usize + 2, 1);
        }
        let mut res = vec![vec![0; n as usize]; n as usize];
        for i in 0..n as usize {
            for j in 0..n as usize {
                res[i][j] = seg.query(i + 1, j + 1);
            }
        }
        res
    }
}

###Swift

class SegmentTree2D {
    var n: Int
    var m: Int
    var tree: [[Int]]
    
    init(n: Int, m: Int) {
        self.n = n
        self.m = m
        self.tree = Array(repeating: Array(repeating: 0, count: m + 1), count: n + 1)
    }
    
    func lowbit(_ x: Int) -> Int {
        return x & (-x)
    }
    
    func update(_ x: Int, _ y: Int, _ val: Int) {
        var x = x
        while x <= n {
            var y1 = y
            while y1 <= m {
                tree[x][y1] += val
                y1 += lowbit(y1)
            }
            x += lowbit(x)
        }
    }
    
    func query(_ x: Int, _ y: Int) -> Int {
        var res = 0
        var x = x
        while x > 0 {
            var y1 = y
            while y1 > 0 {
                res += tree[x][y1]
                y1 -= lowbit(y1)
            }
            x -= lowbit(x)
        }
        return res
    }
}

class Solution {
    func rangeAddQueries(_ n: Int, _ queries: [[Int]]) -> [[Int]] {
        let seg = SegmentTree2D(n: n, m: n)
        for q in queries {
            seg.update(q[0] + 1, q[1] + 1, 1)
            seg.update(q[2] + 2, q[1] + 1, -1)
            seg.update(q[0] + 1, q[3] + 2, -1)
            seg.update(q[2] + 2, q[3] + 2, 1)
        }
        var res = Array(repeating: Array(repeating: 0, count: n), count: n)
        for i in 0..<n {
            for j in 0..<n {
                res[i][j] = seg.query(i + 1, j + 1)
            }
        }
        return res
    }
}

###Kotlin

class SegmentTree2D {
    val n: Int
    val m: Int
    val tree: Array<IntArray>

    constructor(n: Int, m: Int) {
        this.n = n
        this.m = m
        tree = Array(n + 1) { IntArray(m + 1) }
    }

    fun lowbit(x: Int): Int {
        return x and (-x)
    }

    fun update(x: Int, y: Int, `val`: Int) {
        var x = x
        while (x <= n) {
            var y1 = y
            while (y1 <= m) {
                tree[x][y1] += `val`
                y1 += lowbit(y1)
            }
            x += lowbit(x)
        }
    }

    fun query(x: Int, y: Int): Int {
        var res = 0
        var x = x
        while (x > 0) {
            var y1 = y
            while (y1 > 0) {
                res += tree[x][y1]
                y1 -= lowbit(y1)
            }
            x -= lowbit(x)
        }
        return res
    }
}

class Solution {
    fun rangeAddQueries(n: Int, queries: Array<IntArray>): Array<IntArray> {
        val seg = SegmentTree2D(n, n)
        for (q in queries) {
            seg.update(q[0] + 1, q[1] + 1, 1)
            seg.update(q[2] + 2, q[1] + 1, -1)
            seg.update(q[0] + 1, q[3] + 2, -1)
            seg.update(q[2] + 2, q[3] + 2, 1)
        }
        val res = Array(n) { IntArray(n) }
        for (i in 0 until n) {
            for (j in 0 until n) {
                res[i][j] = seg.query(i + 1, j + 1)
            }
        }
        return res
    }
}

###TypeScript

class SegmentTree2D {
    n: number;
    m: number;
    tree: number[][];
    constructor(n: number, m: number) {
        this.n = n;
        this.m = m;
        this.tree = new Array(n + 1).fill(0).map(() => new Array(m + 1).fill(0));
    }
    lowbit(x: number): number {
        return x & (-x);
    }

    update(x: number, y: number, val: number): void {
        while (x <= this.n) {
            let y1 = y;
            while (y1 <= this.m) {
                this.tree[x][y1] += val;
                y1 += this.lowbit(y1);
            }
            x += this.lowbit(x);
        }
    }

    query(x: number, y: number): number {
        let res = 0;
        while (x > 0) {
            let y1 = y;
            while (y1 > 0) {
                res += this.tree[x][y1];
                y1 -= this.lowbit(y1);
            }
            x -= this.lowbit(x);
        }
        return res;
    }
}

function rangeAddQueries(n: number, queries: number[][]): number[][] {
    const seg = new SegmentTree2D(n, n);
    for (const query of queries) {
        seg.update(query[0] + 1, query[1] + 1, 1);
        seg.update(query[2] + 2, query[1] + 1, -1);
        seg.update(query[0] + 1, query[3] + 2, -1);
        seg.update(query[2] + 2, query[3] + 2, 1);
    }
    const res: number[][] = new Array(n).fill(0).map(() => new Array(n).fill(0));
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            res[i][j] = seg.query(i + 1, j + 1);
        }
    }
    return res;
};

###C#

class SegmentTree2D{
    int n, m;
    int[,] tree;

    public SegmentTree2D(int n, int m){
        this.n = n;
        this.m = m;
        tree = new int[n + 1,m + 1];
    }

    public int lowbit(int x){
        return x & (-x);
    }

    public void update(int x, int y, int val){
        while(x <= n){
            int y1 = y;
            while(y1 <= m){
                tree[x,y1] += val;
                y1 += lowbit(y1);
            }
            x += lowbit(x);
        }
    }

    public int query(int x, int y){
        int res = 0;
        while(x > 0){
            int y1 = y;
            while(y1 > 0){
                res += tree[x,y1];
                y1 -= lowbit(y1);
            }
            x -= lowbit(x);
        }
        return res;
    }
};

public class Solution {
    public int[][] RangeAddQueries(int n, int[][] queries) {
        SegmentTree2D seg = new SegmentTree2D(n, n);
        foreach(int[] q in queries){
            seg.update(q[0] + 1, q[1] + 1, 1);
            seg.update(q[2] + 2, q[1] + 1, -1);
            seg.update(q[0] + 1, q[3] + 2, -1);
            seg.update(q[2] + 2, q[3] + 2, 1);
        }
        int[][] res = new int[n][];
        for(int i = 0; i < n; i++){
            res[i] = new int[n];
            for(int j = 0; j < n; j++){
                res[i][j] = seg.query(i + 1, j + 1);
            }
        }
        return res;
    }
}

时间复杂度:$O(Q \times \log(N^2) + N^2 \times \log(N^2))$。其中Q是查询次数,N是二维数组的宽度和高度。
空间复杂度:$O(N^2)$。

各语言执行用时(Java最快,Python最慢):

微信图片_20230117175113.jpg{:width=400}

【模板题】二维差分+二维前缀和(Python/Java/C++/C/Go/JS/Rust)

前置知识

  1. 【图解】从一维差分到二维差分
  2. 【图解】一张图秒懂二维前缀和

思路

二维差分 $\mathcal{O}(1)$ 处理每个 $\textit{queries}[i]$。

然后计算二维差分矩阵的二维前缀和,即为答案。

代码实现时,为方便计算二维前缀和,可以在二维差分矩阵最上面添加一行 $0$,最左边添加一列 $0$,这样计算二维前缀和无需考虑下标越界。

class Solution:
    def rangeAddQueries(self, n: int, queries: List[List[int]]) -> List[List[int]]:
        # 二维差分
        diff = [[0] * (n + 2) for _ in range(n + 2)]
        for r1, c1, r2, c2 in queries:
            diff[r1 + 1][c1 + 1] += 1
            diff[r1 + 1][c2 + 2] -= 1
            diff[r2 + 2][c1 + 1] -= 1
            diff[r2 + 2][c2 + 2] += 1

        # 原地计算 diff 的二维前缀和,然后填入答案
        ans = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j]
                ans[i][j] = diff[i + 1][j + 1]
        return ans
class Solution {
    public int[][] rangeAddQueries(int n, int[][] queries) {
        // 二维差分
        int[][] diff = new int[n + 2][n + 2];
        for (int[] q : queries) {
            int r1 = q[0], c1 = q[1], r2 = q[2], c2 = q[3];
            diff[r1 + 1][c1 + 1]++;
            diff[r1 + 1][c2 + 2]--;
            diff[r2 + 2][c1 + 1]--;
            diff[r2 + 2][c2 + 2]++;
        }

        // 原地计算 diff 的二维前缀和,然后填入答案
        int[][] ans = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j];
                ans[i][j] = diff[i + 1][j + 1];
            }
        }
        return ans;
    }
}
class Solution {
public:
    vector<vector<int>> rangeAddQueries(int n, vector<vector<int>>& queries) {
        // 二维差分
        vector diff(n + 2, vector<int>(n + 2));
        for (auto& q : queries) {
            int r1 = q[0], c1 = q[1], r2 = q[2], c2 = q[3];
            diff[r1 + 1][c1 + 1]++;
            diff[r1 + 1][c2 + 2]--;
            diff[r2 + 2][c1 + 1]--;
            diff[r2 + 2][c2 + 2]++;
        }

        // 原地计算 diff 的二维前缀和,然后填入答案
        vector ans(n, vector<int>(n));
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j];
                ans[i][j] = diff[i + 1][j + 1];
            }
        }
        return ans;
    }
};
int** rangeAddQueries(int n, int** queries, int queriesSize, int* queriesColSize, int* returnSize, int** returnColumnSizes) {
    // 二维差分
    int** diff = calloc(n + 2, sizeof(int*));
    for (int i = 0; i < n + 2; i++) {
        diff[i] = calloc(n + 2, sizeof(int));
    }
    for (int i = 0; i < queriesSize; i++) {
        int r1 = queries[i][0], c1 = queries[i][1], r2 = queries[i][2], c2 = queries[i][3];
        diff[r1 + 1][c1 + 1]++;
        diff[r1 + 1][c2 + 2]--;
        diff[r2 + 2][c1 + 1]--;
        diff[r2 + 2][c2 + 2]++;
    }

    // 原地计算 diff 的二维前缀和,然后填入答案
    int** ans = malloc(n * sizeof(int*));
    *returnSize = n;
    *returnColumnSizes = malloc(n * sizeof(int));
    for (int i = 0; i < n; i++) {
        ans[i] = malloc(n * sizeof(int));
        (*returnColumnSizes)[i] = n;
    }
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j];
            ans[i][j] = diff[i + 1][j + 1];
        }
    }

    for (int i = 0; i < n + 2; i++) {
        free(diff[i]);
    }
    free(diff);
    return ans;
}
func rangeAddQueries(n int, queries [][]int) [][]int {
// 二维差分
diff := make([][]int, n+2)
for i := range diff {
diff[i] = make([]int, n+2)
}
for _, q := range queries {
r1, c1, r2, c2 := q[0], q[1], q[2], q[3]
diff[r1+1][c1+1]++
diff[r1+1][c2+2]--
diff[r2+2][c1+1]--
diff[r2+2][c2+2]++
}

// 原地计算 diff 的二维前缀和,然后填入答案
ans := make([][]int, n)
for i := range ans {
ans[i] = make([]int, n)
for j := range ans[i] {
diff[i+1][j+1] += diff[i+1][j] + diff[i][j+1] - diff[i][j]
ans[i][j] = diff[i+1][j+1]
}
}
return ans
}
var rangeAddQueries = function(n, queries) {
    // 二维差分
    const diff = Array.from({ length: n + 2 }, () => Array(n + 2).fill(0));
    for (const [r1, c1, r2, c2] of queries) {
        diff[r1 + 1][c1 + 1]++;
        diff[r1 + 1][c2 + 2]--;
        diff[r2 + 2][c1 + 1]--;
        diff[r2 + 2][c2 + 2]++;
    }

    // 原地计算 diff 的二维前缀和,然后填入答案
    const ans = Array.from({ length: n }, () => Array(n).fill(0));
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j];
            ans[i][j] = diff[i + 1][j + 1];
        }
    }
    return ans;
};
impl Solution {
    pub fn range_add_queries(n: i32, queries: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        let n = n as usize;
        // 二维差分
        let mut diff = vec![vec![0; n + 2]; n + 2];
        for q in queries {
            let (r1, c1, r2, c2) = (q[0] as usize, q[1] as usize, q[2] as usize, q[3] as usize);
            diff[r1 + 1][c1 + 1] += 1;
            diff[r1 + 1][c2 + 2] -= 1;
            diff[r2 + 2][c1 + 1] -= 1;
            diff[r2 + 2][c2 + 2] += 1;
        }

        // 原地计算 diff 的二维前缀和,然后填入答案
        let mut ans = vec![vec![0; n]; n];
        for i in 0..n {
            for j in 0..n {
                diff[i + 1][j + 1] += diff[i + 1][j] + diff[i][j + 1] - diff[i][j];
                ans[i][j] = diff[i + 1][j + 1];
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2+q)$,其中 $q$ 是 $\textit{queries}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n^2)$。

:也可以创建 $n\times n$ 大小的 $\textit{diff}$,原地计算二维前缀和,最后直接返回 $\textit{diff}$。

专题训练

见数据结构题单的「§2.2 二维差分」和「§1.6 二维前缀和」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-将 1 移动到末尾的最大操作次数🟡

给你一个 二进制字符串 s

你可以对这个字符串执行 任意次 下述操作:

  • 选择字符串中的任一下标 ii + 1 < s.length ),该下标满足 s[i] == '1's[i + 1] == '0'
  • 将字符 s[i]右移 直到它到达字符串的末端或另一个 '1'。例如,对于 s = "010010",如果我们选择 i = 1,结果字符串将会是 s = "000110"

返回你能执行的 最大 操作次数。

 

示例 1:

输入: s = "1001101"

输出: 4

解释:

可以执行以下操作:

  • 选择下标 i = 0。结果字符串为 s = "0011101"
  • 选择下标 i = 4。结果字符串为 s = "0011011"
  • 选择下标 i = 3。结果字符串为 s = "0010111"
  • 选择下标 i = 2。结果字符串为 s = "0001111"

示例 2:

输入: s = "00111"

输出: 0

 

提示:

  • 1 <= s.length <= 105
  • s[i]'0''1'

递推

解法:递推

设 $f(i)$ 表示可以移动原字符串中下标为 $i$ 的 $1$ 几次。

若 $s_i$ 和 $s_{i + 1}$ 都是 $1$,后续只有 $s_{i + 1}$ 动了,$s_i$ 才能动,所以这种情况下 $f(i) = f(i + 1)$。

若 $s_i$ 是 $1$,$s_{i + 1}$ 是 $0$,下一个 $1$ 在 $s_j$($j > i$),那么我们可以先把 $s_i$ 移动到下标 $(j - 1)$,就变成了上面那种情况。所以这种情况下 $f(i) = f(j) + 1$。

答案就是 $\sum f(i)$。复杂度 $\mathcal{O}(n)$。

参考代码(c++)

###cpp

class Solution {
public:
    int maxOperations(string s) {
        int n = s.size();
        long long ans = 0;
        for (int i = n - 2, last = 0; i >= 0; i--) if (s[i] == '1') {
            if (s[i + 1] == '0') last++;
            ans += last;
        }
        return ans;
    }
};

堵车模型(Python/Java/C++/C/Go/JS/Rust)

把 $1$ 当作,想象有一条长为 $n$ 的道路上有一些车。

题意:把所有的车都开到最右边。例如 $011010$ 最终要变成 $000111$。

如果优先操作右边的(能移动的)车,那么这些车都只需操作一次:

$$
\begin{aligned}
& 011010 \
\to{} & 011001 \
\to{} & 010011 \
\to{} & 000111 \
\end{aligned}
$$

一共需要操作 $3$ 次(注意一次操作可以让一辆车移动多次)。

而如果优先操作左边的(能移动的)车,这会制造大量的「堵车」,每辆车的操作次数会更多。

$$
\begin{aligned}
& 011010 \
\to{} & 010110 \
\to{} & 001110 \
\to{} & 001101 \
\to{} & 001011 \
\to{} & 000111 \
\end{aligned}
$$

一共需要操作 $5$ 次。

算法

  1. 从左到右遍历 $s$,同时用一个变量 $\textit{cnt}_1$ 维护遍历到的 $1$ 的个数。
  2. 如果 $s[i]$ 是 $1$,把 $\textit{cnt}_1$ 增加 $1$。
  3. 如果 $s[i]$ 是 $0$ 且 $s[i-1]$ 是 $1$,意味着我们找到了一段道路,可以让 $i$ 左边的每辆车都操作一次,把答案增加 $\textit{cnt}_1$。
  4. 遍历结束,返回答案。

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def maxOperations(self, s: str) -> int:
        ans = cnt1 = 0
        for i, c in enumerate(s):
            if c == '1':
                cnt1 += 1
            elif i > 0 and s[i - 1] == '1':
                ans += cnt1
        return ans

###java

class Solution {
    public int maxOperations(String S) {
        char[] s = S.toCharArray();
        int ans = 0;
        int cnt1 = 0;
        for (int i = 0; i < s.length; i++) {
            if (s[i] == '1') {
                cnt1++;
            } else if (i > 0 && s[i - 1] == '1') {
                ans += cnt1;
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int maxOperations(string s) {
        int ans = 0, cnt1 = 0;
        for (int i = 0; i < s.size(); i++) {
            if (s[i] == '1') {
                cnt1++;
            } else if (i > 0 && s[i - 1] == '1') {
                ans += cnt1;
            }
        }
        return ans;
    }
};

###c

int maxOperations(char* s) {
    int ans = 0, cnt1 = 0;
    for (int i = 0; s[i]; i++) {
        char c = s[i];
        if (c == '1') {
            cnt1++;
        } else if (i > 0 && s[i - 1] == '1') {
            ans += cnt1;
        }
    }
    return ans;
}

###go

func maxOperations(s string) (ans int) {
cnt1 := 0
for i, c := range s {
if c == '1' {
cnt1++
} else if i > 0 && s[i-1] == '1' {
ans += cnt1
}
}
return
}

###js

var maxOperations = function(s) {
    let ans = 0, cnt1 = 0;
    for (let i = 0; i < s.length; i++) {
        const c = s[i];
        if (c === '1') {
            cnt1++;
        } else if (i > 0 && s[i - 1] === '1') {
            ans += cnt1;
        }
    }
    return ans;
};

###rust

impl Solution {
    pub fn max_operations(s: String) -> i32 {
        let s = s.as_bytes();
        let mut ans = 0;
        let mut cnt1 = 0;
        for (i, &c) in s.iter().enumerate() {
            if c == b'1' {
                cnt1 += 1;
            } else if i > 0 && s[i - 1] == b'1' {
                ans += cnt1;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $s$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

思考题

构造一个 $s$,让返回值尽量大。

如果 $n=10^5$,答案最大能是多少?会不会超过 $\texttt{int}$ 最大值?

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-使数组所有元素变成 1 的最少操作次数🟡

给你一个下标从 0 开始的  整数数组 nums 。你可以对数组执行以下操作 任意 次:

  • 选择一个满足 0 <= i < n - 1 的下标 i ,将 nums[i] 或者 nums[i+1] 两者之一替换成它们的最大公约数。

请你返回使数组 nums 中所有元素都等于 1 的 最少 操作次数。如果无法让数组全部变成 1 ,请你返回 -1 。

两个正整数的最大公约数指的是能整除这两个数的最大正整数。

 

示例 1:

输入:nums = [2,6,3,4]
输出:4
解释:我们可以执行以下操作:
- 选择下标 i = 2 ,将 nums[2] 替换为 gcd(3,4) = 1 ,得到 nums = [2,6,1,4] 。
- 选择下标 i = 1 ,将 nums[1] 替换为 gcd(6,1) = 1 ,得到 nums = [2,1,1,4] 。
- 选择下标 i = 0 ,将 nums[0] 替换为 gcd(2,1) = 1 ,得到 nums = [1,1,1,4] 。
- 选择下标 i = 2 ,将 nums[3] 替换为 gcd(1,4) = 1 ,得到 nums = [1,1,1,1] 。

示例 2:

输入:nums = [2,10,6,14]
输出:-1
解释:无法将所有元素都变成 1 。

 

提示:

  • 2 <= nums.length <= 50
  • 1 <= nums[i] <= 106

python SlidingWindowAggregation O(nlogk) 求gcd为1的最短子数组

解题思路

SlidingWindowAggregation 是一个维护幺半群的滑动窗口的数据结构,可以在 $O(1)$ 时间内做到入队出队查询滑窗内的聚合值,原理类似 面试题 03.02. 栈的最小值 (StackAggregation).


ps: 吐槽一下这道题的数据量,看到 50 很容易会去想回溯+剪枝,但是会收获一大堆TLE 🤣

代码

###python3

INF = int(1e20)

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        if gcd(*nums) != 1:
            return -1
        if 1 in nums:
            return len(nums) - nums.count(1)
        return minLen(nums) - 1 + len(nums) - 1


def minLen(nums: List[int]) -> int:
    """gcd为1的最短子数组.不存在返回INF."""
    n = len(nums)
    S = SlidingWindowAggregation(lambda: 0, gcd)
    res, n = INF, len(nums)
    for right in range(n):
        S.append(nums[right])
        while S and S.query() == 1:
            res = min(res, len(S))
            S.popleft()
    return res

###python3

from typing import Callable, Generic, List, TypeVar

E = TypeVar("E")

class SlidingWindowAggregation(Generic[E]):
    """SlidingWindowAggregation

    Api:
    1. append value to tail,O(1).
    2. pop value from head,O(1).
    3. query aggregated value in window,O(1).
    """

    __slots__ = ["_stack0", "_stack1", "_stack2", "_stack3", "_e0", "_e1", "_size", "_op", "_e"]

    def __init__(self, e: Callable[[], E], op: Callable[[E, E], E]):
        """
        Args:
            e: unit element
            op: merge function
        """
        self._stack0 = []
        self._stack1 = []
        self._stack2 = []
        self._stack3 = []
        self._e = e
        self._e0 = e()
        self._e1 = e()
        self._size = 0
        self._op = op

    def append(self, value: E) -> None:
        if not self._stack0:
            self._push0(value)
            self._transfer()
        else:
            self._push1(value)
        self._size += 1

    def popleft(self) -> None:
        if not self._size:
            return
        if not self._stack0:
            self._transfer()
        self._stack0.pop()
        self._stack2.pop()
        self._e0 = self._stack2[-1] if self._stack2 else self._e()
        self._size -= 1

    def query(self) -> E:
        return self._op(self._e0, self._e1)

    def _push0(self, value):
        self._stack0.append(value)
        self._e0 = self._op(value, self._e0)
        self._stack2.append(self._e0)

    def _push1(self, value):
        self._stack1.append(value)
        self._e1 = self._op(self._e1, value)
        self._stack3.append(self._e1)

    def _transfer(self):
        while self._stack1:
            self._push0(self._stack1.pop())
        while self._stack3:
            self._stack3.pop()
        self._e1 = self._e()

    def __len__(self):
        return self._size

两种方法:暴力枚举/利用GCD性质,附题单(Python/Java/C++/Go)

方法一:计算最短的 GCD 等于 1 的子数组

提示 1

首先,如果所有数的 GCD(最大公约数)大于 $1$,那么无论如何都无法操作出 $1$,我们返回 $-1$。

如果 $\textit{nums}$ 中有一个 $1$,那么从 $1$ 向左向右不断替换就能把所有数变成 $1$。

例如 $[2,2,1,2,2]\rightarrow[2,\underline{1},1,2,2]\rightarrow[\underline{1},1,1,2,2]\rightarrow[1,1,1,\underline{1},2]\rightarrow[1,1,1,1,\underline{1}]$,一共 $n-1=5-1=4$ 次操作。

如果有多个 $1$,那么每个 $1$ 只需要向左修改,最后一个 $1$ 向右修改剩余的数字。

例如 $[2,1,2,1,2]\rightarrow[\underline{1},1,2,1,2]\rightarrow[1,1,\underline{1},1,2]\rightarrow[1,1,1,1,\underline{1}]$,一共 $n-\textit{cnt}_1=5-2=3$ 次操作。这里 $\textit{cnt}_1$ 表示 $\textit{nums}$ 中 $1$ 的个数。

所以如果 $\textit{nums}$ 中有 $1$,答案为

$$
n-\textit{cnt}_1
$$

如果 $\textit{nums}$ 中没有 $1$ 呢?

提示 2

如果 $\textit{nums}$ 中没有 $1$,想办法花费尽量少的操作得出一个 $1$。

由于只能操作相邻的数,所以这个 $1$ 必然是一个连续子数组的 GCD。(如果在不连续的情况下得到了 $1$,那么这个 $1$ 只能属于其中某个连续子数组,其余的操作是多余的。)

那么找到最短的 GCD 为 $1$ 的子数组,设其长度为 $\textit{minSize}$,那么我们需要操作 $\textit{minSize}-1$ 次得到 $1$。

例如 $[2,6,3,4]$ 中的 $[3,4]$ 可以操作 $2-1=1$ 次得到 $1$。

然后就转化成提示 1 中的情况了,最终答案为

$$
(\textit{minSize}-1) + (n-1) = \textit{minSize}+n-2
$$

本题视频讲解(第四题)

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        if gcd(*nums) > 1:
            return -1
        n = len(nums)
        cnt1 = sum(x == 1 for x in nums)
        if cnt1:
            return n - cnt1

        min_size = n
        for i in range(n):
            g = 0
            for j in range(i, n):
                g = gcd(g, nums[j])
                if g == 1:
                    # 这里本来是 j-i+1,把 +1 提出来合并到 return 中
                    min_size = min(min_size, j - i)
                    break
        return min_size + n - 1
class Solution {
    public int minOperations(int[] nums) {
        int n = nums.length, gcdAll = 0, cnt1 = 0;
        for (int x : nums) {
            gcdAll = gcd(gcdAll, x);
            if (x == 1) ++cnt1;
        }
        if (gcdAll > 1) return -1;
        if (cnt1 > 0) return n - cnt1;

        int minSize = n;
        for (int i = 0; i < n; i++) {
            int g = 0;
            for (int j = i; j < n; j++) {
                g = gcd(g, nums[j]);
                if (g == 1) {
                    // 这里本来是 j-i+1,把 +1 提出来合并到 return 中
                    minSize = Math.min(minSize, j - i);
                    break;
                }
            }
        }
        return minSize + n - 1;
    }

    private int gcd(int a, int b) {
        while (a != 0) {
            int tmp = a;
            a = b % a;
            b = tmp;
        }
        return b;
    }
}
class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size(), gcd_all = 0, cnt1 = 0;
        for (int x : nums) {
            gcd_all = gcd(gcd_all, x);
            cnt1 += x == 1;
        }
        if (gcd_all > 1) return -1;
        if (cnt1) return n - cnt1;

        int min_size = n;
        for (int i = 0; i < n; i++) {
            int g = 0;
            for (int j = i; j < n; j++) {
                g = gcd(g, nums[j]);
                if (g == 1) {
                    // 这里本来是 j-i+1,把 +1 提出来合并到 return 中
                    min_size = min(min_size, j - i);
                    break;
                }
            }
        }
        return min_size + n - 1;
    }
};
func minOperations(nums []int) int {
n, gcdAll, cnt1 := len(nums), 0, 0
for _, x := range nums {
gcdAll = gcd(gcdAll, x)
if x == 1 {
cnt1++
}
}
if gcdAll > 1 {
return -1
}
if cnt1 > 0 {
return n - cnt1
}

minSize := n
for i := range nums {
g := 0
for j, x := range nums[i:] {
g = gcd(g, x)
if g == 1 {
// 这里本来是 j+1,把 +1 提出来合并到 return 中
minSize = min(minSize, j)
break
}
}
}
return minSize + n - 1
}

func gcd(a, b int) int {
for a != 0 {
a, b = b%a, a
}
return b
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n(n+\log U))$,其中 $n$ 为 $\textit{nums}$ 的长度,$U=\max(\textit{nums})$。外层循环时,单看 $g=\textit{nums}[i]$,它因为求 GCD 减半的次数是 $\mathcal{O}(\log U)$ 次,因此内层循环的时间复杂度为 $\mathcal{O}(n+\log U)$,所以总的时间复杂度为 $\mathcal{O}(n(n+\log U))$。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:利用 GCD 的性质

前置知识LogTrick 入门教程

这个做法可以解决 $n=10^5$ 的情况。

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        if gcd(*nums) > 1:
            return -1
        n = len(nums)
        cnt1 = sum(x == 1 for x in nums)
        if cnt1:
            return n - cnt1

        min_size = n
        a = []  # [GCD,相同 GCD 闭区间的右端点]
        for i, x in enumerate(nums):
            a.append([x, i])

            # 原地去重,因为相同的 GCD 都相邻在一起
            j = 0
            for p in a:
                p[0] = gcd(p[0], x)
                if a[j][0] != p[0]:
                    j += 1
                    a[j] = p
                else:
                    a[j][1] = p[1]
            del a[j + 1:]

            if a[0][0] == 1:
                # 这里本来是 i-a[0][1]+1,把 +1 提出来合并到 return 中
                min_size = min(min_size, i - a[0][1])
        return min_size + n - 1
class Solution {
    public int minOperations(int[] nums) {
        int n = nums.length, gcdAll = 0, cnt1 = 0;
        for (int x : nums) {
            gcdAll = gcd(gcdAll, x);
            if (x == 1) ++cnt1;
        }
        if (gcdAll > 1) return -1;
        if (cnt1 > 0) return n - cnt1;

        int minSize = n;
        var g = new ArrayList<int[]>(); // [GCD,相同 GCD 闭区间的右端点]
        for (int i = 0; i < n; i++) {
            g.add(new int[]{nums[i], i});
            // 原地去重,因为相同的 GCD 都相邻在一起
            var j = 0;
            for (var p : g) {
                p[0] = gcd(p[0], nums[i]);
                if (g.get(j)[0] == p[0])
                    g.get(j)[1] = p[1]; // 合并相同值,下标取最小的
                else g.set(++j, p);
            }
            g.subList(j + 1, g.size()).clear();
            if (g.get(0)[0] == 1)
                // 这里本来是 i-g.get(0)[1]+1,把 +1 提出来合并到 return 中
                minSize = Math.min(minSize, i - g.get(0)[1]);
        }
        return minSize + n - 1;
    }

    private int gcd(int a, int b) {
        while (a != 0) {
            int tmp = a;
            a = b % a;
            b = tmp;
        }
        return b;
    }
}
class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size(), gcd_all = 0, cnt1 = 0;
        for (int x : nums) {
            gcd_all = gcd(gcd_all, x);
            cnt1 += x == 1;
        }
        if (gcd_all > 1) return -1;
        if (cnt1) return n - cnt1;

        int min_size = n;
        vector<pair<int, int>> g; // {GCD,相同 GCD 闭区间的右端点}
        for (int i = 0; i < n; i++) {
            g.emplace_back(nums[i], i);
            // 原地去重,因为相同的 GCD 都相邻在一起
            int j = 0;
            for (auto& p : g) {
                p.first = gcd(p.first, nums[i]);
                if (g[j].first == p.first)
                    g[j].second = p.second;
                else g[++j] = move(p);
            }
            g.resize(j + 1);
            if (g[0].first == 1)
                // 这里本来是 i-g[0].second+1,把 +1 提出来合并到 return 中
                min_size = min(min_size, i - g[0].second);
        }
        return min_size + n - 1;
    }
};
func minOperations(nums []int) int {
n, gcdAll, cnt1 := len(nums), 0, 0
for _, x := range nums {
gcdAll = gcd(gcdAll, x)
if x == 1 {
cnt1++
}
}
if gcdAll > 1 {
return -1
}
if cnt1 > 0 {
return n - cnt1
}

minSize := n
type result struct{ gcd, i int }
a := []result{}
for i, x := range nums {
for j, r := range a {
a[j].gcd = gcd(r.gcd, x)
}
a = append(a, result{x, i})

// 去重
j := 0
for _, q := range a[1:] {
if a[j].gcd != q.gcd {
j++
a[j] = q
} else {
a[j].i = q.i // 相同 gcd 保存最右边的位置
}
}
a = a[:j+1]

if a[0].gcd == 1 {
// 这里本来是 i-a[0].i+1,把 +1 提出来合并到 return 中
minSize = min(minSize, i-a[0].i)
}
}
return minSize + n - 1
}

func gcd(a, b int) int {
for a != 0 {
a, b = b%a, a
}
return b
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log U)$,其中 $n$ 为 $\textit{nums}$ 的长度,$U=\max(\textit{nums})$。单看每个元素,它因为求 GCD 减半的次数是 $\mathcal{O}(\log U)$ 次,并且每次去重的时间复杂度也为 $\mathcal{O}(\log U)$,因此时间复杂度为 $\mathcal{O}(n\log U)$。
  • 空间复杂度:$\mathcal{O}(\log U)$。

注:由于本题数据范围小,这两种做法的运行时间区别并不明显。

可以用该模板秒杀的题目

见下面位运算题单的「LogTrick」。

补充:

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

思维题,从 n^3 逐步优化至 nlogn

解法:思维

首先处理一些特殊情况:

  1. 如果整个序列的最大公约数都不是 $1$,那么无解。
  2. 如果序列里已经存在一些 $1$ 了,那么每次可以选择 $1$ 与它旁边的一个数 $x > 1$,把 $x$ 也变成 $1$。因此只需要 $c$ 步即可,其中 $c$ 表示序列中大于 $1$ 的数有几个。

剩下的情况就是:序列的最大公约数为 $1$,但是每个元素都大于 $1$。如果我们能通过某种方式把第一个 $1$ 弄出来,就转化为了上述第二种特殊情况,答案就是“弄出第一个 $1$ 的最少步数”,加上“弄出第一个 $1$ 后,序列里大于 $1$ 的数还有几个”。第二项的值显然是 $(n - 1)$,因此剩下的问题就是计算“弄出第一个 $1$ 的最少步数”。

首先注意到一个关键结论:进行任意次操作之后,序列里的第 $i$ 个数一定是 $\text{gcd}(a[l..r])$,其中 $a[l..r]$ 表示序列里从第 $l$ 个数到第 $r$ 个数形成的连续子数组,并且满足 $l \le i \le r$。这个结论可以用归纳法进行证明。

既然操作的结果是一段连续子数组的最大公约数,那么对于一个长度为 $k$ 的子数组,我们可以通过 $(k - 1)$ 次操作把其中一个数变成整个子数组的最大公约数。因此,我们只需要找到长度最小的子数组,使得子数组的最大公约数等于 $1$,那么“弄出第一个 $1$ 的最少步数”就是子数组的长度减去一。

因为整个序列的长度只有 $50$,我们完全可以从 $2$ 到 $n$ 枚举 $k$,并检查是否存在长度为 $k$ 的,且最大公约数为 $1$ 的子数组。复杂度 $\mathcal{O}(n^3)$。

参考代码(c++)

###c++

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size();

        // 特殊情况 1:整个序列的 gcd > 1
        int g = nums[0];
        for (int x : nums) g = gcd(g, x);
        if (g > 1) return -1;

        // 特殊情况 2:序列里已经有 1
        int cnt = 0;
        for (int x : nums) if (x != 1) cnt++;
        if (cnt != n) return cnt;

        // 剩余情况,枚举子数组的长度 l 和子数组的初始下标 i,检查 a[i..i + l - 1] 的 gcd 是否为 1
        for (int l = 2; l <= n; l++) for (int i = 0, j = l - 1; j < n; i++, j++) {
            int g = nums[i];
            for (int k = i; k <= j; k++) g = gcd(g, nums[k]);
            // 找到了符合要求的子数组,答案就是子数组长度减去一,再加上 (n - 1)
            if (g == 1) return l - 1 + n - 1;
        }

        // 不可能,但是函数的所有 branch 一定要有返回值
        assert(false);
        return -1;
    }
};

我们还可以对以上代码进行优化。主要是针对“寻找最大公约数等于 $1$,且长度最小的子数组”这一部分。如果一个子数组的最大公约数为 $1$,那么包含该子数组的其它子数组最大公约数也等于 $1$,因此这一部分可以使用 two pointers 解决,复杂度降至 $\mathcal{O}(n^2)$。

参考代码(c++)

###c++

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size();

        // 特殊情况 1:整个序列的 gcd > 1
        int g = nums[0];
        for (int x : nums) g = gcd(g, x);
        if (g > 1) return -1;

        // 特殊情况 2:序列里已经有 1
        int cnt = 0;
        for (int x : nums) if (x != 1) cnt++;
        if (cnt != n) return cnt;

        // 求 a[L..R] 的 gcd
        auto gao = [&](int L, int R) {
            int g = 0;
            for (int i = L; i <= R; i++) g = gcd(g, nums[i]);
            return g;
        };

        int ans = 1e9;
        // 双指针求以 i 为开头的,长度最短的,最大公约数为 1 的子数组
        // 这个子数组含 i 不含 j
        for (int i = 0, j = 0; i < n; i++) {
            while (j < n && gao(i, j - 1) != 1) j++;
            // 找到了符合要求的子数组,答案就是子数组长度减去一,再加上 (n - 1)
            if (gao(i, j - 1) == 1) ans = min(ans, j - i - 1 + n - 1);
        }
        return ans;
    }
};

上述求子数组最大公约数的 gao 函数仍然可以优化。由于最大公约数满足结合律,我们可以通过 RMQ线段树 在每次 $\mathcal{O}(\log n)$ 的复杂度下计算子数组的最大公约数,复杂度降至 $\mathcal{O}(n\log n)$。

参考代码(c++)

###c++

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size();

        // 特殊情况 1:整个序列的 gcd > 1
        int g = nums[0];
        for (int x : nums) g = gcd(g, x);
        if (g > 1) return -1;

        // 特殊情况 2:序列里已经有 1
        int cnt = 0;
        for (int x : nums) if (x != 1) cnt++;
        if (cnt != n) return cnt;

        // go[id] 是线段树节点 id 代表的区间的 gcd
        int go[n * 4 + 10];
        // 构建线段树
        function<void(int, int, int)> build = [&](int id, int l, int r) {
            if (l == r) go[id] = nums[l];
            else {
                int nxt = id << 1, mid = (l + r) >> 1;
                build(nxt, l, mid);
                build(nxt | 1, mid + 1, r);
                go[id] = gcd(go[nxt], go[nxt | 1]);
            }
        };
        build(1, 0, n - 1);

        // 查询 [ql, qr] 区间的 gcd
        function<int(int, int, int, int, int)> query = [&](int id, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) return go[id];
            int nxt = id << 1, mid = (l + r) >> 1;
            return gcd(
                ql <= mid ? query(nxt, l, mid, ql, qr) : 0,
                qr > mid ? query(nxt | 1, mid + 1, r, ql, qr) : 0
            );
        };

        // 求 a[L..R] 的 gcd
        auto gao = [&](int L, int R) {
            if (L > R) return 0;
            return query(1, 0, n - 1, L, R);
        };

        int ans = 1e9;
        // 双指针求以 i 为开头的,长度最短的,最大公约数为 1 的子数组
        // 这个子数组含 i 不含 j
        for (int i = 0, j = 0; i < n; i++) {
            while (j < n && gao(i, j - 1) != 1) j++;
            // 找到了符合要求的子数组,答案就是子数组长度减去一,再加上 (n - 1)
            if (gao(i, j - 1) == 1) ans = min(ans, j - i - 1 + n - 1);
        }
        return ans;
    }
};

每日一题-一和零🟡

给你一个二进制字符串数组 strs 和两个整数 mn

请你找出并返回 strs 的最大子集的长度,该子集中 最多m0n1

如果 x 的所有元素也是 y 的元素,集合 x 是集合 y子集

 

示例 1:

输入:strs = ["10", "0001", "111001", "1", "0"], m = 5, n = 3
输出:4
解释:最多有 5 个 0 和 3 个 1 的最大子集是 {"10","0001","1","0"} ,因此答案是 4 。
其他满足题意但较小的子集包括 {"0001","1"} 和 {"10","1","0"} 。{"111001"} 不满足题意,因为它含 4 个 1 ,大于 n 的值 3 。

示例 2:

输入:strs = ["10", "0", "1"], m = 1, n = 1
输出:2
解释:最大的子集是 {"0", "1"} ,所以答案是 2 。

 

提示:

  • 1 <= strs.length <= 600
  • 1 <= strs[i].length <= 100
  • strs[i] 仅由 '0' 和 '1' 组成
  • 1 <= m, n <= 100

一步步思考:从记忆化搜索到递推到空间优化!(Python/Java/C++/Go)

前言

设 $\textit{strs}[i]$ 中 $0$ 的个数为 $\textit{cnt}_0[i]$,$1$ 的个数为 $\textit{cnt}_1[i]$,那么本题相当于:

  • 有一个容量为 $(m,n)$ 的背包,至多可以装入 $m$ 个 $0$ 和 $n$ 个 $1$。现在有 $n$ 个物品,每个物品的体积为 $(\textit{cnt}_0[i],\textit{cnt}_1[i])$,表示该物品有 $\textit{cnt}_0[i]$ 个 $0$ 和 $\textit{cnt}_1[i]$ 个 $1$。问:最多可以选多少个物品?

这相当于背包有两种体积(二维),所以在定义状态的时候,相比只有一种体积的 0-1 背包,要多加一个参数。

如果你不了解 0-1 背包,请看【基础算法精讲 18】

一、记忆化搜索

在一维 0-1 背包的基础上,多加一个参数,即定义 $\textit{dfs}(i,j,k)$ 表示在 $[0,i]$ 中选字符串,在 $0$ 的个数至多为 $j$,$1$ 的个数至多为 $k$ 的约束下,至多可以选多少个字符串。

考虑 $\textit{strs}[i]$ 选或不选:

  • 不选:问题变成在 $[0,i-1]$ 中选字符串,在 $0$ 的个数至多为 $j$,$1$ 的个数至多为 $k$ 的约束下,至多可以选多少个字符串,即 $\textit{dfs}(i,j,k) = \textit{dfs}(i-1,j,k)$。
  • 选:如果 $j\ge \textit{cnt}_0[i]$ 并且 $k\ge \textit{cnt}_1[i]$ 则可以选。问题变成在 $[0,i-1]$ 中选字符串,在 $0$ 的个数至多为 $j-\textit{cnt}_0[i]$,$1$ 的个数至多为 $k-\textit{cnt}_1[i]$ 的约束下,至多可以选多少个字符串,即 $\textit{dfs}(i,j,k) = \textit{dfs}(i-1,j-\textit{cnt}_0[i],k-\textit{cnt}_1[i]) + 1$。

两种情况取最大值,得

$$
\textit{dfs}(i,j,k) = \max(\textit{dfs}(i-1,j,k), \textit{dfs}(i-1,j-\textit{cnt}_0[i],k-\textit{cnt}_1[i]) + 1)
$$

如果

递归边界:$\textit{dfs}(-1,j,k)=0$。此时没有物品可以选。

递归入口:$\textit{dfs}(k-1,m,n)$,这是原问题,也是答案。其中 $k$ 为 $\textit{strs}$ 的长度。

class Solution:
    def findMaxForm(self, strs: List[str], m: int, n: int) -> int:
        cnt0 = [s.count('0') for s in strs]

        @cache  # 缓存装饰器,避免重复计算 dfs 的结果(记忆化)
        def dfs(i: int, j: int, k: int) -> int:
            if i < 0:
                return 0
            res = dfs(i - 1, j, k)  # 不选 strs[i]
            cnt1 = len(strs[i]) - cnt0[i]
            if j >= cnt0[i] and k >= cnt1:
                res = max(res, dfs(i - 1, j - cnt0[i], k - cnt1) + 1)  # 选 strs[i]
            return res

        return dfs(len(strs) - 1, m, n)
class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int k = strs.length;
        int[] cnt0 = new int[k];
        for (int i = 0; i < k; i++) {
            cnt0[i] = (int) strs[i].chars().filter(ch -> ch == '0').count();
        }

        int[][][] memo = new int[strs.length][m + 1][n + 1];
        for (int[][] mat : memo) {
            for (int[] arr : mat) {
                Arrays.fill(arr, -1); // -1 表示没有计算过
            }
        }
        return dfs(k - 1, m, n, strs, cnt0, memo);
    }

    private int dfs(int i, int j, int k, String[] strs, int[] cnt0, int[][][] memo) {
        if (i < 0) {
            return 0;
        }
        if (memo[i][j][k] != -1) { // 之前计算过
            return memo[i][j][k];
        }
        // 不选 strs[i]
        int res = dfs(i - 1, j, k, strs, cnt0, memo);  
        int cnt1 = strs[i].length() - cnt0[i];
        if (j >= cnt0[i] && k >= cnt1) {
            // 选 strs[i]
            res = Math.max(res, dfs(i - 1, j - cnt0[i], k - cnt1, strs, cnt0, memo) + 1);
        }
        return memo[i][j][k] = res; // 记忆化
    }
}
class Solution {
public:
    int findMaxForm(vector<string>& strs, int m, int n) {
        vector<int> cnt0(strs.size());
        for (int i = 0; i < strs.size(); i++) {
            cnt0[i] = ranges::count(strs[i], '0');
        }

        vector memo(strs.size(), vector(m + 1, vector<int>(n + 1, -1))); // -1 表示没有计算过
        auto dfs = [&](this auto&& dfs, int i, int j, int k) -> int {
            if (i < 0) {
                return 0;
            }
            int& res = memo[i][j][k]; // 注意这里是引用
            if (res != -1) { // 之前计算过
                return res;
            }
            res = dfs(i - 1, j, k); // 不选 strs[i]
            int cnt1 = strs[i].size() - cnt0[i];
            if (j >= cnt0[i] && k >= cnt1) {
                res = max(res, dfs(i - 1, j - cnt0[i], k - cnt1) + 1); // 选 strs[i]
            }
            return res;
        };
        return dfs(strs.size() - 1, m, n);
    }
};
func findMaxForm(strs []string, m, n int) int {
    k := len(strs)
    cnt0 := make([]int, k)
    for i, s := range strs {
        cnt0[i] = strings.Count(s, "0")
    }

    memo := make([][][]int, k)
    for i := range memo {
        memo[i] = make([][]int, m+1)
        for j := range memo[i] {
            memo[i][j] = make([]int, n+1)
            for k := range memo[i][j] {
                memo[i][j][k] = -1 // -1 表示没有计算过
            }
        }
    }
    var dfs func(int, int, int) int
    dfs = func(i, j, k int) int {
        if i < 0 {
            return 0
        }
        p := &memo[i][j][k]
        if *p != -1 { // 之前计算过
            return *p
        }
        res := dfs(i-1, j, k) // 不选 strs[i]
        cnt1 := len(strs[i]) - cnt0[i]
        if j >= cnt0[i] && k >= cnt1 {
            res = max(res, dfs(i-1, j-cnt0[i], k-cnt1)+1) // 选 strs[i]
        }
        *p = res // 记忆化
        return res
    }
    return dfs(k-1, m, n)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(kmn+L)$,其中 $k$ 为 $\textit{strs}$ 的长度,$L$ 为 $\textit{strs}$ 中所有字符串的长度之和。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(kmn)$,单个状态的计算时间为 $\mathcal{O}(1)$,所以总的时间复杂度为 $\mathcal{O}(kmn)$。
  • 空间复杂度:$\mathcal{O}(kmn)$。保存多少状态,就需要多少空间。

二、1:1 翻译成递推

我们可以去掉递归中的「递」,只保留「归」的部分,即自底向上计算。

具体来说,$f[i+1][j][k]$ 的定义和 $\textit{dfs}(i,j,k)$ 的定义是一样的,都表示在 $[0,i]$ 中选字符串,在 $0$ 的个数至多为 $j$,$1$ 的个数至多为 $k$ 的约束下,至多可以选多少个字符串。这里 $+1$ 是为了把 $\textit{dfs}(-1,j,k)$ 这个状态也翻译过来,这样我们可以把 $f[0][j][k]$ 作为初始值。

相应的递推式(状态转移方程)也和 $\textit{dfs}$ 一样:

$$
f[i+1][j][k] = \max(f[i][j][k], f[i][j-\textit{cnt}_0[i]][k-\textit{cnt}_1[i]] + 1)
$$

初始值 $f[0][j][k]=0$,翻译自递归边界 $\textit{dfs}(-1,j,k)=0$。

答案为 $f[k][m][n]$,翻译自递归入口 $\textit{dfs}(k-1,m,n)$。其中 $k$ 为 $\textit{strs}$ 的长度。

class Solution:
    def findMaxForm(self, strs: List[str], m: int, n: int) -> int:
        f = [[[0] * (n + 1) for _ in range(m + 1)] for _ in range(len(strs) + 1)]
        for i, s in enumerate(strs):
            cnt0 = s.count('0')
            cnt1 = len(s) - cnt0
            for j in range(m + 1):
                for k in range(n + 1):
                    if j >= cnt0 and k >= cnt1:
                        f[i + 1][j][k] = max(f[i][j][k], f[i][j - cnt0][k - cnt1] + 1)
                    else:
                        f[i + 1][j][k] = f[i][j][k]
        return f[-1][m][n]
class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int[][][] f = new int[strs.length + 1][m + 1][n + 1];
        for (int i = 0; i < strs.length; i++) {
            int cnt0 = (int) strs[i].chars().filter(ch -> ch == '0').count();
            int cnt1 = strs[i].length() - cnt0;
            for (int j = 0; j <= m; j++) {
                for (int k = 0; k <= n; k++) {
                    if (j >= cnt0 && k >= cnt1) {
                        f[i + 1][j][k] = Math.max(f[i][j][k], f[i][j - cnt0][k - cnt1] + 1);
                    } else {
                        f[i + 1][j][k] = f[i][j][k];
                    }
                }
            }
        }
        return f[strs.length][m][n];
    }
}
class Solution {
public:
    int findMaxForm(vector<string>& strs, int m, int n) {
        vector f(strs.size() + 1, vector(m + 1, vector<int>(n + 1)));
        for (int i = 0; i < strs.size(); i++) {
            int cnt0 = ranges::count(strs[i], '0');
            int cnt1 = strs[i].size() - cnt0;
            for (int j = 0; j <= m; j++) {
                for (int k = 0; k <= n; k++) {
                    if (j >= cnt0 && k >= cnt1) {
                        f[i + 1][j][k] = max(f[i][j][k], f[i][j - cnt0][k - cnt1] + 1);
                    } else {
                        f[i + 1][j][k] = f[i][j][k];
                    }
                }
            }
        }
        return f.back()[m][n];
    }
};
func findMaxForm(strs []string, m, n int) int {
    k := len(strs)
    f := make([][][]int, k+1)
    for i := range f {
        f[i] = make([][]int, m+1)
        for j := range f[i] {
            f[i][j] = make([]int, n+1)
        }
    }
    for i, s := range strs {
        cnt0 := strings.Count(s, "0")
        cnt1 := len(s) - cnt0
        for j := range m + 1 {
            for k := range n + 1 {
                if j >= cnt0 && k >= cnt1 {
                    f[i+1][j][k] = max(f[i][j][k], f[i][j-cnt0][k-cnt1]+1)
                } else {
                    f[i+1][j][k] = f[i][j][k]
                }
            }
        }
    }
    return f[k][m][n]
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(kmn+L)$,其中 $k$ 为 $\textit{strs}$ 的长度,$L$ 为 $\textit{strs}$ 中所有字符串的长度之和。
  • 空间复杂度:$\mathcal{O}(kmn)$。

三、空间优化

观察上面的状态转移方程,在计算 $f[i+1]$ 时,只会用到 $f[i]$,不会用到比 $i$ 更早的状态。

那么去掉第一个维度,把 $f[i+1]$ 和 $f[i]$ 保存到同一个二维数组中。

状态转移方程改为

$$
f[j][k] = \max(f[j][k], f[j-\textit{cnt}_0[i]][k-\textit{cnt}_1[i]] + 1)
$$

初始值 $f[j][k]=0$。

答案为 $f[m][n]$。

下面代码为什么要倒序循环,请看【基础算法精讲 18】

class Solution:
    def findMaxForm(self, strs: List[str], m: int, n: int) -> int:
        f = [[0] * (n + 1) for _ in range(m + 1)]
        for s in strs:
            cnt0 = s.count('0')
            cnt1 = len(s) - cnt0
            for j in range(m, cnt0 - 1, -1):
                for k in range(n, cnt1 - 1, -1):
                    f[j][k] = max(f[j][k], f[j - cnt0][k - cnt1] + 1)
        return f[m][n]
class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int[][] f = new int[m + 1][n + 1];
        for (String s : strs) {
            int cnt0 = (int) s.chars().filter(ch -> ch == '0').count();
            int cnt1 = s.length() - cnt0;
            for (int j = m; j >= cnt0; j--) {
                for (int k = n; k >= cnt1; k--) {
                    f[j][k] = Math.max(f[j][k], f[j - cnt0][k - cnt1] + 1);
                }
            }
        }
        return f[m][n];
    }
}
class Solution {
public:
    int findMaxForm(vector<string>& strs, int m, int n) {
        vector f(m + 1, vector<int>(n + 1));
        for (string& s : strs) {
            int cnt0 = ranges::count(s, '0');
            int cnt1 = s.size() - cnt0;
            for (int j = m; j >= cnt0; j--) {
                for (int k = n; k >= cnt1; k--) {
                    f[j][k] = max(f[j][k], f[j - cnt0][k - cnt1] + 1);
                }
            }
        }
        return f[m][n];
    }
};
func findMaxForm(strs []string, m, n int) int {
    f := make([][]int, m+1)
    for i := range f {
        f[i] = make([]int, n+1)
    }
    for _, s := range strs {
        cnt0 := strings.Count(s, "0")
        cnt1 := len(s) - cnt0
        for j := m; j >= cnt0; j-- {
            for k := n; k >= cnt1; k-- {
                f[j][k] = max(f[j][k], f[j-cnt0][k-cnt1]+1)
            }
        }
    }
    return f[m][n]
}

进一步优化

比如 $n=m=90$,前 $3$ 个字符串总共有 $5$ 个 $0$ 和 $6$ 个 $1$,那么无论我们怎么选,也选不出几十个 $0$ 和 $1$,所以上面的代码中,其实有大量的循环是多余的。

为此,额外用两个变量 $\textit{sum}_0$ 和 $\textit{sum}_1$ 分别维护前 $i$ 个字符串中的 $0$ 的个数和 $1$ 的个数(但不能超过 $m$ 和 $n$)。循环的时候 $j$ 从 $\textit{sum}_0$ 开始,$k$ 从 $\textit{sum}_1$ 开始。

注意这个优化会导致只有一部分 $f[j][k]$ 被更新到,最大值并没有传递给 $f[m][n]$,可能留在二维数组中间的某个位置上。所以最后要遍历 $f$,取其中最大值作为答案。

class Solution:
    def findMaxForm(self, strs: List[str], m: int, n: int) -> int:
        f = [[0] * (n + 1) for _ in range(m + 1)]
        sum0 = sum1 = 0
        for s in strs:
            cnt0 = s.count('0')
            cnt1 = len(s) - cnt0
            sum0 = min(sum0 + cnt0, m)
            sum1 = min(sum1 + cnt1, n)
            for j in range(sum0, cnt0 - 1, -1):
                for k in range(sum1, cnt1 - 1, -1):
                    v = f[j - cnt0][k - cnt1] + 1
                    if v > f[j][k]:  # 手写 max,效率更高
                        f[j][k] = v
        return max(map(max, f))
class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int[][] f = new int[m + 1][n + 1];
        int sum0 = 0;
        int sum1 = 0;
        for (String s : strs) {
            int cnt0 = (int) s.chars().filter(ch -> ch == '0').count();
            int cnt1 = s.length() - cnt0;
            sum0 = Math.min(sum0 + cnt0, m);
            sum1 = Math.min(sum1 + cnt1, n);
            for (int j = sum0; j >= cnt0; j--) {
                for (int k = sum1; k >= cnt1; k--) {
                    f[j][k] = Math.max(f[j][k], f[j - cnt0][k - cnt1] + 1);
                }
            }
        }
        int ans = 0;
        for (int[] row : f) {
            for (int v : row) {
                ans = Math.max(ans, v);
            }
        }
        return ans;
    }
}
class Solution {
public:
    int findMaxForm(vector<string>& strs, int m, int n) {
        vector f(m + 1, vector<int>(n + 1));
        int sum0 = 0, sum1 = 0;
        for (string& s : strs) {
            int cnt0 = ranges::count(s, '0');
            int cnt1 = s.size() - cnt0;
            sum0 = min(sum0 + cnt0, m);
            sum1 = min(sum1 + cnt1, n);
            for (int j = sum0; j >= cnt0; j--) {
                for (int k = sum1; k >= cnt1; k--) {
                    f[j][k] = max(f[j][k], f[j - cnt0][k - cnt1] + 1);
                }
            }
        }
        int ans = 0;
        for (auto& row : f) {
            ans = max(ans, ranges::max(row));
        }
        return ans;
    }
};
func findMaxForm(strs []string, m, n int) (ans int) {
    f := make([][]int, m+1)
    for i := range f {
        f[i] = make([]int, n+1)
    }
    sum0, sum1 := 0, 0
    for _, s := range strs {
        cnt0 := strings.Count(s, "0")
        cnt1 := len(s) - cnt0
        sum0 = min(sum0+cnt0, m)
        sum1 = min(sum1+cnt1, n)
        for j := sum0; j >= cnt0; j-- {
            for k := sum1; k >= cnt1; k-- {
                f[j][k] = max(f[j][k], f[j-cnt0][k-cnt1]+1)
            }
        }
    }
    for _, row := range f {
        ans = max(ans, slices.Max(row))
    }
    return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(kmn+L)$,其中 $k$ 为 $\textit{strs}$ 的长度,$L$ 为 $\textit{strs}$ 中所有字符串的长度之和。
  • 空间复杂度:$\mathcal{O}(mn)$。

更多相似题目,见 动态规划题单 中的「§3.1 0-1 背包」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 【本题相关】动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

【宫水三叶】详解如何转换「背包问题」,以及逐步空间优化

(多维)01 背包

通常与「背包问题」相关的题考察的是 将原问题转换为「背包问题」的能力

要将原问题转换为「背包问题」,往往需要从题目中抽象出「价值」与「成本」的概念。

这道题如果抽象成「背包问题」的话,应该是:

每个字符串的价值都是 1(对答案的贡献都是 1),选择的成本是该字符串中 1 的数量和 0 的数量。

问我们在 1 的数量不超过 $m$,0 的数量不超过 $n$ 的条件下,最大价值是多少。

由于每个字符串只能被选一次,且每个字符串的选与否对应了「价值」和「成本」,求解的问题也是「最大价值」是多少。

因此可以直接套用 01 背包的「状态定义」来做:

$f[k][i][j]$ 代表考虑前 k 件物品,在数字 1 容量不超过 $i$,数字 0 容量不超过 $j$ 的条件下的「最大价值」(每个字符串的价值均为 1)。

有了「状态定义」之后,「转移方程」也很好推导:

$$f[k][i][j] = \max(f[k - 1][i][j], f[k - 1][i - cnt[k][0]][j - cnt[k][1]] + 1)$$

其中 $cnt$ 数组记录的是字符串中出现的 $01$ 数量。

代码(为了方便理解,$P1$ 将第一件物品的处理单独抽了出来,也可以不抽出来,只需要将让物品下标从 $1$ 开始即可,见 $P2$):

###Java

class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        // 预处理每一个字符包含 0 和 1 的数量
        int[][] cnt = new int[len][2];
        for (int i = 0; i < len; i++) {
            String str = strs[i];
            int zero = 0, one = 0;
            for (char c : str.toCharArray()) {
                if (c == '0') {
                    zero++;
                } else {
                    one++;
                }
            }
            cnt[i] = new int[]{zero, one}; 
        }

        // 处理只考虑第一件物品的情况
        int[][][] f = new int[len][m + 1][n + 1];
        for (int i = 0; i <= m; i++) {
            for (int j = 0; j <= n; j++) {
                f[0][i][j] = (i >= cnt[0][0] && j >= cnt[0][1]) ? 1 : 0;
            }
        }

        // 处理考虑其余物品的情况
        for (int k = 1; k < len; k++) {
            int zero = cnt[k][0], one = cnt[k][1];
            for (int i = 0; i <= m; i++) {
                for (int j = 0; j <= n; j++) {
                    // 不选择第 k 件物品
                    int a = f[k-1][i][j];
                    // 选择第 k 件物品(前提是有足够的 m 和 n 额度可使用)
                    int b = (i >= zero && j >= one) ? f[k-1][i-zero][j-one] + 1 : 0;
                    f[k][i][j] = Math.max(a, b);
                }
            }
        }
        return f[len-1][m][n];
    }
}

###Java

class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        int[][] cnt = new int[len][2];
        for (int i = 0; i < len; i++) {
            String str = strs[i];
            int zero = 0, one = 0;
            for (char c : str.toCharArray()) {
                if (c == '0') zero++;    
                else one++;

            }
            cnt[i] = new int[]{zero, one}; 
        }
        int[][][] f = new int[len + 1][m + 1][n + 1];
        for (int k = 1; k <= len; k++) {
            int zero = cnt[k - 1][0], one = cnt[k - 1][1];
            for (int i = 0; i <= m; i++) {
                for (int j = 0; j <= n; j++) {
                    int a = f[k - 1][i][j];
                    int b = (i >= zero && j >= one) ? f[k - 1][i - zero][j - one] + 1 : 0;
                    f[k][i][j] = Math.max(a, b);
                }
            }
        }
        return f[len][m][n];
    }
}
  • 时间复杂度:预处理字符串的复杂度为 $O(\sum_{i = 0}^{k - 1}len(strs[i]))$,处理状态转移的 $O(k * m * n)$。整体复杂度为:$O(k * m * n + \sum_{i = 0}^{k - 1}len(strs[i]))$
  • 空间复杂度:$O(k * m * n)$

滚动数组

根据「状态转移」可知,更新某个物品的状态时,只依赖于上一个物品的状态。

因此,可以使用「滚动数组」的方式进行空间优化。

代码(为了方便理解,$P1$ 将第一件物品的处理单独抽了出来,也可以不抽出来,只需要将让物品下标从 $1$ 开始即可,见 $P2$):

###Java

class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        // 预处理每一个字符包含 0 和 1 的数量
        int[][] cnt = new int[len][2];
        for (int i = 0; i < len; i++) {
            String str = strs[i];
            int zero = 0, one = 0;
            for (char c : str.toCharArray()) {
                if (c == '0') {
                    zero++;
                } else {
                    one++;
                }
            }
            cnt[i] = new int[]{zero, one}; 
        }

        // 处理只考虑第一件物品的情况
        // 「物品维度」修改为 2 
        int[][][] f = new int[2][m + 1][n + 1];
        for (int i = 0; i <= m; i++) {
            for (int j = 0; j <= n; j++) {
                f[0][i][j] = (i >= cnt[0][0] && j >= cnt[0][1]) ? 1 : 0;
            }
        }

        // 处理考虑其余物品的情况
        for (int k = 1; k < len; k++) {
            int zero = cnt[k][0], one = cnt[k][1];
            for (int i = 0; i <= m; i++) {
                for (int j = 0; j <= n; j++) {
                    // 不选择第 k 件物品
                    // 将 k-1 修改为 (k-1)&1
                    int a = f[(k-1)&1][i][j];
                    // 选择第 k 件物品(前提是有足够的 m 和 n 额度可使用)
                    // 将 k-1 修改为 (k-1)&1
                    int b = (i >= zero && j >= one) ? f[(k-1)&1][i-zero][j-one] + 1 : 0;
                    f[k&1][i][j] = Math.max(a, b);
                }
            }
        }
        // 将 len-1 修改为 (len-1)&1
        return f[(len-1)&1][m][n];
    }
}

###Java

class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        int[][] cnt = new int[len][2];
        for (int i = 0; i < len; i++) {
            String str = strs[i];
            int zero = 0, one = 0;
            for (char c : str.toCharArray()) {
                if (c == '0') zero++;
                else one++; 
            }
            cnt[i] = new int[]{zero, one}; 
        }
        int[][][] f = new int[2][m + 1][n + 1];
        for (int k = 1; k <= len; k++) {
            int zero = cnt[k - 1][0], one = cnt[k - 1][1];
            for (int i = 0; i <= m; i++) {
                for (int j = 0; j <= n; j++) {
                    int a = f[(k-1) & 1][i][j];
                    int b = (i >= zero && j >= one) ? f[(k-1) & 1][i - zero][j - one] + 1 : 0;
                    f[k&1][i][j] = Math.max(a, b);
                }
            }
        }
        return f[len&1][m][n];
    }
}
  • 时间复杂度:预处理字符串的复杂度为 $O(\sum_{i = 0}^{k - 1}len(strs[i]))$,处理状态转移的 $O(k * m * n)$。整体复杂度为:$O(k * m * n + \sum_{i = 0}^{k - 1}len(strs[i]))$
  • 空间复杂度:$O(m * n)$

一维空间优化

事实上,我们还能继续进行空间优化。

再次观察我们的「状态转移方程」发现:$f[k][i][j]$ 不仅仅依赖于上一行,还明确依赖于比 $i$ 小和比 $j$ 小的状态。

即可只依赖于「上一行」中「正上方」的格子,和「正上方左边」的格子。

对应到「朴素的 01 背包问题」依赖关系如图:

image.png

因此可直接参考「01 背包的空间优化」方式:取消掉「物品维度」,然后调整容量的遍历顺序。

代码:

###Java

class Solution {
    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        int[][] cnt = new int[len][2];
        for (int i = 0; i < len; i++) {
            int zero = 0, one = 0;
            for (char c : strs[i].toCharArray()) {
                if (c == '0') zero++;
                else one++;
            }
            cnt[i] = new int[]{zero, one};
        }
        int[][] f = new int[m + 1][n + 1];
        for (int k = 0; k < len; k++) {
            int zero = cnt[k][0], one = cnt[k][1];
            for (int i = m; i >= zero; i--) {
                for (int j = n; j >= one; j--) {
                    f[i][j] = Math.max(f[i][j], f[i - zero][j - one] + 1);
                }
            }
        }
        return f[m][n];
    }
}
  • 时间复杂度:预处理字符串的复杂度为 $O(\sum_{i = 0}^{k - 1}len(strs[i]))$,处理状态转移的 $O(k * m * n)$。整体复杂度为:$O(k * m * n + \sum_{i = 0}^{k - 1}len(strs[i]))$
  • 空间复杂度:$O(m * n)$

其他「背包」问题

看不懂「背包」解决方案?

以下是公主号讲过的「背包专题」系列目录,欢迎 关注 🍭🍭🍭 :

  1. 01背包 : 背包问题 第一讲

    1. 【练习】01背包 : 背包问题 第二讲

    2. 【学习&练习】01背包 : 背包问题 第三讲

    3. 【加餐/补充】01 背包:背包问题 第二十一讲

  2. 完全背包 : 背包问题 第四讲

    1. 【练习】完全背包 : 背包问题 第五讲

    2. 【练习】完全背包 : 背包问题 第六讲

    3. 【练习】完全背包 : 背包问题 第七讲

  3. 多重背包 : 背包问题 第八讲

  4. 多重背包(优化篇)

    1. 【上】多重背包(优化篇): 背包问题 第九讲

    2. 【下】多重背包(优化篇): 背包问题 第十讲

  5. 混合背包 : 背包问题 第十一讲

  6. 分组背包 : 背包问题 第十二讲

    1. 【练习】分组背包 : 背包问题 第十三讲
  7. 多维背包

    1. 【练习】多维背包 : 背包问题 第十四讲

    2. 【练习】多维背包 : 背包问题 第十五讲

  8. 树形背包 : 背包问题 第十六讲

    1. 【练习篇】树形背包 : 背包问题 第十七讲

    2. 【练习篇】树形背包 : 背包问题 第十八讲

  9. 背包求方案数

    1. 【练习】背包求方案数 : 背包问题 第十九讲

    2. 【练习】背包求方案数 : 背包问题 第十五讲

    [注:因为之前实在找不到题,这道「求方案数」题作为“特殊”的「多维费用背包问题求方案数」讲过]

  10. 背包求具体方案

    1. 【练习】背包求具体方案 : 背包问题 第二十讲
  11. 泛化背包

    1. 【练习】泛化背包

最后

如果有帮助到你,请给题解点个赞和收藏,让更多的人看到 ~ ("▔□▔)/

也欢迎你 关注我(公主号后台回复「送书」即可参与长期看题解学算法送实体书活动)或 加入「组队打卡」小群 ,提供写「证明」&「思路」的高质量题解。

所有题解已经加入 刷题指南,欢迎 star 哦 ~

动态规划(转换为 0-1 背包问题)

思路:把总共的 01 的个数视为背包的容量,每一个字符串视为装进背包的物品。这道题就可以使用 0-1 背包问题的思路完成,这里的目标值是能放进背包的字符串的数量。


动态规划的思路是:物品一个一个尝试,容量一点一点尝试,每个物品分类讨论的标准是:选与不选。

定义状态:尝试题目问啥,就把啥定义成状态。dp[i][j][k] 表示输入字符串在子区间 [0, i] 能够使用 j0k1 的字符串的最大数量。
状态转移方程
$$dp[i][j][k]=
\begin{cases}
dp[i - 1][j][k], & 不选择当前考虑的字符串,至少是这个数值\
dp[i - 1][j - 当前字符串使用 ;0; 的个数][k - 当前字符串使用 ;1; 的个数] + 1 & 选择当前考虑的字符串
\end{cases}
$$
初始化:为了避免分类讨论,通常多设置一行。这里可以认为,第 $0$ 个字符串是空串。第 $0$ 行默认初始化为 $0$。
输出:输出是最后一个状态,即:dp[len][m][n]

参考代码1

###Java

public class Solution {

    public int findMaxForm(String[] strs, int m, int n) {
        int len = strs.length;
        int[][][] dp = new int[len + 1][m + 1][n + 1];

        for (int i = 1; i <= len; i++) {
            // 注意:有一位偏移
            int[] count = countZeroAndOne(strs[i - 1]);
            for (int j = 0; j <= m; j++) {
                for (int k = 0; k <= n; k++) {
                    // 先把上一行抄下来
                    dp[i][j][k] = dp[i - 1][j][k];
                    int zeros = count[0];
                    int ones = count[1];
                    if (j >= zeros && k >= ones) {
                        dp[i][j][k] = Math.max(dp[i - 1][j][k], dp[i - 1][j - zeros][k - ones] + 1);
                    }
                }
            }
        }
        return dp[len][m][n];
    }

    private int[] countZeroAndOne(String str) {
        int[] cnt = new int[2];
        for (char c : str.toCharArray()) {
            cnt[c - '0']++;
        }
        return cnt;
    }
}

第 5 步:思考优化空间

因为当前行只参考了上一行的值,因此可以使用「滚动数组」,也可以「从后向前赋值」。

参考代码2:这里选用「从后向前赋值」

###Java

public class Solution {

    public int findMaxForm(String[] strs, int m, int n) {
        int[][] dp = new int[m + 1][n + 1];
        dp[0][0] = 0;
        for (String s : strs) {
            int[] zeroAndOne = calcZeroAndOne(s);
            int zeros = zeroAndOne[0];
            int ones = zeroAndOne[1];
            for (int i = m; i >= zeros; i--) {
                for (int j = n; j >= ones; j--) {
                    dp[i][j] = Math.max(dp[i][j], dp[i - zeros][j - ones] + 1);
                }
            }
        }
        return dp[m][n];
    }

    private int[] calcZeroAndOne(String str) {
        int[] res = new int[2];
        for (char c : str.toCharArray()) {
            res[c - '0']++;
        }
        return res;
    }
}

每日一题-将所有元素变为 0 的最少操作次数🟡

给你一个大小为 n非负 整数数组 nums 。你的任务是对该数组执行若干次(可能为 0 次)操作,使得 所有 元素都变为 0。

在一次操作中,你可以选择一个子数组 [i, j](其中 0 <= i <= j < n),将该子数组中所有 最小的非负整数 的设为 0。

返回使整个数组变为 0 所需的最少操作次数。

一个 子数组 是数组中的一段连续元素。

 

示例 1:

输入: nums = [0,2]

输出: 1

解释:

  • 选择子数组 [1,1](即 [2]),其中最小的非负整数是 2。将所有 2 设为 0,结果为 [0,0]
  • 因此,所需的最少操作次数为 1。

示例 2:

输入: nums = [3,1,2,1]

输出: 3

解释:

  • 选择子数组 [1,3](即 [1,2,1]),最小非负整数是 1。将所有 1 设为 0,结果为 [3,0,2,0]
  • 选择子数组 [2,2](即 [2]),将 2 设为 0,结果为 [3,0,0,0]
  • 选择子数组 [0,0](即 [3]),将 3 设为 0,结果为 [0,0,0,0]
  • 因此,最少操作次数为 3。

示例 3:

输入: nums = [1,2,1,2,1,2]

输出: 4

解释:

  • 选择子数组 [0,5](即 [1,2,1,2,1,2]),最小非负整数是 1。将所有 1 设为 0,结果为 [0,2,0,2,0,2]
  • 选择子数组 [1,1](即 [2]),将 2 设为 0,结果为 [0,0,0,2,0,2]
  • 选择子数组 [3,3](即 [2]),将 2 设为 0,结果为 [0,0,0,0,0,2]
  • 选择子数组 [5,5](即 [2]),将 2 设为 0,结果为 [0,0,0,0,0,0]
  • 因此,最少操作次数为 4。

 

提示:

  • 1 <= n == nums.length <= 105
  • 0 <= nums[i] <= 105

从分治到单调栈,简洁写法(Python/Java/C++/Go)

分治

回顾示例 3 $\textit{nums}=[1,2,1,2,1,2]$ 的操作过程:

  • 首先,只需要一次操作(选择整个数组),就可以把所有的最小值 $1$ 都变成 $0$。现在数组是 $[0,2,0,2,0,2]$。
  • 这些被 $0$ 分割开的 $2$,无法合在一起操作(因为子数组会包含 $0$,导致 $2$ 无法变成 $0$),只能一个一个操作。

一般地:

  1. 先通过一次操作,把 $\textit{nums}$ 的最小值都变成 $0$(如果最小值已经是 $0$ 则跳过这步)。
  2. 此时 $\textit{nums}$ 被这些 $0$ 划分成了若干段,后续操作只能在每段内部,不能跨段操作(否则子数组会包含 $0$)。每一段是规模更小的子问题,可以用第一步的方法解决。这样我们可以写一个递归去处理。递归边界:如果操作后全为 $0$,直接返回。

找最小值可以用 ST 表或者线段树,但这种做法很麻烦。有没有简单的做法呢?

单调栈

从左往右遍历数组,只在「必须要操作」的时候,才把答案加一。

什么时候必须要操作?

示例 3 $\textit{nums}=[1,2,1,2,1,2]$,因为 $2$ 左右两侧都有小于 $2$ 的数,需要单独操作。

又例如 $\textit{nums}=[1,2,3,2,1]$:

  • 遍历到第二个 $2$ 时,可以知道 $3$ 左右两侧都有小于 $3$ 的数,所以 $3$ 必须要操作一次,答案加一。注意这不表示第一次操作的是 $3$,而是某次操作会把 $3$ 变成 $0$。
  • 遍历到末尾 $1$ 时,可以知道中间的两个 $2$,左边有 $1$,右边也有 $1$,必须操作一次,答案加一。比如选择 $[2,3,2]$ 可以把这两个 $2$ 都变成 $0$。
  • 最后,数组中的 $1$ 需要操作一次都变成 $0$。

我们怎么知道「$3$ 左右两侧都有小于 $3$ 的数」?

遍历数组的同时,把遍历过的元素用栈记录:

  • 如果当前元素比栈顶大(或者栈为空),那么直接入栈。
  • 如果当前元素比栈顶小,那么对于栈顶来说,左边(栈顶倒数第二个数)比栈顶小(原因后面解释),右边(当前元素)也比栈顶小,所以栈顶必须操作一次。然后弹出栈顶。
  • 如果当前元素等于栈顶,可以在同一次操作中把当前元素与栈顶都变成 $0$,所以无需入栈。注意这保证了栈中没有重复元素。

如果当前元素比栈顶小,就弹出栈顶,我们会得到一个底小顶大的单调栈,这就保证了「对于栈顶来说,左边(栈顶倒数第二个数)比栈顶小」。

遍历结束后,因为栈是严格递增的,所以栈中每个非零数字都需要操作一次。

代码实现时,可以直接把 $\textit{nums}$ 当作栈。

具体请看 视频讲解,欢迎点赞关注~

###py

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        ans = 0
        st = []
        for x in nums:
            while st and x < st[-1]:
                st.pop()
                ans += 1
            # 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if not st or x != st[-1]:
                st.append(x)
        return ans + len(st) - (st[0] == 0)  # 0 不需要操作

###py

class Solution:
    def minOperations(self, nums: List[int]) -> int:
        ans = 0
        top = -1  # 栈顶下标(把 nums 当作栈)
        for x in nums:
            while top >= 0 and x < nums[top]:
                top -= 1  # 出栈
                ans += 1
            # 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if top < 0 or x != nums[top]:
                top += 1
                nums[top] = x  # 入栈
        return ans + top + (nums[0] > 0)

###java

class Solution {
    public int minOperations(int[] nums) {
        int ans = 0;
        int top = -1; // 栈顶下标(把 nums 当作栈)
        for (int x : nums) {
            while (top >= 0 && x < nums[top]) {
                top--; // 出栈
                ans++;
            }
            // 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if (top < 0 || x != nums[top]) {
                nums[++top] = x; // 入栈
            }
        }
        return ans + top + (nums[0] > 0 ? 1 : 0);
    }
}

###cpp

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int ans = 0;
        int top = -1; // 栈顶下标(把 nums 当作栈)
        for (int x : nums) {
            while (top >= 0 && x < nums[top]) {
                top--; // 出栈
                ans++;
            }
            // 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
            if (top < 0 || x != nums[top]) {
                nums[++top] = x; // 入栈
            }
        }
        return ans + top + (nums[0] > 0);
    }
};

###go

func minOperations(nums []int) (ans int) {
st := nums[:0] // 原地
for _, x := range nums {
for len(st) > 0 && x < st[len(st)-1] {
st = st[:len(st)-1]
ans++
}
// 如果 x 与栈顶相同,那么 x 与栈顶可以在同一次操作中都变成 0,x 无需入栈
if len(st) == 0 || x != st[len(st)-1] {
st = append(st, x)
}
}
if st[0] == 0 { // 0 不需要操作
ans--
}
return ans + len(st)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。每个元素至多入栈出栈各一次,所以二重循环的循环次数是 $\mathcal{O}(n)$。
  • 空间复杂度:$\mathcal{O}(n)$ 或 $\mathcal{O}(1)$。原地做法可以做到 $\mathcal{O}(1)$ 空间。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

贪心 & 单调栈

解法:贪心 & 单调栈

首先,由于每次操作只能把一种元素变成 $0$,因此答案至少是元素种数。那为什么答案不一定等于元素种数呢?考虑题目中的样例 nums = [1, 2, 1, 2, 1, 2],当我们把所有 $1$ 都变成 $0$ 后,nums = [0, 2, 0, 2, 0, 2],这时候的三个 $2$ 被更小的 $0$ 隔开了,无法在同一次操作内处理掉。

也就是说,我们每次只考虑一种元素。如果两元素之间有其它更小的元素,那么这两个元素无法在同一次操作内处理掉,答案加 $1$。

怎么知道两个元素之间有没有更小的元素呢?我们可以用单调栈求出每个元素右边最近的更小元素。假设下标 $i$ 右边最近的更小元素下标为 $f_i$,那么考虑下标 $i$ 和 $j$ 之间有没有更小元素时,只要检查是否 $f_i < j$ 即可。不熟悉单调栈的读者可以学习 leetcode 496. 下一个更大元素 I 的单调栈解法。复杂度 $\mathcal{O}(n)$。

参考代码(c++)

class Solution {
public:
    int minOperations(vector<int>& nums) {
        int n = nums.size();
        
        // f[i]:下标 i 右边最近的更小元素下标
        int f[n];
        for (int i = 0; i < n; i++) f[i] = n;
        // 用单调栈求
        stack<int> stk;
        for (int i = 0; i < n; i++) {
            while (!stk.empty() && nums[stk.top()] > nums[i]) {
                f[stk.top()] = i;
                stk.pop();
            }
            stk.push(i);
        }
        
        // 把下标按值分类,每次只考虑一种元素
        unordered_map<int, vector<int>> mp;
        for (int i = 0; i < n; i++) mp[nums[i]].push_back(i);
        int ans = 0;
        for (auto &p : mp) {
            auto &vec = p.second;
            // 除了 0 以外,每种元素至少一次操作
            if (p.first > 0) ans++;
            // 两元素之间有更小元素,无法在一次操作内处理,答案再加 1
            for (int i = 0; i + 1 < vec.size(); i++) if (f[vec[i]] < vec[i + 1]) ans++;
        }
        return ans;
    }
};
❌