阅读视图

发现新文章,点击刷新页面。

统计直线 + 去掉重复统计的平行四边形,附思考题(Python/Java/C++/Go)

核心思路

  1. 本题 $n\le 500$,我们可以 $\mathcal{O}(n^2)$ 枚举所有点对组成的直线,计算直线的斜率和截距。
  2. 把斜率相同的直线放在同一组,可以从中选择一对平行边,作为梯形的顶边和底边。⚠注意:不能选两条重合的边,所以还要按照截距分组,同一组内的边不能选。
  3. 第二步把平行四边形重复统计了一次,所以还要减去任意不共线四点组成的平行四边形的个数。

具体思路

1) 计算直线的斜率和截距

对于两个点 $(x,y)$ 和 $(x_2,y_2)$,设 $\textit{dx} = x - x_2$,$\textit{dy} = y - y_2$。

经过这两个点的斜率为

$$
k =
\begin{cases}
\dfrac{\textit{dy}}{\textit{dx}}, & \textit{dx}\ne 0 \
\infty, & \textit{dx} = 0 \
\end{cases}
$$

当 $\textit{dx} \ne 0$ 时,设直线为 $Y = k\cdot X + b$,把 $(x,y)$ 代入,解得截距

$$
b = y - k\cdot x = \dfrac{y\cdot \textit{dx}-x\cdot \textit{dy}}{\textit{dx}}
$$

当 $\textit{dx} = 0$ 时,直线平行于 $y$ 轴,人为规定 $b=x$,用来区分不同的平行线。

2) 选择一对平行边的方案数

把斜率相同的直线放在同一组,可以从中选择一对平行线,作为梯形的顶边和底边。

注意:不能选两条共线的线段,所以斜率相同的组内,还要再按照截距分组,相同斜率和截距的边不能同时选。

用哈希表套哈希表统计。

统计完后,对于每一组,用「枚举右,维护左」的思想(见周赛第二题 3623. 统计梯形的数目 I),计算选一对平行边的方案数。本题由于哈希表统计的就是线段个数,所以不需要计算 $\dfrac{c(c-1)}{2}$。

3) 平行四边形的个数

第二步把平行四边形重复统计了一次,所以还要减去任意不共线四点组成的平行四边形的个数。

怎么计算平行四边形的个数?

对于平行四边形,其两条对角线的中点是重合的。利用这一性质,按照对角线的中点分组统计。

具体地,两个点 $(x,y)$ 和 $(x_2,y_2)$ 的中点为

$$
\left(\dfrac{x+x_2}{2}, \dfrac{y+y_2}{2}\right)
$$

为避免浮点数,可以把横纵坐标都乘以 $2$(这不影响分组),即

$$
(x+x_2, y+y_2)
$$

用其作为哈希表的 key。

同样地,我们不能选两条共线的线段,所以中点相同的组内,还要再按照斜率分组,相同斜率的边不能同时选。所以同样地,用哈希表套哈希表统计。

统计完后,对于每一组,用「枚举右,维护左」的思想(见周赛第二题),计算选一对中点相同的线段的方案数。

注意计算梯形个数我们用的是顶边和底边,计算平行四边形个数我们用的是对角线。

答疑

:什么情况下用浮点数是错的?

:取两个接近 $1$ 但不相同的分数 $\dfrac{a}{a+1}$ 和 $\dfrac{a-1}{a}$,根据 IEEE 754,在使用双精度浮点数的情况下,如果这两个数的绝对差 $\dfrac{1}{a(a+1)}$ 比 $2^{-52}$ 还小,那么计算机可能会把这两个数舍入到同一个附近的浮点数上。所以当 $a$ 达到 $2^{26}\approx 6.7\cdot 10^7$ 的时候,用浮点数就不一定对了。本题数据范围只有 $2\cdot 10^3$,可以放心地使用浮点数除法。

具体请看 视频讲解,欢迎点赞关注~

优化前

###py

class Solution:
    def countTrapezoids(self, points: List[List[int]]) -> int:
        cnt = defaultdict(lambda: defaultdict(int))  # 斜率 -> 截距 -> 个数
        cnt2 = defaultdict(lambda: defaultdict(int))  # 中点 -> 斜率 -> 个数

        for i, (x, y) in enumerate(points):
            for x2, y2 in points[:i]:
                dy = y - y2
                dx = x - x2
                k = dy / dx if dx else inf
                b = (y * dx - x * dy) / dx if dx else x
                cnt[k][b] += 1  # 按照斜率和截距分组
                cnt2[(x + x2, y + y2)][k] += 1  # 按照中点和斜率分组

        ans = 0
        for m in cnt.values():
            s = 0
            for c in m.values():
                ans += s * c
                s += c

        for m in cnt2.values():
            s = 0
            for c in m.values():
                ans -= s * c  # 平行四边形会统计两次,减去多统计的一次
                s += c

        return ans

###java

class Solution {
    public int countTrapezoids(int[][] points) {
        Map<Double, Map<Double, Integer>> cnt = new HashMap<>(); // 斜率 -> 截距 -> 个数
        Map<Integer, Map<Double, Integer>> cnt2 = new HashMap<>(); // 中点 -> 斜率 -> 个数

        int n = points.length;
        for (int i = 0; i < n; i++) {
            int x = points[i][0], y = points[i][1];
            for (int j = 0; j < i; j++) {
                int x2 = points[j][0], y2 = points[j][1];
                int dy = y - y2;
                int dx = x - x2;
                double k = dx != 0 ? 1.0 * dy / dx : Double.MAX_VALUE;
                double b = dx != 0 ? 1.0 * (y * dx - x * dy) / dx : x;

                // 归一化 -0.0 为 0.0
                if (k == -0.0) {
                    k = 0.0;
                }
                if (b == -0.0) {
                    b = 0.0;
                }

                // 按照斜率和截距分组 cnt[k][b]++
                cnt.computeIfAbsent(k, _ -> new HashMap<>()).merge(b, 1, Integer::sum);

                int mid = (x + x2 + 2000) * 10000 + (y + y2 + 2000); // 把二维坐标压缩成一个 int
                // 按照中点和斜率分组 cnt2[mid][k]++
                cnt2.computeIfAbsent(mid, _ -> new HashMap<>()).merge(k, 1, Integer::sum);
            }
        }

        int ans = 0;
        for (Map<Double, Integer> m : cnt.values()) {
            int s = 0;
            for (int c : m.values()) {
                ans += s * c;
                s += c;
            }
        }

        for (Map<Double, Integer> m : cnt2.values()) {
            int s = 0;
            for (int c : m.values()) {
                ans -= s * c; // 平行四边形会统计两次,减去多统计的一次
                s += c;
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int countTrapezoids(vector<vector<int>>& points) {
        // 经测试,哈希表套 map 比哈希表套哈希表更快(分组后,每一组的数据量比较小,在小数据下 map 比哈希表快)
        unordered_map<double, map<double, int>> cnt; // 斜率 -> 截距 -> 个数
        unordered_map<int, map<double, int>> cnt2; // 中点 -> 斜率 -> 个数

        int n = points.size();
        for (int i = 0; i < n; i++) {
            int x = points[i][0], y = points[i][1];
            for (int j = 0; j < i; j++) {
                int x2 = points[j][0], y2 = points[j][1];
                int dy = y - y2;
                int dx = x - x2;
                double k = dx ? 1.0 * dy / dx : DBL_MAX;
                double b = dx ? 1.0 * (y * dx - x * dy) / dx : x;
                cnt[k][b]++; // 按照斜率和截距分组
                int mid = (x + x2 + 2000) << 16 | (y + y2 + 2000); // 把二维坐标压缩成一个 int
                cnt2[mid][k]++; // 按照中点和斜率分组
            }
        }

        int ans = 0;
        for (auto& [_, m] : cnt) {
            int s = 0;
            for (auto& [_, c] : m) {
                ans += s * c;
                s += c;
            }
        }

        for (auto& [_, m] : cnt2) {
            int s = 0;
            for (auto& [_, c] : m) {
                ans -= s * c; // 平行四边形会统计两次,减去多统计的一次
                s += c;
            }
        }

        return ans;
    }
};

###go

func countTrapezoids(points [][]int) (ans int) {
cnt := map[float64]map[float64]int{} // 斜率 -> 截距 -> 个数
type pair struct{ x, y int }
cnt2 := map[pair]map[float64]int{} // 中点 -> 斜率 -> 个数

for i, p := range points {
x, y := p[0], p[1]
for _, q := range points[:i] {
x2, y2 := q[0], q[1]
dy := y - y2
dx := x - x2
k := math.MaxFloat64
b := float64(x)
if dx != 0 {
k = float64(dy) / float64(dx)
b = float64(y*dx-dy*x) / float64(dx)
}

if _, ok := cnt[k]; !ok {
cnt[k] = map[float64]int{}
}
cnt[k][b]++ // 按照斜率和截距分组

mid := pair{x + x2, y + y2}
if _, ok := cnt2[mid]; !ok {
cnt2[mid] = map[float64]int{}
}
cnt2[mid][k]++ // 按照中点和斜率分组
}
}

for _, m := range cnt {
s := 0
for _, c := range m {
ans += s * c
s += c
}
}

for _, m := range cnt2 {
s := 0
for _, c := range m {
ans -= s * c // 平行四边形会统计两次,减去多统计的一次
s += c
}
}
return
}

优化

上面做法最坏会创建 $\mathcal{O}(n^2)$ 个哈希表。这其实就是导致代码变慢的根源。

减少创建的哈希表个数,就能省下大量时间。

在随机数据下,对于相同的斜率 $k$,大概率只有一条线段,无法组成梯形。这些数据根本就不需要创建哈希表!

所以,先不创建内部的哈希表,而是先把数据保存到更轻量的列表中。在计算答案的时候,再去创建哈希表。对于大小为 $1$ 的列表,我们直接跳过,不创建哈希表。

###py

class Solution:
    def countTrapezoids(self, points: List[List[int]]) -> int:
        groups = defaultdict(list)  # 斜率 -> [截距]
        groups2 = defaultdict(list)  # 中点 -> [斜率]

        for i, (x, y) in enumerate(points):
            for x2, y2 in points[:i]:
                dy = y - y2
                dx = x - x2
                k = dy / dx if dx else inf
                b = (y * dx - x * dy) / dx if dx else x
                groups[k].append(b)
                groups2[(x + x2, y + y2)].append(k)

        ans = 0
        for g in groups.values():
            if len(g) == 1:
                continue
            s = 0
            for c in Counter(g).values():
                ans += s * c
                s += c

        for g in groups2.values():
            if len(g) == 1:
                continue
            s = 0
            for c in Counter(g).values():
                ans -= s * c  # 平行四边形会统计两次,减去多统计的一次
                s += c

        return ans

###java

class Solution {
    public int countTrapezoids(int[][] points) {
        Map<Double, List<Double>> groups = new HashMap<>(); // 斜率 -> [截距]
        Map<Integer, List<Double>> groups2 = new HashMap<>(); // 中点 -> [斜率]

        int n = points.length;
        for (int i = 0; i < n; i++) {
            int x = points[i][0], y = points[i][1];
            for (int j = 0; j < i; j++) {
                int x2 = points[j][0], y2 = points[j][1];
                int dy = y - y2;
                int dx = x - x2;
                double k = dx != 0 ? 1.0 * dy / dx : Double.MAX_VALUE;
                if (k == -0.0) {
                    k = 0.0;
                }
                double b = dx != 0 ? 1.0 * (y * dx - x * dy) / dx : x;

                groups.computeIfAbsent(k, _ -> new ArrayList<>()).add(b);
                int mid = (x + x2 + 2000) * 10000 + (y + y2 + 2000); // 把二维坐标压缩成一个 int
                groups2.computeIfAbsent(mid, _ -> new ArrayList<>()).add(k);
            }
        }

        int ans = 0;
        Map<Double, Integer> cnt = new HashMap<>();
        for (List<Double> g : groups.values()) {
            if (g.size() == 1) {
                continue;
            }
            cnt.clear();
            for (double b : g) {
                if (b == -0.0) {
                    b = 0.0;
                }
                cnt.merge(b, 1, Integer::sum);
            }
            int s = 0;
            for (int c : cnt.values()) {
                ans += s * c;
                s += c;
            }
        }

        for (List<Double> g : groups2.values()) {
            if (g.size() == 1) {
                continue;
            }
            cnt.clear();
            for (double k : g) {
                cnt.merge(k, 1, Integer::sum);
            }
            int s = 0;
            for (int c : cnt.values()) {
                ans -= s * c; // 平行四边形会统计两次,减去多统计的一次
                s += c;
            }
        }

        return ans;
    }
}

###cpp

class Solution {
public:
    int countTrapezoids(vector<vector<int>>& points) {
        unordered_map<double, vector<double>> groups; // 斜率 -> [截距]
        unordered_map<int, vector<double>> groups2; // 中点 -> [斜率]

        int n = points.size();
        for (int i = 0; i < n; i++) {
            int x = points[i][0], y = points[i][1];
            for (int j = 0; j < i; j++) {
                int x2 = points[j][0], y2 = points[j][1];
                int dy = y - y2;
                int dx = x - x2;
                double k = dx ? 1.0 * dy / dx : DBL_MAX;
                double b = dx ? 1.0 * (y * dx - x * dy) / dx : x;
                groups[k].push_back(b);
                int mid = (x + x2 + 2000) << 16 | (y + y2 + 2000); // 把二维坐标压缩成一个 int
                groups2[mid].push_back(k);
            }
        }

        int ans = 0;
        for (auto& [_, g] : groups) {
            if (g.size() == 1) {
                continue;
            }
            // 对于本题的数据,map 比哈希表快
            map<double, int> cnt;
            for (double b : g) {
                cnt[b]++;
            }
            int s = 0;
            for (auto& [_, c] : cnt) {
                ans += s * c;
                s += c;
            }
        }

        for (auto& [_, g] : groups2) {
            if (g.size() == 1) {
                continue;
            }
            map<double, int> cnt;
            for (double k : g) {
                cnt[k]++;
            }
            int s = 0;
            for (auto& [_, c] : cnt) {
                ans -= s * c; // 平行四边形会统计两次,减去多统计的一次
                s += c;
            }
        }

        return ans;
    }
};

###go

func countTrapezoids(points [][]int) (ans int) {
groups := map[float64][]float64{} // 斜率 -> [截距]
type pair struct{ x, y int }
groups2 := map[pair][]float64{} // 中点 -> [斜率]

for i, p := range points {
x, y := p[0], p[1]
for _, q := range points[:i] {
x2, y2 := q[0], q[1]
dy := y - y2
dx := x - x2
k := math.MaxFloat64
b := float64(x)
if dx != 0 {
k = float64(dy) / float64(dx)
b = float64(y*dx-dy*x) / float64(dx)
}

groups[k] = append(groups[k], b)
mid := pair{x + x2, y + y2}
groups2[mid] = append(groups2[mid], k)
}
}

for _, g := range groups {
if len(g) == 1 {
continue
}
cnt := map[float64]int{}
for _, b := range g {
cnt[b]++
}
s := 0
for _, c := range cnt {
ans += s * c
s += c
}
}

for _, g := range groups2 {
if len(g) == 1 {
continue
}
cnt := map[float64]int{}
for _, k := range g {
cnt[k]++
}
s := 0
for _, c := range cnt {
ans -= s * c // 平行四边形会统计两次,减去多统计的一次
s += c
}
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2)$,其中 $n$ 是 $\textit{points}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n^2)$。

思考题

  1. 梯形改成正方形怎么做?
  2. 梯形改成菱形怎么做?
  3. 梯形改成矩形怎么做?
  4. 梯形改成等腰梯形怎么做?
  5. 梯形改成直角梯形怎么做?

欢迎在评论区分享你的思路/代码。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

枚举右,维护左(Python/Java/C++/Go)

首先,统计每一行的点的个数,如果这一行有 $c$ 个点,那么从 $c$ 个点中选 $2$ 个点,有 $\dfrac{c(c-1)}{2}$ 种选法,可以组成一条水平边,即梯形的顶边或底边。

枚举每一行,设这一行有 $k=\dfrac{c(c-1)}{2}$ 条水平边,那么另外一条边就是之前遍历过的行的边数 $s$。根据乘法原理,之前遍历过的行与这一行,一共可以组成

$$
s\cdot k
$$

个水平梯形,加入答案。

注意:另外一条边不能是其余所有行,这会导致重复计算。

在最坏情况下,有两行,每行 $\dfrac{n}{2}$ 个点,组成约 $\dfrac{n^2}{8}$ 条线段,答案约为 $\dfrac{n^4}{64} = 1.5625\times 10^{18}$,这不超过 $64$ 位整数最大值,所以无需在循环中取模。

本题视频讲解 详细介绍了本题的计算过程和注意事项,欢迎点赞关注~

###py

class Solution:
    def countTrapezoids(self, points: List[List[int]]) -> int:
        MOD = 1_000_000_007
        cnt = Counter(p[1] for p in points)  # 统计每一行(水平线)有多少个点
        ans = s = 0
        for c in cnt.values():
            k = c * (c - 1) // 2
            ans += s * k
            s += k
        return ans % MOD

###java

class Solution {
    private static final int MOD = 1_000_000_007;

    public int countTrapezoids(int[][] points) {
        Map<Integer, Integer> cnt = new HashMap<>(points.length, 1); // 预分配空间
        for (int[] p : points) {
            cnt.merge(p[1], 1, Integer::sum); // 统计每一行(水平线)有多少个点
        }

        long ans = 0, s = 0;
        for (int c : cnt.values()) {
            long k = (long) c * (c - 1) / 2;
            ans += s * k;
            s += k;
        }
        return (int) (ans % MOD);
    }
}

###cpp

class Solution {
public:
    int countTrapezoids(vector<vector<int>>& points) {
        const int MOD = 1'000'000'007;
        unordered_map<int, int> cnt;
        for (auto& p : points) {
            cnt[p[1]]++; // 统计每一行(水平线)有多少个点
        }

        long long ans = 0, s = 0;
        for (auto& [_, c] : cnt) {
            long long k = 1LL * c * (c - 1) / 2;
            ans += s * k;
            s += k;
        }
        return ans % MOD;
    }
};

###go

func countTrapezoids(points [][]int) (ans int) {
const mod = 1_000_000_007
cnt := make(map[int]int, len(points)) // 预分配空间
for _, p := range points {
cnt[p[1]]++ // 统计每一行(水平线)有多少个点
}

s := 0
for _, c := range cnt {
k := c * (c - 1) / 2
ans += s * k
s += k
}
return ans % mod
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{points}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

专题训练

见下面数据结构题单的「§0.1 枚举右,维护左」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

两种方法:二分答案 / 排序+贪心(Python/Java/C++/Go)

方法一:二分答案

如果可以让 $n$ 台电脑同时运行 $x$ 分钟,那么必然可以同时运行 $x-1,x-2,\ldots$ 分钟(要求更宽松);如果无法让 $n$ 台电脑同时运行 $x$ 分钟,那么必然无法同时运行 $x+1,x+2,\ldots$ 分钟(要求更苛刻)。

据此,可以二分猜答案。关于二分算法的原理,请看 二分查找 红蓝染色法【基础算法精讲 04】

假设可以让 $n$ 台电脑同时运行 $x$ 分钟,那么对于电量大于 $x$ 的电池,其只能被使用 $x$ 分钟,因此每个电池的使用时间至多为 $\min(\textit{batteries}[i], x)$。累加所有电池的使用时间,记作 $\textit{sum}$。那么要让 $n$ 台电脑同时运行 $x$ 分钟,必要条件是 $n\cdot x\le \textit{sum}$。

下面证明该条件也是充分的,即如果 $n\cdot x\le \textit{sum}$ 成立,那么一定存在一种安排电池的方式,可以让 $n$ 台电脑同时运行 $x$ 分钟。

构造方法如下:

对于电量 $\ge x$ 的电池,我们可以让其给一台电脑供电 $x$ 分钟。由于一个电池不能同时给多台电脑供电,因此该电池若给一台电脑供电 $x$ 分钟,那它就不能用于其他电脑了(因为电脑运行时间就是 $x$ 分钟)。我们可以将所有电量 $\ge x$ 的电池各给一台电脑供电。

对于其余电池,设其电量和为 $\textit{sum}'$,剩余 $n'$ 台电脑未被供电。我们可以随意选择剩下的电池,供给剩余的第一台电脑(用完一个电池就换下一个电池),多余的电池电量与剩下的电池一起供给剩余的第二台电脑,依此类推。注意由于这些电池的电量均小于 $x$,按照这种做法是不会出现同一个电池在同一时间供给多台电脑的(如果某个电池供给了两台电脑,可以将这个电池的供电时间划分到第一台电脑的末尾和第二台电脑的开头)。

由于 $\textit{sum}'=\textit{sum}-(n-n')\cdot x$,结合 $n\cdot x\le \textit{sum}$ 可以得到 $n'\cdot x\le \textit{sum}'$,按照上述供电方案(用完一个电池就换下一个电池),这 $n'$ 台电脑可以运行至少 $x$ 分钟。充分性得证。

细节

下面代码采用开区间二分,这仅仅是二分的一种写法,使用闭区间或者半闭半开区间都是可以的。

  • 开区间左端点初始值:$0$。不运行任何电脑,一定满足要求。
  • 开区间右端点初始值:平均值加一,即 $\left\lfloor\dfrac{\sum \textit{batteries}[i]}{n}\right\rfloor + 1$。一定无法满足要求。
class Solution:
    def maxRunTime(self, n: int, batteries: List[int]) -> int:
        l, r = 0, sum(batteries) // n + 1
        while l + 1 < r:
            x = (l + r) // 2
            if n * x <= sum(min(b, x) for b in batteries):
                l = x
            else:
                r = x
        return l
class Solution:
    def maxRunTime(self, n: int, batteries: List[int]) -> int:
        r = sum(batteries) // n
        # 二分找最小的不满足要求的 x+1,那么最大的满足要求的就是 x
        check = lambda x: n * (x + 1) > sum(min(b, x + 1) for b in batteries)
        return bisect_left(range(r), True, key=check)
class Solution {
    public long maxRunTime(int n, int[] batteries) {
        long tot = 0;
        for (int b : batteries) {
            tot += b;
        }

        long l = 0;
        long r = tot / n + 1;
        while (l + 1 < r) {
            long x = l + (r - l) / 2;
            long sum = 0;
            for (int b : batteries) {
                sum += Math.min(b, x);
            }
            if (n * x <= sum) {
                l = x;
            } else {
                r = x;
            }
        }
        return l;
    }
}
class Solution {
public:
    long long maxRunTime(int n, vector<int>& batteries) {
        long long tot = reduce(batteries.begin(), batteries.end(), 0LL);
        long long l = 0, r = tot / n + 1;
        while (l + 1 < r) {
            long long x = l + (r - l) / 2;
            long long sum = 0;
            for (long long b : batteries) {
                sum += min(b, x);
            }
            (n * x <= sum ? l : r) = x;
        }
        return l;
    }
};
func maxRunTime(n int, batteries []int) int64 {
tot := 0
for _, b := range batteries {
tot += b
}

return int64(sort.Search(tot/n, func(x int) bool {
x++
sum := 0
for _, b := range batteries {
sum += min(b, x)
}
return n*x > sum
}))
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(m\log (S/n))$,其中 $m$ 是 $\textit{batteries}$ 的长度,$S$ 是 $\textit{batteries}$ 的元素和。
  • 空间复杂度:$\mathcal{O}(1)$。

方法二:排序 + 贪心

受解法一的启发,我们可以得出如下贪心策略:

记电池电量和为 $\textit{sum}$,则理论上至多可以供电 $x=\Big\lfloor\dfrac{\textit{sum}}{n}\Big\rfloor$ 分钟。我们对电池电量从大到小排序,然后从电量最大的电池开始遍历:

  • 若该电池电量超过 $x$,则将其供给一台电脑,问题缩减为 $n-1$ 台电脑的子问题。

  • 若该电池电量不超过 $x$,则其余电池的电量均不超过 $x$,此时有

    $$
    n\cdot x=n\cdot\Big\lfloor\dfrac{\textit{sum}}{n}\Big\rfloor \le \textit{sum}
    $$

    根据解法一的结论,这些电池可以给 $n$ 台电脑供电 $x$ 分钟。

由于随着问题规模减小,$x$ 不会增加,因此若遍历到一个电量不超过 $x$ 的电池时,可以直接返回 $x$ 作为答案。

class Solution:
    def maxRunTime(self, n: int, batteries: List[int]) -> int:
        batteries.sort(reverse=True)
        s = sum(batteries)
        for b in batteries:
            if b <= s // n:
                return s // n
            s -= b
            n -= 1
class Solution {
    public long maxRunTime(int n, int[] batteries) {
        Arrays.sort(batteries);

        long sum = 0;
        for (int b : batteries) {
            sum += b;
        }

        for (int i = batteries.length - 1; ; i--) {
            if (batteries[i] <= sum / n) {
                return sum / n;
            }
            sum -= batteries[i];
            n--;
        }
    }
}
class Solution {
public:
    long long maxRunTime(int n, vector<int>& batteries) {
        ranges::sort(batteries, greater());
        long long sum = reduce(batteries.begin(), batteries.end(), 0LL);
        for (int i = 0; ; i++) {
            if (batteries[i] <= sum / n) {
                return sum / n;
            }
            sum -= batteries[i];
            n--;
        }
    }
};
func maxRunTime(n int, batteries []int) int64 {
slices.Sort(batteries)
sum := 0
for _, b := range batteries {
sum += b
}
for i := len(batteries) - 1; ; i-- {
if batteries[i] <= sum/n {
return int64(sum / n)
}
sum -= batteries[i]
n--
}
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(m\log m)$,其中 $m$ 是 $\textit{batteries}$ 的长度。瓶颈在排序上。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销。

专题训练

  1. 二分题单的「§2.2 求最大」。
  2. 贪心题单的「§1.1 从最小/最大开始贪心」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/最短路/最小生成树/二分图/基环树/欧拉路径)
  7. 动态规划(入门/背包/状态机/划分/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

❌