普通视图

发现新文章,点击刷新页面。
今天 — 2026年1月5日首页

最大方阵和

2021年8月22日 17:11

方法一:贪心

提示 $1$

为了使得操作后方阵总和最大,我们需要使得负数元素的总和尽可能大

对于方阵中的两个负数元素,一定存在一系列的操作使得这两个负数元素均变为正数,且其余元素不变。

对于方阵中的一个正数元素和一个负数元素,一定存在一系列的操作使得这两个元素交换正负,且其余元素不变。

提示 $1$ 解释

第一部分是显然的。

对于第二部分,我们可以任意选择一条连接两个负数元素的有向路径,按顺序对路径上(除终点以外)的每个元素和它对应的下一个元素都执行一次操作。最终路径上除了两个端点以外的其他元素都被执行了两次操作,因此数值不变;两个端点元素都被执行了一次操作二变为正数。

由于方阵是网格,因此上述路径一定存在。

对于第三部分,将第二部分中的一个负数更改为正数即可证明。

提示 $2$

如果方阵中存在一个元素为 $0$,另一个元素为负数。那么一定存在一系列的操作使得负数元素变为正数,且其余元素不变。

提示 $2$ 解释

类似 提示 $1$,将一个负数元素更改为 $0$ 即可证明。

提示 $3$

如果方阵中存在 $0$,那么一定可以通过一系列的操作使得方阵中所有元素均为非负数;

如果方阵中不存在 $0$,那么:

  • 如果方阵中有奇数个负数元素,那么一定可以通过一系列的操作使得方阵中只有一个负数元素,且该负数元素可以在任何位置。同时,无论如何操作,方阵中必定存在负数元素。

  • 如果方阵中有偶数个负数元素,那么一定可以通过一系列的操作使得方阵中不存在负数元素。

提示 $3$ 解释

对于第一部分,反复对 $0$ 和负数元素进行 提示 $2$ 的操作即可。

对于第二部分,我们首先可以证明如果方阵不存在 $0$,那么负数元素数量奇偶性不会改变。然后,我们可以根据 提示 $1$ 构造出一系列操作从而达到对应的要求。

思路与算法

根据 提示 $3$,我们可以按照方阵的元素分为以下几种情况:

  • 方阵中有 $0$,那么最大方阵和即为所有元素的绝对值之和;

  • 方阵中没有 $0$,且负数元素数量为偶数,那么最大方阵和即为所有元素的绝对值之和;

  • 方阵中没有 $0$,且负数元素数量为奇数,那么最大方阵和即为所有元素的绝对值之和减去所有元素最小绝对值的两倍。

其中,第一种情况也可以按照负数元素数量的奇偶性划入后两种情况中(此时最小绝对值一定为 $0$)。

我们遍历方阵,维护负数元素的数量、元素的最小绝对值以及所有元素的绝对值之和。随后,我们按照负数元素数量的奇偶性计算对应的最大元素和并返回。

最后,矩阵所有元素绝对值之和可能超过 $32$ 位整数的上限,因此对于 $\texttt{C++}$ 等语言,需要使用 $64$ 位整数来维护。

代码

###C++

class Solution {
public:
    long long maxMatrixSum(vector<vector<int>>& matrix) {
        int n = matrix.size();
        int cnt = 0;   // 负数元素的数量
        long long total = 0;   // 所有元素的绝对值之和
        int mn = INT_MAX;   // 方阵元素的最小绝对值
        for (int i = 0; i < n; ++i){
            for (int j = 0; j < n; ++j){
                mn = min(mn, abs(matrix[i][j]));
                if (matrix[i][j] < 0){
                    ++cnt;
                }
                total += abs(matrix[i][j]);
            }
        }
        // 按照负数元素的数量的奇偶性讨论
        if (cnt % 2 == 0){
            return total;
        }
        else{
            return total - 2 * mn;
        }
    }
};

###Python

class Solution:
    def maxMatrixSum(self, matrix: List[List[int]]) -> int:
        n = len(matrix)
        cnt = 0   # 负数元素的数量
        total = 0   # 所有元素的绝对值之和
        mn = float("INF")   # 方阵元素的最小绝对值
        for i in range(n):
            for j in range(n):
                mn = min(mn, abs(matrix[i][j]))
                if matrix[i][j] < 0:
                    cnt += 1
                total += abs(matrix[i][j])
        # 按照负数元素的数量的奇偶性讨论
        if cnt % 2 == 0:
            return total
        else:
            return total - 2 * mn

###Java

class Solution {
    public long maxMatrixSum(int[][] matrix) {
        int n = matrix.length;
        int cnt = 0;   // 负数元素的数量
        long total = 0;   // 所有元素的绝对值之和
        int mn = Integer.MAX_VALUE;   // 方阵元素的最小绝对值
        for (int i = 0; i < n; ++i){
            for (int j = 0; j < n; ++j){
                mn = Math.min(mn, Math.abs(matrix[i][j]));
                if (matrix[i][j] < 0){
                    ++cnt;
                }
                total += Math.abs(matrix[i][j]);
            }
        }
        // 按照负数元素的数量的奇偶性讨论
        if (cnt % 2 == 0){
            return total;
        } else {
            return total - 2 * mn;
        }
    }
}

###C#

public class Solution {
    public long MaxMatrixSum(int[][] matrix) {
        int n = matrix.Length;
        int cnt = 0;   // 负数元素的数量
        long total = 0;   // 所有元素的绝对值之和
        int mn = int.MaxValue;   // 方阵元素的最小绝对值
        for (int i = 0; i < n; ++i){
            for (int j = 0; j < n; ++j){
                mn = Math.Min(mn, Math.Abs(matrix[i][j]));
                if (matrix[i][j] < 0){
                    ++cnt;
                }
                total += Math.Abs(matrix[i][j]);
            }
        }
        // 按照负数元素的数量的奇偶性讨论
        if (cnt % 2 == 0){
            return total;
        } else{
            return total - 2 * mn;
        }
    }
}

###Go

func maxMatrixSum(matrix [][]int) int64 {
    n := len(matrix)
    cnt := 0   // 负数元素的数量
    total := int64(0)   // 所有元素的绝对值之和
    mn := 1 << 30   // 方阵元素的最小绝对值
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            mn = min(mn, abs(matrix[i][j]))
            if matrix[i][j] < 0 {
                cnt++
            }
            total += int64(abs(matrix[i][j]))
        }
    }
    // 按照负数元素的数量的奇偶性讨论
    if cnt % 2 == 0 {
        return total
    } else {
        return total - int64(2 * mn)
    }
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

###C

long long maxMatrixSum(int** matrix, int matrixSize, int* matrixColSize) {
    int n = matrixSize;
    int cnt = 0;   // 负数元素的数量
    long long total = 0;   // 所有元素的绝对值之和
    int mn = INT_MAX;   // 方阵元素的最小绝对值
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            int abs_val = abs(matrix[i][j]);
            if (abs_val < mn) {
                mn = abs_val;
            }
            if (matrix[i][j] < 0) {
                ++cnt;
            }
            total += abs_val;
        }
    }
    // 按照负数元素的数量的奇偶性讨论
    if (cnt % 2 == 0) {
        return total;
    } else {
        return total - 2 * mn;
    }
}

###JavaScript

var maxMatrixSum = function(matrix) {
    const n = matrix.length;
    let cnt = 0;   // 负数元素的数量
    let total = 0;   // 所有元素的绝对值之和
    let mn = Number.MAX_SAFE_INTEGER;   // 方阵元素的最小绝对值
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            const absVal = Math.abs(matrix[i][j]);
            mn = Math.min(mn, absVal);
            if (matrix[i][j] < 0) {
                cnt++;
            }
            total += absVal;
        }
    }
    // 按照负数元素的数量的奇偶性讨论
    if (cnt % 2 === 0) {
        return total;
    } else {
        return total - 2 * mn;
    }
};

###TypeScript

function maxMatrixSum(matrix: number[][]): number {
    const n = matrix.length;
    let cnt = 0;   // 负数元素的数量
    let total = 0;   // 所有元素的绝对值之和
    let mn = Number.MAX_SAFE_INTEGER;   // 方阵元素的最小绝对值
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            const absVal = Math.abs(matrix[i][j]);
            mn = Math.min(mn, absVal);
            if (matrix[i][j] < 0) {
                cnt++;
            }
            total += absVal;
        }
    }
    // 按照负数元素的数量的奇偶性讨论
    if (cnt % 2 === 0) {
        return total;
    } else {
        return total - 2 * mn;
    }
}

###Rust

impl Solution {
    pub fn max_matrix_sum(matrix: Vec<Vec<i32>>) -> i64 {
        let n = matrix.len();
        let mut cnt = 0;   // 负数元素的数量
        let mut total: i64 = 0;   // 所有元素的绝对值之和
        let mut mn = i32::MAX;   // 方阵元素的最小绝对值
        for i in 0..n {
            for j in 0..n {
                let abs_val = matrix[i][j].abs();
                mn = mn.min(abs_val);
                if matrix[i][j] < 0 {
                    cnt += 1;
                }
                total += abs_val as i64;
            }
        }
        // 按照负数元素的数量的奇偶性讨论
        if cnt % 2 == 0 {
            total
        } else {
            total - 2 * mn as i64
        }
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 为 $\textit{matrix}$ 的行数,$n$ 为 $\textit{matrix}$ 的列数。

  • 空间复杂度:$O(1)$。

昨天 — 2026年1月4日首页

四因数

2020年3月23日 18:55

方法一:枚举

我们可以遍历数组 nums 中的每个元素,依次判断这些元素是否恰好有四个因数。对于任一元素 x,我们可以用类似质数判定的方法得到它的因数个数,其本质为:如果整数 x 有因数 y,那么也必有因数 x/y,并且 yx/y 中至少有一个不大于 sqrt(x)。这样我们只需要在 [1, sqrt(x)] 的区间内枚举可能为整数 x 的因数 y,并通过 x/y 得到整数 x 的其它因数,时间复杂度为 $O(\sqrt{x})$。

如果 x 恰好有四个因数,我们就将其因数之和累加到答案中。

###C++

class Solution {
public:
    int sumFourDivisors(vector<int>& nums) {
        int ans = 0;
        for (int num: nums) {
            // factor_cnt: 因数的个数
            // factor_sum: 因数的和
            int factor_cnt = 0, factor_sum = 0;
            for (int i = 1; i * i <= num; ++i) {
                if (num % i == 0) {
                    ++factor_cnt;
                    factor_sum += i;
                    if (i * i != num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                        ++factor_cnt;
                        factor_sum += num / i;
                    }
                }
            }
            if (factor_cnt == 4) {
                ans += factor_sum;
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public int sumFourDivisors(int[] nums) {
        int ans = 0;
        for (int num : nums) {
            // factor_cnt: 因数的个数
            // factor_sum: 因数的和
            int factor_cnt = 0, factor_sum = 0;
            for (int i = 1; i * i <= num; ++i) {
                if (num % i == 0) {
                    ++factor_cnt;
                    factor_sum += i;
                    if (i * i != num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                        ++factor_cnt;
                        factor_sum += num / i;
                    }
                }
            }
            if (factor_cnt == 4) {
                ans += factor_sum;
            }
        }
        return ans;
    }
}

###Python

class Solution:
    def sumFourDivisors(self, nums: List[int]) -> int:
        ans = 0
        for num in nums:
            # factor_cnt: 因数的个数
            # factor_sum: 因数的和
            factor_cnt = factor_sum = 0
            i = 1
            while i * i <= num:
                if num % i == 0:
                    factor_cnt += 1
                    factor_sum += i
                    if i * i != num:   # 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                        factor_cnt += 1
                        factor_sum += num // i
                i += 1
            if factor_cnt == 4:
                ans += factor_sum
        return ans

###C#

public class Solution {
    public int SumFourDivisors(int[] nums) {
        int ans = 0;
        foreach (int num in nums) {
            // factor_cnt: 因数的个数
            // factor_sum: 因数的和
            int factor_cnt = 0, factor_sum = 0;
            for (int i = 1; i * i <= num; ++i) {
                if (num % i == 0) {
                    ++factor_cnt;
                    factor_sum += i;
                    if (i * i != num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                        ++factor_cnt;
                        factor_sum += num / i;
                    }
                }
            }
            if (factor_cnt == 4) {
                ans += factor_sum;
            }
        }
        return ans;
    }
}

###Go

func sumFourDivisors(nums []int) int {
    ans := 0
    for _, num := range nums {
        // factor_cnt: 因数的个数
        // factor_sum: 因数的和
        factor_cnt, factor_sum := 0, 0
        for i := 1; i*i <= num; i++ {
            if num%i == 0 {
                factor_cnt++
                factor_sum += i
                if i*i != num {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                    factor_cnt++
                    factor_sum += num / i
                }
            }
        }
        if factor_cnt == 4 {
            ans += factor_sum
        }
    }
    return ans
}

###C

int sumFourDivisors(int* nums, int numsSize) {
    int ans = 0;
    for (int idx = 0; idx < numsSize; idx++) {
        int num = nums[idx];
        // factor_cnt: 因数的个数
        // factor_sum: 因数的和
        int factor_cnt = 0, factor_sum = 0;
        for (int i = 1; i * i <= num; ++i) {
            if (num % i == 0) {
                ++factor_cnt;
                factor_sum += i;
                if (i * i != num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                    ++factor_cnt;
                    factor_sum += num / i;
                }
            }
        }
        if (factor_cnt == 4) {
            ans += factor_sum;
        }
    }
    return ans;
}

###JavaScript

var sumFourDivisors = function(nums) {
    let ans = 0;
    for (const num of nums) {
        // factor_cnt: 因数的个数
        // factor_sum: 因数的和
        let factor_cnt = 0, factor_sum = 0;
        for (let i = 1; i * i <= num; ++i) {
            if (num % i === 0) {
                ++factor_cnt;
                factor_sum += i;
                if (i * i !== num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                    ++factor_cnt;
                    factor_sum += num / i;
                }
            }
        }
        if (factor_cnt === 4) {
            ans += factor_sum;
        }
    }
    return ans;
};

###TypeScript

function sumFourDivisors(nums: number[]): number {
    let ans = 0;
    for (const num of nums) {
        // factor_cnt: 因数的个数
        // factor_sum: 因数的和
        let factor_cnt = 0, factor_sum = 0;
        for (let i = 1; i * i <= num; ++i) {
            if (num % i === 0) {
                ++factor_cnt;
                factor_sum += i;
                if (i * i !== num) {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                    ++factor_cnt;
                    factor_sum += num / i;
                }
            }
        }
        if (factor_cnt === 4) {
            ans += factor_sum;
        }
    }
    return ans;
}

###Rust

impl Solution {
    pub fn sum_four_divisors(nums: Vec<i32>) -> i32 {
        let mut ans = 0;
        for &num in &nums {
            // factor_cnt: 因数的个数
            // factor_sum: 因数的和
            let mut factor_cnt = 0;
            let mut factor_sum = 0;
            let mut i = 1;
            while i * i <= num {
                if num % i == 0 {
                    factor_cnt += 1;
                    factor_sum += i;
                    if i * i != num {   // 判断 i 和 num/i 是否相等,若不相等才能将 num/i 看成新的因数
                        factor_cnt += 1;
                        factor_sum += num / i;
                    }
                }
                i += 1;
            }
            if factor_cnt == 4 {
                ans += factor_sum;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(N\sqrt{C})$,其中 $N$ 是数组 nums 的长度,$C$ 是数组 nums 中元素值的范围,在本题中 $C$ 不超过 $10^5$。

  • 空间复杂度:$O(1)$。

方法二:预处理

预备知识

分析与算法

直觉告诉我们,恰好有四个因数的整数不会有很多,我们是否可以预先找出它们呢?

根据「算数基本定理」(又叫「唯一分解定理」),如果整数 $x$ 可以分解为:

$$
x = p_1^{\alpha_1}p_2^{\alpha_2}\cdots p_k^{\alpha_k}
$$

其中 $p_i$ 为互不相同的质数(即 $x$ 的质因数)。那么 $x$ 的因数个数为:

$$
\textit{factor_count}(x) = \prod_{i=1}^k (\alpha_i + 1)
$$

如果 $\textit{factor_count}(x)$ 的值为 $4$,那么只有两种可能:

  • 整数 $x$ 只有一个质因数,对应的指数为 $3$,此时 $\textit{factor_count}(x) = (3+1) = 4$;

  • 整数 $x$ 有两个质因数,对应的指数均为 $1$,此时 $\textit{factor_count}(x) = (1+1)(1+1) = 4$。

对于第一种情况,我们需要找到所有不大于 $C^{1/3}$ 的质数;对于第二种情况,我们需要找到所有不大于 $C$ 的质数,再将它们两两相乘并筛去超过 $C$ 的那些结果。这里 $C$ 的定义与方法一中的复杂度分析部分一致。综上所述,我们需要找到所有不大于 $C$ 的质数。

我们如何找出所有不大于 $C$ 的质数呢?这时就需要「埃拉托斯特尼筛法」或「欧拉筛法」的帮助了。它们可以帮助我们快速找到这些质数。这两种筛法的算法细节不是这篇题解的重点,这里不再赘述。在找到了这些质数后,我们就可以构造出所有满足上述两种可能的 $x$ 了。我们将 $x$ 以及它的因数之和存入哈希映射(HashMap)中,这样就可以在 $O(1)$ 的时间判断数组 nums 中的每个元素是否满足要求,并统计满足要求的元素的因数之和了。

下面的代码给出了 Python 和 C++ 语言的「埃拉托斯特尼筛法」以及「欧拉筛法」的实现。

###C++

class Solution {
public:
    int sumFourDivisors(vector<int>& nums) {
        // C 是数组 nums 元素的上限,C3 是 C 的立方根
        int C = 100000, C3 = 46;
        
        vector<int> isprime(C + 1, 1);
        vector<int> primes;

        // 埃拉托斯特尼筛法
        for (int i = 2; i <= C; ++i) {
            if (isprime[i]) {
                primes.push_back(i);
            }
            for (int j = i + i; j <= C; j += i) {
                isprime[j] = 0;
            }
        }

        // 欧拉筛法
        /*
        for (int i = 2; i <= C; ++i) {
            if (isprime[i]) {
                primes.push_back(i);
            }
            for (int prime: primes) {
                if (i * prime > C) {
                    break;
                }
                isprime[i * prime] = 0;
                if (i % prime == 0) {
                    break;
                }
            }
        }
        */
        
        // 通过质数表构造出所有的四因数
        unordered_map<int, int> factor4;
        for (int prime: primes) {
            if (prime <= C3) {
                factor4[prime * prime * prime] = 1 + prime + prime * prime + prime * prime * prime;
            }
        }
        for (int i = 0; i < primes.size(); ++i) {
            for (int j = i + 1; j < primes.size(); ++j) {
                if (primes[i] <= C / primes[j]) {
                    factor4[primes[i] * primes[j]] = 1 + primes[i] + primes[j] + primes[i] * primes[j];
                }
                else {
                    break;
                }
            }
        }

        int ans = 0;
        for (int num: nums) {
            if (factor4.count(num)) {
                ans += factor4[num];
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public int sumFourDivisors(int[] nums) {
        // C 是数组 nums 元素的上限,C3 是 C 的立方根
        int C = 100000, C3 = 46;
        
        boolean[] isPrime = new boolean[C + 1];
        Arrays.fill(isPrime, true);
        List<Integer> primes = new ArrayList<Integer>();

        // 埃拉托斯特尼筛法
        for (int i = 2; i <= C; ++i) {
            if (isPrime[i]) {
                primes.add(i);
            }
            for (int j = i + i; j <= C; j += i) {
                isPrime[j] = false;
            }
        }

        // 欧拉筛法
        /*
        for (int i = 2; i <= C; ++i) {
            if (isPrime[i]) {
                primes.add(i);
            }
            for (int prime : primes) {
                if (i * prime > C) {
                    break;
                }
                isPrime[i * prime] = false;
                if (i % prime == 0) {
                    break;
                }
            }
        }
        */
        
        // 通过质数表构造出所有的四因数
        Map<Integer, Integer> factor4 = new HashMap<Integer, Integer>();
        for (int prime : primes) {
            if (prime <= C3) {
                factor4.put(prime * prime * prime, 1 + prime + prime * prime + prime * prime * prime);
            }
        }
        for (int i = 0; i < primes.size(); ++i) {
            for (int j = i + 1; j < primes.size(); ++j) {
                if (primes.get(i) <= C / primes.get(j)) {
                    factor4.put(primes.get(i) * primes.get(j), 1 + primes.get(i) + primes.get(j) + primes.get(i) * primes.get(j));
                } else {
                    break;
                }
            }
        }

        int ans = 0;
        for (int num : nums) {
            if (factor4.containsKey(num)) {
                ans += factor4.get(num);
            }
        }
        return ans;
    }
}

###Python

class Solution:
    def sumFourDivisors(self, nums: List[int]) -> int:
        # C 是数组 nums 元素的上限,C3 是 C 的立方根
        C, C3 = 100000, 46

        isprime = [True] * (C + 1)
        primes = list()

        # 埃拉托斯特尼筛法
        for i in range(2, C + 1):
            if isprime[i]:
                primes.append(i)
            for j in range(i + i, C + 1, i):
                isprime[j] = False
        
        # 欧拉筛法
        """
        for i in range(2, C + 1):
            if isprime[i]:
                primes.append(i)
            for prime in primes:
                if i * prime > C:
                    break
                isprime[i * prime] = False
                if i % prime == 0:
                    break
        """
        
        # 通过质数表构造出所有的四因数
        factor4 = dict()
        for prime in primes:
            if prime <= C3:
                factor4[prime**3] = 1 + prime + prime**2 + prime**3
        for i in range(len(primes)):
            for j in range(i + 1, len(primes)):
                if primes[i] * primes[j] <= C:
                    factor4[primes[i] * primes[j]] = 1 + primes[i] + primes[j] + primes[i] * primes[j]
                else:
                    break
        
        ans = 0
        for num in nums:
            if num in factor4:
                ans += factor4[num]
        return ans

###C#

public class Solution {
    public int SumFourDivisors(int[] nums) {
        // C 是数组 nums 元素的上限,C3 是 C 的立方根
        int C = 100000, C3 = 46;
        
        int[] isprime = new int[C + 1];
        for (int i = 2; i <= C; i++) isprime[i] = 1;
        List<int> primes = new List<int>();

        // 埃拉托斯特尼筛法
        for (int i = 2; i <= C; ++i) {
            if (isprime[i] == 1) {
                primes.Add(i);
            }
            for (int j = i + i; j <= C; j += i) {
                isprime[j] = 0;
            }
        }

        // 欧拉筛法
        /*
        for (int i = 2; i <= C; ++i) {
            if (isprime[i] == 1) {
                primes.Add(i);
            }
            foreach (int prime in primes) {
                if (i * prime > C) {
                    break;
                }
                isprime[i * prime] = 0;
                if (i % prime == 0) {
                    break;
                }
            }
        }
        */
        
        // 通过质数表构造出所有的四因数
        Dictionary<int, int> factor4 = new Dictionary<int, int>();
        foreach (int prime in primes) {
            if (prime <= C3) {
                factor4[prime * prime * prime] = 1 + prime + prime * prime + prime * prime * prime;
            }
        }
        for (int i = 0; i < primes.Count; ++i) {
            for (int j = i + 1; j < primes.Count; ++j) {
                if (primes[i] <= C / primes[j]) {
                    factor4[primes[i] * primes[j]] = 1 + primes[i] + primes[j] + primes[i] * primes[j];
                }
                else {
                    break;
                }
            }
        }

        int ans = 0;
        foreach (int num in nums) {
            if (factor4.ContainsKey(num)) {
                ans += factor4[num];
            }
        }
        return ans;
    }
}

###Go

func sumFourDivisors(nums []int) int {
    // C 是数组 nums 元素的上限,C3 是 C 的立方根
    C, C3 := 100000, 46
    
    isprime := make([]int, C+1)
    for i := 2; i <= C; i++ {
        isprime[i] = 1
    }
    primes := []int{}

    // 埃拉托斯特尼筛法
    for i := 2; i <= C; i++ {
        if isprime[i] == 1 {
            primes = append(primes, i)
        }
        for j := i + i; j <= C; j += i {
            isprime[j] = 0
        }
    }

    // 欧拉筛法
    /*
    for i := 2; i <= C; i++ {
        if isprime[i] == 1 {
            primes = append(primes, i)
        }
        for _, prime := range primes {
            if i * prime > C {
                break
            }
            isprime[i * prime] = 0
            if i % prime == 0 {
                break
            }
        }
    }
    */
    
    // 通过质数表构造出所有的四因数
    factor4 := make(map[int]int)
    for _, prime := range primes {
        if prime <= C3 {
            factor4[prime * prime * prime] = 1 + prime + prime * prime + prime * prime * prime
        }
    }
    for i := 0; i < len(primes); i++ {
        for j := i + 1; j < len(primes); j++ {
            if primes[i] <= C / primes[j] {
                factor4[primes[i] * primes[j]] = 1 + primes[i] + primes[j] + primes[i] * primes[j]
            } else {
                break
            }
        }
    }

    ans := 0
    for _, num := range nums {
        if val, exists := factor4[num]; exists {
            ans += val
        }
    }
    return ans
}

###C

int sumFourDivisors(int* nums, int numsSize) {
    // C 是数组 nums 元素的上限,C3 是 C 的立方根
    const int C = 100000, C3 = 46;
    
    int* isprime = (int*)malloc((C + 1) * sizeof(int));
    memset(isprime, 0, (C + 1) * sizeof(int));
    int* primes = (int*)malloc((C + 1) * sizeof(int));
    int primeCount = 0;

    // 埃拉托斯特尼筛法
    for (int i = 2; i <= C; ++i) {
        isprime[i] = 1;
    }
    for (int i = 2; i <= C; ++i) {
        if (isprime[i]) {
            primes[primeCount++] = i;
        }
        for (int j = i + i; j <= C; j += i) {
            isprime[j] = 0;
        }
    }

    // 欧拉筛法
    /*
    for (int i = 2; i <= C; ++i) {
        if (isprime[i]) {
            primes[primeCount++] = i;
        }
        for (int j = 0; j < primeCount; ++j) {
            if (i * primes[j] > C) {
                break;
            }
            isprime[i * primes[j]] = 0;
            if (i % primes[j] == 0) {
                break;
            }
        }
    }
    */
    
    // 通过质数表构造出所有的四因数
    int* factor4_keys = (int*)malloc(primeCount * primeCount * sizeof(int));
    int* factor4_values = (int*)malloc(primeCount * primeCount * sizeof(int));
    int factor4_count = 0;
    
    for (int i = 0; i < primeCount; ++i) {
        int prime = primes[i];
        if (prime <= C3) {
            factor4_keys[factor4_count] = prime * prime * prime;
            factor4_values[factor4_count] = 1 + prime + prime * prime + prime * prime * prime;
            factor4_count++;
        }
    }
    for (int i = 0; i < primeCount; ++i) {
        for (int j = i + 1; j < primeCount; ++j) {
            if (primes[i] <= C / primes[j]) {
                factor4_keys[factor4_count] = primes[i] * primes[j];
                factor4_values[factor4_count] = 1 + primes[i] + primes[j] + primes[i] * primes[j];
                factor4_count++;
            } else {
                break;
            }
        }
    }

    int ans = 0;
    for (int idx = 0; idx < numsSize; ++idx) {
        int num = nums[idx];
        for (int i = 0; i < factor4_count; ++i) {
            if (factor4_keys[i] == num) {
                ans += factor4_values[i];
                break;
            }
        }
    }
    
    free(isprime);
    free(primes);
    free(factor4_keys);
    free(factor4_values);
    
    return ans;
}

###JavaScript

var sumFourDivisors = function(nums) {
    // C 是数组 nums 元素的上限,C3 是 C 的立方根
    const C = 100000, C3 = 46;
    
    let isprime = new Array(C + 1).fill(0);
    let primes = [];

    // 埃拉托斯特尼筛法
    for (let i = 2; i <= C; i++) {
        isprime[i] = 1;
    }
    for (let i = 2; i <= C; i++) {
        if (isprime[i]) {
            primes.push(i);
        }
        for (let j = i + i; j <= C; j += i) {
            isprime[j] = 0;
        }
    }

    // 欧拉筛法
    /*
    for (let i = 2; i <= C; i++) {
        if (isprime[i]) {
            primes.push(i);
        }
        for (let prime of primes) {
            if (i * prime > C) {
                break;
            }
            isprime[i * prime] = 0;
            if (i % prime === 0) {
                break;
            }
        }
    }
    */
    
    // 通过质数表构造出所有的四因数
    let factor4 = new Map();
    for (let prime of primes) {
        if (prime <= C3) {
            factor4.set(prime * prime * prime, 1 + prime + prime * prime + prime * prime * prime);
        }
    }
    for (let i = 0; i < primes.length; i++) {
        for (let j = i + 1; j < primes.length; j++) {
            if (primes[i] <= C / primes[j]) {
                factor4.set(primes[i] * primes[j], 1 + primes[i] + primes[j] + primes[i] * primes[j]);
            } else {
                break;
            }
        }
    }

    let ans = 0;
    for (let num of nums) {
        if (factor4.has(num)) {
            ans += factor4.get(num);
        }
    }
    return ans;
};

###TypeScript

function sumFourDivisors(nums: number[]): number {
    // C 是数组 nums 元素的上限,C3 是 C 的立方根
    const C: number = 100000, C3: number = 46;
    
    let isprime: number[] = new Array(C + 1).fill(0);
    let primes: number[] = [];

    // 埃拉托斯特尼筛法
    for (let i = 2; i <= C; i++) {
        isprime[i] = 1;
    }
    for (let i = 2; i <= C; i++) {
        if (isprime[i]) {
            primes.push(i);
        }
        for (let j = i + i; j <= C; j += i) {
            isprime[j] = 0;
        }
    }

    // 欧拉筛法
    /*
    for (let i = 2; i <= C; i++) {
        if (isprime[i]) {
            primes.push(i);
        }
        for (let prime of primes) {
            if (i * prime > C) {
                break;
            }
            isprime[i * prime] = 0;
            if (i % prime === 0) {
                break;
            }
        }
    }
    */
    
    // 通过质数表构造出所有的四因数
    let factor4: Map<number, number> = new Map();
    for (let prime of primes) {
        if (prime <= C3) {
            factor4.set(prime * prime * prime, 1 + prime + prime * prime + prime * prime * prime);
        }
    }
    for (let i = 0; i < primes.length; i++) {
        for (let j = i + 1; j < primes.length; j++) {
            if (primes[i] <= C / primes[j]) {
                factor4.set(primes[i] * primes[j], 1 + primes[i] + primes[j] + primes[i] * primes[j]);
            } else {
                break;
            }
        }
    }

    let ans: number = 0;
    for (let num of nums) {
        if (factor4.has(num)) {
            ans += factor4.get(num)!;
        }
    }
    return ans;
}

###Rust

use std::collections::HashMap;

impl Solution {
    pub fn sum_four_divisors(nums: Vec<i32>) -> i32 {
        // C 是数组 nums 元素的上限,C3 是 C 的立方根
        const C: i32 = 100000;
        const C3: i32 = 46;
        
        let mut isprime = vec![0; (C + 1) as usize];
        let mut primes = Vec::new();

        // 埃拉托斯特尼筛法
        for i in 2..=C {
            isprime[i as usize] = 1;
        }
        for i in 2..=C {
            if isprime[i as usize] == 1 {
                primes.push(i);
            }
            let mut j = i + i;
            while j <= C {
                isprime[j as usize] = 0;
                j += i;
            }
        }

        // 欧拉筛法
        /*
        for i in 2..=C {
            if isprime[i as usize] == 1 {
                primes.push(i);
            }
            for &prime in &primes {
                if i * prime > C {
                    break;
                }
                isprime[(i * prime) as usize] = 0;
                if i % prime == 0 {
                    break;
                }
            }
        }
        */
        
        // 通过质数表构造出所有的四因数
        let mut factor4 = HashMap::new();
        for &prime in &primes {
            if prime <= C3 {
                let key = prime * prime * prime;
                let value = 1 + prime + prime * prime + prime * prime * prime;
                factor4.insert(key, value);
            }
        }
        for i in 0..primes.len() {
            for j in i + 1..primes.len() {
                if primes[i] <= C / primes[j] {
                    let key = primes[i] * primes[j];
                    let value = 1 + primes[i] + primes[j] + primes[i] * primes[j];
                    factor4.insert(key, value);
                } else {
                    break;
                }
            }
        }

        let mut ans = 0;
        for num in nums {
            if let Some(&value) = factor4.get(&num) {
                ans += value;
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(\pi^2(C) + C\log\log C + N)$ 或 $O(\pi^2(C) + C + N)$,其中 $\pi(X)$ 为「质数计数函数」,表示不超过 $X$ 的质数个数。「埃拉托斯特尼筛法」的时间复杂度为 $O(C\log\log C)$,「欧拉筛法」的时间复杂度为 $O(C)$;通过质数表构造出所有四因数的时间复杂度为 $O(\pi(C^{1/3})) + O(\pi^2(C)) = O(\pi^2(C))$,遍历数组 nums 中的所有元素并检查是否为四因数的时间复杂度为 $O(N)$。

  • 空间复杂度:$O(C + \pi(C))$,无论哪一种筛法,都需要长度为 $C$ 的数组记录每个数是否为质数,以及长度为 $\pi(C)$ 的数组存储所有的质数。

昨天以前首页

给 N x 3 网格图涂色的方案数

2020年4月19日 09:53

方法一:递推

我们可以用 $f[i][\textit{type}]$ 表示当网格的大小为 $i \times 3$ 且最后一行的填色方法为 $\textit{type}$ 时的方案数。由于我们在填充第 $i$ 行时,会影响我们填充方案的只有它上面的那一行(即 $i - 1$ 行),因此用 $f[i][\textit{type}]$ 表示状态是合理的。

那么我们如何计算 $f[i][\textit{type}]$ 呢?可以发现:

  • 首先,$\textit{type}$ 本身是要满足要求的。每一行有 $3$ 个网格,如果我们用 $0, 1, 2$ 分别代表红黄绿,那么 $\textit{type}$ 可以看成一个三进制数,例如 $\textit{type} = (102)_3$ 时,表示 $3$ 个网格从左到右的颜色分别为黄、红、绿;

    • 这样以来,我们可以预处理出所有满足要求的 $\textit{type}$。具体地,我们使用三重循环分别枚举每一个格子的颜色,只有相邻的格子颜色不相同时,$\textit{type}$ 才满足要求。
  • 其次,$f[i][\textit{type}]$ 应该等于所有 $f[i - 1][\textit{type}']$ 的和,其中 $\textit{type'}$ 和 $\textit{type}$ 可以作为相邻的行。也就是说,$\textit{type'}$ 和 $\textit{type}$ 的对应位置不能相同。

递推解法的本身不难想出,难度在于上述的预处理以及编码实现。下面给出包含详细注释的 C++JavaPython 代码。

###C++

class Solution {
private:
    static constexpr int mod = 1000000007;

public:
    int numOfWays(int n) {
        // 预处理出所有满足条件的 type
        vector<int> types;
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                for (int k = 0; k < 3; ++k) {
                    if (i != j && j != k) {
                        // 只要相邻的颜色不相同就行
                        // 将其以十进制的形式存储
                        types.push_back(i * 9 + j * 3 + k);
                    }
                }
            }
        }
        int type_cnt = types.size();
        // 预处理出所有可以作为相邻行的 type 对
        vector<vector<int>> related(type_cnt, vector<int>(type_cnt));
        for (int i = 0; i < type_cnt; ++i) {
            // 得到 types[i] 三个位置的颜色
            int x1 = types[i] / 9, x2 = types[i] / 3 % 3, x3 = types[i] % 3;
            for (int j = 0; j < type_cnt; ++j) {
                // 得到 types[j] 三个位置的颜色
                int y1 = types[j] / 9, y2 = types[j] / 3 % 3, y3 = types[j] % 3;
                // 对应位置不同色,才能作为相邻的行
                if (x1 != y1 && x2 != y2 && x3 != y3) {
                    related[i][j] = 1;
                }
            }
        }
        // 递推数组
        vector<vector<int>> f(n + 1, vector<int>(type_cnt));
        // 边界情况,第一行可以使用任何 type
        for (int i = 0; i < type_cnt; ++i) {
            f[1][i] = 1;
        }
        for (int i = 2; i <= n; ++i) {
            for (int j = 0; j < type_cnt; ++j) {
                for (int k = 0; k < type_cnt; ++k) {
                    // f[i][j] 等于所有 f[i - 1][k] 的和
                    // 其中 k 和 j 可以作为相邻的行
                    if (related[k][j]) {
                        f[i][j] += f[i - 1][k];
                        f[i][j] %= mod;
                    }
                }
            }
        }
        // 最终所有的 f[n][...] 之和即为答案
        int ans = 0;
        for (int i = 0; i < type_cnt; ++i) {
            ans += f[n][i];
            ans %= mod;
        }
        return ans;
    }
};

###Java

class Solution {
    static final int MOD = 1000000007;

    public int numOfWays(int n) {
        // 预处理出所有满足条件的 type
        List<Integer> types = new ArrayList<Integer>();
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                for (int k = 0; k < 3; ++k) {
                    if (i != j && j != k) {
                        // 只要相邻的颜色不相同就行
                        // 将其以十进制的形式存储
                        types.add(i * 9 + j * 3 + k);
                    }
                }
            }
        }
        int typeCnt = types.size();
        // 预处理出所有可以作为相邻行的 type 对
        int[][] related = new int[typeCnt][typeCnt];
        for (int i = 0; i < typeCnt; ++i) {
            // 得到 types[i] 三个位置的颜色
            int x1 = types.get(i) / 9, x2 = types.get(i) / 3 % 3, x3 = types.get(i) % 3;
            for (int j = 0; j < typeCnt; ++j) {
                // 得到 types[j] 三个位置的颜色
                int y1 = types.get(j) / 9, y2 = types.get(j) / 3 % 3, y3 = types.get(j) % 3;
                // 对应位置不同色,才能作为相邻的行
                if (x1 != y1 && x2 != y2 && x3 != y3) {
                    related[i][j] = 1;
                }
            }
        }
        // 递推数组
        int[][] f = new int[n + 1][typeCnt];
        // 边界情况,第一行可以使用任何 type
        for (int i = 0; i < typeCnt; ++i) {
            f[1][i] = 1;
        }
        for (int i = 2; i <= n; ++i) {
            for (int j = 0; j < typeCnt; ++j) {
                for (int k = 0; k < typeCnt; ++k) {
                    // f[i][j] 等于所有 f[i - 1][k] 的和
                    // 其中 k 和 j 可以作为相邻的行
                    if (related[k][j] != 0) {
                        f[i][j] += f[i - 1][k];
                        f[i][j] %= MOD;
                    }
                }
            }
        }
        // 最终所有的 f[n][...] 之和即为答案
        int ans = 0;
        for (int i = 0; i < typeCnt; ++i) {
            ans += f[n][i];
            ans %= MOD;
        }
        return ans;
    }
}

###Python

class Solution:
    def numOfWays(self, n: int) -> int:
        mod = 10**9 + 7
        # 预处理出所有满足条件的 type
        types = list()
        for i in range(3):
            for j in range(3):
                for k in range(3):
                    if i != j and j != k:
                        # 只要相邻的颜色不相同就行
                        # 将其以十进制的形式存储
                        types.append(i * 9 + j * 3 + k)
        type_cnt = len(types)
        # 预处理出所有可以作为相邻行的 type 对
        related = [[0] * type_cnt for _ in range(type_cnt)]
        for i, ti in enumerate(types):
            # 得到 types[i] 三个位置的颜色
            x1, x2, x3 = ti // 9, ti // 3 % 3, ti % 3
            for j, tj in enumerate(types):
                # 得到 types[j] 三个位置的颜色
                y1, y2, y3 = tj // 9, tj // 3 % 3, tj % 3
                # 对应位置不同色,才能作为相邻的行
                if x1 != y1 and x2 != y2 and x3 != y3:
                    related[i][j] = 1
        # 递推数组
        f = [[0] * type_cnt for _ in range(n + 1)]
        # 边界情况,第一行可以使用任何 type
        f[1] = [1] * type_cnt
        for i in range(2, n + 1):
            for j in range(type_cnt):
                for k in range(type_cnt):
                    # f[i][j] 等于所有 f[i - 1][k] 的和
                    # 其中 k 和 j 可以作为相邻的行
                    if related[k][j]:
                        f[i][j] += f[i - 1][k]
                        f[i][j] %= mod
        # 最终所有的 f[n][...] 之和即为答案
        ans = sum(f[n]) % mod
        return ans

###C#

public class Solution {
    private const int mod = 1000000007;
    
    public int NumOfWays(int n) {
        // 预处理出所有满足条件的 type
        List<int> types = new List<int>();
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                for (int k = 0; k < 3; ++k) {
                    if (i != j && j != k) {
                        // 只要相邻的颜色不相同就行
                        // 将其以十进制的形式存储
                        types.Add(i * 9 + j * 3 + k);
                    }
                }
            }
        }
        int type_cnt = types.Count;
        // 预处理出所有可以作为相邻行的 type 对
        int[][] related = new int[type_cnt][];
        for (int i = 0; i < type_cnt; ++i) {
            related[i] = new int[type_cnt];
            // 得到 types[i] 三个位置的颜色
            int x1 = types[i] / 9, x2 = types[i] / 3 % 3, x3 = types[i] % 3;
            for (int j = 0; j < type_cnt; ++j) {
                // 得到 types[j] 三个位置的颜色
                int y1 = types[j] / 9, y2 = types[j] / 3 % 3, y3 = types[j] % 3;
                // 对应位置不同色,才能作为相邻的行
                if (x1 != y1 && x2 != y2 && x3 != y3) {
                    related[i][j] = 1;
                }
            }
        }
        // 递推数组
        int[][] f = new int[n + 1][];
        for (int i = 0; i <= n; ++i) {
            f[i] = new int[type_cnt];
        }
        // 边界情况,第一行可以使用任何 type
        for (int i = 0; i < type_cnt; ++i) {
            f[1][i] = 1;
        }
        for (int i = 2; i <= n; ++i) {
            for (int j = 0; j < type_cnt; ++j) {
                for (int k = 0; k < type_cnt; ++k) {
                    // f[i][j] 等于所有 f[i - 1][k] 的和
                    // 其中 k 和 j 可以作为相邻的行
                    if (related[k][j] == 1) {
                        f[i][j] = (f[i][j] + f[i - 1][k]) % mod;
                    }
                }
            }
        }
        // 最终所有的 f[n][...] 之和即为答案
        int ans = 0;
        for (int i = 0; i < type_cnt; ++i) {
            ans = (ans + f[n][i]) % mod;
        }
        return ans;
    }
}

###Go

func numOfWays(n int) int {
    // 预处理出所有满足条件的 type
    mod := 1000000007
    types := []int{}
    for i := 0; i < 3; i++ {
        for j := 0; j < 3; j++ {
            for k := 0; k < 3; k++ {
                if i != j && j != k {
                    // 只要相邻的颜色不相同就行
                    // 将其以十进制的形式存储
                    types = append(types, i*9 + j*3 + k)
                }
            }
        }
    }
    type_cnt := len(types)
    // 预处理出所有可以作为相邻行的 type 对
    related := make([][]int, type_cnt)
    for i := range related {
        related[i] = make([]int, type_cnt)
    }
    for i := 0; i < type_cnt; i++ {
        // 得到 types[i] 三个位置的颜色
        x1 := types[i] / 9
        x2 := types[i] / 3 % 3
        x3 := types[i] % 3
        for j := 0; j < type_cnt; j++ {
            // 得到 types[j] 三个位置的颜色
            y1 := types[j] / 9
            y2 := types[j] / 3 % 3
            y3 := types[j] % 3
            // 对应位置不同色,才能作为相邻的行
            if x1 != y1 && x2 != y2 && x3 != y3 {
                related[i][j] = 1
            }
        }
    }
    // 递推数组
    f := make([][]int, n+1)
    for i := range f {
        f[i] = make([]int, type_cnt)
    }
    // 边界情况,第一行可以使用任何 type
    for i := 0; i < type_cnt; i++ {
        f[1][i] = 1
    }
    for i := 2; i <= n; i++ {
        for j := 0; j < type_cnt; j++ {
            for k := 0; k < type_cnt; k++ {
                // f[i][j] 等于所有 f[i - 1][k] 的和
                // 其中 k 和 j 可以作为相邻的行
                if related[k][j] == 1 {
                    f[i][j] = (f[i][j] + f[i-1][k]) % mod
                }
            }
        }
    }
    // 最终所有的 f[n][...] 之和即为答案
    ans := 0
    for i := 0; i < type_cnt; i++ {
        ans = (ans + f[n][i]) % mod
    }
    return ans
}

###C

int numOfWays(int n) {
    // 预处理出所有满足条件的 type
    const int mod = 1000000007;
    int types[12];
    int type_cnt = 0;
    for (int i = 0; i < 3; ++i) {
        for (int j = 0; j < 3; ++j) {
            for (int k = 0; k < 3; ++k) {
                if (i != j && j != k) {
                    // 只要相邻的颜色不相同就行
                    // 将其以十进制的形式存储
                    types[type_cnt++] = i * 9 + j * 3 + k;
                }
            }
        }
    }
    // 预处理出所有可以作为相邻行的 type 对
    int related[12][12] = {0};
    for (int i = 0; i < type_cnt; ++i) {
        // 得到 types[i] 三个位置的颜色
        int x1 = types[i] / 9, x2 = types[i] / 3 % 3, x3 = types[i] % 3;
        for (int j = 0; j < type_cnt; ++j) {
            // 得到 types[j] 三个位置的颜色
            int y1 = types[j] / 9, y2 = types[j] / 3 % 3, y3 = types[j] % 3;
            // 对应位置不同色,才能作为相邻的行
            if (x1 != y1 && x2 != y2 && x3 != y3) {
                related[i][j] = 1;
            }
        }
    }
    // 递推数组
    int f[n + 1][type_cnt];
    // 初始化
    for (int i = 0; i <= n; ++i) {
        for (int j = 0; j < type_cnt; ++j) {
            f[i][j] = 0;
        }
    }
    // 边界情况,第一行可以使用任何 type
    for (int i = 0; i < type_cnt; ++i) {
        f[1][i] = 1;
    }
    for (int i = 2; i <= n; ++i) {
        for (int j = 0; j < type_cnt; ++j) {
            for (int k = 0; k < type_cnt; ++k) {
                // f[i][j] 等于所有 f[i - 1][k] 的和
                // 其中 k 和 j 可以作为相邻的行
                if (related[k][j]) {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % mod;
                }
            }
        }
    }
    // 最终所有的 f[n][...] 之和即为答案
    int ans = 0;
    for (int i = 0; i < type_cnt; ++i) {
        ans = (ans + f[n][i]) % mod;
    }
    return ans;
}

###JavaScript

var numOfWays = function(n) {
    // 预处理出所有满足条件的 type
    const mod = 1000000007;
    const types = [];
    for (let i = 0; i < 3; ++i) {
        for (let j = 0; j < 3; ++j) {
            for (let k = 0; k < 3; ++k) {
                if (i !== j && j !== k) {
                    // 只要相邻的颜色不相同就行
                    // 将其以十进制的形式存储
                    types.push(i * 9 + j * 3 + k);
                }
            }
        }
    }
    const type_cnt = types.length;
    // 预处理出所有可以作为相邻行的 type 对
    const related = Array.from({length: type_cnt}, () => new Array(type_cnt).fill(0));
    for (let i = 0; i < type_cnt; ++i) {
        // 得到 types[i] 三个位置的颜色
        const x1 = Math.floor(types[i] / 9);
        const x2 = Math.floor(types[i] / 3) % 3;
        const x3 = types[i] % 3;
        for (let j = 0; j < type_cnt; ++j) {
            // 得到 types[j] 三个位置的颜色
            const y1 = Math.floor(types[j] / 9);
            const y2 = Math.floor(types[j] / 3) % 3;
            const y3 = types[j] % 3;
            // 对应位置不同色,才能作为相邻的行
            if (x1 !== y1 && x2 !== y2 && x3 !== y3) {
                related[i][j] = 1;
            }
        }
    }
    // 递推数组
    const f = Array.from({length: n + 1}, () => new Array(type_cnt).fill(0));
    // 边界情况,第一行可以使用任何 type
    for (let i = 0; i < type_cnt; ++i) {
        f[1][i] = 1;
    }
    for (let i = 2; i <= n; ++i) {
        for (let j = 0; j < type_cnt; ++j) {
            for (let k = 0; k < type_cnt; ++k) {
                // f[i][j] 等于所有 f[i - 1][k] 的和
                // 其中 k 和 j 可以作为相邻的行
                if (related[k][j]) {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % mod;
                }
            }
        }
    }
    // 最终所有的 f[n][...] 之和即为答案
    let ans = 0;
    for (let i = 0; i < type_cnt; ++i) {
        ans = (ans + f[n][i]) % mod;
    }
    return ans;
};

###TypeScript

function numOfWays(n: number): number {
    // 预处理出所有满足条件的 type
    const mod: number = 1000000007;
    const types: number[] = [];
    for (let i = 0; i < 3; ++i) {
        for (let j = 0; j < 3; ++j) {
            for (let k = 0; k < 3; ++k) {
                if (i !== j && j !== k) {
                    // 只要相邻的颜色不相同就行
                    // 将其以十进制的形式存储
                    types.push(i * 9 + j * 3 + k);
                }
            }
        }
    }
    const type_cnt: number = types.length;
    // 预处理出所有可以作为相邻行的 type 对
    const related: number[][] = Array.from({length: type_cnt}, () => new Array(type_cnt).fill(0));
    for (let i = 0; i < type_cnt; ++i) {
        // 得到 types[i] 三个位置的颜色
        const x1: number = Math.floor(types[i] / 9);
        const x2: number = Math.floor(types[i] / 3) % 3;
        const x3: number = types[i] % 3;
        for (let j = 0; j < type_cnt; ++j) {
            // 得到 types[j] 三个位置的颜色
            const y1: number = Math.floor(types[j] / 9);
            const y2: number = Math.floor(types[j] / 3) % 3;
            const y3: number = types[j] % 3;
            // 对应位置不同色,才能作为相邻的行
            if (x1 !== y1 && x2 !== y2 && x3 !== y3) {
                related[i][j] = 1;
            }
        }
    }
    // 递推数组
    const f: number[][] = Array.from({length: n + 1}, () => new Array(type_cnt).fill(0));
    // 边界情况,第一行可以使用任何 type
    for (let i = 0; i < type_cnt; ++i) {
        f[1][i] = 1;
    }
    for (let i = 2; i <= n; ++i) {
        for (let j = 0; j < type_cnt; ++j) {
            for (let k = 0; k < type_cnt; ++k) {
                // f[i][j] 等于所有 f[i - 1][k] 的和
                // 其中 k 和 j 可以作为相邻的行
                if (related[k][j]) {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % mod;
                }
            }
        }
    }
    // 最终所有的 f[n][...] 之和即为答案
    let ans: number = 0;
    for (let i = 0; i < type_cnt; ++i) {
        ans = (ans + f[n][i]) % mod;
    }
    return ans;
}

###Rust

impl Solution {
    pub fn num_of_ways(n: i32) -> i32 {
        // 预处理出所有满足条件的 type
        let mod_val = 1000000007;
        let n = n as usize;
        let mut types = Vec::new();
        for i in 0..3 {
            for j in 0..3 {
                for k in 0..3 {
                    if i != j && j != k {
                        // 只要相邻的颜色不相同就行
                        // 将其以十进制的形式存储
                        types.push(i * 9 + j * 3 + k);
                    }
                }
            }
        }
        let type_cnt = types.len();
        // 预处理出所有可以作为相邻行的 type 对
        let mut related = vec![vec![0; type_cnt]; type_cnt];
        for i in 0..type_cnt {
            // 得到 types[i] 三个位置的颜色
            let x1 = types[i] / 9;
            let x2 = types[i] / 3 % 3;
            let x3 = types[i] % 3;
            for j in 0..type_cnt {
                // 得到 types[j] 三个位置的颜色
                let y1 = types[j] / 9;
                let y2 = types[j] / 3 % 3;
                let y3 = types[j] % 3;
                // 对应位置不同色,才能作为相邻的行
                if x1 != y1 && x2 != y2 && x3 != y3 {
                    related[i][j] = 1;
                }
            }
        }
        // 递推数组
        let mut f = vec![vec![0; type_cnt]; n + 1];
        // 边界情况,第一行可以使用任何 type
        for i in 0..type_cnt {
            f[1][i] = 1;
        }
        for i in 2..=n {
            for j in 0..type_cnt {
                for k in 0..type_cnt {
                    // f[i][j] 等于所有 f[i - 1][k] 的和
                    // 其中 k 和 j 可以作为相邻的行
                    if related[k][j] == 1 {
                        f[i][j] = (f[i][j] + f[i - 1][k]) % mod_val;
                    }
                }
            }
        }
        // 最终所有的 f[n][...] 之和即为答案
        let mut ans = 0;
        for i in 0..type_cnt {
            ans = (ans + f[n][i]) % mod_val;
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(T^2N)$,其中 $T$ 是满足要求的 $\textit{type}$ 的数量,在示例一中已经给出了 $T = 12$。在递推的过程中,我们需要计算所有的 $f[i][\textit{type}]$,并且需要枚举上一行的 $\textit{type}'$。

  • 空间复杂度:$O(T^2 + TN)$。我们需要 $T * T$ 的二维数组存储 $\textit{type}$ 之间的关系,$T * N$ 的数组存储递推的结果。注意到由于 $f[i][\textit{type}]$ 只和上一行的状态有关,我们可以使用两个一维数组存储当前行和上一行的 $f$ 值,空间复杂度降低至 $O(T^2 + 2T) = O(T^2)$。

方法二:递推优化

如果读者有一些高中数学竞赛基础,就可以发现上面的这个递推式是线性的,也就是说:

  • 我们可以进行一些化简;

  • 它存在通项公式。

直观上,我们怎么化简方法一中的递推呢?

我们把满足要求的 $\textit{type}$ 都写出来,一共有 $12$ 种:

010, 012, 020, 021, 101, 102, 120, 121, 201, 202, 210, 212

我们可以把它们分成两类:

  • ABC 类:三个颜色互不相同,一共有 $6$ 种:012, 021, 102, 120, 201, 210

  • ABA 类:左右两侧的颜色相同,也有 $6$ 种:010, 020, 101, 121, 202, 212

这样我们就可以把 $12$ 种 $\textit{type}$ 浓缩成了 $2$ 种,尝试写出这两类之间的递推式。我们用 $f[i][0]$ 表示 ABC 类,$f[i][1]$ 表示 ABA 类。在计算时,我们可以将任意一种满足要求的涂色方法带入第 i - 1 行,并检查第 i 行的方案数,这是因为同一类的涂色方法都是等价的:

  • i - 1 行是 ABC 类,第 i 行是 ABC 类:以 012 为例,那么第 i 行只能是120201,方案数为 $2$;

  • i - 1 行是 ABC 类,第 i 行是 ABA 类:以 012 为例,那么第 i 行只能是 101121,方案数为 $2$;

  • i - 1 行是 ABA 类,第 i 行是 ABC 类:以 010 为例,那么第 i 行只能是 102201,方案数为 2

  • i - 1 行是 ABA 类,第 i 行是 ABA 类:以 010 为例,那么第 i 行只能是 101121202,方案数为 3

因此我们就可以写出递推式:

$$
\begin{cases}
f[i][0] = 2 * f[i - 1][0] + 2 * f[i - 1][1] \
f[i][1] = 2 * f[i - 1][0] + 3 * f[i - 1][1]
\end{cases}
$$

###C++

class Solution {
private:
    static constexpr int mod = 1000000007;

public:
    int numOfWays(int n) {
        int fi0 = 6, fi1 = 6;
        for (int i = 2; i <= n; ++i) {
            int new_fi0 = (2LL * fi0 + 2LL * fi1) % mod;
            int new_fi1 = (2LL * fi0 + 3LL * fi1) % mod;
            fi0 = new_fi0;
            fi1 = new_fi1;
        }
        return (fi0 + fi1) % mod;
    }
};

###Java

class Solution {
    static final int MOD = 1000000007;

    public int numOfWays(int n) {
        long fi0 = 6, fi1 = 6;
        for (int i = 2; i <= n; ++i) {
            long newFi0 = (2 * fi0 + 2 * fi1) % MOD;
            long newFi1 = (2 * fi0 + 3 * fi1) % MOD;
            fi0 = newFi0;
            fi1 = newFi1;
        }
        return (int) ((fi0 + fi1) % MOD);
    }
}

###Python

class Solution:
    def numOfWays(self, n: int) -> int:
        mod = 10**9 + 7
        fi0, fi1 = 6, 6
        for i in range(2, n + 1):
            fi0, fi1 = (2 * fi0 + 2 * fi1) % mod, (2 * fi0 + 3 * fi1) % mod
        return (fi0 + fi1) % mod

###C#

public class Solution {
    private const int mod = 1000000007;
    
    public int NumOfWays(int n) {
        long fi0 = 6, fi1 = 6;
        for (int i = 2; i <= n; ++i) {
            long new_fi0 = (2 * fi0 + 2 * fi1) % mod;
            long new_fi1 = (2 * fi0 + 3 * fi1) % mod;
            fi0 = new_fi0;
            fi1 = new_fi1;
        }
        return (int)((fi0 + fi1) % mod);
    }
}

###Go

func numOfWays(n int) int {
    mod := 1000000007
    fi0, fi1 := 6, 6
    for i := 2; i <= n; i++ {
        new_fi0 := (2*fi0 + 2*fi1) % mod
        new_fi1 := (2*fi0 + 3*fi1) % mod
        fi0, fi1 = new_fi0, new_fi1
    }
    return (fi0 + fi1) % mod
}

###C

int numOfWays(int n) {
    const int mod = 1000000007;
    long long fi0 = 6, fi1 = 6;
    for (int i = 2; i <= n; ++i) {
        long long new_fi0 = (2 * fi0 + 2 * fi1) % mod;
        long long new_fi1 = (2 * fi0 + 3 * fi1) % mod;
        fi0 = new_fi0;
        fi1 = new_fi1;
    }
    return (fi0 + fi1) % mod;
}

###JavaScript

var numOfWays = function(n) {
    const mod = 1000000007;    
    let fi0 = 6, fi1 = 6;
    for (let i = 2; i <= n; i++) {
        const new_fi0 = (2 * fi0 + 2 * fi1) % mod;
        const new_fi1 = (2 * fi0 + 3 * fi1) % mod;
        fi0 = new_fi0;
        fi1 = new_fi1;
    }
    return (fi0 + fi1) % mod;
};

###TypeScript

function numOfWays(n: number): number {
    const mod: number = 1000000007;    
    let fi0: number = 6, fi1: number = 6;
    for (let i = 2; i <= n; i++) {
        const new_fi0: number = (2 * fi0 + 2 * fi1) % mod;
        const new_fi1: number = (2 * fi0 + 3 * fi1) % mod;
        fi0 = new_fi0;
        fi1 = new_fi1;
    }
    return (fi0 + fi1) % mod;
}

###Rust

impl Solution {
    pub fn num_of_ways(n: i32) -> i32 {
        let mod_val: i64 = 1000000007;
        let mut fi0: i64 = 6;
        let mut fi1: i64 = 6;
        
        for _ in 2..= n {
            let new_fi0 = (2 * fi0 + 2 * fi1) % mod_val;
            let new_fi1 = (2 * fi0 + 3 * fi1) % mod_val;
            fi0 = new_fi0;
            fi1 = new_fi1;
        }
        
        ((fi0 + fi1) % mod_val) as i32
    }
}

复杂度分析

  • 时间复杂度:$O(N)$。

  • 空间复杂度:$O(1)$。

在长度 2N 的数组中找出重复 N 次的元素

2022年5月20日 09:48

方法一:哈希表

思路与算法

记重复 $n$ 次的元素为 $x$。由于数组 $\textit{nums}$ 中有 $n+1$ 个不同的元素,而其长度为 $2n$,那么数组中剩余的元素均只出现了一次。也就是说,我们只需要找到重复出现的元素即为答案。

因此我们可以对数组进行一次遍历,并使用哈希集合存储已经出现过的元素。如果遍历到了哈希集合中的元素,那么返回该元素作为答案。

代码

###C++

class Solution {
public:
    int repeatedNTimes(vector<int>& nums) {
        unordered_set<int> found;
        for (int num: nums) {
            if (found.count(num)) {
                return num;
            }
            found.insert(num);
        }
        // 不可能的情况
        return -1;
    }
};

###Java

class Solution {
    public int repeatedNTimes(int[] nums) {
        Set<Integer> found = new HashSet<Integer>();
        for (int num : nums) {
            if (!found.add(num)) {
                return num;
            }
        }
        // 不可能的情况
        return -1;
    }
}

###C#

public class Solution {
    public int RepeatedNTimes(int[] nums) {
        ISet<int> found = new HashSet<int>();
        foreach (int num in nums) {
            if (!found.Add(num)) {
                return num;
            }
        }
        // 不可能的情况
        return -1;
    }
}

###Python

class Solution:
    def repeatedNTimes(self, nums: List[int]) -> int:
        found = set()

        for num in nums:
            if num in found:
                return num
            found.add(num)
        
        # 不可能的情况
        return -1

###C

struct HashItem {
    int key;
    UT_hash_handle hh;
};

void freeHash(struct HashItem **obj) {
    struct HashItem *curr, *tmp;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);  
        free(curr);        
    }
}

int repeatedNTimes(int* nums, int numsSize){
    struct HashItem *found = NULL;
    for (int i = 0; i < numsSize; i++) {
        struct HashItem *pEntry = NULL;
        HASH_FIND_INT(found, &nums[i], pEntry);
        if (pEntry != NULL) {
            freeHash(&found);
            return nums[i];
        } else {
            pEntry = (struct HashItem *)malloc(sizeof(struct HashItem));
            pEntry->key = nums[i];
            HASH_ADD_INT(found, key, pEntry);
        }
    }
    // 不可能的情况
    freeHash(&found);
    return -1;
}

###go

func repeatedNTimes(nums []int) int {
    found := map[int]bool{}
    for _, num := range nums {
        if found[num] {
            return num
        }
        found[num] = true
    }
    return -1 // 不可能的情况
}

###JavaScript

var repeatedNTimes = function(nums) {
    const found = new Set();
    for (const num of nums) {
        if (found.has(num)) {
            return num;
        }
        found.add(num);
    }
    // 不可能的情况
    return -1;
};

复杂度分析

  • 时间复杂度:$O(n)$。我们只需要对数组 $\textit{nums}$ 进行一次遍历。

  • 空间复杂度:$O(n)$,即为哈希集合需要使用的空间。

方法二:数学

思路与算法

我们可以考虑重复的元素 $x$ 在数组 $\textit{nums}$ 中出现的位置。

如果相邻的 $x$ 之间至少都隔了 $2$ 个位置,那么数组的总长度至少为:

$$
n + 2(n - 1) = 3n - 2
$$

当 $n > 2$ 时,$3n-2 > 2n$,不存在满足要求的数组。因此一定存在两个相邻的 $x$,它们的位置是连续的,或者只隔了 $1$ 个位置。

当 $n = 2$ 时,数组的长度最多为 $2n = 4$,因此最多只能隔 $2$ 个位置。

这样一来,我们只需要遍历所有间隔 $2$ 个位置及以内的下标对,判断对应的元素是否相等即可。

代码

###C++

class Solution {
public:
    int repeatedNTimes(vector<int>& nums) {
        int n = nums.size();
        for (int gap = 1; gap <= 3; ++gap) {
            for (int i = 0; i + gap < n; ++i) {
                if (nums[i] == nums[i + gap]) {
                    return nums[i];
                }
            }
        }
        // 不可能的情况
        return -1;
    }
};

###Java

class Solution {
    public int repeatedNTimes(int[] nums) {
        int n = nums.length;
        for (int gap = 1; gap <= 3; ++gap) {
            for (int i = 0; i + gap < n; ++i) {
                if (nums[i] == nums[i + gap]) {
                    return nums[i];
                }
            }
        }
        // 不可能的情况
        return -1;
    }
}

###C#

public class Solution {
    public int RepeatedNTimes(int[] nums) {
        int n = nums.Length;
        for (int gap = 1; gap <= 3; ++gap) {
            for (int i = 0; i + gap < n; ++i) {
                if (nums[i] == nums[i + gap]) {
                    return nums[i];
                }
            }
        }
        // 不可能的情况
        return -1;
    }
}

###Python

class Solution:
    def repeatedNTimes(self, nums: List[int]) -> int:
        n = len(nums)
        for gap in range(1, 4):
            for i in range(n - gap):
                if nums[i] == nums[i + gap]:
                    return nums[i]
        
        # 不可能的情况
        return -1

###C

int repeatedNTimes(int* nums, int numsSize) {
    for (int gap = 1; gap <= 3; ++gap) {
        for (int i = 0; i + gap < numsSize; ++i) {
            if (nums[i] == nums[i + gap]) {
                return nums[i];
            }
        }
    }
    // 不可能的情况
    return -1;
}

###go

func repeatedNTimes(nums []int) int {
    for gap := 1; gap <= 3; gap++ {
        for i, num := range nums[:len(nums)-gap] {
            if num == nums[i+gap] {
                return num
            }
        }
    }
    return -1 // 不可能的情况
}

###JavaScript

var repeatedNTimes = function(nums) {
    const n = nums.length;
    for (let gap = 1; gap <= 3; ++gap) {
        for (let i = 0; i + gap < n; ++i) {
            if (nums[i] === nums[i + gap]) {
                return nums[i];
            }
        }
    }
    // 不可能的情况
    return -1;
};

复杂度分析

  • 时间复杂度:$O(n)$。我们最多对数组进行三次遍历(除了 $n=2$ 之外,最多两次遍历)。

  • 空间复杂度:$O(1)$。

方法三:随机选择

思路与算法

我们可以每次随机选择两个不同的下标,判断它们对应的元素是否相等即可。如果相等,那么返回任意一个作为答案。

代码

###C++

class Solution {
public:
    int repeatedNTimes(vector<int>& nums) {
        int n = nums.size();
        mt19937 gen{random_device{}()};
        uniform_int_distribution<int> dis(0, n - 1);

        while (true) {
            int x = dis(gen), y = dis(gen);
            if (x != y && nums[x] == nums[y]) {
                return nums[x];
            }
        }
    }
};

###Java

class Solution {
    public int repeatedNTimes(int[] nums) {
        int n = nums.length;
        Random random = new Random();

        while (true) {
            int x = random.nextInt(n), y = random.nextInt(n);
            if (x != y && nums[x] == nums[y]) {
                return nums[x];
            }
        }
    }
}

###C#

public class Solution {
    public int RepeatedNTimes(int[] nums) {
        int n = nums.Length;
        Random random = new Random();

        while (true) {
            int x = random.Next(n), y = random.Next(n);
            if (x != y && nums[x] == nums[y]) {
                return nums[x];
            }
        }
    }
}

###Python

class Solution:
    def repeatedNTimes(self, nums: List[int]) -> int:
        n = len(nums)

        while True:
            x, y = random.randrange(n), random.randrange(n)
            if x != y and nums[x] == nums[y]:
                return nums[x]

###C

int repeatedNTimes(int* nums, int numsSize) {
    srand(time(NULL));
    while (true) {
        int x = random() % numsSize, y = random() % numsSize;
        if (x != y && nums[x] == nums[y]) {
            return nums[x];
        }
    }
}

###go

func repeatedNTimes(nums []int) int {
    n := len(nums)
    for {
        x, y := rand.Intn(n), rand.Intn(n)
        if x != y && nums[x] == nums[y] {
            return nums[x]
        }
    }
}

###JavaScript

var repeatedNTimes = function(nums) {
    const n = nums.length;

    while (true) {
        const x = Math.floor(Math.random() * n), y = Math.floor(Math.random() * n);
        if (x !== y && nums[x] === nums[y]) {
            return nums[x];
        }
    }
};

复杂度分析

  • 时间复杂度:期望 $O(1)$。选择两个相同元素的概率为 $\dfrac{n}{2n} \times \dfrac{n-1}{2n} \approx \dfrac{1}{4}$,因此期望 $4$ 次结束循环。

  • 空间复杂度:$O(1)$。

统计有序矩阵中的负数

2020年2月18日 20:22

方法一:暴力

观察数据范围注意到矩阵大小不会超过 $100*100=10^4$,所以我们可以遍历矩阵所有数,统计负数的个数。

###C++

class Solution {
public:
    int countNegatives(vector<vector<int>>& grid) {
        int num = 0;
        for (int x : grid) {
            for (int y : x) {
                if (y < 0) {
                    num++;
                }
            }
        }
        return num;
    }
};

###Java

class Solution {
    public int countNegatives(int[][] grid) {
        int num = 0;
        for (int[] row : grid) {
            for (int value : row) {
                if (value < 0) {
                    num++;
                }
            }
        }
        return num;
    }
}

###C#

public class Solution {
    public int CountNegatives(int[][] grid) {
        int num = 0;
        foreach (int[] row in grid) {
            foreach (int value in row) {
                if (value < 0) {
                    num++;
                }
            }
        }
        return num;
    }
}

###Go

func countNegatives(grid [][]int) int {
    num := 0
    for _, row := range grid {
        for _, value := range row {
            if value < 0 {
                num++
            }
        }
    }
    return num
}

###Python

class Solution:
    def countNegatives(self, grid: List[List[int]]) -> int:
        num = 0
        for row in grid:
            for value in row:
                if value < 0:
                    num += 1
        return num

###C

int countNegatives(int** grid, int gridSize, int* gridColSize) {
    int num = 0;
    for (int i = 0; i < gridSize; i++) {
        for (int j = 0; j < gridColSize[i]; j++) {
            if (grid[i][j] < 0) {
                num++;
            }
        }
    }
    return num;
}

###JavaScript

var countNegatives = function(grid) {
    let num = 0;
    for (const row of grid) {
        for (const value of row) {
            if (value < 0) {
                num++;
            }
        }
    }
    return num;
};

###TypeScript

function countNegatives(grid: number[][]): number {
    let num = 0;
    for (const row of grid) {
        for (const value of row) {
            if (value < 0) {
                num++;
            }
        }
    }
    return num;
};

###Rust

impl Solution {
    pub fn count_negatives(grid: Vec<Vec<i32>>) -> i32 {
        let mut num = 0;
        for row in grid {
            for value in row {
                if value < 0 {
                    num += 1;
                }
            }
        }
        num
    }
}

复杂度分析

  • 时间复杂度:$O(nm)$,即矩阵元素的总个数。
  • 空间复杂度:$O(1)$。

方法二:二分查找

注意到题目中给了一个性质,即矩阵中的元素无论是按行还是按列,都以非递增顺序排列,可以考虑把这个性质利用起来优化暴力。已知这个性质告诉了我们每一行的数都是有序的,所以我们通过二分查找可以找到每一行中从前往后的第一个负数,那么这个位置之后到这一行的末尾里所有的数必然是负数了,可以直接统计。

  1. 遍历矩阵的每一行。

  2. 二分查找到该行从前往后的第一个负数,考虑第 $i$ 行,我们记这个位置为 $pos_i$,那么第 $i$ 行 $[pos_i,m-1]$ 中的所有数都是负数,所以这一行对答案的贡献就是 $m-1-pos_i+1=m-pos_i$。

  3. 最后的答案就是 $\sum_{i=0}^{n-1}(m-pos_i)$。

###C++

class Solution {
public:
    int countNegatives(vector<vector<int>>& grid) {
        int num = 0;
        for (auto x : grid) {
            int l = 0, r = (int)x.size() - 1, pos = -1;
            while (l <= r) {
                int mid = l + ((r - l) >> 1);
                if (x[mid] < 0) {
                    pos = mid;
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            }
            
            if (~pos) {  // pos != -1 表示这一行存在负数
                num += (int)x.size() - pos;
            }
        }
        return num;
    }
};

###Java

class Solution {
    public int countNegatives(int[][] grid) {
        int num = 0;
        for (int[] row : grid) {
            int l = 0, r = row.length - 1, pos = -1;
            while (l <= r) {
                int mid = l + (r - l) / 2;
                if (row[mid] < 0) {
                    pos = mid;
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            }
            if (pos != -1) {
                num += row.length - pos;
            }
        }
        return num;
    }
}

###C#

public class Solution {
    public int CountNegatives(int[][] grid) {
        int num = 0;
        foreach (int[] row in grid) {
            int l = 0, r = row.Length - 1, pos = -1;
            while (l <= r) {
                int mid = l + (r - l) / 2;
                if (row[mid] < 0) {
                    pos = mid;
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            }
            if (pos != -1) {
                num += row.Length - pos;
            }
        }
        return num;
    }
}

###Go

func countNegatives(grid [][]int) int {
    num := 0
    for _, row := range grid {
        l, r, pos := 0, len(row) - 1, -1
        for l <= r {
            mid := l + (r - l) / 2
            if row[mid] < 0 {
                pos = mid
                r = mid - 1
            } else {
                l = mid + 1
            }
        }
        if pos != -1 {
            num += len(row) - pos
        }
    }
    return num
}

###Python

class Solution:
    def countNegatives(self, grid: List[List[int]]) -> int:
        num = 0
        for row in grid:
            l, r, pos = 0, len(row) - 1, -1
            while l <= r:
                mid = l + (r - l) // 2
                if row[mid] < 0:
                    pos = mid
                    r = mid - 1
                else:
                    l = mid + 1
            if pos != -1:
                num += len(row) - pos
        return num

###C

int countNegatives(int** grid, int gridSize, int* gridColSize) {
    int num = 0;
    for (int i = 0; i < gridSize; i++) {
        int l = 0, r = gridColSize[i] - 1, pos = -1;
        while (l <= r) {
            int mid = l + (r - l) / 2;
            if (grid[i][mid] < 0) {
                pos = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        if (pos != -1) {
            num += gridColSize[i] - pos;
        }
    }
    return num;
}

###JavaScript

var countNegatives = function(grid) {
    let num = 0;
    for (const row of grid) {
        let l = 0, r = row.length - 1, pos = -1;
        while (l <= r) {
            const mid = l + Math.floor((r - l) / 2);
            if (row[mid] < 0) {
                pos = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        if (pos !== -1) {
            num += row.length - pos;
        }
    }
    return num;
};

###TypeScript

function countNegatives(grid: number[][]): number {
    let num = 0;
    for (const row of grid) {
        let l = 0, r = row.length - 1, pos = -1;
        while (l <= r) {
            const mid = l + Math.floor((r - l) / 2);
            if (row[mid] < 0) {
                pos = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        if (pos !== -1) {
            num += row.length - pos;
        }
    }
    return num;
}

###Rust

impl Solution {
    pub fn count_negatives(grid: Vec<Vec<i32>>) -> i32 {
        let mut num = 0;
        for row in grid {
            let (mut l, mut r, mut pos) = (0, row.len() as i32 - 1, -1);
            while l <= r {
                let mid = l + (r - l) / 2;
                if row[mid as usize] < 0 {
                    pos = mid;
                    r = mid - 1;
                } else {
                    l = mid + 1;
                }
            }
            if pos != -1 {
                num += row.len() as i32 - pos;
            }
        }
        num
    }
}

复杂度分析

  • 时间复杂度:二分查找一行的时间复杂度为$logm$,需要遍历$n$行,所以总时间复杂度是$O(nlogm)$。
  • 空间复杂度:$O(1)$。

方法三:分治

方法二其实只利用了一部分的性质,即每一行是非递增的,但其实整个矩阵是每行每列均非递增,这说明了一个更重要的性质:每一行从前往后第一个负数的位置是不断递减的,即我们设第 $i$ 行的第一个负数的位置为 $pos_i$,不失一般性,我们把一行全是正数的 $pos$ 设为 $m$,则
$$
pos_0>=pos_1>=pos_2>=...>=pos_{n-1}
$$
所以我们可以依此设计一个分治算法。

我们设计一个函数 $solve(l,r,L,R)$ 表示我们在统计 $[l,r]$ 行的答案,第 $[l,r]$ 行 $pos$ 的位置在 $[L,R]$ 列中,计算 $[l,r]$ 的中间行第 $mid$ 行的的 $pos_{mid}$,算完以后根据之前的方法计算这一行对答案的贡献。然后根据我们之前发现的性质,可以知道 $[l,mid-1]$ 中所有行的 $pos$ 是大于等于 $pos_{mid}$,$[mid+1,r]$ 中所有行的 $pos$ 值是小于等于 $pos_{mid}$ 的,所以可以分成两部分递归下去,即:
$$
solve(l,mid-1,pos_{mid},R)
$$

$$
solve(mid+1,r,L,pos_{mid})
$$
所以答案就是 $m-pos_{mid}+solve(l,mid-1,pos_{mid},R)+solve(mid+1,r,L,pos_{mid})$。

递归函数入口为 $solve(0,n-1,0,m-1)$。

###C++

class Solution {
public:
    int solve(int l, int r, int L, int R, vector<vector<int>>& grid) {
        if (l > r) {
            return 0;
        }
        
        int mid = l + ((r - l) >> 1);
        int pos = -1;
        // 在当前行中查找第一个负数
        for (int i = L; i <= R; ++i) {
            if (grid[mid][i] < 0) {
                pos = i;
                break;
            }
        }
        int ans = 0;
        if (pos != -1) {
            // 当前行找到负数,计算当前行的负数个数
            ans += (int)grid[0].size() - pos;
            // 递归处理上半部分(使用更小的列范围)
            ans += solve(l, mid - 1, pos, R, grid);
            // 递归处理下半部分(使用相同的列起始范围)
            ans += solve(mid + 1, r, L, pos, grid);
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += solve(mid + 1, r, L, R, grid);
        }
        
        return ans;
    }
    
    int countNegatives(vector<vector<int>>& grid) {
        return solve(0, (int)grid.size() - 1, 0, (int)grid[0].size() - 1, grid);
    }
};  

###Java

class Solution {
    private int solve(int l, int r, int L, int R, int[][] grid) {
        if (l > r) {
            return 0;
        }
        
        int mid = l + (r - l) / 2;
        int pos = -1;
        // 在当前行中查找第一个负数
        for (int i = L; i <= R; i++) {
            if (grid[mid][i] < 0) {
                pos = i;
                break;
            }
        }
        
        int ans = 0;
        if (pos != -1) {
            // 当前行找到负数,计算当前行的负数个数
            ans += grid[0].length - pos;
            // 递归处理上半部分(使用更小的列范围)
            ans += solve(l, mid - 1, pos, R, grid);
            // 递归处理下半部分(使用相同的列起始范围)
            ans += solve(mid + 1, r, L, pos, grid);
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += solve(mid + 1, r, L, R, grid);
        }
        return ans;
    }
    
    public int countNegatives(int[][] grid) {
        return solve(0, grid.length - 1, 0, grid[0].length - 1, grid);
    }
}

###C#

public class Solution {
    private int Solve(int l, int r, int L, int R, int[][] grid) {
        if (l > r) {
            return 0;
        }
        
        int mid = l + (r - l) / 2;
        int pos = -1;
        // 在当前行中查找第一个负数
        for (int i = L; i <= R; i++) {
            if (grid[mid][i] < 0) {
                pos = i;
                break;
            }
        }
        int ans = 0;
        if (pos != -1) {
            // 当前行找到负数,计算当前行的负数个数
            ans += grid[0].Length - pos;
            // 递归处理上半部分(使用更小的列范围)
            ans += Solve(l, mid - 1, pos, R, grid);
            // 递归处理下半部分(使用相同的列起始范围)
            ans += Solve(mid + 1, r, L, pos, grid);
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += Solve(mid + 1, r, L, R, grid);
        }
        
        return ans;
    }
    
    public int CountNegatives(int[][] grid) {
        return Solve(0, grid.Length - 1, 0, grid[0].Length - 1, grid);
    }
}

###Go

func countNegatives(grid [][]int) int {
    var solve func(l, r, L, R int) int
    solve = func(l, r, L, R int) int {
        if l > r {
            return 0
        }
        
        mid := l + (r - l) / 2
        pos := -1
        // 在当前行中查找第一个负数
        for i := L; i <= R; i++ {
            if grid[mid][i] < 0 {
                pos = i
                break
            }
        }
        
        ans := 0
        if pos != -1 {
            // 当前行找到负数,计算当前行的负数个数
            ans += len(grid[0]) - pos
            // 递归处理上半部分(使用更小的列范围)
            ans += solve(l, mid-1, pos, R)
            // 递归处理下半部分(使用相同的列起始范围)
            ans += solve(mid+1, r, L, pos)
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += solve(mid+1, r, L, R)
        }
        
        return ans
    }
    
    return solve(0, len(grid)-1, 0, len(grid[0])-1)
}

###Python

class Solution:
    def countNegatives(self, grid: List[List[int]]) -> int:
        def solve(l: int, r: int, L: int, R: int) -> int:
            if l > r:
                return 0
            mid = l + (r - l) // 2
            pos = -1
            # 在当前行中查找第一个负数
            for i in range(L, R + 1):
                if grid[mid][i] < 0:
                    pos = i
                    break
            ans = 0
            if pos != -1:
                # 当前行找到负数,计算当前行的负数个数
                ans += len(grid[0]) - pos
                # 递归处理上半部分(使用更小的列范围)
                ans += solve(l, mid - 1, pos, R)
                # 递归处理下半部分(使用相同的列起始范围)
                ans += solve(mid + 1, r, L, pos)
            else:
                # 当前行没有负数,只需要递归处理下半部分
                ans += solve(mid + 1, r, L, R)

            return ans
            
        return solve(0, len(grid) - 1, 0, len(grid[0]) - 1)

###C

int solve(int l, int r, int L, int R, int** grid, int gridSize, int gridColSize) {
    if (l > r) {
        return 0;
    }
    
    int mid = l + (r - l) / 2;
    int pos = -1;
    // 在当前行中查找第一个负数
    for (int i = L; i <= R; i++) {
        if (grid[mid][i] < 0) {
            pos = i;
            break;
        }
    }
    
    int ans = 0;
    if (pos != -1) {
        // 当前行找到负数,计算当前行的负数个数
        ans += gridColSize - pos;
        // 递归处理上半部分(使用更小的列范围)
        ans += solve(l, mid - 1, pos, R, grid, gridSize, gridColSize);
        // 递归处理下半部分(使用相同的列起始范围)
        ans += solve(mid + 1, r, L, pos, grid, gridSize, gridColSize);
    } else {
        // 当前行没有负数,只需要递归处理下半部分
        ans += solve(mid + 1, r, L, R, grid, gridSize, gridColSize);
    }
    
    return ans;
}

int countNegatives(int** grid, int gridSize, int* gridColSize) {
    return solve(0, gridSize - 1, 0, gridColSize[0] - 1, grid, gridSize, gridColSize[0]);
}

###JavaScript

var countNegatives = function(grid) {
    const solve = (l, r, L, R) => {
        if (l > r) {
            return 0;
        }
        
        const mid = l + Math.floor((r - l) / 2);
        let pos = -1;
        // 在当前行中查找第一个负数
        for (let i = L; i <= R; i++) {
            if (grid[mid][i] < 0) {
                pos = i;
                break;
            }
        }
        
        let ans = 0;
        if (pos !== -1) {
            // 当前行找到负数,计算当前行的负数个数
            ans += grid[0].length - pos;
            // 递归处理上半部分(使用更小的列范围)
            ans += solve(l, mid - 1, pos, R);
            // 递归处理下半部分(使用相同的列起始范围)
            ans += solve(mid + 1, r, L, pos);
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += solve(mid + 1, r, L, R);
        }
        
        return ans;
    };
    
    return solve(0, grid.length - 1, 0, grid[0].length - 1);
};

###TypeScript

function countNegatives(grid: number[][]): number {
    const solve = (l: number, r: number, L: number, R: number): number => {
        if (l > r) {
            return 0;
        }
        
        const mid = l + Math.floor((r - l) / 2);
        let pos = -1;
        // 在当前行中查找第一个负数
        for (let i = L; i <= R; i++) {
            if (grid[mid][i] < 0) {
                pos = i;
                break;
            }
        }
        
        let ans = 0;
        if (pos !== -1) {
            // 当前行找到负数,计算当前行的负数个数
            ans += grid[0].length - pos;
            // 递归处理上半部分(使用更小的列范围)
            ans += solve(l, mid - 1, pos, R);
            // 递归处理下半部分(使用相同的列起始范围)
            ans += solve(mid + 1, r, L, pos);
        } else {
            // 当前行没有负数,只需要递归处理下半部分
            ans += solve(mid + 1, r, L, R);
        }
        
        return ans;
    };
    
    return solve(0, grid.length - 1, 0, grid[0].length - 1);
}

###Rust

impl Solution {
    pub fn count_negatives(grid: Vec<Vec<i32>>) -> i32 {
        fn solve(l: i32, r: i32, L: i32, R: i32, grid: &Vec<Vec<i32>>) -> i32 {
            if l > r {
                return 0;
            }
            
            let mid = l + (r - l) / 2;
            let mut pos = -1;
            // 在当前行中查找第一个负数
            for i in L..=R {
                if grid[mid as usize][i as usize] < 0 {
                    pos = i;
                    break;
                }
            }
            
            let mut ans = 0;
            if pos != -1 {
                // 当前行找到负数,计算当前行的负数个数
                ans += grid[0].len() as i32 - pos;
                // 递归处理上半部分(使用更小的列范围)
                ans += solve(l, mid - 1, pos, R, grid);
                // 递归处理下半部分(使用相同的列起始范围)
                ans += solve(mid + 1, r, L, pos, grid);
            } else {
                // 当前行没有负数,只需要递归处理下半部分
                ans += solve(mid + 1, r, L, R, grid);
            }
            
            ans
        }
        
        solve(0, grid.len() as i32 - 1, 0, grid[0].len() as i32 - 1, &grid)
    }
}

复杂度分析

  • 时间复杂度:代码中找第一个负数的位置是直接遍历 $[L,R]$ 找的,再考虑到 $n$ 和 $m$ 同阶,所以每个 $solve$ 函数里需要消耗 $O(n)$ 的时间,由主定理可得时间复杂度为:
    $$
    T(n)=2T(n/2)+O(n)=O(nlogn)
    $$

  • 空间复杂度:$O(1)$。

方法四:倒序遍历

考虑方法三发现的性质,我们可以设计一个更简单的方法。考虑我们已经算出第 $i$ 行的从前往后第一个负数的位置 $pos_i$,那么第 $i+1$ 行的时候,$pos_{i+1}$ 的位置肯定是位于 $[0,pos_i]$ 中,所以对于第 $i+1$ 行我们倒着从 $pos_i$ 循环找 $pos_{i+1}$ 即可,这个循环起始变量是一直在递减的。

###C++

class Solution {
public:
    int countNegatives(vector<vector<int>>& grid) {
        int num = 0;
        int m = (int)grid[0].size();
        int pos = (int)grid[0].size() - 1;
        
        for (auto& row : grid) {
            int i;
            for (i = pos; i >= 0; --i) {
                if (row[i] >= 0) {
                    if (i + 1 < m) {
                        pos = i + 1;
                        num += m - pos;
                    }
                    break;
                }
            }
            if (i == -1) {
                num += m;
                pos = -1;
            }
        }
        
        return num;
    }
};

###Java

class Solution {
    public int countNegatives(int[][] grid) {
        int num = 0;
        int m = grid[0].length;
        int pos = grid[0].length - 1;
        
        for (int[] row : grid) {
            int i;
            for (i = pos; i >= 0; i--) {
                if (row[i] >= 0) {
                    if (i + 1 < m) {
                        pos = i + 1;
                        num += m - pos;
                    }
                    break;
                }
            }
            if (i == -1) {
                num += m;
                pos = -1;
            }
        }
        
        return num;
    }
}

###C#

public class Solution {
    public int CountNegatives(int[][] grid) {
        int num = 0;
        int m = grid[0].Length;
        int pos = grid[0].Length - 1;
        
        foreach (int[] row in grid) {
            int i;
            for (i = pos; i >= 0; i--) {
                if (row[i] >= 0) {
                    if (i + 1 < m) {
                        pos = i + 1;
                        num += m - pos;
                    }
                    break;
                }
            }
            if (i == -1) {
                num += m;
                pos = -1;
            }
        }
        
        return num;
    }
}

###Go

func countNegatives(grid [][]int) int {
    num := 0
    m := len(grid[0])
    pos := len(grid[0]) - 1
    
    for _, row := range grid {
        i := pos
        for ; i >= 0; i-- {
            if row[i] >= 0 {
                if i + 1 < m {
                    pos = i + 1
                    num += m - pos
                }
                break
            }
        }
        if i == -1 {
            num += m
            pos = -1
        }
    }
    
    return num
}

###Python

class Solution:
    def countNegatives(self, grid: List[List[int]]) -> int:
        num = 0
        m = len(grid[0])
        pos = len(grid[0]) - 1
        
        for row in grid:
            i = pos
            while i >= 0:
                if row[i] >= 0:
                    if i + 1 < m:
                        pos = i + 1
                        num += m - pos
                    break
                i -= 1
            if i == -1:
                num += m
                pos = -1
        
        return num

###C

int countNegatives(int** grid, int gridSize, int* gridColSize) {
    int num = 0;
    int m = gridColSize[0];
    int pos = gridColSize[0] - 1;
    
    for (int i = 0; i < gridSize; i++) {
        int j;
        for (j = pos; j >= 0; j--) {
            if (grid[i][j] >= 0) {
                if (j + 1 < m) {
                    pos = j + 1;
                    num += m - pos;
                }
                break;
            }
        }
        if (j == -1) {
            num += m;
            pos = -1;
        }
    }
    
    return num;
}

###JavaScript

var countNegatives = function(grid) {
    let num = 0;
    const m = grid[0].length;
    let pos = grid[0].length - 1;
    
    for (const row of grid) {
        let i;
        for (i = pos; i >= 0; i--) {
            if (row[i] >= 0) {
                if (i + 1 < m) {
                    pos = i + 1;
                    num += m - pos;
                }
                break;
            }
        }
        if (i === -1) {
            num += m;
            pos = -1;
        }
    }
    
    return num;
};

###TypeScript

function countNegatives(grid: number[][]): number {
    let num = 0;
    const m = grid[0].length;
    let pos = grid[0].length - 1;
    
    for (const row of grid) {
        let i: number;
        for (i = pos; i >= 0; i--) {
            if (row[i] >= 0) {
                if (i + 1 < m) {
                    pos = i + 1;
                    num += m - pos;
                }
                break;
            }
        }
        if (i === -1) {
            num += m;
            pos = -1;
        }
    }
    
    return num;
}

###Rust

impl Solution {
    pub fn count_negatives(grid: Vec<Vec<i32>>) -> i32 {
        let mut num = 0;
        let m = grid[0].len();
        let mut pos = (grid[0].len() - 1) as i32;
        
        for row in grid {
            let mut i = pos;
            while i >= 0 {
                if row[i as usize] >= 0 {
                    if i + 1 < m as i32 {
                        pos = i + 1;
                        num += (m as i32) - pos;
                    }
                    break;
                }
                i -= 1;
            }
            if i == -1 {
                num += m as i32;
                pos = -1;
            }
        }
        
        num
    }
}

复杂度分析

  • 时间复杂度:考虑每次循环变量的起始位置是单调不降的,所以起始位置最多移动 $m$ 次,时间复杂度 $O(n+m)$。
  • 空间复杂度:$O(1)$。

两个最好的不重叠活动

2021年11月1日 00:19

方法一:时间戳排序

思路与算法

我们可以将所有活动的左右边界放在一起进行自定义排序。具体地,我们用 $(\textit{ts}, \textit{op}, \textit{val})$ 表示一个「事件」:

  • $\textit{op}$ 表示该事件的类型。如果 $\textit{op} = 0$,说明该事件表示一个活动的开始;如果 $\textit{op} = 1$,说明该事件表示一个活动的结束。

  • $\textit{ts}$ 表示该事件发生的时间,即活动的开始时间或结束时间。

  • $\textit{val}$ 表示该事件的价值,即对应活动的 $\textit{value}$ 值。

我们将所有的时间按照 $\textit{ts}$ 为第一关键字升序排序,这样我们就能按照时间顺序依次处理每一个事件。当 $\textit{ts}$ 相等时,我们按照 $\textit{op}$ 为第二关键字升序排序,这是因为题目中要求了「第一个活动的结束时间不能等于第二个活动的起始时间」,因此当时间相同时,我们先处理开始的事件,再处理结束的事件。

当排序完成后,我们就可以通过对所有的事件进行一次遍历,从而算出最多两个时间不重叠的活动的最大价值:

  • 当我们遍历到一个结束事件时,我们用 $\textit{val}$ 来更新 $\textit{bestFirst}$,其中 $\textit{bestFirst}$ 表示当前已经结束的所有活动的最大价值。这样做的意义在于,所有已经结束的事件都可以当作第一个活动

  • 当我们遍历到一个开始事件时,我们将该活动当作第二个活动,由于第一个活动的最大价值为 $\textit{bestFirst}$,因此我们用 $\textit{val} + \textit{bestFirst}$ 更新答案即可。

代码

###C++

struct Event {
    // 时间戳
    int ts;
    // op = 0 表示左边界,op = 1 表示右边界
    int op;
    int val;
    Event(int _ts, int _op, int _val): ts(_ts), op(_op), val(_val) {}
    bool operator< (const Event& that) const {
        return tie(ts, op) < tie(that.ts, that.op);
    }
};

class Solution {
public:
    int maxTwoEvents(vector<vector<int>>& events) {
        vector<Event> evs;
        for (const auto& event: events) {
            evs.emplace_back(event[0], 0, event[2]);
            evs.emplace_back(event[1], 1, event[2]);
        }
        sort(evs.begin(), evs.end());
        
        int ans = 0, bestFirst = 0;
        for (const auto& [ts, op, val]: evs) {
            if (op == 0) {
                ans = max(ans, val + bestFirst);
            }
            else {
                bestFirst = max(bestFirst, val);
            }
        }
        return ans;
    }
};

###Python

class Event:
    def __init__(self, ts: int, op: int, val: int):
        self.ts = ts
        self.op = op
        self.val = val
    
    def __lt__(self, other: "Event") -> bool:
        return (self.ts, self.op) < (other.ts, other.op)


class Solution:
    def maxTwoEvents(self, events: List[List[int]]) -> int:
        evs = list()
        for event in events:
            evs.append(Event(event[0], 0, event[2]))
            evs.append(Event(event[1], 1, event[2]))
        evs.sort()

        ans = bestFirst = 0
        for ev in evs:
            if ev.op == 0:
                ans = max(ans, ev.val + bestFirst)
            else:
                bestFirst = max(bestFirst, ev.val)
        
        return ans

###Java

class Solution {
    public int maxTwoEvents(int[][] events) {
        List<Event> evs = new ArrayList<>();
        for (int[] event : events) {
            evs.add(new Event(event[0], 0, event[2]));
            evs.add(new Event(event[1], 1, event[2]));
        }
        Collections.sort(evs);
        int ans = 0, bestFirst = 0;
        for (Event event : evs) {
            if (event.op == 0) {
                ans = Math.max(ans, event.val + bestFirst);
            } else {
                bestFirst = Math.max(bestFirst, event.val);
            }
        }
        return ans;
    }
    
    class Event implements Comparable<Event> {
        int ts;
        int op;
        int val;
        
        Event(int ts, int op, int val) {
            this.ts = ts;
            this.op = op;
            this.val = val;
        }
        
        @Override
        public int compareTo(Event other) {
            if (this.ts != other.ts) {
                return Integer.compare(this.ts, other.ts);
            }
            return Integer.compare(this.op, other.op);
        }
    }
}

###C#

public class Solution {
    public int MaxTwoEvents(int[][] events) {
        List<Event> evs = new List<Event>();
        foreach (var eventArr in events) {
            evs.Add(new Event(eventArr[0], 0, eventArr[2]));
            evs.Add(new Event(eventArr[1], 1, eventArr[2]));
        }
        evs.Sort();
        
        int ans = 0, bestFirst = 0;
        foreach (var ev in evs) {
            if (ev.Op == 0) {
                ans = Math.Max(ans, ev.Val + bestFirst);
            } else {
                bestFirst = Math.Max(bestFirst, ev.Val);
            }
        }
        return ans;
    }
    
    class Event : IComparable<Event> {
        public int Ts { get; set; }
        public int Op { get; set; }
        public int Val { get; set; }
        
        public Event(int ts, int op, int val) {
            Ts = ts;
            Op = op;
            Val = val;
        }
        
        public int CompareTo(Event other) {
            if (Ts != other.Ts) {
                return Ts.CompareTo(other.Ts);
            }
            return Op.CompareTo(other.Op);
        }
    }
}

###Go

func maxTwoEvents(events [][]int) int {
    type Event struct {
        ts  int
        op  int
        val int
    }
    
    evs := make([]Event, 0)
    for _, event := range events {
        evs = append(evs, Event{event[0], 0, event[2]})
        evs = append(evs, Event{event[1], 1, event[2]})
    }
    
    sort.Slice(evs, func(i, j int) bool {
        if evs[i].ts != evs[j].ts {
            return evs[i].ts < evs[j].ts
        }
        return evs[i].op < evs[j].op
    })
    
    ans, bestFirst := 0, 0
    for _, ev := range evs {
        if ev.op == 0 {
            if ev.val + bestFirst > ans {
                ans = ev.val + bestFirst
            }
        } else {
            if ev.val > bestFirst {
                bestFirst = ev.val
            }
        }
    }
    return ans
}

###C

typedef struct {
    int ts;
    int op;
    int val;
} Event;

int compareEvents(const void* a, const void* b) {
    Event* e1 = (Event*)a;
    Event* e2 = (Event*)b;
    if (e1->ts != e2->ts) {
        return e1->ts - e2->ts;
    }
    return e1->op - e2->op;
}

int maxTwoEvents(int** events, int eventsSize, int* eventsColSize) {
    Event* evs = (Event*)malloc(2 * eventsSize * sizeof(Event));
    int idx = 0;
    for (int i = 0; i < eventsSize; i++) {
        evs[idx++] = (Event){events[i][0], 0, events[i][2]};
        evs[idx++] = (Event){events[i][1], 1, events[i][2]};
    }
    qsort(evs, 2 * eventsSize, sizeof(Event), compareEvents);

    int ans = 0, bestFirst = 0;
    for (int i = 0; i < 2 * eventsSize; i++) {
        if (evs[i].op == 0) {
            if (evs[i].val + bestFirst > ans) {
                ans = evs[i].val + bestFirst;
            }
        } else {
            if (evs[i].val > bestFirst) {
                bestFirst = evs[i].val;
            }
        }
    }
    
    free(evs);
    return ans;
}

###JavaScript

var maxTwoEvents = function(events) {
    const evs = [];
    for (const event of events) {
        evs.push({ts: event[0], op: 0, val: event[2]});
        evs.push({ts: event[1], op: 1, val: event[2]});
    }
    
    evs.sort((a, b) => {
        if (a.ts !== b.ts) {
            return a.ts - b.ts;
        }
        return a.op - b.op;
    });
    
    let ans = 0, bestFirst = 0;
    for (const ev of evs) {
        if (ev.op === 0) {
            ans = Math.max(ans, ev.val + bestFirst);
        } else {
            bestFirst = Math.max(bestFirst, ev.val);
        }
    }
    return ans;
};

###TypeScript

function maxTwoEvents(events: number[][]): number {
    interface Event {
        ts: number;
        op: number;
        val: number;
    }
    
    const evs: Event[] = [];
    for (const event of events) {
        evs.push({ts: event[0], op: 0, val: event[2]});
        evs.push({ts: event[1], op: 1, val: event[2]});
    }
    
    evs.sort((a, b) => {
        if (a.ts !== b.ts) {
            return a.ts - b.ts;
        }
        return a.op - b.op;
    });
    
    let ans = 0, bestFirst = 0;
    for (const ev of evs) {
        if (ev.op === 0) {
            ans = Math.max(ans, ev.val + bestFirst);
        } else {
            bestFirst = Math.max(bestFirst, ev.val);
        }
    }
    return ans;
}

###Rust

#[derive(Debug)]
struct Event {
    ts: i32,
    op: i32,
    val: i32,
}

impl Solution {
    pub fn max_two_events(events: Vec<Vec<i32>>) -> i32 {
        let mut evs: Vec<Event> = Vec::new();
        for event in events {
            evs.push(Event { ts: event[0], op: 0, val: event[2] });
            evs.push(Event { ts: event[1], op: 1, val: event[2] });
        }
        
        evs.sort_by(|a, b| {
            if a.ts != b.ts {
                a.ts.cmp(&b.ts)
            } else {
                a.op.cmp(&b.op)
            }
        });
        
        let mut ans = 0;
        let mut best_first = 0;
        for ev in evs {
            if ev.op == 0 {
                ans = ans.max(ev.val + best_first);
            } else {
                best_first = best_first.max(ev.val);
            }
        }
        
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(n \log n)$,其中 $n$ 是数组 $\textit{events}$ 的长度。

  • 空间复杂度:$O(n)$。

❌
❌