普通视图

发现新文章,点击刷新页面。
今天 — 2024年11月26日LeetCode 每日一题题解

[Python3/Java/C++/Go/TypeScript] 一题一解:一次遍历(清晰题解)

作者 lcbin
2024年11月26日 09:14

方法一:一次遍历

我们令 $k = 3$,表示交替组的长度为 $3$。

为了方便处理,我们可以将环展开成一个长度为 $2n$ 的数组,然后从左到右遍历这个数组,用一个变量 $\textit{cnt}$ 记录当前交替组的长度,如果遇到了相同的颜色,就将 $\textit{cnt}$ 重置为 $1$,否则将 $\textit{cnt}$ 加一。如果 $\textit{cnt} \ge k$,并且当前位置 $i$ 大于等于 $n$,那么就找到了一个交替组,答案加一。

遍历结束后,返回答案即可。

###python

class Solution:
    def numberOfAlternatingGroups(self, colors: List[int]) -> int:
        k = 3
        n = len(colors)
        ans = cnt = 0
        for i in range(n << 1):
            if i and colors[i % n] == colors[(i - 1) % n]:
                cnt = 1
            else:
                cnt += 1
            ans += i >= n and cnt >= k
        return ans

###java

class Solution {
    public int numberOfAlternatingGroups(int[] colors) {
        int k = 3;
        int n = colors.length;
        int ans = 0, cnt = 0;
        for (int i = 0; i < n << 1; ++i) {
            if (i > 0 && colors[i % n] == colors[(i - 1) % n]) {
                cnt = 1;
            } else {
                ++cnt;
            }
            ans += i >= n && cnt >= k ? 1 : 0;
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int numberOfAlternatingGroups(vector<int>& colors) {
        int k = 3;
        int n = colors.size();
        int ans = 0, cnt = 0;
        for (int i = 0; i < n << 1; ++i) {
            if (i && colors[i % n] == colors[(i - 1) % n]) {
                cnt = 1;
            } else {
                ++cnt;
            }
            ans += i >= n && cnt >= k ? 1 : 0;
        }
        return ans;
    }
};

###go

func numberOfAlternatingGroups(colors []int) (ans int) {
k := 3
n := len(colors)
cnt := 0
for i := 0; i < n<<1; i++ {
if i > 0 && colors[i%n] == colors[(i-1)%n] {
cnt = 1
} else {
cnt++
}
if i >= n && cnt >= k {
ans++
}
}
return
}

###ts

function numberOfAlternatingGroups(colors: number[]): number {
    const k = 3;
    const n = colors.length;
    let [ans, cnt] = [0, 0];
    for (let i = 0; i < n << 1; ++i) {
        if (i && colors[i % n] === colors[(i - 1) % n]) {
            cnt = 1;
        } else {
            ++cnt;
        }
        ans += i >= n && cnt >= k ? 1 : 0;
    }
    return ans;
}

时间复杂度 $O(n)$,其中 $n$ 为数组 $\textit{colors}$ 的长度。空间复杂度 $O(1)$。


有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈😄~

遍历计数,C 0ms

Problem: 100336. 交替组 I

[TOC]

思路

直接按题意遍历计数。

Code

执行用时分布0ms击败100.00%;消耗内存分布5.69MB击败100.00%

###C

int numberOfAlternatingGroups(int* colors, int colorsSize) {
    int ans = (colors[colorsSize - 2] != colors[colorsSize - 1] && colors[colorsSize - 1] != colors[0])
            + (colors[colorsSize - 1] != colors[0] && colors[0] != colors[1]);
    for (int i = 2; i < colorsSize; ++ i)
        if ((colors[i - 2] != colors[i - 1] && colors[i - 1] != colors[i])) ++ ans;
    return ans;
}

###Python3

class Solution:
    def numberOfAlternatingGroups(self, colors: List[int]) -> int:
        n, colors = len(colors), colors + colors[:2]
        return sum(colors[i] != colors[i + 1] != colors[i + 2] for i in range(n))

您若还有不同方法,欢迎贴在评论区,一起交流探讨! ^_^

↓ 点个赞,点收藏,留个言,再划走,感谢您支持作者! ^_^

每日一题-交替组 I🟢

2024年11月26日 00:00

给你一个整数数组 colors ,它表示一个由红色和蓝色瓷砖组成的环,第 i 块瓷砖的颜色为 colors[i] :

  • colors[i] == 0 表示第 i 块瓷砖的颜色是 红色 。
  • colors[i] == 1 表示第 i 块瓷砖的颜色是 蓝色 。

环中连续 3 块瓷砖的颜色如果是 交替 颜色(也就是说中间瓷砖的颜色与它 左边 和 右边 的颜色都不同),那么它被称为一个 交替 组。

请你返回 交替 组的数目。

注意 ,由于 colors 表示一个  ,第一块 瓷砖和 最后一块 瓷砖是相邻的。

 

示例 1:

输入:colors = [1,1,1]

输出:0

解释:

示例 2:

输入:colors = [0,1,0,0,1]

输出:3

解释:

交替组包括:

 

提示:

  • 3 <= colors.length <= 100
  • 0 <= colors[i] <= 1

交替组 I

2024年11月13日 10:16

方法一:模拟

思路

按照题意遍历数组 $\textit{colors}$ 的每个元素,判断其前一个元素和后一个元素是否都与当前元素不同,如果满足,则将结果加 $1$。注意瓷砖是环形的,则数组的首尾元素是相邻的。最后返回结果。

代码

###Python

class Solution:
    def numberOfAlternatingGroups(self, colors: List[int]) -> int:
        n = len(colors)
        res = 0
        for i in range(n):
            if colors[i] != colors[i - 1] and colors[i] != colors[(i + 1) % n]:
                res += 1
        return res

###Java

class Solution {
    public int numberOfAlternatingGroups(int[] colors) {
        int n = colors.length;
        int res = 0;
        for (int i = 0; i < n; i++) {
            if (colors[i] != colors[(i - 1 + n) % n] && colors[i] != colors[(i + 1) % n]) {
                res += 1;
            }
        }
        return res;
    }
}

###C#

public class Solution {
    public int NumberOfAlternatingGroups(int[] colors) {
        int n = colors.Length;
        int res = 0;
        for (int i = 0; i < n; i++) {
            if (colors[i] != colors[(i - 1 + n) % n] && colors[i] != colors[(i + 1) % n]) {
                res += 1;
            }
        }
        return res;
    }
}

###C++

class Solution {
public:
    int numberOfAlternatingGroups(vector<int>& colors) {
        int n = colors.size();
        int res = 0;
        for (int i = 0; i < n; i++) {
            if (colors[i] != colors[(i - 1 + n) % n] && colors[i] != colors[(i + 1) % n]) {
                res += 1;
            }
        }
        return res;
    }
};

###Go

func numberOfAlternatingGroups(colors []int) int {
    n := len(colors)
    res := 0
    for i := 0; i < n; i++ {
        if colors[i] != colors[(i-1+n)%n] && colors[i] != colors[(i+1)%n] {
            res++
        }
    }
    return res
}

###C

int numberOfAlternatingGroups(int* colors, int colorsSize) {
    int res = 0;
    for (size_t i = 0; i < colorsSize; i++) {
        if (colors[i] != colors[(i - 1 + colorsSize) % colorsSize] && colors[i] != colors[(i + 1) % colorsSize]) {
            res += 1;
        }
    }
    return res;
}

###JavaScript

var numberOfAlternatingGroups = function(colors) {
    const n = colors.length;
    let res = 0;
    for (let i = 0; i < n; i++) {
        if (colors[i] !== colors[(i - 1 + n) % n] && colors[i] !== colors[(i + 1) % n]) {
            res++;
        }
    }
    return res;
};

###TypeScript

function numberOfAlternatingGroups(colors: number[]): number {
    const n = colors.length;
    let res = 0;
    for (let i = 0; i < n; i++) {
        if (colors[i] !== colors[(i - 1 + n) % n] && colors[i] !== colors[(i + 1) % n]) {
            res++;
        }
    }
    return res;
};

###Rust

impl Solution {
    pub fn number_of_alternating_groups(colors: Vec<i32>) -> i32 {
        let n = colors.len();
        let mut res = 0;
        for i in 0..n {
            if colors[i] != colors[(i + n - 1) % n] && colors[i] != colors[(i + 1) % n] {
                res += 1;
            }
        }
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n)$。

  • 空间复杂度:$O(1)$。

滑动窗口

作者 tsreaper
2024年7月7日 01:31

解法:滑动窗口

枚举组的开头,那么组中间的 $(k - 2)$ 个元素都需要满足“与两边的颜色不同”的条件。预处理哪些元素和两边的颜色不同,再用滑动窗口统计中间的 $(k - 2)$ 个元素中,有几个满足该条件即可。复杂度 $\mathcal{O}(n)$。

参考代码(c++)

###cpp

class Solution {
public:
    int numberOfAlternatingGroups(vector<int>& colors, int K) {
        int n = colors.size();
        // 预处理哪些元素与两边颜色不同
        int f[n];
        for (int i = 0; i < n; i++) {
            int x = colors[(i - 1 + n) % n];
            int y = colors[i];
            int z = colors[(i + 1) % n];
            if (x != y && y != z) f[i] = 1;
            else f[i] = 0;
        }

        // 滑动窗口
        int sm = 0;
        for (int i = 1; i + 1 < K; i++) sm += f[i];
        int ans = 0;
        for (int i = 0; i < n; i++) {
            if (sm == K - 2) ans++;
            sm -= f[(i + 1) % n];
            sm += f[(i + K - 1) % n];
        }
        return ans;
    }
};
昨天 — 2024年11月25日LeetCode 每日一题题解

每日一题-网络延迟时间🟡

2024年11月25日 00:00

n 个网络节点,标记为 1 到 n

给你一个列表 times,表示信号经过 有向 边的传递时间。 times[i] = (ui, vi, wi),其中 ui 是源节点,vi 是目标节点, wi 是一个信号从源节点传递到目标节点的时间。

现在,从某个节点 K 发出一个信号。需要多久才能使所有节点都收到信号?如果不能使所有节点收到信号,返回 -1

 

示例 1:

输入:times = [[2,1,1],[2,3,1],[3,4,1]], n = 4, k = 2
输出:2

示例 2:

输入:times = [[1,2,1]], n = 2, k = 1
输出:1

示例 3:

输入:times = [[1,2,1]], n = 2, k = 2
输出:-1

 

提示:

  • 1 <= k <= n <= 100
  • 1 <= times.length <= 6000
  • times[i].length == 3
  • 1 <= ui, vi <= n
  • ui != vi
  • 0 <= wi <= 100
  • 所有 (ui, vi) 对都 互不相同(即,不含重复边)

两种 Dijkstra 写法(附题单)Python/Java/C++/Go/JS/Rust

作者 endlesscheng
2024年3月5日 10:23

Dijkstra 算法介绍

定义 $g[i][j]$ 表示节点 $i$ 到节点 $j$ 这条边的边权。如果没有 $i$ 到 $j$ 的边,则 $g[i][j]=\infty$。

定义 $\textit{dis}[i]$ 表示起点 $k$ 到节点 $i$ 的最短路长度,一开始 $\textit{dis}[k]=0$,其余 $\textit{dis}[i]=\infty$ 表示尚未计算出。

我们的目标是计算出最终的 $\textit{dis}$ 数组。

  • 首先更新起点 $k$ 到其邻居 $y$ 的最短路,即更新 $\textit{dis}[y]$ 为 $g[k][y]$。
  • 然后取除了起点 $k$ 以外的 $\textit{dis}[i]$ 的最小值,假设最小值对应的节点是 $3$。此时可以断言:$\textit{dis}[3]$ 已经是 $k$ 到 $3$ 的最短路长度,不可能有其它 $k$ 到 $3$ 的路径更短!反证法:假设存在更短的路径,那我们一定会从 $k$ 出发经过一个点 $u$,它的 $\textit{dis}[u]$ 比 $\textit{dis}[3]$ 还要小,然后再经过一些边到达 $3$,得到更小的 $\textit{dis}[3]$。但 $\textit{dis}[3]$ 已经是最小的了,并且图中没有负数边权,所以 $u$ 是不存在的,矛盾。故原命题成立,此时我们得到了 $\textit{dis}[3]$ 的最终值。
  • 用节点 $3$ 到其邻居 $y$ 的边权 $g[3][y]$ 更新 $\textit{dis}[y]$:如果 $\textit{dis}[3] + g[3][y] < \textit{dis}[y]$,那么更新 $\textit{dis}[y]$ 为 $\textit{dis}[3] + g[3][y]$,否则不更新。
  • 然后取除了节点 $k,3$ 以外的 $\textit{dis}[i]$ 的最小值,重复上述过程。
  • 由数学归纳法可知,这一做法可以得到每个点的最短路。当所有点的最短路都已确定时,算法结束。

写法一:朴素 Dijkstra(适用于稠密图)

对于本题,在计算最短路时,如果发现当前找到的最小最短路等于 $\infty$,说明有节点无法到达,可以提前结束算法,返回 $-1$。

如果所有节点都可以到达,返回 $\max(\textit{dis})$。

代码实现时,节点编号改成从 $0$ 开始。

###py

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[inf for _ in range(n)] for _ in range(n)]  # 邻接矩阵
        for x, y, d in times:
            g[x - 1][y - 1] = d

        dis = [inf] * n
        ans = dis[k - 1] = 0
        done = [False] * n
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x < 0:
                return ans  # 最后一次算出的最短路就是最大的
            if dis[x] == inf:  # 有节点无法到达
                return -1
            ans = dis[x]  # 求出的最短路会越来越大
            done[x] = True  # 最短路长度已确定(无法变得更小)
            for y, d in enumerate(g[x]):
                # 更新 x 的邻居的最短路
                dis[y] = min(dis[y], dis[x] + d)

###java

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        final int INF = Integer.MAX_VALUE / 2; // 防止加法溢出
        int[][] g = new int[n][n]; // 邻接矩阵
        for (int[] row : g) {
            Arrays.fill(row, INF);
        }
        for (int[] t : times) {
            g[t[0] - 1][t[1] - 1] = t[2];
        }

        int maxDis = 0;
        int[] dis = new int[n];
        Arrays.fill(dis, INF);
        dis[k - 1] = 0;
        boolean[] done = new boolean[n];
        while (true) {
            int x = -1;
            for (int i = 0; i < n; i++) {
                if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                    x = i;
                }
            }
            if (x < 0) {
                return maxDis; // 最后一次算出的最短路就是最大的
            }
            if (dis[x] == INF) { // 有节点无法到达
                return -1;
            }
            maxDis = dis[x]; // 求出的最短路会越来越大
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (int y = 0; y < n; y++) {
                // 更新 x 的邻居的最短路
                dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
            }
        }
    }
}

###cpp

class Solution {
public:
    int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        vector<vector<int>> g(n, vector<int>(n, INT_MAX / 2)); // 邻接矩阵
        for (auto& t : times) {
            g[t[0] - 1][t[1] - 1] = t[2];
        }

        vector<int> dis(n, INT_MAX / 2), done(n);
        dis[k - 1] = 0;
        while (true) {
            int x = -1;
            for (int i = 0; i < n; i++) {
                if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                    x = i;
                }
            }
            if (x < 0) {
                return ranges::max(dis);
            }
            if (dis[x] == INT_MAX / 2) { // 有节点无法到达
                return -1;
            }
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (int y = 0; y < n; y++) {
                // 更新 x 的邻居的最短路
                dis[y] = min(dis[y], dis[x] + g[x][y]);
            }
        }
    }
};

###go

func networkDelayTime(times [][]int, n, k int) int {
    const inf = math.MaxInt / 2 // 防止加法溢出
    g := make([][]int, n) // 邻接矩阵
    for i := range g {
        g[i] = make([]int, n)
        for j := range g[i] {
            g[i][j] = inf
        }
    }
    for _, t := range times {
        g[t[0]-1][t[1]-1] = t[2]
    }

    dis := make([]int, n)
    for i := range dis {
        dis[i] = inf
    }
    dis[k-1] = 0
    done := make([]bool, n)
    for {
        x := -1
        for i, ok := range done {
            if !ok && (x < 0 || dis[i] < dis[x]) {
                x = i
            }
        }
        if x < 0 {
            return slices.Max(dis)
        }
        if dis[x] == inf { // 有节点无法到达
            return -1
        }
        done[x] = true // 最短路长度已确定(无法变得更小)
        for y, d := range g[x] {
            // 更新 x 的邻居的最短路
            dis[y] = min(dis[y], dis[x]+d)
        }
    }
}

###js

var networkDelayTime = function(times, n, k) {
    const g = Array.from({length: n}, () => Array(n).fill(Infinity)); // 邻接矩阵
    for (const [x, y, d] of times) {
        g[x - 1][y - 1] = d;
    }

    const dis = Array(n).fill(Infinity);
    dis[k - 1] = 0;
    const done = Array(n).fill(false);
    while (true) {
        let x = -1;
        for (let i = 0; i < n; i++) {
            if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                x = i;
            }
        }
        if (x < 0) {
            return Math.max(...dis);
        }
        if (dis[x] === Infinity) { // 有节点无法到达
            return -1;
        }
        done[x] = true; // 最短路长度已确定(无法变得更小)
        for (let y = 0; y < n; y++) {
            // 更新 x 的邻居的最短路
            dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
        }
    }
};

###rust

impl Solution {
    pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
        const INF: i32 = i32::MAX / 2; // 防止加法溢出
        let n = n as usize;
        let mut g = vec![vec![INF; n]; n]; // 邻接矩阵
        for t in &times {
            g[t[0] as usize - 1][t[1] as usize - 1] = t[2];
        }

        let mut dis = vec![INF; n];
        dis[k as usize - 1] = 0;
        let mut done = vec![false; n];
        loop {
            let mut x = n;
            for (i, &ok) in done.iter().enumerate() {
                if !ok && (x == n || dis[i] < dis[x]) {
                    x = i;
                }
            }
            if x == n {
                return *dis.iter().max().unwrap();
            }
            if dis[x] == INF { // 有节点无法到达
                return -1;
            }
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (y, &d) in g[x].iter().enumerate() {
                // 更新 x 的邻居的最短路
                dis[y] = dis[y].min(dis[x] + d);
            }
        }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$。
  • 空间复杂度:$\mathcal{O}(n^2)$。

写法二:堆优化 Dijkstra(适用于稀疏图)

寻找最小值的过程可以用一个最小堆来快速完成:

  • 一开始把 $(\textit{dis}[k],k)$ 二元组入堆。
  • 当节点 $x$ 首次出堆时,$\textit{dis}[x]$ 就是写法一中寻找的最小最短路。
  • 更新 $\textit{dis}[y]$ 时,把 $(\textit{dis}[y],y)$ 二元组入堆。

注意,如果一个节点 $x$ 在出堆前,其最短路长度 $\textit{dis}[x]$ 被多次更新,那么堆中会有多个重复的 $x$,并且包含 $x$ 的二元组中的 $\textit{dis}[x]$ 是互不相同的(因为我们只在找到更小的最短路时才会把二元组入堆)。

所以写法一中的 $\textit{done}$ 数组可以省去,取而代之的是用出堆的最短路值(记作 $\textit{dx}$)与当前的 $\textit{dis}[x]$ 比较,如果 $\textit{dx} > \textit{dis}[x]$ 说明 $x$ 之前出堆过,我们已经更新了 $x$ 的邻居的最短路,所以这次就不用更新了,继续外层循环。

答疑

:为什么代码要判断 dx > dis[x]

:对于同一个 $x$,例如先入堆一个比较大的 $\textit{dis}[x]=10$,后面又把 $\textit{dis}[x]$ 更新成 $5$,之后这个 $5$ 会先出堆,然后再把 $10$ 出堆。$10$ 出堆时候是没有必要去更新周围邻居的最短路的,因为 $5$ 出堆之后,就已经把邻居的最短路更新过了,用 $10$ 是无法把邻居的最短路变得更短的,所以直接 continue

###py

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[] for _ in range(n)]  # 邻接表
        for x, y, d in times:
            g[x - 1].append((y - 1, d))

        dis = [inf] * n
        dis[k - 1] = 0
        h = [(0, k - 1)]
        while h:
            dx, x = heappop(h)
            if dx > dis[x]:  # x 之前出堆过
                continue
            for y, d in g[x]:
                new_dis = dx + d
                if new_dis < dis[y]:
                    dis[y] = new_dis  # 更新 x 的邻居的最短路
                    heappush(h, (new_dis, y))
        mx = max(dis)
        return mx if mx < inf else -1

###java

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        List<int[]>[] g = new ArrayList[n]; // 邻接表
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] t : times) {
            g[t[0] - 1].add(new int[]{t[1] - 1, t[2]});
        }

        int maxDis = 0;
        int left = n; // 未确定最短路的节点个数
        int[] dis = new int[n];
        Arrays.fill(dis, Integer.MAX_VALUE);
        dis[k - 1] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
        pq.offer(new int[]{0, k - 1});
        while (!pq.isEmpty()) {
            int[] p = pq.poll();
            int dx = p[0];
            int x = p[1];
            if (dx > dis[x]) { // x 之前出堆过
                continue;
            }
            maxDis = dx; // 求出的最短路会越来越大
            left--;
            for (int[] e : g[x]) {
                int y = e[0];
                int newDis = dx + e[1];
                if (newDis < dis[y]) {
                    dis[y] = newDis; // 更新 x 的邻居的最短路
                    pq.offer(new int[]{newDis, y});
                }
            }
        }
        return left == 0 ? maxDis : -1;
    }
}

###cpp

class Solution {
public:
    int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        vector<vector<pair<int, int>>> g(n); // 邻接表
        for (auto& t : times) {
            g[t[0] - 1].emplace_back(t[1] - 1, t[2]);
        }

        vector<int> dis(n, INT_MAX);
        dis[k - 1] = 0;
        priority_queue<pair<int, int>, vector<pair<int, int>>, greater<>> pq;
        pq.emplace(0, k - 1);
        while (!pq.empty()) {
            auto [dx, x] = pq.top();
            pq.pop();
            if (dx > dis[x]) { // x 之前出堆过
                continue;
            }
            for (auto &[y, d] : g[x]) {
                int new_dis = dx + d;
                if (new_dis < dis[y]) {
                    dis[y] = new_dis; // 更新 x 的邻居的最短路
                    pq.emplace(new_dis, y);
                }
            }
        }
        int mx = ranges::max(dis);
        return mx < INT_MAX ? mx : -1;
    }
};

###go

func networkDelayTime(times [][]int, n, k int) int {
    type edge struct{ to, wt int }
    g := make([][]edge, n) // 邻接表
    for _, t := range times {
        g[t[0]-1] = append(g[t[0]-1], edge{t[1] - 1, t[2]})
    }

    dis := make([]int, n)
    for i := range dis {
        dis[i] = math.MaxInt
    }
    dis[k-1] = 0
    h := hp{{0, k - 1}}
    for len(h) > 0 {
        p := heap.Pop(&h).(pair)
        dx := p.dis
        x := p.x
        if dx > dis[x] { // x 之前出堆过
            continue
        }
        for _, e := range g[x] {
            y := e.to
            newDis := dx + e.wt
            if newDis < dis[y] {
                dis[y] = newDis // 更新 x 的邻居的最短路
                heap.Push(&h, pair{newDis, y})
            }
        }
    }
    mx := slices.Max(dis)
    if mx < math.MaxInt {
        return mx
    }
    return -1
}

type pair struct{ dis, x int }
type hp []pair
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].dis < h[j].dis }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(pair)) }
func (h *hp) Pop() (v any)      { a := *h; *h, v = a[:len(a)-1], a[len(a)-1]; return }

###js

var networkDelayTime = function(times, n, k) {
    const g = Array.from({length: n}, () => []); // 邻接表
    for (const [x, y, d] of times) {
        g[x - 1].push([y - 1, d]);
    }

    const dis = Array(n).fill(Infinity);
    dis[k - 1] = 0;
    const pq = new MinPriorityQueue({priority: (p) => p[0]});
    pq.enqueue([0, k - 1]);
    while (!pq.isEmpty()) {
        const [dx, x] = pq.dequeue().element;
        if (dx > dis[x]) { // x 之前出堆过
            continue;
        }
        for (const [y, d] of g[x]) {
            const newDis = dx + d;
            if (newDis < dis[y]) {
                dis[y] = newDis; // 更新 x 的邻居的最短路
                pq.enqueue([newDis, y]);
            }
        }
    }
    const mx = Math.max(...dis);
    return mx < Infinity ? mx : -1;
};

###rust

use std::collections::BinaryHeap;

impl Solution {
    pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
        let n = n as usize;
        let k = k as usize - 1;
        let mut g = vec![vec![]; n]; // 邻接表
        for t in &times {
            g[t[0] as usize - 1].push((t[1] as usize - 1, t[2]));
        }

        let mut dis = vec![i32::MAX; n];
        dis[k] = 0;
        let mut h = BinaryHeap::new();
        h.push((0, k));
        while let Some((dx, x)) = h.pop() {
            if -dx > dis[x] { // x 之前出堆过
                continue;
            }
            for &(y, d) in &g[x] {
                let new_dis = -dx + d;
                if new_dis < dis[y] {
                    dis[y] = new_dis; // 更新 x 的邻居的最短路
                    h.push((-new_dis, y));
                }
            }
        }
        let mx = *dis.iter().max().unwrap();
        if mx < i32::MAX { mx } else { -1 }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(m\log m)$,其中 $m$ 为 $\textit{times}$ 的长度。由于 $m\ge n-1$,分析复杂度时以 $m$ 为主。注意堆中会有重复节点,所以至多有 $\mathcal{O}(m)$ 个元素,单次操作的复杂度是 $\mathcal{O}(\log m)$。值得注意的是,如果输入的是稠密图,写法二的时间复杂度为 $\mathcal{O}(n^2\log n)$,不如写法一。
  • 空间复杂度:$\mathcal{O}(m)$。

更多相似题目,见下面图论题单中的「单源最短路:Dijkstra」。

分类题单

如何科学刷题?

  1. 滑动窗口(定长/不定长/多指针)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心算法(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

【宫水三叶】涵盖所有的「存图方式」与「最短路算法(详尽注释)」

作者 AC_OIer
2021年8月2日 10:21

欢迎关注我 ❤️ 提供写「证明」&「思路」的高质量专项题解
后台回复「刷题路线」有惊喜,更有「长期送实体书」活动等你来 🎉 🎉

基本分析

为了方便,我们约定 $n$ 为点数,$m$ 为边数。

根据题意,首先 $n$ 的数据范围只有 $100$,$m$ 的数据范围为 $6000$,使用「邻接表」或「邻接矩阵」来存图都可以。

同时求的是「从 $k$ 点出发,所有点都被访问到的最短时间」,将问题转换一下其实就是求「从 $k$ 点出发,到其他点 $x$ 的最短距离的最大值」。


存图方式

在开始讲解最短路之前,我们先来学习三种「存图」方式。

邻接矩阵

这是一种使用二维矩阵来进行存图的方式。

适用于边数较多的稠密图使用,当边数量接近点的数量的平方,即 $m \approx n^2$ 时,可定义为稠密图

###Java

// 邻接矩阵数组:w[a][b] = c 代表从 a 到 b 有权重为 c 的边
int[][] w = new int[N][N];

// 加边操作
void add(int a, int b, int c) {
    w[a][b] = c;
}

邻接表

这也是一种在图论中十分常见的存图方式,与数组存储单链表的实现一致(头插法)。

这种存图方式又叫链式前向星存图

适用于边数较少的稀疏图使用,当边数量接近点的数量,即 $m \approx n$ 时,可定义为稀疏图

###Java

int[] he = new int[N], e = new int[M], ne = new int[M], w = new int[M];
int idx;

void add(int a, int b, int c) {
    e[idx] = b;
    ne[idx] = he[a];
    he[a] = idx;
    w[idx] = c;
    idx++;
}

首先 idx 是用来对边进行编号的,然后对存图用到的几个数组作简单解释:

  • he 数组:存储是某个节点所对应的边的集合(链表)的头结点;
  • e 数组:由于访问某一条边指向的节点;
  • ne 数组:由于是以链表的形式进行存边,该数组就是用于找到下一条边;
  • w 数组:用于记录某条边的权重为多少。

因此当我们想要遍历所有由 a 点发出的边时,可以使用如下方式:

###Java

for (int i = he[a]; i != -1; i = ne[i]) {
    int b = e[i], c = w[i]; // 存在由 a 指向 b 的边,权重为 c
}

这是一种最简单,但是相比上述两种存图方式,使用得较少的存图方式。

只有当我们需要确保某个操作复杂度严格为 $O(m)$ 时,才会考虑使用。

具体的,我们建立一个类来记录有向边信息:

###Java

class Edge {
    // 代表从 a 到 b 有一条权重为 c 的边
    int a, b, c;
    Edge(int _a, int _b, int _c) {
        a = _a; b = _b; c = _c;
    }
}

通常我们会使用 List 存起所有的边对象,并在需要遍历所有边的时候,进行遍历:

###Java

List<Edge> es = new ArrayList<>();

...

for (Edge e : es) {
    ...
}

Floyd(邻接矩阵)

根据「基本分析」,我们可以使用复杂度为 $O(n^3)$ 的「多源汇最短路」算法 Floyd 算法进行求解,同时使用「邻接矩阵」来进行存图。

此时计算量约为 $10^6$,可以过。

跑一遍 Floyd,可以得到「从任意起点出发,到达任意起点的最短距离」。然后从所有 $w[k][x]$ 中取 $max$ 即是「从 $k$ 点出发,到其他点 $x$ 的最短距离的最大值」。

image.png

代码:

###Java

class Solution {
    int N = 110, M = 6010;
    // 邻接矩阵数组:w[a][b] = c 代表从 a 到 b 有权重为 c 的边
    int[][] w = new int[N][N];
    int INF = 0x3f3f3f3f;
    int n, k;
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        // 初始化邻接矩阵
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                w[i][j] = w[j][i] = i == j ? 0 : INF;
            }
        }
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            w[u][v] = c;
        }
        // 最短路
        floyd();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, w[k][i]);
        }
        return ans >= INF / 2 ? -1 : ans;
    }
    void floyd() {
        // floyd 基本流程为三层循环:
        // 枚举中转点 - 枚举起点 - 枚举终点 - 松弛操作        
        for (int p = 1; p <= n; p++) {
            for (int i = 1; i <= n; i++) {
                for (int j = 1; j <= n; j++) {
                    w[i][j] = Math.min(w[i][j], w[i][p] + w[p][j]);
                }
            }
        }
    }
}
  • 时间复杂度:$O(n^3)$
  • 空间复杂度:$O(n^2)$

朴素 Dijkstra(邻接矩阵)

同理,我们可以使用复杂度为 $O(n^2)$ 的「单源最短路」算法朴素 Dijkstra 算法进行求解,同时使用「邻接矩阵」来进行存图。

根据题意,$k$ 点作为源点,跑一遍 Dijkstra 我们可以得到从源点 $k$ 到其他点 $x$ 的最短距离,再从所有最短路中取 $max$ 即是「从 $k$ 点出发,到其他点 $x$ 的最短距离的最大值」。

朴素 Dijkstra 复杂度为 $O(n^2)$,可以过。

image.png

代码:

###Java

class Solution {
    int N = 110, M = 6010;
    // 邻接矩阵数组:w[a][b] = c 代表从 a 到 b 有权重为 c 的边
    int[][] w = new int[N][N];
    // dist[x] = y 代表从「源点/起点」到 x 的最短距离为 y
    int[] dist = new int[N];
    // 记录哪些点已经被更新过
    boolean[] vis = new boolean[N];
    int INF = 0x3f3f3f3f;
    int n, k;
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        // 初始化邻接矩阵
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                w[i][j] = w[j][i] = i == j ? 0 : INF;
            }
        }
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            w[u][v] = c;
        }
        // 最短路
        dijkstra();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, dist[i]);
        }
        return ans > INF / 2 ? -1 : ans;
    }
    void dijkstra() {
        // 起始先将所有的点标记为「未更新」和「距离为正无穷」
        Arrays.fill(vis, false);
        Arrays.fill(dist, INF);
        // 只有起点最短距离为 0
        dist[k] = 0;
        // 迭代 n 次
        for (int p = 1; p <= n; p++) {
            // 每次找到「最短距离最小」且「未被更新」的点 t
            int t = -1;
            for (int i = 1; i <= n; i++) {
                if (!vis[i] && (t == -1 || dist[i] < dist[t])) t = i;
            }
            // 标记点 t 为已更新
            vis[t] = true;
            // 用点 t 的「最小距离」更新其他点
            for (int i = 1; i <= n; i++) {
                dist[i] = Math.min(dist[i], dist[t] + w[t][i]);
            }
        }
    }
}
  • 时间复杂度:$O(n^2)$
  • 空间复杂度:$O(n^2)$

堆优化 Dijkstra(邻接表)

由于边数据范围不算大,我们还可以使用复杂度为 $O(m\log{n})$ 的堆优化 Dijkstra 算法进行求解。

堆优化 Dijkstra 算法与朴素 Dijkstra 都是「单源最短路」算法。

跑一遍堆优化 Dijkstra 算法求最短路,再从所有最短路中取 $max$ 即是「从 $k$ 点出发,到其他点 $x$ 的最短距离的最大值」。

此时算法复杂度为 $O(m\log{n})$,可以过。

image.png

代码:

###Java

class Solution {
    int N = 110, M = 6010;
    // 邻接表
    int[] he = new int[N], e = new int[M], ne = new int[M], w = new int[M];
    // dist[x] = y 代表从「源点/起点」到 x 的最短距离为 y
    int[] dist = new int[N];
    // 记录哪些点已经被更新过
    boolean[] vis = new boolean[N];
    int n, k, idx;
    int INF = 0x3f3f3f3f;
    void add(int a, int b, int c) {
        e[idx] = b;
        ne[idx] = he[a];
        he[a] = idx;
        w[idx] = c;
        idx++;
    }
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        // 初始化链表头
        Arrays.fill(he, -1);
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            add(u, v, c);
        }
        // 最短路
        dijkstra();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, dist[i]);
        }
        return ans > INF / 2 ? -1 : ans;
    }
    void dijkstra() {
        // 起始先将所有的点标记为「未更新」和「距离为正无穷」
        Arrays.fill(vis, false);
        Arrays.fill(dist, INF);
        // 只有起点最短距离为 0
        dist[k] = 0;
        // 使用「优先队列」存储所有可用于更新的点
        // 以 (点编号, 到起点的距离) 进行存储,优先弹出「最短距离」较小的点
        PriorityQueue<int[]> q = new PriorityQueue<>((a,b)->a[1]-b[1]);
        q.add(new int[]{k, 0});
        while (!q.isEmpty()) {
            // 每次从「优先队列」中弹出
            int[] poll = q.poll();
            int id = poll[0], step = poll[1];
            // 如果弹出的点被标记「已更新」,则跳过
            if (vis[id]) continue;
            // 标记该点「已更新」,并使用该点更新其他点的「最短距离」
            vis[id] = true;
            for (int i = he[id]; i != -1; i = ne[i]) {
                int j = e[i];
                if (dist[j] > dist[id] + w[i]) {
                    dist[j] = dist[id] + w[i];
                    q.add(new int[]{j, dist[j]});
                }
            }
        }
    }
}
  • 时间复杂度:$O(m\log{n} + n)$
  • 空间复杂度:$O(m)$

Bellman Ford(类 & 邻接表)

虽然题目规定了不存在「负权边」,但我们仍然可以使用可以在「负权图中求最短路」的 Bellman Ford 进行求解,该算法也是「单源最短路」算法,复杂度为 $O(n * m)$。

通常为了确保 $O(n * m)$,可以单独建一个类代表边,将所有边存入集合中,在 $n$ 次松弛操作中直接对边集合进行遍历(代码见 $P1$)。

由于本题边的数量级大于点的数量级,因此也能够继续使用「邻接表」的方式进行边的遍历,遍历所有边的复杂度的下界为 $O(n)$,上界可以确保不超过 $O(m)$(代码见 $P2$)。

image.png

代码:

###Java

class Solution {
    class Edge {
        int a, b, c;
        Edge(int _a, int _b, int _c) {
            a = _a; b = _b; c = _c;
        }
    }
    int N = 110, M = 6010;
    // dist[x] = y 代表从「源点/起点」到 x 的最短距离为 y
    int[] dist = new int[N];
    int INF = 0x3f3f3f3f;
    int n, m, k;
    // 使用类进行存边
    List<Edge> es = new ArrayList<>();
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        m = ts.length;
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            es.add(new Edge(u, v, c));
        }
        // 最短路
        bf();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, dist[i]);
        }
        return ans > INF / 2 ? -1 : ans;
    }
    void bf() {
        // 起始先将所有的点标记为「距离为正无穷」
        Arrays.fill(dist, INF);
        // 只有起点最短距离为 0
        dist[k] = 0;
        // 迭代 n 次
        for (int p = 1; p <= n; p++) {
            int[] prev = dist.clone();
            // 每次都使用上一次迭代的结果,执行松弛操作
            for (Edge e : es) {
                int a = e.a, b = e.b, c = e.c;
                dist[b] = Math.min(dist[b], prev[a] + c);
            }
        }
    }
}

###Java

class Solution {
    int N = 110, M = 6010;
    // 邻接表
    int[] he = new int[N], e = new int[M], ne = new int[M], w = new int[M];
    // dist[x] = y 代表从「源点/起点」到 x 的最短距离为 y
    int[] dist = new int[N];
    int INF = 0x3f3f3f3f;
    int n, m, k, idx;
    void add(int a, int b, int c) {
        e[idx] = b;
        ne[idx] = he[a];
        he[a] = idx;
        w[idx] = c;
        idx++;
    }
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        m = ts.length;
        // 初始化链表头
        Arrays.fill(he, -1);
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            add(u, v, c);
        }
        // 最短路
        bf();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, dist[i]);
        }
        return ans > INF / 2 ? -1 : ans;
    }
    void bf() {
        // 起始先将所有的点标记为「距离为正无穷」
        Arrays.fill(dist, INF);
        // 只有起点最短距离为 0
        dist[k] = 0;
        // 迭代 n 次
        for (int p = 1; p <= n; p++) {
            int[] prev = dist.clone();
            // 每次都使用上一次迭代的结果,执行松弛操作
            for (int a = 1; a <= n; a++) {
                for (int i = he[a]; i != -1; i = ne[i]) {
                    int b = e[i];
                    dist[b] = Math.min(dist[b], prev[a] + w[i]);
                }
            }
        }
    }
}
  • 时间复杂度:$O(n*m)$
  • 空间复杂度:$O(m)$

SPFA(邻接表)

SPFA 是对 Bellman Ford 的优化实现,可以使用队列进行优化,也可以使用栈进行优化。

通常情况下复杂度为 $O(km)$,$k$ 一般为 $4$ 到 $5$,最坏情况下仍为 $O(n * m)$,当数据为网格图时,复杂度会从 $O(km)$ 退化为 $O(n*m)$。

image.png

代码:

###Java

class Solution {
    int N = 110, M = 6010;
    // 邻接表
    int[] he = new int[N], e = new int[M], ne = new int[M], w = new int[M];
    // dist[x] = y 代表从「源点/起点」到 x 的最短距离为 y
    int[] dist = new int[N];
    // 记录哪一个点「已在队列」中
    boolean[] vis = new boolean[N];
    int INF = 0x3f3f3f3f;
    int n, k, idx;
    void add(int a, int b, int c) {
        e[idx] = b;
        ne[idx] = he[a];
        he[a] = idx;
        w[idx] = c;
        idx++;
    }
    public int networkDelayTime(int[][] ts, int _n, int _k) {
        n = _n; k = _k;
        // 初始化链表头
        Arrays.fill(he, -1);
        // 存图
        for (int[] t : ts) {
            int u = t[0], v = t[1], c = t[2];
            add(u, v, c);
        }
        // 最短路
        spfa();
        // 遍历答案
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            ans = Math.max(ans, dist[i]);
        }
        return ans > INF / 2 ? -1 : ans;
    }
    void spfa() {
        // 起始先将所有的点标记为「未入队」和「距离为正无穷」
        Arrays.fill(vis, false);
        Arrays.fill(dist, INF);
        // 只有起点最短距离为 0
        dist[k] = 0;
        // 使用「双端队列」存储,存储的是点编号
        Deque<Integer> d = new ArrayDeque<>();
        // 将「源点/起点」进行入队,并标记「已入队」
        d.addLast(k);
        vis[k] = true;
        while (!d.isEmpty()) {
            // 每次从「双端队列」中取出,并标记「未入队」
            int poll = d.pollFirst();
            vis[poll] = false;
            // 尝试使用该点,更新其他点的最短距离
            // 如果更新的点,本身「未入队」则加入队列中,并标记「已入队」
            for (int i = he[poll]; i != -1; i = ne[i]) {
                int j = e[i];
                if (dist[j] > dist[poll] + w[i]) {
                    dist[j] = dist[poll] + w[i];
                    if (vis[j]) continue;
                    d.addLast(j);
                    vis[j] = true;
                }
            }
        }
    }
}
  • 时间复杂度:$O(n*m)$
  • 空间复杂度:$O(m)$

最后

如果有帮助到你,请给题解点个赞和收藏,让更多的人看到 ~ ("▔□▔)/

也欢迎你 关注我(公主号后台回复「送书」即可参与看题解学算法送实体书长期活动)或 加入「组队打卡」小群 ,提供写「证明」&「思路」的高质量题解。

所有题解已经加入 刷题指南,欢迎 star 哦 ~

【GTAlgorithm】图解算法,吃透一个Dijkstra就够了!C++/Java/Python

作者 已注销
2021年8月2日 10:06

解法:最短路

  • 题目实际是求节点 $K$ 到其他所有点中最远的距离,那么首先需要求出节点 $K$ 到其他所有点的最短路,然后取最大值即可。

  • 单源最短路问题可以使用 Dijkstra 算法,其核心思路是贪心算法。流程如下:

    1. 首先,Dijkstra 算法需要从当前全部未确定最短路的点中,找到距离源点最短的点 $x$。

    2. 其次,通过点 $x$ 更新其他所有点距离源点的最短距离。例如目前点 A 距离源点最短,距离为 3;有一条 A->B 的有向边,权值为 1,那么从源点先去 A 点再去 B 点距离为 3 + 1 = 4,若原先从源点到 B 的有向边权值为 5,那么我们便可以更新 B 到源点的最短距离为 4

    3. 当全部其他点都遍历完成后,一次循环结束,将 $x$ 标记为已经确定最短路。进入下一轮循环,直到全部点被标记为确定了最短路。


我们通过一个[例子]对 Dijkstra 算法的流程深入了解一下:


以上图片为一个有向带权图,圆圈中为节点序号,箭头上为边权,右侧为所有点距离源点 0 的距离。
image.png
将顶点 0 进行标识,并作为点 $x$,更新其到其他所有点的距离。一轮循环结束。
image.png
image.png

将顶点 2 进行标识,并作为新的点 $x$,更新。我们看到,原本点 1 的最短距离为 5,被更新为了 3。同理还更新了点 3 和点 4 的最短距离。

image.png
image.png

将顶点 1 进行标识,并作为新的点 $x$,同样更新了点 4 到源点的最短距离。

image.png

再分别标识点 4 和点 3,循环结束。


  • 我们来看在实现时需要的代码支持:
    1. 首先,Dijkstra 算法需要存储各个边权,由于本题节点数量不超过 $100$,所以代码中使用了邻接矩阵 g[i][j] 存储从点 i 到点 j 的距离。若两点之间没有给出有向边,则初始化为 inf算法还需要记录所有点到源点的最短距离,代码中使用了 dist[i] 数组存储源点到点 i 的最短距离,初始值也全部设为 inf。由于本题源点为 $K$,所以该点距离设为 0
    2. 其次,Dijkstra 算法需要标记某一节点是否已确定了最短路,在代码中使用了 used[i] 数组存储,若已确定最短距离,则值为 true,否则值为 false
    3. 之所以 inf 设置为 INT_MAX / 2,是因为在更新最短距离的时候,要有两个距离相加,为了防止溢出 int 型,所以除以 2

代码(C++)

class Solution {
public:
    int networkDelayTime(vector<vector<int>> &times, int n, int k) {
        const int inf = INT_MAX / 2;

        // 邻接矩阵存储边信息
        vector<vector<int>> g(n, vector<int>(n, inf));
        for (auto &t : times) {
            // 边序号从 0 开始
            int x = t[0] - 1, y = t[1] - 1;
            g[x][y] = t[2];
        }

        // 从源点到某点的距离数组
        vector<int> dist(n, inf);
        // 由于从 k 开始,所以该点距离设为 0,也即源点
        dist[k - 1] = 0;

        // 节点是否被更新数组
        vector<bool> used(n);

        for (int i = 0; i < n; ++i) {
            // 在还未确定最短路的点中,寻找距离最小的点
            int x = -1;
            for (int y = 0; y < n; ++y) {
                if (!used[y] && (x == -1 || dist[y] < dist[x])) {
                    x = y;
                }
            }

            // 用该点更新所有其他点的距离
            used[x] = true;
            for (int y = 0; y < n; ++y) {
                dist[y] = min(dist[y], dist[x] + g[x][y]);
            }
        }

        // 找到距离最远的点
        int ans = *max_element(dist.begin(), dist.end());
        return ans == inf ? -1 : ans;
    }
};

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        final int INF = Integer.MAX_VALUE / 2;

        // 邻接矩阵存储边信息
        int[][] g = new int[n][n];
        for (int i = 0; i < n; ++i) {
            Arrays.fill(g[i], INF);
        }
        for (int[] t : times) {
            // 边序号从 0 开始
            int x = t[0] - 1, y = t[1] - 1;
            g[x][y] = t[2];
        }

        // 从源点到某点的距离数组
        int[] dist = new int[n];
        Arrays.fill(dist, INF);
        // 由于从 k 开始,所以该点距离设为 0,也即源点
        dist[k - 1] = 0;

        // 节点是否被更新数组
        boolean[] used = new boolean[n];

        for (int i = 0; i < n; ++i) {
            // 在还未确定最短路的点中,寻找距离最小的点
            int x = -1;
            for (int y = 0; y < n; ++y) {
                if (!used[y] && (x == -1 || dist[y] < dist[x])) {
                    x = y;
                }
            }

            // 用该点更新所有其他点的距离
            used[x] = true;
            for (int y = 0; y < n; ++y) {
                dist[y] = Math.min(dist[y], dist[x] + g[x][y]);
            }
        }

        // 找到距离最远的点
        int ans = Arrays.stream(dist).max().getAsInt();
        return ans == INF ? -1 : ans;
    }
}

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        # 邻接矩阵
        g = [[float('inf')] * n for _ in range(n)]
        for x, y, time in times:
            g[x - 1][y - 1] = time

        # 距离数组
        dist = [float('inf')] * n
        dist[k - 1] = 0

        # 标记数组
        used = [False] * n
        for _ in range(n):
            # 找到未标记最近的点
            x = -1
            for y, u in enumerate(used):
                if not u and (x == -1 or dist[y] < dist[x]):
                    x = y
            
            # 更新
            used[x] = True
            for y, time in enumerate(g[x]):
                dist[y] = min(dist[y], dist[x] + time)

        ans = max(dist)
        return ans if ans < float('inf') else -1

  • 时间复杂度:$O(n^2+m)$,其中 $n, m$ 分别为节点和边的数量。如果用优先队列存储,能够将复杂度降为 $O(mlgm)$。
  • 空间复杂度:$O(n^2)$,利用了邻接矩阵。
昨天以前LeetCode 每日一题题解

每日一题-最小区间🔴

2024年11月24日 00:00

你有 k 个 非递减排列 的整数列表。找到一个 最小 区间,使得 k 个列表中的每个列表至少有一个数包含在其中。

我们定义如果 b-a < d-c 或者在 b-a == d-c 时 a < c,则区间 [a,b][c,d] 小。

 

示例 1:

输入:nums = [[4,10,15,24,26], [0,9,12,20], [5,18,22,30]]
输出:[20,24]
解释: 
列表 1:[4, 10, 15, 24, 26],24 在区间 [20,24] 中。
列表 2:[0, 9, 12, 20],20 在区间 [20,24] 中。
列表 3:[5, 18, 22, 30],22 在区间 [20,24] 中。

示例 2:

输入:nums = [[1,2,3],[1,2,3],[1,2,3]]
输出:[1,1]

 

提示:

  • nums.length == k
  • 1 <= k <= 3500
  • 1 <= nums[i].length <= 50
  • -105 <= nums[i][j] <= 105
  • nums[i] 按非递减顺序排列

 

两种方法:堆/排序+滑动窗口(Python/Java/C++/Go/JS/Rust)

作者 endlesscheng
2024年11月9日 12:08

核心思路

把「每个列表至少有一个数包含在其中」的区间叫做合法区间

先求出最左边的合法区间,然后求出第二个合法区间,第三个合法区间,依此类推。

比如示例 1,最左边的合法区间是 $[0,5]$。

枚举所有合法区间的左端点,或者枚举所有合法区间的右端点。其中第一个最短的合法区间就是答案。

方法一:堆

在示例 1 中,有三个列表:

  • $[4,10,15,24,26]$。
  • $[0,9,12,20]$。
  • $[5,18,22,30]$。

我们来计算最左边的合法区间,第二个合法区间,第三个合法区间,……

也就是左端点为 $0$ 的合法区间,左端点为 $4$ 的合法区间,左端点为 $5$ 的合法区间。

求出左端点对应的右端点,就知道了区间的长度,其中第一个最短的区间就是答案。

左端点为 $0$ 的合法区间,右端点是这三个列表的第一个元素的最大值,即 $5$。

接下来,去掉 $0$,列表 $[0,9,12,20]$ 变成 $[9,12,20]$,问题变成如下三个列表:

  • $[4,10,15,24,26]$。
  • $[9,12,20]$。
  • $[5,18,22,30]$。

这三个列表的最左边的合法区间是什么?

左端点是这三个列表的第一个元素的最小值 $4$,右端点是这三个列表的第一个元素的最大值 $9$,所以合法区间为 $[4,9]$。

接下来,去掉 $4$,列表 $[4,10,15,24,26]$ 变成 $[10,15,24,26]$,重复上述过程。

在上述过程中,需要快速地求出合法区间的左端点和右端点:

  • 左端点:需要一个数据结构,支持添加元素、计算最小值、删除最小值。这可以用最小堆维护。堆顶(最小元素)就是左端点。
  • 右端点:记作 $r$。一开始 $r$ 为每个列表第一个元素的最大值。当我们去掉列表第一个元素时,就用列表的第二个元素更新 $r$ 的最大值。依此类推。

注:实际没有去掉元素,而是用下标表示元素在列表中的位置。

细节

  1. 为方便计算列表的下一个元素,需要在堆中额外保存元素属于哪个列表,以及在这个列表中的位置。所以堆中的每个元素是一个三元组,即 $(\textit{nums}[i][j],i,j)$,这样列表的下一个元素就是 $\textit{nums}[i][j+1]$。
  2. 如果列表没有下一个元素($j+1$ 等于列表长度),那么去掉当前元素后,将不会有合法区间包含这个列表的元素,算法结束。

###py

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        # 把每个列表的第一个元素入堆
        h = [(arr[0], i, 0) for i, arr in enumerate(nums)]
        heapify(h)

        ans_l = h[0][0]  # 第一个合法区间的左端点
        ans_r = r = max(arr[0] for arr in nums)  # 第一个合法区间的右端点
        while h[0][2] + 1 < len(nums[h[0][1]]):  # 堆顶列表有下一个元素
            _, i, j = h[0]
            x = nums[i][j + 1]  # 堆顶列表的下一个元素
            heapreplace(h, (x, i, j + 1))  # 替换堆顶
            r = max(r, x)  # 更新合法区间的右端点
            l = h[0][0]  # 当前合法区间的左端点
            if r - l < ans_r - ans_l:
                ans_l, ans_r = l, r
        return [ans_l, ans_r]

###java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        int r = Integer.MIN_VALUE;
        for (int i = 0; i < nums.size(); i++) {
            // 把每个列表的第一个元素入堆
            int x = nums.get(i).get(0);
            pq.offer(new int[]{x, i, 0});
            r = Math.max(r, x);
        }

        int ansL = pq.peek()[0]; // 第一个合法区间的左端点
        int ansR = r; // 第一个合法区间的右端点
        while (pq.peek()[2] + 1 < nums.get(pq.peek()[1]).size()) { // 堆顶列表有下一个元素
            int[] top = pq.poll();
            top[0] = nums.get(top[1]).get(++top[2]); // 堆顶列表的下一个元素
            r = Math.max(r, top[0]); // 更新合法区间的右端点
            pq.offer(top); // 入堆(复用 int[],提高效率)
            int l = pq.peek()[0]; // 当前合法区间的左端点
            if (r - l < ansR - ansL) {
                ansL = l;
                ansR = r;
            }
        }
        return new int[]{ansL, ansR};
    }
}

###cpp

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<>> pq;
        int r = INT_MIN;
        for (int i = 0; i < nums.size(); i++) {
            pq.emplace(nums[i][0], i, 0); // 把每个列表的第一个元素入堆
            r = max(r, nums[i][0]);
        }

        int ans_l = get<0>(pq.top()); // 第一个合法区间的左端点
        int ans_r = r; // 第一个合法区间的右端点
        while (true) {
            auto [_, i, j] = pq.top();
            if (j + 1 == nums[i].size()) { // 堆顶列表没有下一个元素
                break;
            }
            pq.pop();
            int x = nums[i][j + 1]; // 堆顶列表的下一个元素
            pq.emplace(x, i, j + 1); // 入堆
            r = max(r, x); // 更新合法区间的右端点
            int l = get<0>(pq.top()); // 当前合法区间的左端点
            if (r - l < ans_r - ans_l) {
                ans_l = l;
                ans_r = r;
            }
        }
        return {ans_l, ans_r};
    }
};

###go

func smallestRange(nums [][]int) []int {
    h := make(hp, len(nums))
    r := math.MinInt
    for i, arr := range nums {
        h[i] = tuple{arr[0], i, 0} // 把每个列表的第一个元素入堆
        r = max(r, arr[0])
    }
    heap.Init(&h)

    ansL, ansR := h[0].x, r // 第一个合法区间的左右端点
    for h[0].j+1 < len(nums[h[0].i]) { // 堆顶列表有下一个元素
        x := nums[h[0].i][h[0].j+1] // 堆顶列表的下一个元素
        r = max(r, x) // 更新合法区间的右端点
        h[0].x = x // 替换堆顶
        h[0].j++
        heap.Fix(&h, 0)
        l := h[0].x // 当前合法区间的左端点
        if r-l < ansR-ansL {
            ansL, ansR = l, r
        }
    }
    return []int{ansL, ansR}
}

type tuple struct{ x, i, j int }
type hp []tuple
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].x < h[j].x }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (hp) Push(any)             {} // 没用到,可以不写
func (hp) Pop() (_ any)         { return }

###js

var smallestRange = function(nums) {
    const pq = new MinPriorityQueue({priority: a => a[0]});
    let r = -Infinity;
    for (let i = 0; i < nums.length; i++) {
        pq.enqueue([nums[i][0], i, 0]); // 每个列表的第一个元素入堆
        r = Math.max(r, nums[i][0]);
    }

    let ansL = pq.front().element[0]; // 第一个合法区间的左端点
    let ansR = r; // 第一个合法区间的右端点
    while (true) {
        const [_, i, j] = pq.dequeue().element;
        if (j + 1 === nums[i].length) { // 堆顶列表没有下一个元素
            break;
        }
        const x = nums[i][j + 1]; // 堆顶列表的下一个元素
        pq.enqueue([x, i, j + 1]); // 入堆
        r = Math.max(r, x); // 更新合法区间的右端点
        const l = pq.front().element[0]; // 当前合法区间的左端点
        if (r - l < ansR - ansL) {
            ansL = l;
            ansR = r;
        }
    }
    return [ansL, ansR];
};

###rust

use std::collections::BinaryHeap;

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let mut h = BinaryHeap::with_capacity(nums.len()); // 预分配空间
        let mut r = i32::MIN;
        for (i, arr) in nums.iter().enumerate() {
            // 把每个列表的第一个元素入堆
            h.push((-arr[0], i, 0)); // 取反变成最小堆
            r = r.max(arr[0]);
        }

        let mut ans_l = -h.peek().unwrap().0; // 第一个合法区间的左端点
        let mut ans_r = r; // 第一个合法区间的右端点
        while h.peek().unwrap().2 + 1 < nums[h.peek().unwrap().1].len() { // 堆顶列表有下一个元素
            let (_, i, j) = h.pop().unwrap();
            let x = nums[i][j + 1]; // 堆顶列表的下一个元素
            h.push((-x, i, j + 1)); // 入堆
            r = r.max(x); // 更新合法区间的右端点
            let l = -h.peek().unwrap().0; // 当前合法区间的左端点
            if r - l < ans_r - ans_l {
                ans_l = l;
                ans_r = r;
            }
        }
        vec![ans_l, ans_r]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(L\log k)$,其中 $k$ 是 $\textit{nums}$ 的长度,$L$ 是所有 $\textit{nums}[i]$ 的长度之和。循环 $\mathcal{O}(L)$ 次,每次循环需要 $\mathcal{O}(\log k)$ 的时间操作堆。
  • 空间复杂度:$\mathcal{O}(k)$。堆需要 $\mathcal{O}(k)$ 的空间保存元素。

方法二:排序+滑动窗口

对于示例 1 的这三个列表:

  • $[4,10,15,24,26]$。
  • $[0,9,12,20]$。
  • $[5,18,22,30]$。

把所有元素都合在一起排序,可以得到如下结果:

$$
\begin{array}{r|}
元素值 & 0 & 4 & 5 & 9 & 10 & 12 & 15 & 18 & 20 & 22 & 24 & 26 & 30 \
所属列表编号 & 1 & 0 & 2 & 1 & 0 & 1 & 0 & 2 & 1 & 2 & 0 & 0 & 2 \
\end{array}
$$

把上表视作一个由(元素值,所属列表编号)组成的数组,即

$$
\textit{pairs} = [(0, 1), (4, 0), (5, 2), \ldots, (24, 0), (26, 0), (30, 2)]
$$

合法区间等价于 $\textit{pairs}$ 的一个连续子数组,满足列表编号 $0,1,2,\ldots,k-1$ 都在这个子数组中。

由于子数组越长,越能包含 $0,1,2,\ldots,k-1$ 所有编号,有单调性,可以用滑动窗口解决。如果你不了解滑动窗口,可以看视频【基础算法精讲 03】

细节

  1. 用一个长为 $k$ 的数组 $\textit{cnt}$ 统计窗口中的每个编号的出现次数。
  2. 判断是否包含所有编号,简单的方法是判断所有 $\textit{cnt}[i]$ 是否都大于 $0$。为了加快判断速度,可以仿照 76. 最小覆盖子串 的思路,额外用一个变量 $\textit{empty}$ 表示 $\textit{cnt}[i]=0$ 的列表个数。编号 $i$ 进入窗口前,如果 $\textit{cnt}[i]=0$,那么 $\textit{empty}$ 减一;编号 $i$ 离开窗口后,如果 $\textit{cnt}[i]=0$,那么 $\textit{empty}$ 加一。这样只需要判断 $\textit{empty}$ 是否为 $0$,就知道所有 $\textit{cnt}[i]$ 是否都大于 $0$ 了。

注:方法一相当于枚举合法区间的左端点,而方法二相当于枚举合法区间的右端点。

###py

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        pairs = sorted((x, i) for (i, arr) in enumerate(nums) for x in arr)
        ans_l, ans_r = -inf, inf
        empty = len(nums)
        cnt = [0] * empty
        left = 0
        for r, i in pairs:
            if cnt[i] == 0:  # 包含 nums[i] 的数字
                empty -= 1
            cnt[i] += 1
            while empty == 0:  # 每个列表都至少包含一个数
                l, i = pairs[left]
                if r - l < ans_r - ans_l:
                    ans_l, ans_r = l, r
                cnt[i] -= 1
                if cnt[i] == 0:  # 不包含 nums[i] 的数字
                    empty += 1
                left += 1
        return [ans_l, ans_r]

###java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        int sumLen = 0;
        for (List<Integer> list : nums) {
            sumLen += list.size();
        }

        int[][] pairs = new int[sumLen][2];
        int pi = 0;
        for (int i = 0; i < nums.size(); i++) {
            for (int x : nums.get(i)) {
                pairs[pi][0] = x;
                pairs[pi++][1] = i;
            }
        }
        Arrays.sort(pairs, (a, b) -> a[0] - b[0]);

        int ansL = pairs[0][0];
        int ansR = pairs[sumLen - 1][0];
        int empty = nums.size();
        int[] cnt = new int[empty];
        int left = 0;
        for (int[] p : pairs) {
            int r = p[0];
            int i = p[1];
            if (cnt[i] == 0) { // 包含 nums[i] 的数字
                empty--;
            }
            cnt[i]++;
            while (empty == 0) { // 每个列表都至少包含一个数
                int l = pairs[left][0];
                if (r - l < ansR - ansL) {
                    ansL = l;
                    ansR = r;
                }
                i = pairs[left][1];
                cnt[i]--;
                if (cnt[i] == 0) { // 不包含 nums[i] 的数字
                    empty++;
                }
                left++;
            }
        }
        return new int[]{ansL, ansR};
    }
}

###cpp

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        vector<pair<int, int>> pairs;
        for (int i = 0; i < nums.size(); i++) {
            for (int x : nums[i]) {
                pairs.emplace_back(x, i);
            }
        }
        // 看上去 std::sort 比 ranges::sort 更快
        sort(pairs.begin(), pairs.end());

        int ans_l = pairs[0].first;
        int ans_r = pairs.back().first;
        int empty = nums.size();
        vector<int> cnt(empty);
        int left = 0;
        for (auto [r, i] : pairs) {
            if (cnt[i] == 0) { // 包含 nums[i] 的数字
                empty--;
            }
            cnt[i]++;
            while (empty == 0) { // 每个列表都至少包含一个数
                auto [l, i] = pairs[left];
                if (r - l < ans_r - ans_l) {
                    ans_l = l;
                    ans_r = r;
                }
                cnt[i]--;
                if (cnt[i] == 0) { // 不包含 nums[i] 的数字
                    empty++;
                }
                left++;
            }
        }
        return {ans_l, ans_r};
    }
};

###go

func smallestRange(nums [][]int) []int {
    type pair struct{ x, i int }
    pairs := []pair{}
    for i, arr := range nums {
        for _, x := range arr {
            pairs = append(pairs, pair{x, i})
        }
    }
    slices.SortFunc(pairs, func(a, b pair) int { return a.x - b.x })

    ansL, ansR := pairs[0].x, pairs[len(pairs)-1].x
    empty := len(nums)
    cnt := make([]int, empty)
    left := 0
    for _, p := range pairs {
        r, i := p.x, p.i
        if cnt[i] == 0 { // 包含 nums[i] 的数字
            empty--
        }
        cnt[i]++
        for empty == 0 { // 每个列表都至少包含一个数
            l, i := pairs[left].x, pairs[left].i
            if r-l < ansR-ansL {
                ansL, ansR = l, r
            }
            cnt[i]--
            if cnt[i] == 0 {
                // 不包含 nums[i] 的数字
                empty++
            }
            left++
        }
    }
    return []int{ansL, ansR}
}

###js

var smallestRange = function(nums) {
    const pairs = [];
    for (let i = 0; i < nums.length; i++) {
        for (const x of nums[i]) {
            pairs.push([x, i]);
        }
    }
    pairs.sort((a, b) => a[0] - b[0]);

    let ansL = -Infinity, ansR = Infinity;
    let empty = nums.length;
    const cnt = Array(empty).fill(0);
    let left = 0;
    for (const [r, i] of pairs) {
        if (cnt[i] === 0) { // 包含 nums[i] 的数字
            empty--;
        }
        cnt[i]++;
        while (empty === 0) { // 每个列表都至少包含一个数
            const [l, i] = pairs[left];
            if (r - l < ansR - ansL) {
                ansL = l;
                ansR = r;
            }
            cnt[i]--;
            if (cnt[i] === 0) { // 不包含 nums[i] 的数字
                empty++;
            }
            left++;
        }
    }
    return [ansL, ansR];
};

###rust

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let mut pairs = vec![];
        for (i, arr) in nums.iter().enumerate() {
            for &x in arr {
                pairs.push((x, i));
            }
        }
        pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0));

        let mut ans_l = pairs[0].0;
        let mut ans_r = pairs[pairs.len() - 1].0;
        let mut empty = nums.len();
        let mut cnt = vec![0; empty];
        let mut left = 0;
        for &(r, i) in &pairs {
            if cnt[i] == 0 { // 包含 nums[i] 的数字
                empty -= 1;
            }
            cnt[i] += 1;
            while empty == 0 { // 每个列表都至少包含一个数
                let (l, i) = pairs[left];
                if r - l < ans_r - ans_l {
                    ans_l = l;
                    ans_r = r;
                }
                cnt[i] -= 1;
                if cnt[i] == 0 { // 不包含 nums[i] 的数字
                    empty += 1;
                }
                left += 1;
            }
        }
        vec![ans_l, ans_r]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(L\log L)$,其中 $L$ 是所有 $\textit{nums}[i]$ 的长度之和。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(L)$。

更多相似题目,见下面数据结构题单中的「五、堆(优先队列)」,以及滑动窗口题单中的「§2.2 求最短/最小」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与一般树(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

最小区间

2020年7月31日 22:43

方法一:贪心 + 最小堆

给定 $k$ 个列表,需要找到最小区间,使得每个列表都至少有一个数在该区间中。该问题可以转化为,从 $k$ 个列表中各取一个数,使得这 $k$ 个数中的最大值与最小值的差最小。

假设这 $k$ 个数中的最小值是第 $i$ 个列表中的 $x$,对于任意 $j \ne i$,设第 $j$ 个列表中被选为 $k$ 个数之一的数是 $y$,则为了找到最小区间,$y$ 应该取第 $j$ 个列表中大于等于 $x$ 的最小的数,这是一个贪心的策略。贪心策略的正确性简单证明如下:假设 $z$ 也是第 $j$ 个列表中的数,且 $z>y$,则有 $z-x>y-x$,同时包含 $x$ 和 $z$ 的区间一定不会小于同时包含 $x$ 和 $y$ 的区间。因此,其余 $k-1$ 个列表中应该取大于等于 $x$ 的最小的数。

由于 $k$ 个列表都是升序排列的,因此对每个列表维护一个指针,通过指针得到列表中的元素,指针右移之后指向的元素一定大于或等于之前的元素。

使用最小堆维护 $k$ 个指针指向的元素中的最小值,同时维护堆中元素的最大值。初始时,$k$ 个指针都指向下标 $0$,最大元素即为所有列表的下标 $0$ 位置的元素中的最大值。每次从堆中取出最小值,根据最大值和最小值计算当前区间,如果当前区间小于最小区间则用当前区间更新最小区间,然后将对应列表的指针右移,将新元素加入堆中,并更新堆中元素的最大值。

如果一个列表的指针超出该列表的下标范围,则说明该列表中的所有元素都被遍历过,堆中不会再有该列表中的元素,因此退出循环。

###Java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        int rangeLeft = 0, rangeRight = Integer.MAX_VALUE;
        int minRange = rangeRight - rangeLeft;
        int max = Integer.MIN_VALUE;
        int size = nums.size();
        int[] next = new int[size];
        PriorityQueue<Integer> priorityQueue = new PriorityQueue<Integer>(new Comparator<Integer>() {
            public int compare(Integer index1, Integer index2) {
                return nums.get(index1).get(next[index1]) - nums.get(index2).get(next[index2]);
            }
        });
        for (int i = 0; i < size; i++) {
            priorityQueue.offer(i);
            max = Math.max(max, nums.get(i).get(0));
        }
        while (true) {
            int minIndex = priorityQueue.poll();
            int curRange = max - nums.get(minIndex).get(next[minIndex]);
            if (curRange < minRange) {
                minRange = curRange;
                rangeLeft = nums.get(minIndex).get(next[minIndex]);
                rangeRight = max;
            }
            next[minIndex]++;
            if (next[minIndex] == nums.get(minIndex).size()) {
                break;
            }
            priorityQueue.offer(minIndex);
            max = Math.max(max, nums.get(minIndex).get(next[minIndex]));
        }
        return new int[]{rangeLeft, rangeRight};
    }
}

###cpp

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        int rangeLeft = 0, rangeRight = INT_MAX;
        int size = nums.size();
        vector<int> next(size);
        
        auto cmp = [&](const int& u, const int& v) {
            return nums[u][next[u]] > nums[v][next[v]];
        };
        priority_queue<int, vector<int>, decltype(cmp)> pq(cmp);
        int minValue = 0, maxValue = INT_MIN;
        for (int i = 0; i < size; ++i) {
            pq.emplace(i);
            maxValue = max(maxValue, nums[i][0]);
        }

        while (true) {
            int row = pq.top();
            pq.pop();
            minValue = nums[row][next[row]];
            if (maxValue - minValue < rangeRight - rangeLeft) {
                rangeLeft = minValue;
                rangeRight = maxValue;
            }
            if (next[row] == nums[row].size() - 1) {
                break;
            }
            ++next[row];
            maxValue = max(maxValue, nums[row][next[row]]);
            pq.emplace(row);
        }

        return {rangeLeft, rangeRight};
    }
};

###golang

var (
    next []int
    numsC [][]int
)

func smallestRange(nums [][]int) []int {
    numsC = nums
    rangeLeft, rangeRight := 0, math.MaxInt32
    minRange := rangeRight - rangeLeft
    max := math.MinInt32
    size := len(nums)
    next = make([]int, size)
    h := &IHeap{}
    heap.Init(h)

    for i := 0; i < size; i++ {
        heap.Push(h, i)
        max = Max(max, nums[i][0])
    }

    for {
        minIndex := heap.Pop(h).(int)
        curRange := max - nums[minIndex][next[minIndex]]
        if curRange < minRange {
            minRange = curRange
            rangeLeft, rangeRight = nums[minIndex][next[minIndex]], max
        }
        next[minIndex]++
        if next[minIndex] == len(nums[minIndex]) {
            break
        }
        heap.Push(h, minIndex)
        max = Max(max, nums[minIndex][next[minIndex]])
    }
    return []int{rangeLeft, rangeRight}
}

type IHeap []int

func (h IHeap) Len() int           { return len(h) }
func (h IHeap) Less(i, j int) bool { return numsC[h[i]][next[h[i]]] < numsC[h[j]][next[h[j]]] }
func (h IHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *IHeap) Push(x interface{}) {
    *h = append(*h, x.(int))
}

func (h *IHeap) Pop() interface{} {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}

func Max(x, y int) int {
    if x > y {
        return x
    }
    return y
}

###Python

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        rangeLeft, rangeRight = -10**9, 10**9
        maxValue = max(vec[0] for vec in nums)
        priorityQueue = [(vec[0], i, 0) for i, vec in enumerate(nums)]
        heapq.heapify(priorityQueue)

        while True:
            minValue, row, idx = heapq.heappop(priorityQueue)
            if maxValue - minValue < rangeRight - rangeLeft:
                rangeLeft, rangeRight = minValue, maxValue
            if idx == len(nums[row]) - 1:
                break
            maxValue = max(maxValue, nums[row][idx + 1])
            heapq.heappush(priorityQueue, (nums[row][idx + 1], row, idx + 1))
        
        return [rangeLeft, rangeRight]

###C

#define maxn 100005

int heap[maxn];
int heap_count;
int **rec, *nx;

bool heap_comp(int *first, int *second) {
    return rec[*first][nx[*first]] < rec[*second][nx[*second]];
}

void swap(int *first, int *second) {
    int temp = *second;
    *second = *first;
    *first = temp;
    return;
}

void push(int num) {
    int pos = ++heap_count;
    heap[pos] = num;
    while (pos > 1) {
        if (heap_comp(&heap[pos], &heap[pos >> 1])) {
            swap(&heap[pos], &heap[pos >> 1]);
            pos >>= 1;
        } else
            break;
    }
    return;
}

void pop() {
    int top_num = 1;
    int now;
    swap(&heap[top_num], &heap[heap_count--]);
    while ((now = (top_num << 1)) <= heap_count) {
        if (heap_comp(&heap[now + 1], &heap[now]) && now < heap_count) now++;
        if (heap_comp(&heap[now], &heap[top_num])) {
            swap(&heap[top_num], &heap[now]);
            top_num = now;
        } else
            break;
    }
}

int top() { return heap[1]; }

int *smallestRange(int **nums, int numsSize, int *numsColSize,
                   int *returnSize) {
    heap_count = 0;
    nx = (int *)malloc(sizeof(int) * numsSize);
    memset(nx, 0, sizeof(int) * numsSize);
    rec = nums;

    int rangeLeft = 0, rangeRight = 2147483647;
    int minValue = 0, maxValue = -2147483648;
    for (int i = 0; i < numsSize; ++i) {
        push(i);
        maxValue = fmax(maxValue, nums[i][0]);
    }

    while (true) {
        int row = top();
        pop();
        minValue = nums[row][nx[row]];
        if (maxValue - minValue < rangeRight - rangeLeft) {
            rangeLeft = minValue;
            rangeRight = maxValue;
        }
        if (nx[row] == numsColSize[row] - 1) {
            break;
        }
        ++nx[row];
        maxValue = fmax(maxValue, nums[row][nx[row]]);
        push(row);
    }
    int *ret = malloc(sizeof(int) * 2);
    ret[0] = rangeLeft, ret[1] = rangeRight;
    *returnSize = 2;
    return ret;
}

###JavaScript

var smallestRange = function(nums) {
    let rangeLeft = 0, rangeRight = Number.MAX_SAFE_INTEGER;
    const size = nums.length;
    const next = new Array(size).fill(0);
    const pq = new MinPriorityQueue();
    let minValue = 0, maxValue = Number.MIN_SAFE_INTEGER;

    for (let i = 0; i < size; ++i) {
        pq.enqueue(i, nums[i][next[i]]);
        maxValue = Math.max(maxValue, nums[i][0]);
    }

    while (true) {
        const row = pq.dequeue().element;
        minValue = nums[row][next[row]];
        if (maxValue - minValue < rangeRight - rangeLeft) {
            rangeLeft = minValue;
            rangeRight = maxValue;
        }
        if (next[row] === nums[row].length - 1) {
            break;
        }
        ++next[row];
        maxValue = Math.max(maxValue, nums[row][next[row]]);
        pq.enqueue(row, nums[row][next[row]]);
    }

    return [rangeLeft, rangeRight];
};

###TypeScript

function smallestRange(nums: number[][]): number[] {
    let rangeLeft = 0, rangeRight = Number.MAX_SAFE_INTEGER;
    const size = nums.length;
    const next: number[] = new Array(size).fill(0);

    const pq = new MinPriorityQueue();
    let minValue = 0, maxValue = Number.MIN_SAFE_INTEGER;

    for (let i = 0; i < size; ++i) {
        pq.enqueue(i, nums[i][next[i]]);
        maxValue = Math.max(maxValue, nums[i][0]);
    }

    while (true) {
        const row = pq.dequeue().element;
        minValue = nums[row][next[row]];
        if (maxValue - minValue < rangeRight - rangeLeft) {
            rangeLeft = minValue;
            rangeRight = maxValue;
        }
        if (next[row] === nums[row].length - 1) {
            break;
        }
        ++next[row];
        maxValue = Math.max(maxValue, nums[row][next[row]]);
        pq.enqueue(row, nums[row][next[row]]);
    }

    return [rangeLeft, rangeRight];
};

###Rust

use std::cmp::Ordering;
use std::collections::BinaryHeap;

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let mut range_left = 0;
        let mut range_right = i32::MAX;
        let size = nums.len();
        let mut next = vec![0; size];
        let mut max_value = i32::MIN;
        let mut pq = BinaryHeap::new();

        for i in 0..size {
            max_value = max_value.max(nums[i][0]);
            pq.push(std::cmp::Reverse((nums[i][0], i)));
        }

        while let Some(std::cmp::Reverse((min_value, row))) = pq.pop() {
            if max_value - min_value < range_right - range_left {
                range_left = min_value;
                range_right = max_value;
            }
            if next[row] == nums[row].len() - 1 {
                break;
            }
            next[row] += 1;
            max_value = max_value.max(nums[row][next[row]]);
            pq.push(std::cmp::Reverse((nums[row][next[row]], row)));
        }

        vec![range_left, range_right]
    }
}

复杂度分析

  • 时间复杂度:$O(nk \log k)$,其中 $n$ 是所有列表的平均长度,$k$ 是列表数量。所有的指针移动的总次数最多是 $nk$ 次,每次从堆中取出元素和添加元素都需要更新堆,时间复杂度是 $O(\log k)$,因此总时间复杂度是 $O(nk \log k)$。

  • 空间复杂度:$O(k)$,其中 $k$ 是列表数量。空间复杂度取决于堆的大小,堆中维护 $k$ 个元素。

方法二:哈希表 + 滑动窗口

思路

在讲这个方法之前我们先思考这样一个问题:有一个序列 $A = { a_1, a_2, \cdots, a_n }$ 和一个序列 $B = {b_1, b_2, \cdots, b_m}$,请找出一个 $B$ 中的一个最小的区间,使得在这个区间中 $A$ 序列的每个数字至少出现一次,请注意 $A$ 中的元素可能重复,也就是说如果 $A$ 中有 $p$ 个 $u$,那么你选择的这个区间中 $u$ 的个数一定不少于 $p$。没错,这就是我们五月份的一道打卡题:「76. 最小覆盖子串」。官方题解使用了一种滑动窗口的方法,遍历整个 $B$ 序列并用一个哈希表表示当前窗口中的元素:

  • 右边界在每次遍历到新元素的时候右移,同时将拓展到的新元素加入哈希表;
  • 左边界右移当且仅当当前区间为一个合法的答案区间,即当前窗口内的元素包含 $A$ 中所有的元素,同时将原来左边界指向的元素从哈希表中移除;
  • 答案更新当且仅当当前窗口内的元素包含 $A$ 中所有的元素。

如果这个地方不理解,可以参考「76. 最小覆盖子串的官方题解」。

回到这道题,我们发现这两道题的相似之处在于都要求我们找到某个符合条件的最小区间,我们可以借鉴「76. 最小覆盖子串」的做法:这里序列 ${ 0, 1, \cdots , k - 1 }$ 就是上面描述的 $A$ 序列,即 $k$ 个列表,我们需要在一个 $B$ 序列当中找到一个区间,可以覆盖 $A$ 序列。这里的 $B$ 序列是什么?我们可以用一个哈希映射来表示 $B$ 序列—— $B[i]$ 表示 $i$ 在哪些列表当中出现过,这里哈希映射的键是一个整数,表示列表中的某个数值,哈希映射的值是一个数组,这个数组里的元素代表当前的键出现在哪些列表里。也许文字表述比较抽象,大家可以结合下面这个例子来理解。

  • 如果列表集合为:
    0: [-1, 2, 3]
    1: [1]
    2: [1, 2]
    3: [1, 1, 3]
    
  • 那么可以得到这样一个哈希映射
    -1: [0]
     1: [1, 2, 3, 3]
     2: [0, 2]
     3: [0, 3]
    

我们得到的这个哈希映射就是这里的 $B$ 序列。我们要做的就是在 $B$ 序列上使用两个指针维护一个滑动窗口,并用一个哈希表维护当前窗口中已经包含了哪些列表中的元素,记录它们的索引。遍历 $B$ 序列的每一个元素:

  • 指向窗口右边界的指针右移当且仅当每次遍历到新的元素,并将这个新的元素对应的值数组中的每一个数加入到哈希表中;
  • 指向窗口左边界的指针右移当且仅当当前区间内的元素包含 $A$ 中所有的元素,同时将原来左边界对应的值数组的元素们从哈希表中移除;
  • 答案更新当且仅当当前窗口内的元素包含 $A$ 中所有的元素。

大家可以参考代码理解这个过程。

代码

###Java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        int size = nums.size();
        Map<Integer, List<Integer>> indices = new HashMap<Integer, List<Integer>>();
        int xMin = Integer.MAX_VALUE, xMax = Integer.MIN_VALUE;
        for (int i = 0; i < size; i++) {
            for (int x : nums.get(i)) {
                List<Integer> list = indices.getOrDefault(x, new ArrayList<Integer>());
                list.add(i);
                indices.put(x, list);
                xMin = Math.min(xMin, x);
                xMax = Math.max(xMax, x);
            }
        }

        int[] freq = new int[size];
        int inside = 0;
        int left = xMin, right = xMin - 1;
        int bestLeft = xMin, bestRight = xMax;

        while (right < xMax) {
            right++;
            if (indices.containsKey(right)) {
                for (int x : indices.get(right)) {
                    freq[x]++;
                    if (freq[x] == 1) {
                        inside++;
                    }
                }
                while (inside == size) {
                    if (right - left < bestRight - bestLeft) {
                        bestLeft = left;
                        bestRight = right;
                    }
                    if (indices.containsKey(left)) {
                        for (int x: indices.get(left)) {
                            freq[x]--;
                            if (freq[x] == 0) {
                                inside--;
                            }
                        }
                    }
                    left++;
                }
            }
        }

        return new int[]{bestLeft, bestRight};
    }
}

###C++

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        int n = nums.size();
        unordered_map<int, vector<int>> indices;
        int xMin = INT_MAX, xMax = INT_MIN;
        for (int i = 0; i < n; ++i) {
            for (const int& x: nums[i]) {
                indices[x].push_back(i);
                xMin = min(xMin, x);
                xMax = max(xMax, x);
            }
        }

        vector<int> freq(n);
        int inside = 0;
        int left = xMin, right = xMin - 1;
        int bestLeft = xMin, bestRight = xMax;

        while (right < xMax) {
            ++right;
            if (indices.count(right)) {
                for (const int& x: indices[right]) {
                    ++freq[x];
                    if (freq[x] == 1) {
                        ++inside;
                    }
                }
                while (inside == n) {
                    if (right - left < bestRight - bestLeft) {
                        bestLeft = left;
                        bestRight = right;
                    }
                    if (indices.count(left)) {
                        for (const int& x: indices[left]) {
                            --freq[x];
                            if (freq[x] == 0) {
                                --inside;
                            }
                        }
                    }
                    ++left;
                }
            }
        }

        return {bestLeft, bestRight};
    }
};

###Go

func smallestRange(nums [][]int) []int {
    size := len(nums)
    indices := map[int][]int{}
    xMin, xMax := math.MaxInt32, math.MinInt32
    for i := 0; i < size; i++ {
        for _, x := range nums[i] {
            indices[x] = append(indices[x], i)
            xMin = min(xMin, x)
            xMax = max(xMax, x)
        }
    }
    freq := make([]int, size)
    inside := 0
    left, right := xMin, xMin - 1
    bestLeft, bestRight := xMin, xMax
    for right < xMax {
        right++
        if len(indices[right]) > 0 {
            for _, x := range indices[right] {
                freq[x]++
                if freq[x] == 1 {
                    inside++
                }
            }
            for inside == size {
                if right - left < bestRight - bestLeft {
                    bestLeft, bestRight = left, right
                }
                for _, x := range indices[left] {
                    freq[x]--
                    if freq[x] == 0 {
                        inside--
                    }
                }
                left++
            }
        }
    }
    return []int{bestLeft, bestRight}
}

func min(x, y int) int {
    if x < y {
        return x
    }
    return y
}

func max(x, y int) int {
    if x > y {
        return x
    }
    return y
}

###Python

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        n = len(nums)
        indices = collections.defaultdict(list)
        xMin, xMax = 10**9, -10**9
        for i, vec in enumerate(nums):
            for x in vec:
                indices[x].append(i)
            xMin = min(xMin, *vec)
            xMax = max(xMax, *vec)
        
        freq = [0] * n
        inside = 0
        left, right = xMin, xMin - 1
        bestLeft, bestRight = xMin, xMax

        while right < xMax:
            right += 1
            if right in indices:
                for x in indices[right]:
                    freq[x] += 1
                    if freq[x] == 1:
                        inside += 1
                while inside == n:
                    if right - left < bestRight - bestLeft:
                        bestLeft, bestRight = left, right
                    if left in indices:
                        for x in indices[left]:
                            freq[x] -= 1
                            if freq[x] == 0:
                                inside -= 1
                    left += 1

        return [bestLeft, bestRight]

###C

typedef struct {
    int key;
    struct ListNode *val;
    UT_hash_handle hh;
} HashItem; 

struct ListNode *createListNode(int val) {
    struct ListNode *p = (struct ListNode*)malloc(sizeof(struct ListNode));
    p->val = val;
    p->next = NULL;
    return p;
}

HashItem *hashFindItem(HashItem **obj, int key) {
    HashItem *pEntry = NULL;
    HASH_FIND_INT(*obj, &key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, int key, int val) {
    struct ListNode *p = createListNode(val);
    HashItem *pEntry = hashFindItem(obj, key);
    if (pEntry) {
        p->next = pEntry->val;
        pEntry->val = p;
        return true;
    }
    pEntry = (HashItem *)malloc(sizeof(HashItem));
    pEntry->key = key;
    pEntry->val = p;
    HASH_ADD_INT(*obj, key, pEntry);
    return true;
}

struct ListNode* hashGetItem(HashItem **obj, int key) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        return NULL;
    }
    return pEntry->val;
}

void freeList(struct ListNode *list) {
    while (list) {
        struct ListNode *p = list;
        list = list->next;
        free(p);
    }
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);  
        free(curr);
    }
}

int* smallestRange(int** nums, int numsSize, int* numsColSize, int* returnSize) {
    int n = numsSize;
    HashItem *indices = NULL;
    int xMin = INT_MAX, xMax = INT_MIN;
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < numsColSize[i]; j++) {
            int x = nums[i][j];
            hashAddItem(&indices, x, i);
            xMin = fmin(xMin, x);
            xMax = fmax(xMax, x);
        }
    }

    int freq[n];
    memset(freq, 0, sizeof(freq));
    int inside = 0;
    int left = xMin, right = xMin - 1;
    int bestLeft = xMin, bestRight = xMax;

    while (right < xMax) {
        ++right;
        if (hashFindItem(&indices, right)) {
            for (struct ListNode *p = hashGetItem(&indices, right); p; p = p->next) {
                int x = p->val;
                ++freq[x];
                if (freq[x] == 1) {
                    ++inside;
                }
            }
            while (inside == n) {
                if (right - left < bestRight - bestLeft) {
                    bestLeft = left;
                    bestRight = right;
                }
                if (hashFindItem(&indices, left)) {
                    for (struct ListNode *p = hashGetItem(&indices, left); p; p = p->next) {
                        int x = p->val;
                        --freq[x];
                        if (freq[x] == 0) {
                            --inside;
                        }
                    }
                }
                ++left;
            }
        }
    }
    int *res = (int *)malloc(sizeof(int) * 2);
    res[0] = bestLeft;
    res[1] = bestRight;
    *returnSize = 2;
    hashFree(&indices);
    return res;
}

###JavaScript

var smallestRange = function(nums) {
    const size = nums.length;
    const indices = new Map();
    let xMin = Number.MAX_SAFE_INTEGER, xMax = Number.MIN_SAFE_INTEGER;

    for (let i = 0; i < size; i++) {
        for (const x of nums[i]) {
            if (!indices.has(x)) {
                indices.set(x, []);
            }
            indices.get(x).push(i);
            xMin = Math.min(xMin, x);
            xMax = Math.max(xMax, x);
        }
    }

    const freq = new Array(size).fill(0);
    let inside = 0;
    let left = xMin, right = xMin - 1;
    let bestLeft = xMin, bestRight = xMax;

    while (right < xMax) {
        right++;
        if (indices.has(right)) {
            for (const x of indices.get(right)) {
                freq[x]++;
                if (freq[x] === 1) {
                    inside++;
                }
            }
            while (inside === size) {
                if (right - left < bestRight - bestLeft) {
                    bestLeft = left;
                    bestRight = right;
                }
                if (indices.has(left)) {
                    for (const x of indices.get(left)) {
                        freq[x]--;
                        if (freq[x] === 0) {
                            inside--;
                        }
                    }
                }
                left++;
            }
        }
    }

    return [bestLeft, bestRight];
};

###TypeScript

function smallestRange(nums: number[][]): number[] {
    const size = nums.length;
    const indices = new Map<number, number[]>();
    let xMin = Number.MAX_SAFE_INTEGER, xMax = Number.MIN_SAFE_INTEGER;

    for (let i = 0; i < size; i++) {
        for (const x of nums[i]) {
            if (!indices.has(x)) {
                indices.set(x, []);
            }
            indices.get(x)!.push(i);
            xMin = Math.min(xMin, x);
            xMax = Math.max(xMax, x);
        }
    }

    const freq = new Array(size).fill(0);
    let inside = 0;
    let left = xMin, right = xMin - 1;
    let bestLeft = xMin, bestRight = xMax;

    while (right < xMax) {
        right++;
        if (indices.has(right)) {
            for (const x of indices.get(right)!) {
                freq[x]++;
                if (freq[x] === 1) {
                    inside++;
                }
            }
            while (inside === size) {
                if (right - left < bestRight - bestLeft) {
                    bestLeft = left;
                    bestRight = right;
                }
                if (indices.has(left)) {
                    for (const x of indices.get(left)!) {
                        freq[x]--;
                        if (freq[x] === 0) {
                            inside--;
                        }
                    }
                }
                left++;
            }
        }
    }

    return [bestLeft, bestRight];
};

###Rust

use std::collections::HashMap;

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let size = nums.len();
        let mut indices: HashMap<i32, Vec<usize>> = HashMap::new();
        let mut x_min = i32::MAX;
        let mut x_max = i32::MIN;

        for i in 0..size {
            for &x in &nums[i] {
                indices.entry(x).or_insert_with(Vec::new).push(i);
                x_min = x_min.min(x);
                x_max = x_max.max(x);
            }
        }

        let mut freq = vec![0; size];
        let mut inside = 0;
        let mut left = x_min;
        let mut right = x_min - 1;
        let mut best_left = x_min;
        let mut best_right = x_max;

        while right < x_max {
            right += 1;
            if let Some(vec) = indices.get(&right) {
                for &x in vec {
                    freq[x] += 1;
                    if freq[x] == 1 {
                        inside += 1;
                    }
                }
                while inside == size {
                    if right - left < best_right - best_left {
                        best_left = left;
                        best_right = right;
                    }
                    if let Some(vec) = indices.get(&left) {
                        for &x in vec {
                            freq[x] -= 1;
                            if freq[x] == 0 {
                                inside -= 1;
                            }
                        }
                    }
                    left += 1;
                }
            }
        }

        vec![best_left, best_right]
    }
}

复杂度分析

  • 时间复杂度:$O(nk + |V|)$,其中 $n$ 是所有列表的平均长度,$k$ 是列表数量,$|V|$ 是列表中元素的值域,在本题中 $|V| \leq 2*10^5$。构造哈希映射的时间复杂度为 $O(nk)$,双指针的移动范围为 $|V|$,在此过程中会对哈希映射再进行一次遍历,时间复杂度为 $O(nk)$,因此总时间复杂度为 $O(nk + |V|)$。

  • 空间复杂度:$O(nk)$,即为哈希映射使用的空间。哈希映射的「键」的数量由列表中的元素个数 $nk$ 以及值域 $|V|$ 中的较小值决定,「值」为长度不固定的数组,但是它们的长度之和为 $nk$,因此哈希映射使用的空间为 $O(nk)$。在使用双指针时,还需要一个长度为 $n$ 的数组,其对应的空间在渐进意义下小于 $O(nk)$,因此可以忽略。

排序滑窗

作者 netcan
2020年5月10日 15:07

解题思路:

首先将 $k$ 组数据升序合并成一组,并记录每个数字所属的组,例如:

$[[4,10,15,24,26], [0,9,12,20], [5,18,22,30]]$

合并升序后得到:
$[(0, 1), (4, 0), (5, 2), (9, 1), (10, 0), (12, 1), (15, 0), (18, 2), (20, 1), (22, 2), (24, 0), (26, 0), (30, 2)]$

然后只看所属组的话,那么
$[1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 0, 0, 2]$

按组进行滑窗,保证一个窗口的组满足$k$组后在记录窗口的最小区间值。

[1 0 2] 2 1 0 1 0 2 1 2 0 0 2    [0, 5]
1 [0 2 1] 1 0 1 0 2 1 2 0 0 2    [0, 5]
1 0 [2 1 0] 0 1 0 2 1 2 0 0 2    [0, 5]
1 0 [2 1 0 1] 1 0 2 1 2 0 0 2    [0, 5]
1 0 [2 1 0 1 0] 0 2 1 2 0 0 2    [0, 5]
1 0 2 1 0 [1 0 2] 2 1 2 0 0 2    [0, 5]
1 0 2 1 0 1 [0 2 1] 1 2 0 0 2    [0, 5]
1 0 2 1 0 1 [0 2 1 2] 2 0 0 2    [0, 5]
1 0 2 1 0 1 0 2 [1 2 0] 0 0 2    [20, 24]
1 0 2 1 0 1 0 2 [1 2 0 0] 0 2    [20, 24]
1 0 2 1 0 1 0 2 [1 2 0 0 2] 2    [20, 24]

###C++

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        vector<pair<int, int>> ordered; // (number, group)
        for (size_t k = 0; k < nums.size(); ++k)
            for (auto n: nums[k]) ordered.push_back({n, k});
        sort(ordered.begin(), ordered.end());

        int i = 0, k = 0;
        vector<int> ans;
        unordered_map<int, int> count;
        for (size_t j = 0; j < ordered.size(); ++j) {
            if (! count[ordered[j].second]++) ++k;
            if (k == nums.size()) { 
                while (count[ordered[i].second] > 1) --count[ordered[i++].second]; // minialize range
                if (ans.empty() || ans[1] - ans[0] > ordered[j].first - ordered[i].first) {
                    ans = vector<int>{ordered[i].first, ordered[j].first};
                }
            }
        }

        return ans;
    }
};

[Python3/Java/C++/Go/TypeScript] 一题一解:计数(清晰题解)

作者 lcbin
2024年11月23日 09:43

方法一:计数

我们可以用一个二维数组 $\textit{cnt}$ 记录每个玩家获得的每种颜色球的数量,用一个哈希表 $\textit{s}$ 记录胜利玩家的编号。

遍历 $\textit{pick}$ 数组,对于每个元素 $[x, y]$,我们将 $\textit{cnt}[x][y]$ 加一,如果 $\textit{cnt}[x][y]$ 大于 $x$,则将 $x$ 加入哈希表 $\textit{s}$。

最后返回哈希表 $\textit{s}$ 的大小即可。

###python

class Solution:
    def winningPlayerCount(self, n: int, pick: List[List[int]]) -> int:
        cnt = [[0] * 11 for _ in range(n)]
        s = set()
        for x, y in pick:
            cnt[x][y] += 1
            if cnt[x][y] > x:
                s.add(x)
        return len(s)

###java

class Solution {
    public int winningPlayerCount(int n, int[][] pick) {
        int[][] cnt = new int[n][11];
        Set<Integer> s = new HashSet<>();
        for (var p : pick) {
            int x = p[0], y = p[1];
            if (++cnt[x][y] > x) {
                s.add(x);
            }
        }
        return s.size();
    }
}

###cpp

class Solution {
public:
    int winningPlayerCount(int n, vector<vector<int>>& pick) {
        int cnt[10][11]{};
        unordered_set<int> s;
        for (const auto& p : pick) {
            int x = p[0], y = p[1];
            if (++cnt[x][y] > x) {
                s.insert(x);
            }
        }
        return s.size();
    }
};

###go

func winningPlayerCount(n int, pick [][]int) int {
cnt := make([][11]int, n)
s := map[int]struct{}{}
for _, p := range pick {
x, y := p[0], p[1]
cnt[x][y]++
if cnt[x][y] > x {
s[x] = struct{}{}
}
}
return len(s)
}

###ts

function winningPlayerCount(n: number, pick: number[][]): number {
    const cnt: number[][] = Array.from({ length: n }, () => Array(11).fill(0));
    const s = new Set<number>();
    for (const [x, y] of pick) {
        if (++cnt[x][y] > x) {
            s.add(x);
        }
    }
    return s.size;
}

时间复杂度 $O(m + n \times M)$,空间复杂度 $O(n \times M)$。其中 $m$ 为 $\textit{pick}$ 数组的长度,而 $n$ 和 $M$ 分别为玩家数目和颜色数目。


有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈😄~

❌
❌