O(n) 做法(Python/Java/C++/C/Go/JS/Rust)
本题和 3208. 交替组 II 是一样的,令 $k=3$ 即可,请看 我的题解。
本题和 3208. 交替组 II 是一样的,令 $k=3$ 即可,请看 我的题解。
定义 $g[i][j]$ 表示节点 $i$ 到节点 $j$ 这条边的边权。如果没有 $i$ 到 $j$ 的边,则 $g[i][j]=\infty$。
定义 $\textit{dis}[i]$ 表示起点 $k$ 到节点 $i$ 的最短路长度,一开始 $\textit{dis}[k]=0$,其余 $\textit{dis}[i]=\infty$ 表示尚未计算出。
我们的目标是计算出最终的 $\textit{dis}$ 数组。
对于本题,在计算最短路时,如果发现当前找到的最小最短路等于 $\infty$,说明有节点无法到达,可以提前结束算法,返回 $-1$。
如果所有节点都可以到达,返回 $\max(\textit{dis})$。
代码实现时,节点编号改成从 $0$ 开始。
###py
class Solution:
def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
g = [[inf for _ in range(n)] for _ in range(n)] # 邻接矩阵
for x, y, d in times:
g[x - 1][y - 1] = d
dis = [inf] * n
ans = dis[k - 1] = 0
done = [False] * n
while True:
x = -1
for i, ok in enumerate(done):
if not ok and (x < 0 or dis[i] < dis[x]):
x = i
if x < 0:
return ans # 最后一次算出的最短路就是最大的
if dis[x] == inf: # 有节点无法到达
return -1
ans = dis[x] # 求出的最短路会越来越大
done[x] = True # 最短路长度已确定(无法变得更小)
for y, d in enumerate(g[x]):
# 更新 x 的邻居的最短路
dis[y] = min(dis[y], dis[x] + d)
###java
class Solution {
public int networkDelayTime(int[][] times, int n, int k) {
final int INF = Integer.MAX_VALUE / 2; // 防止加法溢出
int[][] g = new int[n][n]; // 邻接矩阵
for (int[] row : g) {
Arrays.fill(row, INF);
}
for (int[] t : times) {
g[t[0] - 1][t[1] - 1] = t[2];
}
int maxDis = 0;
int[] dis = new int[n];
Arrays.fill(dis, INF);
dis[k - 1] = 0;
boolean[] done = new boolean[n];
while (true) {
int x = -1;
for (int i = 0; i < n; i++) {
if (!done[i] && (x < 0 || dis[i] < dis[x])) {
x = i;
}
}
if (x < 0) {
return maxDis; // 最后一次算出的最短路就是最大的
}
if (dis[x] == INF) { // 有节点无法到达
return -1;
}
maxDis = dis[x]; // 求出的最短路会越来越大
done[x] = true; // 最短路长度已确定(无法变得更小)
for (int y = 0; y < n; y++) {
// 更新 x 的邻居的最短路
dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
}
}
}
}
###cpp
class Solution {
public:
int networkDelayTime(vector<vector<int>>& times, int n, int k) {
vector<vector<int>> g(n, vector<int>(n, INT_MAX / 2)); // 邻接矩阵
for (auto& t : times) {
g[t[0] - 1][t[1] - 1] = t[2];
}
vector<int> dis(n, INT_MAX / 2), done(n);
dis[k - 1] = 0;
while (true) {
int x = -1;
for (int i = 0; i < n; i++) {
if (!done[i] && (x < 0 || dis[i] < dis[x])) {
x = i;
}
}
if (x < 0) {
return ranges::max(dis);
}
if (dis[x] == INT_MAX / 2) { // 有节点无法到达
return -1;
}
done[x] = true; // 最短路长度已确定(无法变得更小)
for (int y = 0; y < n; y++) {
// 更新 x 的邻居的最短路
dis[y] = min(dis[y], dis[x] + g[x][y]);
}
}
}
};
###go
func networkDelayTime(times [][]int, n, k int) int {
const inf = math.MaxInt / 2 // 防止加法溢出
g := make([][]int, n) // 邻接矩阵
for i := range g {
g[i] = make([]int, n)
for j := range g[i] {
g[i][j] = inf
}
}
for _, t := range times {
g[t[0]-1][t[1]-1] = t[2]
}
dis := make([]int, n)
for i := range dis {
dis[i] = inf
}
dis[k-1] = 0
done := make([]bool, n)
for {
x := -1
for i, ok := range done {
if !ok && (x < 0 || dis[i] < dis[x]) {
x = i
}
}
if x < 0 {
return slices.Max(dis)
}
if dis[x] == inf { // 有节点无法到达
return -1
}
done[x] = true // 最短路长度已确定(无法变得更小)
for y, d := range g[x] {
// 更新 x 的邻居的最短路
dis[y] = min(dis[y], dis[x]+d)
}
}
}
###js
var networkDelayTime = function(times, n, k) {
const g = Array.from({length: n}, () => Array(n).fill(Infinity)); // 邻接矩阵
for (const [x, y, d] of times) {
g[x - 1][y - 1] = d;
}
const dis = Array(n).fill(Infinity);
dis[k - 1] = 0;
const done = Array(n).fill(false);
while (true) {
let x = -1;
for (let i = 0; i < n; i++) {
if (!done[i] && (x < 0 || dis[i] < dis[x])) {
x = i;
}
}
if (x < 0) {
return Math.max(...dis);
}
if (dis[x] === Infinity) { // 有节点无法到达
return -1;
}
done[x] = true; // 最短路长度已确定(无法变得更小)
for (let y = 0; y < n; y++) {
// 更新 x 的邻居的最短路
dis[y] = Math.min(dis[y], dis[x] + g[x][y]);
}
}
};
###rust
impl Solution {
pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
const INF: i32 = i32::MAX / 2; // 防止加法溢出
let n = n as usize;
let mut g = vec![vec![INF; n]; n]; // 邻接矩阵
for t in × {
g[t[0] as usize - 1][t[1] as usize - 1] = t[2];
}
let mut dis = vec![INF; n];
dis[k as usize - 1] = 0;
let mut done = vec![false; n];
loop {
let mut x = n;
for (i, &ok) in done.iter().enumerate() {
if !ok && (x == n || dis[i] < dis[x]) {
x = i;
}
}
if x == n {
return *dis.iter().max().unwrap();
}
if dis[x] == INF { // 有节点无法到达
return -1;
}
done[x] = true; // 最短路长度已确定(无法变得更小)
for (y, &d) in g[x].iter().enumerate() {
// 更新 x 的邻居的最短路
dis[y] = dis[y].min(dis[x] + d);
}
}
}
}
寻找最小值的过程可以用一个最小堆来快速完成:
注意,如果一个节点 $x$ 在出堆前,其最短路长度 $\textit{dis}[x]$ 被多次更新,那么堆中会有多个重复的 $x$,并且包含 $x$ 的二元组中的 $\textit{dis}[x]$ 是互不相同的(因为我们只在找到更小的最短路时才会把二元组入堆)。
所以写法一中的 $\textit{done}$ 数组可以省去,取而代之的是用出堆的最短路值(记作 $\textit{dx}$)与当前的 $\textit{dis}[x]$ 比较,如果 $\textit{dx} > \textit{dis}[x]$ 说明 $x$ 之前出堆过,我们已经更新了 $x$ 的邻居的最短路,所以这次就不用更新了,继续外层循环。
问:为什么代码要判断 dx > dis[x]
?
答:对于同一个 $x$,例如先入堆一个比较大的 $\textit{dis}[x]=10$,后面又把 $\textit{dis}[x]$ 更新成 $5$,之后这个 $5$ 会先出堆,然后再把 $10$ 出堆。$10$ 出堆时候是没有必要去更新周围邻居的最短路的,因为 $5$ 出堆之后,就已经把邻居的最短路更新过了,用 $10$ 是无法把邻居的最短路变得更短的,所以直接 continue
。
###py
class Solution:
def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
g = [[] for _ in range(n)] # 邻接表
for x, y, d in times:
g[x - 1].append((y - 1, d))
dis = [inf] * n
dis[k - 1] = 0
h = [(0, k - 1)]
while h:
dx, x = heappop(h)
if dx > dis[x]: # x 之前出堆过
continue
for y, d in g[x]:
new_dis = dx + d
if new_dis < dis[y]:
dis[y] = new_dis # 更新 x 的邻居的最短路
heappush(h, (new_dis, y))
mx = max(dis)
return mx if mx < inf else -1
###java
class Solution {
public int networkDelayTime(int[][] times, int n, int k) {
List<int[]>[] g = new ArrayList[n]; // 邻接表
Arrays.setAll(g, i -> new ArrayList<>());
for (int[] t : times) {
g[t[0] - 1].add(new int[]{t[1] - 1, t[2]});
}
int maxDis = 0;
int left = n; // 未确定最短路的节点个数
int[] dis = new int[n];
Arrays.fill(dis, Integer.MAX_VALUE);
dis[k - 1] = 0;
PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
pq.offer(new int[]{0, k - 1});
while (!pq.isEmpty()) {
int[] p = pq.poll();
int dx = p[0];
int x = p[1];
if (dx > dis[x]) { // x 之前出堆过
continue;
}
maxDis = dx; // 求出的最短路会越来越大
left--;
for (int[] e : g[x]) {
int y = e[0];
int newDis = dx + e[1];
if (newDis < dis[y]) {
dis[y] = newDis; // 更新 x 的邻居的最短路
pq.offer(new int[]{newDis, y});
}
}
}
return left == 0 ? maxDis : -1;
}
}
###cpp
class Solution {
public:
int networkDelayTime(vector<vector<int>>& times, int n, int k) {
vector<vector<pair<int, int>>> g(n); // 邻接表
for (auto& t : times) {
g[t[0] - 1].emplace_back(t[1] - 1, t[2]);
}
vector<int> dis(n, INT_MAX);
dis[k - 1] = 0;
priority_queue<pair<int, int>, vector<pair<int, int>>, greater<>> pq;
pq.emplace(0, k - 1);
while (!pq.empty()) {
auto [dx, x] = pq.top();
pq.pop();
if (dx > dis[x]) { // x 之前出堆过
continue;
}
for (auto &[y, d] : g[x]) {
int new_dis = dx + d;
if (new_dis < dis[y]) {
dis[y] = new_dis; // 更新 x 的邻居的最短路
pq.emplace(new_dis, y);
}
}
}
int mx = ranges::max(dis);
return mx < INT_MAX ? mx : -1;
}
};
###go
func networkDelayTime(times [][]int, n, k int) int {
type edge struct{ to, wt int }
g := make([][]edge, n) // 邻接表
for _, t := range times {
g[t[0]-1] = append(g[t[0]-1], edge{t[1] - 1, t[2]})
}
dis := make([]int, n)
for i := range dis {
dis[i] = math.MaxInt
}
dis[k-1] = 0
h := hp{{0, k - 1}}
for len(h) > 0 {
p := heap.Pop(&h).(pair)
dx := p.dis
x := p.x
if dx > dis[x] { // x 之前出堆过
continue
}
for _, e := range g[x] {
y := e.to
newDis := dx + e.wt
if newDis < dis[y] {
dis[y] = newDis // 更新 x 的邻居的最短路
heap.Push(&h, pair{newDis, y})
}
}
}
mx := slices.Max(dis)
if mx < math.MaxInt {
return mx
}
return -1
}
type pair struct{ dis, x int }
type hp []pair
func (h hp) Len() int { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].dis < h[j].dis }
func (h hp) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any) { *h = append(*h, v.(pair)) }
func (h *hp) Pop() (v any) { a := *h; *h, v = a[:len(a)-1], a[len(a)-1]; return }
###js
var networkDelayTime = function(times, n, k) {
const g = Array.from({length: n}, () => []); // 邻接表
for (const [x, y, d] of times) {
g[x - 1].push([y - 1, d]);
}
const dis = Array(n).fill(Infinity);
dis[k - 1] = 0;
const pq = new MinPriorityQueue({priority: (p) => p[0]});
pq.enqueue([0, k - 1]);
while (!pq.isEmpty()) {
const [dx, x] = pq.dequeue().element;
if (dx > dis[x]) { // x 之前出堆过
continue;
}
for (const [y, d] of g[x]) {
const newDis = dx + d;
if (newDis < dis[y]) {
dis[y] = newDis; // 更新 x 的邻居的最短路
pq.enqueue([newDis, y]);
}
}
}
const mx = Math.max(...dis);
return mx < Infinity ? mx : -1;
};
###rust
use std::collections::BinaryHeap;
impl Solution {
pub fn network_delay_time(times: Vec<Vec<i32>>, n: i32, k: i32) -> i32 {
let n = n as usize;
let k = k as usize - 1;
let mut g = vec![vec![]; n]; // 邻接表
for t in × {
g[t[0] as usize - 1].push((t[1] as usize - 1, t[2]));
}
let mut dis = vec![i32::MAX; n];
dis[k] = 0;
let mut h = BinaryHeap::new();
h.push((0, k));
while let Some((dx, x)) = h.pop() {
if -dx > dis[x] { // x 之前出堆过
continue;
}
for &(y, d) in &g[x] {
let new_dis = -dx + d;
if new_dis < dis[y] {
dis[y] = new_dis; // 更新 x 的邻居的最短路
h.push((-new_dis, y));
}
}
}
let mx = *dis.iter().max().unwrap();
if mx < i32::MAX { mx } else { -1 }
}
}
更多相似题目,见下面图论题单中的「单源最短路:Dijkstra」。
欢迎关注 B站@灵茶山艾府
把「每个列表至少有一个数包含在其中」的区间叫做合法区间。
先求出最左边的合法区间,然后求出第二个合法区间,第三个合法区间,依此类推。
比如示例 1,最左边的合法区间是 $[0,5]$。
枚举所有合法区间的左端点,或者枚举所有合法区间的右端点。其中第一个最短的合法区间就是答案。
在示例 1 中,有三个列表:
我们来计算最左边的合法区间,第二个合法区间,第三个合法区间,……
也就是左端点为 $0$ 的合法区间,左端点为 $4$ 的合法区间,左端点为 $5$ 的合法区间。
求出左端点对应的右端点,就知道了区间的长度,其中第一个最短的区间就是答案。
左端点为 $0$ 的合法区间,右端点是这三个列表的第一个元素的最大值,即 $5$。
接下来,去掉 $0$,列表 $[0,9,12,20]$ 变成 $[9,12,20]$,问题变成如下三个列表:
这三个列表的最左边的合法区间是什么?
左端点是这三个列表的第一个元素的最小值 $4$,右端点是这三个列表的第一个元素的最大值 $9$,所以合法区间为 $[4,9]$。
接下来,去掉 $4$,列表 $[4,10,15,24,26]$ 变成 $[10,15,24,26]$,重复上述过程。
在上述过程中,需要快速地求出合法区间的左端点和右端点:
注:实际没有去掉元素,而是用下标表示元素在列表中的位置。
###py
class Solution:
def smallestRange(self, nums: List[List[int]]) -> List[int]:
# 把每个列表的第一个元素入堆
h = [(arr[0], i, 0) for i, arr in enumerate(nums)]
heapify(h)
ans_l = h[0][0] # 第一个合法区间的左端点
ans_r = r = max(arr[0] for arr in nums) # 第一个合法区间的右端点
while h[0][2] + 1 < len(nums[h[0][1]]): # 堆顶列表有下一个元素
_, i, j = h[0]
x = nums[i][j + 1] # 堆顶列表的下一个元素
heapreplace(h, (x, i, j + 1)) # 替换堆顶
r = max(r, x) # 更新合法区间的右端点
l = h[0][0] # 当前合法区间的左端点
if r - l < ans_r - ans_l:
ans_l, ans_r = l, r
return [ans_l, ans_r]
###java
class Solution {
public int[] smallestRange(List<List<Integer>> nums) {
PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
int r = Integer.MIN_VALUE;
for (int i = 0; i < nums.size(); i++) {
// 把每个列表的第一个元素入堆
int x = nums.get(i).get(0);
pq.offer(new int[]{x, i, 0});
r = Math.max(r, x);
}
int ansL = pq.peek()[0]; // 第一个合法区间的左端点
int ansR = r; // 第一个合法区间的右端点
while (pq.peek()[2] + 1 < nums.get(pq.peek()[1]).size()) { // 堆顶列表有下一个元素
int[] top = pq.poll();
top[0] = nums.get(top[1]).get(++top[2]); // 堆顶列表的下一个元素
r = Math.max(r, top[0]); // 更新合法区间的右端点
pq.offer(top); // 入堆(复用 int[],提高效率)
int l = pq.peek()[0]; // 当前合法区间的左端点
if (r - l < ansR - ansL) {
ansL = l;
ansR = r;
}
}
return new int[]{ansL, ansR};
}
}
###cpp
class Solution {
public:
vector<int> smallestRange(vector<vector<int>>& nums) {
priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<>> pq;
int r = INT_MIN;
for (int i = 0; i < nums.size(); i++) {
pq.emplace(nums[i][0], i, 0); // 把每个列表的第一个元素入堆
r = max(r, nums[i][0]);
}
int ans_l = get<0>(pq.top()); // 第一个合法区间的左端点
int ans_r = r; // 第一个合法区间的右端点
while (true) {
auto [_, i, j] = pq.top();
if (j + 1 == nums[i].size()) { // 堆顶列表没有下一个元素
break;
}
pq.pop();
int x = nums[i][j + 1]; // 堆顶列表的下一个元素
pq.emplace(x, i, j + 1); // 入堆
r = max(r, x); // 更新合法区间的右端点
int l = get<0>(pq.top()); // 当前合法区间的左端点
if (r - l < ans_r - ans_l) {
ans_l = l;
ans_r = r;
}
}
return {ans_l, ans_r};
}
};
###go
func smallestRange(nums [][]int) []int {
h := make(hp, len(nums))
r := math.MinInt
for i, arr := range nums {
h[i] = tuple{arr[0], i, 0} // 把每个列表的第一个元素入堆
r = max(r, arr[0])
}
heap.Init(&h)
ansL, ansR := h[0].x, r // 第一个合法区间的左右端点
for h[0].j+1 < len(nums[h[0].i]) { // 堆顶列表有下一个元素
x := nums[h[0].i][h[0].j+1] // 堆顶列表的下一个元素
r = max(r, x) // 更新合法区间的右端点
h[0].x = x // 替换堆顶
h[0].j++
heap.Fix(&h, 0)
l := h[0].x // 当前合法区间的左端点
if r-l < ansR-ansL {
ansL, ansR = l, r
}
}
return []int{ansL, ansR}
}
type tuple struct{ x, i, j int }
type hp []tuple
func (h hp) Len() int { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].x < h[j].x }
func (h hp) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (hp) Push(any) {} // 没用到,可以不写
func (hp) Pop() (_ any) { return }
###js
var smallestRange = function(nums) {
const pq = new MinPriorityQueue({priority: a => a[0]});
let r = -Infinity;
for (let i = 0; i < nums.length; i++) {
pq.enqueue([nums[i][0], i, 0]); // 每个列表的第一个元素入堆
r = Math.max(r, nums[i][0]);
}
let ansL = pq.front().element[0]; // 第一个合法区间的左端点
let ansR = r; // 第一个合法区间的右端点
while (true) {
const [_, i, j] = pq.dequeue().element;
if (j + 1 === nums[i].length) { // 堆顶列表没有下一个元素
break;
}
const x = nums[i][j + 1]; // 堆顶列表的下一个元素
pq.enqueue([x, i, j + 1]); // 入堆
r = Math.max(r, x); // 更新合法区间的右端点
const l = pq.front().element[0]; // 当前合法区间的左端点
if (r - l < ansR - ansL) {
ansL = l;
ansR = r;
}
}
return [ansL, ansR];
};
###rust
use std::collections::BinaryHeap;
impl Solution {
pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
let mut h = BinaryHeap::with_capacity(nums.len()); // 预分配空间
let mut r = i32::MIN;
for (i, arr) in nums.iter().enumerate() {
// 把每个列表的第一个元素入堆
h.push((-arr[0], i, 0)); // 取反变成最小堆
r = r.max(arr[0]);
}
let mut ans_l = -h.peek().unwrap().0; // 第一个合法区间的左端点
let mut ans_r = r; // 第一个合法区间的右端点
while h.peek().unwrap().2 + 1 < nums[h.peek().unwrap().1].len() { // 堆顶列表有下一个元素
let (_, i, j) = h.pop().unwrap();
let x = nums[i][j + 1]; // 堆顶列表的下一个元素
h.push((-x, i, j + 1)); // 入堆
r = r.max(x); // 更新合法区间的右端点
let l = -h.peek().unwrap().0; // 当前合法区间的左端点
if r - l < ans_r - ans_l {
ans_l = l;
ans_r = r;
}
}
vec![ans_l, ans_r]
}
}
对于示例 1 的这三个列表:
把所有元素都合在一起排序,可以得到如下结果:
$$
\begin{array}{r|}
元素值 & 0 & 4 & 5 & 9 & 10 & 12 & 15 & 18 & 20 & 22 & 24 & 26 & 30 \
所属列表编号 & 1 & 0 & 2 & 1 & 0 & 1 & 0 & 2 & 1 & 2 & 0 & 0 & 2 \
\end{array}
$$
把上表视作一个由(元素值,所属列表编号)组成的数组,即
$$
\textit{pairs} = [(0, 1), (4, 0), (5, 2), \ldots, (24, 0), (26, 0), (30, 2)]
$$
合法区间等价于 $\textit{pairs}$ 的一个连续子数组,满足列表编号 $0,1,2,\ldots,k-1$ 都在这个子数组中。
由于子数组越长,越能包含 $0,1,2,\ldots,k-1$ 所有编号,有单调性,可以用滑动窗口解决。如果你不了解滑动窗口,可以看视频【基础算法精讲 03】。
注:方法一相当于枚举合法区间的左端点,而方法二相当于枚举合法区间的右端点。
###py
class Solution:
def smallestRange(self, nums: List[List[int]]) -> List[int]:
pairs = sorted((x, i) for (i, arr) in enumerate(nums) for x in arr)
ans_l, ans_r = -inf, inf
empty = len(nums)
cnt = [0] * empty
left = 0
for r, i in pairs:
if cnt[i] == 0: # 包含 nums[i] 的数字
empty -= 1
cnt[i] += 1
while empty == 0: # 每个列表都至少包含一个数
l, i = pairs[left]
if r - l < ans_r - ans_l:
ans_l, ans_r = l, r
cnt[i] -= 1
if cnt[i] == 0: # 不包含 nums[i] 的数字
empty += 1
left += 1
return [ans_l, ans_r]
###java
class Solution {
public int[] smallestRange(List<List<Integer>> nums) {
int sumLen = 0;
for (List<Integer> list : nums) {
sumLen += list.size();
}
int[][] pairs = new int[sumLen][2];
int pi = 0;
for (int i = 0; i < nums.size(); i++) {
for (int x : nums.get(i)) {
pairs[pi][0] = x;
pairs[pi++][1] = i;
}
}
Arrays.sort(pairs, (a, b) -> a[0] - b[0]);
int ansL = pairs[0][0];
int ansR = pairs[sumLen - 1][0];
int empty = nums.size();
int[] cnt = new int[empty];
int left = 0;
for (int[] p : pairs) {
int r = p[0];
int i = p[1];
if (cnt[i] == 0) { // 包含 nums[i] 的数字
empty--;
}
cnt[i]++;
while (empty == 0) { // 每个列表都至少包含一个数
int l = pairs[left][0];
if (r - l < ansR - ansL) {
ansL = l;
ansR = r;
}
i = pairs[left][1];
cnt[i]--;
if (cnt[i] == 0) { // 不包含 nums[i] 的数字
empty++;
}
left++;
}
}
return new int[]{ansL, ansR};
}
}
###cpp
class Solution {
public:
vector<int> smallestRange(vector<vector<int>>& nums) {
vector<pair<int, int>> pairs;
for (int i = 0; i < nums.size(); i++) {
for (int x : nums[i]) {
pairs.emplace_back(x, i);
}
}
// 看上去 std::sort 比 ranges::sort 更快
sort(pairs.begin(), pairs.end());
int ans_l = pairs[0].first;
int ans_r = pairs.back().first;
int empty = nums.size();
vector<int> cnt(empty);
int left = 0;
for (auto [r, i] : pairs) {
if (cnt[i] == 0) { // 包含 nums[i] 的数字
empty--;
}
cnt[i]++;
while (empty == 0) { // 每个列表都至少包含一个数
auto [l, i] = pairs[left];
if (r - l < ans_r - ans_l) {
ans_l = l;
ans_r = r;
}
cnt[i]--;
if (cnt[i] == 0) { // 不包含 nums[i] 的数字
empty++;
}
left++;
}
}
return {ans_l, ans_r};
}
};
###go
func smallestRange(nums [][]int) []int {
type pair struct{ x, i int }
pairs := []pair{}
for i, arr := range nums {
for _, x := range arr {
pairs = append(pairs, pair{x, i})
}
}
slices.SortFunc(pairs, func(a, b pair) int { return a.x - b.x })
ansL, ansR := pairs[0].x, pairs[len(pairs)-1].x
empty := len(nums)
cnt := make([]int, empty)
left := 0
for _, p := range pairs {
r, i := p.x, p.i
if cnt[i] == 0 { // 包含 nums[i] 的数字
empty--
}
cnt[i]++
for empty == 0 { // 每个列表都至少包含一个数
l, i := pairs[left].x, pairs[left].i
if r-l < ansR-ansL {
ansL, ansR = l, r
}
cnt[i]--
if cnt[i] == 0 {
// 不包含 nums[i] 的数字
empty++
}
left++
}
}
return []int{ansL, ansR}
}
###js
var smallestRange = function(nums) {
const pairs = [];
for (let i = 0; i < nums.length; i++) {
for (const x of nums[i]) {
pairs.push([x, i]);
}
}
pairs.sort((a, b) => a[0] - b[0]);
let ansL = -Infinity, ansR = Infinity;
let empty = nums.length;
const cnt = Array(empty).fill(0);
let left = 0;
for (const [r, i] of pairs) {
if (cnt[i] === 0) { // 包含 nums[i] 的数字
empty--;
}
cnt[i]++;
while (empty === 0) { // 每个列表都至少包含一个数
const [l, i] = pairs[left];
if (r - l < ansR - ansL) {
ansL = l;
ansR = r;
}
cnt[i]--;
if (cnt[i] === 0) { // 不包含 nums[i] 的数字
empty++;
}
left++;
}
}
return [ansL, ansR];
};
###rust
impl Solution {
pub fn smallest_range(nums: Vec<Vec<i32>>) -> Vec<i32> {
let mut pairs = vec![];
for (i, arr) in nums.iter().enumerate() {
for &x in arr {
pairs.push((x, i));
}
}
pairs.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let mut ans_l = pairs[0].0;
let mut ans_r = pairs[pairs.len() - 1].0;
let mut empty = nums.len();
let mut cnt = vec![0; empty];
let mut left = 0;
for &(r, i) in &pairs {
if cnt[i] == 0 { // 包含 nums[i] 的数字
empty -= 1;
}
cnt[i] += 1;
while empty == 0 { // 每个列表都至少包含一个数
let (l, i) = pairs[left];
if r - l < ans_r - ans_l {
ans_l = l;
ans_r = r;
}
cnt[i] -= 1;
if cnt[i] == 0 { // 不包含 nums[i] 的数字
empty += 1;
}
left += 1;
}
}
vec![ans_l, ans_r]
}
}
更多相似题目,见下面数据结构题单中的「五、堆(优先队列)」,以及滑动窗口题单中的「§2.2 求最短/最小」。
欢迎关注 B站@灵茶山艾府