阅读视图

发现新文章,点击刷新页面。

两种 Dijkstra 写法(附题单)Python/Java/C++/Go/JS/Rust

Dijkstra 算法介绍

定义 $g[i][j]$ 表示节点 $i$ 到节点 $j$ 这条边的边权。如果没有 $i$ 到 $j$ 的边,则 $g[i][j]=\infty$。

定义 $\textit{dis}[i]$ 表示起点 $k$ 到节点 $i$ 的最短路长度,一开始 $\textit{dis}[k]=0$,其余 $\textit{dis}[i]=\infty$ 表示尚未计算出。

我们的目标是计算出最终的 $\textit{dis}$ 数组。

  • 首先更新起点 $k$ 到其邻居 $y$ 的最短路,即更新 $\textit{dis}[y]$ 为 $g[k][y]$。
  • 然后取除了起点 $k$ 以外的 $\textit{dis}[i]$ 的最小值,假设最小值对应的节点是 $3$。此时可以断言:$\textit{dis}[3]$ 已经是 $k$ 到 $3$ 的最短路长度,不可能有其它 $k$ 到 $3$ 的路径更短!反证法:假设存在更短的路径,那我们一定会从 $k$ 出发经过一个点 $u$,它的 $\textit{dis}[u]$ 比 $\textit{dis}[3]$ 还要小,然后再经过一些边到达 $3$,得到更小的 $\textit{dis}[3]$。但 $\textit{dis}[3]$ 已经是最小的了,并且图中没有负数边权,所以 $u$ 是不存在的,矛盾。故原命题成立,此时我们得到了 $\textit{dis}[3]$ 的最终值。
  • 用节点 $3$ 到其邻居 $y$ 的边权 $g[3][y]$ 更新 $\textit{dis}[y]$:如果 $\textit{dis}[3] + g[3][y] < \textit{dis}[y]$,那么更新 $\textit{dis}[y]$ 为 $\textit{dis}[3] + g[3][y]$,否则不更新。
  • 然后取除了节点 $k,3$ 以外的 $\textit{dis}[i]$ 的最小值,重复上述过程。
  • 由数学归纳法可知,这一做法可以得到每个点的最短路。当所有点的最短路都已确定时,算法结束。

写法一:朴素 Dijkstra(适用于稠密图)

对于本题,在计算最短路时,如果发现当前找到的最小最短路等于 $\infty$,说明有节点无法到达,可以提前结束算法,返回 $-1$。

如果所有节点都可以到达,返回 $\max(\textit{dis})$。

代码实现时,节点编号改成从 $0$ 开始。

###py

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[inf for _ in range(n)] for _ in range(n)]  # 邻接矩阵
        for x, y, d in times:
            g[x - 1][y - 1] = d

        dis = [inf] * n
        ans = dis[k - 1] = 0
        done = [False] * n
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x < 0:
                return ans  # 最后一次算出的最短路就是最大的
            if dis[x] == inf:  # 有节点无法到达
                return -1
            ans = dis[x]  # 求出的最短路会越来越大
            done[x] = True  # 最短路长度已确定(无法变得更小)
            for y, d in enumerate(g[x]):
                # 更新 x 的邻居的最短路
                dis[y] = min(dis[y], dis[x] + d)

###java

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        final int INF = Integer.MAX_VALUE / 2; // 防止加法溢出
        int[][] g = new int[n][n]; // 邻接矩阵
        for (int[] row : g) {
            Arrays.fill(row, INF);
        }
        for (int[] t : times) {
            g[t[0] - 1][t[1] - 1] = t[2];
        }

        int maxDis = 0;
        int[] dis = new int[n];
        Arrays.fill(dis, INF);
        dis[k - 1] = 0;
        boolean[] done = new boolean[n];
        while (true) {
            int x = -1;
            for (int i = 0; i < n; i++) {
                if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                    x = i;
                }
            }
            if (x < 0) {
                return maxDis; // 最后一次算出的最短路就是最大的
            }
            if (dis[x] == INF) { // 有节点无法到达
                return -1;
            }
            maxDis = dis[x]; // 求出的最短路会越来越大
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (int y = 0; y < n; y++) {
                // 更新 x 的邻居的最短路
                dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
            }
        }
    }
}

###cpp

class Solution {
public:
    int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        vector<vector<int>> g(n, vector<int>(n, INT_MAX / 2)); // 邻接矩阵
        for (auto& t : times) {
            g[t[0] - 1][t[1] - 1] = t[2];
        }

        vector<int> dis(n, INT_MAX / 2), done(n);
        dis[k - 1] = 0;
        while (true) {
            int x = -1;
            for (int i = 0; i < n; i++) {
                if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                    x = i;
                }
            }
            if (x < 0) {
                return ranges::max(dis);
            }
            if (dis[x] == INT_MAX / 2) { // 有节点无法到达
                return -1;
            }
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (int y = 0; y < n; y++) {
                // 更新 x 的邻居的最短路
                dis[y] = min(dis[y], dis[x] + g[x][y]);
            }
        }
    }
};

###go

func networkDelayTime(times [][]int, n, k int) int {
    const inf = math.MaxInt / 2 // 防止加法溢出
    g := make([][]int, n) // 邻接矩阵
    for i := range g {
        g[i] = make([]int, n)
        for j := range g[i] {
            g[i][j] = inf
        }
    }
    for _, t := range times {
        g[t[0]-1][t[1]-1] = t[2]
    }

    dis := make([]int, n)
    for i := range dis {
        dis[i] = inf
    }
    dis[k-1] = 0
    done := make([]bool, n)
    for {
        x := -1
        for i, ok := range done {
            if !ok && (x < 0 || dis[i] < dis[x]) {
                x = i
            }
        }
        if x < 0 {
            return slices.Max(dis)
        }
        if dis[x] == inf { // 有节点无法到达
            return -1
        }
        done[x] = true // 最短路长度已确定(无法变得更小)
        for y, d := range g[x] {
            // 更新 x 的邻居的最短路
            dis[y] = min(dis[y], dis[x]+d)
        }
    }
}

###js

var networkDelayTime = function(times, n, k) {
    const g = Array.from({length: n}, () => Array(n).fill(Infinity)); // 邻接矩阵
    for (const [x, y, d] of times) {
        g[x - 1][y - 1] = d;
    }

    const dis = Array(n).fill(Infinity);
    dis[k - 1] = 0;
    const done = Array(n).fill(false);
    while (true) {
        let x = -1;
        for (let i = 0; i < n; i++) {
            if (!done[i] && (x < 0 || dis[i] < dis[x])) {
                x = i;
            }
        }
        if (x < 0) {
            return Math.max(...dis);
        }
        if (dis[x] === Infinity) { // 有节点无法到达
            return -1;
        }
        done[x] = true; // 最短路长度已确定(无法变得更小)
        for (let y = 0; y < n; y++) {
            // 更新 x 的邻居的最短路
            dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
        }
    }
};

###rust

impl Solution {
    pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
        const INF: i32 = i32::MAX / 2; // 防止加法溢出
        let n = n as usize;
        let mut g = vec![vec![INF; n]; n]; // 邻接矩阵
        for t in &times {
            g[t[0] as usize - 1][t[1] as usize - 1] = t[2];
        }

        let mut dis = vec![INF; n];
        dis[k as usize - 1] = 0;
        let mut done = vec![false; n];
        loop {
            let mut x = n;
            for (i, &ok) in done.iter().enumerate() {
                if !ok && (x == n || dis[i] < dis[x]) {
                    x = i;
                }
            }
            if x == n {
                return *dis.iter().max().unwrap();
            }
            if dis[x] == INF { // 有节点无法到达
                return -1;
            }
            done[x] = true; // 最短路长度已确定(无法变得更小)
            for (y, &d) in g[x].iter().enumerate() {
                // 更新 x 的邻居的最短路
                dis[y] = dis[y].min(dis[x] + d);
            }
        }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$。
  • 空间复杂度:$\mathcal{O}(n^2)$。

写法二:堆优化 Dijkstra(适用于稀疏图)

寻找最小值的过程可以用一个最小堆来快速完成:

  • 一开始把 $(\textit{dis}[k],k)$ 二元组入堆。
  • 当节点 $x$ 首次出堆时,$\textit{dis}[x]$ 就是写法一中寻找的最小最短路。
  • 更新 $\textit{dis}[y]$ 时,把 $(\textit{dis}[y],y)$ 二元组入堆。

注意,如果一个节点 $x$ 在出堆前,其最短路长度 $\textit{dis}[x]$ 被多次更新,那么堆中会有多个重复的 $x$,并且包含 $x$ 的二元组中的 $\textit{dis}[x]$ 是互不相同的(因为我们只在找到更小的最短路时才会把二元组入堆)。

所以写法一中的 $\textit{done}$ 数组可以省去,取而代之的是用出堆的最短路值(记作 $\textit{dx}$)与当前的 $\textit{dis}[x]$ 比较,如果 $\textit{dx} > \textit{dis}[x]$ 说明 $x$ 之前出堆过,我们已经更新了 $x$ 的邻居的最短路,所以这次就不用更新了,继续外层循环。

答疑

:为什么代码要判断 dx > dis[x]

:对于同一个 $x$,例如先入堆一个比较大的 $\textit{dis}[x]=10$,后面又把 $\textit{dis}[x]$ 更新成 $5$,之后这个 $5$ 会先出堆,然后再把 $10$ 出堆。$10$ 出堆时候是没有必要去更新周围邻居的最短路的,因为 $5$ 出堆之后,就已经把邻居的最短路更新过了,用 $10$ 是无法把邻居的最短路变得更短的,所以直接 continue

###py

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[] for _ in range(n)]  # 邻接表
        for x, y, d in times:
            g[x - 1].append((y - 1, d))

        dis = [inf] * n
        dis[k - 1] = 0
        h = [(0, k - 1)]
        while h:
            dx, x = heappop(h)
            if dx > dis[x]:  # x 之前出堆过
                continue
            for y, d in g[x]:
                new_dis = dx + d
                if new_dis < dis[y]:
                    dis[y] = new_dis  # 更新 x 的邻居的最短路
                    heappush(h, (new_dis, y))
        mx = max(dis)
        return mx if mx < inf else -1

###java

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        List<int[]>[] g = new ArrayList[n]; // 邻接表
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] t : times) {
            g[t[0] - 1].add(new int[]{t[1] - 1, t[2]});
        }

        int maxDis = 0;
        int left = n; // 未确定最短路的节点个数
        int[] dis = new int[n];
        Arrays.fill(dis, Integer.MAX_VALUE);
        dis[k - 1] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
        pq.offer(new int[]{0, k - 1});
        while (!pq.isEmpty()) {
            int[] p = pq.poll();
            int dx = p[0];
            int x = p[1];
            if (dx > dis[x]) { // x 之前出堆过
                continue;
            }
            maxDis = dx; // 求出的最短路会越来越大
            left--;
            for (int[] e : g[x]) {
                int y = e[0];
                int newDis = dx + e[1];
                if (newDis < dis[y]) {
                    dis[y] = newDis; // 更新 x 的邻居的最短路
                    pq.offer(new int[]{newDis, y});
                }
            }
        }
        return left == 0 ? maxDis : -1;
    }
}

###cpp

class Solution {
public:
    int networkDelayTime(vector<vector<int>>& times, int n, int k) {
        vector<vector<pair<int, int>>> g(n); // 邻接表
        for (auto& t : times) {
            g[t[0] - 1].emplace_back(t[1] - 1, t[2]);
        }

        vector<int> dis(n, INT_MAX);
        dis[k - 1] = 0;
        priority_queue<pair<int, int>, vector<pair<int, int>>, greater<>> pq;
        pq.emplace(0, k - 1);
        while (!pq.empty()) {
            auto [dx, x] = pq.top();
            pq.pop();
            if (dx > dis[x]) { // x 之前出堆过
                continue;
            }
            for (auto &[y, d] : g[x]) {
                int new_dis = dx + d;
                if (new_dis < dis[y]) {
                    dis[y] = new_dis; // 更新 x 的邻居的最短路
                    pq.emplace(new_dis, y);
                }
            }
        }
        int mx = ranges::max(dis);
        return mx < INT_MAX ? mx : -1;
    }
};

###go

func networkDelayTime(times [][]int, n, k int) int {
    type edge struct{ to, wt int }
    g := make([][]edge, n) // 邻接表
    for _, t := range times {
        g[t[0]-1] = append(g[t[0]-1], edge{t[1] - 1, t[2]})
    }

    dis := make([]int, n)
    for i := range dis {
        dis[i] = math.MaxInt
    }
    dis[k-1] = 0
    h := hp{{0, k - 1}}
    for len(h) > 0 {
        p := heap.Pop(&h).(pair)
        dx := p.dis
        x := p.x
        if dx > dis[x] { // x 之前出堆过
            continue
        }
        for _, e := range g[x] {
            y := e.to
            newDis := dx + e.wt
            if newDis < dis[y] {
                dis[y] = newDis // 更新 x 的邻居的最短路
                heap.Push(&h, pair{newDis, y})
            }
        }
    }
    mx := slices.Max(dis)
    if mx < math.MaxInt {
        return mx
    }
    return -1
}

type pair struct{ dis, x int }
type hp []pair
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].dis < h[j].dis }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(pair)) }
func (h *hp) Pop() (v any)      { a := *h; *h, v = a[:len(a)-1], a[len(a)-1]; return }

###js

var networkDelayTime = function(times, n, k) {
    const g = Array.from({length: n}, () => []); // 邻接表
    for (const [x, y, d] of times) {
        g[x - 1].push([y - 1, d]);
    }

    const dis = Array(n).fill(Infinity);
    dis[k - 1] = 0;
    const pq = new MinPriorityQueue({priority: (p) => p[0]});
    pq.enqueue([0, k - 1]);
    while (!pq.isEmpty()) {
        const [dx, x] = pq.dequeue().element;
        if (dx > dis[x]) { // x 之前出堆过
            continue;
        }
        for (const [y, d] of g[x]) {
            const newDis = dx + d;
            if (newDis < dis[y]) {
                dis[y] = newDis; // 更新 x 的邻居的最短路
                pq.enqueue([newDis, y]);
            }
        }
    }
    const mx = Math.max(...dis);
    return mx < Infinity ? mx : -1;
};

###rust

use std::collections::BinaryHeap;

impl Solution {
    pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
        let n = n as usize;
        let k = k as usize - 1;
        let mut g = vec![vec![]; n]; // 邻接表
        for t in &times {
            g[t[0] as usize - 1].push((t[1] as usize - 1, t[2]));
        }

        let mut dis = vec![i32::MAX; n];
        dis[k] = 0;
        let mut h = BinaryHeap::new();
        h.push((0, k));
        while let Some((dx, x)) = h.pop() {
            if -dx > dis[x] { // x 之前出堆过
                continue;
            }
            for &(y, d) in &g[x] {
                let new_dis = -dx + d;
                if new_dis < dis[y] {
                    dis[y] = new_dis; // 更新 x 的邻居的最短路
                    h.push((-new_dis, y));
                }
            }
        }
        let mx = *dis.iter().max().unwrap();
        if mx < i32::MAX { mx } else { -1 }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(m\log m)$,其中 $m$ 为 $\textit{times}$ 的长度。由于 $m\ge n-1$,分析复杂度时以 $m$ 为主。注意堆中会有重复节点,所以至多有 $\mathcal{O}(m)$ 个元素,单次操作的复杂度是 $\mathcal{O}(\log m)$。值得注意的是,如果输入的是稠密图,写法二的时间复杂度为 $\mathcal{O}(n^2\log n)$,不如写法一。
  • 空间复杂度:$\mathcal{O}(m)$。

更多相似题目,见下面图论题单中的「单源最短路:Dijkstra」。

分类题单

如何科学刷题?

  1. 滑动窗口(定长/不定长/多指针)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心算法(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

两种方法:堆/排序+滑动窗口(Python/Java/C++/Go/JS/Rust)

核心思路

把「每个列表至少有一个数包含在其中」的区间叫做合法区间

先求出最左边的合法区间,然后求出第二个合法区间,第三个合法区间,依此类推。

比如示例 1,最左边的合法区间是 $[0,5]$。

枚举所有合法区间的左端点,或者枚举所有合法区间的右端点。其中第一个最短的合法区间就是答案。

方法一:堆

在示例 1 中,有三个列表:

  • $[4,10,15,24,26]$。
  • $[0,9,12,20]$。
  • $[5,18,22,30]$。

我们来计算最左边的合法区间,第二个合法区间,第三个合法区间,……

也就是左端点为 $0$ 的合法区间,左端点为 $4$ 的合法区间,左端点为 $5$ 的合法区间。

求出左端点对应的右端点,就知道了区间的长度,其中第一个最短的区间就是答案。

左端点为 $0$ 的合法区间,右端点是这三个列表的第一个元素的最大值,即 $5$。

接下来,去掉 $0$,列表 $[0,9,12,20]$ 变成 $[9,12,20]$,问题变成如下三个列表:

  • $[4,10,15,24,26]$。
  • $[9,12,20]$。
  • $[5,18,22,30]$。

这三个列表的最左边的合法区间是什么?

左端点是这三个列表的第一个元素的最小值 $4$,右端点是这三个列表的第一个元素的最大值 $9$,所以合法区间为 $[4,9]$。

接下来,去掉 $4$,列表 $[4,10,15,24,26]$ 变成 $[10,15,24,26]$,重复上述过程。

在上述过程中,需要快速地求出合法区间的左端点和右端点:

  • 左端点:需要一个数据结构,支持添加元素、计算最小值、删除最小值。这可以用最小堆维护。堆顶(最小元素)就是左端点。
  • 右端点:记作 $r$。一开始 $r$ 为每个列表第一个元素的最大值。当我们去掉列表第一个元素时,就用列表的第二个元素更新 $r$ 的最大值。依此类推。

注:实际没有去掉元素,而是用下标表示元素在列表中的位置。

细节

  1. 为方便计算列表的下一个元素,需要在堆中额外保存元素属于哪个列表,以及在这个列表中的位置。所以堆中的每个元素是一个三元组,即 $(\textit{nums}[i][j],i,j)$,这样列表的下一个元素就是 $\textit{nums}[i][j+1]$。
  2. 如果列表没有下一个元素($j+1$ 等于列表长度),那么去掉当前元素后,将不会有合法区间包含这个列表的元素,算法结束。

###py

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        # 把每个列表的第一个元素入堆
        h = [(arr[0], i, 0) for i, arr in enumerate(nums)]
        heapify(h)

        ans_l = h[0][0]  # 第一个合法区间的左端点
        ans_r = r = max(arr[0] for arr in nums)  # 第一个合法区间的右端点
        while h[0][2] + 1 < len(nums[h[0][1]]):  # 堆顶列表有下一个元素
            _, i, j = h[0]
            x = nums[i][j + 1]  # 堆顶列表的下一个元素
            heapreplace(h, (x, i, j + 1))  # 替换堆顶
            r = max(r, x)  # 更新合法区间的右端点
            l = h[0][0]  # 当前合法区间的左端点
            if r - l < ans_r - ans_l:
                ans_l, ans_r = l, r
        return [ans_l, ans_r]

###java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        int r = Integer.MIN_VALUE;
        for (int i = 0; i < nums.size(); i++) {
            // 把每个列表的第一个元素入堆
            int x = nums.get(i).get(0);
            pq.offer(new int[]{x, i, 0});
            r = Math.max(r, x);
        }

        int ansL = pq.peek()[0]; // 第一个合法区间的左端点
        int ansR = r; // 第一个合法区间的右端点
        while (pq.peek()[2] + 1 < nums.get(pq.peek()[1]).size()) { // 堆顶列表有下一个元素
            int[] top = pq.poll();
            top[0] = nums.get(top[1]).get(++top[2]); // 堆顶列表的下一个元素
            r = Math.max(r, top[0]); // 更新合法区间的右端点
            pq.offer(top); // 入堆(复用 int[],提高效率)
            int l = pq.peek()[0]; // 当前合法区间的左端点
            if (r - l < ansR - ansL) {
                ansL = l;
                ansR = r;
            }
        }
        return new int[]{ansL, ansR};
    }
}

###cpp

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<>> pq;
        int r = INT_MIN;
        for (int i = 0; i < nums.size(); i++) {
            pq.emplace(nums[i][0], i, 0); // 把每个列表的第一个元素入堆
            r = max(r, nums[i][0]);
        }

        int ans_l = get<0>(pq.top()); // 第一个合法区间的左端点
        int ans_r = r; // 第一个合法区间的右端点
        while (true) {
            auto [_, i, j] = pq.top();
            if (j + 1 == nums[i].size()) { // 堆顶列表没有下一个元素
                break;
            }
            pq.pop();
            int x = nums[i][j + 1]; // 堆顶列表的下一个元素
            pq.emplace(x, i, j + 1); // 入堆
            r = max(r, x); // 更新合法区间的右端点
            int l = get<0>(pq.top()); // 当前合法区间的左端点
            if (r - l < ans_r - ans_l) {
                ans_l = l;
                ans_r = r;
            }
        }
        return {ans_l, ans_r};
    }
};

###go

func smallestRange(nums [][]int) []int {
    h := make(hp, len(nums))
    r := math.MinInt
    for i, arr := range nums {
        h[i] = tuple{arr[0], i, 0} // 把每个列表的第一个元素入堆
        r = max(r, arr[0])
    }
    heap.Init(&h)

    ansL, ansR := h[0].x, r // 第一个合法区间的左右端点
    for h[0].j+1 < len(nums[h[0].i]) { // 堆顶列表有下一个元素
        x := nums[h[0].i][h[0].j+1] // 堆顶列表的下一个元素
        r = max(r, x) // 更新合法区间的右端点
        h[0].x = x // 替换堆顶
        h[0].j++
        heap.Fix(&h, 0)
        l := h[0].x // 当前合法区间的左端点
        if r-l < ansR-ansL {
            ansL, ansR = l, r
        }
    }
    return []int{ansL, ansR}
}

type tuple struct{ x, i, j int }
type hp []tuple
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].x < h[j].x }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (hp) Push(any)             {} // 没用到,可以不写
func (hp) Pop() (_ any)         { return }

###js

var smallestRange = function(nums) {
    const pq = new MinPriorityQueue({priority: a => a[0]});
    let r = -Infinity;
    for (let i = 0; i < nums.length; i++) {
        pq.enqueue([nums[i][0], i, 0]); // 每个列表的第一个元素入堆
        r = Math.max(r, nums[i][0]);
    }

    let ansL = pq.front().element[0]; // 第一个合法区间的左端点
    let ansR = r; // 第一个合法区间的右端点
    while (true) {
        const [_, i, j] = pq.dequeue().element;
        if (j + 1 === nums[i].length) { // 堆顶列表没有下一个元素
            break;
        }
        const x = nums[i][j + 1]; // 堆顶列表的下一个元素
        pq.enqueue([x, i, j + 1]); // 入堆
        r = Math.max(r, x); // 更新合法区间的右端点
        const l = pq.front().element[0]; // 当前合法区间的左端点
        if (r - l < ansR - ansL) {
            ansL = l;
            ansR = r;
        }
    }
    return [ansL, ansR];
};

###rust

use std::collections::BinaryHeap;

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let mut h = BinaryHeap::with_capacity(nums.len()); // 预分配空间
        let mut r = i32::MIN;
        for (i, arr) in nums.iter().enumerate() {
            // 把每个列表的第一个元素入堆
            h.push((-arr[0], i, 0)); // 取反变成最小堆
            r = r.max(arr[0]);
        }

        let mut ans_l = -h.peek().unwrap().0; // 第一个合法区间的左端点
        let mut ans_r = r; // 第一个合法区间的右端点
        while h.peek().unwrap().2 + 1 < nums[h.peek().unwrap().1].len() { // 堆顶列表有下一个元素
            let (_, i, j) = h.pop().unwrap();
            let x = nums[i][j + 1]; // 堆顶列表的下一个元素
            h.push((-x, i, j + 1)); // 入堆
            r = r.max(x); // 更新合法区间的右端点
            let l = -h.peek().unwrap().0; // 当前合法区间的左端点
            if r - l < ans_r - ans_l {
                ans_l = l;
                ans_r = r;
            }
        }
        vec![ans_l, ans_r]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(L\log k)$,其中 $k$ 是 $\textit{nums}$ 的长度,$L$ 是所有 $\textit{nums}[i]$ 的长度之和。循环 $\mathcal{O}(L)$ 次,每次循环需要 $\mathcal{O}(\log k)$ 的时间操作堆。
  • 空间复杂度:$\mathcal{O}(k)$。堆需要 $\mathcal{O}(k)$ 的空间保存元素。

方法二:排序+滑动窗口

对于示例 1 的这三个列表:

  • $[4,10,15,24,26]$。
  • $[0,9,12,20]$。
  • $[5,18,22,30]$。

把所有元素都合在一起排序,可以得到如下结果:

$$
\begin{array}{r|}
元素值 & 0 & 4 & 5 & 9 & 10 & 12 & 15 & 18 & 20 & 22 & 24 & 26 & 30 \
所属列表编号 & 1 & 0 & 2 & 1 & 0 & 1 & 0 & 2 & 1 & 2 & 0 & 0 & 2 \
\end{array}
$$

把上表视作一个由(元素值,所属列表编号)组成的数组,即

$$
\textit{pairs} = [(0, 1), (4, 0), (5, 2), \ldots, (24, 0), (26, 0), (30, 2)]
$$

合法区间等价于 $\textit{pairs}$ 的一个连续子数组,满足列表编号 $0,1,2,\ldots,k-1$ 都在这个子数组中。

由于子数组越长,越能包含 $0,1,2,\ldots,k-1$ 所有编号,有单调性,可以用滑动窗口解决。如果你不了解滑动窗口,可以看视频【基础算法精讲 03】

细节

  1. 用一个长为 $k$ 的数组 $\textit{cnt}$ 统计窗口中的每个编号的出现次数。
  2. 判断是否包含所有编号,简单的方法是判断所有 $\textit{cnt}[i]$ 是否都大于 $0$。为了加快判断速度,可以仿照 76. 最小覆盖子串 的思路,额外用一个变量 $\textit{empty}$ 表示 $\textit{cnt}[i]=0$ 的列表个数。编号 $i$ 进入窗口前,如果 $\textit{cnt}[i]=0$,那么 $\textit{empty}$ 减一;编号 $i$ 离开窗口后,如果 $\textit{cnt}[i]=0$,那么 $\textit{empty}$ 加一。这样只需要判断 $\textit{empty}$ 是否为 $0$,就知道所有 $\textit{cnt}[i]$ 是否都大于 $0$ 了。

注:方法一相当于枚举合法区间的左端点,而方法二相当于枚举合法区间的右端点。

###py

class Solution:
    def smallestRange(self, nums: List[List[int]]) -> List[int]:
        pairs = sorted((x, i) for (i, arr) in enumerate(nums) for x in arr)
        ans_l, ans_r = -inf, inf
        empty = len(nums)
        cnt = [0] * empty
        left = 0
        for r, i in pairs:
            if cnt[i] == 0:  # 包含 nums[i] 的数字
                empty -= 1
            cnt[i] += 1
            while empty == 0:  # 每个列表都至少包含一个数
                l, i = pairs[left]
                if r - l < ans_r - ans_l:
                    ans_l, ans_r = l, r
                cnt[i] -= 1
                if cnt[i] == 0:  # 不包含 nums[i] 的数字
                    empty += 1
                left += 1
        return [ans_l, ans_r]

###java

class Solution {
    public int[] smallestRange(List<List<Integer>> nums) {
        int sumLen = 0;
        for (List<Integer> list : nums) {
            sumLen += list.size();
        }

        int[][] pairs = new int[sumLen][2];
        int pi = 0;
        for (int i = 0; i < nums.size(); i++) {
            for (int x : nums.get(i)) {
                pairs[pi][0] = x;
                pairs[pi++][1] = i;
            }
        }
        Arrays.sort(pairs, (a, b) -> a[0] - b[0]);

        int ansL = pairs[0][0];
        int ansR = pairs[sumLen - 1][0];
        int empty = nums.size();
        int[] cnt = new int[empty];
        int left = 0;
        for (int[] p : pairs) {
            int r = p[0];
            int i = p[1];
            if (cnt[i] == 0) { // 包含 nums[i] 的数字
                empty--;
            }
            cnt[i]++;
            while (empty == 0) { // 每个列表都至少包含一个数
                int l = pairs[left][0];
                if (r - l < ansR - ansL) {
                    ansL = l;
                    ansR = r;
                }
                i = pairs[left][1];
                cnt[i]--;
                if (cnt[i] == 0) { // 不包含 nums[i] 的数字
                    empty++;
                }
                left++;
            }
        }
        return new int[]{ansL, ansR};
    }
}

###cpp

class Solution {
public:
    vector<int> smallestRange(vector<vector<int>>& nums) {
        vector<pair<int, int>> pairs;
        for (int i = 0; i < nums.size(); i++) {
            for (int x : nums[i]) {
                pairs.emplace_back(x, i);
            }
        }
        // 看上去 std::sort 比 ranges::sort 更快
        sort(pairs.begin(), pairs.end());

        int ans_l = pairs[0].first;
        int ans_r = pairs.back().first;
        int empty = nums.size();
        vector<int> cnt(empty);
        int left = 0;
        for (auto [r, i] : pairs) {
            if (cnt[i] == 0) { // 包含 nums[i] 的数字
                empty--;
            }
            cnt[i]++;
            while (empty == 0) { // 每个列表都至少包含一个数
                auto [l, i] = pairs[left];
                if (r - l < ans_r - ans_l) {
                    ans_l = l;
                    ans_r = r;
                }
                cnt[i]--;
                if (cnt[i] == 0) { // 不包含 nums[i] 的数字
                    empty++;
                }
                left++;
            }
        }
        return {ans_l, ans_r};
    }
};

###go

func smallestRange(nums [][]int) []int {
    type pair struct{ x, i int }
    pairs := []pair{}
    for i, arr := range nums {
        for _, x := range arr {
            pairs = append(pairs, pair{x, i})
        }
    }
    slices.SortFunc(pairs, func(a, b pair) int { return a.x - b.x })

    ansL, ansR := pairs[0].x, pairs[len(pairs)-1].x
    empty := len(nums)
    cnt := make([]int, empty)
    left := 0
    for _, p := range pairs {
        r, i := p.x, p.i
        if cnt[i] == 0 { // 包含 nums[i] 的数字
            empty--
        }
        cnt[i]++
        for empty == 0 { // 每个列表都至少包含一个数
            l, i := pairs[left].x, pairs[left].i
            if r-l < ansR-ansL {
                ansL, ansR = l, r
            }
            cnt[i]--
            if cnt[i] == 0 {
                // 不包含 nums[i] 的数字
                empty++
            }
            left++
        }
    }
    return []int{ansL, ansR}
}

###js

var smallestRange = function(nums) {
    const pairs = [];
    for (let i = 0; i < nums.length; i++) {
        for (const x of nums[i]) {
            pairs.push([x, i]);
        }
    }
    pairs.sort((a, b) => a[0] - b[0]);

    let ansL = -Infinity, ansR = Infinity;
    let empty = nums.length;
    const cnt = Array(empty).fill(0);
    let left = 0;
    for (const [r, i] of pairs) {
        if (cnt[i] === 0) { // 包含 nums[i] 的数字
            empty--;
        }
        cnt[i]++;
        while (empty === 0) { // 每个列表都至少包含一个数
            const [l, i] = pairs[left];
            if (r - l < ansR - ansL) {
                ansL = l;
                ansR = r;
            }
            cnt[i]--;
            if (cnt[i] === 0) { // 不包含 nums[i] 的数字
                empty++;
            }
            left++;
        }
    }
    return [ansL, ansR];
};

###rust

impl Solution {
    pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
        let mut pairs = vec![];
        for (i, arr) in nums.iter().enumerate() {
            for &x in arr {
                pairs.push((x, i));
            }
        }
        pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0));

        let mut ans_l = pairs[0].0;
        let mut ans_r = pairs[pairs.len() - 1].0;
        let mut empty = nums.len();
        let mut cnt = vec![0; empty];
        let mut left = 0;
        for &(r, i) in &pairs {
            if cnt[i] == 0 { // 包含 nums[i] 的数字
                empty -= 1;
            }
            cnt[i] += 1;
            while empty == 0 { // 每个列表都至少包含一个数
                let (l, i) = pairs[left];
                if r - l < ans_r - ans_l {
                    ans_l = l;
                    ans_r = r;
                }
                cnt[i] -= 1;
                if cnt[i] == 0 { // 不包含 nums[i] 的数字
                    empty += 1;
                }
                left += 1;
            }
        }
        vec![ans_l, ans_r]
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(L\log L)$,其中 $L$ 是所有 $\textit{nums}[i]$ 的长度之和。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(L)$。

更多相似题目,见下面数据结构题单中的「五、堆(优先队列)」,以及滑动窗口题单中的「§2.2 求最短/最小」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与一般树(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

❌