普通视图

发现新文章,点击刷新页面。
今天 — 2025年5月18日LeetCode 每日一题题解

每日一题-用三种不同颜色为网格涂色🔴

2025年5月18日 00:00

给你两个整数 mn 。构造一个 m x n 的网格,其中每个单元格最开始是白色。请你用 红、绿、蓝 三种颜色为每个单元格涂色。所有单元格都需要被涂色。

涂色方案需要满足:不存在相邻两个单元格颜色相同的情况 。返回网格涂色的方法数。因为答案可能非常大, 返回 109 + 7 取余 的结果。

 

示例 1:

输入:m = 1, n = 1
输出:3
解释:如上图所示,存在三种可能的涂色方案。

示例 2:

输入:m = 1, n = 2
输出:6
解释:如上图所示,存在六种可能的涂色方案。

示例 3:

输入:m = 5, n = 5
输出:580986

 

提示:

  • 1 <= m <= 5
  • 1 <= n <= 1000

用三种不同颜色为网格涂色

2021年7月11日 15:17

方法一:状态压缩动态规划

提示 $1$

要使得任意两个相邻的格子的颜色均不相同,我们需要保证:

  • 同一行内任意两个相邻格子的颜色互不相同;

  • 相邻的两行之间,同一列上的两个格子的颜色互不相同。

因此,我们可以考虑:

  • 首先通过枚举的方法,找出所有对一行进行涂色的方案数;

  • 然后通过动态规划的方法,计算出对整个 $m \times n$ 的方格进行涂色的方案数。

在本题中,$m$ 和 $n$ 的最大值分别是 $5$ 和 $1000$,我们需要将较小的 $m$ 看成行的长度,较大的 $n$ 看成列的长度,这样才可以对一行进行枚举。

思路与算法

我们首先枚举对一行进行涂色的方案数。

对于我们可以选择红绿蓝三种颜色,我们可以将它们看成 $0, 1, 2$。这样一来,一种涂色方案就对应着一个长度为 $m$ 的三进制数,其十进制的范围为 $[0, 3^m)$。

因此,我们可以枚举 $[0, 3^m)$ 范围内的所有整数,将其转换为长度为 $m$ 的三进制串,再判断其是否满足任意相邻的两个数位均不相同即可。

随后我们就可以使用动态规划来计算方案数了。我们用 $f[i][\textit{mask}]$ 表示我们已经对 $0, 1, \cdots, i$ 行进行了涂色,并且第 $i$ 行的涂色方案对应的三进制表示为 $\textit{mask}$ 的前提下的总方案数。在进行状态转移时,我们可以考虑第 $i-1$ 行的涂色方案 $\textit{mask}'$:

$$
f[i][\textit{mask}] = \sum_{\textit{mask} ~与~ \textit{mask}' 同一数位上的数字均不相同} f[i-1][\textit{mask}']
$$

只要 $\textit{mask}'$ 与 $\textit{mask}$ 同一数位上的数字均不相同,就说明这两行可以相邻,我们就可以进行状态转移。

最终的答案即为所有满足 $\textit{mask} \in [0, 3^m)$ 的 $f[n-1][\textit{mask}]$ 之和。

细节

上述动态规划中的边界条件在于第 $0$ 行的涂色。当 $i=0$ 时,$f[i-1][..]$ 不是合法状态,无法进行转移,我们需要对它们进行特判:即如果 $\textit{mask}$ 任意相邻的两个数位均不相同,那么 $f[0][\textit{mask}] = 1$,否则 $f[0][\textit{mask}] = 0$。

在其余情况下的状态转移时,对于给定的 $\textit{mask}$,我们总是要找出所有满足要求的 $\textit{mask}'$,因此我们不妨也把它们预处理出来,具体可以参考下方给出的代码。

最后需要注意的是,在状态转移方程中,$f[i][..]$ 只会从 $f[i-1][..]$ 转移而来,因此我们可以使用两个长度为 $3^m$ 的一维数组,交替地进行状态转移。

代码

###C++

class Solution {
private:
    static constexpr int mod = 1000000007;

public:
    int colorTheGrid(int m, int n) {
        // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
        // 键表示 mask,值表示 mask 的三进制串(以列表的形式存储)
        unordered_map<int, vector<int>> valid;

        // 在 [0, 3^m) 范围内枚举满足要求的 mask
        int mask_end = pow(3, m);
        for (int mask = 0; mask < mask_end; ++mask) {
            vector<int> color;
            int mm = mask;
            for (int i = 0; i < m; ++i) {
                color.push_back(mm % 3);
                mm /= 3;
            }
            bool check = true;
            for (int i = 0; i < m - 1; ++i) {
                if (color[i] == color[i + 1]) {
                    check = false;
                    break;
                }
            }
            if (check) {
                valid[mask] = move(color);
            }
        }

        // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
        unordered_map<int, vector<int>> adjacent;
        for (const auto& [mask1, color1]: valid) {
            for (const auto& [mask2, color2]: valid) {
                bool check = true;
                for (int i = 0; i < m; ++i) {
                    if (color1[i] == color2[i]) {
                        check = false;
                        break;
                    }
                }
                if (check) {
                    adjacent[mask1].push_back(mask2);
                }
            }
        }

        vector<int> f(mask_end);
        for (const auto& [mask, _]: valid) {
            f[mask] = 1;
        }
        for (int i = 1; i < n; ++i) {
            vector<int> g(mask_end);
            for (const auto& [mask2, _]: valid) {
                for (int mask1: adjacent[mask2]) {
                    g[mask2] += f[mask1];
                    if (g[mask2] >= mod) {
                        g[mask2] -= mod;
                    }
                }
            }
            f = move(g);
        }

        int ans = 0;
        for (int num: f) {
            ans += num;
            if (ans >= mod) {
                ans -= mod;
            }
        }
        return ans;
    }
};

###Python

class Solution:
    def colorTheGrid(self, m: int, n: int) -> int:
        mod = 10**9 + 7
        # 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
        # 键表示 mask,值表示 mask 的三进制串(以列表的形式存储)
        valid = dict()
        
        # 在 [0, 3^m) 范围内枚举满足要求的 mask
        for mask in range(3**m):
            color = list()
            mm = mask
            for i in range(m):
                color.append(mm % 3)
                mm //= 3
            if any(color[i] == color[i + 1] for i in range(m - 1)):
                continue
            valid[mask] = color
        
        # 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
        adjacent = defaultdict(list)
        for mask1, color1 in valid.items():
            for mask2, color2 in valid.items():
                if not any(x == y for x, y in zip(color1, color2)):
                    adjacent[mask1].append(mask2)
        
        f = [int(mask in valid) for mask in range(3**m)]
        for i in range(1, n):
            g = [0] * (3**m)
            for mask2 in valid.keys():
                for mask1 in adjacent[mask2]:
                    g[mask2] += f[mask1]
                    if g[mask2] >= mod:
                        g[mask2] -= mod
            f = g
            
        return sum(f) % mod

###Java

class Solution {
    static final int mod = 1000000007;

    public int colorTheGrid(int m, int n) {
        // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
        Map<Integer, List<Integer>> valid = new HashMap<>();
        // 在 [0, 3^m) 范围内枚举满足要求的 mask
        int maskEnd = (int) Math.pow(3, m);
        for (int mask = 0; mask < maskEnd; ++mask) {
            List<Integer> color = new ArrayList<>();
            int mm = mask;
            for (int i = 0; i < m; ++i) {
                color.add(mm % 3);
                mm /= 3;
            }
            boolean check = true;
            for (int i = 0; i < m - 1; ++i) {
                if (color.get(i).equals(color.get(i + 1))) {
                    check = false;
                    break;
                }
            }
            if (check) {
                valid.put(mask, color);
            }
        }

        // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
        Map<Integer, List<Integer>> adjacent = new HashMap<>();
        for (int mask1 : valid.keySet()) {
            for (int mask2 : valid.keySet()) {
                boolean check = true;
                for (int i = 0; i < m; ++i) {
                    if (valid.get(mask1).get(i).equals(valid.get(mask2).get(i))) {
                        check = false;
                        break;
                    }
                }
                if (check) {
                    adjacent.computeIfAbsent(mask1, k -> new ArrayList<>()).add(mask2);
                }
            }
        }

        Map<Integer, Integer> f = new HashMap<>();
        for (int mask : valid.keySet()) {
            f.put(mask, 1);
        }
        for (int i = 1; i < n; ++i) {
            Map<Integer, Integer> g = new HashMap<>();
            for (int mask2 : valid.keySet()) {
                for (int mask1 : adjacent.getOrDefault(mask2, new ArrayList<>())) {
                    g.put(mask2, (g.getOrDefault(mask2, 0) + f.getOrDefault(mask1, 0)) % mod);
                }
            }
            f = g;
        }

        int ans = 0;
        for (int num : f.values()) {
            ans = (ans + num) % mod;
        }
        return ans;
    }
}

###C#

public class Solution {
    private const int mod = 1000000007;

    public int ColorTheGrid(int m, int n) {
        // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
        var valid = new Dictionary<int, List<int>>();
        // 在 [0, 3^m) 范围内枚举满足要求的 mask
        int maskEnd = (int)Math.Pow(3, m);
        for (int mask = 0; mask < maskEnd; ++mask) {
            var color = new List<int>();
            int mm = mask;
            for (int i = 0; i < m; ++i) {
                color.Add(mm % 3);
                mm /= 3;
            }
            bool check = true;
            for (int i = 0; i < m - 1; ++i) {
                if (color[i] == color[i + 1]) {
                    check = false;
                    break;
                }
            }
            if (check) {
                valid[mask] = color;
            }
        }

        // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
        var adjacent = new Dictionary<int, List<int>>();
        foreach (var mask1 in valid.Keys) {
            foreach (var mask2 in valid.Keys) {
                bool check = true;
                for (int i = 0; i < m; ++i) {
                    if (valid[mask1][i] == valid[mask2][i]) {
                        check = false;
                        break;
                    }
                }
                if (check) {
                    if (!adjacent.ContainsKey(mask1)) {
                        adjacent[mask1] = new List<int>();
                    }
                    adjacent[mask1].Add(mask2);
                }
            }
        }

        var f = new Dictionary<int, int>();
        foreach (var mask in valid.Keys) {
            f[mask] = 1;
        }
        for (int i = 1; i < n; ++i) {
            var g = new Dictionary<int, int>();
            foreach (var mask2 in valid.Keys) {
                if (adjacent.ContainsKey(mask2)) {
                    foreach (var mask1 in adjacent[mask2]) {
                        if (!g.ContainsKey(mask2)) {
                            g[mask2] = 0;
                        }
                        g[mask2] = (g[mask2] + f[mask1]) % mod;
                    }
                }
            }
            f = g;
        }

        int ans = 0;
        foreach (var num in f.Values) {
            ans = (ans + num) % mod;
        }
        return ans;
    }
}

###Go

const mod = 1000000007

func colorTheGrid(m int, n int) int {
// 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
valid := make(map[int][]int)

// 在 [0, 3^m) 范围内枚举满足要求的 mask
maskEnd := int(math.Pow(3, float64(m)))
for mask := 0; mask < maskEnd; mask++ {
color := make([]int, m)
mm := mask
for i := 0; i < m; i++ {
color[i] = mm % 3
mm /= 3
}
check := true
for i := 0; i < m-1; i++ {
if color[i] == color[i+1] {
check = false
break
}
}
if check {
valid[mask] = color
}
}

// 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
adjacent := make(map[int][]int)
for mask1 := range valid {
for mask2 := range valid {
check := true
for i := 0; i < m; i++ {
if valid[mask1][i] == valid[mask2][i] {
check = false
break
}
}
if check {
adjacent[mask1] = append(adjacent[mask1], mask2)
}
}
}

f := make(map[int]int)
for mask := range valid {
f[mask] = 1
}
for i := 1; i < n; i++ {
g := make(map[int]int)
for mask2 := range valid {
for _, mask1 := range adjacent[mask2] {
g[mask2] = (g[mask2] + f[mask1]) % mod
}
}
f = g
}

ans := 0
for _, num := range f {
ans = (ans + num) % mod
}
return ans
}

###C

#define MOD 1000000007

struct ListNode *createListNode(int val) {
    struct ListNode *obj = (struct ListNode*)malloc(sizeof(struct ListNode));
    obj->val = val;
    obj->next = NULL;
    return obj;
}

void freeList(struct ListNode *list) {
    while (list) {
        struct ListNode *p = list;
        list = list->next;
        free(p);
    }
}

typedef struct {
    int key;
    struct ListNode *val;
    UT_hash_handle hh;
} HashItem; 

HashItem *hashFindItem(HashItem **obj, int key) {
    HashItem *pEntry = NULL;
    HASH_FIND_INT(*obj, &key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, int key, int val) {
    HashItem *pEntry = hashFindItem(obj, key);
    struct ListNode *p = createListNode(val);
    if (!pEntry) {
        pEntry = (HashItem *)malloc(sizeof(HashItem));
        pEntry->key = key;
        pEntry->val = p;
        HASH_ADD_INT(*obj, key, pEntry);
    } else {
        p->next = pEntry->val;
        pEntry->val = p;
    }
    return true;
}

bool hashSetItem(HashItem **obj, int key, struct ListNode *list) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        pEntry = (HashItem *)malloc(sizeof(HashItem));
        pEntry->key = key;
        pEntry->val = list;
        HASH_ADD_INT(*obj, key, pEntry);
    } else {
        freeList(pEntry->val);
        pEntry->val = list;
    }
    return true;
}

struct ListNode* hashGetItem(HashItem **obj, int key) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        return NULL;
    }
    return pEntry->val;
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr); 
        freeList(curr->val); 
        free(curr);
    }
}

// 主函数
int colorTheGrid(int m, int n) {
    // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
    // 键表示 mask,值表示 mask 的三进制串(以列表的形式存储)
    HashItem *valid = NULL;
    // 在 [0, 3^m) 范围内枚举满足要求的 mask
    int mask_end = pow(3, m);
    for (int mask = 0; mask < mask_end; ++mask) {
        int mm = mask;
        int color[m];
        for (int i = 0; i < m; ++i) {
            color[i] = mm % 3;
            mm /= 3;
        }
        bool check = true;
        for (int i = 0; i < m - 1; ++i) {
            if (color[i] == color[i + 1]) {
                check = false;
                break;
            }
        }
        if (check) {
            for (int i = 0; i < m; i++) {
                hashAddItem(&valid, mask, color[i]);
            }
        }
    }

    // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
    HashItem *adjacent = NULL;
    for (HashItem *pEntry1 = valid; pEntry1; pEntry1 = pEntry1->hh.next) {
        int mask1 = pEntry1->key;
        for (HashItem *pEntry2 = valid; pEntry2; pEntry2 = pEntry2->hh.next) {
            int mask2 = pEntry2->key;
            bool check = true;
            for (struct ListNode *p1 = pEntry1->val, *p2 = pEntry2->val; p1 && p2; p1 = p1->next, p2 = p2->next) {
                if (p1->val == p2->val) {
                    check = false;
                    break;
                }
            }
            if (check) {
                hashAddItem(&adjacent, mask1, mask2);
            }
        }
    }

    int f[mask_end];
    memset(f, 0, sizeof(f));
    for (HashItem *pEntry = valid; pEntry; pEntry = pEntry->hh.next) {
        int mask = pEntry->key;
        f[mask] = 1;
    }
    for (int i = 1; i < n; ++i) {
        int g[mask_end];
        memset(g, 0, sizeof(g));
        for (HashItem *pEntry1 = valid; pEntry1; pEntry1 = pEntry1->hh.next) {
            int mask2 = pEntry1->key;
            for (struct ListNode *p = hashGetItem(&adjacent, mask2); p != NULL; p = p->next) {
                int mask1 = p->val;
                g[mask2] += f[mask1];
                if (g[mask2] >= MOD) {
                    g[mask2] -= MOD;
                }
            }
        }
        memcpy(f, g, sizeof(f));
    }

    int ans = 0;
    for (int i = 0; i < mask_end; i++) {
        ans += f[i];
        if (ans >= MOD) {
            ans -= MOD;
        }
    }
    hashFree(&valid);
    hashFree(&adjacent);
    return ans;
}

###JavaScript

var colorTheGrid = function(m, n) {
    const mod = 1000000007;
    // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
    const valid = new Map();
    // 在 [0, 3^m) 范围内枚举满足要求的 mask
    const maskEnd = Math.pow(3, m);
    for (let mask = 0; mask < maskEnd; ++mask) {
        const color = [];
        let mm = mask;
        for (let i = 0; i < m; ++i) {
            color.push(mm % 3);
            mm = Math.floor(mm / 3);
        }
        let check = true;
        for (let i = 0; i < m - 1; ++i) {
            if (color[i] === color[i + 1]) {
                check = false;
                break;
            }
        }
        if (check) {
            valid.set(mask, color);
        }
    }

    // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
    const adjacent = new Map();
    for (const [mask1, color1] of valid.entries()) {
        for (const [mask2, color2] of valid.entries()) {
            let check = true;
            for (let i = 0; i < m; ++i) {
                if (color1[i] === color2[i]) {
                    check = false;
                    break;
                }
            }
            if (check) {
                if (!adjacent.has(mask1)) {
                    adjacent.set(mask1, []);
                }
                adjacent.get(mask1).push(mask2);
            }
        }
    }

    let f = new Map();
    for (const [mask, _] of valid.entries()) {
        f.set(mask, 1);
    }
    for (let i = 1; i < n; ++i) {
        const g = new Map();
        for (const [mask2, _] of valid.entries()) {
            for (const mask1 of adjacent.get(mask2) || []) {
                g.set(mask2, ((g.get(mask2) || 0) + f.get(mask1)) % mod);
            }
        }
        f = g;
    }

    let ans = 0;
    for (const num of f.values()) {
        ans = (ans + num) % mod;
    }
    return ans;
}

###TypeScript

function colorTheGrid(m: number, n: number): number {
    const mod = 1000000007;
    // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
    const valid = new Map<number, number[]>();

    // 在 [0, 3^m) 范围内枚举满足要求的 mask
    const maskEnd = Math.pow(3, m);
    for (let mask = 0; mask < maskEnd; ++mask) {
        const color: number[] = [];
        let mm = mask;
        for (let i = 0; i < m; ++i) {
            color.push(mm % 3);
            mm = Math.floor(mm / 3);
        }
        let check = true;
        for (let i = 0; i < m - 1; ++i) {
            if (color[i] === color[i + 1]) {
                check = false;
                break;
            }
        }
        if (check) {
            valid.set(mask, color);
        }
    }

    // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
    const adjacent = new Map<number, number[]>();
    for (const [mask1, color1] of valid.entries()) {
        for (const [mask2, color2] of valid.entries()) {
            let check = true;
            for (let i = 0; i < m; ++i) {
                if (color1[i] === color2[i]) {
                    check = false;
                    break;
                }
            }
            if (check) {
                if (!adjacent.has(mask1)) {
                    adjacent.set(mask1, []);
                }
                adjacent.get(mask1)!.push(mask2);
            }
        }
    }

    let f = new Map<number, number>();
    for (const [mask, _] of valid.entries()) {
        f.set(mask, 1);
    }
    for (let i = 1; i < n; ++i) {
        const g = new Map<number, number>();
        for (const [mask2, _] of valid.entries()) {
            for (const mask1 of adjacent.get(mask2) || []) {
                g.set(mask2, ((g.get(mask2) || 0) + f.get(mask1)!) % mod);
            }
        }
        f = g;
    }

    let ans = 0;
    for (const num of f.values()) {
        ans = (ans + num) % mod;
    }
    return ans;
}

###Rust

use std::collections::HashMap;

const MOD: i32 = 1_000_000_007;

impl Solution {
    pub fn color_the_grid(m: i32, n: i32) -> i32 {
        let m = m as usize;
        let n = n as usize;
        // 哈希映射 valid 存储所有满足要求的对一行进行涂色的方案
        let mut valid = HashMap::new();
        // 在 [0, 3^m) 范围内枚举满足要求的 mask
        let mask_end = 3i32.pow(m as u32);
        for mask in 0..mask_end {
            let mut color = Vec::new();
            let mut mm = mask;
            for _ in 0..m {
                color.push(mm % 3);
                mm /= 3;
            }
            let mut check = true;
            for i in 0..m - 1 {
                if color[i] == color[i + 1] {
                    check = false;
                    break;
                }
            }
            if check {
                valid.insert(mask, color);
            }
        }

        // 预处理所有的 (mask1, mask2) 二元组,满足 mask1 和 mask2 作为相邻行时,同一列上两个格子的颜色不同
        let mut adjacent = HashMap::new();
        for (&mask1, color1) in &valid {
            for (&mask2, color2) in &valid {
                let mut check = true;
                for i in 0..m {
                    if color1[i] == color2[i] {
                        check = false;
                        break;
                    }
                }
                if check {
                    adjacent.entry(mask1).or_insert(Vec::new()).push(mask2);
                }
            }
        }

        let mut f = HashMap::new();
        for &mask in valid.keys() {
            f.insert(mask, 1);
        }
        for _ in 1..n {
            let mut g = HashMap::new();
            for &mask2 in valid.keys() {
                let mut total = 0;
                if let Some(list) = adjacent.get(&mask2) {
                    for &mask1 in list {
                        total = (total + f.get(&mask1).unwrap_or(&0)) % MOD;
                    }
                }
                g.insert(mask2, total);
            }
            f = g;
        }

        let mut ans = 0;
        for &num in f.values() {
            ans = (ans + num) % MOD;
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(3^{2m} \cdot n)$。

    • 预处理 $\textit{mask}$ 的时间复杂度为 $O(m \cdot 3^m)$;

    • 预处理 $(\textit{mask}, \textit{mask}')$ 二元组的时间复杂度为 $O(3^{2m})$;

    • 动态规划的时间复杂度为 $O(3^{2m} \cdot n)$,其在渐近意义下大于前两者。

  • 空间复杂度:$O(3^{2m})$。

    • 存储 $\textit{mask}$ 的哈希映射需要的空间为 $O(m \cdot 3^m)$;

    • 存储 $(\textit{mask}, \textit{mask}')$ 二元组需要的空间为 $O(3^{2m})$,在渐进意义下大于其余两者;

    • 动态规划存储状态需要的空间为 $O(3^m)$。

不过需要注意的是,在实际的情况下,当 $m=5$ 时,满足要求的 $\textit{mask}$ 仅有 $48$ 个,远小于 $3^m=324$;满足要求的 $(\textit{mask}, \textit{mask}')$ 二元组仅有 $486$ 对,远小于 $3^{2m}=59049$。因此该算法的实际运行时间会较快。

O(2^m*n)大聪明解法,推了个小时,0ms用时,绝对双百

作者 lzt666
2021年7月11日 13:06

A代表第一种颜色,B代表第二种颜色,C代表第三种颜色。

0代表红色,1代表绿色,2代表蓝色。

1. m=1

第一行有3种情况,且接下来的所有行,都是2种情况

    long long mod = 1000000007;
    int ans=3;
    for(int i=1;i<n;++i)    
        ans= (ans * 2LL) % mod;
    return ans;

2. m=2

所有颜色排列有6种

01, 02, 10, 12, 20, 21

可分类为AB
对于第一行为AB,第二行则有BA、BC、CA三种情况,第三行同理有对应三种情况,第n行同理有三种情况。

    long long mod = 1000000007;
    int ans=6;
    for(int i=1;i<n;++i)
        ans= (ans * 3LL) % mod;
    return ans;

3. m=3

所有颜色排列有12种

010, 012, 020, 021, 101, 102, 120, 121, 201, 202, 210, 212

可分类为ABC和ABA

  • ABC类:共6种:012, 021, 102, 120, 201, 210;
  • ABA类:共6种:010, 020, 101, 121, 202, 212。

则可据此根据上一行的类型递推该行的类型种数。

  • 第 i - 1 行是 ABC 类,第 i 行是 ABC 类:以 012 为例,那么第 i 行只能是120 或 201,方案数为 2;
  • 第 i - 1 行是 ABC 类,第 i 行是 ABA 类:以 012 为例,那么第 i 行只能是101 或 121,方案数为 2;
  • 第 i - 1 行是 ABA 类,第 i 行是 ABC 类:以 010 为例,那么第 i 行只能是102 或 201,方案数为 2;
  • 第 i - 1 行是 ABA 类,第 i 行是 ABA 类:以 010 为例,那么第 i 行只能是101,121 或 202,方案数为 3。

故有递推式

f[i][0] = 2 * f[i - 1][0] + 2 * f[i - 1][1];
f[i][1] = 2 * f[i - 1][0] + 3 * f[i - 1][1];

4. m=4

所有颜色排列有24种
可分类为ABCA、ABCB、ABAB、ABAC
(实际上ABCB和ABAC可归为一类,见评论区用户AndrewPei代码)

  • ABCA类:共6种
  • ABCB类:共6种
  • ABAB类:共6种
  • ABAC类:共6种

则可据此根据上一行的类型递推该行的类型种数。

  • 第 i - 1 行是 ABCA 类,第 i 行是 ABCA 类:方案数为 3;
  • 第 i - 1 行是 ABCA 类,第 i 行是 ABCB 类:方案数为 2;
  • 第 i - 1 行是 ABCA 类,第 i 行是 ABAB 类:方案数为 1;
  • 第 i - 1 行是 ABCA 类,第 i 行是 ABAC 类:方案数为 2。
  • 第 i - 1 行是 ABCB 类,第 i 行是 ABCA 类:方案数为 2;
  • 第 i - 1 行是 ABCB 类,第 i 行是 ABCB 类:方案数为 2;
  • 第 i - 1 行是 ABCB 类,第 i 行是 ABAB 类:方案数为 1;
  • 第 i - 1 行是 ABCB 类,第 i 行是 ABAC 类:方案数为 2。
  • 第 i - 1 行是 ABAB 类,第 i 行是 ABCA 类:方案数为 1;
  • 第 i - 1 行是 ABAB 类,第 i 行是 ABCB 类:方案数为 1;
  • 第 i - 1 行是 ABAB 类,第 i 行是 ABAB 类:方案数为 2;
  • 第 i - 1 行是 ABAB 类,第 i 行是 ABAC 类:方案数为 1。
  • 第 i - 1 行是 ABAC 类,第 i 行是 ABCA 类:方案数为 2;
  • 第 i - 1 行是 ABAC 类,第 i 行是 ABCB 类:方案数为 2;
  • 第 i - 1 行是 ABAC 类,第 i 行是 ABAB 类:方案数为 1;
  • 第 i - 1 行是 ABAC 类,第 i 行是 ABAC 类:方案数为 2。

故有递推式

f[i][0] = 3 * f[i - 1][0] + 2 * f[i - 1][1] + 1 * f[i - 1][2] + 2 * f[i - 1][3];
f[i][1] = 2 * f[i - 1][0] + 2 * f[i - 1][1] + 1 * f[i - 1][2] + 2 * f[i - 1][3];
f[i][2] = 1 * f[i - 1][0] + 1 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3];
f[i][3] = 2 * f[i - 1][0] + 2 * f[i - 1][1] + 1 * f[i - 1][2] + 2 * f[i - 1][3];

5. m=5

同理,实在是写不下去了,直接上递推式

f[i][0] = 3 * f[i - 1][0] + 2 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3] + 0 * f[i - 1][4] + 1 * f[i - 1][5] + 2 * f[i - 1][6] + 2 * f[i - 1][7];
f[i][1] = 2 * f[i - 1][0] + 2 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3] + 1 * f[i - 1][4] + 1 * f[i - 1][5] + 1 * f[i - 1][6] + 1 * f[i - 1][7];
f[i][2] = 2 * f[i - 1][0] + 2 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3] + 0 * f[i - 1][4] + 1 * f[i - 1][5] + 2 * f[i - 1][6] + 2 * f[i - 1][7];
f[i][3] = 1 * f[i - 1][0] + 1 * f[i - 1][1] + 1 * f[i - 1][2] + 2 * f[i - 1][3] + 1 * f[i - 1][4] + 1 * f[i - 1][5] + 1 * f[i - 1][6] + 1 * f[i - 1][7];
f[i][4] = 0 * f[i - 1][0] + 1 * f[i - 1][1] + 0 * f[i - 1][2] + 1 * f[i - 1][3] + 2 * f[i - 1][4] + 1 * f[i - 1][5] + 0 * f[i - 1][6] + 1 * f[i - 1][7];
f[i][5] = 1 * f[i - 1][0] + 1 * f[i - 1][1] + 1 * f[i - 1][2] + 1 * f[i - 1][3] + 1 * f[i - 1][4] + 2 * f[i - 1][5] + 1 * f[i - 1][6] + 1 * f[i - 1][7];
f[i][6] = 2 * f[i - 1][0] + 1 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3] + 0 * f[i - 1][4] + 1 * f[i - 1][5] + 2 * f[i - 1][6] + 1 * f[i - 1][7];
f[i][7] = 2 * f[i - 1][0] + 1 * f[i - 1][1] + 2 * f[i - 1][2] + 1 * f[i - 1][3] + 1 * f[i - 1][4] + 1 * f[i - 1][5] + 1 * f[i - 1][6] + 2 * f[i - 1][7];

代码如下(貌似系数可以矩阵快速幂递推来着,等我哪天有时间再试试)

class Solution {
public:
    int colorTheGrid(int m, int n) {
        long long mod = 1000000007;
        if(m==1)
        {
            int ans=3;
            for(int i=1;i<n;++i)    ans= ans * 2LL % mod;
            return ans;
        }
        else if(m==2)
        {
            int fi = 6;
            for(int i=1;i<n;++i)    fi= 3LL * fi % mod;
            return fi;
        }
        else if(m==3)
        {
            int fi0 = 6, fi1 = 6;
            for (int i = 1; i < n; ++i) {
                int new_fi0 = (2LL * fi0 + 2LL * fi1) % mod;
                int new_fi1 = (2LL * fi0 + 3LL * fi1) % mod;
                fi0 = new_fi0;
                fi1 = new_fi1;
            }
            return ((long long)fi0 + fi1) % mod;
        }
        else if(m==4)
        {
            //ABAB//ABAC//ABCA//ABCB
            int fi0 = 6, fi1 = 6, fi2=6, fi3=6;
            for (int i = 1; i < n; ++i) {
                int new_fi0 = (3LL * fi0 + 2LL * fi1+ 1LL*fi2+ 2LL*fi3) % mod;
                int new_fi1 = (2LL * fi0 + 2LL * fi1+ 1LL*fi2+2LL*fi3) % mod;
                int new_fi2 = (1LL * fi0 + 1LL * fi1+ 2LL*fi2 +1LL*fi3) % mod;
                int new_fi3 = (2LL * fi0 + 2LL * fi1+ 1LL*fi2+2LL*fi3) % mod;
                fi0 = new_fi0;
                fi1 = new_fi1;
                fi2 = new_fi2;
                fi3 = new_fi3;
            }
            return ((long long)fi0 + fi1+ fi2+ fi3) % mod;
        }
        else
        {
            //ABABA//ABABC//ABACA//ABACB//ABCAB//ABCAC//ABCBA//ABCBC
            int fi0 = 6, fi1 = 6, fi2=6 ,fi3 =6, fi4=6, fi5=6, fi6=6, fi7=6;
            for (int i = 1; i < n; ++i) {
                int new_fi0 = (3LL * fi0 + 2LL * fi1+ 2LL*fi2+ 1LL*fi3+ 0LL*fi4 +1LL*fi5 +2LL*fi6+2LL*fi7) % mod;
                int new_fi1 = (2LL * fi0 + 2LL * fi1+ 2LL*fi2+ 1LL*fi3+ 1LL*fi4 +1LL*fi5 +1LL*fi6+1LL*fi7) % mod;
                int new_fi2 = (2LL * fi0 + 2LL * fi1+ 2LL*fi2+ 1LL*fi3+ 0LL*fi4 +1LL*fi5 +2LL*fi6+2LL*fi7) % mod;
                int new_fi3 = (1LL * fi0 + 1LL * fi1+ 1LL*fi2+ 2LL*fi3+ 1LL*fi4 +1LL*fi5 +1LL*fi6+1LL*fi7) % mod;
                int new_fi4 = (0LL * fi0 + 1LL * fi1+ 0LL*fi2+ 1LL*fi3+ 2LL*fi4 +1LL*fi5 +0LL*fi6+1LL*fi7) % mod;
                int new_fi5 = (1LL * fi0 + 1LL * fi1+ 1LL*fi2+ 1LL*fi3+ 1LL*fi4 +2LL*fi5 +1LL*fi6+1LL*fi7) % mod;
                int new_fi6 = (2LL * fi0 + 1LL * fi1+ 2LL*fi2+ 1LL*fi3+ 0LL*fi4 +1LL*fi5 +2LL*fi6+1LL*fi7) % mod;
                int new_fi7 = (2LL * fi0 + 1LL * fi1+ 2LL*fi2+ 1LL*fi3+ 1LL*fi4 +1LL*fi5 +1LL*fi6+2LL*fi7) % mod;
                fi0 = new_fi0;
                fi1 = new_fi1;
                fi2 = new_fi2;
                fi3 = new_fi3;
                fi4 = new_fi4;
                fi5 = new_fi5;
                fi6 = new_fi6;
                fi7 = new_fi7;
            }
            return ((long long)fi0 + fi1+ fi2+ fi3+ fi4 + fi5+ fi6+ fi7) % mod;
        }
    }
};

综上所述

  1. 对于m = 1,可分为1种情况
  2. 对于m > 1,可分为2^(m-2)种情况

故时间复杂度为O((2^m)*n)

记忆化搜索/递推/矩阵快速幂(Python/Java/C++/Go)

作者 endlesscheng
2021年7月11日 12:09

前言

如果只有红绿两种颜色,可以把这两种颜色分别用 $0$ 和 $1$ 表示,用一个长为 $m$ 的二进制数表示一行的颜色。

例如 $m=5$,二进制数 $01010_{(2)}$ 表示红绿红绿红。

本题有红绿蓝三种颜色,可以分别用 $0,1,2$ 表示,用一个长为 $m$ 的三进制数表示一行的颜色。

例如 $m=5$,三进制数 $01202_{(3)}$ 表示红绿蓝红蓝。

注:本题不区分左右,三进制数从高到低读还是从低到高读都可以。

思路

首先预处理所有合法的(没有相邻相同颜色的)三进制数,记在数组 $\textit{valid}$ 中。

然后对于每个 $\textit{valid}[i]$,预处理它的下一列颜色,要求左右相邻颜色不同。把 $\textit{valid}$ 的下标记在数组 $\textit{nxt}[i]$ 中。

预处理这些数据之后,就可以 DP 了。

对于 $m\times n$ 的网格,如果最后一列填的是三进制数 $\textit{valid}[j]$,那么问题为:对于 $m\times (n-1)$ 的网格,最后一列填的是三进制数 $\textit{valid}[j]$ 的情况下的涂色方案数。

继续,如果倒数第二列填的是三进制数 $\textit{valid}[k]$,那么接下来要解决的问题为:对于 $m\times (n-2)$ 的网格,右边一列填的是三进制数 $\textit{valid}[k]$ 的情况下的涂色方案数。

所以定义 $\textit{dfs}(i,j)$ 表示对于 $m\times i$ 的网格,右边第 $i+1$ 列填的是三进制数 $\textit{valid}[j]$ 的情况下的涂色方案数。

枚举第 $i$ 列填颜色 $\textit{valid}[k]$(其中 $k$ 是 $\textit{nxt}[j]$ 中的元素),问题变成对于 $m\times (i-1)$ 的网格,右边第 $i$ 列填的是三进制数 $\textit{valid}[k]$ 的情况下的涂色方案数。

累加得

$$
\textit{dfs}(i,j) = \sum_{k} \textit{dfs}(i-1,k)
$$

递归边界:$\textit{dfs}(0,j) = 1$,表示找到了一个合法涂色方案。

递归入口:$\displaystyle\sum\limits_{j} \textit{dfs}(n-1, j)$。第 $n$ 列填颜色 $\textit{valid}[j]$。

细节

三进制数最大为 $22\ldots 2_{(3)} = 3^m-1$。枚举 $[0,3^m-1]$ 中的三进制数,怎么判断一个三进制数是否合法?

我们需要取出三进制数中的每一位。

回想一下十进制数 $12345$ 怎么取出百位的 $3$:$12345$ 除以 $100$ 下取整,得到 $123$,再模 $10$,得到 $3$。

所以对于三进制数,可以除以 $3^i$ 下取整,再模 $3$。

为什么可以在 DP 的计算过程中取模?可以看 模运算的世界:当加减乘除遇上取模

写法一:记忆化搜索

class Solution:
    def colorTheGrid(self, m: int, n: int) -> int:
        pow3 = [3 ** i for i in range(m)]
        valid = []
        for color in range(3 ** m):
            for i in range(1, m):
                if color // pow3[i] % 3 == color // pow3[i - 1] % 3:  # 相邻颜色相同
                    break
            else:  # 没有中途 break,合法
                valid.append(color)

        nv = len(valid)
        nxt = [[] for _ in range(nv)]
        for i, color1 in enumerate(valid):
            for j, color2 in enumerate(valid):
                for p3 in pow3:
                    if color1 // p3 % 3 == color2 // p3 % 3:  # 相邻颜色相同
                        break
                else:  # 没有中途 break,合法
                    nxt[i].append(j)

        MOD = 1_000_000_007
        @cache  # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
        def dfs(i: int, j: int) -> int:
            if i == 0:
                return 1  # 找到了一个合法涂色方案
            return sum(dfs(i - 1, k) for k in nxt[j]) % MOD
        return sum(dfs(n - 1, j) for j in range(nv)) % MOD
class Solution {
    private static final int MOD = 1_000_000_007;

    public int colorTheGrid(int m, int n) {
        int[] pow3 = new int[m];
        pow3[0] = 1;
        for (int i = 1; i < m; i++) {
            pow3[i] = pow3[i - 1] * 3;
        }

        List<Integer> valid = new ArrayList<>();
        next:
        for (int color = 0; color < pow3[m - 1] * 3; color++) {
            for (int i = 1; i < m; i++) {
                if (color / pow3[i] % 3 == color / pow3[i - 1] % 3) { // 相邻颜色相同
                    continue next;
                }
            }
            valid.add(color);
        }

        int nv = valid.size();
        List<Integer>[] nxt = new ArrayList[nv];
        Arrays.setAll(nxt, i -> new ArrayList<>());
        for (int i = 0; i < nv; i++) {
            next2:
            for (int j = 0; j < nv; j++) {
                for (int p3 : pow3)
                    if (valid.get(i) / p3 % 3 == valid.get(j) / p3 % 3) { // 相邻颜色相同
                        continue next2;
                    }
                nxt[i].add(j);
            }
        }

        int[][] memo = new int[n][nv];
        for (int[] row : memo) {
            Arrays.fill(row, -1);
        }

        long ans = 0;
        for (int j = 0; j < nv; j++) {
            ans += dfs(n - 1, j, nxt, memo);
        }
        return (int) (ans % MOD);
    }

    private int dfs(int i, int j, List<Integer>[] nxt, int[][] memo) {
        if (i == 0) {
            return 1; // 找到了一个合法涂色方案
        }
        if (memo[i][j] != -1) { // 之前计算过
            return memo[i][j];
        }
        long res = 0;
        for (int k : nxt[j]) {
            res += dfs(i - 1, k, nxt, memo);
        }
        return memo[i][j] = (int) (res % MOD); // 记忆化
    }
}
class Solution {
    const int MOD = 1'000'000'007;
public:
    int colorTheGrid(int m, int n) {
        vector<int> pow3(m);
        pow3[0] = 1;
        for (int i = 1; i < m; i++) {
            pow3[i] = pow3[i - 1] * 3;
        }

        vector<int> valid;
        for (int color = 0; color < pow3[m - 1] * 3; color++) {
            bool ok = true;
            for (int i = 1; i < m; i++) {
                if (color / pow3[i] % 3 == color / pow3[i - 1] % 3) { // 相邻颜色相同
                    ok = false;
                    break;
                }
            }
            if (ok) {
                valid.push_back(color);
            }
        }

        int nv = valid.size();
        vector<vector<int>> nxt(nv);
        for (int i = 0; i < nv; i++) {
            for (int j = 0; j < nv; j++) {
                bool ok = true;
                for (int k = 0; k < m; k++) {
                    if (valid[i] / pow3[k] % 3 == valid[j] / pow3[k] % 3) { // 相邻颜色相同
                        ok = false;
                        break;
                    }
                }
                if (ok) {
                    nxt[i].push_back(j);
                }
            }
        }

        vector memo(n, vector<int>(nv, -1));
        auto dfs = [&](this auto&& dfs, int i, int j) -> int {
            if (i == 0) {
                return 1; // 找到了一个合法涂色方案
            }
            int& res = memo[i][j]; // 注意这里是引用
            if (res != -1) { // 之前计算过
                return res;
            }
            res = 0;
            for (int k : nxt[j]) {
                res = (res + dfs(i - 1, k)) % MOD;
            }
            return res;
        };

        long long ans = 0;
        for (int j = 0; j < nv; j++) {
            ans += dfs(n - 1, j);
        }
        return ans % MOD;
    }
};
func colorTheGrid(m, n int) int {
const mod = 1_000_000_007
pow3 := make([]int, m)
pow3[0] = 1
for i := 1; i < m; i++ {
pow3[i] = pow3[i-1] * 3
}

valid := []int{}
next:
for color := range pow3[m-1] * 3 {
for i := range m - 1 {
if color/pow3[i+1]%3 == color/pow3[i]%3 { // 相邻颜色相同
continue next
}
}
valid = append(valid, color)
}

nv := len(valid)
nxt := make([][]int, nv)
for i, color1 := range valid {
next2:
for j, color2 := range valid {
for _, p3 := range pow3 {
if color1/p3%3 == color2/p3%3 { // 相邻颜色相同
continue next2
}
}
nxt[i] = append(nxt[i], j)
}
}

memo := make([][]int, n)
for i := range memo {
memo[i] = make([]int, nv)
for j := range memo[i] {
memo[i][j] = -1
}
}
var dfs func(int, int) int
dfs = func(i, j int) (res int) {
if i == 0 {
return 1 // 找到了一个合法涂色方案
}
p := &memo[i][j]
if *p != -1 { // 之前计算过
return *p
}
defer func() { *p = res }() // 记忆化
for _, k := range nxt[j] {
res += dfs(i-1, k)
}
return res % mod
}

ans := 0
for j := range nv {
ans += dfs(n-1, j)
}
return ans % mod
}

复杂度分析

有多少个状态?$\textit{valid}$ 有多长?

对于一列长为 $m$ 的涂色方案,第一个颜色有 $3$ 种,其余颜色不能与上一个颜色相同,所以都是 $2$ 种。所以 $\textit{valid}$ 的长度为

$$
3\cdot 2^{m-1}
$$

所以状态个数为 $i$ 的个数 $\mathcal{O}(n)$ 乘以 $j$ 的个数 $\mathcal{O}(2^m)$,一共有 $\mathcal{O}(n2^m)$ 个状态。

  • 时间复杂度:$\mathcal{O}(n4^m)$。由于每个状态只会计算一次,动态规划的时间复杂度 $=$ 状态个数 $\times$ 单个状态的计算时间。本题状态个数等于 $\mathcal{O}(n2^m)$,单个状态的计算时间为 $\mathcal{O}(2^m)$,所以总的时间复杂度为 $\mathcal{O}(n4^m)$。
  • 空间复杂度:$\mathcal{O}(4^m + n2^m)$。其中 $\mathcal{O}(4^m)$ 是 $\textit{nxt}$ 需要的空间,$\mathcal{O}(n2^m)$ 是记忆化搜索需要的空间。

写法二:递推

把记忆化搜索 1:1 翻译成递推,原理见 动态规划入门:从记忆化搜索到递推【基础算法精讲 17】

class Solution:
    def colorTheGrid(self, m: int, n: int) -> int:
        pow3 = [3 ** i for i in range(m)]
        valid = []
        for color in range(3 ** m):
            for i in range(1, m):
                if color // pow3[i] % 3 == color // pow3[i - 1] % 3:  # 相邻颜色相同
                    break
            else:  # 没有中途 break,合法
                valid.append(color)

        nv = len(valid)
        nxt = [[] for _ in range(nv)]
        for i, color1 in enumerate(valid):
            for j, color2 in enumerate(valid):
                for p3 in pow3:
                    if color1 // p3 % 3 == color2 // p3 % 3:  # 相邻颜色相同
                        break
                else:  # 没有中途 break,合法
                    nxt[i].append(j)

        MOD = 1_000_000_007
        f = [[0] * nv for _ in range(n)]
        f[0] = [1] * nv  # dfs 的递归边界就是 DP 数组的初始值
        for i in range(1, n):
            for j in range(nv):
                f[i][j] = sum(f[i - 1][k] for k in nxt[j]) % MOD
        return sum(f[-1]) % MOD  # 递归入口就是答案
class Solution {
    private static final int MOD = 1_000_000_007;

    public int colorTheGrid(int m, int n) {
        int[] pow3 = new int[m];
        pow3[0] = 1;
        for (int i = 1; i < m; i++) {
            pow3[i] = pow3[i - 1] * 3;
        }

        List<Integer> valid = new ArrayList<>();
        next:
        for (int color = 0; color < pow3[m - 1] * 3; color++) {
            for (int i = 1; i < m; i++) {
                if (color / pow3[i] % 3 == color / pow3[i - 1] % 3) { // 相邻颜色相同
                    continue next;
                }
            }
            valid.add(color);
        }

        int nv = valid.size();
        List<Integer>[] nxt = new ArrayList[nv];
        Arrays.setAll(nxt, i -> new ArrayList<>());
        for (int i = 0; i < nv; i++) {
            next2:
            for (int j = 0; j < nv; j++) {
                for (int p3 : pow3)
                    if (valid.get(i) / p3 % 3 == valid.get(j) / p3 % 3) { // 相邻颜色相同
                        continue next2;
                    }
                nxt[i].add(j);
            }
        }

        int[][] f = new int[n][nv];
        Arrays.fill(f[0], 1);
        for (int i = 1; i < n; i++) {
            for (int j = 0; j < nv; j++) {
                for (int k : nxt[j]) {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % MOD;
                }
            }
        }

        long ans = 0;
        for (int j = 0; j < nv; j++) {
            ans += f[n - 1][j];
        }
        return (int) (ans % MOD);
    }
}
class Solution {
    const int MOD = 1'000'000'007;
public:
    int colorTheGrid(int m, int n) {
        vector<int> pow3(m);
        pow3[0] = 1;
        for (int i = 1; i < m; i++) {
            pow3[i] = pow3[i - 1] * 3;
        }

        vector<int> valid;
        for (int color = 0; color < pow3[m - 1] * 3; color++) {
            bool ok = true;
            for (int i = 1; i < m; i++) {
                if (color / pow3[i] % 3 == color / pow3[i - 1] % 3) { // 相邻颜色相同
                    ok = false;
                    break;
                }
            }
            if (ok) {
                valid.push_back(color);
            }
        }

        int nv = valid.size();
        vector<vector<int>> nxt(nv);
        for (int i = 0; i < nv; i++) {
            for (int j = 0; j < nv; j++) {
                bool ok = true;
                for (int k = 0; k < m; k++) {
                    if (valid[i] / pow3[k] % 3 == valid[j] / pow3[k] % 3) { // 相邻颜色相同
                        ok = false;
                        break;
                    }
                }
                if (ok) {
                    nxt[i].push_back(j);
                }
            }
        }

        vector f(n, vector<int>(nv));
        ranges::fill(f[0], 1);
        for (int i = 1; i < n; i++) {
            for (int j = 0; j < nv; j++) {
                for (int k : nxt[j]) {
                    f[i][j] = (f[i][j] + f[i - 1][k]) % MOD;
                }
            }
        }

        long long ans = 0;
        for (int j = 0; j < nv; j++) {
            ans += f[n - 1][j];
        }
        return ans % MOD;
    }
};
func colorTheGrid(m, n int) int {
const mod = 1_000_000_007
pow3 := make([]int, m)
pow3[0] = 1
for i := 1; i < m; i++ {
pow3[i] = pow3[i-1] * 3
}

valid := []int{}
next:
for color := range pow3[m-1] * 3 {
for i := range m - 1 {
if color/pow3[i+1]%3 == color/pow3[i]%3 { // 相邻颜色相同
continue next
}
}
valid = append(valid, color)
}

nv := len(valid)
nxt := make([][]int, nv)
for i, color1 := range valid {
next2:
for j, color2 := range valid {
for _, p3 := range pow3 {
if color1/p3%3 == color2/p3%3 { // 相邻颜色相同
continue next2
}
}
nxt[i] = append(nxt[i], j)
}
}

f := make([][]int, n)
for i := range f {
f[i] = make([]int, nv)
}
for j := range f[0] {
f[0][j] = 1
}
for i := 1; i < n; i++ {
for j := range f[i] {
for _, k := range nxt[j] {
f[i][j] += f[i-1][k]
}
f[i][j] %= mod
}
}

ans := 0
for _, fv := range f[n-1] {
ans += fv
}
return ans % mod
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n4^m)$。理由同写法一。
  • 空间复杂度:$\mathcal{O}(4^m + n2^m)$。:用滚动数组可以优化至 $\mathcal{O}(4^m)$。

附:矩阵快速幂

$n=10^{18}$ 也可以通过。

原理讲解

import numpy as np

MOD = 1_000_000_007

# a^n @ f0
def pow(a: np.ndarray, n: int, f0: np.ndarray) -> np.ndarray:
    res = f0
    while n:
        if n & 1:
            res = a @ res % MOD
        a = a @ a % MOD
        n >>= 1
    return res

class Solution:
    def colorTheGrid(self, m: int, n: int) -> int:
        pow3 = [3 ** i for i in range(m)]
        valid = []
        for color in range(3 ** m):
            for i in range(1, m):
                if color // pow3[i] % 3 == color // pow3[i - 1] % 3:  # 相邻颜色相同
                    break
            else:  # 没有中途 break,合法
                valid.append(color)

        nv = len(valid)
        m = np.zeros((nv, nv), dtype=object)
        for i, color1 in enumerate(valid):
            for j, color2 in enumerate(valid):
                for p3 in pow3:
                    if color1 // p3 % 3 == color2 // p3 % 3:  # 相邻颜色相同
                        break
                else:  # 没有中途 break,合法
                    m[i, j] = 1

        f0 = np.ones((nv,), dtype=object)
        res = pow(m, n - 1, f0)
        return np.sum(res) % MOD
import numpy as np

MOD = 1_000_000_007

class Solution:
    def colorTheGrid(self, m: int, n: int) -> int:
        pow3 = [3 ** i for i in range(m)]
        valid = []
        for color in range(3 ** m):
            for i in range(1, m):
                if color // pow3[i] % 3 == color // pow3[i - 1] % 3:  # 相邻颜色相同
                    break
            else:  # 没有中途 break,合法
                valid.append(color)

        nv = len(valid)
        m = np.zeros((nv, nv), dtype=object)
        for i, color1 in enumerate(valid):
            for j, color2 in enumerate(valid):
                for p3 in pow3:
                    if color1 // p3 % 3 == color2 // p3 % 3:  # 相邻颜色相同
                        break
                else:  # 没有中途 break,合法
                    m[i, j] = 1

        f0 = np.ones((nv,), dtype=object)
        res = np.linalg.matrix_power(m, n - 1) @ f0
        return np.sum(res) % MOD

复杂度分析

  • 时间复杂度:$\mathcal{O}(2^{m\omega} \log n)$。矩阵长宽均为 $\mathcal{O}(2^m)$,计算一次矩阵乘法需要 $\mathcal{O}((2^m)^\omega)$ 的时间,其中 $\omega\le 3$。
  • 空间复杂度:$\mathcal{O}(4^m)$。

双倍经验

1411. 给 N x 3 网格图涂色的方案数

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天 — 2025年5月17日LeetCode 每日一题题解

O(n) 插入排序,简洁写法(Python/Java/C++/C/Go/JS/Rust)

作者 endlesscheng
2025年5月17日 08:57

不让用 $\texttt{sort}$ 吗?有意思……

技巧:O(1) 插入元素

假设现在有一个有序数组 $a=[0,0,1,1,2,2]$。在 $a$ 中插入一个 $0$,同时保证 $a$ 是有序的,你会怎么做?

最暴力的想法是,把 $0$ 插在数组的最左边,原来的元素全体右移一位,得到 $[0,0,0,1,1,2,2]$。这样做是 $\mathcal{O}(n)$ 的。

实际上,我们可以「狸猫换太子」:不是插入元素,而是修改元素!

对比一下插入前后:

  • 插入前 $[0,0,1,1,2,2]$。
  • 插入后 $[0,0,0,1,1,2,2]$。

竖着看,其实只有 $3$ 个位置变了:

  1. 原来的 $a[2]$ 变成 $0$。
  2. 原来的 $a[4]$ 变成 $1$。
  3. 末尾新增一个 $2$,相当于 $a[6]=2$。

怎么知道要修改的位置(下标)?

  1. 维护 $0$ 的个数,即为改成 $0$ 的位置,记作 $p_0$。上例中 $p_0=2$。把 $a[p_0]$ 改成 $0$。
  2. 维护 $0$ 和 $1$ 的个数,即为改成 $1$ 的位置,记作 $p_1$。上例中 $p_1=4$。把 $a[p_1]$ 改成 $1$。
  3. 末尾新增的位置记作 $i$,把 $a[i]$ 改成 $2$。

细节

如果 $a$ 中没有 $2$ 呢?上面第三步就错了。

比如现在 $a=[1]$,插入一个 $0$,变成 $[0,1]$。

如果按照上面三步走,最后把 $a[1]$ 改成 $2$,得到的是 $[0,2]$,这就错了。

要写很多 $\texttt{if-else}$,特判这些特殊情况吗?

不需要,我们可以倒过来算:先把 $a[1]$ 改成 $2$,再把 $a[1]$ 改成 $1$(覆盖),最后 $a[0]$ 改成 $0$,得到 $[0,1]$。这种「覆盖」等价于「没有 $2$ 的时候不改成 $2$」。

如果插入的是 $1$ 呢?

跳过「把 $a[p_0]$ 改成 $0$」这一步。

如果插入的是 $2$ 呢?

只需要把 $a[i]$ 改成 $2$ 即可。

本题思路

对 $\textit{nums}$ 执行插入排序,也就是对 $i=0,1,2,\ldots,n-1$ 依次执行如下过程:

  • 现在前缀 $\textit{nums}[0]$ 到 $\textit{nums}[i-1]$ 是有序的,我们把 $\textit{nums}[i]$ 插入到这个有序前缀中,从而把前缀 $\textit{nums}[0]$ 到 $\textit{nums}[i]$ 变成有序的。
  • 算法执行完后,$\textit{nums}$ 就是一个有序数组了。
class Solution:
    def sortColors(self, nums: List[int]) -> None:
        p0 = p1 = 0
        for i, x in enumerate(nums):
            nums[i] = 2
            if x <= 1:
                nums[p1] = 1
                p1 += 1
            if x == 0:
                nums[p0] = 0
                p0 += 1
class Solution {
    public void sortColors(int[] nums) {
        int p0 = 0;
        int p1 = 0;
        for (int i = 0; i < nums.length; i++) {
            int x = nums[i];
            nums[i] = 2;
            if (x <= 1) {
                nums[p1++] = 1;
            }
            if (x == 0) {
                nums[p0++] = 0;
            }
        }
    }
}
class Solution {
public:
    void sortColors(vector<int>& nums) {
        int p0 = 0, p1 = 0;
        for (int i = 0; i < nums.size(); i++) {
            int x = nums[i];
            nums[i] = 2;
            if (x <= 1) {
                nums[p1++] = 1;
            }
            if (x == 0) {
                nums[p0++] = 0;
            }
        }
    }
};
void sortColors(int* nums, int numsSize) {
    int p0 = 0, p1 = 0;
    for (int i = 0; i < numsSize; i++) {
        int x = nums[i];
        nums[i] = 2;
        if (x <= 1) {
            nums[p1++] = 1;
        }
        if (x == 0) {
            nums[p0++] = 0;
        }
    }
}
func sortColors(nums []int) {
    p0, p1 := 0, 0
    for i, x := range nums {
        nums[i] = 2
        if x <= 1 {
            nums[p1] = 1
            p1++
        }
        if x == 0 {
            nums[p0] = 0
            p0++
        }
    }
}
var sortColors = function(nums) {
    let p0 = 0, p1 = 0;
    for (let i = 0; i < nums.length; i++) {
        const x = nums[i];
        nums[i] = 2;
        if (x <= 1) {
            nums[p1++] = 1;
        }
        if (x === 0) {
            nums[p0++] = 0;
        }
    }
};
impl Solution {
    pub fn sort_colors(nums: &mut Vec<i32>) {
        let mut p0 = 0;
        let mut p1 = 0;
        for i in 0..nums.len() {
            let x = nums[i];
            nums[i] = 2;
            if x <= 1 {
                nums[p1] = 1;
                p1 += 1;
            }
            if x == 0 {
                nums[p0] = 0;
                p0 += 1;
            }
        }
    }
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-颜色分类🟡

2025年5月17日 00:00

给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums ,原地 对它们进行排序,使得相同颜色的元素相邻,并按照红色、白色、蓝色顺序排列。

我们使用整数 0、 12 分别表示红色、白色和蓝色。

    必须在不使用库内置的 sort 函数的情况下解决这个问题。

     

    示例 1:

    输入:nums = [2,0,2,1,1,0]
    输出:[0,0,1,1,2,2]
    

    示例 2:

    输入:nums = [2,0,1]
    输出:[0,1,2]
    

     

    提示:

    • n == nums.length
    • 1 <= n <= 300
    • nums[i]012

     

    进阶:

    • 你能想出一个仅使用常数空间的一趟扫描算法吗?

    颜色分类

    2020年10月6日 21:28

    📺 视频题解

    75.颜色分类.mp4

    📖 文字题解

    前言

    本题是经典的「荷兰国旗问题」,由计算机科学家 Edsger W. Dijkstra 首先提出。

    根据题目中的提示,我们可以统计出数组中 $0, 1, 2$ 的个数,再根据它们的数量,重写整个数组。这种方法较为简单,也很容易想到,而本题解中会介绍两种基于指针进行交换的方法。

    方法一:单指针

    思路与算法

    我们可以考虑对数组进行两次遍历。在第一次遍历中,我们将数组中所有的 $0$ 交换到数组的头部。在第二次遍历中,我们将数组中所有的 $1$ 交换到头部的 $0$ 之后。此时,所有的 $2$ 都出现在数组的尾部,这样我们就完成了排序。

    具体地,我们使用一个指针 $\textit{ptr}$ 表示「头部」的范围,$\textit{ptr}$ 中存储了一个整数,表示数组 $\textit{nums}$ 从位置 $0$ 到位置 $\textit{ptr}-1$ 都属于「头部」。$\textit{ptr}$ 的初始值为 $0$,表示还没有数处于「头部」。

    在第一次遍历中,我们从左向右遍历整个数组,如果找到了 $0$,那么就需要将 $0$ 与「头部」位置的元素进行交换,并将「头部」向后扩充一个位置。在遍历结束之后,所有的 $0$ 都被交换到「头部」的范围,并且「头部」只包含 $0$。

    在第二次遍历中,我们从「头部」开始,从左向右遍历整个数组,如果找到了 $1$,那么就需要将 $1$ 与「头部」位置的元素进行交换,并将「头部」向后扩充一个位置。在遍历结束之后,所有的 $1$ 都被交换到「头部」的范围,并且都在 $0$ 之后,此时 $2$ 只出现在「头部」之外的位置,因此排序完成。

    代码

    ###C++

    class Solution {
    public:
        void sortColors(vector<int>& nums) {
            int n = nums.size();
            int ptr = 0;
            for (int i = 0; i < n; ++i) {
                if (nums[i] == 0) {
                    swap(nums[i], nums[ptr]);
                    ++ptr;
                }
            }
            for (int i = ptr; i < n; ++i) {
                if (nums[i] == 1) {
                    swap(nums[i], nums[ptr]);
                    ++ptr;
                }
            }
        }
    };
    

    ###Java

    class Solution {
        public void sortColors(int[] nums) {
            int n = nums.length;
            int ptr = 0;
            for (int i = 0; i < n; ++i) {
                if (nums[i] == 0) {
                    int temp = nums[i];
                    nums[i] = nums[ptr];
                    nums[ptr] = temp;
                    ++ptr;
                }
            }
            for (int i = ptr; i < n; ++i) {
                if (nums[i] == 1) {
                    int temp = nums[i];
                    nums[i] = nums[ptr];
                    nums[ptr] = temp;
                    ++ptr;
                }
            }
        }
    }
    

    ###Python

    class Solution:
        def sortColors(self, nums: List[int]) -> None:
            n = len(nums)
            ptr = 0
            for i in range(n):
                if nums[i] == 0:
                    nums[i], nums[ptr] = nums[ptr], nums[i]
                    ptr += 1
            for i in range(ptr, n):
                if nums[i] == 1:
                    nums[i], nums[ptr] = nums[ptr], nums[i]
                    ptr += 1
    

    ###Golang

    func swapColors(colors []int, target int) (countTarget int) {
        for i, c := range colors {
            if c == target {
                colors[i], colors[countTarget] = colors[countTarget], colors[i]
                countTarget++
            }
        }
        return
    }
    
    func sortColors(nums []int) {
        count0 := swapColors(nums, 0) // 把 0 排到前面
        swapColors(nums[count0:], 1)  // nums[:count0] 全部是 0 了,对剩下的 nums[count0:] 把 1 排到前面
    }
    

    ###C

    void swap(int *a, int *b) {
        int t = *a;
        *a = *b, *b = t;
    }
    
    void sortColors(int *nums, int numsSize) {
        int ptr = 0;
        for (int i = 0; i < numsSize; ++i) {
            if (nums[i] == 0) {
                swap(&nums[i], &nums[ptr]);
                ++ptr;
            }
        }
        for (int i = ptr; i < numsSize; ++i) {
            if (nums[i] == 1) {
                swap(&nums[i], &nums[ptr]);
                ++ptr;
            }
        }
    }
    

    复杂度分析

    • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

    • 空间复杂度:$O(1)$。

    方法二:双指针

    思路与算法

    方法一需要进行两次遍历,那么我们是否可以仅使用一次遍历呢?我们可以额外使用一个指针,即使用两个指针分别用来交换 $0$ 和 $1$。

    具体地,我们用指针 $p_0$ 来交换 $0$,$p_1$ 来交换 $1$,初始值都为 $0$。当我们从左向右遍历整个数组时:

    • 如果找到了 $1$,那么将其与 $\textit{nums}[p_1]$ 进行交换,并将 $p_1$ 向后移动一个位置,这与方法一是相同的;

    • 如果找到了 $0$,那么将其与 $\textit{nums}[p_0]$ 进行交换,并将 $p_0$ 向后移动一个位置。这样做是正确的吗?我们可以注意到,因为连续的 $0$ 之后是连续的 $1$,因此如果我们将 $0$ 与 $\textit{nums}[p_0]$ 进行交换,那么我们可能会把一个 $1$ 交换出去。当 $p_0 < p_1$ 时,我们已经将一些 $1$ 连续地放在头部,此时一定会把一个 $1$ 交换出去,导致答案错误。因此,如果 $p_0 < p_1$,那么我们需要再将 $\textit{nums}[i]$ 与 $\textit{nums}[p_1]$ 进行交换,其中 $i$ 是当前遍历到的位置,在进行了第一次交换后,$\textit{nums}[i]$ 的值为 $1$,我们需要将这个 $1$ 放到「头部」的末端。在最后,无论是否有 $p_0 < p_1$,我们需要将 $p_0$ 和 $p_1$ 均向后移动一个位置,而不是仅将 $p_0$ 向后移动一个位置。

    <ppt1,ppt2,ppt3,ppt4,ppt5,ppt6,ppt7,ppt8,ppt9,ppt10,ppt11,ppt12,ppt13,ppt14,ppt15,ppt16,ppt17,ppt18>

    代码

    ###C++

    class Solution {
    public:
        void sortColors(vector<int>& nums) {
            int n = nums.size();
            int p0 = 0, p1 = 0;
            for (int i = 0; i < n; ++i) {
                if (nums[i] == 1) {
                    swap(nums[i], nums[p1]);
                    ++p1;
                } else if (nums[i] == 0) {
                    swap(nums[i], nums[p0]);
                    if (p0 < p1) {
                        swap(nums[i], nums[p1]);
                    }
                    ++p0;
                    ++p1;
                }
            }
        }
    };
    

    ###Java

    class Solution {
        public void sortColors(int[] nums) {
            int n = nums.length;
            int p0 = 0, p1 = 0;
            for (int i = 0; i < n; ++i) {
                if (nums[i] == 1) {
                    int temp = nums[i];
                    nums[i] = nums[p1];
                    nums[p1] = temp;
                    ++p1;
                } else if (nums[i] == 0) {
                    int temp = nums[i];
                    nums[i] = nums[p0];
                    nums[p0] = temp;
                    if (p0 < p1) {
                        temp = nums[i];
                        nums[i] = nums[p1];
                        nums[p1] = temp;
                    }
                    ++p0;
                    ++p1;
                }
            }
        }
    }
    

    ###Python

    class Solution:
        def sortColors(self, nums: List[int]) -> None:
            n = len(nums)
            p0 = p1 = 0
            for i in range(n):
                if nums[i] == 1:
                    nums[i], nums[p1] = nums[p1], nums[i]
                    p1 += 1
                elif nums[i] == 0:
                    nums[i], nums[p0] = nums[p0], nums[i]
                    if p0 < p1:
                        nums[i], nums[p1] = nums[p1], nums[i]
                    p0 += 1
                    p1 += 1
    

    ###Golang

    func sortColors(nums []int) {
        p0, p1 := 0, 0
        for i, c := range nums {
            if c == 0 {
                nums[i], nums[p0] = nums[p0], nums[i]
                if p0 < p1 {
                    nums[i], nums[p1] = nums[p1], nums[i]
                }
                p0++
                p1++
            } else if c == 1 {
                nums[i], nums[p1] = nums[p1], nums[i]
                p1++
            }
        }
    }
    

    ###C

    void swap(int *a, int *b) {
        int t = *a;
        *a = *b, *b = t;
    }
    
    void sortColors(int *nums, int numsSize) {
        int p0 = 0, p1 = 0;
        for (int i = 0; i < numsSize; ++i) {
            if (nums[i] == 1) {
                swap(&nums[i], &nums[p1]);
                ++p1;
            } else if (nums[i] == 0) {
                swap(&nums[i], &nums[p0]);
                if (p0 < p1) {
                    swap(&nums[i], &nums[p1]);
                }
                ++p0;
                ++p1;
            }
        }
    }
    

    复杂度分析

    • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

    • 空间复杂度:$O(1)$。

    方法三:双指针

    思路与算法

    与方法二类似,我们也可以考虑使用指针 $p_0$ 来交换 $0$,$p_2$ 来交换 $2$。此时,$p_0$ 的初始值仍然为 $0$,而 $p_2$ 的初始值为 $n-1$。在遍历的过程中,我们需要找出所有的 $0$ 交换至数组的头部,并且找出所有的 $2$ 交换至数组的尾部。

    由于此时其中一个指针 $p_2$ 是从右向左移动的,因此当我们在从左向右遍历整个数组时,如果遍历到的位置超过了 $p_2$,那么就可以直接停止遍历了。

    具体地,我们从左向右遍历整个数组,设当前遍历到的位置为 $i$,对应的元素为 $\textit{nums}[i]$;

    • 如果找到了 $0$,那么与前面两种方法类似,将其与 $\textit{nums}[p_0]$ 进行交换,并将 $p_0$ 向后移动一个位置;

    • 如果找到了 $2$,那么将其与 $\textit{nums}[p_2]$ 进行交换,并将 $p_2$ 向前移动一个位置。

    这样做是正确的吗?可以发现,对于第二种情况,当我们将 $\textit{nums}[i]$ 与 $\textit{nums}[p_2]$ 进行交换之后,新的 $\textit{nums}[i]$ 可能仍然是 $2$,也可能是 $0$。然而此时我们已经结束了交换,开始遍历下一个元素 $\textit{nums}[i+1]$,不会再考虑 $\textit{nums}[i]$ 了,这样我们就会得到错误的答案。

    因此,当我们找到 $2$ 时,我们需要不断地将其与 $\textit{nums}[p_2]$ 进行交换,直到新的 $\textit{nums}[i]$ 不为 $2$。此时,如果 $\textit{nums}[i]$ 为 $0$,那么对应着第一种情况;如果 $\textit{nums}[i]$ 为 $1$,那么就不需要进行任何后续的操作。

    <fig1,fig2,fig3,fig4,fig5,fig6,fig7,fig8,fig9,fig10,fig11,fig12,fig13>

    代码

    ###C++

    class Solution {
    public:
        void sortColors(vector<int>& nums) {
            int n = nums.size();
            int p0 = 0, p2 = n - 1;
            for (int i = 0; i <= p2; ++i) {
                while (i <= p2 && nums[i] == 2) {
                    swap(nums[i], nums[p2]);
                    --p2;
                }
                if (nums[i] == 0) {
                    swap(nums[i], nums[p0]);
                    ++p0;
                }
            }
        }
    };
    

    ###Java

    class Solution {
        public void sortColors(int[] nums) {
            int n = nums.length;
            int p0 = 0, p2 = n - 1;
            for (int i = 0; i <= p2; ++i) {
                while (i <= p2 && nums[i] == 2) {
                    int temp = nums[i];
                    nums[i] = nums[p2];
                    nums[p2] = temp;
                    --p2;
                }
                if (nums[i] == 0) {
                    int temp = nums[i];
                    nums[i] = nums[p0];
                    nums[p0] = temp;
                    ++p0;
                }
            }
        }
    }
    

    ###Python

    class Solution:
        def sortColors(self, nums: List[int]) -> None:
            n = len(nums)
            p0, p2 = 0, n - 1
            i = 0
            while i <= p2:
                while i <= p2 and nums[i] == 2:
                    nums[i], nums[p2] = nums[p2], nums[i]
                    p2 -= 1
                if nums[i] == 0:
                    nums[i], nums[p0] = nums[p0], nums[i]
                    p0 += 1
                i += 1
    

    ###Golang

    func sortColors(nums []int) {
        p0, p2 := 0, len(nums)-1
        for i := 0; i <= p2; i++ {
            for ; i <= p2 && nums[i] == 2; p2-- {
                nums[i], nums[p2] = nums[p2], nums[i]
            }
            if nums[i] == 0 {
                nums[i], nums[p0] = nums[p0], nums[i]
                p0++
            }
        }
    }
    

    ###C

    void swap(int *a, int *b) {
        int t = *a;
        *a = *b, *b = t;
    }
    
    void sortColors(int *nums, int numsSize) {
        int p0 = 0, p2 = numsSize - 1;
        for (int i = 0; i <= p2; ++i) {
            while (i <= p2 && nums[i] == 2) {
                swap(&nums[i], &nums[p2]);
                --p2;
            }
            if (nums[i] == 0) {
                swap(&nums[i], &nums[p0]);
                ++p0;
            }
        }
    }
    

    复杂度分析

    • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

    • 空间复杂度:$O(1)$。

    「三路快排」应用(Java)

    作者 liweiwei1419
    2020年1月10日 01:03

    本题其实是经典的荷兰国旗问题,最初由荷兰计算机科学家艾兹赫尔·迪克斯特拉(Edsger W. Dijkstra)提出,并以荷兰国旗的颜色(红、白、蓝)为类比。

    如果你学习过「快速排序」,知道「三路快排」,这道题就非常简单,本题考查的知识点就是「三路快排」。我以前录过视频讲解了「快速排序」:地址,在第 6 节讲到了「三路快排」。

    思路分析

    容易想到的做法:

    • 排序。可以使用编程语言提供的排序函数,一般认为是「快速排序」或者「归并排序」,排序以后即为所求。但不符合题目的「进阶」要求「一趟扫描」和「常数空间」;
    • 使用「计数排序」。分别统计 0、1、2 的个数,再赋值回原数组。但不符合题目的「进阶」要求「一趟扫描」。

    「快速排序」的 partition,其中有一种 partition 叫做「三路快排」,即通过一次 partition,把区间里的数根据基准元素 pivot 分成 3 个部分:

    • 第 1 个部分:小于 pivot
    • 第 2 个部分:等于 pivot
    • 第 3 个部分:大于 pivot

    本题的解法和「三路快排」几乎是一样,在「一趟扫描」的过程中,使用变量 i 扫描,分别使用两个变量放在数组的头和尾,我们就分别命名为 zerotwo。其中

    • zero 是 0 和 1 的分界;
    • two 是 2 和还未遍历到的数的分界。

    变量的定义如下图所示(「参考代码 1」采用的定义方式):

    image.png

    变量 i 看到 1 的时候直接 ++ ,看到 0 的时候交换到数组的前面,看到 2 的时候交换到数组的末尾。具体细节,可以先写出变量的定义,在编码就容易多了

    这里给出两版代码,其实是一样的,它们的区别仅仅在于 zero 指向 0 还是指向 1 ,two 指向 2 还是指向未看到的数。我们把 zerotwo 的定义写成了区间的形式,写在注释中。

    由于区间定义不同,初始化,循环过程中,退出循环的条件就有细微差别。

    参考代码 1

    ###java

    public class Solution {
    
        public void sortColors(int[] nums) {
            int n = nums.length;
            if (n < 2) {
                return;
            }
    
            // all in [0..zero) = 0
            // all in [zero..i) = 1
            // all in [two..n - 1] = 2
            // 初始化的时候,要满足上面 3 个区间全是空区间
            int zero = 0;
            int two = n;
            int i = 0;
            // i = two 的时候,[zero..i)、[two..len - 1] 已经接起来了,所以 while 里面是 i < two
            while (i < two) {
                if (nums[i] == 0) {
                    // zero 指向 1,所以先交换
                    swap(nums, i, zero);
                    zero++;
                    i++;
                } else if (nums[i] == 1) {
                    // 直接划分到 1 所在的区间
                    i++;
                } else {
                    // [two..n - 1] = 2,two 指向 2,所以先 -- ,再交换
                    two--;
                    swap(nums, i, two);
                }
            }
        }
    
        private void swap(int[] nums, int index1, int index2) {
            int temp = nums[index1];
            nums[index1] = nums[index2];
            nums[index2] = temp;
        }
    
    }
    

    复杂度分析

    • 时间复杂度:$O(n)$,这里 $n$ 是输入数组的长度;
    • 空间复杂度:$O(1)$。

    参考代码 2

    ###java

    public class Solution {
    
        public void sortColors(int[] nums) {
            int n = nums.length;
            if (n < 2) {
                return;
            }
    
            // all in [0, zero] = 0
            // all in (zero, i) = 1
            // all in (two, n - 1] = 2
            int zero = -1;
            int two = n - 1;
            int i = 0;
            while (i <= two) {
                if (nums[i] == 0) {
                    zero++;
                    swap(nums, i, zero);
                    i++;
                } else if (nums[i] == 1) {
                    i++;
                } else {
                    swap(nums, i, two);
                    two--;
                }
            }
        }
    
        private void swap(int[] nums, int index1, int index2) {
            int temp = nums[index1];
            nums[index1] = nums[index2];
            nums[index2] = temp;
        }
    
    }
    

    复杂度分析:(同「参考代码 1」)。

    昨天以前LeetCode 每日一题题解

    最长相邻不相等子序列 II

    2025年5月6日 10:11

    方法一:动态规划

    思路与算法

    题目要求找到 $[0, 1, ..., n - 1]$ 中的最长子序列,该子序列中满足前后相邻下标对应的 $\textit{groups}$ 值不同,且相邻下标对应的 $\textit{words}$ 的汉明距离为 $1$。与「最长相邻不相等子序列 I」题目类似,我们仍可采用动态规划来解决该问题。

    设 $\textit{dp}[i]$ 表示以下标 $i$ 为结尾的最长子序列长度,设 $\text{HammingDistance}(s,t)$ 表示两个字符串 $s,t$ 的「汉明距离」。子序列中如果下标 $i$ 可以添加在下标 $j$ 之后,则此时一定满足 $\textit{groups}[i] \neq \textit{groups}[j], j < i$ 且 $\text{HammingDistance}(\textit{words}[i], \textit{words}[j]) = 1$,此时下标 $i$ 可以添加到下标 $j$ 之后,此时以下标 $i$ 为结尾的最长子序列长度为 $\textit{dp}[i] = \textit{dp}[j] + 1$,我们可以得到动态规划递推公式如下:

    $$
    \textit{dp}[i] = \max(\textit{dp}[i], \textit{dp}[j] + 1) \quad if \quad \textit{groups}[i] \neq \textit{groups}[j],\text{HammingDistance}(\textit{words}[i], \textit{words}[j]) = 1
    $$

    对于下标 $i$,我们可以枚举 $i$ 之前的小标,即可求得以 $i$ 为结尾的最长子序列的长度,依次求出以每个下标为结尾的最长子序列长度即可找到 $[0, 1, ..., n - 1]$
    中的最长子序列长度。为了方便计算,我们用 $\textit{prev}[i]$ 记载以下标 $i$ 为结尾的最长子序列中 $i$ 的上一个下标。当我们找到最长子序列的结尾下标 $i$ 时,沿着 $i$ 往前即可找到整个序列的下标,并将每个下标对应的字符串加入到数组中,对整个数组反转后的结果即为答案。

    代码

    ###C++

    class Solution {
    public:
        vector<string> getWordsInLongestSubsequence(vector<string>& words, vector<int>& groups) {
            int n = groups.size();
            vector<int> dp(n, 1);
            vector<int> prev(n, -1);
            int maxIndex = 0;
            for (int i = 1; i < n; i++) {
                for (int j = 0; j < i; j++) {
                    if (check(words[i], words[j]) == 1 && dp[j] + 1 > dp[i] && groups[i] != groups[j]) {
                        dp[i] = dp[j] + 1;
                        prev[i] = j;
                    }
                }
                if (dp[i] > dp[maxIndex]) {
                    maxIndex = i;
                }
            }
    
            vector<string> ans;
            for (int i = maxIndex; i >= 0; i = prev[i]) {
                ans.emplace_back(words[i]);
            }
            reverse(ans.begin(), ans.end());
            return ans;
        }
    
        bool check(string &s1, string &s2) {
            if (s1.size() != s2.size()) {
                return false;
            }
            int diff = 0;
            for (int i = 0; i < s1.size(); i++) {
                diff += s1[i] != s2[i];
                if (diff > 1) {
                    return false;
                }
            }
            return diff == 1;
        }
    };
    

    ###Java

    class Solution {
        public List<String> getWordsInLongestSubsequence(String[] words, int[] groups) {
            int n = groups.length;
            int[] dp = new int[n];
            int[] prev = new int[n];
            Arrays.fill(dp, 1);
            Arrays.fill(prev, -1);
            int maxIndex = 0;
            for (int i = 1; i < n; i++) {
                for (int j = 0; j < i; j++) {
                    if (check(words[i], words[j]) && dp[j] + 1 > dp[i] && groups[i] != groups[j]) {
                        dp[i] = dp[j] + 1;
                        prev[i] = j;
                    }
                }
                if (dp[i] > dp[maxIndex]) {
                    maxIndex = i;
                }
            }
            List<String> ans = new ArrayList<>();
            for (int i = maxIndex; i >= 0; i = prev[i]) {
                ans.add(words[i]);
            }
            Collections.reverse(ans);
            return ans;
        }
    
        private boolean check(String s1, String s2) {
            if (s1.length() != s2.length()) {
                return false;
            }
            int diff = 0;
            for (int i = 0; i < s1.length(); i++) {
                if (s1.charAt(i) != s2.charAt(i)) {
                    if (++diff > 1) {
                        return false;
                    }
                }
            }
            return diff == 1;
        }
    }
    

    ###C#

    public class Solution {
        public IList<string> GetWordsInLongestSubsequence(string[] words, int[] groups) {
            int n = groups.Length;
            int[] dp = new int[n];
            int[] prev = new int[n];
            Array.Fill(dp, 1);
            Array.Fill(prev, -1);
            int maxIndex = 0;
    
            for (int i = 1; i < n; i++) {
                for (int j = 0; j < i; j++) {
                    if (Check(words[i], words[j]) && dp[j] + 1 > dp[i] && groups[i] != groups[j]) {
                        dp[i] = dp[j] + 1;
                        prev[i] = j;
                    }
                }
                if (dp[i] > dp[maxIndex]) {
                    maxIndex = i;
                }
            }
    
            List<string> ans = new List<string>();
            for (int i = maxIndex; i >= 0; i = prev[i]) {
                ans.Add(words[i]);
            }
            ans.Reverse();
            return ans;
        }
    
        private bool Check(string s1, string s2) {
            if (s1.Length != s2.Length) {
                return false;
            }
            int diff = 0;
            for (int i = 0; i < s1.Length; i++) {
                if (s1[i] != s2[i]) {
                    if (++diff > 1) {
                        return false;
                    }
                }
            }
            return diff == 1;
        }
    }
    

    ###Python

    class Solution:
        def getWordsInLongestSubsequence(self, words: List[str], groups: List[int]) -> List[str]:
            n = len(groups)
            dp = [1] * n
            prev_ = [-1] * n
            max_index = 0
    
            for i in range(1, n):
                for j in range(i):
                    if self.check(words[i], words[j]) and dp[j] + 1 > dp[i] and groups[i] != groups[j]:
                        dp[i] = dp[j] + 1
                        prev_[i] = j
                if dp[i] > dp[max_index]:
                    max_index = i
    
            ans = []
            i = max_index
            while i >= 0:
                ans.append(words[i])
                i = prev_[i]
            ans.reverse()
            return ans
    
        def check(self, s1: str, s2: str) -> bool:
            if len(s1) != len(s2):
                return False
            diff = 0
            for c1, c2 in zip(s1, s2):
                if c1 != c2:
                    diff += 1
                    if diff > 1:
                        return False
            return diff == 1
    

    ###Go

    func getWordsInLongestSubsequence(words []string, groups []int) []string {
        n := len(groups)
    dp := make([]int, n)
    prev := make([]int, n)
    for i := range dp {
    dp[i] = 1
    prev[i] = -1
    }
    maxIndex := 0
    
    for i := 1; i < n; i++ {
    for j := 0; j < i; j++ {
    if check(words[i], words[j]) && dp[j]+1 > dp[i] && groups[i] != groups[j] {
    dp[i] = dp[j] + 1
    prev[i] = j
    }
    }
    if dp[i] > dp[maxIndex] {
    maxIndex = i
    }
    }
    
    ans := []string{}
    for i := maxIndex; i >= 0; i = prev[i] {
    ans = append(ans, words[i])
    }
    reverse(ans)
    return ans
    }
    
    func check(s1, s2 string) bool {
    if len(s1) != len(s2) {
    return false
    }
    diff := 0
    for i := 0; i < len(s1); i++ {
    if s1[i] != s2[i] {
    diff++
                if diff > 1 {
                    return false
                }
    }
    }
    return diff == 1
    }
    
    func reverse(arr []string) {
    for i, j := 0, len(arr) - 1; i < j; i, j = i + 1, j - 1 {
    arr[i], arr[j] = arr[j], arr[i]
    }
    }
    

    ###C

    bool check(const char *s1, const char *s2) {
        if (strlen(s1) != strlen(s2)) {
            return false;
        }
        int diff = 0;
        for (int i = 0; s1[i]; i++) {
            if (s1[i] != s2[i]) {
                if (++diff > 1) {
                    return false;
                }
            }
        }
        return diff == 1;
    }
    
    char **getWordsInLongestSubsequence(char **words, int wordsSize, int *groups, int groupsSize, int *returnSize) {
        int *dp = (int *)malloc(wordsSize * sizeof(int));
        int *prev = (int *)malloc(wordsSize * sizeof(int));
        for (int i = 0; i < wordsSize; i++) {
            dp[i] = 1;
            prev[i] = -1;
        }
        int maxIndex = 0;
        for (int i = 1; i < wordsSize; i++) {
            for (int j = 0; j < i; j++) {
                if (check(words[i], words[j]) && dp[j] + 1 > dp[i] && groups[i] != groups[j]) {
                    dp[i] = dp[j] + 1;
                    prev[i] = j;
                }
            }
            if (dp[i] > dp[maxIndex]) {
                maxIndex = i;
            }
        }
    
        int count = 0;
        for (int i = maxIndex; i >= 0; i = prev[i]) {
            count++;
        }
    
        char **ans = (char **)malloc(count * sizeof(char *));
        int index = 0;
        for (int i = maxIndex; i >= 0; i = prev[i]) {
            ans[index++] = words[i];
        }
        for (int i = 0; i < count / 2; i++) {
            char *temp = ans[i];
            ans[i] = ans[count - 1 - i];
            ans[count - 1 - i] = temp;
        }
    
        *returnSize = count;
        free(dp);
        free(prev);
        return ans;
    }
    

    ###JavaScript

    var getWordsInLongestSubsequence = function(words, groups) {
        const n = groups.length;
        const dp = new Array(n).fill(1);
        const prev = new Array(n).fill(-1);
        let maxIndex = 0;
    
        for (let i = 1; i < n; i++) {
            for (let j = 0; j < i; j++) {
                if (check(words[i], words[j]) && dp[j] + 1 > dp[i] && groups[i] !== groups[j]) {
                    dp[i] = dp[j] + 1;
                    prev[i] = j;
                }
            }
            if (dp[i] > dp[maxIndex]) {
                maxIndex = i;
            }
        }
    
        const ans = [];
        for (let i = maxIndex; i >= 0; i = prev[i]) {
            ans.push(words[i]);
        }
        ans.reverse();
        return ans;
    };
    
    const check = (s1, s2) => {
        if (s1.length !== s2.length) {
            return false;
        }
        let diff = 0;
        for (let i = 0; i < s1.length; i++) {
            if (s1[i] !== s2[i]) {
                if (++diff > 1) {
                    return false;
                }
            }
        }
        return diff === 1;
    };
    

    ###TypeScript

    function getWordsInLongestSubsequence(words: string[], groups: number[]): string[] {
        const n = groups.length;
        const dp = new Array(n).fill(1);
        const prev = new Array(n).fill(-1);
        let maxIndex = 0;
    
        for (let i = 1; i < n; i++) {
            for (let j = 0; j < i; j++) {
                if (check(words[i], words[j]) && dp[j] + 1 > dp[i] && groups[i] !== groups[j]) {
                    dp[i] = dp[j] + 1;
                    prev[i] = j;
                }
            }
            if (dp[i] > dp[maxIndex]) {
                maxIndex = i;
            }
        }
    
        const ans = [];
        for (let i = maxIndex; i >= 0; i = prev[i]) {
            ans.push(words[i]);
        }
        ans.reverse();
        return ans;
    };
    
    function check(s1: string, s2: string): boolean {
        if (s1.length !== s2.length) {
            return false;
        }
        let diff = 0;
        for (let i = 0; i < s1.length; i++) {
            if (s1[i] !== s2[i]) {
                if (++diff > 1) {
                    return false;
                }
            }
        }
        return diff === 1;
    }
    

    ###Rust

    impl Solution {
        pub fn get_words_in_longest_subsequence(words: Vec<String>, groups: Vec<i32>) -> Vec<String> {
            let n = groups.len();
            let mut dp = vec![1; n];
            let mut prev = vec![-1; n];
            let mut max_index = 0;
            for i in 1..n {
                for j in 0..i {
                    if Self::check(&words[i], &words[j]) && dp[j] + 1 > dp[i] && groups[i] != groups[j] {
                        dp[i] = dp[j] + 1;
                        prev[i] = j as i32;
                    }
                }
                if dp[i] > dp[max_index] {
                    max_index = i;
                }
            }
            let mut ans = Vec::new();
            let mut i = max_index as i32;
            while i >= 0 {
                ans.push(words[i as usize].clone());
                i = prev[i as usize];
            }
            ans.reverse();
            ans
        }
    
        fn check(s1: &String, s2: &String) -> bool {
            if s1.len() != s2.len() {
                return false;
            }
            let mut diff = 0;
            for (c1, c2) in s1.chars().zip(s2.chars()) {
                if c1 != c2 {
                    diff += 1;
                    if diff > 1 {
                        return false;
                    }
                }
            }
            diff == 1
        }
    }
    

    复杂度分析

    • 时间复杂度:$O(n^2L)$,其中 $n$ 表示给定数组的长度,$L$ 表示字符串数组 $\textit{word}$ 中字符串的长度。计算两个字符串的汉明码距离需要的时间为 $L$,找到以索引 $i$ 为结尾的最长子序列的需要遍历 $i$ 之前的所有索引,此时需要的时间为 $O(nL)$,求出以每个索引为结尾的最长子序列长度此时需要总时间为 $O(n^2L)$。

    • 空间复杂度:$O(n)$。其中 $n$ 表示给定数组的长度。需要存储以每个索引为结尾的最长子序列长度,一共需要的空间为 $O(n)$。

    ❌
    ❌