阅读视图

发现新文章,点击刷新页面。

最早完成陆地和水上游乐设施的时间 II

方法一:分类讨论

思路与算法

可以先玩陆地项目再玩水上项目,也可以反过来。要找到最早的完成时间,我们需要分别计算这两种顺序下的最优结果,然后取其中的最小值。以“先陆地、后水上”为例,计算逻辑如下:

  • 对于陆地类别的所有项目,分别计算它们的“最早开始时间 + 持续时间”,找到其中的最小值。
  • 准备玩第二个项目时,会遇到两种情况:
    • 如果水上项目已经开放了,你可以立即开始,完成时刻就是“第一个项目的结束时间 + 水上项目的持续时间”。
    • 如果水上项目还没开放,你必须等到它的最早开始时间才能动工,完成时刻就是“水上项目的最早开始时间 + 水上项目的持续时间”。

总结一下,对于固定的第二类项目,最终完成时间为:

$$
\max(\textit{finish1}, \textit{start2}) + \textit{duration2}
$$

其中 $\textit{finish1}$ 表示第一类项目的结束时间,$\textit{start2}$ 表示第二类项目的开始时间。由于该表达式随着 $\textit{finish1}$ 的增大单调不减,因此为了使最终完成时间最小,我们只需要保留第一类项目中的最早结束时间即可。

在陆地项目结束最早的前提下,遍历所有的水上项目,并找到最早结束时间。

在陆地项目结束最早的前提下,遍历所有的水上项目,并找到最早结束时间。最后,交换顺序,按照同样的方法计算“先水上、后陆地”的最早完成时间。比较这两种顺序得到的结果,返回数值较小作为最终答案。

代码

###C++

class Solution {
    int solve(vector<int>& start1, vector<int>& duration1, vector<int>& start2, vector<int>& duration2) {
        int finish1 = INT_MAX;
        for (int i = 0; i < start1.size(); i++) {
            finish1 = min(finish1, start1[i] + duration1[i]);
        }

        int finish2 = INT_MAX;
        for (int i = 0; i < start2.size(); i++) {
            finish2 = min(finish2, max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

public:
    int earliestFinishTime(vector<int>& landStartTime, vector<int>& landDuration, vector<int>& waterStartTime, vector<int>& waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return min(land_water, water_land);
    }
};

###Java

class Solution {
    private int solve(int[] start1, int[] duration1, int[] start2, int[] duration2) {
        int finish1 = Integer.MAX_VALUE;
        for (int i = 0; i < start1.length; i++) {
            finish1 = Math.min(finish1, start1[i] + duration1[i]);
        }
        int finish2 = Integer.MAX_VALUE;
        for (int i = 0; i < start2.length; i++) {
            finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    public int earliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return Math.min(land_water, water_land);
    }
}

###Python

class Solution:
    def earliestFinishTime(self, landStartTime: List[int], landDuration: List[int], waterStartTime: List[int], waterDuration: List[int]) -> int:
        def solve(start1, duration1, start2, duration2):
            finish1 = inf
            for i in range(len(start1)):
                finish1 = min(finish1, start1[i] + duration1[i])
            finish2 = inf
            for i in range(len(start2)):
                finish2 = min(finish2, max(start2[i], finish1) + duration2[i])
            return finish2

        land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration)
        water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration)
        return min(land_water, water_land)

###JavaScript

var earliestFinishTime = function(landStartTime, landDuration, waterStartTime, waterDuration) {
    function solve(start1, duration1, start2, duration2) {
        let finish1 = Infinity;
        for (let i = 0; i < start1.length; i++) {
            finish1 = Math.min(finish1, start1[i] + duration1[i]);
        }
        let finish2 = Infinity;
        for (let i = 0; i < start2.length; i++) {
            finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    let land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
    let water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
    return Math.min(land_water, water_land);
};

###TypeScript

function solve(start1, duration1, start2, duration2) {
    let finish1 = Infinity;
    for (let i = 0; i < start1.length; i++) {
        finish1 = Math.min(finish1, start1[i] + duration1[i]);
    }
    let finish2 = Infinity;
    for (let i = 0; i < start2.length; i++) {
        finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
    }
    return finish2;
}

function earliestFinishTime(landStartTime: number[], landDuration: number[], waterStartTime: number[], waterDuration: number[]): number {
    let land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
    let water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
    return Math.min(land_water, water_land);
};

###Go

func earliestFinishTime(landStartTime []int, landDuration []int, waterStartTime []int, waterDuration []int) int {
    solve := func(start1, duration1, start2, duration2 []int) int {
        finish1 := 2147483647
        for i := 0; i < len(start1); i++ {
            if val := start1[i] + duration1[i]; val < finish1 {
                finish1 = val
            }
        }
        finish2 := 2147483647
        for i := 0; i < len(start2); i++ {
            curStart := start2[i]
            if finish1 > curStart {
                curStart = finish1
            }
            if val := curStart + duration2[i]; val < finish2 {
                finish2 = val
            }
        }
        return finish2
    }

    land_water := solve(landStartTime, landDuration, waterStartTime, waterDuration)
    water_land := solve(waterStartTime, waterDuration, landStartTime, landDuration)
    if land_water < water_land {
        return land_water
    }
    return water_land
}

###C#

public class Solution {
    private int solve(int[] start1, int[] duration1, int[] start2, int[] duration2) {
        int finish1 = int.MaxValue;
        for (int i = 0; i < start1.Length; i++) {
            finish1 = Math.Min(finish1, start1[i] + duration1[i]);
        }
        int finish2 = int.MaxValue;
        for (int i = 0; i < start2.Length; i++) {
            finish2 = Math.Min(finish2, Math.Max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    public int EarliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return Math.Min(land_water, water_land);
    }
}

###C

#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))

int solve(int* start1, int start1Size, int* duration1, int* start2, int start2Size, int* duration2) {
    int finish1 = INT_MAX;
    for (int i = 0; i < start1Size; i++) {
        finish1 = min(finish1, start1[i] + duration1[i]);
    }
    int finish2 = INT_MAX;
    for (int i = 0; i < start2Size; i++) {
        finish2 = min(finish2, max(start2[i], finish1) + duration2[i]);
    }
    return finish2;
}

int earliestFinishTime(int* landStartTime, int landStartTimeSize, int* landDuration, int landDurationSize, int* waterStartTime, int waterStartTimeSize, int* waterDuration, int waterDurationSize) {
    int land_water = solve(landStartTime, landStartTimeSize, landDuration, waterStartTime, waterStartTimeSize, waterDuration);
    int water_land = solve(waterStartTime, waterStartTimeSize, waterDuration, landStartTime, landStartTimeSize, landDuration);
    return min(land_water, water_land);
}

###Rust

impl Solution {
    fn solve(start1: &Vec<i32>, duration1: &Vec<i32>, start2: &Vec<i32>, duration2: &Vec<i32>) -> i32 {
        let mut finish1 = i32::MAX;
        for i in 0..start1.len() {
            finish1 = finish1.min(start1[i] + duration1[i]);
        }
        let mut finish2 = i32::MAX;
        for i in 0..start2.len() {
            finish2 = finish2.min(start2[i].max(finish1) + duration2[i]);
        }
        finish2
    }

    pub fn earliest_finish_time(land_start_time: Vec<i32>, land_duration: Vec<i32>, water_start_time: Vec<i32>, water_duration: Vec<i32>) -> i32 {
        let land_water = Self::solve(&land_start_time, &land_duration, &water_start_time, &water_duration);
        let water_land = Self::solve(&water_start_time, &water_duration, &land_start_time, &land_duration);
        land_water.min(water_land)
    }
}

复杂度分析

  • 时间复杂度:$O(n + m)$,其中 $n$ 和 $m$ 是输入数组的长度。

  • 空间复杂度:$O(1)$。

最早完成陆地和水上游乐设施的时间 I

方法一:暴力枚举 + 分类讨论

思路与算法

暴力枚举所有水上项目和陆地项目的组合。可以先玩陆地项目再玩水上项目,也可以反过来。要找到最早的完成时间,我们需要分别计算这两种顺序下的最优结果,然后取其中的最小值。以“先陆地、后水上”为例,计算逻辑如下:

  • 对于陆地类别项目,分别计算它的“开始时间 + 持续时间”。
  • 准备玩第二个项目时,会遇到两种情况:
    • 如果水上项目已经开放了,你可以立即开始,完成时刻就是“第一个项目的结束时间 + 水上项目的持续时间”。
    • 如果水上项目还没开放,你必须等到它开始才能动工,完成时刻就是“水上项目的开始时间 + 水上项目的持续时间”。

然后交换顺序,按照同样的方法计算“先水上、后陆地”的最早完成时间。

暴力枚举所有组合,最后返回数值较小作为最终答案。

代码

###C++

class Solution {
public:
    int earliestFinishTime(vector<int>& landStartTime, vector<int>& landDuration, vector<int>& waterStartTime, vector<int>& waterDuration) {
        int n = landStartTime.size();
        int m = waterStartTime.size();
        int res = INT_MAX;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                int land = landStartTime[i] + landDuration[i];
                int land_water = max(land, waterStartTime[j]) + waterDuration[j];
                res = min(res, land_water);

                int water = waterStartTime[j] + waterDuration[j];
                int water_land = max(water, landStartTime[i]) + landDuration[i];
                res = min(res, water_land);
            }
        }
        return res;
    }
};

###Java

class Solution {
    public int earliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int n = landStartTime.length;
        int m = waterStartTime.length;
        int res = Integer.MAX_VALUE;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                int land = landStartTime[i] + landDuration[i];
                int land_water = Math.max(land, waterStartTime[j]) + waterDuration[j];
                res = Math.min(res, land_water);

                int water = waterStartTime[j] + waterDuration[j];
                int water_land = Math.max(water, landStartTime[i]) + landDuration[i];
                res = Math.min(res, water_land);
            }
        }
        return res;
    }
}

###Python

class Solution:
    def earliestFinishTime(self, landStartTime: List[int], landDuration: List[int], waterStartTime: List[int], waterDuration: List[int]) -> int:
        n = len(landStartTime)
        m = len(waterStartTime)
        res = inf
        for i in range(n):
            for j in range(m):
                land = landStartTime[i] + landDuration[i]
                land_water = max(land,  waterStartTime[j]) + waterDuration[j]
                res = min(res, land_water)

                water = waterStartTime[j] + waterDuration[j]
                water_land = max(water, landStartTime[i]) + landDuration[i]
                res = min(res, water_land)
        return res

###JavaScript

var earliestFinishTime = function(landStartTime, landDuration, waterStartTime, waterDuration) {
    let n = landStartTime.length;
    let m = waterStartTime.length;
    let res = Infinity;
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            let land = landStartTime[i] + landDuration[i];
            let land_water = Math.max(land, waterStartTime[j]) + waterDuration[j];
            res = Math.min(res, land_water);

            let water = waterStartTime[j] + waterDuration[j];
            let water_land = Math.max(water, landStartTime[i]) + landDuration[i];
            res = Math.min(res, water_land);
        }
    }
    return res;
};

###TypeScript

function earliestFinishTime(landStartTime: number[], landDuration: number[], waterStartTime: number[], waterDuration: number[]): number {
    let n = landStartTime.length;
    let m = waterStartTime.length;
    let res = Infinity;
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            let land = landStartTime[i] + landDuration[i];
            let land_water = Math.max(land, waterStartTime[j]) + waterDuration[j];
            res = Math.min(res, land_water);

            let water = waterStartTime[j] + waterDuration[j];
            let water_land = Math.max(water, landStartTime[i]) + landDuration[i];
            res = Math.min(res, water_land);
        }
    }
    return res;
};

###Go

func earliestFinishTime(landStartTime []int, landDuration []int, waterStartTime []int, waterDuration []int) int {
    n := len(landStartTime)
    m := len(waterStartTime)
    res := 1000000
    for i := 0; i < n; i++ {
        for j := 0; j < m; j++ {
            land := landStartTime[i] + landDuration[i]
            land_water := land
            if waterStartTime[j] > land_water {
                land_water = waterStartTime[j]
            }
            land_water += waterDuration[j]
            if land_water < res {
                res = land_water
            }

            water := waterStartTime[j] + waterDuration[j]
            water_land := water
            if landStartTime[i] > water_land {
                water_land = landStartTime[i]
            }
            water_land += landDuration[i]
            if water_land < res {
                res = water_land
            }
        }
    }
    return res
}

###C#

public class Solution {
    public int EarliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int n = landStartTime.Length;
        int m = waterStartTime.Length;
        int res = int.MaxValue;;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                int land = landStartTime[i] + landDuration[i];
                int land_water = Math.Max(land, waterStartTime[j]) + waterDuration[j];
                res = Math.Min(res, land_water);

                int water = waterStartTime[j] + waterDuration[j];
                int water_land = Math.Max(water, landStartTime[i]) + landDuration[i];
                res = Math.Min(res, water_land);
            }
        }
        return res;
    }
}

###C

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

int earliestFinishTime(int* landStartTime, int landStartTimeSize, int* landDuration, int landDurationSize, int* waterStartTime, int waterStartTimeSize, int* waterDuration, int waterDurationSize) {
    int n = landStartTimeSize;
    int m = waterStartTimeSize;
    int res = INT_MAX;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            int land = landStartTime[i] + landDuration[i];
            int land_water = MAX(land, waterStartTime[j]) + waterDuration[j];
            res = MIN(res, land_water);

            int water = waterStartTime[j] + waterDuration[j];
            int water_land = MAX(water, landStartTime[i]) + landDuration[i];
            res = MIN(res, water_land);
        }
    }
    return res;
}

###Rust

impl Solution {
    pub fn earliest_finish_time(land_start_time: Vec<i32>, land_duration: Vec<i32>, water_start_time: Vec<i32>, water_duration: Vec<i32>) -> i32 {
        let n = land_start_time.len();
        let m = water_start_time.len();
        let mut res = i32::MAX;
        for i in 0..n {
            for j in 0..m {
                let land = land_start_time[i] + land_duration[i];
                let land_water = land.max(water_start_time[j]) + water_duration[j];
                res = res.min(land_water);

                let water = water_start_time[j] + water_duration[j];
                let water_land = water.max(land_start_time[i]) + land_duration[i];
                res = res.min(water_land);
            }
        }
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n \times m)$,其中 $n$ 和 $m$ 是输入数组的长度。

  • 空间复杂度:$O(1)$,其中 $n$ 是数组的长度。

方法二:线性枚举 + 分类讨论

思路与算法

可以先玩陆地项目再玩水上项目,也可以反过来。要找到最早的完成时间,我们需要分别计算这两种顺序下的最优结果,然后取其中的最小值。

对于固定的第二类项目,最终完成时间为:

$$
\max(\textit{finish1}, \textit{start2}) + \textit{duration2}
$$

其中 $\textit{finish1}$ 表示第一类项目的结束时间,$\textit{start2}$ 表示第二类项目的开始时间。由于该表达式随着 $\textit{finish1}$ 的增大单调不减,因此为了使最终完成时间最小,我们只需要保留第一类项目中的最早结束时间即可。

在陆地项目结束最早的前提下,遍历所有的水上项目,并找到最早结束时间。

最后,交换顺序,按照同样的方法计算“先水上、后陆地”的最早完成时间。比较这两种顺序得到的结果,返回数值较小作为最终答案。

代码

###C++

class Solution {
    int solve(vector<int>& start1, vector<int>& duration1, vector<int>& start2, vector<int>& duration2) {
        int finish1 = INT_MAX;
        for (int i = 0; i < start1.size(); i++) {
            finish1 = min(finish1, start1[i] + duration1[i]);
        }

        int finish2 = INT_MAX;
        for (int i = 0; i < start2.size(); i++) {
            finish2 = min(finish2, max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

public:
    int earliestFinishTime(vector<int>& landStartTime, vector<int>& landDuration, vector<int>& waterStartTime, vector<int>& waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return min(land_water, water_land);
    }
};

###Java

class Solution {
    private int solve(int[] start1, int[] duration1, int[] start2, int[] duration2) {
        int finish1 = Integer.MAX_VALUE;
        for (int i = 0; i < start1.length; i++) {
            finish1 = Math.min(finish1, start1[i] + duration1[i]);
        }
        int finish2 = Integer.MAX_VALUE;
        for (int i = 0; i < start2.length; i++) {
            finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    public int earliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return Math.min(land_water, water_land);
    }
}

###Python

class Solution:
    def earliestFinishTime(self, landStartTime: List[int], landDuration: List[int], waterStartTime: List[int], waterDuration: List[int]) -> int:
        def solve(start1, duration1, start2, duration2):
            finish1 = inf
            for i in range(len(start1)):
                finish1 = min(finish1, start1[i] + duration1[i])
            finish2 = inf
            for i in range(len(start2)):
                finish2 = min(finish2, max(start2[i], finish1) + duration2[i])
            return finish2

        land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration)
        water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration)
        return min(land_water, water_land)

###JavaScript

var earliestFinishTime = function(landStartTime, landDuration, waterStartTime, waterDuration) {
    function solve(start1, duration1, start2, duration2) {
        let finish1 = Infinity;
        for (let i = 0; i < start1.length; i++) {
            finish1 = Math.min(finish1, start1[i] + duration1[i]);
        }
        let finish2 = Infinity;
        for (let i = 0; i < start2.length; i++) {
            finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    let land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
    let water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
    return Math.min(land_water, water_land);
};

###TypeScript

function solve(start1, duration1, start2, duration2) {
    let finish1 = Infinity;
    for (let i = 0; i < start1.length; i++) {
        finish1 = Math.min(finish1, start1[i] + duration1[i]);
    }
    let finish2 = Infinity;
    for (let i = 0; i < start2.length; i++) {
        finish2 = Math.min(finish2, Math.max(start2[i], finish1) + duration2[i]);
    }
    return finish2;
}

function earliestFinishTime(landStartTime: number[], landDuration: number[], waterStartTime: number[], waterDuration: number[]): number {
    let land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
    let water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
    return Math.min(land_water, water_land);
};

###Go

func earliestFinishTime(landStartTime []int, landDuration []int, waterStartTime []int, waterDuration []int) int {
    solve := func(start1, duration1, start2, duration2 []int) int {
        finish1 := 2147483647
        for i := 0; i < len(start1); i++ {
            if val := start1[i] + duration1[i]; val < finish1 {
                finish1 = val
            }
        }
        finish2 := 2147483647
        for i := 0; i < len(start2); i++ {
            curStart := start2[i]
            if finish1 > curStart {
                curStart = finish1
            }
            if val := curStart + duration2[i]; val < finish2 {
                finish2 = val
            }
        }
        return finish2
    }

    land_water := solve(landStartTime, landDuration, waterStartTime, waterDuration)
    water_land := solve(waterStartTime, waterDuration, landStartTime, landDuration)
    if land_water < water_land {
        return land_water
    }
    return water_land
}

###C#

public class Solution {
    private int solve(int[] start1, int[] duration1, int[] start2, int[] duration2) {
        int finish1 = int.MaxValue;
        for (int i = 0; i < start1.Length; i++) {
            finish1 = Math.Min(finish1, start1[i] + duration1[i]);
        }
        int finish2 = int.MaxValue;
        for (int i = 0; i < start2.Length; i++) {
            finish2 = Math.Min(finish2, Math.Max(start2[i], finish1) + duration2[i]);
        }
        return finish2;
    }

    public int EarliestFinishTime(int[] landStartTime, int[] landDuration, int[] waterStartTime, int[] waterDuration) {
        int land_water = solve(landStartTime, landDuration, waterStartTime, waterDuration);
        int water_land = solve(waterStartTime, waterDuration, landStartTime, landDuration);
        return Math.Min(land_water, water_land);
    }
}

###C

#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))

int solve(int* start1, int start1Size, int* duration1, int* start2, int start2Size, int* duration2) {
    int finish1 = INT_MAX;
    for (int i = 0; i < start1Size; i++) {
        finish1 = min(finish1, start1[i] + duration1[i]);
    }
    int finish2 = INT_MAX;
    for (int i = 0; i < start2Size; i++) {
        finish2 = min(finish2, max(start2[i], finish1) + duration2[i]);
    }
    return finish2;
}

int earliestFinishTime(int* landStartTime, int landStartTimeSize, int* landDuration, int landDurationSize, int* waterStartTime, int waterStartTimeSize, int* waterDuration, int waterDurationSize) {
    int land_water = solve(landStartTime, landStartTimeSize, landDuration, waterStartTime, waterStartTimeSize, waterDuration);
    int water_land = solve(waterStartTime, waterStartTimeSize, waterDuration, landStartTime, landStartTimeSize, landDuration);
    return min(land_water, water_land);
}

###Rust

impl Solution {
    fn solve(start1: &Vec<i32>, duration1: &Vec<i32>, start2: &Vec<i32>, duration2: &Vec<i32>) -> i32 {
        let mut finish1 = i32::MAX;
        for i in 0..start1.len() {
            finish1 = finish1.min(start1[i] + duration1[i]);
        }
        let mut finish2 = i32::MAX;
        for i in 0..start2.len() {
            finish2 = finish2.min(start2[i].max(finish1) + duration2[i]);
        }
        finish2
    }

    pub fn earliest_finish_time(landStartTime: Vec<i32>, landDuration: Vec<i32>, waterStartTime: Vec<i32>, waterDuration: Vec<i32>) -> i32 {
        let land_water = Self::solve(&landStartTime, &landDuration, &waterStartTime, &waterDuration);
        let water_land = Self::solve(&waterStartTime, &waterDuration, &landStartTime, &landDuration);
        land_water.min(water_land)
    }
}

复杂度分析

  • 时间复杂度:$O(n + m)$,其中 $n$ 和 $m$ 是输入数组的长度。

  • 空间复杂度:$O(1)$。

打折购买糖果的最小开销

方法一:贪心

提示 $1$

我们采用如下的策略,可以使得购买糖果的总开销最小:

将糖果价格从高到低排序,然后按照每 $3$ 个分成一组,花钱购买前两个,赠送第三个。

提示 $1$ 解释

我们假设糖果数量为 $n$,那么最多可以赠送的糖果数量为 $\lfloor n / 3 \rfloor$,提示 $1$ 的方法中,赠送糖果的数量与这个上界相等。那么,我们可以将证明分为两部分:

  1. 开销最小的购买方案一定是赠送数量最多的方案;
  2. 提示 $1$ 的购买方案一定是赠送数量最多的方案中最优的。

对于第一部分,任意一个赠送糖果数量少于 $\lfloor n / 3 \rfloor$ 的方案,都一定可以找到至少三个未被分组的糖果,对于这三个糖果,一定可以使得价格最低的糖果免费。因此命题 $1$ 成立。

对于第二部分,我们不妨假设 $\textit{cost}$ 数组已经按照价格降序排序,根据定义,免费获得糖果中价格最高的一定不大于 $\textit{cost}[2]$(假设该下标存在,下同)。类似地,我们可以得出,价格第 $k (0 \le k \le \lfloor n / 3 \rfloor)$ 高的糖果的价格一定不大于 $\textit{cost}[3k+2]$。

提示 $1$ 的方案中,所有的不等式均取了等号,同时考虑到免费糖果数量一定,因此命题 $2$ 成立。

综上,我们可以得出,提示 $1$ 的购买方案是开销最小的。

思路与算法

根据 提示 $1$,我们首先将糖果价格数组 $\textit{cost}$ 从高到低排序,此时免费获得所有下标模 $3$ 余 $2$ 的糖果的方案开销最小。随后我们遍历数组计算总开销,在计算时我们需要跳过这些免费获得的糖果。最终,我们将总开销返回作为答案。

代码

###C++

class Solution {
public:
    int minimumCost(vector<int>& cost) {
        sort(cost.begin(), cost.end(), greater<int>());
        int res = 0;
        int n = cost.size();
        for (int i = 0; i < n; ++i) {
            if (i % 3 != 2) {
                res += cost[i];
            }
        }
        return res;
    }
};

###Python

class Solution:
    def minimumCost(self, cost: List[int]) -> int:
        cost.sort(key = lambda x: -x)
        res = 0
        n = len(cost)
        for i in range(n):
            if i % 3 != 2:
                res += cost[i]
        return res

###Java

class Solution {
    public int minimumCost(int[] cost) {
        Arrays.sort(cost);
        int res = 0;
        int n = cost.length;
        for (int i = n - 1; i >= 0; --i) {
            if ((n - 1 - i) % 3 != 2) {
                res += cost[i];
            }
        }
        return res;
    }
}

###C#

public class Solution {
    public int MinimumCost(int[] cost) {
        Array.Sort(cost);
        Array.Reverse(cost);
        
        int res = 0;
        int n = cost.Length;
        for (int i = 0; i < n; ++i) {
            if (i % 3 != 2) {
                res += cost[i];
            }
        }
        return res;
    }
}

###Go

func minimumCost(cost []int) int {
    sort.Sort(sort.Reverse(sort.IntSlice(cost)))
    res := 0
    n := len(cost)
    for i := 0; i < n; i++ {
        if i % 3 != 2 {
            res += cost[i]
        }
    }
    return res
}

###C

int compareDesc(const void* a, const void* b) {
    return *(int*)b - *(int*)a;
}

int minimumCost(int* cost, int costSize) {
    qsort(cost, costSize, sizeof(int), compareDesc);
    int res = 0;
    for (int i = 0; i < costSize; ++i) {
        if (i % 3 != 2) {
            res += cost[i];
        }
    }
    return res;
}

###JavaScript

var minimumCost = function(cost) {
    cost.sort((a, b) => b - a);
    let res = 0;
    const n = cost.length;
    for (let i = 0; i < n; ++i) {
        if (i % 3 !== 2) {
            res += cost[i];
        }
    }
    return res;
};

###TypeScript

function minimumCost(cost: number[]): number {
    cost.sort((a, b) => b - a);
    let res = 0;
    const n = cost.length;
    for (let i = 0; i < n; ++i) {
        if (i % 3 !== 2) {
            res += cost[i];
        }
    }
    return res;
}

###Rust

impl Solution {
    pub fn minimum_cost(mut cost: Vec<i32>) -> i32 {
        cost.sort_by(|a, b| b.cmp(a));
        let mut res = 0;
        let n = cost.len();
        for i in 0..n {
            if i % 3 != 2 {
                res += cost[i];
            }
        }
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n \log n)$,其中 $n$ 为 $\textit{cost}$ 的长度。即为对糖果按照价格排序的时间复杂度。

  • 空间复杂度:$O(\log n)$,即为排序的栈空间开销。

跳跃游戏 VII

方法一:动态规划 + 前缀和优化

提示 $1$

我们用 $f(i)$ 表示能否从位置 $0$ 按照给定的规则跳到位置 $i$。

如果 $s[i]$ 为 $1$,我们无法跳到位置 $i$,此时 $f(i) = \text{False}$。

如果 $s[i]$ 为 $0$,我们可以枚举位置 $j$,表示最后一步是从位置 $j$ 跳到位置 $i$ 的。位置 $j$ 需要满足 $j \in [i - \textit{maxJump}, i - \textit{minJump}]$ 并且 $j \geq 0$,只要存在一个 $j$ 满足 $f(j)=\text{True}$,那么 $f(i)$ 就为 $\text{True}$。因此我们可以写出状态转移方程:

$$
f(i) = \text{any}\big(f(j)\big), \quad 其中 ~ j \in [i - \textit{maxJump}, i - \textit{minJump}] ~并且~ j \geq 0
$$

如果字符串 $s$ 的长度为 $n$,我们按照上述状态转移方程进行动态规划后,最终的答案即为 $f(n-1)$。

然而该状态转移方程的转移时间为 $O(n)$,即动态规划的总时间复杂度为 $O(n^2)$,会超出时间限制,因此我们需要进行优化。

提示 $2$

为了叙述方便,我们用 $\textit{left}_i$ 和 $\textit{right}_i$ 表示位置 $i$ 在状态转移中对应的 $j$ 的区间。在大部分情况下,有:

$$
[\textit{left}_i, \textit{right}_i] = [i - \textit{maxJump}, i - \textit{minJump}]
$$

但由于有 $j \geq 0$ 的限制,可能需要对该区间进行一些处理,具体的处理方法可以参考代码部分。

根据提示 $1$,$f(i)$ 的值为 $\text{True}$,当且仅当 $s[i]$ 为 $0$,并且区间 $[\textit{left}_i, \textit{right}_i]$ 中存在一个位置作为下标对应的 $f$ 值也为 $\text{True}$。如果我们将 $\text{True}$ 看成 $1$,$\text{False}$ 看成 $0$,那么其等价于:

  • $f(i)$ 的值为 $\text{True}$,当且仅当 $s[i]$ 为 $0$,并且 $\sum\limits_{j=\textit{left}_i}^{\textit{right}_i} f(j)$ 的值不为 $0$。

由于 $\sum\limits_{j=\textit{left}_i}^{\textit{right}_i} f(j)$ 是数组 $f$ 的一段连续区间的求和,因此我们可以在动态规划的同时维护数组 $f$ 的前缀和数组 $\textit{pre}$,其中:

$$
\textit{pre}(i) = \sum_{j=0}^{i} f(i)
$$

这样就可以通过:

$$
\sum_{j=\textit{left}_i}^{\textit{right}_i} f(j) = \textit{pre}(\textit{right}_i) - \textit{pre}(\textit{left}_i - 1)
$$

在 $O(1)$ 的时间快速地进行状态转移了,使得动态规划的总时间减少为 $O(n)$。这里同样需要注意处理 $\textit{left}_i \leq 0$ 的情况,可以参考代码部分。

细节

动态规划的边界条件为 $f(0) = \text{True}$。在进行状态转移时,我们可以从 $i = \textit{minJump}$ 开始,保证 $\textit{right}_i$ 恒大于等于 $0$,这样就只需要特殊处理 $\textit{left}_i$ 了。

代码

###C++

class Solution {
public:
    bool canReach(string s, int minJump, int maxJump) {
        int n = s.size();
        vector<int> f(n), pre(n);
        f[0] = 1;
        // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
        for (int i = 0; i < minJump; ++i) {
            pre[i] = 1;
        }
        for (int i = minJump; i < n; ++i) {
            int left = i - maxJump, right = i - minJump;
            if (s[i] == '0') {
                int total = pre[right] - (left <= 0 ? 0 : pre[left - 1]);
                f[i] = (total != 0);
            }
            pre[i] = pre[i - 1] + f[i];
        }
        return f[n - 1];
    }
};

###Python

class Solution:
    def canReach(self, s: str, minJump: int, maxJump: int) -> bool:
        n = len(s)
        f, pre = [0] * n, [0] * n
        f[0] = 1
        # 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
        for i in range(minJump):
            pre[i] = 1
        for i in range(minJump, n):
            left, right = i - maxJump, i - minJump
            if s[i] == "0":
                total = pre[right] - (0 if left <= 0 else pre[left - 1])
                f[i] = int(total != 0)
            pre[i] = pre[i - 1] + f[i]

        return bool(f[n - 1])

###Java

class Solution {
    public boolean canReach(String s, int minJump, int maxJump) {
        int n = s.length();
        int[] f = new int[n];
        int[] pre = new int[n];
        f[0] = 1;
        // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
        for (int i = 0; i < minJump; i++) {
            pre[i] = 1;
        }
        for (int i = minJump; i < n; i++) {
            int left = i - maxJump;
            int right = i - minJump;
            if (s.charAt(i) == '0') {
                int total = pre[right] - (left <= 0 ? 0 : pre[left - 1]);
                f[i] = total != 0 ? 1 : 0;
            }
            pre[i] = pre[i - 1] + f[i];
        }
        return f[n - 1] == 1;
    }
}

###C#

public class Solution {
    public bool CanReach(string s, int minJump, int maxJump) {
        int n = s.Length;
        int[] f = new int[n];
        int[] pre = new int[n];
        f[0] = 1;
        // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
        for (int i = 0; i < minJump; i++) {
            pre[i] = 1;
        }
        for (int i = minJump; i < n; i++) {
            int left = i - maxJump;
            int right = i - minJump;
            if (s[i] == '0') {
                int total = pre[right] - (left <= 0 ? 0 : pre[left - 1]);
                f[i] = total != 0 ? 1 : 0;
            }
            pre[i] = pre[i - 1] + f[i];
        }
        return f[n - 1] == 1;
    }
}

###Go

func canReach(s string, minJump int, maxJump int) bool {
    n := len(s)
    f := make([]int, n)
    pre := make([]int, n)
    f[0] = 1
    // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
    for i := 0; i < minJump; i++ {
        pre[i] = 1
    }
    for i := minJump; i < n; i++ {
        left := i - maxJump
        right := i - minJump
        if s[i] == '0' {
            total := pre[right]
            if left > 0 {
                total -= pre[left-1]
            }
            if total != 0 {
                f[i] = 1
            } else {
                f[i] = 0
            }
        }
        pre[i] = pre[i-1] + f[i]
    }
    return f[n-1] == 1
}

###C

bool canReach(char* s, int minJump, int maxJump) {
    int n = strlen(s);
    int* f = (int*)calloc(n, sizeof(int));
    int* pre = (int*)malloc(n * sizeof(int));
    f[0] = 1;
    // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
    for (int i = 0; i < minJump; i++) {
        pre[i] = 1;
    }
    for (int i = minJump; i < n; i++) {
        int left = i - maxJump;
        int right = i - minJump;
        if (s[i] == '0') {
            int total = pre[right];
            if (left > 0) {
                total -= pre[left - 1];
            }
            f[i] = total != 0 ? 1 : 0;
        }
        pre[i] = pre[i - 1] + f[i];
    }
    bool result = (f[n - 1] == 1);
    free(f);
    free(pre);
    return result;
}

###JavaScript

var canReach = function(s, minJump, maxJump) {
    const n = s.length;
    const f = new Array(n).fill(0);
    const pre = new Array(n).fill(0);
    f[0] = 1;
    // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
    for (let i = 0; i < minJump; i++) {
        pre[i] = 1;
    }
    for (let i = minJump; i < n; i++) {
        const left = i - maxJump;
        const right = i - minJump;
        if (s[i] === '0') {
            const total = pre[right] - (left <= 0 ? 0 : pre[left - 1]);
            f[i] = total !== 0 ? 1 : 0;
        }
        pre[i] = pre[i - 1] + f[i];
    }
    return f[n - 1] === 1;
};

###TypeScript

function canReach(s: string, minJump: number, maxJump: number): boolean {
    const n: number = s.length;
    const f: number[] = new Array(n).fill(0);
    const pre: number[] = new Array(n).fill(0);
    f[0] = 1;
    // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
    for (let i = 0; i < minJump; i++) {
        pre[i] = 1;
    }
    for (let i = minJump; i < n; i++) {
        const left: number = i - maxJump;
        const right: number = i - minJump;
        if (s[i] === '0') {
            const total: number = pre[right] - (left <= 0 ? 0 : pre[left - 1]);
            f[i] = total !== 0 ? 1 : 0;
        }
        pre[i] = pre[i - 1] + f[i];
    }
    return f[n - 1] === 1;
}

###Rust

impl Solution {
    pub fn can_reach(s: String, min_jump: i32, max_jump: i32) -> bool {
        let n = s.len();
        let min_jump = min_jump as usize;
        let max_jump = max_jump as usize;
        let mut f = vec![0; n];
        let mut pre = vec![0; n];
        f[0] = 1;
        // 由于我们从 i=minJump 开始动态规划,因此需要将 [0,minJump) 这部分的前缀和预处理出来
        for i in 0..min_jump {
            pre[i] = 1;
        }
        let s_chars: Vec<char> = s.chars().collect();
        for i in min_jump..n {
            let left = i as i32 - max_jump as i32;
            let right = i - min_jump;
            if s_chars[i] == '0' {
                let total = if left <= 0 {
                    pre[right]
                } else {
                    pre[right] - pre[left as usize - 1]
                };
                f[i] = if total != 0 { 1 } else { 0 };
            }
            pre[i] = pre[i - 1] + f[i];
        }
        f[n - 1] == 1
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是字符串 $s$ 的长度。

  • 空间复杂度:$O(n)$,即为数组 $f$ 和 $\textit{pre}$ 需要使用的空间。

跳跃游戏 V

方法一:记忆化搜索

我们用 $\textit{dp}[i]$ 表示从位置 $i$ 开始跳跃,最多可以访问的下标个数。我们可以写出如下的状态转移方程:

$$
\textit{dp}[i] = \max(\textit{dp}[j]) + 1
$$

其中 $j$ 需要满足三个条件:

  • $0 \leq j < \textit{arr}.\text{length}$,即 $j$ 必须在数组 $\textit{arr}$ 的范围内;

  • $i - d \leq j \leq i + d$,即 $j$ 到 $i$ 的距离不能超过给定的 $d$;

  • 从 $\textit{arr}[j]$ 到 $\textit{arr}[i]$ 的这些元素除了 $\textit{arr}[i]$ 本身之外,都必须小于 $\textit{arr}[i]$,这是题目中的要求。

对于任意的位置 $i$,根据第二个条件,我们只需要在其左右两侧最多扫描 $d$ 个元素,就可以找出所有满足条件的位置 $j$。随后我们通过这些 $j$ 的 $\textit{dp}$ 值对位置 $i$ 进行状态转移,就可以得到 $\textit{dp}[i]$ 的值。

此时出现了一个需要解决的问题:如何保证在处理到位置 $i$ 时,所有满足条件的位置 $j$ 都已经被处理过了呢?换句话说,如何保证这些位置 $j$ 对应的 $\textit{dp}[j]$ 都已经计算过了?如果我们用常规的动态规划方法(例如根据位置从小到大或者从大到小进行动态规划),那么并不能保证这一点,因为 $j$ 分布在位置 $i$ 的两侧。

因此我们需要借助记忆化搜索的方法,即当我们需要 $\textit{dp}[j]$ 的值时,如果我们之前已经计算过,就直接返回这个值(记忆);如果我们之前没有计算过,就先将 $\textit{dp}[i]$ 搁在一边,转而去计算 $\textit{dp}[j]$(搜索),当 $\textit{dp}[j]$ 计算完成后,再用其对 $\textit{dp}[i]$ 进行状态转移。

记忆化搜索一定能在有限的时间内停止吗?如果它不能在有限的时间内停止,说明在搜索的过程中出现了环。即当我们需要计算 $\textit{dp}[i]$ 时,我们发现某个 $\textit{dp}[j]$ 没有计算过,接着在计算 $\textit{dp}[j]$ 时,又发现某个 $\textit{dp}[k]$ 没有计算过,以此类推,直到某次搜索时发现当前位置的 $\textit{dp}$ 值需要 $\textit{dp}[i]$ 的值才能得到,这样就出现了环。在本题中,根据第三个条件,$\textit{arr}[j]$ 是一定小于 $\textit{arr}[i]$ 的,即我们的搜索每深入一层,就跳到了高度更小的位置。因此在搜索的过程中不会出现环。这样一来,我们通过记忆化搜索,就可以在与常规的动态规划相同的时间复杂度内得到所有的 $\textit{dp}$ 值。

注意:如果你不太能理解这篇题解在讲什么,请使用搜索引擎,补充「记忆化搜索」的相关知识。记忆化搜索以深度优先搜索为基础,在第一次搜索到某个状态时,会将该状态与其对应的值存储下来,这样在未来的搜索中,如果搜索到相同的状态,就不用再进行重复搜索了。记忆化搜索和动态规划非常相似,大部分的题目如果可以使用动态规划解决,那么一定可以使用记忆化搜索解决,反之亦然。这是因为记忆化搜索要求搜索状态满足拓扑序(即不会出现环),而动态规划同样要求状态满足拓扑序,不然就没法进行状态转移了。

###C++

class Solution {
private:
    vector<int> f;
    
public:
    void dfs(vector<int>& arr, int id, int d, int n) {
        if (f[id] != -1) {
            return;
        }
        f[id] = 1;
        for (int i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
            dfs(arr, i, d, n);
            f[id] = max(f[id], f[i] + 1);
        }
        for (int i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
            dfs(arr, i, d, n);
            f[id] = max(f[id], f[i] + 1);
        }
    }
    
    int maxJumps(vector<int>& arr, int d) {
        int n = arr.size();
        f.resize(n, -1);
        for (int i = 0; i < n; ++i) {
            dfs(arr, i, d, n);
        }
        return *max_element(f.begin(), f.end());
    }
};

###Python

class Solution:
    def maxJumps(self, arr: List[int], d: int) -> int:
        seen = dict()

        def dfs(pos):
            if pos in seen:
                return
            seen[pos] = 1

            i = pos - 1
            while i >= 0 and pos - i <= d and arr[pos] > arr[i]:
                dfs(i)
                seen[pos] = max(seen[pos], seen[i] + 1)
                i -= 1
            i = pos + 1
            while i < len(arr) and i - pos <= d and arr[pos] > arr[i]:
                dfs(i)
                seen[pos] = max(seen[pos], seen[i] + 1)
                i += 1

        for i in range(len(arr)):
            dfs(i)

        return max(seen.values())

###Java

class Solution {
    private int[] f;
    
    private void dfs(int[] arr, int id, int d, int n) {
        if (f[id] != -1) {
            return;
        }
        f[id] = 1;
        for (int i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
            dfs(arr, i, d, n);
            f[id] = Math.max(f[id], f[i] + 1);
        }
        for (int i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
            dfs(arr, i, d, n);
            f[id] = Math.max(f[id], f[i] + 1);
        }
    }
    
    public int maxJumps(int[] arr, int d) {
        int n = arr.length;
        f = new int[n];
        Arrays.fill(f, -1);
        for (int i = 0; i < n; ++i) {
            dfs(arr, i, d, n);
        }
        int ans = 0;
        for (int val : f) {
            ans = Math.max(ans, val);
        }
        return ans;
    }
}

###C#

public class Solution {
    private int[] f;
    
    private void Dfs(int[] arr, int id, int d, int n) {
        if (f[id] != -1) {
            return;
        }
        f[id] = 1;
        for (int i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
            Dfs(arr, i, d, n);
            f[id] = Math.Max(f[id], f[i] + 1);
        }
        for (int i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
            Dfs(arr, i, d, n);
            f[id] = Math.Max(f[id], f[i] + 1);
        }
    }
    
    public int MaxJumps(int[] arr, int d) {
        int n = arr.Length;
        f = new int[n];
        Array.Fill(f, -1);
        for (int i = 0; i < n; ++i) {
            Dfs(arr, i, d, n);
        }
        return f.Max();
    }
}

###Go

func maxJumps(arr []int, d int) int {
    n := len(arr)
    f := make([]int, n)
    for i := range f {
        f[i] = -1
    }
    
    var dfs func(int)
    dfs = func(id int) {
        if f[id] != -1 {
            return
        }
        f[id] = 1
        for i := id - 1; i >= 0 && id-i <= d && arr[id] > arr[i]; i-- {
            dfs(i)
            if f[i]+1 > f[id] {
                f[id] = f[i] + 1
            }
        }
        for i := id + 1; i < n && i-id <= d && arr[id] > arr[i]; i++ {
            dfs(i)
            if f[i]+1 > f[id] {
                f[id] = f[i] + 1
            }
        }
    }
    
    for i := 0; i < n; i++ {
        dfs(i)
    }
    
    ans := 0
    for _, val := range f {
        if val > ans {
            ans = val
        }
    }
    return ans
}

###C

void dfs(int* arr, int id, int d, int n, int *f) {
    if (f[id] != -1) {
        return;
    }
    f[id] = 1;
    for (int i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
        dfs(arr, i, d, n, f);
        if (f[i] + 1 > f[id]) {
            f[id] = f[i] + 1;
        }
    }
    for (int i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
        dfs(arr, i, d, n, f);
        if (f[i] + 1 > f[id]) {
            f[id] = f[i] + 1;
        }
    }
}

int maxJumps(int* arr, int arrSize, int d) {
    int n = arrSize;
    int *f = (int*)malloc(n * sizeof(int));
    for (int i = 0; i < n; ++i) {
        f[i] = -1;
    }
    
    for (int i = 0; i < n; ++i) {
        dfs(arr, i, d, n, f);
    }
    
    int ans = 0;
    for (int i = 0; i < n; ++i) {
        if (f[i] > ans) {
            ans = f[i];
        }
    }
    
    free(f);
    return ans;
}

###JavaScript

var maxJumps = function(arr, d) {
    const n = arr.length;
    const f = new Array(n).fill(-1);
    
    const dfs = (id) => {
        if (f[id] !== -1) {
            return;
        }
        f[id] = 1;
        for (let i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
            dfs(i);
            f[id] = Math.max(f[id], f[i] + 1);
        }
        for (let i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
            dfs(i);
            f[id] = Math.max(f[id], f[i] + 1);
        }
    };
    
    for (let i = 0; i < n; ++i) {
        dfs(i);
    }
    
    return Math.max(...f);
};

###TypeScript

function maxJumps(arr: number[], d: number): number {
    const n = arr.length;
    const f: number[] = new Array(n).fill(-1);
    
    const dfs = (id: number): void => {
        if (f[id] !== -1) {
            return;
        }
        f[id] = 1;
        for (let i = id - 1; i >= 0 && id - i <= d && arr[id] > arr[i]; --i) {
            dfs(i);
            f[id] = Math.max(f[id], f[i] + 1);
        }
        for (let i = id + 1; i < n && i - id <= d && arr[id] > arr[i]; ++i) {
            dfs(i);
            f[id] = Math.max(f[id], f[i] + 1);
        }
    };
    
    for (let i = 0; i < n; ++i) {
        dfs(i);
    }
    
    return Math.max(...f);
}

###Rust

impl Solution {
    pub fn max_jumps(arr: Vec<i32>, d: i32) -> i32 {
        let n = arr.len();
        let mut f = vec![-1; n];
        let d = d as usize;
        
        fn dfs(arr: &Vec<i32>, f: &mut Vec<i32>, id: usize, d: usize, n: usize) {
            if f[id] != -1 {
                return;
            }
            f[id] = 1;
            
            let mut i = id as i32 - 1;
            while i >= 0 && (id as i32 - i) <= d as i32 && arr[id] > arr[i as usize] {
                let i_idx = i as usize;
                dfs(arr, f, i_idx, d, n);
                f[id] = f[id].max(f[i_idx] + 1);
                i -= 1;
            }
            
            for i in id + 1..n {
                if i - id <= d && arr[id] > arr[i] {
                    dfs(arr, f, i, d, n);
                    f[id] = f[id].max(f[i] + 1);
                } else {
                    break;
                }
            }
        }
        
        for i in 0..n {
            dfs(&arr, &mut f, i, d, n);
        }
        
        *f.iter().max().unwrap()
    }
}

复杂度分析

  • 时间复杂度:$O(ND)$,其中 $N$ 是数组 arr 的长度。

  • 空间复杂度:$O(N)$。

思考

上面我们提到:大部分的题目如果可以使用动态规划解决,那么一定可以使用记忆化搜索解决,反之亦然。那么本题如何使用动态规划解决呢?

由于我们已经得到了状态转移方程,因此重点在于动态规划的顺序。可以发现,如果我们将所有的位置按照高度进行升序排序,并按照该顺序计算状态转移方程,那么就可以完成动态规划。这是因为在第三个条件中,arr[j] < arr[i] 一定成立,因此对于位置 i,如果我们在此之前计算出了所有高度小于 arr[i] 的位置的 dp 值,那么在对位置 i 进行状态转移时,所有满足条件的 jdp 值就已经全部计算完成了,因此我们可以通过该顺序完成动态规划。

搜索旋转排序数组

📺 视频题解

33. 搜索旋转排序数组_1.mp4

📖 文字题解

方法一:二分查找

思路和算法

对于有序数组,可以使用二分查找的方法查找元素。

但是这道题中,数组本身不是有序的,进行旋转后只保证了数组的局部是有序的,这还能进行二分查找吗?答案是可以的。

可以发现的是,我们将数组从中间分开成左右两部分的时候,一定有一部分的数组是有序的。拿示例来看,我们从 6 这个位置分开以后数组变成了 [4, 5, 6][7, 0, 1, 2] 两个部分,其中左边 [4, 5, 6] 这个部分的数组是有序的,其他也是如此。

这启示我们可以在常规二分查找的时候查看当前 mid 为分割位置分割出来的两个部分 [l, mid][mid + 1, r] 哪个部分是有序的,并根据有序的那个部分确定我们该如何改变二分查找的上下界,因为我们能够根据有序的那部分判断出 target 在不在这个部分:

  • 如果 [l, mid - 1] 是有序数组,且 target 的大小满足 $[\textit{nums}[l],\textit{nums}[mid])$,则我们应该将搜索范围缩小至 [l, mid - 1],否则在 [mid + 1, r] 中寻找。
  • 如果 [mid, r] 是有序数组,且 target 的大小满足 $[\textit{nums}[mid+1],\textit{nums}[r]]$,则我们应该将搜索范围缩小至 [mid + 1, r],否则在 [l, mid - 1] 中寻找。

fig1

需要注意的是,二分的写法有很多种,所以在判断 target 大小与有序部分的关系的时候可能会出现细节上的差别。

###C++

class Solution {
public:
    int search(vector<int>& nums, int target) {
        int n = (int)nums.size();
        if (!n) {
            return -1;
        }
        if (n == 1) {
            return nums[0] == target ? 0 : -1;
        }
        int l = 0, r = n - 1;
        while (l <= r) {
            int mid = (l + r) / 2;
            if (nums[mid] == target) return mid;
            if (nums[0] <= nums[mid]) {
                if (nums[0] <= target && target < nums[mid]) {
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            } else {
                if (nums[mid] < target && target <= nums[n - 1]) {
                    l = mid + 1;
                } else {
                    r = mid - 1;
                }
            }
        }
        return -1;
    }
};

###Java

class Solution {
    public int search(int[] nums, int target) {
        int n = nums.length;
        if (n == 0) {
            return -1;
        }
        if (n == 1) {
            return nums[0] == target ? 0 : -1;
        }
        int l = 0, r = n - 1;
        while (l <= r) {
            int mid = (l + r) / 2;
            if (nums[mid] == target) {
                return mid;
            }
            if (nums[0] <= nums[mid]) {
                if (nums[0] <= target && target < nums[mid]) {
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            } else {
                if (nums[mid] < target && target <= nums[n - 1]) {
                    l = mid + 1;
                } else {
                    r = mid - 1;
                }
            }
        }
        return -1;
    }
}

###Python

class Solution:
    def search(self, nums: List[int], target: int) -> int:
        if not nums:
            return -1
        l, r = 0, len(nums) - 1
        while l <= r:
            mid = (l + r) // 2
            if nums[mid] == target:
                return mid
            if nums[0] <= nums[mid]:
                if nums[0] <= target < nums[mid]:
                    r = mid - 1
                else:
                    l = mid + 1
            else:
                if nums[mid] < target <= nums[len(nums) - 1]:
                    l = mid + 1
                else:
                    r = mid - 1
        return -1

复杂度分析

  • 时间复杂度: $O(\log n)$,其中 $n$ 为 $\textit{nums}$ 数组的大小。整个算法时间复杂度即为二分查找的时间复杂度 $O(\log n)$。

  • 空间复杂度: $O(1)$ 。我们只需要常数级别的空间存放变量。

跳跃游戏 IV

方法一:广度优先搜索

思路

记数组 $\textit{arr}$ 的长度为 $n$。题目描述的数组可以抽象为一个无向图,数组元素为图的顶点,相邻下标的元素之间有一条无向边相连,所有值相同元素之间也有无向边相连。每条边的权重都为 $1$,即此图为无权图。求从第一个元素到最后一个元素的最少操作数,即求从第一个元素到最后一个元素的最短路径长度。求无权图两点间的最短路可以用广度优先搜索来解,时间复杂度为 $O(V+E)$,其中 $V$ 为图的顶点数,$E$ 为图的边数。

在此题中,$V = n$,而 $E$ 可达 $O(n^2)$ 数量级,按照常规方法使用广度优先搜索会超时。造成超时的主要原因是所有值相同的元素构成了一个稠密子图,普通的广度优先搜索方法会对这个稠密子图中的所有边都访问一次。但对于无权图的最短路问题,这样的访问是不必要的。在第一次访问到这个子图中的某个节点时,即会将这个子图的所有其他未在队列中的节点都放入队列。在第二次访问到这个子图中的节点时,就不需要去考虑这个子图中的其他节点了,因为所有其他节点都已经在队列中或者已经被访问过了。因此,在用广度优先搜索解决此题时,先需要找出所有的值相同的子图,用一个哈希表 $\textit{idxSameValue}$ 保存。在第一次把这个子图的所有节点放入队列后,把该子图清空,就不会重复访问该子图的其他边了。

代码

###Python

class Solution:
    def minJumps(self, arr: List[int]) -> int:
        idxSameValue = defaultdict(list)
        for i, a in enumerate(arr):
            idxSameValue[a].append(i)
        visitedIndex = set()
        q = deque()
        q.append([0, 0])
        visitedIndex.add(0)
        while q:
            idx, step = q.popleft()
            if idx == len(arr) - 1:
                return step
            v = arr[idx]
            step += 1
            for i in idxSameValue[v]:
                if i not in visitedIndex:
                    visitedIndex.add(i)
                    q.append([i, step])
            del idxSameValue[v]
            if idx + 1 < len(arr) and (idx + 1) not in visitedIndex:
                visitedIndex.add(idx + 1)
                q.append([idx+1, step])
            if idx - 1 >= 0 and (idx - 1) not in visitedIndex:
                visitedIndex.add(idx - 1)
                q.append([idx-1, step])

###Java

class Solution {
    public int minJumps(int[] arr) {
        Map<Integer, List<Integer>> idxSameValue = new HashMap<Integer, List<Integer>>();
        for (int i = 0; i < arr.length; i++) {
            idxSameValue.putIfAbsent(arr[i], new ArrayList<Integer>());
            idxSameValue.get(arr[i]).add(i);
        }
        Set<Integer> visitedIndex = new HashSet<Integer>();
        Queue<int[]> queue = new ArrayDeque<int[]>();
        queue.offer(new int[]{0, 0});
        visitedIndex.add(0);
        while (!queue.isEmpty()) {
            int[] idxStep = queue.poll();
            int idx = idxStep[0], step = idxStep[1];
            if (idx == arr.length - 1) {
                return step;
            }
            int v = arr[idx];
            step++;
            if (idxSameValue.containsKey(v)) {
                for (int i : idxSameValue.get(v)) {
                    if (visitedIndex.add(i)) {
                        queue.offer(new int[]{i, step});
                    }
                }
                idxSameValue.remove(v);
            }
            if (idx + 1 < arr.length && visitedIndex.add(idx + 1)) {
                queue.offer(new int[]{idx + 1, step});
            }
            if (idx - 1 >= 0 && visitedIndex.add(idx - 1)) {
                queue.offer(new int[]{idx - 1, step});
            }
        }
        return -1;
    }
}

###C#

public class Solution {
    public int MinJumps(int[] arr) {
        Dictionary<int, IList<int>> idxSameValue = new Dictionary<int, IList<int>>();
        for (int i = 0; i < arr.Length; i++) {
            if (!idxSameValue.ContainsKey(arr[i])) {
                idxSameValue.Add(arr[i], new List<int>());
            }
            idxSameValue[arr[i]].Add(i);
        }
        ISet<int> visitedIndex = new HashSet<int>();
        Queue<int[]> queue = new Queue<int[]>();
        queue.Enqueue(new int[]{0, 0});
        visitedIndex.Add(0);
        while (queue.Count > 0) {
            int[] idxStep = queue.Dequeue();
            int idx = idxStep[0], step = idxStep[1];
            if (idx == arr.Length - 1) {
                return step;
            }
            int v = arr[idx];
            step++;
            if (idxSameValue.ContainsKey(v)) {
                foreach (int i in idxSameValue[v]) {
                    if (visitedIndex.Add(i)) {
                        queue.Enqueue(new int[]{i, step});
                    }
                }
                idxSameValue.Remove(v);
            }
            if (idx + 1 < arr.Length && visitedIndex.Add(idx + 1)) {
                queue.Enqueue(new int[]{idx + 1, step});
            }
            if (idx - 1 >= 0 && visitedIndex.Add(idx - 1)) {
                queue.Enqueue(new int[]{idx - 1, step});
            }
        }
        return 0;
    }
}

###C++

class Solution {
public:
    int minJumps(vector<int>& arr) {
        unordered_map<int, vector<int>> idxSameValue;
        for (int i = 0; i < arr.size(); i++) {
            idxSameValue[arr[i]].push_back(i);
        }
        unordered_set<int> visitedIndex;
        queue<pair<int, int>> q;
        q.emplace(0, 0);
        visitedIndex.emplace(0);
        while (!q.empty()) {
            auto [idx, step] = q.front();
            q.pop();
            if (idx == arr.size() - 1) {
                return step;
            }
            int v = arr[idx];
            step++;
            if (idxSameValue.count(v)) {
                for (auto & i : idxSameValue[v]) {
                    if (!visitedIndex.count(i)) {
                        visitedIndex.emplace(i);
                        q.emplace(i, step);
                    }
                }
                idxSameValue.erase(v);
            }
            if (idx + 1 < arr.size() && !visitedIndex.count(idx + 1)) {
                visitedIndex.emplace(idx + 1);
                q.emplace(idx + 1, step);
            }
            if (idx - 1 >= 0 && !visitedIndex.count(idx - 1)) {
                visitedIndex.emplace(idx - 1);
                q.emplace(idx - 1, step);
            }
        }
        return -1;
    }
};

###C

typedef struct IdxHashEntry {
    int key;               
    struct ListNode * head;
    UT_hash_handle hh;         
}IdxHashEntry;

typedef struct SetHashEntry {
    int key; 
    UT_hash_handle hh;         
}SetHashEntry;

typedef struct Pair {
    int idx;
    int step;
}Pair;

void hashAddIdxItem(struct IdxHashEntry **obj, int key, int val) {
    struct IdxHashEntry *pEntry = NULL;
    struct ListNode * node = (struct ListNode *)malloc(sizeof(struct ListNode));
    node->val = val;
    node->next = NULL;

    HASH_FIND(hh, *obj, &key, sizeof(key), pEntry);
    if (NULL == pEntry) {
        pEntry = (struct IdxHashEntry *)malloc(sizeof(struct IdxHashEntry));
        pEntry->key = key;
        pEntry->head = node;
        HASH_ADD(hh, *obj, key, sizeof(int), pEntry);
    } else {
        node->next = pEntry->head;
        pEntry->head = node;
    }
} 

struct IdxHashEntry *hashFindIdxItem(struct IdxHashEntry **obj, int key)
{
    struct IdxHashEntry *pEntry = NULL;
    HASH_FIND(hh, *obj, &key, sizeof(int), pEntry);
    return pEntry;
}

void hashFreeIdxAll(struct IdxHashEntry **obj)
{
    struct IdxHashEntry *curr = NULL, *next = NULL;
    HASH_ITER(hh, *obj, curr, next)
    {
        HASH_DEL(*obj, curr);  
        free(curr);      
    }
}

void hashAddSetItem(struct SetHashEntry **obj, int key) {
    struct SetHashEntry *pEntry = NULL;
    HASH_FIND(hh, *obj, &key, sizeof(key), pEntry);
    if (pEntry == NULL) {
        pEntry = malloc(sizeof(struct SetHashEntry));
        pEntry->key = key;
        HASH_ADD(hh, *obj, key, sizeof(int), pEntry);
    }
} 

struct SetHashEntry *hashFindSetItem(struct SetHashEntry **obj, int key)
{
    struct SetHashEntry *pEntry = NULL;
    HASH_FIND(hh, *obj, &key, sizeof(int), pEntry);
    return pEntry;
}

void hashFreeSetAll(struct SetHashEntry **obj)
{
    struct SetHashEntry *curr = NULL, *next = NULL;
    HASH_ITER(hh, *obj, curr, next)
    {
        HASH_DEL(*obj, curr);  
        free(curr);      
    }
}

int minJumps(int* arr, int arrSize){
    struct IdxHashEntry * idxSameValue = NULL;
    for (int i = 0; i < arrSize; i++) {
        hashAddIdxItem(&idxSameValue, arr[i], i);
    }
    
    struct SetHashEntry * visitedIndex = NULL;
    struct Pair * queue = (struct Pair *)malloc(sizeof(struct Pair) * arrSize * 2);
    int head = 0;
    int tail = 0;
    queue[tail].idx = 0;
    queue[tail].step = 0;
    tail++;
    hashAddSetItem(&visitedIndex, 0);
    while (head != tail) {
        int idx = queue[head].idx;
        int step = queue[head].step;
        head++;
        if (idx + 1 == arrSize) {
            hashFreeIdxAll(&idxSameValue);
            hashFreeSetAll(&visitedIndex);
            free(queue);
            return step;
        }
        int v = arr[idx];
        step++;
        struct IdxHashEntry * pEntry = hashFindIdxItem(&idxSameValue, v);
        if (NULL != pEntry) {
            for (struct ListNode * node = pEntry->head; node; node = node->next) {
                if (NULL == hashFindSetItem(&visitedIndex, node->val)) {
                    hashAddSetItem(&visitedIndex, node->val);
                    queue[tail].idx = node->val;
                    queue[tail].step = step;
                    tail++;
                }
            }
            HASH_DEL(idxSameValue, pEntry);
        }
        if (idx + 1 < arrSize && NULL == hashFindSetItem(&visitedIndex, idx + 1)) {
            hashAddSetItem(&visitedIndex, idx + 1);
            queue[tail].idx = idx + 1;
            queue[tail].step = step;
            tail++;
        }
        if (idx - 1 >= 0 && NULL == hashFindSetItem(&visitedIndex, idx - 1)) {
            hashAddSetItem(&visitedIndex, idx - 1);
            queue[tail].idx = idx - 1;
            queue[tail].step = step;
            tail++;
        }
    }
    hashFreeIdxAll(&idxSameValue);
    hashFreeSetAll(&visitedIndex);
    free(queue);
    return -1;
}

###go

func minJumps(arr []int) int {
    n := len(arr)
    idx := map[int][]int{}
    for i, v := range arr {
        idx[v] = append(idx[v], i)
    }
    vis := map[int]bool{0: true}
    type pair struct{ idx, step int }
    q := []pair{{}}
    for {
        p := q[0]
        q = q[1:]
        i, step := p.idx, p.step
        if i == n-1 {
            return step
        }
        for _, j := range idx[arr[i]] {
            if !vis[j] {
                vis[j] = true
                q = append(q, pair{j, step + 1})
            }
        }
        delete(idx, arr[i])
        if !vis[i+1] {
            vis[i+1] = true
            q = append(q, pair{i + 1, step + 1})
        }
        if i > 0 && !vis[i-1] {
            vis[i-1] = true
            q = append(q, pair{i - 1, step + 1})
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为数组 $\textit{arr}$ 的长度。每个元素最多只进入队列一次,最多被判断是否需要进入队列三次。

  • 空间复杂度:$O(n)$,其中 $n$ 为数组 $\textit{arr}$ 的长度。队列,哈希表和哈希集合均最多存储 $n$ 个元素。

跳跃游戏 III

方法一:广度优先搜索

我们可以使用广度优先搜索的方法得到从 $\text{start}$ 开始能够到达的所有位置,如果其中某个位置对应的元素值为 $0$,那么就返回 $\text{True}$。

具体地,我们初始时将 $\text{start}$ 加入队列。在每一次的搜索过程中,我们取出队首的节点 $u$,它可以到达的位置为 $u + \text{arr}[u]$ 和 $u - \text{arr}[u]$。如果某个位置落在数组的下标范围 $[0, \text{len}(\text{arr}))$ 内,并且没有被搜索过,则将该位置加入队尾。只要我们搜索到一个对应元素值为 $0$ 的位置,我们就返回 $\text{True}$。在搜索结束后,如果仍然没有找到符合要求的位置,我们就返回 $\text{False}$。

###C++

class Solution {
public:
    bool canReach(vector<int>& arr, int start) {
        if (arr[start] == 0) {
            return true;
        }
        
        int n = arr.size();
        vector<bool> used(n);
        queue<int> q;
        q.push(start);
        used[start] = true;

        while (!q.empty()) {
            int u = q.front();
            q.pop();
            if (u + arr[u] < n && !used[u + arr[u]]) {
                if (arr[u + arr[u]] == 0) {
                    return true;
                }
                q.push(u + arr[u]);
                used[u + arr[u]] = true;
            }
            if (u - arr[u] >= 0 && !used[u - arr[u]]) {
                if (arr[u - arr[u]] == 0) {
                    return true;
                }
                q.push(u - arr[u]);
                used[u - arr[u]] = true;
            }
        }
        return false;
    }
};

###Python

class Solution:
    def canReach(self, arr: List[int], start: int) -> bool:
        if arr[start] == 0:
            return True

        n = len(arr)
        used = {start}
        q = collections.deque([start])

        while len(q) > 0:
            u = q.popleft()
            for v in [u + arr[u], u - arr[u]]:
                if 0 <= v < n and v not in used:
                    if arr[v] == 0:
                        return True
                    q.append(v)
                    used.add(v)
        
        return False

###Java

class Solution {
    public boolean canReach(int[] arr, int start) {
        if (arr[start] == 0) {
            return true;
        }

        int n = arr.length;
        boolean[] used = new boolean[n];
        used[start] = true;
        Queue<Integer> q = new LinkedList<>();
        q.offer(start);

        while (!q.isEmpty()) {
            int u = q.poll();
            for (int v : new int[]{u + arr[u], u - arr[u]}) {
                if (0 <= v && v < n && !used[v]) {
                    if (arr[v] == 0) {
                        return true;
                    }
                    q.offer(v);
                    used[v] = true;
                }
            }
        }
        
        return false;
    }
}

###C#

public class Solution {
    public bool CanReach(int[] arr, int start) {
        if (arr[start] == 0) {
            return true;
        }

        int n = arr.Length;
        bool[] used = new bool[n];
        used[start] = true;
        Queue<int> q = new Queue<int>();
        q.Enqueue(start);

        while (q.Count > 0) {
            int u = q.Dequeue();
            foreach (int v in new int[]{u + arr[u], u - arr[u]}) {
                if (0 <= v && v < n && !used[v]) {
                    if (arr[v] == 0) {
                        return true;
                    }
                    q.Enqueue(v);
                    used[v] = true;
                }
            }
        }
        
        return false;
    }
}

###Go

func canReach(arr []int, start int) bool {
    if arr[start] == 0 {
        return true
    }

    n := len(arr)
    used := make([]bool, n)
    used[start] = true
    q := []int{start}

    for len(q) > 0 {
        u := q[0]
        q = q[1:]
        for _, v := range []int{u + arr[u], u - arr[u]} {
            if 0 <= v && v < n && !used[v] {
                if arr[v] == 0 {
                    return true
                }
                q = append(q, v)
                used[v] = true
            }
        }
    }
    
    return false
}

###C

bool canReach(int* arr, int arrSize, int start) {
    if (arr[start] == 0) {
        return true;
    }

    bool* used = (bool*)calloc(arrSize, sizeof(bool));
    used[start] = true;
    int* q = (int*)malloc(arrSize * sizeof(int));
    int front = 0, rear = 0;
    q[rear++] = start;

    while (front < rear) {
        int u = q[front++];
        int next[] = {u + arr[u], u - arr[u]};
        for (int i = 0; i < 2; i++) {
            int v = next[i];
            if (0 <= v && v < arrSize && !used[v]) {
                if (arr[v] == 0) {
                    free(used);
                    free(q);
                    return true;
                }
                q[rear++] = v;
                used[v] = true;
            }
        }
    }
    
    free(used);
    free(q);
    return false;
}

###JavaScript

var canReach = function(arr, start) {
    if (arr[start] === 0) {
        return true;
    }

    const n = arr.length;
    const used = new Array(n).fill(false);
    used[start] = true;
    const q = new Queue([start]);

    while (!q.isEmpty()) {
        const u = q.dequeue();
        for (const v of [u + arr[u], u - arr[u]]) {
            if (0 <= v && v < n && !used[v]) {
                if (arr[v] === 0) {
                    return true;
                }
                q.enqueue(v);
                used[v] = true;
            }
        }
    }
    
    return false;
};

###TypeScript

function canReach(arr: number[], start: number): boolean {
    if (arr[start] === 0) {
        return true;
    }

    const n = arr.length;
    const used: boolean[] = new Array(n).fill(false);
    used[start] = true;
    const q = new Queue<number>([start]);

    while (!q.isEmpty()) {
        const u = q.dequeue();
        for (const v of [u + arr[u], u - arr[u]]) {
            if (0 <= v && v < n && !used[v]) {
                if (arr[v] === 0) {
                    return true;
                }
                q.enqueue(v);
                used[v] = true;
            }
        }
    }
    
    return false;
}

###Rust

use std::collections::VecDeque;

impl Solution {
    pub fn can_reach(arr: Vec<i32>, start: i32) -> bool {
        let start = start as usize;
        if arr[start] == 0 {
            return true;
        }

        let n = arr.len();
        let mut used = vec![false; n];
        used[start] = true;
        let mut q = VecDeque::new();
        q.push_back(start);

        while let Some(u) = q.pop_front() {
            for &v in &[u as i32 + arr[u], u as i32 - arr[u]] {
                if 0 <= v && (v as usize) < n {
                    let v = v as usize;
                    if !used[v] {
                        if arr[v] == 0 {
                            return true;
                        }
                        q.push_back(v);
                        used[v] = true;
                    }
                }
            }
        }
        
        false
    }
}

复杂度分析

  • 时间复杂度:$O(N)$,其中 $N$ 是数组 arr 的长度。

  • 空间复杂度:$O(N)$。

寻找旋转排序数组中的最小值 II

📺 视频题解

...寻找旋转排序数组中的最小值 II.mp4

📖 文字题解

前言

本题是「153. 寻找旋转排序数组中的最小值」的延伸。读者可以先尝试第 153 题,体会在旋转数组中进行二分查找的思路,再来尝试解决本题。

方法一:二分查找

思路与算法

一个包含重复元素的升序数组在经过旋转之后,可以得到下面可视化的折线图:

fig1

其中横轴表示数组元素的下标,纵轴表示数组元素的值。图中标出了最小值的位置,是我们需要查找的目标。

我们考虑数组中的最后一个元素 $x$:在最小值右侧的元素,它们的值一定都小于等于 $x$;而在最小值左侧的元素,它们的值一定都大于等于 $x$。因此,我们可以根据这一条性质,通过二分查找的方法找出最小值。

在二分查找的每一步中,左边界为 $\it low$,右边界为 $\it high$,区间的中点为 $\it pivot$,最小值就在该区间内。我们将中轴元素 $\textit{nums}[\textit{pivot}]$ 与右边界元素 $\textit{nums}[\textit{high}]$ 进行比较,可能会有以下的三种情况:

第一种情况是 $\textit{nums}[\textit{pivot}] < \textit{nums}[\textit{high}]$。如下图所示,这说明 $\textit{nums}[\textit{pivot}]$ 是最小值右侧的元素,因此我们可以忽略二分查找区间的右半部分。

fig2

第二种情况是 $\textit{nums}[\textit{pivot}] > \textit{nums}[\textit{high}]$。如下图所示,这说明 $\textit{nums}[\textit{pivot}]$ 是最小值左侧的元素,因此我们可以忽略二分查找区间的左半部分。

fig3

第三种情况是 $\textit{nums}[\textit{pivot}] == \textit{nums}[\textit{high}]$。如下图所示,由于重复元素的存在,我们并不能确定 $\textit{nums}[\textit{pivot}]$ 究竟在最小值的左侧还是右侧,因此我们不能莽撞地忽略某一部分的元素。我们唯一可以知道的是,由于它们的值相同,所以无论 $\textit{nums}[\textit{high}]$ 是不是最小值,都有一个它的「替代品」$\textit{nums}[\textit{pivot}]$,因此我们可以忽略二分查找区间的右端点。

fig4

当二分查找结束时,我们就得到了最小值所在的位置。

###C++

class Solution {
public:
    int findMin(vector<int>& nums) {
        int low = 0;
        int high = nums.size() - 1;
        while (low < high) {
            int pivot = low + (high - low) / 2;
            if (nums[pivot] < nums[high]) {
                high = pivot;
            }
            else if (nums[pivot] > nums[high]) {
                low = pivot + 1;
            }
            else {
                high -= 1;
            }
        }
        return nums[low];
    }
};

###Java

class Solution {
    public int findMin(int[] nums) {
        int low = 0;
        int high = nums.length - 1;
        while (low < high) {
            int pivot = low + (high - low) / 2;
            if (nums[pivot] < nums[high]) {
                high = pivot;
            } else if (nums[pivot] > nums[high]) {
                low = pivot + 1;
            } else {
                high -= 1;
            }
        }
        return nums[low];
    }
}

###Python

class Solution:
    def findMin(self, nums: List[int]) -> int:    
        low, high = 0, len(nums) - 1
        while low < high:
            pivot = low + (high - low) // 2
            if nums[pivot] < nums[high]:
                high = pivot 
            elif nums[pivot] > nums[high]:
                low = pivot + 1
            else:
                high -= 1
        return nums[low]

###C

int findMin(int* nums, int numsSize) {
    int low = 0;
    int high = numsSize - 1;
    while (low < high) {
        int pivot = low + (high - low) / 2;
        if (nums[pivot] < nums[high]) {
            high = pivot;
        } else if (nums[pivot] > nums[high]) {
            low = pivot + 1;
        } else {
            high -= 1;
        }
    }
    return nums[low];
}

###golang

func findMin(nums []int) int {
    low, high := 0, len(nums) - 1
    for low < high {
        pivot := low + (high - low) / 2
        if nums[pivot] < nums[high] {
            high = pivot
        } else if nums[pivot] > nums[high] {
            low = pivot + 1
        } else {
            high--
        }
    }
    return nums[low]
}

###JavaScript

var findMin = function(nums) {
    let low = 0;
    let high = nums.length - 1;
    while (low < high) {
        const pivot = low + Math.floor((high - low) / 2);
        if (nums[pivot] < nums[high]) {
            high = pivot;
        } else if (nums[pivot] > nums[high]) {
            low = pivot + 1;
        } else {
            high -= 1;
        }
    }
    return nums[low];
};

复杂度分析

  • 时间复杂度:平均时间复杂度为 $O(\log n)$,其中 $n$ 是数组 $\it nums$ 的长度。如果数组是随机生成的,那么数组中包含相同元素的概率很低,在二分查找的过程中,大部分情况都会忽略一半的区间。而在最坏情况下,如果数组中的元素完全相同,那么 $\texttt{while}$ 循环就需要执行 $n$ 次,每次忽略区间的右端点,时间复杂度为 $O(n)$。

  • 空间复杂度:$O(1)$。

寻找旋转排序数组中的最小值

方法一:二分查找

思路与算法

一个不包含重复元素的升序数组在经过旋转之后,可以得到下面可视化的折线图:

fig1

其中横轴表示数组元素的下标,纵轴表示数组元素的值。图中标出了最小值的位置,是我们需要查找的目标。

我们考虑数组中的最后一个元素 $x$:在最小值右侧的元素(不包括最后一个元素本身),它们的值一定都严格小于 $x$;而在最小值左侧的元素,它们的值一定都严格大于 $x$。因此,我们可以根据这一条性质,通过二分查找的方法找出最小值。

在二分查找的每一步中,左边界为 $\it low$,右边界为 $\it high$,区间的中点为 $\it pivot$,最小值就在该区间内。我们将中轴元素 $\textit{nums}[\textit{pivot}]$ 与右边界元素 $\textit{nums}[\textit{high}]$ 进行比较,可能会有以下的三种情况:

第一种情况是 $\textit{nums}[\textit{pivot}] < \textit{nums}[\textit{high}]$。如下图所示,这说明 $\textit{nums}[\textit{pivot}]$ 是最小值右侧的元素,因此我们可以忽略二分查找区间的右半部分。

fig2

第二种情况是 $\textit{nums}[\textit{pivot}] > \textit{nums}[\textit{high}]$。如下图所示,这说明 $\textit{nums}[\textit{pivot}]$ 是最小值左侧的元素,因此我们可以忽略二分查找区间的左半部分。

fig3

由于数组不包含重复元素,并且只要当前的区间长度不为 $1$,$\it pivot$ 就不会与 $\it high$ 重合;而如果当前的区间长度为 $1$,这说明我们已经可以结束二分查找了。因此不会存在 $\textit{nums}[\textit{pivot}] = \textit{nums}[\textit{high}]$ 的情况。

当二分查找结束时,我们就得到了最小值所在的位置。

###C++

class Solution {
public:
    int findMin(vector<int>& nums) {
        int low = 0;
        int high = nums.size() - 1;
        while (low < high) {
            int pivot = low + (high - low) / 2;
            if (nums[pivot] < nums[high]) {
                high = pivot;
            }
            else {
                low = pivot + 1;
            }
        }
        return nums[low];
    }
};

###Java

class Solution {
    public int findMin(int[] nums) {
        int low = 0;
        int high = nums.length - 1;
        while (low < high) {
            int pivot = low + (high - low) / 2;
            if (nums[pivot] < nums[high]) {
                high = pivot;
            } else {
                low = pivot + 1;
            }
        }
        return nums[low];
    }
}

###Python

class Solution:
    def findMin(self, nums: List[int]) -> int:    
        low, high = 0, len(nums) - 1
        while low < high:
            pivot = low + (high - low) // 2
            if nums[pivot] < nums[high]:
                high = pivot 
            else:
                low = pivot + 1
        return nums[low]

###C

int findMin(int* nums, int numsSize) {
    int low = 0;
    int high = numsSize - 1;
    while (low < high) {
        int pivot = low + (high - low) / 2;
        if (nums[pivot] < nums[high]) {
            high = pivot;
        } else {
            low = pivot + 1;
        }
    }
    return nums[low];
}

###golang

func findMin(nums []int) int {
    low, high := 0, len(nums) - 1
    for low < high {
        pivot := low + (high - low) / 2
        if nums[pivot] < nums[high] {
            high = pivot
        } else {
            low = pivot + 1
        }
    }
    return nums[low]
}

###JavaScript

var findMin = function(nums) {
    let low = 0;
    let high = nums.length - 1;
    while (low < high) {
        const pivot = low + Math.floor((high - low) / 2);
        if (nums[pivot] < nums[high]) {
            high = pivot;
        } else {
            low = pivot + 1;
        }
    }
    return nums[low];
};

复杂度分析

  • 时间复杂度:时间复杂度为 $O(\log n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。在二分查找的过程中,每一步会忽略一半的区间,因此时间复杂度为 $O(\log n)$。

  • 空间复杂度:$O(1)$。

检查数组是否是好的

方法一:排序

思路与算法

将数组进行排序,随后遍历前 $n$ 个元素,比对是否等于 $i + 1$,最后检查末尾元素是否等于 $n$ 即可。

代码

###C++

class Solution {
public:
    bool isGood(vector<int>& nums) {
        sort(nums.begin(), nums.end());
        int n = nums.size() - 1;
        for (int i = 0; i < n; ++i) {
            if (nums[i] != i + 1) {
                return false;
            }
        }
        return nums[n] == n;
    }
};

###Java

class Solution {
    public boolean isGood(int[] nums) {
        Arrays.sort(nums);
        int n = nums.length - 1;
        for (int i = 0; i < n; i++) {
            if (nums[i] != i + 1) {
                return false;
            }
        }
        return nums[n] == n;
    }
}

###Python

class Solution:
    def isGood(self, nums: List[int]) -> bool:
        nums.sort()
        n = len(nums) - 1
        for i in range(n):
            if nums[i] != i + 1:
                return False
        return nums[n] == n

###JavaScript

var isGood = function(nums) {
    nums.sort((a, b) => a - b);
    const n = nums.length - 1;
    for (let i = 0; i < n; i++) {
        if (nums[i] !== i + 1) {
            return false;
        }
    }
    return nums[n] === n;
};

###TypeScript

function isGood(nums: number[]): boolean {
    nums.sort((a, b) => a - b);
    const n = nums.length - 1;
    for (let i = 0; i < n; i++) {
        if (nums[i] !== i + 1) {
            return false;
        }
    }
    return nums[n] === n;
};

###Go

func isGood(nums []int) bool {
    sort.Ints(nums)
    n := len(nums) - 1
    for i := 0; i < n; i++ {
        if nums[i] != i + 1 {
            return false
        }
    }
    return nums[n] == n
}

###C#

public class Solution {
    public bool IsGood(int[] nums) {
        Array.Sort(nums);
        int n = nums.Length - 1;
        for (int i = 0; i < n; i++) {
            if (nums[i] != i + 1) {
                return false;
            }
        }
        return nums[n] == n;
    }
}

###C

int cmp(const void *a, const void *b) {
    return (*(int*)a - *(int*)b);
}

bool isGood(int* nums, int numsSize) {
    qsort(nums, numsSize, sizeof(int), cmp);
    int n = numsSize - 1;
    for (int i = 0; i < n; i++) {
        if (nums[i] != i + 1) {
            return false;
        }
    }
    return nums[n] == n;
}

###Rust

impl Solution {
    pub fn is_good(nums: Vec<i32>) -> bool {
        let mut nums = nums;
        nums.sort_unstable();
        let n = nums.len() - 1;
        for i in 0..n {
            if nums[i] != (i + 1) as i32 {
                return false;
            }
        }
        nums[n] == n as i32
    }
}

复杂度分析

  • 时间复杂度:$O(n\log n)$,其中 $n$ 是数组的长度。

  • 空间复杂度:$O(\log n)$,其中 $n$ 是数组的长度。

方法二:统计频数

思路与算法

遍历数组,使用数组来统计每个元素出现的次数。

在遍历统计的过程中,如发现有超过 $n$ 的数,就可以提前判断为不符合并返回。数字 $n$,出现的次数不能超过 $2$ 次。 其它数字,出现的次数不能超过 $1$ 次。否则提前判断为不符合并返回。

若全部满足,则说明它是好数组,并返回结果。

代码

###C++

class Solution {
public:
    bool isGood(vector<int>& nums) {
        int n = nums.size();
        vector<int> count(n, 0);
        for (int a : nums) {
            if (a >= n) {
                return false;
            }
            if (a < n - 1 && count[a] > 0) {
                return false;
            }
            if (a == n - 1 && count[a] > 1) {
                return false;
            }
            count[a]++;
        }
        return true;
    }
};

###Java

class Solution {
    public boolean isGood(int[] nums) {
        int n = nums.length;
        int[] count = new int[n];
        for (int a : nums) {
            if (a >= n) {
                return false;
            }
            if (a < n - 1 && count[a] > 0) {
                return false;
            }
            if (a == n - 1 && count[a] > 1) {
                return false;
            }
            count[a]++;
        }
        return true;
    }
}

###Python

class Solution:
    def isGood(self, nums: List[int]) -> bool:
        n = len(nums)
        count = [0] * n
        for a in nums:
            if a >= n:
                return False
            if a < n - 1 and count[a] > 0:
                return False
            if a == n - 1 and count[a] > 1:
                return False
            count[a] += 1
        return True

###JavaScript

var isGood = function(nums) {
    const n = nums.length;
    const count = new Array(n).fill(0);
    for (const a of nums) {
        if (a >= n) {
            return false;
        }
        if (a < n - 1 && count[a] > 0) {
            return false;
        }
        if (a === n - 1 && count[a] > 1) {
            return false;
        }
        count[a]++;
    }
    return true;
};

###TypeScript

function isGood(nums: number[]): boolean {
    const n = nums.length;
    const count = new Array(n).fill(0);
    for (const a of nums) {
        if (a >= n) {
            return false;
        }
        if (a < n - 1 && count[a] > 0) {
            return false;
        }
        if (a === n - 1 && count[a] > 1) {
            return false;
        }
        count[a]++;
    }
    return true;
};

###Go

func isGood(nums []int) bool {
    n := len(nums)
    count := make([]int, n)
    for _, a := range nums {
        if a < 1 || a >= n {
            return false
        }
        if a < n - 1 && count[a] > 0 {
            return false
        }
        if a == n - 1 && count[a] > 1 {
            return false
        }
        count[a]++
    }
    return true
}

###C#

public class Solution {
    public bool IsGood(int[] nums) {
        int n = nums.Length;
        int[] count = new int[n];
        foreach (int a in nums) {
            if (a < 1 || a >= n) {
                return false;
            }
            if (a < n - 1 && count[a] > 0) {
                return false;
            }
            if (a == n - 1 && count[a] > 1) {
                return false;
            }
            count[a]++;
        }
        return true;
    }
}

###C

bool isGood(int* nums, int numsSize) {
    int n = numsSize;
    int* count = (int*)calloc(n, sizeof(int));
    for (int i = 0; i < n; i++) {
        int a = nums[i];
        if (a < 1 || a >= n) {
            free(count);
            return false;
        }
        if (a < n - 1 && count[a] > 0) {
            free(count);
            return false;
        }
        if (a == n - 1 && count[a] > 1) {
            free(count);
            return false;
        }
        count[a]++;
    }
    free(count);
    return true;
}

###Rust

impl Solution {
    pub fn is_good(nums: Vec<i32>) -> bool {
        let n = nums.len() as i32;
        let mut count = vec![0; n as usize];
        for &a in nums.iter() {
            if a < 1 || a >= n {
                return false;
            }
            if a < n - 1 && count[a as usize] > 0 {
                return false;
            }
            if a == n - 1 && count[a as usize] > 1 {
                return false;
            }
            count[a as usize] += 1;
        }
        true
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是数组的长度。

  • 空间复杂度:$O(n)$,其中 $n$ 是数组的长度。

循环轮转矩阵

方法一:枚举每一层

思路与算法

对于一个 $m \times n$ 的矩阵 $\textit{grid}$,它的层数为 $\min(m / 2, n / 2)$。我们可以从外向内枚举矩阵的每一层模拟循环轮转操作。

为了方便模拟,我们从左上角起按照逆时针方向遍历每一层的元素。在本文中,我们将遍历过程分为四个部分,每个部分按顺序遍历每条边除了最后一个元素以外的所有元素。

我们将这些元素的行坐标、列坐标与数值保存在对应的数组 $r, c, \textit{val}$ 中,并计算元素总数,即数组的长度 $\textit{total}$。此时,如果对该层元素进行 $\textit{total}$ 次循环轮转操作,那么该层元素不会改变。因此,实际的循环轮转操作数量即为 $\textit{kk} = k % \textit{total}$。

那么,这一层中遍历到的第 $i$ 个位置在轮转操作后存放的值对应 $\textit{val}$ 数组中下标为 $(i - \textit{kk} + \textit{total}) % \textit{total}$ 的值。此处在取模时加上 $\textit{total}$ 是为了避免出现负数。

我们遍历行列坐标数组,并在 $\textit{grid}$ 中更新每个坐标对应的轮转操作后的取值。当枚举并更新完所有层后,$\textit{grid}$ 即为轮转操作后的矩阵。

代码

###C++

class Solution {
public:
    vector<vector<int>> rotateGrid(vector<vector<int>>& grid, int k) {
        int m = grid.size();
        int n = grid[0].size();
        int nlayer = min(m / 2, n / 2);   // 层数
        // 从左上角起逆时针枚举每一层
        for (int layer = 0; layer < nlayer; ++layer){
            vector<int> r, c, val;   // 每个元素的行下标,列下标与数值
            for (int i = layer; i < m - layer - 1; ++i){   // 左
                r.push_back(i);
                c.push_back(layer);
                val.push_back(grid[i][layer]);
            }
            for (int j = layer; j < n - layer - 1; ++j){   // 下
                r.push_back(m - layer - 1);
                c.push_back(j);
                val.push_back(grid[m-layer-1][j]);
            }
            for (int i = m - layer - 1; i > layer; --i){   // 右
                r.push_back(i);
                c.push_back(n - layer - 1);
                val.push_back(grid[i][n-layer-1]);
            }
            for (int j = n - layer - 1; j > layer; --j){   // 上
                r.push_back(layer);
                c.push_back(j);
                val.push_back(grid[layer][j]);
            }
            int total = val.size();   // 每一层的元素总数
            int kk = k % total;   // 等效轮转次数
            // 找到每个下标对应的轮转后的取值
            for (int i = 0; i < total; ++i){
                int idx = (i + total - kk) % total;   // 轮转后取值对应的下标
                grid[r[i]][c[i]] = val[idx];
            }
        }
        return grid;
    }
};

###Python

class Solution:
    def rotateGrid(self, grid: List[List[int]], k: int) -> List[List[int]]:
        m, n = len(grid), len(grid[0])
        nlayer = min(m // 2, n // 2)   # 层数
        # 从左上角起逆时针枚举每一层
        for layer in range(nlayer):
            r = []   # 每个元素的行下标
            c = []   # 每个元素的列下标
            val = []   # 每个元素的数值
            for i in range(layer, m - layer - 1):   # 左 
                r.append(i)
                c.append(layer)
                val.append(grid[i][layer])
            for j in range(layer, n - layer - 1):   # 下
                r.append(m - layer - 1)
                c.append(j)
                val.append(grid[m-layer-1][j])
            for i in range(m - layer - 1, layer, -1):   # 右
                r.append(i)
                c.append(n - layer - 1)
                val.append(grid[i][n-layer-1])
            for j in range(n - layer - 1, layer, -1):   # 上
                r.append(layer)
                c.append(j)
                val.append(grid[layer][j])
            total = len(val)   # 每一层的元素总数
            kk = k % total   # 等效轮转次数
            # 找到每个下标对应的轮转后的取值
            for i in range(total):
                idx = (i + total - kk) % total   # 轮转后取值对应的下标
                grid[r[i]][c[i]] = val[idx]
        return grid

###Java

class Solution {
    public int[][] rotateGrid(int[][] grid, int k) {
        int m = grid.length;
        int n = grid[0].length;
        int nlayer = Math.min(m / 2, n / 2);   // 层数
        // 从左上角起逆时针枚举每一层
        for (int layer = 0; layer < nlayer; ++layer){
            List<Integer> r = new ArrayList<>();
            List<Integer> c = new ArrayList<>();
            List<Integer> val = new ArrayList<>();   // 每个元素的行下标,列下标与数值
            for (int i = layer; i < m - layer - 1; ++i){   // 左
                r.add(i);
                c.add(layer);
                val.add(grid[i][layer]);
            }
            for (int j = layer; j < n - layer - 1; ++j){   // 下
                r.add(m - layer - 1);
                c.add(j);
                val.add(grid[m - layer - 1][j]);
            }
            for (int i = m - layer - 1; i > layer; --i){   // 右
                r.add(i);
                c.add(n - layer - 1);
                val.add(grid[i][n - layer - 1]);
            }
            for (int j = n - layer - 1; j > layer; --j){   // 上
                r.add(layer);
                c.add(j);
                val.add(grid[layer][j]);
            }
            int total = val.size();   // 每一层的元素总数
            int kk = k % total;   // 等效轮转次数
            // 找到每个下标对应的轮转后的取值
            for (int i = 0; i < total; ++i){
                int idx = (i + total - kk) % total;   // 轮转后取值对应的下标
                grid[r.get(i)][c.get(i)] = val.get(idx);
            }
        }
        return grid;
    }
}

###C#

public class Solution {
    public int[][] RotateGrid(int[][] grid, int k) {
        int m = grid.Length;
        int n = grid[0].Length;
        int nlayer = Math.Min(m / 2, n / 2);   // 层数
        // 从左上角起逆时针枚举每一层
        for (int layer = 0; layer < nlayer; ++layer){
            List<int> r = new List<int>();
            List<int> c = new List<int>();
            List<int> val = new List<int>();   // 每个元素的行下标,列下标与数值
            for (int i = layer; i < m - layer - 1; ++i){   // 左
                r.Add(i);
                c.Add(layer);
                val.Add(grid[i][layer]);
            }
            for (int j = layer; j < n - layer - 1; ++j){   // 下
                r.Add(m - layer - 1);
                c.Add(j);
                val.Add(grid[m - layer - 1][j]);
            }
            for (int i = m - layer - 1; i > layer; --i){   // 右
                r.Add(i);
                c.Add(n - layer - 1);
                val.Add(grid[i][n - layer - 1]);
            }
            for (int j = n - layer - 1; j > layer; --j){   // 上
                r.Add(layer);
                c.Add(j);
                val.Add(grid[layer][j]);
            }
            int total = val.Count;   // 每一层的元素总数
            int kk = k % total;   // 等效轮转次数
            // 找到每个下标对应的轮转后的取值
            for (int i = 0; i < total; ++i){
                int idx = (i + total - kk) % total;   // 轮转后取值对应的下标
                grid[r[i]][c[i]] = val[idx];
            }
        }
        return grid;
    }
}

###Go

func rotateGrid(grid [][]int, k int) [][]int {
    m := len(grid)
    n := len(grid[0])
    nlayer := min(m / 2, n / 2)   // 层数
    // 从左上角起逆时针枚举每一层
    for layer := 0; layer < nlayer; layer++ {
        r := make([]int, 0)
        c := make([]int, 0)
        val := make([]int, 0)   // 每个元素的行下标,列下标与数值
        for i := layer; i < m - layer - 1; i++ {   // 左
            r = append(r, i)
            c = append(c, layer)
            val = append(val, grid[i][layer])
        }
        for j := layer; j < n - layer - 1; j++ {   // 下
            r = append(r, m - layer - 1)
            c = append(c, j)
            val = append(val, grid[m-layer - 1][j])
        }
        for i := m - layer - 1; i > layer; i-- {   // 右
            r = append(r, i)
            c = append(c, n - layer - 1)
            val = append(val, grid[i][n - layer - 1])
        }
        for j := n - layer - 1; j > layer; j-- {   // 上
            r = append(r, layer)
            c = append(c, j)
            val = append(val, grid[layer][j])
        }
        total := len(val)   // 每一层的元素总数
        kk := k % total   // 等效轮转次数
        // 找到每个下标对应的轮转后的取值
        for i := 0; i < total; i++ {
            idx := (i + total - kk) % total   // 轮转后取值对应的下标
            grid[r[i]][c[i]] = val[idx]
        }
    }
    return grid
}

###C

int** rotateGrid(int** grid, int gridSize, int* gridColSize, int k, int* returnSize, int** returnColumnSizes) {
    int m = gridSize;
    int n = gridColSize[0];
    *returnSize = m;
    *returnColumnSizes = (int*)malloc(m * sizeof(int));
    for (int i = 0; i < m; i++) {
        (*returnColumnSizes)[i] = n;
    }
    
    int nlayer = fmin(m / 2, n / 2);   // 层数
    // 从左上角起逆时针枚举每一层
    for (int layer = 0; layer < nlayer; ++layer) {
        int maxSize = 2 * (m + n - 4 * layer - 2);
        int* r = (int*)malloc(maxSize * sizeof(int));
        int* c = (int*)malloc(maxSize * sizeof(int));
        int* val = (int*)malloc(maxSize * sizeof(int));   // 每个元素的行下标,列下标与数值
        int idx = 0;
        
        for (int i = layer; i < m - layer - 1; ++i) {   // 左
            r[idx] = i;
            c[idx] = layer;
            val[idx] = grid[i][layer];
            idx++;
        }
        for (int j = layer; j < n - layer - 1; ++j) {   // 下
            r[idx] = m - layer - 1;
            c[idx] = j;
            val[idx] = grid[m - layer - 1][j];
            idx++;
        }
        for (int i = m - layer - 1; i > layer; --i) {   // 右
            r[idx] = i;
            c[idx] = n - layer - 1;
            val[idx] = grid[i][n - layer - 1];
            idx++;
        }
        for (int j = n - layer - 1; j > layer; --j) {   // 上
            r[idx] = layer;
            c[idx] = j;
            val[idx] = grid[layer][j];
            idx++;
        }
        
        int total = idx;   // 每一层的元素总数
        int kk = k % total;   // 等效轮转次数
        // 找到每个下标对应的轮转后的取值
        for (int i = 0; i < total; ++i) {
            int pos = (i + total - kk) % total;   // 轮转后取值对应的下标
            grid[r[i]][c[i]] = val[pos];
        }
        
        free(r);
        free(c);
        free(val);
    }
    return grid;
}

###JavaScript

var rotateGrid = function(grid, k) {
    const m = grid.length;
    const n = grid[0].length;
    const nlayer = Math.min(Math.floor(m / 2), Math.floor(n / 2));   // 层数
    // 从左上角起逆时针枚举每一层
    for (let layer = 0; layer < nlayer; ++layer) {
        const r = [];
        const c = [];
        const val = [];   // 每个元素的行下标,列下标与数值
        for (let i = layer; i < m - layer - 1; ++i) {   // 左
            r.push(i);
            c.push(layer);
            val.push(grid[i][layer]);
        }
        for (let j = layer; j < n - layer - 1; ++j) {   // 下
            r.push(m - layer - 1);
            c.push(j);
            val.push(grid[m - layer - 1][j]);
        }
        for (let i = m - layer - 1; i > layer; --i) {   // 右
            r.push(i);
            c.push(n - layer - 1);
            val.push(grid[i][n - layer - 1]);
        }
        for (let j = n - layer - 1; j > layer; --j) {   // 上
            r.push(layer);
            c.push(j);
            val.push(grid[layer][j]);
        }
        const total = val.length;   // 每一层的元素总数
        const kk = k % total;   // 等效轮转次数
        // 找到每个下标对应的轮转后的取值
        for (let i = 0; i < total; ++i) {
            const idx = (i + total - kk) % total;   // 轮转后取值对应的下标
            grid[r[i]][c[i]] = val[idx];
        }
    }
    return grid;
};

###TypeScript

function rotateGrid(grid: number[][], k: number): number[][] {
    const m: number = grid.length;
    const n: number = grid[0].length;
    const nlayer: number = Math.min(Math.floor(m / 2), Math.floor(n / 2));   // 层数
    // 从左上角起逆时针枚举每一层
    for (let layer = 0; layer < nlayer; ++layer) {
        const r: number[] = [];
        const c: number[] = [];
        const val: number[] = [];   // 每个元素的行下标,列下标与数值
        for (let i = layer; i < m - layer - 1; ++i) {   // 左
            r.push(i);
            c.push(layer);
            val.push(grid[i][layer]);
        }
        for (let j = layer; j < n - layer - 1; ++j) {   // 下
            r.push(m - layer - 1);
            c.push(j);
            val.push(grid[m - layer - 1][j]);
        }
        for (let i = m - layer - 1; i > layer; --i) {   // 右
            r.push(i);
            c.push(n - layer - 1);
            val.push(grid[i][n - layer - 1]);
        }
        for (let j = n - layer - 1; j > layer; --j) {   // 上
            r.push(layer);
            c.push(j);
            val.push(grid[layer][j]);
        }
        const total: number = val.length;   // 每一层的元素总数
        const kk: number = k % total;   // 等效轮转次数
        // 找到每个下标对应的轮转后的取值
        for (let i = 0; i < total; ++i) {
            const idx: number = (i + total - kk) % total;   // 轮转后取值对应的下标
            grid[r[i]][c[i]] = val[idx];
        }
    }
    return grid;
}

###Rust

impl Solution {
    pub fn rotate_grid(mut grid: Vec<Vec<i32>>, k: i32) -> Vec<Vec<i32>> {
        let m = grid.len();
        let n = grid[0].len();
        let nlayer = (m / 2).min(n / 2);   // 层数
        let k = k as usize;
        // 从左上角起逆时针枚举每一层
        for layer in 0..nlayer {
            let mut r = Vec::new();
            let mut c = Vec::new();
            let mut val = Vec::new();   // 每个元素的行下标,列下标与数值
            for i in layer..m - layer - 1 {   // 左
                r.push(i);
                c.push(layer);
                val.push(grid[i][layer]);
            }
            for j in layer..n - layer - 1 {   // 下
                r.push(m - layer - 1);
                c.push(j);
                val.push(grid[m - layer - 1][j]);
            }
            for i in (layer + 1..=m - layer - 1).rev() {   // 右
                r.push(i);
                c.push(n - layer - 1);
                val.push(grid[i][n - layer - 1]);
            }
            for j in (layer + 1..=n - layer - 1).rev() {   // 上
                r.push(layer);
                c.push(j);
                val.push(grid[layer][j]);
            }
            let total = val.len();   // 每一层的元素总数
            let kk = k % total;   // 等效轮转次数
            // 找到每个下标对应的轮转后的取值
            for i in 0..total {
                let idx = (i + total - kk) % total;   // 轮转后取值对应的下标
                grid[r[i]][c[i]] = val[idx];
            }
        }
        grid
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 和 $n$ 分别为 $\textit{grid}$ 的行数和列数。即为遍历 $\textit{grid}$ 并进行旋转的时间复杂度。

  • 空间复杂度:$O(m + n)$,即为存储每一层行列与数值的辅助数组大小。事实上,我们可以利用原地旋转将空间复杂度优化至 $O(1)$,但这样写出的代码并不直观,因此本题解中不给出空间复杂度最优的写法。

旋转盒子

方法一:用队列维护空位

提示 $1$

当我们将盒子顺时针旋转之后,原先的「每一行」就变成了「每一列」。

由于石头受到重力只会竖直向下掉落,因此「每一列」之间都互不影响,我们可以依次计算「每一列」的结果,即原先的「每一行」的结果。

思路与算法

由于重力向下,那么我们应当从右向左遍历原先的「每一行」。

我们使用一个队列来存放一行中的空位:

  • 当我们遍历到一块石头时,就从队首取出一个空位来放置这块石头。如果队列为空,那么说明右侧没有空位,这块石头不会下落;

  • 当我们遍历到一个空位时,我们将其加入队列;

  • 当我们遍历到一个障碍物时,需要将队列清空,障碍物右侧的空位都是不可用的。

在遍历完所有的行后,我们将矩阵顺时针旋转 $90$ 度,放入答案数组中即可。

代码

###C++

class Solution {
public:
    vector<vector<char>> rotateTheBox(vector<vector<char>>& box) {
        int m = box.size();
        int n = box[0].size();

        for (int i = 0; i < m; ++i) {
            deque<int> q;
            for (int j = n - 1; j >= 0; --j) {
                if (box[i][j] == '*') {
                    // 遇到障碍物,清空队列
                    q.clear();
                }
                else if (box[i][j] == '#') {
                    if (!q.empty()) {
                        // 如果队列不为空,石头就会下落
                        int pos = q.front();
                        q.pop_front();
                        box[i][pos] = '#';
                        box[i][j] = '.';
                        // 由于下落,石头变为空位,也需要加入队列
                        q.push_back(j);
                    }
                }
                else {
                    // 将空位加入队列
                    q.push_back(j);
                }
            }
        }

        // 将矩阵顺时针旋转 90 度放入答案
        vector<vector<char>> ans(n, vector<char>(m));
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                ans[j][m - i - 1] = box[i][j];
            }
        }
        return ans;
    }
};

###Python

class Solution:
    def rotateTheBox(self, box: List[List[str]]) -> List[List[str]]:
        m, n = len(box), len(box[0])

        for i in range(m):
            q = deque()
            for j in range(n - 1, -1, -1):
                if box[i][j] == "*":
                    # 遇到障碍物,清空队列
                    q.clear()
                elif box[i][j] == "#":
                    if q:
                        # 如果队列不为空,石头就会下落
                        pos = q.popleft()
                        box[i][pos] = "#"
                        box[i][j] = "."
                        # 由于下落,石头变为空位,也需要加入队列
                        q.append(j)
                else:
                    # 将空位加入队列
                    q.append(j)

        # 将矩阵顺时针旋转 90 度放入答案
        ans = [[""] * m for _ in range(n)]
        for i in range(m):
            for j in range(n):
                ans[j][m - i - 1] = box[i][j]
        return ans

复杂度分析

  • 时间复杂度:$O(mn)$。

  • 空间复杂度:$O(n)$,即为队列需要使用的空间。这里我们不计算返回的答案使用的空间。

方法二:用指针维护空位

提示 $1$

在遍历完某一个位置之后,如果队列不为空,那么:

  • 队尾一定是该位置;
  • 队列中的位置一定是连续的。

提示 $1$ 解释

如果队列不为空,那么该位置一定是空位(要么原本就是空位,要么原本有一块石头下落,该位置变成了空位),因此该位置会被加入队列成为队尾。

如果队列中的位置不连续,假设队列中没有位置 $x$,但有小于 $x$ 和大于 $x$ 的位置,当我们在此之前遍历到位置 $x$ 时,$x$ 没有被放入队列,说明 $x$ 不是空位,并且那时的队列为空,这样队列中就不可能有大于 $x$ 的位置了,这就产生了矛盾。

思路与算法

根据提示 $1$,我们就无需显式地维护这个队列了。

如果队列不为空,那么队尾一定为当前位置,且队列中的位置连续。因此我们只需要维护队首对应的位置即可。

代码

###C++

class Solution {
public:
    vector<vector<char>> rotateTheBox(vector<vector<char>>& box) {
        int m = box.size();
        int n = box[0].size();

        for (int i = 0; i < m; ++i) {
            // 队首对应的位置
            int front_pos = n - 1;
            for (int j = n - 1; j >= 0; --j) {
                if (box[i][j] == '*') {
                    // 遇到障碍物,清空队列
                    front_pos = j - 1;
                }
                else if (box[i][j] == '#') {
                    if (front_pos > j) {
                        // 如果队列不为空,石头就会下落
                        box[i][front_pos] = '#';
                        box[i][j] = '.';
                        --front_pos;
                    }
                    else {
                        front_pos = j - 1;
                    }
                }
            }
        }

        // 将矩阵顺时针旋转 90 度放入答案
        vector<vector<char>> ans(n, vector<char>(m));
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                ans[j][m - i - 1] = box[i][j];
            }
        }
        return ans;
    }
};

###Python

class Solution:
    def rotateTheBox(self, box: List[List[str]]) -> List[List[str]]:
        m, n = len(box), len(box[0])

        for i in range(m):
            # 队首对应的位置
            front_pos = n - 1
            for j in range(n - 1, -1, -1):
                if box[i][j] == "*":
                    # 遇到障碍物,清空队列
                    front_pos = j - 1
                elif box[i][j] == "#":
                    if front_pos > j:
                        # 如果队列不为空,石头就会下落
                        box[i][front_pos] = "#"
                        box[i][j] = "."
                        front_pos -= 1
                    else:
                        front_pos = j - 1

        # 将矩阵顺时针旋转 90 度放入答案
        ans = [[""] * m for _ in range(n)]
        for i in range(m):
            for j in range(n):
                ans[j][m - i - 1] = box[i][j]
        return ans

复杂度分析

  • 时间复杂度:$O(mn)$。

  • 空间复杂度:$O(1)$。这里我们不计算返回的答案使用的空间。

旋转链表

方法一:闭合为环

思路及算法

记给定链表的长度为 $n$,注意到当向右移动的次数 $k \geq n$ 时,我们仅需要向右移动 $k \bmod n$ 次即可。因为每 $n$ 次移动都会让链表变为原状。这样我们可以知道,新链表的最后一个节点为原链表的第 $(n - 1) - (k \bmod n)$ 个节点(从 $0$ 开始计数)。

这样,我们可以先将给定的链表连接成环,然后将指定位置断开。

具体代码中,我们首先计算出链表的长度 $n$,并找到该链表的末尾节点,将其与头节点相连。这样就得到了闭合为环的链表。然后我们找到新链表的最后一个节点(即原链表的第 $(n - 1) - (k \bmod n)$ 个节点),将当前闭合为环的链表断开,即可得到我们所需要的结果。

特别地,当链表长度不大于 $1$,或者 $k$ 为 $n$ 的倍数时,新链表将与原链表相同,我们无需进行任何处理。

代码

###C++

class Solution {
public:
    ListNode* rotateRight(ListNode* head, int k) {
        if (k == 0 || head == nullptr || head->next == nullptr) {
            return head;
        }
        int n = 1;
        ListNode* iter = head;
        while (iter->next != nullptr) {
            iter = iter->next;
            n++;
        }
        int add = n - k % n;
        if (add == n) {
            return head;
        }
        iter->next = head;
        while (add--) {
            iter = iter->next;
        }
        ListNode* ret = iter->next;
        iter->next = nullptr;
        return ret;
    }
};

###Java

class Solution {
    public ListNode rotateRight(ListNode head, int k) {
        if (k == 0 || head == null || head.next == null) {
            return head;
        }
        int n = 1;
        ListNode iter = head;
        while (iter.next != null) {
            iter = iter.next;
            n++;
        }
        int add = n - k % n;
        if (add == n) {
            return head;
        }
        iter.next = head;
        while (add-- > 0) {
            iter = iter.next;
        }
        ListNode ret = iter.next;
        iter.next = null;
        return ret;
    }
}

###Python

class Solution:
    def rotateRight(self, head: ListNode, k: int) -> ListNode:
        if k == 0 or not head or not head.next:
            return head
        
        n = 1
        cur = head
        while cur.next:
            cur = cur.next
            n += 1
        
        if (add := n - k % n) == n:
            return head
        
        cur.next = head
        while add:
            cur = cur.next
            add -= 1
        
        ret = cur.next
        cur.next = None
        return ret

###JavaScript

var rotateRight = function(head, k) {
    if (k === 0 || !head || !head.next) {
        return head;
    }
    let n = 1;
    let cur = head;
    while (cur.next) {
        cur = cur.next;
        n++;
    }

    let add = n - k % n;
    if (add === n) {
        return head;
    }

    cur.next = head;
    while (add) {
        cur = cur.next;
        add--;
    }

    const ret = cur.next;
    cur.next = null;
    return ret;
};

###go

func rotateRight(head *ListNode, k int) *ListNode {
    if k == 0 || head == nil || head.Next == nil {
        return head
    }
    n := 1
    iter := head
    for iter.Next != nil {
        iter = iter.Next
        n++
    }
    add := n - k%n
    if add == n {
        return head
    }
    iter.Next = head
    for add > 0 {
        iter = iter.Next
        add--
    }
    ret := iter.Next
    iter.Next = nil
    return ret
}

###C

struct ListNode* rotateRight(struct ListNode* head, int k) {
    if (k == 0 || head == NULL || head->next == NULL) {
        return head;
    }
    int n = 1;
    struct ListNode* iter = head;
    while (iter->next != NULL) {
        iter = iter->next;
        n++;
    }
    int add = n - k % n;
    if (add == n) {
        return head;
    }
    iter->next = head;
    while (add--) {
        iter = iter->next;
    }
    struct ListNode* ret = iter->next;
    iter->next = NULL;
    return ret;
}

复杂度分析

  • 时间复杂度:$O(n)$,最坏情况下,我们需要遍历该链表两次。

  • 空间复杂度:$O(1)$,我们只需要常数的空间存储若干变量。

旋转图像

📺 视频题解

48. 旋转图像.mp4

📖 文字题解

方法一:使用辅助数组

我们以题目中的示例二

$$
\begin{bmatrix}
5 & 1 & 9 & 11 \
2 & 4 & 8 & 10 \
13 & 3 & 6 & 7 \
15 & 14 & 12 & 16
\end{bmatrix}
$$

作为例子,分析将图像旋转 90 度之后,这些数字出现在什么位置。

对于矩阵中的第一行而言,在旋转后,它出现在倒数第一列的位置:

$$
\begin{bmatrix}
5 & 1 & 9 & 11 \
\circ & \circ & \circ & \circ \
\circ & \circ & \circ & \circ \
\circ & \circ & \circ & \circ \
\end{bmatrix}
\xRightarrow[]{旋转后}
\begin{bmatrix}
\circ & \circ & \circ & 5 \
\circ & \circ & \circ & 1 \
\circ & \circ & \circ & 9 \
\circ & \circ & \circ & 11
\end{bmatrix}
$$

并且,第一行的第 $x$ 个元素在旋转后恰好是倒数第一列的第 $x$ 个元素。

对于矩阵中的第二行而言,在旋转后,它出现在倒数第二列的位置:

$$
\begin{bmatrix}
\circ & \circ & \circ & \circ \
2 & 4 & 8 & 10 \
\circ & \circ & \circ & \circ \
\circ & \circ & \circ & \circ
\end{bmatrix}
\xRightarrow[]{旋转后}
\begin{bmatrix}
\circ & \circ & 2 & \circ \
\circ & \circ & 4 & \circ \
\circ & \circ & 8 & \circ \
\circ & \circ & 10 & \circ
\end{bmatrix}
$$

对于矩阵中的第三行和第四行同理。这样我们可以得到规律:

对于矩阵中第 $i$ 行的第 $j$ 个元素,在旋转后,它出现在倒数第 $i$ 列的第 $j$ 个位置。

我们将其翻译成代码。由于矩阵中的行列从 $0$ 开始计数,因此对于矩阵中的元素 $\textit{matrix}[\textit{row}][\textit{col}]$,在旋转后,它的新位置为 $\textit{matrix}_\textit{new}[\textit{col}][n - \textit{row} - 1]$。

这样以来,我们使用一个与 $\textit{matrix}$ 大小相同的辅助数组 ${matrix}\textit{new}$,临时存储旋转后的结果。我们遍历 $\textit{matrix}$ 中的每一个元素,根据上述规则将该元素存放到 ${matrix}\textit{new}$ 中对应的位置。在遍历完成之后,再将 ${matrix}_\textit{new}$ 中的结果复制到原数组中即可。

###C++

class Solution {
public:
    void rotate(vector<vector<int>>& matrix) {
        int n = matrix.size();
        // C++ 这里的 = 拷贝是值拷贝,会得到一个新的数组
        auto matrix_new = matrix;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix_new[j][n - i - 1] = matrix[i][j];
            }
        }
        // 这里也是值拷贝
        matrix = matrix_new;
    }
};

###Java

class Solution {
    public void rotate(int[][] matrix) {
        int n = matrix.length;
        int[][] matrix_new = new int[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix_new[j][n - i - 1] = matrix[i][j];
            }
        }
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix[i][j] = matrix_new[i][j];
            }
        }
    }
}

###Python

class Solution:
    def rotate(self, matrix: List[List[int]]) -> None:
        n = len(matrix)
        # Python 这里不能 matrix_new = matrix 或 matrix_new = matrix[:] 因为是引用拷贝
        matrix_new = [[0] * n for _ in range(n)]
        for i in range(n):
            for j in range(n):
                matrix_new[j][n - i - 1] = matrix[i][j]
        # 不能写成 matrix = matrix_new
        matrix[:] = matrix_new

###JavaScript

var rotate = function(matrix) {
    const n = matrix.length;
    const matrix_new = new Array(n).fill(0).map(() => new Array(n).fill(0));
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            matrix_new[j][n - i - 1] = matrix[i][j];
        }
    }
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            matrix[i][j] = matrix_new[i][j];
        }
    }
};

###Go

func rotate(matrix [][]int) {
    n := len(matrix)
    tmp := make([][]int, n)
    for i := range tmp {
        tmp[i] = make([]int, n)
    }
    for i, row := range matrix {
        for j, v := range row {
            tmp[j][n-1-i] = v
        }
    }
    copy(matrix, tmp) // 拷贝 tmp 矩阵每行的引用
}

###C

void rotate(int** matrix, int matrixSize, int* matrixColSize) {
    int matrix_new[matrixSize][matrixSize];
    for (int i = 0; i < matrixSize; i++) {
        for (int j = 0; j < matrixSize; j++) {
            matrix_new[i][j] = matrix[i][j];
        }
    }
    for (int i = 0; i < matrixSize; ++i) {
        for (int j = 0; j < matrixSize; ++j) {
            matrix[j][matrixSize - i - 1] = matrix_new[i][j];
        }
    }
}

复杂度分析

  • 时间复杂度:$O(N^2)$,其中 $N$ 是 $\textit{matrix}$ 的边长。

  • 空间复杂度:$O(N^2)$。我们需要使用一个和 $\textit{matrix}$ 大小相同的辅助数组。

方法二:原地旋转

题目中要求我们尝试在不使用额外内存空间的情况下进行矩阵的旋转,也就是说,我们需要「原地旋转」这个矩阵。那么我们如何在方法一的基础上完成原地旋转呢?

我们观察方法一中的关键等式:

$$
\textit{matrix}_\textit{new}[\textit{col}][n - \textit{row} - 1] = \textit{matrix}[\textit{row}][\textit{col}]
$$

它阻止了我们进行原地旋转,这是因为如果我们直接将 $\textit{matrix}[\textit{row}][\textit{col}]$ 放到原矩阵中的目标位置 $\textit{matrix}[\textit{col}][n - \textit{row} - 1]$:

$$
\textit{matrix}[\textit{col}][n - \textit{row} - 1] = \textit{matrix}[\textit{row}][\textit{col}]
$$

原矩阵中的 $\textit{matrix}[\textit{col}][n - \textit{row} - 1]$ 就被覆盖了!这并不是我们想要的结果。因此我们可以考虑用一个临时变量 $\textit{temp}$ 暂存 $\textit{matrix}[\textit{col}][n - \textit{row} - 1]$ 的值,这样虽然 $\textit{matrix}[\textit{col}][n - \textit{row} - 1]$ 被覆盖了,我们还是可以通过 $\textit{temp}$ 获取它原来的值:

$$
\left{
\begin{alignedat}{2}
&\textit{temp} &&= \textit{matrix}[\textit{col}][n - \textit{row} - 1]\
&\textit{matrix}[\textit{col}][n - \textit{row} - 1] &&= \textit{matrix}[\textit{row}][\textit{col}]
\end{alignedat}
\right.
$$

那么 $\textit{matrix}[\textit{col}][n - \textit{row} - 1]$ 经过旋转操作之后会到哪个位置呢?我们还是使用方法一中的关键等式,不过这次,我们需要将

$$
\left{
\begin{alignedat}{2}
& \textit{row} &&= \textit{col} \
& \textit{col} &&= n - \textit{row} - 1
\end{alignedat}
\right.
$$

带入关键等式,就可以得到:

$$
\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1] = \textit{matrix}[\textit{col}][n - \textit{row} - 1]
$$

同样地,直接赋值会覆盖掉 $\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]$ 原来的值,因此我们还是需要使用一个临时变量进行存储,不过这次,我们可以直接使用之前的临时变量 $\textit{temp}$:

$$
\left{
\begin{alignedat}{2}
&\textit{temp} &&= \textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]\
&\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1] &&= \textit{matrix}[\textit{col}][n - \textit{row} - 1]\
&\textit{matrix}[\textit{col}][n - \textit{row} - 1] &&= \textit{matrix}[\textit{row}][\textit{col}]
\end{alignedat}
\right.
$$

我们再重复一次之前的操作,$\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]$ 经过旋转操作之后会到哪个位置呢?

$$
\left{
\begin{alignedat}{2}
& \textit{row} &&= n - \textit{row} - 1\
& \textit{col} &&= n - \textit{col} - 1
\end{alignedat}
\right.
$$

带入关键等式,就可以得到:

$$
\textit{matrix}[n - \textit{col} - 1][\textit{row}] = \textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]
$$

写进去:

$$
\left{
\begin{alignedat}{2}
&\textit{temp} &&= \textit{matrix}[n - \textit{col} - 1][\textit{row}]\
&\textit{matrix}[n - \textit{col} - 1][\textit{row}] &&= \textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]\
&\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1] &&= \textit{matrix}[\textit{col}][n - \textit{row} - 1]\
&\textit{matrix}[\textit{col}][n - \textit{row} - 1] &&= \textit{matrix}[\textit{row}][\textit{col}]
\end{alignedat}
\right.
$$

不要灰心,再来一次!$\textit{matrix}[n - \textit{col} - 1][\textit{row}]$ 经过旋转操作之后回到哪个位置呢?

$$
\left{
\begin{alignedat}{2}
& \textit{row} &&= n - \textit{col} - 1\
& \textit{col} &&= \textit{row}
\end{alignedat}
\right.
$$

带入关键等式,就可以得到:

$$
\textit{matrix}[\textit{row}][\textit{col}] = \textit{matrix}[n - \textit{col} - 1][\textit{row}]
$$

我们回到了最初的起点 $\textit{matrix}[\textit{row}][\textit{col}]$,也就是说:

$$
\begin{cases}
\textit{matrix}[\textit{row}][\textit{col}]\
\textit{matrix}[\textit{col}][n - \textit{row} - 1]\
\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]\
\textit{matrix}[n - \textit{col} - 1][\textit{row}]
\end{cases}
$$

这四项处于一个循环中,并且每一项旋转后的位置就是下一项所在的位置!因此我们可以使用一个临时变量 $\textit{temp}$ 完成这四项的原地交换:

$$
\left{
\begin{alignedat}{2}
&\textit{temp} &&= \textit{matrix}[\textit{row}][\textit{col}]\
&\textit{matrix}[\textit{row}][\textit{col}] &&= \textit{matrix}[n - \textit{col} - 1][\textit{row}]\
&\textit{matrix}[n - \textit{col} - 1][\textit{row}] &&= \textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1]\
&\textit{matrix}[n - \textit{row} - 1][n - \textit{col} - 1] &&= \textit{matrix}[\textit{col}][n - \textit{row} - 1]\
&\textit{matrix}[\textit{col}][n - \textit{row} - 1] &&= \textit{temp}
\end{alignedat}
\right.
$$

当我们知道了如何原地旋转矩阵之后,还有一个重要的问题在于:我们应该枚举哪些位置 $(\textit{row}, \textit{col})$ 进行上述的原地交换操作呢?由于每一次原地交换四个位置,因此:

  • 当 $n$ 为偶数时,我们需要枚举 $n^2 / 4 = (n/2) \times (n/2)$ 个位置,可以将该图形分为四块,以 $4 \times 4$ 的矩阵为例:

fig1{:width="80%"}

保证了不重复、不遗漏;

  • 当 $n$ 为奇数时,由于中心的位置经过旋转后位置不变,我们需要枚举 $(n^2-1) / 4 = ((n-1)/2) \times ((n+1)/2)$ 个位置,需要换一种划分的方式,以 $5 \times 5$ 的矩阵为例:

fig2{:width="80%"}

同样保证了不重复、不遗漏,矩阵正中央的点无需旋转。

###C++

class Solution {
public:
    void rotate(vector<vector<int>>& matrix) {
        int n = matrix.size();
        for (int i = 0; i < n / 2; ++i) {
            for (int j = 0; j < (n + 1) / 2; ++j) {
                int temp = matrix[i][j];
                matrix[i][j] = matrix[n - j - 1][i];
                matrix[n - j - 1][i] = matrix[n - i - 1][n - j - 1];
                matrix[n - i - 1][n - j - 1] = matrix[j][n - i - 1];
                matrix[j][n - i - 1] = temp;
            }
        }
    }
};

###C++

class Solution {
public:
    void rotate(vector<vector<int>>& matrix) {
        int n = matrix.size();
        for (int i = 0; i < n / 2; ++i) {
            for (int j = 0; j < (n + 1) / 2; ++j) {
                tie(matrix[i][j], matrix[n - j - 1][i], matrix[n - i - 1][n - j - 1], matrix[j][n - i - 1]) \
                    = make_tuple(matrix[n - j - 1][i], matrix[n - i - 1][n - j - 1], matrix[j][n - i - 1], matrix[i][j]);
            }
        }
    }
};

###Java

class Solution {
    public void rotate(int[][] matrix) {
        int n = matrix.length;
        for (int i = 0; i < n / 2; ++i) {
            for (int j = 0; j < (n + 1) / 2; ++j) {
                int temp = matrix[i][j];
                matrix[i][j] = matrix[n - j - 1][i];
                matrix[n - j - 1][i] = matrix[n - i - 1][n - j - 1];
                matrix[n - i - 1][n - j - 1] = matrix[j][n - i - 1];
                matrix[j][n - i - 1] = temp;
            }
        }
    }
}

###Python

class Solution:
    def rotate(self, matrix: List[List[int]]) -> None:
        n = len(matrix)
        for i in range(n // 2):
            for j in range((n + 1) // 2):
                matrix[i][j], matrix[n - j - 1][i], matrix[n - i - 1][n - j - 1], matrix[j][n - i - 1] \
                    = matrix[n - j - 1][i], matrix[n - i - 1][n - j - 1], matrix[j][n - i - 1], matrix[i][j]

###JavaScript

var rotate = function(matrix) {
    const n = matrix.length;
    for (let i = 0; i < Math.floor(n / 2); ++i) {
        for (let j = 0; j < Math.floor((n + 1) / 2); ++j) {
            const temp = matrix[i][j];
            matrix[i][j] = matrix[n - j - 1][i];
            matrix[n - j - 1][i] = matrix[n - i - 1][n - j - 1];
            matrix[n - i - 1][n - j - 1] = matrix[j][n - i - 1];
            matrix[j][n - i - 1] = temp;
        }
    }
};

###Go

func rotate(matrix [][]int) {
    n := len(matrix)
    for i := 0; i < n/2; i++ {
        for j := 0; j < (n+1)/2; j++ {
            matrix[i][j], matrix[n-j-1][i], matrix[n-i-1][n-j-1], matrix[j][n-i-1] =
                matrix[n-j-1][i], matrix[n-i-1][n-j-1], matrix[j][n-i-1], matrix[i][j]
        }
    }
}

###C

void rotate(int** matrix, int matrixSize, int* matrixColSize) {
    for (int i = 0; i < matrixSize / 2; ++i) {
        for (int j = 0; j < (matrixSize + 1) / 2; ++j) {
            int temp = matrix[i][j];
            matrix[i][j] = matrix[matrixSize - j - 1][i];
            matrix[matrixSize - j - 1][i] = matrix[matrixSize - i - 1][matrixSize - j - 1];
            matrix[matrixSize - i - 1][matrixSize - j - 1] = matrix[j][matrixSize - i - 1];
            matrix[j][matrixSize - i - 1] = temp;
        }
    }
}

复杂度分析

  • 时间复杂度:$O(N^2)$,其中 $N$ 是 $\textit{matrix}$ 的边长。我们需要枚举的子矩阵大小为 O($\lfloor n/2 \rfloor \times \lfloor (n+1)/2 \rfloor) = O(N^2)$。

  • 空间复杂度:$O(1)$。为原地旋转。

方法三:用翻转代替旋转

我们还可以另辟蹊径,用翻转操作代替旋转操作。我们还是以题目中的示例二

$$
\begin{bmatrix}
5 & 1 & 9 & 11 \
2 & 4 & 8 & 10 \
13 & 3 & 6 & 7 \
15 & 14 & 12 & 16
\end{bmatrix}
$$

作为例子,先将其通过水平轴翻转得到:

$$
\begin{bmatrix}
5 & 1 & 9 & 11 \
2 & 4 & 8 & 10 \
13 & 3 & 6 & 7 \
15 & 14 & 12 & 16
\end{bmatrix}
\xRightarrow[]{水平翻转}
\begin{bmatrix}
15 & 14 & 12 & 16 \
13 & 3 & 6 & 7 \
2 & 4 & 8 & 10 \
5 & 1 & 9 & 11
\end{bmatrix}
$$

再根据主对角线翻转得到:

$$
\begin{bmatrix}
15 & 14 & 12 & 16 \
13 & 3 & 6 & 7 \
2 & 4 & 8 & 10 \
5 & 1 & 9 & 11
\end{bmatrix}
\xRightarrow[]{主对角线翻转}
\begin{bmatrix}
15 & 13 & 2 & 5 \
14 & 3 & 4 & 1 \
12 & 6 & 8 & 9 \
16 & 7 & 10 & 11
\end{bmatrix}
$$

就得到了答案。这是为什么呢?对于水平轴翻转而言,我们只需要枚举矩阵上半部分的元素,和下半部分的元素进行交换,即

$$
\textit{matrix}[\textit{row}][\textit{col}] \xRightarrow[]{水平轴翻转} \textit{matrix}[n - \textit{row} - 1][\textit{col}]
$$

对于主对角线翻转而言,我们只需要枚举对角线左侧的元素,和右侧的元素进行交换,即

$$
\textit{matrix}[\textit{row}][\textit{col}] \xRightarrow[]{主对角线翻转} \textit{matrix}[\textit{col}][\textit{row}]
$$

将它们联立即可得到:

$$
\begin{aligned}
\textit{matrix}[\textit{row}][\textit{col}] & \xRightarrow[]{水平轴翻转} \textit{matrix}[n - \textit{row} - 1][\textit{col}] \
&\xRightarrow[]{主对角线翻转} \textit{matrix}[\textit{col}][n - \textit{row} - 1]
\end{aligned}
$$

和方法一、方法二中的关键等式:

$$
\textit{matrix}_\textit{new}[\textit{col}][n - \textit{row} - 1] = \textit{matrix}[\textit{row}][\textit{col}]
$$

是一致的。

###C++

class Solution {
public:
    void rotate(vector<vector<int>>& matrix) {
        int n = matrix.size();
        // 水平翻转
        for (int i = 0; i < n / 2; ++i) {
            for (int j = 0; j < n; ++j) {
                swap(matrix[i][j], matrix[n - i - 1][j]);
            }
        }
        // 主对角线翻转
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                swap(matrix[i][j], matrix[j][i]);
            }
        }
    }
};

###Java

class Solution {
    public void rotate(int[][] matrix) {
        int n = matrix.length;
        // 水平翻转
        for (int i = 0; i < n / 2; ++i) {
            for (int j = 0; j < n; ++j) {
                int temp = matrix[i][j];
                matrix[i][j] = matrix[n - i - 1][j];
                matrix[n - i - 1][j] = temp;
            }
        }
        // 主对角线翻转
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                int temp = matrix[i][j];
                matrix[i][j] = matrix[j][i];
                matrix[j][i] = temp;
            }
        }
    }
}

###Python

class Solution:
    def rotate(self, matrix: List[List[int]]) -> None:
        n = len(matrix)
        # 水平翻转
        for i in range(n // 2):
            for j in range(n):
                matrix[i][j], matrix[n - i - 1][j] = matrix[n - i - 1][j], matrix[i][j]
        # 主对角线翻转
        for i in range(n):
            for j in range(i):
                matrix[i][j], matrix[j][i] = matrix[j][i], matrix[i][j]

###JavaScript

var rotate = function(matrix) {
    const n = matrix.length;
    // 水平翻转
    for (let i = 0; i < Math.floor(n / 2); i++) {
        for (let j = 0; j < n; j++) {
            [matrix[i][j], matrix[n - i - 1][j]] = [matrix[n - i - 1][j], matrix[i][j]];
        }
    }
    // 主对角线翻转
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < i; j++) {
            [matrix[i][j], matrix[j][i]] = [matrix[j][i], matrix[i][j]];
        }
    }
};

###Go

func rotate(matrix [][]int) {
    n := len(matrix)
    // 水平翻转
    for i := 0; i < n/2; i++ {
        matrix[i], matrix[n-1-i] = matrix[n-1-i], matrix[i]
    }
    // 主对角线翻转
    for i := 0; i < n; i++ {
        for j := 0; j < i; j++ {
            matrix[i][j], matrix[j][i] = matrix[j][i], matrix[i][j]
        }
    }
}

###C

void swap(int* a, int* b) {
    int t = *a;
    *a = *b, *b = t;
}

void rotate(int** matrix, int matrixSize, int* matrixColSize) {
    // 水平翻转
    for (int i = 0; i < matrixSize / 2; ++i) {
        for (int j = 0; j < matrixSize; ++j) {
            swap(&matrix[i][j], &matrix[matrixSize - i - 1][j]);
        }
    }
    // 主对角线翻转
    for (int i = 0; i < matrixSize; ++i) {
        for (int j = 0; j < i; ++j) {
            swap(&matrix[i][j], &matrix[j][i]);
        }
    }
}

复杂度分析

  • 时间复杂度:$O(N^2)$,其中 $N$ 是 $\textit{matrix}$ 的边长。对于每一次翻转操作,我们都需要枚举矩阵中一半的元素。

  • 空间复杂度:$O(1)$。为原地翻转得到的原地旋转。

旋转字符串

方法一:模拟

思路

首先,如果 $s$ 和 $\textit{goal}$ 的长度不一样,那么无论怎么旋转,$s$ 都不能得到 $\textit{goal}$,返回 $\text{false}$。在长度一样(都为 $n$)的前提下,假设 $s$ 旋转 $i$ 位,则与 $\textit{goal}$ 中的某一位字符 $\textit{goal}[j]$ 对应的原 $s$ 中的字符应该为 $s[(i+j) \bmod n]$。在固定 $i$ 的情况下,遍历所有 $j$,若对应字符都相同,则返回 $\text{true}$。否则,继续遍历其他候选的 $i$。若所有的 $i$ 都不能使 $s$ 变成 $\textit{goal}$,则返回 $\text{false}$。

代码

###Python

class Solution:
    def rotateString(self, s: str, goal: str) -> bool:
        m, n = len(s), len(goal)
        if m != n:
            return False
        for i in range(n):
            for j in range(n):
                if s[(i + j) % n] != goal[j]:
                    break
            else:
                return True
        return False

###Java

class Solution {
    public boolean rotateString(String s, String goal) {
        int m = s.length(), n = goal.length();
        if (m != n) {
            return false;
        }
        for (int i = 0; i < n; i++) {
            boolean flag = true;
            for (int j = 0; j < n; j++) {
                if (s.charAt((i + j) % n) != goal.charAt(j)) {
                    flag = false;
                    break;
                }
            }
            if (flag) {
                return true;
            }
        }
        return false;
    }
}

###C#

public class Solution {
    public bool rotateString(string s, string goal) {
        int m = s.Length, n = goal.Length;
        if (m != n) {
            return false;
        }
        for (int i = 0; i < n; i++) {
            bool flag = true;
            for (int j = 0; j < n; j++) {
                if (s[(i + j) % n] != goal[j]) {
                    flag = false;
                    break;
                }
            }
            if (flag) {
                return true;
            }
        }
        return false;
    }
}

###C++

class Solution {
public:
    bool rotateString(string s, string goal) {
        int m = s.size(), n = goal.size();
        if (m != n) {
            return false;
        }
        for (int i = 0; i < n; i++) {
            bool flag = true;
            for (int j = 0; j < n; j++) {
                if (s[(i + j) % n] != goal[j]) {
                    flag = false;
                    break;
                }
            }
            if (flag) {
                return true;
            }
        }
        return false;
    }
};

###C

bool rotateString(char * s, char * goal){
    int m = strlen(s), n = strlen(goal);
    if (m != n) {
        return false;
    }
    for (int i = 0; i < n; i++) {
        bool flag = true;
        for (int j = 0; j < n; j++) {
            if (s[(i + j) % n] != goal[j]) {
                flag = false;
                break;
            }
        }
        if (flag) {
            return true;
        }
    }
    return false;
}

###JavaScript

var rotateString = function(s, goal) {
    const m = s.length, n = goal.length;
    if (m !== n) {
        return false;
    }
    for (let i = 0; i < n; i++) {
        let flag = true;
        for (let j = 0; j < n; j++) {
            if (s[(i + j) % n] !== goal[j]) {
                flag = false;
                break;
            }
        }
        if (flag) {
            return true;
        }
    }
    return false;
};

###go

func rotateString(s, goal string) bool {
    n := len(s)
    if n != len(goal) {
        return false
    }
next:
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            if s[(i+j)%n] != goal[j] {
                continue next
            }
        }
        return true
    }
    return false
}

复杂度分析

  • 时间复杂度:$O(n^2)$,其中 $n$ 是字符串 $s$ 的长度。我们需要双重循环来判断。

  • 空间复杂度:$O(1)$。仅使用常数空间。

方法二:搜索子字符串

思路

首先,如果 $s$ 和 $\textit{goal}$ 的长度不一样,那么无论怎么旋转,$s$ 都不能得到 $\textit{goal}$,返回 $\text{false}$。字符串 $s + s$ 包含了所有 $s$ 可以通过旋转操作得到的字符串,只需要检查 $\textit{goal}$ 是否为 $s + s$ 的子字符串即可。具体可以参考「28. 实现 strStr() 的官方题解」的实现代码,本题解中采用直接调用库函数的方法。

代码

###Python

class Solution:
    def rotateString(self, s: str, goal: str) -> bool:
        return len(s) == len(goal) and goal in s + s

###Java

class Solution {
    public boolean rotateString(String s, String goal) {
        return s.length() == goal.length() && (s + s).contains(goal);
    }
}

###C#

public class Solution {
    public bool rotateString(string s, string goal) {
        return s.Length == goal.Length && (s + s).Contains(goal);
    }
}

###C++

class Solution {
public:
    bool rotateString(string s, string goal) {
        return s.size() == goal.size() && (s + s).find(goal) != string::npos;
    }
};

###C

bool rotateString(char * s, char * goal){
    int m = strlen(s), n = strlen(goal);
    if (m != n) {
        return false;
    }
    char * str = (char *)malloc(sizeof(char) * (m + n + 1));
    sprintf(str, "%s%s", goal, goal);
    return strstr(str, s) != NULL;
}

###JavaScript

var rotateString = function(s, goal) {
    return s.length === goal.length && (s + s).indexOf(goal) !== -1;
};

###go

func rotateString(s, goal string) bool {
    return len(s) == len(goal) && strings.Contains(s+s, goal)
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是字符串 $s$ 的长度。$\text{KMP}$ 算法搜索子字符串的时间复杂度为 $O(n)$,其他搜索子字符串的方法会略有差异。

  • 空间复杂度:$O(n)$,其中 $n$ 是字符串 $s$ 的长度。$\text{KMP}$ 算法搜索子字符串的空间复杂度为 $O(n)$,其他搜索子字符串的方法会略有差异。

旋转数组

方法一:迭代

思路

记数组 $\textit{nums}$ 的元素之和为 $\textit{numSum}$。根据公式,可以得到:

  • $F(0) = 0 \times \textit{nums}[0] + 1 \times \textit{nums}[1] + \ldots + (n-1) \times \textit{nums}[n-1]$
  • $F(1) = 1 \times \textit{nums}[0] + 2 \times \textit{nums}[1] + \ldots + 0 \times \textit{nums}[n-1] = F(0) + \textit{numSum} - n \times \textit{nums}[n-1]$

更一般地,当 $1 \le k \lt n$ 时,$F(k) = F(k-1) + \textit{numSum} - n \times \textit{nums}[n-k]$。我们可以不停迭代计算出不同的 $F(k)$,并求出最大值。

代码

###Python

class Solution:
    def maxRotateFunction(self, nums: List[int]) -> int:
        f, n, numSum = 0, len(nums), sum(nums)
        for i, num in enumerate(nums):
            f += i * num
        res = f
        for i in range(n - 1, 0, -1):
            f = f + numSum - n * nums[i]
            res = max(res, f)
        return res

###Java

class Solution {
    public int maxRotateFunction(int[] nums) {
        int f = 0, n = nums.length, numSum = Arrays.stream(nums).sum();
        for (int i = 0; i < n; i++) {
            f += i * nums[i];
        }
        int res = f;
        for (int i = n - 1; i > 0; i--) {
            f += numSum - n * nums[i];
            res = Math.max(res, f);
        }
        return res;
    }
}

###C#

public class Solution {
    public int MaxRotateFunction(int[] nums) {
        int f = 0, n = nums.Length, numSum = nums.Sum();
        for (int i = 0; i < n; i++) {
            f += i * nums[i];
        }
        int res = f;
        for (int i = n - 1; i > 0; i--) {
            f += numSum - n * nums[i];
            res = Math.Max(res, f);
        }
        return res;
    }
}

###C++

class Solution {
public:
    int maxRotateFunction(vector<int>& nums) {
        int f = 0, n = nums.size();
        int numSum = accumulate(nums.begin(), nums.end(), 0);
        for (int i = 0; i < n; i++) {
            f += i * nums[i];
        }
        int res = f;
        for (int i = n - 1; i > 0; i--) {
            f += numSum - n * nums[i];
            res = max(res, f);
        }
        return res;
    }
};

###C

#define MAX(a, b) ((a) > (b) ? (a) : (b))

int maxRotateFunction(int* nums, int numsSize){
    int f = 0, numSum = 0;
    for (int i = 0; i < numsSize; i++) {
        f += i * nums[i];
        numSum += nums[i];
    }
    int res = f;
    for (int i = numsSize - 1; i > 0; i--) {
        f += numSum - numsSize * nums[i];
        res = MAX(res, f);
    }
    return res;
}

###JavaScript

var maxRotateFunction = function(nums) {
    let f = 0, n = nums.length, numSum = _.sum(nums);
    for (let i = 0; i < n; i++) {
        f += i * nums[i];
    }
    let res = f;
    for (let i = n - 1; i > 0; i--) {
        f += numSum - n * nums[i];
        res = Math.max(res, f);
    }
    return res;
};

###go

func maxRotateFunction(nums []int) int {
    numSum := 0
    for _, v := range nums {
        numSum += v
    }
    f := 0
    for i, num := range nums {
        f += i * num
    }
    ans := f
    for i := len(nums) - 1; i > 0; i-- {
        f += numSum - len(nums)*nums[i]
        ans = max(ans, f)
    }
    return ans
}

func max(a, b int) int {
    if b > a {
        return b
    }
    return a
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。计算 $\textit{numSum}$ 需要 $O(n)$ 时间,计算初始值 $F(0)$ 也需要 $O(n)$ 时间,因为我们只需遍历一次数组。之后,我们进行 $n - 1$ 次迭代来计算 $F(k)$ 的其余值。每次迭代使用递推关系式更新值:

    $$
    F(k) = F(k-1) + \textit{numSum} - n \cdot \textit{nums}[n-k]
    $$

    此更新仅涉及常数数量的算术运算,因此每次迭代的时间复杂度为 $O(1)$。因此,总的时间复杂度为:

    $$
    O(n) + O(n) + O(n) = O(n)
    $$

    总体而言,该算法的时间复杂度为线性。

  • 空间复杂度:$O(1)$。仅使用常数空间。

网格中得分最大的路径

方法一:动态规划

思路与算法

题目要求我们在总花费不超过 $k$ 的情况下,找到一条从 $\textit{grid}[0][0]$ 到 $\textit{grid}[m-1][n-1]$ 的路径,使得获得的分数最大。这种有限制的最优化问题结构类似背包问题,可以用动态规划解决。

定义状态 $\textit{dp}[i][j][c]$ 表示到达位置 $(i,j)$,当前花费为 $c$ 时的最大得分。

我们从当前格子向后转移,即从 $(i,j)$ 出发,可以向下或向右移动,将下一个格子的代价和分数加入:

  • 向下:转移到 $(i+1,j)$
  • 向右:转移到 $(i,j+1)$

状态转移为:

$$
\begin{aligned}
\textit{dp}[i+1][j][c + \textit{cost}(i+1,j)] &= \max(\textit{dp}[i+1][j][c + \textit{cost}(i+1,j)],\textit{dp}[i][j][c] + \textit{grid}[i+1][j]) \
\textit{dp}[i][j+1][c + \textit{cost}(i,j+1)] &= \max(\textit{dp}[i][j+1][c + \textit{cost}(i,j+1)],\textit{dp}[i][j][c] + \textit{grid}[i][j+1])
\end{aligned}
$$

其中:

$$
\textit{cost}(i,j) =
\begin{cases}
1, & \textit{grid}[i][j] \neq 0 \
0, & \textit{grid}[i][j] = 0
\end{cases}
$$

初始状态为 $\textit{dp}[0][0][0] = 0$(起点不计入得分和花费)。

最终答案为:

$$
\max\limits_{0 \le c \le k} \textit{dp}[m-1][n-1][c]
$$

代码

###C++

class Solution {
public:
    int maxPathScore(vector<vector<int>>& grid, int k) {
        int m = grid.size();
        int n = grid[0].size();
        vector<vector<vector<int>>> dp(
            m, vector<vector<int>>(n, vector<int>(k + 1, INT_MIN)));
        dp[0][0][0] = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                for (int c = 0; c <= k; c++) {
                    if (dp[i][j][c] == INT_MIN)
                        continue;
                    if (i + 1 < m) {
                        int val = grid[i + 1][j];
                        int cost = (val == 0 ? 0 : 1);
                        if (c + cost <= k) {
                            dp[i + 1][j][c + cost] =
                                max(dp[i + 1][j][c + cost], dp[i][j][c] + val);
                        }
                    }
                    if (j + 1 < n) {
                        int val = grid[i][j + 1];
                        int cost = (val == 0 ? 0 : 1);
                        if (c + cost <= k) {
                            dp[i][j + 1][c + cost] =
                                max(dp[i][j + 1][c + cost], dp[i][j][c] + val);
                        }
                    }
                }
            }
        }
        int ans = INT_MIN;
        for (int c = 0; c <= k; c++) {
            ans = max(ans, dp[m - 1][n - 1][c]);
        }
        return ans < 0 ? -1 : ans;
    }
};

###Java

class Solution {
    public int maxPathScore(int[][] grid, int k) {
        int m = grid.length;
        int n = grid[0].length;

        int[][][] dp = new int[m][n][k + 1];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                Arrays.fill(dp[i][j], Integer.MIN_VALUE);
            }
        }

        dp[0][0][0] = 0;

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                for (int c = 0; c <= k; c++) {
                    if (dp[i][j][c] == Integer.MIN_VALUE) continue;

                    if (i + 1 < m) {
                        int val = grid[i + 1][j];
                        int cost = (val == 0 ? 0 : 1);
                        if (c + cost <= k) {
                            dp[i + 1][j][c + cost] = Math.max(
                                dp[i + 1][j][c + cost],
                                dp[i][j][c] + val
                            );
                        }
                    }

                    if (j + 1 < n) {
                        int val = grid[i][j + 1];
                        int cost = (val == 0 ? 0 : 1);
                        if (c + cost <= k) {
                            dp[i][j + 1][c + cost] = Math.max(
                                dp[i][j + 1][c + cost],
                                dp[i][j][c] + val
                            );
                        }
                    }
                }
            }
        }

        int ans = Integer.MIN_VALUE;
        for (int c = 0; c <= k; c++) {
            ans = Math.max(ans, dp[m - 1][n - 1][c]);
        }

        return ans < 0 ? -1 : ans;
    }
}

###Python3

class Solution:
    def maxPathScore(self, grid, k):
        m, n = len(grid), len(grid[0])

        INF = float('-inf')
        dp = [[[INF] * (k + 1) for _ in range(n)] for _ in range(m)]
        dp[0][0][0] = 0

        for i in range(m):
            for j in range(n):
                for c in range(k + 1):
                    if dp[i][j][c] == INF:
                        continue

                    if i + 1 < m:
                        val = grid[i + 1][j]
                        cost = 0 if val == 0 else 1
                        if c + cost <= k:
                            dp[i + 1][j][c + cost] = max(
                                dp[i + 1][j][c + cost],
                                dp[i][j][c] + val
                            )

                    if j + 1 < n:
                        val = grid[i][j + 1]
                        cost = 0 if val == 0 else 1
                        if c + cost <= k:
                            dp[i][j + 1][c + cost] = max(
                                dp[i][j + 1][c + cost],
                                dp[i][j][c] + val
                            )

        ans = max(dp[m - 1][n - 1])
        return -1 if ans < 0 else ans

###C

int maxPathScore(int** grid, int m, int* gridColSize, int k) {
    int n = gridColSize[0];

    int*** dp = (int***)malloc(m * sizeof(int**));
    for (int i = 0; i < m; i++) {
        dp[i] = (int**)malloc(n * sizeof(int*));
        for (int j = 0; j < n; j++) {
            dp[i][j] = (int*)malloc((k + 1) * sizeof(int));
            for (int c = 0; c <= k; c++) {
                dp[i][j][c] = INT_MIN;
            }
        }
    }

    dp[0][0][0] = 0;

    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            for (int c = 0; c <= k; c++) {
                if (dp[i][j][c] == INT_MIN) continue;

                if (i + 1 < m) {
                    int val = grid[i + 1][j];
                    int cost = val == 0 ? 0 : 1;
                    if (c + cost <= k) {
                        int* target = &dp[i + 1][j][c + cost];
                        if (*target < dp[i][j][c] + val)
                            *target = dp[i][j][c] + val;
                    }
                }

                if (j + 1 < n) {
                    int val = grid[i][j + 1];
                    int cost = val == 0 ? 0 : 1;
                    if (c + cost <= k) {
                        int* target = &dp[i][j + 1][c + cost];
                        if (*target < dp[i][j][c] + val)
                            *target = dp[i][j][c] + val;
                    }
                }
            }
        }
    }

    int ans = INT_MIN;
    for (int c = 0; c <= k; c++) {
        if (dp[m - 1][n - 1][c] > ans)
            ans = dp[m - 1][n - 1][c];
    }

    return ans < 0 ? -1 : ans;
}

###Golang

func maxPathScore(grid [][]int, k int) int {
    m, n := len(grid), len(grid[0])

    const INF = math.MinInt32

    dp := make([][][]int, m)
    for i := range dp {
        dp[i] = make([][]int, n)
        for j := range dp[i] {
            dp[i][j] = make([]int, k+1)
            for c := range dp[i][j] {
                dp[i][j][c] = INF
            }
        }
    }

    dp[0][0][0] = 0

    for i := 0; i < m; i++ {
        for j := 0; j < n; j++ {
            for c := 0; c <= k; c++ {
                if dp[i][j][c] == INF {
                    continue
                }

                if i+1 < m {
                    val := grid[i+1][j]
                    cost := 0
                    if val != 0 {
                        cost = 1
                    }
                    if c+cost <= k {
                        if dp[i+1][j][c+cost] < dp[i][j][c]+val {
                            dp[i+1][j][c+cost] = dp[i][j][c] + val
                        }
                    }
                }

                if j+1 < n {
                    val := grid[i][j+1]
                    cost := 0
                    if val != 0 {
                        cost = 1
                    }
                    if c+cost <= k {
                        if dp[i][j+1][c+cost] < dp[i][j][c]+val {
                            dp[i][j+1][c+cost] = dp[i][j][c] + val
                        }
                    }
                }
            }
        }
    }

    ans := INF
    for c := 0; c <= k; c++ {
        if dp[m-1][n-1][c] > ans {
            ans = dp[m-1][n-1][c]
        }
    }

    if ans < 0 {
        return -1
    }
    return ans
}

###C#

public class Solution {
    public int MaxPathScore(int[][] grid, int k) {
        int m = grid.Length, n = grid[0].Length;

        int[,,] dp = new int[m, n, k + 1];

        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                for (int c = 0; c <= k; c++)
                    dp[i, j, c] = int.MinValue;

        dp[0, 0, 0] = 0;

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                for (int c = 0; c <= k; c++) {
                    if (dp[i, j, c] == int.MinValue) continue;

                    if (i + 1 < m) {
                        int val = grid[i + 1][j];
                        int cost = val == 0 ? 0 : 1;
                        if (c + cost <= k) {
                            dp[i + 1, j, c + cost] = Math.Max(
                                dp[i + 1, j, c + cost],
                                dp[i, j, c] + val
                            );
                        }
                    }

                    if (j + 1 < n) {
                        int val = grid[i][j + 1];
                        int cost = val == 0 ? 0 : 1;
                        if (c + cost <= k) {
                            dp[i, j + 1, c + cost] = Math.Max(
                                dp[i, j + 1, c + cost],
                                dp[i, j, c] + val
                            );
                        }
                    }
                }
            }
        }

        int ans = int.MinValue;
        for (int c = 0; c <= k; c++) {
            ans = Math.Max(ans, dp[m - 1, n - 1, c]);
        }

        return ans < 0 ? -1 : ans;
    }
}

###JavaScript

var maxPathScore = function(grid, k) {
    const m = grid.length, n = grid[0].length;

    const INF = -Infinity;
    const dp = Array.from({ length: m }, () =>
        Array.from({ length: n }, () => Array(k + 1).fill(INF))
    );

    dp[0][0][0] = 0;

    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            for (let c = 0; c <= k; c++) {
                if (dp[i][j][c] === INF) continue;

                if (i + 1 < m) {
                    const val = grid[i + 1][j];
                    const cost = val === 0 ? 0 : 1;
                    if (c + cost <= k) {
                        dp[i + 1][j][c + cost] = Math.max(
                            dp[i + 1][j][c + cost],
                            dp[i][j][c] + val
                        );
                    }
                }

                if (j + 1 < n) {
                    const val = grid[i][j + 1];
                    const cost = val === 0 ? 0 : 1;
                    if (c + cost <= k) {
                        dp[i][j + 1][c + cost] = Math.max(
                            dp[i][j + 1][c + cost],
                            dp[i][j][c] + val
                        );
                    }
                }
            }
        }
    }

    let ans = Math.max(...dp[m - 1][n - 1]);
    return ans < 0 ? -1 : ans;
};

###TypeScript

function maxPathScore(grid: number[][], k: number): number {
    const m = grid.length, n = grid[0].length;

    const INF = -Infinity;
    const dp: number[][][] = Array.from({ length: m }, () =>
        Array.from({ length: n }, () => Array(k + 1).fill(INF))
    );

    dp[0][0][0] = 0;

    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            for (let c = 0; c <= k; c++) {
                if (dp[i][j][c] === INF) continue;

                if (i + 1 < m) {
                    const val = grid[i + 1][j];
                    const cost = val === 0 ? 0 : 1;
                    if (c + cost <= k) {
                        dp[i + 1][j][c + cost] = Math.max(
                            dp[i + 1][j][c + cost],
                            dp[i][j][c] + val
                        );
                    }
                }

                if (j + 1 < n) {
                    const val = grid[i][j + 1];
                    const cost = val === 0 ? 0 : 1;
                    if (c + cost <= k) {
                        dp[i][j + 1][c + cost] = Math.max(
                            dp[i][j + 1][c + cost],
                            dp[i][j][c] + val
                        );
                    }
                }
            }
        }
    }

    const ans = Math.max(...dp[m - 1][n - 1]);
    return ans < 0 ? -1 : ans;
}

###Rust

impl Solution {
    pub fn max_path_score(grid: Vec<Vec<i32>>, k: i32) -> i32 {
        let m = grid.len();
        let n = grid[0].len();
        let k = k as usize;

        let inf = i32::MIN / 2;
        let mut dp = vec![vec![vec![inf; k + 1]; n]; m];

        dp[0][0][0] = 0;

        for i in 0..m {
            for j in 0..n {
                for c in 0..=k {
                    if dp[i][j][c] == inf {
                        continue;
                    }

                    if i + 1 < m {
                        let val = grid[i + 1][j];
                        let cost = if val == 0 { 0 } else { 1 };
                        if c + cost <= k {
                            dp[i + 1][j][c + cost] =
                                dp[i + 1][j][c + cost].max(dp[i][j][c] + val);
                        }
                    }

                    if j + 1 < n {
                        let val = grid[i][j + 1];
                        let cost = if val == 0 { 0 } else { 1 };
                        if c + cost <= k {
                            dp[i][j + 1][c + cost] =
                                dp[i][j + 1][c + cost].max(dp[i][j][c] + val);
                        }
                    }
                }
            }
        }

        let mut ans = inf;
        for c in 0..=k {
            ans = ans.max(dp[m - 1][n - 1][c]);
        }

        if ans < 0 { -1 } else { ans }
    }
}

复杂度分析

  • 时间复杂度:$O(mnk)$,其中 $m$,$n$ 分别是矩阵 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$O(mnk)$。

距离字典两次编辑以内的单词

方法一:暴力

思路与算法

注意到题目的数据范围很小,我们可以直接实施暴力算法。

对于 $\textit{queries}$ 中的每个字符串 $\textit{queries}[i]$,查找 $\textit{dictionary}$ 中是否存在一个字符串,使得两个字符串中最多只有两个字符不同(即汉明距离小于等于 $2$),如果存在,就将其添加到答案中。由于添加的顺序与遍历 $\textit{queries}$ 的顺序一致,我们不需要对答案的顺序进行特殊处理。

代码

###C++

class Solution {
public:
    vector<string> twoEditWords(vector<string>& queries,
                                vector<string>& dictionary) {
        vector<string> ans;
        for (string query : queries) {
            for (string s : dictionary) {
                int dis = 0;
                for (int i = 0; i < query.size(); i++) {
                    if (query[i] != s[i]) {
                        ++dis;
                    }
                }
                if (dis <= 2) {
                    ans.push_back(query);
                    break;
                }
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public List<String> twoEditWords(String[] queries, String[] dictionary) {
        List<String> ans = new ArrayList<>();
        for (String query : queries) {
            for (String s : dictionary) {
                int dis = 0;
                for (int i = 0; i < query.length(); i++) {
                    if (query.charAt(i) != s.charAt(i)) {
                        dis++;
                    }
                }
                if (dis <= 2) {
                    ans.add(query);
                    break;
                }
            }
        }
        return ans;
    }
}

###C#

public class Solution {
    public IList<string> TwoEditWords(string[] queries, string[] dictionary) {
        var ans = new List<string>();
        foreach (var query in queries) {
            foreach (var s in dictionary) {
                int dis = 0;
                for (int i = 0; i < query.Length; i++) {
                    if (query[i] != s[i]) {
                        dis++;
                    }
                }
                if (dis <= 2) {
                    ans.Add(query);
                    break;
                }
            }
        }
        return ans;
    }
}

###Python

class Solution:
    def twoEditWords(self, queries, dictionary):
        ans = []
        for query in queries:
            for s in dictionary:
                dis = 0
                for i in range(len(query)):
                    if query[i] != s[i]:
                        dis += 1
                if dis <= 2:
                    ans.append(query)
                    break
        return ans

###C

char** twoEditWords(char** queries, int queriesSize,
                    char** dictionary, int dictionarySize,
                    int* returnSize) {
    char** ans = (char**)malloc(sizeof(char*) * queriesSize);
    int cnt = 0;

    for (int i = 0; i < queriesSize; i++) {
        char* query = queries[i];
        for (int j = 0; j < dictionarySize; j++) {
            char* s = dictionary[j];
            int dis = 0;
            for (int k = 0; query[k] != '\0'; k++) {
                if (query[k] != s[k]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans[cnt++] = query;
                break;
            }
        }
    }

    *returnSize = cnt;
    return ans;
}

###Go

func twoEditWords(queries []string, dictionary []string) []string {
    var ans []string
    for _, query := range queries {
        for _, s := range dictionary {
            dis := 0
            for i := 0; i < len(query); i++ {
                if query[i] != s[i] {
                    dis++
                }
            }
            if dis <= 2 {
                ans = append(ans, query)
                break
            }
        }
    }
    return ans
}

###JavaScript

var twoEditWords = function(queries, dictionary) {
    const ans = [];
    for (const query of queries) {
        for (const s of dictionary) {
            let dis = 0;
            for (let i = 0; i < query.length; i++) {
                if (query[i] !== s[i]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans.push(query);
                break;
            }
        }
    }
    return ans;
};

###TypeScript

function twoEditWords(queries: string[], dictionary: string[]): string[] {
    const ans: string[] = [];
    for (const query of queries) {
        for (const s of dictionary) {
            let dis = 0;
            for (let i = 0; i < query.length; i++) {
                if (query[i] !== s[i]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans.push(query);
                break;
            }
        }
    }
    return ans;
}

###Rust

impl Solution {
    pub fn two_edit_words(queries: Vec<String>, dictionary: Vec<String>) -> Vec<String> {
        let mut ans = Vec::new();

        for query in &queries {
            for s in &dictionary {
                let mut dis = 0;
                for (c1, c2) in query.chars().zip(s.chars()) {
                    if c1 != c2 {
                        dis += 1;
                    }
                }
                if dis <= 2 {
                    ans.push(query.clone());
                    break;
                }
            }
        }

        ans
    }
}

复杂度分析

  • 时间复杂度:$O(qkn)$,其中 $q$ 为 $\textit{queries}$ 的长度,$k$ 为 $\textit{dictionary}$ 的长度,$n$ 为 $\textit{queries}[i]$ 的长度。我们需要对每一个 $\textit{queries}[i]$ 遍历一次 $\textit{dictionary}$,然后比较两个字符串。
  • 空间复杂度:$O(1)$,仅使用常数个变量。返回数组不计入空间复杂度。

方法二:字典树

思路与算法

我们可以将 $\textit{dictionary}$ 中的所有单词插入字典树,对每个 $\textit{queries}[i]$ 做深度优先搜索,在过程中维护修改次数 $\textit{cnt}$,从而实现在字典树上进行「最多 2 次修改」的匹配搜索。

定义状态 $\textit{dfs}(i, \textit{node}, \textit{cnt})$,其中 $i$ 表示当前匹配到第 $i$ 个字符,$\textit{node}$ 表示当前所在字典树的节点,$\textit{cnt}$ 表示已修改 $\textit{cnt}$ 次。在字典树上进行查找,对于第 $i$ 个字符 $\textit{query}[i]$:

  1. 如果字典树中存在 $\textit{node}.\textit{children}[\textit{query}[i]]$,不进行修改,下一步搜索状态 $\textit{dfs}(i+1, \textit{node}.\textit{children}[\textit{query}[i]], \textit{cnt})$。
  2. 如果字典树中不存在 $\textit{node}.\textit{children}[\textit{query}[i]]$,且 $\textit{cnt}<2$,进行修改,即枚举所有 $c \neq \textit{query}[i]$,下一步搜索状态 $\textit{dfs}(i+1,\textit{node}.\textit{children}[c], \textit{cnt}+1)$。

搜索过程中,我们可以进行一些剪枝,比如某条路径找到合法解就提前终止。

代码

###C++

struct TrieNode {
    TrieNode* child[26];
    bool isEnd;
    TrieNode() {
        memset(child, 0, sizeof(child));
        isEnd = false;
    }
};

class Solution {
public:
    TrieNode* root = new TrieNode();

    void insert(string& word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->child[idx])
                node->child[idx] = new TrieNode();
            node = node->child[idx];
        }
        node->isEnd = true;
    }

    bool dfs(string& word, int i, TrieNode* node, int cnt) {
        if (cnt > 2)
            return false;
        if (!node)
            return false;

        if (i == word.size()) {
            return node->isEnd;
        }

        int idx = word[i] - 'a';

        // 不修改
        if (node->child[idx]) {
            if (dfs(word, i + 1, node->child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx)
                    continue;
                if (node->child[c]) {
                    if (dfs(word, i + 1, node->child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    vector<string> twoEditWords(vector<string>& queries,
                                vector<string>& dictionary) {
        for (auto& w : dictionary)
            insert(w);

        vector<string> res;
        for (auto& q : queries) {
            if (dfs(q, 0, root, 0)) {
                res.push_back(q);
            }
        }
        return res;
    }
};

###Java

class Solution {
    static class TrieNode {
        TrieNode[] child = new TrieNode[26];
        boolean isEnd = false;
    }

    TrieNode root = new TrieNode();

    void insert(String word) {
        TrieNode node = root;
        for (char c : word.toCharArray()) {
            int idx = c - 'a';
            if (node.child[idx] == null)
                node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    boolean dfs(String word, int i, TrieNode node, int cnt) {
        if (cnt > 2 || node == null)
            return false;

        if (i == word.length())
            return node.isEnd;

        int idx = word.charAt(i) - 'a';

        // 不修改
        if (node.child[idx] != null) {
            if (dfs(word, i + 1, node.child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx)
                    continue;
                if (node.child[c] != null) {
                    if (dfs(word, i + 1, node.child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    public List<String> twoEditWords(String[] queries, String[] dictionary) {
        for (String w : dictionary)
            insert(w);

        List<String> res = new ArrayList<>();
        for (String q : queries) {
            if (dfs(q, 0, root, 0)) {
                res.add(q);
            }
        }
        return res;
    }
}

###C#

public class Solution {
    class TrieNode {
        public TrieNode[] child = new TrieNode[26];
        public bool isEnd = false;
    }

    TrieNode root = new TrieNode();

    void Insert(string word) {
        var node = root;
        foreach (char c in word) {
            int idx = c - 'a';
            if (node.child[idx] == null)
                node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    bool Dfs(string word, int i, TrieNode node, int cnt) {
        if (cnt > 2 || node == null)
            return false;

        if (i == word.Length)
            return node.isEnd;

        int idx = word[i] - 'a';

        // 不修改
        if (node.child[idx] != null) {
            if (Dfs(word, i + 1, node.child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx) continue;
                if (node.child[c] != null) {
                    if (Dfs(word, i + 1, node.child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    public IList<string> TwoEditWords(string[] queries, string[] dictionary) {
        foreach (var w in dictionary)
            Insert(w);

        var res = new List<string>();
        foreach (var q in queries) {
            if (Dfs(q, 0, root, 0)) {
                res.Add(q);
            }
        }
        return res;
    }
}

###Python

class TrieNode:
    def __init__(self):
        self.child = [None] * 26
        self.isEnd = False


class Solution:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for c in word:
            idx = ord(c) - ord('a')
            if not node.child[idx]:
                node.child[idx] = TrieNode()
            node = node.child[idx]
        node.isEnd = True

    def dfs(self, word, i, node, cnt):
        if cnt > 2 or not node:
            return False

        if i == len(word):
            return node.isEnd

        idx = ord(word[i]) - ord('a')

        # 不修改
        if node.child[idx] and self.dfs(word, i + 1, node.child[idx], cnt):
            return True

        # 修改
        if cnt < 2:
            for c in range(26):
                if c == idx:
                    continue
                if node.child[c] and self.dfs(word, i + 1, node.child[c], cnt + 1):
                    return True

        return False

    def twoEditWords(self, queries, dictionary):
        for w in dictionary:
            self.insert(w)

        res = []
        for q in queries:
            if self.dfs(q, 0, self.root, 0):
                res.append(q)
        return res

###C

typedef struct TrieNode {
    struct TrieNode* child[26];
    bool isEnd;
} TrieNode;

TrieNode* newNode() {
    TrieNode* node = (TrieNode*)malloc(sizeof(TrieNode));
    memset(node->child, 0, sizeof(node->child));
    node->isEnd = false;
    return node;
}

void insert(TrieNode* root, char* word) {
    TrieNode* node = root;
    for (int i = 0; word[i]; i++) {
        int idx = word[i] - 'a';
        if (!node->child[idx])
            node->child[idx] = newNode();
        node = node->child[idx];
    }
    node->isEnd = true;
}

bool dfs(char* word, int i, TrieNode* node, int cnt) {
    if (cnt > 2 || !node)
        return false;

    if (word[i] == '\0')
        return node->isEnd;

    int idx = word[i] - 'a';

    // 不修改
    if (node->child[idx] && dfs(word, i + 1, node->child[idx], cnt))
        return true;

    // 修改
    if (cnt < 2) {
        for (int c = 0; c < 26; c++) {
            if (c == idx) continue;
            if (node->child[c] && dfs(word, i + 1, node->child[c], cnt + 1))
                return true;
        }
    }

    return false;
}

char** twoEditWords(char** queries, int queriesSize,
                    char** dictionary, int dictionarySize,
                    int* returnSize) {
    TrieNode* root = newNode();
    for (int i = 0; i < dictionarySize; i++)
        insert(root, dictionary[i]);

    char** res = (char**)malloc(sizeof(char*) * queriesSize);
    int cnt = 0;

    for (int i = 0; i < queriesSize; i++) {
        if (dfs(queries[i], 0, root, 0)) {
            res[cnt++] = queries[i];
        }
    }

    *returnSize = cnt;
    return res;
}

###Go

type TrieNode struct {
    child [26]*TrieNode
    isEnd bool
}

var root = &TrieNode{}

func insert(word string) {
    node := root
    for _, c := range word {
        idx := c - 'a'
        if node.child[idx] == nil {
            node.child[idx] = &TrieNode{}
        }
        node = node.child[idx]
    }
    node.isEnd = true
}

func dfs(word string, i int, node *TrieNode, cnt int) bool {
    if cnt > 2 || node == nil {
        return false
    }

    if i == len(word) {
        return node.isEnd
    }

    idx := word[i] - 'a'

    // 不修改
    if node.child[idx] != nil && dfs(word, i+1, node.child[idx], cnt) {
        return true
    }

    // 修改
    if cnt < 2 {
        for c := 0; c < 26; c++ {
            if byte(c) == idx {
                continue
            }
            if node.child[c] != nil && dfs(word, i+1, node.child[c], cnt+1) {
                return true
            }
        }
    }

    return false
}

func twoEditWords(queries []string, dictionary []string) []string {
    root = &TrieNode{}
    for _, w := range dictionary {
        insert(w)
    }

    var res []string
    for _, q := range queries {
        if dfs(q, 0, root, 0) {
            res = append(res, q)
        }
    }
    return res
}

###JavaScript

class TrieNode {
    constructor() {
        this.child = new Array(26).fill(null);
        this.isEnd = false;
    }
}

var twoEditWords = function(queries, dictionary) {
    const root = new TrieNode();

    function insert(word) {
        let node = root;
        for (let c of word) {
            let idx = c.charCodeAt(0) - 97;
            if (!node.child[idx]) node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    function dfs(word, i, node, cnt) {
        if (cnt > 2 || !node) return false;
        if (i === word.length) return node.isEnd;

        let idx = word.charCodeAt(i) - 97;

        // 修改
        if (node.child[idx] && dfs(word, i + 1, node.child[idx], cnt))
            return true;

        // 不修改
        if (cnt < 2) {
            for (let c = 0; c < 26; c++) {
                if (c === idx) continue;
                if (node.child[c] && dfs(word, i + 1, node.child[c], cnt + 1))
                    return true;
            }
        }

        return false;
    }

    for (let w of dictionary) insert(w);

    const res = [];
    for (let q of queries) {
        if (dfs(q, 0, root, 0)) res.push(q);
    }

    return res;
};

###TypeScript

class TrieNode {
    child: (TrieNode | null)[] = new Array(26).fill(null);
    isEnd: boolean = false;
}

function twoEditWords(queries: string[], dictionary: string[]): string[] {
    const root = new TrieNode();

    function insert(word: string) {
        let node = root;
        for (const c of word) {
            const idx = c.charCodeAt(0) - 97;
            if (!node.child[idx]) node.child[idx] = new TrieNode();
            node = node.child[idx]!;
        }
        node.isEnd = true;
    }

    function dfs(word: string, i: number, node: TrieNode | null, cnt: number): boolean {
        if (cnt > 2 || !node) return false;
        if (i === word.length) return node.isEnd;

        const idx = word.charCodeAt(i) - 97;

        // 不修改
        if (node.child[idx] && dfs(word, i + 1, node.child[idx], cnt))
            return true;

        // 修改
        if (cnt < 2) {
            for (let c = 0; c < 26; c++) {
                if (c === idx) continue;
                if (node.child[c] && dfs(word, i + 1, node.child[c], cnt + 1))
                    return true;
            }
        }

        return false;
    }

    for (const w of dictionary) insert(w);

    const res: string[] = [];
    for (const q of queries) {
        if (dfs(q, 0, root, 0)) res.push(q);
    }

    return res;
}

###Rust

struct TrieNode {
    child: Vec<Option<Box<TrieNode>>>,
    is_end: bool,
}

impl TrieNode {
    fn new() -> Self {
        let mut child = Vec::with_capacity(26);
        for _ in 0..26 {
            child.push(None);
        }

        Self {
            child,
            is_end: false,
        }
    }
}

impl Solution {
    fn insert(root: &mut TrieNode, word: &str) {
        let mut node = root;
        for c in word.chars() {
            let idx = (c as u8 - b'a') as usize;
            if node.child[idx].is_none() {
                node.child[idx] = Some(Box::new(TrieNode::new()));
            }
            node = node.child[idx].as_mut().unwrap();
        }
        node.is_end = true;
    }

    fn dfs(word: &[u8], i: usize, node: &TrieNode, cnt: i32) -> bool {
        if cnt > 2 {
            return false;
        }
        if i == word.len() {
            return node.is_end;
        }

        let idx = (word[i] - b'a') as usize;

        // 不修改
        if let Some(ref next) = node.child[idx] {
            if Self::dfs(word, i + 1, next, cnt) {
                return true;
            }
        }

        // 修改
        if cnt < 2 {
            for c in 0..26 {
                if c == idx {
                    continue;
                }
                if let Some(ref next) = node.child[c] {
                    if Self::dfs(word, i + 1, next, cnt + 1) {
                        return true;
                    }
                }
            }
        }

        false
    }

    pub fn two_edit_words(queries: Vec<String>, dictionary: Vec<String>) -> Vec<String> {
        let mut root = TrieNode::new();

        for w in &dictionary {
            Self::insert(&mut root, w);
        }

        let mut res = vec![];
        for q in &queries {
            if Self::dfs(q.as_bytes(), 0, &root, 0) {
                res.push(q.clone());
            }
        }

        res
    }
}

复杂度分析

  • 时间复杂度:$O(k \cdot n + q \cdot n^2 \cdot 25^2)$,其中 $q$ 为 $\textit{queries}$ 的长度,$k$ 为 $\textit{dictionary}$ 的长度,$n$ 为 $\textit{queries}[i]$ 的长度。建字典树需要 $O(kn)$,查询时对于每一个字母都有修改和不修改两种选择,选择修改位置有 $C_n^2=n^2$ 种选择,其中不修改有 $1$ 条分支,修改有 $25$ 条分支,最多修改两次,因此有 $25^2$ 种选择。
  • 空间复杂度:$O(kn)$。即为字典树所占用的空间。
❌