阅读视图

发现新文章,点击刷新页面。

距离字典两次编辑以内的单词

方法一:暴力

思路与算法

注意到题目的数据范围很小,我们可以直接实施暴力算法。

对于 $\textit{queries}$ 中的每个字符串 $\textit{queries}[i]$,查找 $\textit{dictionary}$ 中是否存在一个字符串,使得两个字符串中最多只有两个字符不同(即汉明距离小于等于 $2$),如果存在,就将其添加到答案中。由于添加的顺序与遍历 $\textit{queries}$ 的顺序一致,我们不需要对答案的顺序进行特殊处理。

代码

###C++

class Solution {
public:
    vector<string> twoEditWords(vector<string>& queries,
                                vector<string>& dictionary) {
        vector<string> ans;
        for (string query : queries) {
            for (string s : dictionary) {
                int dis = 0;
                for (int i = 0; i < query.size(); i++) {
                    if (query[i] != s[i]) {
                        ++dis;
                    }
                }
                if (dis <= 2) {
                    ans.push_back(query);
                    break;
                }
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public List<String> twoEditWords(String[] queries, String[] dictionary) {
        List<String> ans = new ArrayList<>();
        for (String query : queries) {
            for (String s : dictionary) {
                int dis = 0;
                for (int i = 0; i < query.length(); i++) {
                    if (query.charAt(i) != s.charAt(i)) {
                        dis++;
                    }
                }
                if (dis <= 2) {
                    ans.add(query);
                    break;
                }
            }
        }
        return ans;
    }
}

###C#

public class Solution {
    public IList<string> TwoEditWords(string[] queries, string[] dictionary) {
        var ans = new List<string>();
        foreach (var query in queries) {
            foreach (var s in dictionary) {
                int dis = 0;
                for (int i = 0; i < query.Length; i++) {
                    if (query[i] != s[i]) {
                        dis++;
                    }
                }
                if (dis <= 2) {
                    ans.Add(query);
                    break;
                }
            }
        }
        return ans;
    }
}

###Python

class Solution:
    def twoEditWords(self, queries, dictionary):
        ans = []
        for query in queries:
            for s in dictionary:
                dis = 0
                for i in range(len(query)):
                    if query[i] != s[i]:
                        dis += 1
                if dis <= 2:
                    ans.append(query)
                    break
        return ans

###C

char** twoEditWords(char** queries, int queriesSize,
                    char** dictionary, int dictionarySize,
                    int* returnSize) {
    char** ans = (char**)malloc(sizeof(char*) * queriesSize);
    int cnt = 0;

    for (int i = 0; i < queriesSize; i++) {
        char* query = queries[i];
        for (int j = 0; j < dictionarySize; j++) {
            char* s = dictionary[j];
            int dis = 0;
            for (int k = 0; query[k] != '\0'; k++) {
                if (query[k] != s[k]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans[cnt++] = query;
                break;
            }
        }
    }

    *returnSize = cnt;
    return ans;
}

###Go

func twoEditWords(queries []string, dictionary []string) []string {
    var ans []string
    for _, query := range queries {
        for _, s := range dictionary {
            dis := 0
            for i := 0; i < len(query); i++ {
                if query[i] != s[i] {
                    dis++
                }
            }
            if dis <= 2 {
                ans = append(ans, query)
                break
            }
        }
    }
    return ans
}

###JavaScript

var twoEditWords = function(queries, dictionary) {
    const ans = [];
    for (const query of queries) {
        for (const s of dictionary) {
            let dis = 0;
            for (let i = 0; i < query.length; i++) {
                if (query[i] !== s[i]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans.push(query);
                break;
            }
        }
    }
    return ans;
};

###TypeScript

function twoEditWords(queries: string[], dictionary: string[]): string[] {
    const ans: string[] = [];
    for (const query of queries) {
        for (const s of dictionary) {
            let dis = 0;
            for (let i = 0; i < query.length; i++) {
                if (query[i] !== s[i]) {
                    dis++;
                }
            }
            if (dis <= 2) {
                ans.push(query);
                break;
            }
        }
    }
    return ans;
}

###Rust

impl Solution {
    pub fn two_edit_words(queries: Vec<String>, dictionary: Vec<String>) -> Vec<String> {
        let mut ans = Vec::new();

        for query in &queries {
            for s in &dictionary {
                let mut dis = 0;
                for (c1, c2) in query.chars().zip(s.chars()) {
                    if c1 != c2 {
                        dis += 1;
                    }
                }
                if dis <= 2 {
                    ans.push(query.clone());
                    break;
                }
            }
        }

        ans
    }
}

复杂度分析

  • 时间复杂度:$O(qkn)$,其中 $q$ 为 $\textit{queries}$ 的长度,$k$ 为 $\textit{dictionary}$ 的长度,$n$ 为 $\textit{queries}[i]$ 的长度。我们需要对每一个 $\textit{queries}[i]$ 遍历一次 $\textit{dictionary}$,然后比较两个字符串。
  • 空间复杂度:$O(1)$,仅使用常数个变量。返回数组不计入空间复杂度。

方法二:字典树

思路与算法

我们可以将 $\textit{dictionary}$ 中的所有单词插入字典树,对每个 $\textit{queries}[i]$ 做深度优先搜索,在过程中维护修改次数 $\textit{cnt}$,从而实现在字典树上进行「最多 2 次修改」的匹配搜索。

定义状态 $\textit{dfs}(i, \textit{node}, \textit{cnt})$,其中 $i$ 表示当前匹配到第 $i$ 个字符,$\textit{node}$ 表示当前所在字典树的节点,$\textit{cnt}$ 表示已修改 $\textit{cnt}$ 次。在字典树上进行查找,对于第 $i$ 个字符 $\textit{query}[i]$:

  1. 如果字典树中存在 $\textit{node}.\textit{children}[\textit{query}[i]]$,不进行修改,下一步搜索状态 $\textit{dfs}(i+1, \textit{node}.\textit{children}[\textit{query}[i]], \textit{cnt})$。
  2. 如果字典树中不存在 $\textit{node}.\textit{children}[\textit{query}[i]]$,且 $\textit{cnt}<2$,进行修改,即枚举所有 $c \neq \textit{query}[i]$,下一步搜索状态 $\textit{dfs}(i+1,\textit{node}.\textit{children}[c], \textit{cnt}+1)$。

搜索过程中,我们可以进行一些剪枝,比如某条路径找到合法解就提前终止。

代码

###C++

struct TrieNode {
    TrieNode* child[26];
    bool isEnd;
    TrieNode() {
        memset(child, 0, sizeof(child));
        isEnd = false;
    }
};

class Solution {
public:
    TrieNode* root = new TrieNode();

    void insert(string& word) {
        TrieNode* node = root;
        for (char c : word) {
            int idx = c - 'a';
            if (!node->child[idx])
                node->child[idx] = new TrieNode();
            node = node->child[idx];
        }
        node->isEnd = true;
    }

    bool dfs(string& word, int i, TrieNode* node, int cnt) {
        if (cnt > 2)
            return false;
        if (!node)
            return false;

        if (i == word.size()) {
            return node->isEnd;
        }

        int idx = word[i] - 'a';

        // 不修改
        if (node->child[idx]) {
            if (dfs(word, i + 1, node->child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx)
                    continue;
                if (node->child[c]) {
                    if (dfs(word, i + 1, node->child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    vector<string> twoEditWords(vector<string>& queries,
                                vector<string>& dictionary) {
        for (auto& w : dictionary)
            insert(w);

        vector<string> res;
        for (auto& q : queries) {
            if (dfs(q, 0, root, 0)) {
                res.push_back(q);
            }
        }
        return res;
    }
};

###Java

class Solution {
    static class TrieNode {
        TrieNode[] child = new TrieNode[26];
        boolean isEnd = false;
    }

    TrieNode root = new TrieNode();

    void insert(String word) {
        TrieNode node = root;
        for (char c : word.toCharArray()) {
            int idx = c - 'a';
            if (node.child[idx] == null)
                node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    boolean dfs(String word, int i, TrieNode node, int cnt) {
        if (cnt > 2 || node == null)
            return false;

        if (i == word.length())
            return node.isEnd;

        int idx = word.charAt(i) - 'a';

        // 不修改
        if (node.child[idx] != null) {
            if (dfs(word, i + 1, node.child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx)
                    continue;
                if (node.child[c] != null) {
                    if (dfs(word, i + 1, node.child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    public List<String> twoEditWords(String[] queries, String[] dictionary) {
        for (String w : dictionary)
            insert(w);

        List<String> res = new ArrayList<>();
        for (String q : queries) {
            if (dfs(q, 0, root, 0)) {
                res.add(q);
            }
        }
        return res;
    }
}

###C#

public class Solution {
    class TrieNode {
        public TrieNode[] child = new TrieNode[26];
        public bool isEnd = false;
    }

    TrieNode root = new TrieNode();

    void Insert(string word) {
        var node = root;
        foreach (char c in word) {
            int idx = c - 'a';
            if (node.child[idx] == null)
                node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    bool Dfs(string word, int i, TrieNode node, int cnt) {
        if (cnt > 2 || node == null)
            return false;

        if (i == word.Length)
            return node.isEnd;

        int idx = word[i] - 'a';

        // 不修改
        if (node.child[idx] != null) {
            if (Dfs(word, i + 1, node.child[idx], cnt))
                return true;
        }

        // 修改
        if (cnt < 2) {
            for (int c = 0; c < 26; c++) {
                if (c == idx) continue;
                if (node.child[c] != null) {
                    if (Dfs(word, i + 1, node.child[c], cnt + 1))
                        return true;
                }
            }
        }

        return false;
    }

    public IList<string> TwoEditWords(string[] queries, string[] dictionary) {
        foreach (var w in dictionary)
            Insert(w);

        var res = new List<string>();
        foreach (var q in queries) {
            if (Dfs(q, 0, root, 0)) {
                res.Add(q);
            }
        }
        return res;
    }
}

###Python

class TrieNode:
    def __init__(self):
        self.child = [None] * 26
        self.isEnd = False


class Solution:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for c in word:
            idx = ord(c) - ord('a')
            if not node.child[idx]:
                node.child[idx] = TrieNode()
            node = node.child[idx]
        node.isEnd = True

    def dfs(self, word, i, node, cnt):
        if cnt > 2 or not node:
            return False

        if i == len(word):
            return node.isEnd

        idx = ord(word[i]) - ord('a')

        # 不修改
        if node.child[idx] and self.dfs(word, i + 1, node.child[idx], cnt):
            return True

        # 修改
        if cnt < 2:
            for c in range(26):
                if c == idx:
                    continue
                if node.child[c] and self.dfs(word, i + 1, node.child[c], cnt + 1):
                    return True

        return False

    def twoEditWords(self, queries, dictionary):
        for w in dictionary:
            self.insert(w)

        res = []
        for q in queries:
            if self.dfs(q, 0, self.root, 0):
                res.append(q)
        return res

###C

typedef struct TrieNode {
    struct TrieNode* child[26];
    bool isEnd;
} TrieNode;

TrieNode* newNode() {
    TrieNode* node = (TrieNode*)malloc(sizeof(TrieNode));
    memset(node->child, 0, sizeof(node->child));
    node->isEnd = false;
    return node;
}

void insert(TrieNode* root, char* word) {
    TrieNode* node = root;
    for (int i = 0; word[i]; i++) {
        int idx = word[i] - 'a';
        if (!node->child[idx])
            node->child[idx] = newNode();
        node = node->child[idx];
    }
    node->isEnd = true;
}

bool dfs(char* word, int i, TrieNode* node, int cnt) {
    if (cnt > 2 || !node)
        return false;

    if (word[i] == '\0')
        return node->isEnd;

    int idx = word[i] - 'a';

    // 不修改
    if (node->child[idx] && dfs(word, i + 1, node->child[idx], cnt))
        return true;

    // 修改
    if (cnt < 2) {
        for (int c = 0; c < 26; c++) {
            if (c == idx) continue;
            if (node->child[c] && dfs(word, i + 1, node->child[c], cnt + 1))
                return true;
        }
    }

    return false;
}

char** twoEditWords(char** queries, int queriesSize,
                    char** dictionary, int dictionarySize,
                    int* returnSize) {
    TrieNode* root = newNode();
    for (int i = 0; i < dictionarySize; i++)
        insert(root, dictionary[i]);

    char** res = (char**)malloc(sizeof(char*) * queriesSize);
    int cnt = 0;

    for (int i = 0; i < queriesSize; i++) {
        if (dfs(queries[i], 0, root, 0)) {
            res[cnt++] = queries[i];
        }
    }

    *returnSize = cnt;
    return res;
}

###Go

type TrieNode struct {
    child [26]*TrieNode
    isEnd bool
}

var root = &TrieNode{}

func insert(word string) {
    node := root
    for _, c := range word {
        idx := c - 'a'
        if node.child[idx] == nil {
            node.child[idx] = &TrieNode{}
        }
        node = node.child[idx]
    }
    node.isEnd = true
}

func dfs(word string, i int, node *TrieNode, cnt int) bool {
    if cnt > 2 || node == nil {
        return false
    }

    if i == len(word) {
        return node.isEnd
    }

    idx := word[i] - 'a'

    // 不修改
    if node.child[idx] != nil && dfs(word, i+1, node.child[idx], cnt) {
        return true
    }

    // 修改
    if cnt < 2 {
        for c := 0; c < 26; c++ {
            if byte(c) == idx {
                continue
            }
            if node.child[c] != nil && dfs(word, i+1, node.child[c], cnt+1) {
                return true
            }
        }
    }

    return false
}

func twoEditWords(queries []string, dictionary []string) []string {
    root = &TrieNode{}
    for _, w := range dictionary {
        insert(w)
    }

    var res []string
    for _, q := range queries {
        if dfs(q, 0, root, 0) {
            res = append(res, q)
        }
    }
    return res
}

###JavaScript

class TrieNode {
    constructor() {
        this.child = new Array(26).fill(null);
        this.isEnd = false;
    }
}

var twoEditWords = function(queries, dictionary) {
    const root = new TrieNode();

    function insert(word) {
        let node = root;
        for (let c of word) {
            let idx = c.charCodeAt(0) - 97;
            if (!node.child[idx]) node.child[idx] = new TrieNode();
            node = node.child[idx];
        }
        node.isEnd = true;
    }

    function dfs(word, i, node, cnt) {
        if (cnt > 2 || !node) return false;
        if (i === word.length) return node.isEnd;

        let idx = word.charCodeAt(i) - 97;

        // 修改
        if (node.child[idx] && dfs(word, i + 1, node.child[idx], cnt))
            return true;

        // 不修改
        if (cnt < 2) {
            for (let c = 0; c < 26; c++) {
                if (c === idx) continue;
                if (node.child[c] && dfs(word, i + 1, node.child[c], cnt + 1))
                    return true;
            }
        }

        return false;
    }

    for (let w of dictionary) insert(w);

    const res = [];
    for (let q of queries) {
        if (dfs(q, 0, root, 0)) res.push(q);
    }

    return res;
};

###TypeScript

class TrieNode {
    child: (TrieNode | null)[] = new Array(26).fill(null);
    isEnd: boolean = false;
}

function twoEditWords(queries: string[], dictionary: string[]): string[] {
    const root = new TrieNode();

    function insert(word: string) {
        let node = root;
        for (const c of word) {
            const idx = c.charCodeAt(0) - 97;
            if (!node.child[idx]) node.child[idx] = new TrieNode();
            node = node.child[idx]!;
        }
        node.isEnd = true;
    }

    function dfs(word: string, i: number, node: TrieNode | null, cnt: number): boolean {
        if (cnt > 2 || !node) return false;
        if (i === word.length) return node.isEnd;

        const idx = word.charCodeAt(i) - 97;

        // 不修改
        if (node.child[idx] && dfs(word, i + 1, node.child[idx], cnt))
            return true;

        // 修改
        if (cnt < 2) {
            for (let c = 0; c < 26; c++) {
                if (c === idx) continue;
                if (node.child[c] && dfs(word, i + 1, node.child[c], cnt + 1))
                    return true;
            }
        }

        return false;
    }

    for (const w of dictionary) insert(w);

    const res: string[] = [];
    for (const q of queries) {
        if (dfs(q, 0, root, 0)) res.push(q);
    }

    return res;
}

###Rust

struct TrieNode {
    child: Vec<Option<Box<TrieNode>>>,
    is_end: bool,
}

impl TrieNode {
    fn new() -> Self {
        let mut child = Vec::with_capacity(26);
        for _ in 0..26 {
            child.push(None);
        }

        Self {
            child,
            is_end: false,
        }
    }
}

impl Solution {
    fn insert(root: &mut TrieNode, word: &str) {
        let mut node = root;
        for c in word.chars() {
            let idx = (c as u8 - b'a') as usize;
            if node.child[idx].is_none() {
                node.child[idx] = Some(Box::new(TrieNode::new()));
            }
            node = node.child[idx].as_mut().unwrap();
        }
        node.is_end = true;
    }

    fn dfs(word: &[u8], i: usize, node: &TrieNode, cnt: i32) -> bool {
        if cnt > 2 {
            return false;
        }
        if i == word.len() {
            return node.is_end;
        }

        let idx = (word[i] - b'a') as usize;

        // 不修改
        if let Some(ref next) = node.child[idx] {
            if Self::dfs(word, i + 1, next, cnt) {
                return true;
            }
        }

        // 修改
        if cnt < 2 {
            for c in 0..26 {
                if c == idx {
                    continue;
                }
                if let Some(ref next) = node.child[c] {
                    if Self::dfs(word, i + 1, next, cnt + 1) {
                        return true;
                    }
                }
            }
        }

        false
    }

    pub fn two_edit_words(queries: Vec<String>, dictionary: Vec<String>) -> Vec<String> {
        let mut root = TrieNode::new();

        for w in &dictionary {
            Self::insert(&mut root, w);
        }

        let mut res = vec![];
        for q in &queries {
            if Self::dfs(q.as_bytes(), 0, &root, 0) {
                res.push(q.clone());
            }
        }

        res
    }
}

复杂度分析

  • 时间复杂度:$O(k \cdot n + q \cdot n^2 \cdot 25^2)$,其中 $q$ 为 $\textit{queries}$ 的长度,$k$ 为 $\textit{dictionary}$ 的长度,$n$ 为 $\textit{queries}[i]$ 的长度。建字典树需要 $O(kn)$,查询时对于每一个字母都有修改和不修改两种选择,选择修改位置有 $C_n^2=n^2$ 种选择,其中不修改有 $1$ 条分支,修改有 $25$ 条分支,最多修改两次,因此有 $25^2$ 种选择。
  • 空间复杂度:$O(kn)$。即为字典树所占用的空间。

下标对中的最大距离

方法一:双指针

提示 $1$

考虑遍历下标对中的某一个下标,并寻找此时所有有效坐标对中距离最大的另一个下标。暴力遍历另一个下标显然不满足时间复杂度要求,此时是否存在一些可以优化寻找过程的性质?

思路与算法

不失一般性,我们遍历 $\textit{nums}_2$ 中的下标 $j$,同时寻找此时符合要求的 $\textit{nums}_1$ 中最小的下标 $i$。

假设下标 $j$ 对应的最小下标为 $i$,当 $j$ 变为 $j + 1$ 时,由于 $\textit{nums}_2$ 非递增,即 $\textit{nums}_2[j] \ge \textit{nums}_2[j+1]$,那么 $\textit{nums}_1$ 中可取元素的上界不会增加。同时由于 $\textit{nums}_1$ 也非递增,因此 $j + 1$ 对应的最小下标 $i'$ 一定满足 $i' \ge i$。

那么我们就可以在遍历 $j$ 的同时维护对应的 $i$,并用 $\textit{res}$ 来维护下标对 $(i, j)$ 的最大距离。我们将 $\textit{res}$ 初值置为 $0$,这样即使存在 $\textit{nums}_1[i] \le \textit{nums}_2[j]$ 但 $i > j$ 这种不符合要求的情况,由于此时距离为负因而不会对结果产生影响(不存在时也返回 $0$)。

另外,在维护最大距离的时候要注意下标 $i$ 的合法性,即 $i < n_1$,其中 $n_1$ 为 $\textit{nums}_1$ 的长度。

代码

###C++

class Solution {
public:
    int maxDistance(vector<int>& nums1, vector<int>& nums2) {
        int n1 = nums1.size();
        int n2 = nums2.size();
        int i = 0;
        int res = 0;
        for (int j = 0; j < n2; ++j){
            while (i < n1 && nums1[i] > nums2[j]){
                ++i;
            }
            if (i < n1){
                res = max(res, j - i);
            }
        }
        return res;
    }
};

###Python

class Solution:
    def maxDistance(self, nums1: List[int], nums2: List[int]) -> int:
        n1, n2 = len(nums1), len(nums2)
        i = 0
        res = 0
        for j in range(n2):
            while i < n1 and nums1[i] > nums2[j]:
                i += 1
            if i < n1:
                res = max(res, j - i)
        return res

###Java

class Solution {
    public int maxDistance(int[] nums1, int[] nums2) {
        int n1 = nums1.length;
        int n2 = nums2.length;
        int i = 0;
        int res = 0;
        
        for (int j = 0; j < n2; j++) {
            while (i < n1 && nums1[i] > nums2[j]) {
                i++;
            }
            if (i < n1) {
                res = Math.max(res, j - i);
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int MaxDistance(int[] nums1, int[] nums2) {
        int n1 = nums1.Length;
        int n2 = nums2.Length;
        int i = 0;
        int res = 0;
        
        for (int j = 0; j < n2; j++) {
            while (i < n1 && nums1[i] > nums2[j]) {
                i++;
            }
            if (i < n1) {
                res = Math.Max(res, j - i);
            }
        }
        
        return res;
    }
}

###Go

func maxDistance(nums1 []int, nums2 []int) int {
    n1 := len(nums1)
    n2 := len(nums2)
    i := 0
    res := 0
    
    for j := 0; j < n2; j++ {
        for i < n1 && nums1[i] > nums2[j] {
            i++
        }
        if i < n1 {
            if j-i > res {
                res = j - i
            }
        }
    }
    
    return res
}

###C

int maxDistance(int* nums1, int nums1Size, int* nums2, int nums2Size) {
    int i = 0;
    int res = 0;
    
    for (int j = 0; j < nums2Size; j++) {
        while (i < nums1Size && nums1[i] > nums2[j]) {
            i++;
        }
        if (i < nums1Size) {
            if (j - i > res) {
                res = j - i;
            }
        }
    }
    
    return res;
}

###JavaScript

var maxDistance = function(nums1, nums2) {
    const n1 = nums1.length;
    const n2 = nums2.length;
    let i = 0;
    let res = 0;
    
    for (let j = 0; j < n2; j++) {
        while (i < n1 && nums1[i] > nums2[j]) {
            i++;
        }
        if (i < n1) {
            res = Math.max(res, j - i);
        }
    }
    
    return res;
};

###TypeScript

function maxDistance(nums1: number[], nums2: number[]): number {
    const n1 = nums1.length;
    const n2 = nums2.length;
    let i = 0;
    let res = 0;
    
    for (let j = 0; j < n2; j++) {
        while (i < n1 && nums1[i] > nums2[j]) {
            i++;
        }
        if (i < n1) {
            res = Math.max(res, j - i);
        }
    }
    
    return res;
};

###Rust

impl Solution {
    pub fn max_distance(nums1: Vec<i32>, nums2: Vec<i32>) -> i32 {
        let n1 = nums1.len();
        let n2 = nums2.len();
        let mut i = 0;
        let mut res = 0;
        
        for j in 0..n2 {
            while i < n1 && nums1[i] > nums2[j] {
                i += 1;
            }
            if i < n1 {
                res = res.max((j as i32) - (i as i32));
            }
        }
        
        res
    }
}

复杂度分析

  • 时间复杂度:$O(n_1 + n_2)$,其中 $n_1, n_2$ 分别为 $\textit{nums}_1$ 与 $\textit{nums}_2$ 的长度。在双指针寻找最大值的过程中,我们最多会遍历两个数组各一次。

  • 空间复杂度:$O(1)$,我们使用了常数个变量进行遍历。

二指输入的的最小距离

方法一:动态规划

我们用 dp[i][l][r] 表示在输入了字符串 word 的第 i 个字母后,左手的位置为 l,右手的位置为 r,达到该状态的最小移动距离。这里的位置为指向的字母编号,例如 A 对应 0B 对应 1,以此类推,而非字母在键盘上的位置。这样做的好处是将字母的位置映射成一个整数而非二维的坐标,使得我们更加方便地进行状态转移。

那么如何进行状态转移呢?我们首先需要看出一个非常重要的性质:对于状态 dp[i][l][r],要么 word[i] == l,要么 word[i] == r,即在输入了第 i 个字母后,左手和右手中至少有一个在 word[i] 的位置。我们可以根据这两种情况,分别进行状态转移:

  • word[i] == l 时,左手在 word[i] 的位置。我们需要考虑在输入字符串 word 的第 i - 1 个字母时,是左手还是右手在 word[i - 1] 的位置:

    • 如果左手在 word[i - 1] 的位置,那么在输入第 i 个字母时,左手从 word[i - 1] 移动至 word[i],状态转移方程为:

      dp[i][l = word[i]][r] = dp[i - 1][l0 = word[i - 1]][r] + dist(word[i - 1], word[i])
      
    • 如果右手在 word[i - 1] 的位置,那么由于第 i 个字母使用了左手,右手就没有移动,即 word[i - 1] == r。同时,在输入 word[i1] 之前的左手位置也无关紧要,可以为任意的 l0,状态转移方程为:

      dp[i][l = word[i]][r = word[i - 1]] = dp[i - 1][l0][r = word[i - 1]] + dist(l0, word[i])
      
  • word[i] == r 时,右手在 word[i] 的位置。我们需要考虑在输入字符串 word 的第 i - 1 个字母时,是右手还是左手在 word[i - 1] 的位置:

    • 如果右手在 word[i - 1] 的位置,那么在输入第 i 个字母时,右手从 word[i - 1] 移动至 word[i],状态转移方程为:

      dp[i][l][r = word[i]] = dp[i - 1][l][r0 = word[i - 1]] + dist(word[i - 1], word[i])
      
    • 如果左手在 word[i - 1] 的位置,那么由于第 i 个字母使用了右手,左手就没有移动,即 word[i - 1] == l。同时,在输入 word[i] 之前的右手位置也无关紧要,可以为任意的 r0,状态转移方程为:

      dp[i][l = word[i - 1]][r = word[i]] = dp[i - 1][l = word[i - 1]][r0] + dist(r0, word[i])
      

对于每一个状态 dp[i][l][r],我们取它所有转移中的最小值,即为输入了字符串 word 的第 i 个字母后,左手的位置为 l,右手的位置为 r,达到该状态的最小移动距离。

在这个动态规划中,我们还需要考虑不合法的状态以及边界状态。对于某一个不合法的状态,如果用它来进行状态转移,可能会使得 dp[i][l][r] 取到一个更小且不合法的值。因此,我们一般会给所有不合法的状态赋予一个非常大的值(例如 C++ 中的整数最大值 INT_MAX),这样即使用它来进行状态转移,也会因为本身值非常大的缘故,对最优解没有任何影响。在考虑边界状态时,由于题目中规定两根手指的起始位置是零代价的,因此对于字符串中的第 0 个字母 word[0],输入它的最小移动距离为 0。此时要么左手的位置为 word[0],要么右手的位置为 word[0],因此我们可以将所有的 dp[0][l = word[0]][r] 以及 dp[0][l][r = word[0]] 作为边界状态,它们的值为 0

###C++

class Solution {
public:
    int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return abs(x1 - x2) + abs(y1 - y2);
    }

    int minimumDistance(string word) {
        int n = word.size();
        int dp[n][26][26];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                fill(dp[i][j], dp[i][j] + 26, INT_MAX >> 1);
            }
        }
        for (int i = 0; i < 26; ++i) {
            dp[0][i][word[0] - 'A'] = dp[0][word[0] - 'A'][i] = 0;
        }
        
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = getDistance(prev, cur);
            for (int j = 0; j < 26; ++j) {
                dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][prev][j] + d);
                dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][prev] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                        dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                    }
                }
            }
        }

        int ans = INT_MAX >> 1;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = min(ans, dp[n - 1][i][j]);
            }
        }
        return ans;
    }
};

###Python

class Solution:
    def minimumDistance(self, word: str) -> int:
        n = len(word)
        BIG = 2**30
        dp = [[[BIG] * 26 for x in range(26)] for y in range(n)]
        for i in range(26):
            dp[0][i][ord(word[0]) - 65] = 0
            dp[0][ord(word[0]) - 65][i] = 0
    
        def getDistance(p, q):
            x1, y1 = p // 6, p % 6
            x2, y2 = q // 6, q % 6
            return abs(x1 - x2) + abs(y1 - y2)

        for i in range(1, n):
            cur, prev = ord(word[i]) - 65, ord(word[i - 1]) - 65
            d = getDistance(prev, cur)
            for j in range(26):
                dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][prev][j] + d)
                dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][prev] + d)
                if prev == j:
                    for k in range(26):
                        d0 = getDistance(k, cur)
                        dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][k][j] + d0)
                        dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][k] + d0)
        
        ans = min(min(dp[n - 1][x]) for x in range(26))
        return ans

###Java

class Solution {
    private int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    }

    public int minimumDistance(String word) {
        int n = word.length();
        int[][][] dp = new int[n][26][26];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                for (int k = 0; k < 26; ++k) {
                    dp[i][j][k] = Integer.MAX_VALUE / 2;
                }
            }
        }
        
        for (int i = 0; i < 26; ++i) {
            dp[0][i][word.charAt(0) - 'A'] = 0;
            dp[0][word.charAt(0) - 'A'][i] = 0;
        }
        
        for (int i = 1; i < n; ++i) {
            int cur = word.charAt(i) - 'A';
            int prev = word.charAt(i - 1) - 'A';
            int d = getDistance(prev, cur);
            
            for (int j = 0; j < 26; ++j) {
                dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
                dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
                
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                        dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                    }
                }
            }
        }
        
        int ans = Integer.MAX_VALUE / 2;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = Math.min(ans, dp[n - 1][i][j]);
            }
        }
        return ans;
    }
}

###C#

public class Solution {
    private int GetDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.Abs(x1 - x2) + Math.Abs(y1 - y2);
    }

    public int MinimumDistance(string word) {
        int n = word.Length;
        int[,,] dp = new int[n, 26, 26];
        
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                for (int k = 0; k < 26; ++k) {
                    dp[i, j, k] = int.MaxValue / 2;
                }
            }
        }
        
        for (int i = 0; i < 26; ++i) {
            dp[0, i, word[0] - 'A'] = 0;
            dp[0, word[0] - 'A', i] = 0;
        }
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = GetDistance(prev, cur);
            
            for (int j = 0; j < 26; ++j) {
                dp[i, cur, j] = Math.Min(dp[i, cur, j], dp[i - 1, prev, j] + d);
                dp[i, j, cur] = Math.Min(dp[i, j, cur], dp[i - 1, j, prev] + d);
                
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = GetDistance(k, cur);
                        dp[i, cur, j] = Math.Min(dp[i, cur, j], dp[i - 1, k, j] + d0);
                        dp[i, j, cur] = Math.Min(dp[i, j, cur], dp[i - 1, j, k] + d0);
                    }
                }
            }
        }
        
        int ans = int.MaxValue / 2;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = Math.Min(ans, dp[n - 1, i, j]);
            }
        }
        return ans;
    }
}

###Go

func getDistance(p, q int) int {
    x1, y1 := p/6, p%6
    x2, y2 := q/6, q%6
    return abs(x1 - x2) + abs(y1 - y2)
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

func minimumDistance(word string) int {
    n := len(word)
    dp := make([][26][26]int, n)
    
    for i := 0; i < n; i++ {
        for j := 0; j < 26; j++ {
            for k := 0; k < 26; k++ {
                dp[i][j][k] = 1 << 30
            }
        }
    }
    
    firstChar := int(word[0] - 'A')
    for i := 0; i < 26; i++ {
        dp[0][i][firstChar] = 0
        dp[0][firstChar][i] = 0
    }
    
    for i := 1; i < n; i++ {
        cur := int(word[i] - 'A')
        prev := int(word[i-1] - 'A')
        d := getDistance(prev, cur)
        
        for j := 0; j < 26; j++ {
            dp[i][cur][j] = min(dp[i][cur][j], dp[i-1][prev][j]+d)
            dp[i][j][cur] = min(dp[i][j][cur], dp[i-1][j][prev]+d)
            
            if prev == j {
                for k := 0; k < 26; k++ {
                    d0 := getDistance(k, cur)
                    dp[i][cur][j] = min(dp[i][cur][j], dp[i-1][k][j]+d0)
                    dp[i][j][cur] = min(dp[i][j][cur], dp[i-1][j][k]+d0)
                }
            }
        }
    }
    
    ans := 1 << 30
    for i := 0; i < 26; i++ {
        for j := 0; j < 26; j++ {
            ans = min(ans, dp[n-1][i][j])
        }
    }
    return ans
}

###C

int getDistance(int p, int q) {
    int x1 = p / 6, y1 = p % 6;
    int x2 = q / 6, y2 = q % 6;
    return abs(x1 - x2) + abs(y1 - y2);
}

int minimumDistance(char* word) {
    int n = strlen(word);
    int*** dp = (int***)malloc(n * sizeof(int**));
    for (int i = 0; i < n; ++i) {
        dp[i] = (int**)malloc(26 * sizeof(int*));
        for (int j = 0; j < 26; ++j) {
            dp[i][j] = (int*)malloc(26 * sizeof(int));
            for (int k = 0; k < 26; ++k) {
                dp[i][j][k] = INT_MAX / 2;
            }
        }
    }
    
    for (int i = 0; i < 26; ++i) {
        dp[0][i][word[0] - 'A'] = 0;
        dp[0][word[0] - 'A'][i] = 0;
    }
    for (int i = 1; i < n; ++i) {
        int cur = word[i] - 'A';
        int prev = word[i - 1] - 'A';
        int d = getDistance(prev, cur);
        
        for (int j = 0; j < 26; ++j) {
            dp[i][cur][j] = fmin(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = fmin(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev == j) {
                for (int k = 0; k < 26; ++k) {
                    int d0 = getDistance(k, cur);
                    dp[i][cur][j] = fmin(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = fmin(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    int ans = INT_MAX / 2;
    for (int i = 0; i < 26; ++i) {
        for (int j = 0; j < 26; ++j) {
            if (ans > dp[n - 1][i][j]) {
                ans = dp[n - 1][i][j];
            }
        }
    }
    
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < 26; ++j) {
            free(dp[i][j]);
        }
        free(dp[i]);
    }
    free(dp);
    
    return ans;
}

###JavaScript

var minimumDistance = function(word) {
    const n = word.length;
    const getDistance = (p, q) => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26);
        for (let j = 0; j < 26; j++) {
            dp[i][j] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
        }
    }
    
    const firstChar = word.charCodeAt(0) - 65;
    for (let i = 0; i < 26; i++) {
        dp[0][i][firstChar] = 0;
        dp[0][firstChar][i] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    let ans = Number.MAX_SAFE_INTEGER;
    for (let i = 0; i < 26; i++) {
        for (let j = 0; j < 26; j++) {
            ans = Math.min(ans, dp[n - 1][i][j]);
        }
    }
    return ans;
};

###TypeScript

function minimumDistance(word: string): number {
    const n = word.length;
    const getDistance = (p: number, q: number): number => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp: number[][][] = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26);
        for (let j = 0; j < 26; j++) {
            dp[i][j] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
        }
    }
    const firstChar = word.charCodeAt(0) - 65;
    for (let i = 0; i < 26; i++) {
        dp[0][i][firstChar] = 0;
        dp[0][firstChar][i] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    let ans = Number.MAX_SAFE_INTEGER;
    for (let i = 0; i < 26; i++) {
        for (let j = 0; j < 26; j++) {
            ans = Math.min(ans, dp[n - 1][i][j]);
        }
    }
    return ans;
}

###Rust

impl Solution {
    fn get_distance(p: i32, q: i32) -> i32 {
        let x1 = p / 6;
        let y1 = p % 6;
        let x2 = q / 6;
        let y2 = q % 6;
        (x1 - x2).abs() + (y1 - y2).abs()
    }

    pub fn minimum_distance(word: String) -> i32 {
        let n = word.len();
        let word_chars: Vec<char> = word.chars().collect();
        let mut dp = vec![vec![vec![i32::MAX >> 1; 26]; 26]; n];
        let first_char = (word_chars[0] as u8 - b'A') as usize;
        for i in 0..26 {
            dp[0][i][first_char] = 0;
            dp[0][first_char][i] = 0;
        }
        
        for i in 1..n {
            let cur = (word_chars[i] as u8 - b'A') as i32;
            let prev = (word_chars[i - 1] as u8 - b'A') as i32;
            let d = Self::get_distance(prev, cur);
            
            for j in 0..26 {
                let j_i32 = j as i32;
                dp[i][cur as usize][j] = dp[i][cur as usize][j].min(
                    dp[i - 1][prev as usize][j].saturating_add(d)
                );
                dp[i][j][cur as usize] = dp[i][j][cur as usize].min(
                    dp[i - 1][j][prev as usize].saturating_add(d)
                );
                
                if prev == j_i32 {
                    for k in 0..26 {
                        let d0 = Self::get_distance(k as i32, cur);
                        dp[i][cur as usize][j] = dp[i][cur as usize][j].min(
                            dp[i - 1][k][j].saturating_add(d0)
                        );
                        dp[i][j][cur as usize] = dp[i][j][cur as usize].min(
                            dp[i - 1][j][k].saturating_add(d0)
                        );
                    }
                }
            }
        }
        
        let mut ans = i32::MAX >> 1;
        for i in 0..26 {
            for j in 0..26 {
                ans = ans.min(dp[n - 1][i][j]);
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(|\Sigma|N)$,其中 $N$ 是字符串 word 的长度,$|\Sigma|$ 是可能出现的字母数量,在本题中 $|\Sigma| = 26$。对于状态 dp[i][l][r],枚举 i 需要的时间复杂度为 $O(N)$,在此之后,如果 word[i] == l,根据上面的状态转移方程:

    • 如果左手在 word[i - 1] 的位置,那么单次状态转移的时间复杂度为 $O(1)$,需要对所有的 r 都进行转移,总时间复杂度为 $O(|\Sigma|)$;

    • 如果右手在 word[i - 1] 的位置,那么 r == word[i - 1]。虽然我们要枚举 l0,但是合法的 r 只有一个,因此总时间复杂度也为 $O(|\Sigma|)$。

    如果 word[i] == r,分析的过程相同,在此不再赘述。这样总时间复杂度即为 $O(|\Sigma|N)$。

  • 空间复杂度:$O(|\Sigma|^2 N)$。

方法二:动态规划 + 空间优化

在方法一中,我们提到了一条重要的性质:对于状态 dp[i][l][r],要么 word[i] == l,要么 word[i] == r,即在输入了第 i 个字母后,左手和右手中至少有一个在 word[i] 的位置。那么对于每一个 i,我们其实只需要存储 $2|\Sigma|$ 而不是 $|\Sigma|^2$ 个状态。例如我们可以用 dp[i][op][rest] 表示状态,其中 op 的值只能为 01op == 0 表示左手在 word[i] 的位置,op == 1 表示右手在 word[i] 的位置,而 rest 表示不在 word[i] 位置的另一只手的位置。这样我们在状态转移方程几乎不变的前提下,减少了动态规划需要的空间。

那么我们是否还可以继续进行优化呢?我们可以发现,在方法一中,状态转移方程具有高度对称性,那么我们可以断定,dp[i][op = 0][rest]dp[i][op = 1][rest] 的值一定是相等的。这是因为 dp[i][op = 0][rest] 表示左手在 word[i] 的位置且右手在 rest 的位置,如果我们将左右手互换,那么我们同样可以使用 dp[i][op = 0][rest] 的移动距离使得右手在 word[i] 的位置且左手在 rest 的位置,而这恰好就是 dp[i][op = 1][rest]

因此我们可以直接使用 dp[i][rest] 进行状态转移,其表示一只手在 word[i] 的位置,另一只手在 rest 的位置的最小移动距离。我们并不需要关心具体哪只手在 word[i] 的位置,因为两只手是完全对称的。这样以来,我们将三维的动态规划优化至了二维,大大减少了空间的使用。

###C++

class Solution {
public:
    int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return abs(x1 - x2) + abs(y1 - y2);
    }

    int minimumDistance(string word) {
        int n = word.size();
        vector<vector<int>> dp(n, vector<int>(26, INT_MAX >> 1));
        fill(dp[0].begin(), dp[0].end(), 0);
        
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = getDistance(prev, cur);
            for (int j = 0; j < 26; ++j) {
                dp[i][j] = min(dp[i][j], dp[i - 1][j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][j] = min(dp[i][j], dp[i - 1][k] + d0);
                    }
                }
            }
        }

        int ans = *min_element(dp[n - 1].begin(), dp[n - 1].end());
        return ans;
    }
};

###Python

class Solution:
    def minimumDistance(self, word: str) -> int:
        n = len(word)
        BIG = 2**30
        dp = [[0] * 26] + [[BIG] * 26 for _ in range(n - 1)]
        
        def getDistance(p, q):
            x1, y1 = p // 6, p % 6
            x2, y2 = q // 6, q % 6
            return abs(x1 - x2) + abs(y1 - y2)

        for i in range(1, n):
            cur, prev = ord(word[i]) - 65, ord(word[i - 1]) - 65
            d = getDistance(prev, cur)
            for j in range(26):
                dp[i][j] = min(dp[i][j], dp[i - 1][j] + d)
                if prev == j:
                    for k in range(26):
                        d0 = getDistance(k, cur)
                        dp[i][j] = min(dp[i][j], dp[i - 1][k] + d0)

        ans = min(dp[n - 1])
        return ans

###Java

class Solution {
    private int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    }

    public int minimumDistance(String word) {
        int n = word.length();
        int[][] dp = new int[n][26];
        for (int i = 0; i < n; i++) {
            Arrays.fill(dp[i], Integer.MAX_VALUE / 2);
        }
        Arrays.fill(dp[0], 0);
        
        for (int i = 1; i < n; i++) {
            int cur = word.charAt(i) - 'A';
            int prev = word.charAt(i - 1) - 'A';
            int d = getDistance(prev, cur);
            
            for (int j = 0; j < 26; j++) {
                dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; k++) {
                        int d0 = getDistance(k, cur);
                        dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                    }
                }
            }
        }
        
        int ans = Integer.MAX_VALUE / 2;
        for (int value : dp[n - 1]) {
            ans = Math.min(ans, value);
        }
        return ans;
    }
}

###C#

public class Solution {
    private int GetDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.Abs(x1 - x2) + Math.Abs(y1 - y2);
    }

    public int MinimumDistance(string word) {
        int n = word.Length;
        int[,] dp = new int[n, 26];
        
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < 26; j++) {
                dp[i, j] = int.MaxValue / 2;
            }
        }
        
        for (int j = 0; j < 26; j++) {
            dp[0, j] = 0;
        }
        for (int i = 1; i < n; i++) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = GetDistance(prev, cur);
            
            for (int j = 0; j < 26; j++) {
                dp[i, j] = Math.Min(dp[i, j], dp[i - 1, j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; k++) {
                        int d0 = GetDistance(k, cur);
                        dp[i, j] = Math.Min(dp[i, j], dp[i - 1, k] + d0);
                    }
                }
            }
        }
        
        int ans = int.MaxValue / 2;
        for (int j = 0; j < 26; j++) {
            ans = Math.Min(ans, dp[n - 1, j]);
        }
        return ans;
    }
}

###Go

func getDistance(p, q int) int {
    x1, y1 := p / 6, p % 6
    x2, y2 := q / 6, q % 6
    return abs(x1 - x2) + abs(y1 - y2)
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

func minimumDistance(word string) int {
    n := len(word)
    dp := make([][]int, n)
    for i := range dp {
        dp[i] = make([]int, 26)
        for j := range dp[i] {
            dp[i][j] = 1 << 30
        }
    }
    for j := 0; j < 26; j++ {
        dp[0][j] = 0
    }
    
    for i := 1; i < n; i++ {
        cur := int(word[i] - 'A')
        prev := int(word[i-1] - 'A')
        d := getDistance(prev, cur)
        
        for j := 0; j < 26; j++ {
            dp[i][j] = min(dp[i][j], dp[i-1][j]+d)
            if prev == j {
                for k := 0; k < 26; k++ {
                    d0 := getDistance(k, cur)
                    dp[i][j] = min(dp[i][j], dp[i-1][k]+d0)
                }
            }
        }
    }
    
    ans := 1 << 30
    for j := 0; j < 26; j++ {
        ans = min(ans, dp[n-1][j])
    }
    return ans
}

###C

int getDistance(int p, int q) {
    int x1 = p / 6, y1 = p % 6;
    int x2 = q / 6, y2 = q % 6;
    return abs(x1 - x2) + abs(y1 - y2);
}

int minimumDistance(char* word) {
    int n = strlen(word);
    int** dp = (int**)malloc(n * sizeof(int*));
    for (int i = 0; i < n; i++) {
        dp[i] = (int*)malloc(26 * sizeof(int));
        for (int j = 0; j < 26; j++) {
            dp[i][j] = INT_MAX / 2;
        }
    }
    for (int j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (int i = 1; i < n; i++) {
        int cur = word[i] - 'A';
        int prev = word[i - 1] - 'A';
        int d = getDistance(prev, cur);
        
        for (int j = 0; j < 26; j++) {
            dp[i][j] = fmin(dp[i][j], dp[i - 1][j] + d);
            if (prev == j) {
                for (int k = 0; k < 26; k++) {
                    int d0 = getDistance(k, cur);
                    dp[i][j] = fmin(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    int ans = INT_MAX / 2;
    for (int j = 0; j < 26; j++) {
        if (ans > dp[n - 1][j]) {
            ans = dp[n - 1][j];
        }
    }
    for (int i = 0; i < n; i++) {
        free(dp[i]);
    }
    free(dp);
    
    return ans;
}

###JavaScript

var minimumDistance = function(word) {
    const n = word.length;
    const getDistance = (p, q) => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
    }
    
    for (let j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    let ans = Math.floor(Number.MAX_SAFE_INTEGER / 2);
    for (let j = 0; j < 26; j++) {
        ans = Math.min(ans, dp[n - 1][j]);
    }
    return ans;
};

###TypeScript

function minimumDistance(word: string): number {
    const n = word.length;
    const getDistance = (p: number, q: number): number => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp: number[][] = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
    }
    for (let j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    let ans = Math.floor(Number.MAX_SAFE_INTEGER / 2);
    for (let j = 0; j < 26; j++) {
        ans = Math.min(ans, dp[n - 1][j]);
    }
    return ans;
}

###Rust

impl Solution {
    fn get_distance(p: i32, q: i32) -> i32 {
        let x1 = p / 6;
        let y1 = p % 6;
        let x2 = q / 6;
        let y2 = q % 6;
        (x1 - x2).abs() + (y1 - y2).abs()
    }

    pub fn minimum_distance(word: String) -> i32 {
        let n = word.len();
        let word_bytes = word.as_bytes();
        let mut dp = vec![vec![i32::MAX >> 1; 26]; n];
        
        for j in 0..26 {
            dp[0][j] = 0;
        }
        
        for i in 1..n {
            let cur = (word_bytes[i] - b'A') as i32;
            let prev = (word_bytes[i - 1] - b'A') as i32;
            let d = Self::get_distance(prev, cur);
            
            for j in 0..26 {
                let j_i32 = j as i32;
                dp[i][j] = dp[i][j].min(dp[i - 1][j].saturating_add(d));
                
                if prev == j_i32 {
                    for k in 0..26 {
                        let d0 = Self::get_distance(k as i32, cur);
                        dp[i][j] = dp[i][j].min(dp[i - 1][k].saturating_add(d0));
                    }
                }
            }
        }
        
        let mut ans = i32::MAX >> 1;
        for j in 0..26 {
            ans = ans.min(dp[n - 1][j]);
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(|\Sigma|N)$。

  • 空间复杂度:$O(|\Sigma|N)$。

三个相等元素之间的最小距离 II

方法一:遍历 + 哈希表

思路与算法

分析题目可知,所求三元组的距离实际上是广义三角形的三边之和,不管选取的三个点顺序如何,长度一定等于两倍的端点构成的线段的长度;换而言之,设最右侧点的下标是 $k$,最左侧点的下标是 $i$,所求的距离就是 $2 \times (k - i)$。

显然,对于所有相同元素对应下标构成的有效三元组,其最小距离必定在三个相邻元素构成的三元组间产生。以此为突破口,类比链表,我们可以通过维护前驱数组或者后继数组的方式快速求解当前元素的前驱和后继,以便计算距离并更新答案。

下面以后继数组为例讲解具体实现,采用前驱数组的方法只需要一次遍历,留给读者自行思考。

首先定义后继数组 $\textit{next}$,设 $\textit{next}[i]$ 记录了 $\textit{nums}[i]$ 在 $\textit{nums}$ 中下一次出现的位置。倒序遍历 $\textit{nums}$,配合哈希表记录 $\textit{nums}[i]$ 在倒序遍历中最近一次出现的位置,即可求出 $\textit{next}$ 数组。

随后遍历 $\textit{nums}$,借助 $\textit{next}$ 数组,我们可以在 $O(1)$ 的时间内求出与 $\textit{nums}[i]$ 值相同的两个相邻的后继元素,计算距离并更新答案即可。

代码

###C++

class Solution {
public:
    int minimumDistance(vector<int>& nums) {
        int n = nums.size();
        std::vector<int> next(n, -1);
        std::unordered_map<int, int> occur;
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.count(nums[i])) {
                next[i] = occur[nums[i]];
            }
            occur[nums[i]] = i;
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = std::min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
};

###JavaScript

var minimumDistance = function (nums) {
    const next = Array.from({ length: nums.length }).fill(-1);
    const occur = new Map();
    let ans = nums.length + 1;

    for (let i = nums.length - 1; i >= 0; i--) {
        if (occur.has(nums[i])) {
            next[i] = occur.get(nums[i]);
        }
        occur.set(nums[i], i);
    }

    for (let i = 0; i < nums.length; i++) {
        let secondPos = next[i];
        let thirdPos = next[secondPos];
        if (secondPos !== -1 && thirdPos !== -1) {
            ans = Math.min(ans, thirdPos - i);
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###TypeScript

function minimumDistance(nums: number[]): number {
    const next = Array.from<number>({ length: nums.length }).fill(-1);
    const occur = new Map<number, number>();
    let ans = nums.length + 1;

    for (let i = nums.length - 1; i >= 0; i--) {
        if (occur.has(nums[i])) {
            next[i] = occur.get(nums[i])!;
        }
        occur.set(nums[i], i);
    }

    for (let i = 0; i < nums.length; i++) {
        let secondPos = next[i];
        let thirdPos = next[secondPos];
        if (secondPos !== -1 && thirdPos !== -1) {
            ans = Math.min(ans, thirdPos - i);
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###Java

class Solution {
    public int minimumDistance(int[] nums) {
        int n = nums.length;
        int[] next = new int[n];
        Arrays.fill(next, -1);
        Map<Integer, Integer> occur = new HashMap<>();
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.containsKey(nums[i])) {
                next[i] = occur.get(nums[i]);
            }
            occur.put(nums[i], i);
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = Math.min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###C#

public class Solution {
    public int MinimumDistance(int[] nums) {
        int n = nums.Length;
        int[] next = new int[n];
        Array.Fill(next, -1);
        Dictionary<int, int> occur = new();
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.TryGetValue(nums[i], out int val)) {
                next[i] = val;
            }
            occur[nums[i]] = i;
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = Math.Min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###Go

func minimumDistance(nums []int) int {
n := len(nums)
next := make([]int, n)
for i := range next {
next[i] = -1
}
occur := make(map[int]int)
ans := n + 1

for i := n - 1; i >= 0; i-- {
if val, ok := occur[nums[i]]; ok {
next[i] = val
}
occur[nums[i]] = i
}

for i := 0; i < n; i++ {
secondPos := next[i]
if secondPos != -1 {
thirdPos := next[secondPos]
if thirdPos != -1 {
if dist := thirdPos - i; dist < ans {
ans = dist
}
}
}
}

if ans == n + 1 {
return -1
}
return ans * 2
}

###Python

class Solution:
    def minimumDistance(self, nums: List[int]) -> int:
        n = len(nums)
        nxt = [-1] * n
        occur = {}
        ans = n + 1

        for i in range(n - 1, -1, -1):
            if nums[i] in occur:
                nxt[i] = occur[nums[i]]
            occur[nums[i]] = i

        for i in range(n):
            second_pos = nxt[i]
            if second_pos != -1:
                third_pos = nxt[second_pos]
                if third_pos != -1:
                    ans = min(ans, third_pos - i)

        return -1 if ans == n + 1 else ans * 2

###C

typedef struct {
    int key;
    int val;
    UT_hash_handle hh;
} HashItem;

HashItem *hashFindItem(HashItem **obj, int key) {
    HashItem *pEntry = NULL;
    HASH_FIND_INT(*obj, &key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, int key, int val) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = (HashItem *)malloc(sizeof(HashItem));
    pEntry->key = key;
    pEntry->val = val;
    HASH_ADD_INT(*obj, key, pEntry);
    return true;
}

bool hashSetItem(HashItem **obj, int key, int val) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        hashAddItem(obj, key, val);
    } else {
        pEntry->val = val;
    }
    return true;
}

int hashGetItem(HashItem **obj, int key, int defaultVal) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        return defaultVal;
    }
    return pEntry->val;
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);  
        free(curr);
    }
}

int minimumDistance(int* nums, int numsSize) {
    int* next = (int*)malloc(numsSize * sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        next[i] = -1;
    }
    
    HashItem* occur = NULL;
    int ans = numsSize + 1;
    
    for (int i = numsSize - 1; i >= 0; i--) {
        int prevPos = hashGetItem(&occur, nums[i], -1);
        if (prevPos != -1) {
            next[i] = prevPos;
        }
        hashSetItem(&occur, nums[i], i);
    }
    
    for (int i = 0; i < numsSize; i++) {
        int secondPos = next[i];
        if (secondPos != -1) {
            int thirdPos = next[secondPos];
            if (thirdPos != -1) {
                int distance = thirdPos - i;
                if (distance < ans) {
                    ans = distance;
                }
            }
        }
    }
    
    free(next);
    hashFree(&occur);
    
    return ans == numsSize + 1 ? - 1 : ans * 2;
}

###Rust

use std::collections::HashMap;

impl Solution {
    pub fn minimum_distance(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        let mut next = vec![-1_isize; n];
        let mut occur = HashMap::new();
        let mut ans = n + 1;

        for i in (0..n).rev() {
            if let Some(&val) = occur.get(&nums[i]) {
                next[i] = val as isize;
            }
            occur.insert(nums[i], i);
        }

        for i in 0..n {
            let second_pos = next[i];
            if second_pos != -1 {
                let third_pos = next[second_pos as usize];
                if third_pos != -1 {
                    ans = ans.min(third_pos as usize - i);
                }
            }
        }

        if ans == n + 1 {
            -1
        } else {
            (ans * 2) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。倒序遍历构造 $\textit{next}$ 数组和正序遍历求解答案各需要 $O(n)$,哈希表各项操作平均复杂度为 $O(1)$。

  • 空间复杂度:$O(n)$,$\textit{next}$ 数组和哈希表需要 $O(n)$ 的空间。

三个相等元素之间的最小距离 I

方法一:暴力

思路与算法

本题是「3741. 三个相等元素之间的最小距离 II」的数据简化版,由于数据范围较小,可以直接使用暴力求解。

首先观察要求的绝对值距离之和计算公式:可以发现实际上这就是一个广义三角形的三边之和,不管选取的三个点顺序如何,长度一定等于两倍的端点构成的线段的长度;换而言之,设最右侧点的下标是 $k$,最左侧点的下标是 $i$,所求的距离就是 $2 \times (k - i)$。

故使用三重循环暴力枚举所有不同的顺序三元组,若 $\textit{nums}$ 中对应位置的元素相同,则根据上述分析计算距离,最后取全局最小值即为所求。

代码

###C++

class Solution {
public:
    int minimumDistance(vector<int>& nums) {
        int n = nums.size();
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = std::min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
};

###JavaScript

var minimumDistance = function (nums) {
    let ans = nums.length + 1;

    for (let i = 0; i < nums.length - 2; i++) {
        for (let j = i + 1; j < nums.length - 1; j++) {
            if (nums[i] !== nums[j]) {
                continue;
            }
            for (let k = j + 1; k < nums.length; k++) {
                if (nums[j] === nums[k]) {
                    ans = Math.min(ans, k - i);
                    break;
                }
            }
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###TypeScript

function minimumDistance(nums: number[]): number {
    let ans = nums.length + 1;
    for (let i = 0; i < nums.length - 2; i++) {
        for (let j = i + 1; j < nums.length - 1; j++) {
            if (nums[i] !== nums[j]) {
                continue;
            }
            for (let k = j + 1; k < nums.length; k++) {
                if (nums[j] === nums[k]) {
                    ans = Math.min(ans, k - i);
                    break;
                }
            }
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
}

###Java

class Solution {
    public int minimumDistance(int[] nums) {
        int n = nums.length;
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = Math.min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###C#

public class Solution {
    public int MinimumDistance(int[] nums) {
        int n = nums.Length;
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = Math.Min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###Go

func minimumDistance(nums []int) int {
n := len(nums)
ans := n + 1

for i := 0; i < n-2; i++ {
for j := i + 1; j < n-1; j++ {
if nums[i] != nums[j] {
continue
}
for k := j + 1; k < n; k++ {
if nums[j] == nums[k] {
if dist := k - i; dist < ans {
ans = dist
}
break
}
}
}
}

if ans == n+1 {
return -1
}
return ans * 2
}

###Python

class Solution:
    def minimumDistance(self, nums: List[int]) -> int:
        n = len(nums)
        ans = n + 1

        for i in range(n - 2):
            for j in range(i + 1, n - 1):
                if nums[i] != nums[j]:
                    continue
                for k in range(j + 1, n):
                    if nums[j] == nums[k]:
                        ans = min(ans, k - i)
                        break

        return -1 if ans == n + 1 else ans * 2

###C

int minimumDistance(int* nums, int numsSize) {
    int ans = numsSize + 1;

    for (int i = 0; i < numsSize - 2; i++) {
        for (int j = i + 1; j < numsSize - 1; j++) {
            if (nums[i] != nums[j]) {
                continue;
            }
            for (int k = j + 1; k < numsSize; k++) {
                if (nums[j] == nums[k]) {
                    if (k - i < ans) {
                        ans = k - i;
                    }
                    break;
                }
            }
        }
    }

    return ans == numsSize + 1 ? -1 : ans * 2;
}

###Rust

impl Solution {
    pub fn minimum_distance(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        let mut ans = n + 1;

        if n < 3 {
           return -1;
        }

        for i in 0..n - 2 {
            for j in i + 1..n - 1 {
                if nums[i] != nums[j] {
                    continue;
                }
                for k in j + 1..n {
                    if nums[j] == nums[k] {
                        ans = ans.min(k - i);
                        break;
                    }
                }
            }
        }

        if ans == n + 1 {
            -1
        } else {
            (ans * 2) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n^3)$,其中 $n$ 是 $\textit{nums}$ 的长度,求解用到三重循环,每重循环需要 $O(n)$,故总时间复杂度是 $O(n^3)$。

  • 空间复杂度:$O(1)$,只声明了常数个变量。

网格图中机器人回家的最小代价

方法一:贪心

提示 $1$

如果在某一条路径中,相邻的两步分别为横向(左/右)和纵向(上/下)移动,那么交换这两步前后,路径的总代价不变。

提示 $1$ 解释

由于路径的其它部分不会改变,对应部分的代价也不会改变,因此我们只需要考虑交换的两步。不妨假设在这两步的过程中,机器人从 $(r, c)$ 移动到了 $(r + 1, c + 1)$。

考虑交换前后两种不同的移动方式(用 $\rightarrow$ 表示沿着某个方向一直移动,下同):

  • $(r, c) \rightarrow (r + 1, c) \rightarrow (r + 1, c + 1)$:第一步移动到 $r + 1$ 行,代价为 $\textit{rowCost}[r + 1]$;第二步移动到 $c + 1$ 列,代价为 $\textit{colCost}[c + 1]$。总代价为 $\textit{rowCost}[r + 1] + \textit{colCost}[c + 1]$。

  • $(r, c) \rightarrow (r, c + 1) \rightarrow (r + 1, c + 1)$:第一步移动到 $c + 1$ 列,代价为 $\textit{colCost}[c + 1]$;第二步移动到 $r + 1$ 行,代价为 $\textit{rowCost}[r + 1]$。总代价为 $\textit{colCost}[c + 1] + \textit{rowCost}[r + 1]$。

可以发现,这两种方式代价相同。因此,路径的总代价也不会改变。

提示 $2$

如果某一条路径中包含相反操作(即同时含有向左和向右的操作,或同时含有向上和向下的操作),那么这条路径的代价一定不优于将这些操作成对抵消后的路径。

除此之外,任意不包含任何相反操作的路径对应的总代价一定最小。

提示 $2$ 解释

我们首先考虑前半部分。

不失一般性地,首先考虑从 $(r, c)$ 到 $(r + x, c) (x \ge 0)$ 的两种路径。一种路径为 $(r, c) \rightarrow (r + x, c)$,另一种路径为 $(r, c) \rightarrow (r, c + 1) \rightarrow (r + x, c + 1) \rightarrow (r + x, c)$。计算可得,后者相对于前者多出了 $\textit{colCost}[c] + \textit{colCost}[c + 1] \ge 0$ 的总代价,亦即前者一定更优。

而对于一般的存在相反方向操作的路径,其中必定包含上述的路径片段;而将路径片段中的相反操作抵消后,新的路径在总代价上一定不高于原路径。因此,我们可以递归地抵消这些相反操作,直至路径不包含任何相反操作,同时在每次操作时,总代价一定不会增加。

综上可知,对于任意包含相反操作的路径,一定存在一个不包含相反操作的路径,后者的总代价小于等于前者。因此,最小总代价对应的路径一定是不包含相反操作的路径。

而对于所有的这些不包含任何相反操作的路径,这些路径一定是由一些(数量可能为 $0$)单方向的横向操作和一些(数量可能为 $0$)单方向的纵向操作组成。根据 提示 $1$,我们可以任意交换这些操作,且总代价不变。因此,任意不包含任何相反操作的路径对应的总代价一定最小。

思路与算法

根据 提示 $2$,我们只需要构造任意一条从起点到家的不包含相反操作的路径,该路径对应的总代价即为最小总代价。

为了方便计算,我们先让机器人向上或向下移动至家所在行,再让机器人向左或向右移动至家所在的格子,并在这过程中计算总代价。

而对于如何确定移动的方向,我们行间的上下移动为例:我们比较机器人所在行号 $r_1$ 与家所在行号 $r_2$,如果 $r_1 < r_2$,则我们需要向下移动;如果 $r_1 > r_2$,则我们需要向上移动;如果 $r_1 = r_2$,则我们无需移动。

最终,我们返回该总代价作为答案。

代码

###C++

class Solution {
public:
    int minCost(vector<int>& startPos, vector<int>& homePos, vector<int>& rowCosts, vector<int>& colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1){
            res += accumulate(rowCosts.begin() + r1 + 1, rowCosts.begin() + r2 + 1, 0);
        }
        else{
            res += accumulate(rowCosts.begin() + r2, rowCosts.begin() + r1, 0);
        }
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1){
            res += accumulate(colCosts.begin() + c1 + 1, colCosts.begin() + c2 + 1, 0);
        }
        else{
            res += accumulate(colCosts.begin() + c2, colCosts.begin() + c1, 0);
        }
        return res;
    }
};

###Python

class Solution:
    def minCost(self, startPos: List[int], homePos: List[int], rowCosts: List[int], colCosts: List[int]) -> int:
        r1, c1 = startPos[0], startPos[1]
        r2, c2 = homePos[0], homePos[1]
        res = 0   # 总代价
        # 移动至家所在行,判断行间移动方向并计算对应代价
        if r2 >= r1:
            for i in range(r1 + 1, r2 + 1):
                res += rowCosts[i]
        else:
            for i in range(r2, r1):
                res += rowCosts[i]
        # 移动至家所在位置,判断列间移动方向并计算对应代价
        if c2 >= c1:
            for i in range(c1 + 1, c2 + 1):
                res += colCosts[i]
        else:
            for i in range(c2, c1):
                res += colCosts[i]
        return res

###Java

class Solution {
    public int minCost(int[] startPos, int[] homePos, int[] rowCosts, int[] colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1) {
            for (int i = r1 + 1; i <= r2; i++) {
                res += rowCosts[i];
            }
        } else {
            for (int i = r2; i < r1; i++) {
                res += rowCosts[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1) {
            for (int i = c1 + 1; i <= c2; i++) {
                res += colCosts[i];
            }
        } else {
            for (int i = c2; i < c1; i++) {
                res += colCosts[i];
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int MinCost(int[] startPos, int[] homePos, int[] rowCosts, int[] colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1) {
            for (int i = r1 + 1; i <= r2; i++) {
                res += rowCosts[i];
            }
        } else {
            for (int i = r2; i < r1; i++) {
                res += rowCosts[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1) {
            for (int i = c1 + 1; i <= c2; i++) {
                res += colCosts[i];
            }
        } else {
            for (int i = c2; i < c1; i++) {
                res += colCosts[i];
            }
        }
        
        return res;
    }
}

###Go

func minCost(startPos []int, homePos []int, rowCosts []int, colCosts []int) int {
    r1, c1 := startPos[0], startPos[1]
    r2, c2 := homePos[0], homePos[1]
    res := 0 // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if r2 >= r1 {
        for i := r1 + 1; i <= r2; i++ {
            res += rowCosts[i]
        }
    } else {
        for i := r2; i < r1; i++ {
            res += rowCosts[i]
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if c2 >= c1 {
        for i := c1 + 1; i <= c2; i++ {
            res += colCosts[i]
        }
    } else {
        for i := c2; i < c1; i++ {
            res += colCosts[i]
        }
    }
    
    return res
}

###C

int minCost(int* startPos, int startPosSize, int* homePos, int homePosSize, 
            int* rowCosts, int rowCostsSize, int* colCosts, int colCostsSize) {
    int r1 = startPos[0], c1 = startPos[1];
    int r2 = homePos[0], c2 = homePos[1];
    int res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (int i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (int i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (int i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (int i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
}

###JavaScript

var minCost = function(startPos, homePos, rowCosts, colCosts) {
    const r1 = startPos[0], c1 = startPos[1];
    const r2 = homePos[0], c2 = homePos[1];
    let res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (let i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (let i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (let i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (let i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
};

###TypeScript

function minCost(startPos: number[], homePos: number[], rowCosts: number[], colCosts: number[]): number {
    const r1 = startPos[0], c1 = startPos[1];
    const r2 = homePos[0], c2 = homePos[1];
    let res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (let i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (let i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (let i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (let i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
}

###Rust

impl Solution {
    pub fn min_cost(start_pos: Vec<i32>, home_pos: Vec<i32>, row_costs: Vec<i32>, col_costs: Vec<i32>) -> i32 {
        let r1 = start_pos[0] as usize;
        let c1 = start_pos[1] as usize;
        let r2 = home_pos[0] as usize;
        let c2 = home_pos[1] as usize;
        let mut res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if r2 >= r1 {
            for i in (r1 + 1)..=r2 {
                res += row_costs[i];
            }
        } else {
            for i in r2..r1 {
                res += row_costs[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if c2 >= c1 {
            for i in (c1 + 1)..=c2 {
                res += col_costs[i];
            }
        } else {
            for i in c2..c1 {
                res += col_costs[i];
            }
        }
        
        res
    }
}

复杂度分析

  • 时间复杂度:$O(m + n)$,其中 $m$ 为网格图的行数,$n$ 为网格图的列数。即为计算最小代价的时间复杂度。

  • 空间复杂度:$O(1)$。

等和矩阵分割 II

方法一:旋转矩阵 + 哈希表 + 枚举上半矩阵之和

思路与算法

本题是「等和矩阵分割 I」的增强版,在这一题的基础上,添加了 删除至多一个单元格 并且 删除后剩余部分必须保持连通 的条件。

那么需要进行删除的时候,我们需要考虑两条分割线的选取以及删除分割线哪一边的元素,为了简化思路,假设我们只判断是否存在水平分割线,并且进行删除操作时只删除水平分割线以上的元素。

能够发现,我们将矩阵进行 $3$ 次 $90$ 度的旋转,每次旋转后进行如上述所说的简化操作,就能够覆盖枚举分割线以及枚举删除元素的位置所带来的 $4$ 种不同情况。

接下来分析如何判断:

  1. 假设当前 $\textit{grid}$ 上半矩阵之和为 $\textit{sum}$,整个矩阵 $\textit{grid}$ 之和为 $\textit{total}$,那么 $\textit{grid}$ 下半矩阵之和为 $\textit{total} - \textit{sum}$。
  2. 假设我们要删除的元素为 $x$,那么需要满足 $\textit{sum} - x == \textit{total} - \textit{sum}$,于是有:$x == \textit{sum} * 2 - \textit{total}$。
  3. 因此在枚举完每一行之后只需要判断是否存在 $\textit{grid}[i][j]$ 满足 $\textit{grid}[i][j] == \textit{sum} * 2 - \textit{total}$ 即可。

我们可以使用一个集合来保存出现过的元素,便于判断是否存在满足题目要求的元素,集合中可以预存一个 $0$,这样可以将删除元素与不删除元素合并为一种情况。

特殊情况处理:

  1. 矩阵 $\textit{grid}$ 在遍历第一行的情况:
    在遍历第一行时能够删除的元素只有第一行的首尾元素,因此在统计完第一行的和之后需要判断 $\textit{grid}[0][0]$ 或者 $\textit{grid}[0][n - 1]$ 或者 $0$ 是否满足题目要求。
  2. 矩阵 $\textit{grid}$ 只有一列的情况:
    $\textit{grid}$ 只有一列时能够删除的元素只有首行以及尾行的元素,需要在遍历第 $i$ 行后判断 $\textit{grid}[0][0]$ 或者 $\textit{grid}[i][0]$ 或者 $0$ 是否满足题目要求。
  3. 当矩阵 $\textit{grid}$ 只有一行时可以跳过,因为无法进行水平分割。

其他情况中 $\textit{grid}$ 上半矩阵中所有的元素均可被删除。

在 $3$ 次旋转后就能够将所有情况覆盖到。

代码

###C++

class Solution {
public:
    vector<vector<int>> rotation(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector tmp(n, vector<int>(m));
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
    bool canPartitionGrid(vector<vector<int>>& grid) {
        long long total = 0;
        long long sum;
        long long tag;
        int m = grid.size();
        int n = grid[0].size();
        unordered_set<long long> exist;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            exist.clear();
            exist.insert(0);
            sum = 0;
            m = grid.size();
            n = grid[0].size();
            if (m < 2) {
                grid = rotation(grid);
                continue;
            }
            if(n == 1){
                for(int i = 0; i < m - 1; i++){
                    sum += grid[i][0];
                    tag = sum * 2 - total;
                    if(tag == 0 || tag == grid[0][0] || tag == grid[i][0]){
                        return true;
                    }
                }
                grid = rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for(int j = 0; j < n; j++){
                    exist.insert(grid[i][j]);
                    sum += grid[i][j];
                }
                tag = sum * 2 - total;
                if(i == 0){
                    if(tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]){
                        return true;
                    }
                    continue;
                }
                if(exist.contains(tag)){
                    return true;
                }
            }
            grid = rotation(grid);
        }
        return false;
    }
};

###JavaScript

var canPartitionGrid = function(grid) {
    let total = 0;
    let m = grid.length;
    let n = grid[0].length;
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    for (let k = 0; k < 4; k++) {
        const exist = new Set();
        exist.add(0);
        let sum = 0;
        m = grid.length;
        n = grid[0].length;
        if (m < 2) {
            grid = rotation(grid);
            continue;
        }
        if (n == 1) {
            for (let i = 0; i < m - 1; i++) {
                sum += grid[i][0];
                let tag = sum * 2 - total;
                if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                    return true;
                }
            }
            grid = rotation(grid);
            continue;
        }
        for (let i = 0; i < m - 1; i++) {
            for (let j = 0; j < n; j++) {
                exist.add(grid[i][j]);
                sum += grid[i][j];
            }
            let tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                    return true;
                }
                continue;
            }
            if (exist.has(tag)) {
                return true;
            }
        }
        grid = rotation(grid);
    }
    return false;
};

function rotation(grid) {
    const m = grid.length, n = grid[0].length;
    const tmp = Array.from({ length: n }, () => Array(m).fill(0));
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

###TypeScript

function canPartitionGrid(grid: number[][]): boolean {
    let total = 0;
    let m = grid.length;
    let n = grid[0].length;
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    for (let k = 0; k < 4; k++) {
        const exist = new Set<number>();
        exist.add(0);
        let sum = 0;
        m = grid.length;
        n = grid[0].length;
        if (m < 2) {
            grid = rotation(grid);
            continue;
        }
        if (n == 1) {
            for (let i = 0; i < m - 1; i++) {
                sum += grid[i][0];
                let tag = sum * 2 - total;
                if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                    return true;
                }
            }
            grid = rotation(grid);
            continue;
        }
        for (let i = 0; i < m - 1; i++) {
            for (let j = 0; j < n; j++) {
                exist.add(grid[i][j]);
                sum += grid[i][j];
            }
            let tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                    return true;
                }
                continue;
            }
            if (exist.has(tag)) {
                return true;
            }
        }
        grid = rotation(grid);
    }
    return false;
}

function rotation(grid: number[][]): number[][] {
    const m = grid.length, n = grid[0].length;
    const tmp: number[][] = Array.from({ length: n }, () => Array(m).fill(0));
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

###Java

class Solution {
    public boolean canPartitionGrid(int[][] grid) {
        long total = 0;
        int m = grid.length;
        int n = grid[0].length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            Set<Long> exist = new HashSet<>();
            exist.add(0L);
            long sum = 0;
            m = grid.length;
            n = grid[0].length;
            if (m < 2) {
                grid = rotation(grid);
                continue;
            }
            if (n == 1) {
                for (int i = 0; i < m - 1; i++) {
                    sum += grid[i][0];
                    long tag = sum * 2 - total;
                    if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                        return true;
                    }
                }
                grid = rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for (int j = 0; j < n; j++) {
                    exist.add((long) grid[i][j]);
                    sum += grid[i][j];
                }
                long tag = sum * 2 - total;
                if (i == 0) {
                    if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                        return true;
                    }
                    continue;
                }
                if (exist.contains(tag)) {
                    return true;
                }
            }
            grid = rotation(grid);
        }
        return false;
    }

    public int[][] rotation(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        int[][] tmp = new int[n][m];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
}

###C#

public class Solution {
    public bool CanPartitionGrid(int[][] grid) {
        long total = 0;
        int m = grid.Length;
        int n = grid[0].Length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            HashSet<long> exist = new HashSet<long>();
            exist.Add(0);
            long sum = 0;
            m = grid.Length;
            n = grid[0].Length;
            if (m < 2) {
                grid = Rotation(grid);
                continue;
            }
            if (n == 1) {
                for (int i = 0; i < m - 1; i++) {
                    sum += grid[i][0];
                    long tag = sum * 2 - total;
                    if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                        return true;
                    }
                }
                grid = Rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for (int j = 0; j < n; j++) {
                    exist.Add(grid[i][j]);
                    sum += grid[i][j];
                }
                long tag = sum * 2 - total;
                if (i == 0) {
                    if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                        return true;
                    }
                    continue;
                }
                if (exist.Contains(tag)) {
                    return true;
                }
            }
            grid = Rotation(grid);
        }
        return false;
    }

    public int[][] Rotation(int[][] grid) {
        int m = grid.Length, n = grid[0].Length;
        int[][] tmp = new int[n][];
        for (int i = 0; i < n; i++) {
            tmp[i] = new int[m];
        }
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
}

###Go

func canPartitionGrid(grid [][]int) bool {
    var total int64 = 0
    m, n := len(grid), len(grid[0])
    for i := 0; i < m; i++ {
        for j := 0; j < n; j++ {
            total += int64(grid[i][j])
        }
    }
    for k := 0; k < 4; k++ {
        exist := make(map[int64]bool)
        exist[0] = true
        var sum int64 = 0
        m, n = len(grid), len(grid[0])
        if m < 2 {
            grid = rotation(grid)
            continue
        }
        if n == 1 {
            for i := 0; i < m-1; i++ {
                sum += int64(grid[i][0])
                tag := sum*2 - total
                if tag == 0 || tag == int64(grid[0][0]) || tag == int64(grid[i][0]) {
                    return true
                }
            }
            grid = rotation(grid)
            continue
        }
        for i := 0; i < m-1; i++ {
            for j := 0; j < n; j++ {
                exist[int64(grid[i][j])] = true
                sum += int64(grid[i][j])
            }
            tag := sum*2 - total
            if i == 0 {
                if tag == 0 || tag == int64(grid[0][0]) || tag == int64(grid[0][n-1]) {
                    return true
                }
                continue
            }
            if exist[tag] {
                return true
            }
        }
        grid = rotation(grid)
    }
    return false
}

func rotation(grid [][]int) [][]int {
    m, n := len(grid), len(grid[0])
    tmp := make([][]int, n)
    for i := range tmp {
        tmp[i] = make([]int, m)
    }
    for i := 0; i < m; i++ {
        for j := 0; j < n; j++ {
            tmp[j][m-1-i] = grid[i][j]
        }
    }
    return tmp
}

###Python

class Solution:
    def canPartitionGrid(self, grid: List[List[int]]) -> bool:
        total = 0
        m = len(grid)
        n = len(grid[0])
        for i in range(m):
            for j in range(n):
                total += grid[i][j]
        for _ in range(4):
            exist = set()
            exist.add(0)
            sum_val = 0
            m = len(grid)
            n = len(grid[0])
            if m < 2:
                grid = self.rotation(grid)
                continue
            if n == 1:
                for i in range(m - 1):
                    sum_val += grid[i][0]
                    tag = sum_val * 2 - total
                    if tag == 0 or tag == grid[0][0] or tag == grid[i][0]:
                        return True
                grid = self.rotation(grid)
                continue
            for i in range(m - 1):
                for j in range(n):
                    exist.add(grid[i][j])
                    sum_val += grid[i][j]
                tag = sum_val * 2 - total
                if i == 0:
                    if tag == 0 or tag == grid[0][0] or tag == grid[0][n - 1]:
                        return True
                    continue
                if tag in exist:
                    return True
            grid = self.rotation(grid)
        return False

    def rotation(self, grid: List[List[int]]) -> List[List[int]]:
        m = len(grid)
        n = len(grid[0])
        tmp = [[0] * m for _ in range(n)]
        for i in range(m):
            for j in range(n):
                tmp[j][m - 1 - i] = grid[i][j]
        return tmp

###C

typedef struct {
    long long key;
    UT_hash_handle hh;
} HashItem;

static inline HashItem* hashFindItem(HashItem **obj, long long key) {
    HashItem *pEntry = NULL;
    HASH_FIND(hh, *obj, &key, sizeof(key), pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, long long key) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = malloc(sizeof(HashItem));
    if (!pEntry) return false;
    pEntry->key = key;
    HASH_ADD(hh, *obj, key, sizeof(key), pEntry);
    return true;
}

void hashFree(HashItem **obj) {
    HashItem *curr, *tmp;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);
        free(curr);
    }
    *obj = NULL;
}


int** rotation(int** grid, int m, int n, int* newM, int* newN) {
    *newM = n;
    *newN = m;
    int** tmp = malloc(n * sizeof(int*));
    int* data = malloc(n * m * sizeof(int));
    if (!tmp || !data) {
        free(tmp);
        free(data);
        return NULL;
    }
    for (int i = 0; i < n; i++) {
        tmp[i] = data + i * m;
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

void freeGrid(int** grid, int rows) {
    if (grid && grid[0]) {
        free(grid[0]);
    }
    free(grid);
}

static inline bool checkAndReturnTrue(HashItem **exist, int** currentGrid, int currentM, int** originalGrid) {
    hashFree(exist);
    if (currentGrid != originalGrid) {
        freeGrid(currentGrid, currentM);
    }
    return true;
}

static inline void rotateAndUpdate(int** *currentGrid, int *currentM, int *currentN, int** originalGrid) {
    int newM, newN;
    int** newGrid = rotation(*currentGrid, *currentM, *currentN, &newM, &newN);
    if (*currentGrid != originalGrid) {
        freeGrid(*currentGrid, *currentM);
    }
    *currentGrid = newGrid;
    *currentM = newM;
    *currentN = newN;
}

bool canPartitionGrid(int** grid, int gridSize, int* gridColSize) {
    const int m = gridSize;
    const int n = gridColSize[0];
    long long total = 0;

    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    int currentM = m, currentN = n;
    int** currentGrid = grid;

    for (int k = 0; k < 4; k++) {
        HashItem* exist = NULL;
        hashAddItem(&exist, 0LL);
        long long sum = 0;
        if (currentM < 2 || currentN == 1) {
            if (currentN == 1 && currentM >= 2) {
                for (int i = 0; i < currentM - 1; i++) {
                    sum += currentGrid[i][0];
                    long long tag = sum * 2 - total;
                    if (tag == 0 || tag == currentGrid[0][0] || tag == currentGrid[i][0]) {
                        return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
                    }
                }
            }
            rotateAndUpdate(&currentGrid, &currentM, &currentN, grid);
            hashFree(&exist);
            continue;
        }

        for (int i = 0; i < currentM - 1; i++) {
            for (int j = 0; j < currentN; j++) {
                hashAddItem(&exist, (long long)currentGrid[i][j]);
                sum += currentGrid[i][j];
            }
            long long tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == currentGrid[0][0] || tag == currentGrid[0][currentN - 1]) {
                    return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
                }
                continue;
            }
            if (hashFindItem(&exist, tag)) {
                return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
            }
        }

        rotateAndUpdate(&currentGrid, &currentM, &currentN, grid);
        hashFree(&exist);
    }

    if (currentGrid != grid) {
        freeGrid(currentGrid, currentM);
    }

    return false;
}

###Rust

use std::collections::HashSet;

impl Solution {
    fn rotation(grid: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        let m = grid.len();
        let n = grid[0].len();
        let mut tmp = vec![vec![0; m]; n];

        for i in 0..m {
            for j in 0..n {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        tmp
    }

    pub fn can_partition_grid(grid: Vec<Vec<i32>>) -> bool {
        let mut grid = grid;
        let mut total: i64 = 0;
        let mut sum: i64;
        let mut tag: i64;
        let mut m = grid.len();
        let mut n = grid[0].len();
        for i in 0..m {
            for j in 0..n {
                total += grid[i][j] as i64;
            }
        }

        let mut exist = HashSet::new();

        for _ in 0..4 {
            exist.clear();
            exist.insert(0);
            sum = 0;
            m = grid.len();
            n = grid[0].len();
            if m < 2 {
                grid = Self::rotation(&grid);
                continue;
            }
            if n == 1 {
                for i in 0..m - 1 {
                    sum += grid[i][0] as i64;
                    tag = sum * 2 - total;
                    if tag == 0 || tag == grid[0][0] as i64 || tag == grid[i][0] as i64 {
                        return true;
                    }
                }
                grid = Self::rotation(&grid);
                continue;
            }

            for i in 0..m - 1 {
                for j in 0..n {
                    exist.insert(grid[i][j] as i64);
                    sum += grid[i][j] as i64;
                }

                tag = sum * 2 - total;

                if i == 0 {
                    if tag == 0 || tag == grid[0][0] as i64 || tag == grid[0][n - 1] as i64 {
                        return true;
                    }
                    continue;
                }

                if exist.contains(&tag) {
                    return true;
                }
            }

            grid = Self::rotation(&grid);
        }

        false
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 为 $\textit{grid}$ 矩阵的行数,$n$ 为 $\textit{grid}$ 矩阵的列数。

  • 空间复杂度:$O(mn)$,其中 $m$ 为 $\textit{grid}$ 矩阵的行数,$n$ 为 $\textit{grid}$ 矩阵的列数。

矩阵的最大非负积

方法一:动态规划

思路与算法

由于矩阵中的元素有正有负,要想得到最大积,我们只存储移动过程中的最大积是不够的,例如当前的最大积为正数时,乘上一个负数后,反而不如一个负数乘上相同的负数得到的积大。

因此,我们需要存储的是移动过程中的积的范围,也就是积的最小值以及最大值。由于只能向下或者向右走,我们可以考虑使用动态规划的方法解决本题。

设 $\textit{maxgt}[i][j], \textit{minlt}[i][j]$ 分别为从坐标 $(0, 0)$ 出发,到达位置 $(i, j)$ 时乘积的最大值与最小值。由于我们只能向下或者向右走,因此乘积的取值必然只与 $(i, j-1)$ 和 $(i-1, j)$ 两个位置有关。

对于乘积的最大值而言:若 $\textit{grid}[i][j] \ge 0$,则 $\textit{maxgt}[i][j]$ 的取值取决于这两个位置的最大值,此时

$$
\textit{maxgt}[i][j] = \max(\textit{maxgt}[i][j-1], \textit{maxgt}[i-1][j]) \times \textit{grid}[i][j]
$$

相反地,若 $\textit{grid}[i][j] \le 0$,则 $\textit{maxgt}[i][j]$ 的取值取决于这两个位置的最小值,此时

$$
\textit{maxgt}[i][j] = \min(\textit{minlt}[i][j-1], \textit{minlt}[i-1][j]) \times \textit{grid}[i][j]
$$

计算乘积的最小值也是类似的思路。若 $\textit{grid}[i][j] \ge 0$,此时

$$
\textit{mingt}[i][j] = \min(\textit{mingt}[i][j-1], \textit{mingt}[i-1][j]) \times \textit{grid}[i][j]
$$

若 $\textit{grid}[i][j] \le 0$,此时

$$
\textit{mingt}[i][j] = \max(\textit{maxgt}[i][j-1], \textit{maxgt}[i-1][j]) \times \textit{grid}[i][j]
$$

特别地,当 $i=0$ 时,只需要从 $(i, j-1)$ 进行转移;$j=0$ 时,只需要从 $(i-1, j)$ 进行转移;$i=0$ 且 $j=0$ 时,$\textit{maxgt}[i][j]$ 与 $\textit{mingt}[i][j]$ 的值均为左上角的元素值 $\textit{grid}[i][j]$。

最终的答案即为 $\textit{maxgt}[m-1][n-1]$,其中 $m$ 和 $n$ 分别是矩阵的行数与列数。

代码

###C++

class Solution {
public:
    int maxProductPath(vector<vector<int>>& grid) {
        const int mod = 1000000000 + 7;
        int m = grid.size(), n = grid[0].size();
        vector<vector<long long>> maxgt(m, vector<long long>(n));
        vector<vector<long long>> minlt(m, vector<long long>(n));

        maxgt[0][0] = minlt[0][0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
        }

        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    maxgt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                } else {
                    maxgt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                }
            }
        }
        if (maxgt[m - 1][n - 1] < 0) {
            return -1;
        } else {
            return maxgt[m - 1][n - 1] % mod;
        }
    }
};

###Java

class Solution {
    public int maxProductPath(int[][] grid) {
        final int MOD = 1000000000 + 7;
        int m = grid.length, n = grid[0].length;
        long[][] maxgt = new long[m][n];
        long[][] minlt = new long[m][n];

        maxgt[0][0] = minlt[0][0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
        }

        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                } else {
                    maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                }
            }
        }
        if (maxgt[m - 1][n - 1] < 0) {
            return -1;
        } else {
            return (int) (maxgt[m - 1][n - 1] % MOD);
        }
    }
}

###Python

class Solution:
    def maxProductPath(self, grid: List[List[int]]) -> int:
        mod = 10**9 + 7
        m, n = len(grid), len(grid[0])
        maxgt = [[0] * n for _ in range(m)]
        minlt = [[0] * n for _ in range(m)]

        maxgt[0][0] = minlt[0][0] = grid[0][0]
        for i in range(1, n):
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i]
        for i in range(1, m):
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0]
        
        for i in range(1, m):
            for j in range(1, n):
                if grid[i][j] >= 0:
                    maxgt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j]
                    minlt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j]
                else:
                    maxgt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j]
                    minlt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j]
        
        if maxgt[m - 1][n - 1] < 0:
            return -1
        return maxgt[m - 1][n - 1] % mod

###C#

public class Solution {
    public int MaxProductPath(int[][] grid) {
        const int MOD = 1000000007;
        int m = grid.Length, n = grid[0].Length;
        long[,] maxgt = new long[m, n];
        long[,] minlt = new long[m, n];

        maxgt[0, 0] = minlt[0, 0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0, i] = minlt[0, i] = maxgt[0, i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i, 0] = minlt[i, 0] = maxgt[i - 1, 0] * grid[i][0];
        }
        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    long maxPrev = Math.Max(maxgt[i, j - 1], maxgt[i - 1, j]);
                    long minPrev = Math.Min(minlt[i, j - 1], minlt[i - 1, j]);
                    maxgt[i, j] = maxPrev * grid[i][j];
                    minlt[i, j] = minPrev * grid[i][j];
                } else {
                    long maxPrev = Math.Max(maxgt[i, j - 1], maxgt[i - 1, j]);
                    long minPrev = Math.Min(minlt[i, j - 1], minlt[i - 1, j]);
                    maxgt[i, j] = minPrev * grid[i][j];
                    minlt[i, j] = maxPrev * grid[i][j];
                }
            }
        }
        
        if (maxgt[m - 1, n - 1] < 0) {
            return -1;
        } else {
            return (int)(maxgt[m - 1, n - 1] % MOD);
        }
    }
}

###Go

func maxProductPath(grid [][]int) int {
    const MOD = 1000000007
    m, n := len(grid), len(grid[0])
    
    maxgt := make([][]int64, m)
    minlt := make([][]int64, m)
    for i := range maxgt {
        maxgt[i] = make([]int64, n)
        minlt[i] = make([]int64, n)
    }
    
    maxgt[0][0] = int64(grid[0][0])
    minlt[0][0] = int64(grid[0][0])
    for i := 1; i < n; i++ {
        maxgt[0][i] = maxgt[0][i-1] * int64(grid[0][i])
        minlt[0][i] = maxgt[0][i]
    }
    for i := 1; i < m; i++ {
        maxgt[i][0] = maxgt[i-1][0] * int64(grid[i][0])
        minlt[i][0] = maxgt[i][0]
    }
    for i := 1; i < m; i++ {
        for j := 1; j < n; j++ {
            if grid[i][j] >= 0 {
                maxPrev := max(maxgt[i][j-1], maxgt[i-1][j])
                minPrev := min(minlt[i][j-1], minlt[i-1][j])
                maxgt[i][j] = maxPrev * int64(grid[i][j])
                minlt[i][j] = minPrev * int64(grid[i][j])
            } else {
                maxPrev := max(maxgt[i][j-1], maxgt[i-1][j])
                minPrev := min(minlt[i][j-1], minlt[i-1][j])
                maxgt[i][j] = minPrev * int64(grid[i][j])
                minlt[i][j] = maxPrev * int64(grid[i][j])
            }
        }
    }
    
    if maxgt[m-1][n-1] < 0 {
        return -1
    }
    return int(maxgt[m-1][n-1] % MOD)
}

###C

#define MOD 1000000007

int maxProductPath(int** grid, int gridSize, int* gridColSize) {
    int m = gridSize, n = gridColSize[0];
    long long** maxgt = (long long**)malloc(m * sizeof(long long*));
    long long** minlt = (long long**)malloc(m * sizeof(long long*));
    for (int i = 0; i < m; i++) {
        maxgt[i] = (long long*)malloc(n * sizeof(long long));
        minlt[i] = (long long*)malloc(n * sizeof(long long));
    }
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (int i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (int i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (int i = 1; i < m; i++) {
        for (int j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = fmax(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = fmin(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = fmin(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = fmax(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    long long result = maxgt[m - 1][n - 1];
    for (int i = 0; i < m; i++) {
        free(maxgt[i]);
        free(minlt[i]);
    }
    free(maxgt);
    free(minlt);
    
    if (result < 0) {
        return -1;
    } else {
        return result % MOD;
    }
}

###JavaScript

var maxProductPath = function(grid) {
    const MOD = 1000000007;
    const m = grid.length, n = grid[0].length;
    const maxgt = new Array(m).fill(0).map(() => new Array(n).fill(0));
    const minlt = new Array(m).fill(0).map(() => new Array(n).fill(0));
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (let i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (let i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (let i = 1; i < m; i++) {
        for (let j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    if (maxgt[m - 1][n - 1] < 0) {
        return -1;
    } else {
        return maxgt[m - 1][n - 1] % MOD;
    }
};

###TypeScript

function maxProductPath(grid: number[][]): number {
    const MOD = 1000000007;
    const m = grid.length, n = grid[0].length;
    
    const maxgt: number[][] = new Array(m).fill(0).map(() => new Array(n).fill(0));
    const minlt: number[][] = new Array(m).fill(0).map(() => new Array(n).fill(0));
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (let i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (let i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (let i = 1; i < m; i++) {
        for (let j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    if (maxgt[m - 1][n - 1] < 0) {
        return -1;
    } else {
        return maxgt[m - 1][n - 1] % MOD;
    }
}

###Rust

impl Solution {
    pub fn max_product_path(grid: Vec<Vec<i32>>) -> i32 {
        const MOD: i64 = 1_000_000_007;
        let m = grid.len();
        let n = grid[0].len();
        let mut maxgt = vec![vec![0i64; n]; m];
        let mut minlt = vec![vec![0i64; n]; m];
        maxgt[0][0] = grid[0][0] as i64;
        minlt[0][0] = grid[0][0] as i64;
        
        for i in 1..n {
            maxgt[0][i] = maxgt[0][i-1] * grid[0][i] as i64;
            minlt[0][i] = maxgt[0][i];
        }
        for i in 1..m {
            maxgt[i][0] = maxgt[i-1][0] * grid[i][0] as i64;
            minlt[i][0] = maxgt[i][0];
        }
        for i in 1..m {
            for j in 1..n {
                let grid_val = grid[i][j] as i64;
                if grid_val >= 0 {
                    let max_prev = maxgt[i][j-1].max(maxgt[i-1][j]);
                    let min_prev = minlt[i][j-1].min(minlt[i-1][j]);
                    maxgt[i][j] = max_prev * grid_val;
                    minlt[i][j] = min_prev * grid_val;
                } else {
                    let max_prev = maxgt[i][j-1].max(maxgt[i-1][j]);
                    let min_prev = minlt[i][j-1].min(minlt[i-1][j]);
                    maxgt[i][j] = min_prev * grid_val;
                    minlt[i][j] = max_prev * grid_val;
                }
            }
        }
        
        let result = maxgt[m-1][n-1];
        if result < 0 {
            -1
        } else {
            (result % MOD) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 和 $n$ 为矩阵的行数与列数。我们需要遍历矩阵的每一个元素,而处理每个元素时只需要常数时间。

  • 空间复杂度:$O(mn)$。我们开辟了两个与原矩阵等大的数组。

元素和小于等于 k 的子矩阵的数目

方法一:二维前缀和

思路与算法

题目要求统计包含矩阵 $\textit{grid}$ 左上角元素的所有子矩阵中,元素和不超过 $k$ 的子矩阵个数。

我们从左上角出发,按行优先顺序遍历矩阵,将当前访问位置 $(i, j)$ 视为子矩阵的右下角。为了在一次遍历中高效计算子矩阵的和,我们维护一个数组 $\textit{cols}[j]$,用于记录当前行之前第 $j$ 列所有元素的和。在遍历第 $i$ 行时,按列从左到右遍历 $j$,将 $\textit{grid}[i][j]$ 累加到 $\textit{cols}[j]$后,并将 $\textit{cols}[j]$ 累加起来,若当前累加和 $\le k$,则答案加 $1$。

代码

###C++

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n = grid.size(), m = grid[0].size();
        vector<int> cols(m);
        int res = 0;
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        return res;
    }
};

###Python

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        n, m = len(grid), len(grid[0])
        cols = [0] * m
        res = 0
        
        for i in range(n):
            row_sum = 0
            for j in range(m):
                cols[j] += grid[i][j]
                row_sum += cols[j]
                if row_sum <= k:
                    res += 1
        
        return res

###Rust

impl Solution {
    pub fn count_submatrices(grid: Vec<Vec<i32>>, k: i32) -> i32 {
        let n = grid.len();
        let m = grid[0].len();
        let mut cols = vec![0; m];
        let mut res = 0;
        
        for i in 0..n {
            let mut row_sum = 0;
            for j in 0..m {
                cols[j] += grid[i][j];
                row_sum += cols[j];
                if row_sum <= k {
                    res += 1;
                }
            }
        }
        
        res
    }
}

###Java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int n = grid.length, m = grid[0].length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int CountSubmatrices(int[][] grid, int k) {
        int n = grid.Length, m = grid[0].Length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###Go

func countSubmatrices(grid [][]int, k int) int {
    n := len(grid)
    m := len(grid[0])
    cols := make([]int, m)
    res := 0
    
    for i := 0; i < n; i++ {
        rows := 0
        for j := 0; j < m; j++ {
            cols[j] += grid[i][j]
            rows += cols[j]
            if rows <= k {
                res++
            }
        }
    }
    
    return res
}

###C

int countSubmatrices(int** grid, int gridSize, int* gridColSize, int k) {
    int n = gridSize;
    int m = *gridColSize;
    int* cols = (int*)calloc(m, sizeof(int));
    int res = 0;
    
    for (int i = 0; i < n; i++) {
        int rows = 0;
        for (int j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    free(cols);
    return res;
}

###JavaScript

var countSubmatrices = function(grid, k) {
    const n = grid.length;
    const m = grid[0].length;
    const cols = new Array(m).fill(0);
    let res = 0;
    
    for (let i = 0; i < n; i++) {
        let rows = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
};

###TypeScript

function countSubmatrices(grid: number[][], k: number): number {
    const n: number = grid.length;
    const m: number = grid[0].length;
    const cols: number[] = new Array(m).fill(0);
    let res: number = 0;
    
    for (let i = 0; i < n; i++) {
        let rows: number = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 是 $\textit{grid}$ 的行数,$n$ 是 $\textit{grid}$ 的列数。

  • 空间复杂度:$O(n)$。

移山所需的最少秒数

方法一:二分答案

思路与算法

根据题目描述,如果 $t$ 秒可以使山的高度降低到 $0$,那么任何大于 $t$ 的秒数也可以。因此答案具有单调性,我们可以使用二分查找来解决本题。

对于二分查找的每一步,假设当前猜测的秒数为 $\textit{mid}$,我们需要判断所有工人在 $\textit{mid}$ 秒内能否将山的高度降低 $H = \textit{mountainHeight}$。对于第 $i$ 个工人,他将山的高度降低 $k$ 所需的时间为:

$$
\textit{workerTimes}[i] \cdot (1 + 2 + \cdots + k) = \textit{workerTimes}[i] \cdot \frac{k(k+1)}{2}
$$

因此在 $\textit{mid}$ 秒内,第 $i$ 个工人 $i$ 最多能将山降低的高度,是满足

$$\textit{workerTimes}[i] \cdot \frac{k(k+1)}{2} \leq \textit{mid}$$

的最大正整数 $k$。

令 $\textit{work} = \lfloor \dfrac{\textit{mid}}{\textit{workerTimes}[i]} \rfloor$,其中 $\lfloor \cdot \rfloor$ 表示下取整,则需要满足 $\dfrac{k(k+1)}{2} \leq \textit{work}$,利用求一元二次方程求根公式可得:

$$
k = \left\lfloor \frac{-1 + \sqrt{1 + 8 \cdot \textit{work}}}{2} \right\rfloor
$$

我们将所有工人计算得到的 $k$ 值相加,如果总和大于等于 $H$,则说明 $\textit{mid}$ 秒可以完成任务,应当尝试更少的时间,否则尝试更多的时间。

二分查找的下界为 $1$,上界为 $\max(\textit{workerTimes}) \cdot \dfrac{H(H + 1)}{2}$,即最慢的工人独自完成所有工作所需的时间。

代码

###C++

class Solution {
public:
    long long minNumberOfSeconds(int mountainHeight, vector<int>& workerTimes) {
        int maxWorkerTimes = *max_element(workerTimes.begin(), workerTimes.end());
        long long l = 1, r = static_cast<long long>(maxWorkerTimes) * mountainHeight * (mountainHeight + 1) / 2;
        long long ans = 0;

        while (l <= r) {
            long long mid = (l + r) / 2;
            long long cnt = 0;
            for (int t: workerTimes) {
                long long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long long k = (-1.0 + sqrt(1 + work * 8)) / 2 + eps;
                cnt += k;
            }
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            }
            else {
                l = mid + 1;
            }
        }

        return ans;
    }

private:
    static constexpr double eps = 1e-7;
};

###Python

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        maxWorkerTimes = max(workerTimes)
        l, r, ans = 1, maxWorkerTimes * mountainHeight * (mountainHeight + 1) // 2, 0
        eps = 1e-7
        
        while l <= r:
            mid = (l + r) // 2
            cnt = 0
            for t in workerTimes:
                work = mid // t
                # 求最大的 k 满足 1+2+...+k <= work
                k = int((-1 + ((1 + work * 8) ** 0.5)) / 2 + eps)
                cnt += k
            if cnt >= mountainHeight:
                ans = mid
                r = mid - 1
            else:
                l = mid + 1

        return ans

###Java

class Solution {
    private static final double EPS = 1e-7;
    
    public long minNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        for (int t : workerTimes) {
            maxWorkerTimes = Math.max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long) maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            for (int t : workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###C#

class Solution {
    private const double EPS = 1e-7;
    
    public long MinNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        foreach (int t in workerTimes) {
            maxWorkerTimes = Math.Max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            
            foreach (int t in workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.Sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###Go

const eps = 1e-7

func minNumberOfSeconds(mountainHeight int, workerTimes []int) int64 {
    maxWorkerTimes := 0
    for _, t := range workerTimes {
        if t > maxWorkerTimes {
            maxWorkerTimes = t
        }
    }
    
    l := int64(1)
    r := int64(maxWorkerTimes) * int64(mountainHeight) * int64(mountainHeight + 1) / 2
    var ans int64 = 0
    
    for l <= r {
        mid := (l + r) / 2
        var cnt int64 = 0
        
        for _, t := range workerTimes {
            work := mid / int64(t)
            // 求最大的 k 满足 1+2+...+k <= work
            k := int64((-1.0 + math.Sqrt(1 + float64(work) * 8)) / 2 + eps)
            cnt += k
        }
        if cnt >= int64(mountainHeight) {
            ans = mid
            r = mid - 1
        } else {
            l = mid + 1
        }
    }
    
    return ans
}

###C

#define EPS 1e-7

long long minNumberOfSeconds(int mountainHeight, int* workerTimes, int workerTimesSize) {
    int maxWorkerTimes = 0;
    for (int i = 0; i < workerTimesSize; i++) {
        if (workerTimes[i] > maxWorkerTimes) {
            maxWorkerTimes = workerTimes[i];
        }
    }
    
    long long l = 1;
    long long r = (long long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    long long ans = 0;
    
    while (l <= r) {
        long long mid = (l + r) / 2;
        long long cnt = 0;
        for (int i = 0; i < workerTimesSize; i++) {
            long long work = mid / workerTimes[i];
            // 求最大的 k 满足 1+2+...+k <= work
            long long k = (long long)((-1.0 + sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###JavaScript

const EPS = 1e-7;

var minNumberOfSeconds = function(mountainHeight, workerTimes) {
    const maxWorkerTimes = Math.max(...workerTimes);
    let l = 1;
    let r = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans = 0;
    
    while (l <= r) {
        const mid = Math.floor((l + r) / 2);
        let cnt = 0;
        for (const t of workerTimes) {
            const work = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###TypeScript

const EPS: number = 1e-7;

function minNumberOfSeconds(mountainHeight: number, workerTimes: number[]): number {
    const maxWorkerTimes: number = Math.max(...workerTimes);
    let l: number = 1;
    let r: number = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans: number = 0;
    
    while (l <= r) {
        const mid: number = Math.floor((l + r) / 2);
        let cnt: number = 0;
        for (const t of workerTimes) {
            const work: number = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k: number = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###Rust

const EPS: f64 = 1e-7;

impl Solution {
    pub fn min_number_of_seconds(mountain_height: i32, worker_times: Vec<i32>) -> i64 {
        let mountain_height = mountain_height as i64;
        let max_worker_times = *worker_times.iter().max().unwrap_or(&0) as i64;
        
        let mut l: i64 = 1;
        let mut r: i64 = max_worker_times * mountain_height * (mountain_height + 1) / 2;
        let mut ans: i64 = 0;
        
        while l <= r {
            let mid = (l + r) / 2;
            let mut cnt: i64 = 0;
            
            for &t in &worker_times {
                let work = mid / t as i64;
                // 求最大的 k 满足 1+2+...+k <= work
                let k = ((-1.0 + (1.0 + (work as f64) * 8.0).sqrt()) / 2.0 + EPS) as i64;
                cnt += k;
            }
            
            if cnt >= mountain_height {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(n \log(MH^2))$,其中 $n$ 是数组 $\textit{workerTimes}$ 的长度,$M$ 是数组 $\textit{workerTimes}$ 中的最大值,$H$ 是 $\textit{mountainHeight}$。二分查找需要 $O(\log(MH^2))$ 次迭代,每次迭代遍历所有工人,需要 $O(n)$ 的时间。

  • 空间复杂度:$O(1)$。

找出所有稳定的二进制数组 II

方法一:记忆化搜索

思路

根据稳定数组的前两个条件,可知稳定数组的长度为 $\textit{zero} + \textit{one}$。第三个条件可知,稳定数组不存在长度为 $\textit{limit} + 1$ 的全 $0$ 或全 $1$ 子数组。

接下来我们分解问题,包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$ 的稳定数组,末位元素可能为 $1$,也可能为 $0$。

  • 如果末位元素为 $1$,我们需要知道有多少个包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}-1$ 个 $1$ 的稳定数组,再去掉“由于添加了一个 $1$ 而使得原来的稳定数组变得不稳定”的情况。那么有哪些情况会使得原来稳定的数组变得不稳定呢?即原来的稳定数组的末尾连续 $1$ 的个数正好为 $\textit{limit}$ 个。在这种情况下,添加一个 $1$ 会使得原来稳定的数组变得不稳定。这种情况出现的次数,即为包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}-1-\textit{limit}$ 个 $1$,且末位元素为 $0$ 的稳定数组的个数。
  • 如果末位元素为 $0$,我们需要知道有多少个包含 $\textit{zero}-1$ 个 $0$ 和 $\textit{one}$ 个 $1$ 的稳定数组,再去掉“由于添加了一个 $0$ 而使得原来的稳定数组变得不稳定”的情况。

这样一来,我们就将问题分解为子问题了,可以用动态规划求解。用函数 $\textit{dp}(\textit{zero},\textit{one},\textit{lastBit})$,来求解包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$,并且末位元素为 $\textit{lastBit}$ 的稳定数组的个数,其中 $\textit{lastBit}$ 为 $0$ 或 $1$。根据上面的讨论,可以得到递推公式:

  • $\textit{dp}(\textit{zero},\textit{one},0)$ = $\textit{dp}(\textit{zero}-1,\textit{one},0) + \textit{dp}(\textit{zero}-1,\textit{one},1) - \textit{dp}(\textit{zero}-1-\textit{limit},\textit{one},1)$
  • $\textit{dp}(\textit{zero},\textit{one},1)$ = $\textit{dp}(\textit{zero},\textit{one}-1,0) + \textit{dp}(\textit{zero},\textit{one}-1,1) - \textit{dp}(\textit{zero},\textit{one}-1-\textit{limit},0)$。

另外考虑边界情况。如果 $\textit{zero}$ 为 $0$,那么当 $\textit{lastBit}$ 为 $1$ 或者 $\textit{one}$ 大于 $\textit{limit}$ 时,不存在这样的稳定数组,返回 $0$,否则返回 $1$。如果 $\textit{zero}$ 为 $1$,也有对应的结论。

我们用记忆化搜索的方式来计算结果,记录所有的中间状态,最终返回 $\textit{dp}(\textit{zero},\textit{one},0)$ + $\textit{dp}(\textit{zero},\textit{one},1)$ 取模后的结果。

代码

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        mod = 10 ** 9 + 7

        @cache
        def dp(zero, one, lastBit):
            if zero == 0:
                if lastBit == 0 or one > limit:
                    return 0
                else:
                    return 1
            elif one == 0:
                if lastBit == 1 or zero > limit:
                    return 0
                else:
                    return 1
            if lastBit == 0:
                res = dp(zero - 1, one, 0) + dp(zero - 1, one, 1)
                if zero > limit:
                    res -= dp(zero - limit - 1, one, 1)
            else:
                res = dp(zero, one - 1, 0) + dp(zero, one - 1, 1)
                if one > limit:
                    res -= dp(zero, one - limit - 1, 0)
            return res % mod
            
        res = (dp(zero, one, 0) + dp(zero, one, 1)) % mod
        dp.cache_clear()
        return res

###C++

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        int mod = 1e9 + 7;
        vector<vector<vector<int>>> memo(zero + 1, vector<vector<int>>(one + 1, vector<int>(2, -1)));

        function<int(int, int, int)> dp = [&](int zero, int one, int lastBit) -> int {
            if (zero == 0) {
                return (lastBit == 0 || one > limit) ? 0 : 1;
            } else if (one == 0) {
                return (lastBit == 1 || zero > limit) ? 0 : 1;
            }

            if (memo[zero][one][lastBit] == -1) {
                int res = 0;
                if (lastBit == 0) {
                    res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % mod;
                    if (zero > limit) {
                        res = (res - dp(zero - limit - 1, one, 1) + mod) % mod;
                    }
                } else {
                    res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % mod;
                    if (one > limit) {
                        res = (res - dp(zero, one - limit - 1, 0) + mod) % mod;
                    }
                }
                memo[zero][one][lastBit] = res % mod;
            }
            return memo[zero][one][lastBit];
        };

        return (dp(zero, one, 0) + dp(zero, one, 1)) % mod;
    }
};

###Java

class Solution {
    static final int MOD = 1000000007;
    int[][][] memo;
    int limit;

    public int numberOfStableArrays(int zero, int one, int limit) {
        this.memo = new int[zero + 1][one + 1][2];
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                Arrays.fill(memo[i][j], -1);
            }
        }
        this.limit = limit;
        return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
    }

    public int dp(int zero, int one, int lastBit) {
        if (zero == 0) {
            return (lastBit == 0 || one > limit) ? 0 : 1;
        } else if (one == 0) {
            return (lastBit == 1 || zero > limit) ? 0 : 1;
        }

        if (memo[zero][one][lastBit] == -1) {
            int res = 0;
            if (lastBit == 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1))% MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }
}

###C#

public class Solution {
    const int MOD = 1000000007;
    int[][][] memo;
    int limit;

    public int NumberOfStableArrays(int zero, int one, int limit) {
        this.memo = new int[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            memo[i] = new int[one + 1][];
            for (int j = 0; j <= one; j++) {
                memo[i][j] = new int[2];
                Array.Fill(memo[i][j], -1);
            }
        }
        this.limit = limit;
        return (DP(zero, one, 0) + DP(zero, one, 1)) % MOD;
    }

    public int DP(int zero, int one, int lastBit) {
        if (zero == 0) {
            return (lastBit == 0 || one > limit) ? 0 : 1;
        } else if (one == 0) {
            return (lastBit == 1 || zero > limit) ? 0 : 1;
        }

        if (memo[zero][one][lastBit] == -1) {
            int res = 0;
            if (lastBit == 0) {
                res = (DP(zero - 1, one, 0) + DP(zero - 1, one, 1))% MOD;
                if (zero > limit) {
                    res = (res - DP(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (DP(zero, one - 1, 0) + DP(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - DP(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }
}

###C

#define MOD 1000000007

int ***createMemo(int zero, int one) {
    int ***memo = malloc((zero + 1) * sizeof(int **));
    for (int i = 0; i <= zero; ++i) {
        memo[i] = malloc((one + 1) * sizeof(int *));
        for (int j = 0; j <= one; ++j) {
            memo[i][j] = malloc(2 * sizeof(int));
            memo[i][j][0] = -1;
            memo[i][j][1] = -1;
        }
    }
    return memo;
}

void freeMemo(int zero, int one, int ***memo) {
    for (int i = 0; i <= zero; ++i) {
        for (int j = 0; j <= one; ++j) {
            free(memo[i][j]);
        }
        free(memo[i]);
    }
    free(memo);
}

int dp(int zero, int one, int lastBit, int limit, int ***memo) {
    if (zero == 0) {
        return (lastBit == 0 || one > limit) ? 0 : 1;
    } else if (one == 0) {
        return (lastBit == 1 || zero > limit) ? 0 : 1;
    }
    if (memo[zero][one][lastBit] == -1) {
        int res = 0;
        if (lastBit == 0) {
            res = (dp(zero - 1, one, 0, limit, memo) + dp(zero - 1, one, 1, limit, memo)) % MOD;
            if (zero > limit) {
                res = (res - dp(zero - limit - 1, one, 1, limit, memo) + MOD) % MOD;
            }
        } else {
            res = (dp(zero, one - 1, 0, limit, memo) + dp(zero, one - 1, 1, limit, memo)) % MOD;
            if (one > limit) {
                res = (res - dp(zero, one - limit - 1, 0, limit, memo) + MOD) % MOD;
            }
        }
        memo[zero][one][lastBit] = res % MOD;
    }
    return memo[zero][one][lastBit];
}

int numberOfStableArrays(int zero, int one, int limit) {
    int ***memo = createMemo(zero, one);
    int result = (dp(zero, one, 0, limit, memo) + dp(zero, one, 1, limit, memo)) % MOD;
    freeMemo(zero, one, memo);
    return result;
}

###Go

const MOD = 1000000007

func numberOfStableArrays(zero int, one int, limit int) int {
    memo := make([][][]int, zero + 1)
for i := range memo {
memo[i] = make([][]int, one + 1)
for j := range memo[i] {
memo[i][j] = []int{-1, -1}
}
}

var dp func(int, int, int) int
dp = func(zero, one, lastBit int) int {
if zero == 0 {
if lastBit == 0 || one > limit {
return 0
} else {
return 1
}
} else if one == 0 {
if lastBit == 1 || zero > limit {
return 0
} else {
return 1
}
}

if memo[zero][one][lastBit] == -1 {
res := 0
if lastBit == 0 {
res = (dp(zero-1, one, 0) + dp(zero - 1, one, 1)) % MOD
if zero > limit {
res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD
}
} else {
res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD
if one > limit {
res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD
}
}
memo[zero][one][lastBit] = res % MOD
}
return memo[zero][one][lastBit]
}

return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    const memo = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [-1, -1])
    );

    function dp(zero, one, lastBit) {
        if (zero === 0) {
            return lastBit === 0 || one > limit ? 0 : 1;
        } else if (one === 0) {
            return lastBit === 1 || zero > limit ? 0 : 1;
        }

        if (memo[zero][one][lastBit] === -1) {
            let res = 0;
            if (lastBit === 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }

    return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    const memo: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [-1, -1])
    );

    function dp(zero: number, one: number, lastBit: number): number {
        if (zero === 0) {
            return lastBit === 0 || one > limit ? 0 : 1;
        } else if (one === 0) {
            return lastBit === 1 || zero > limit ? 0 : 1;
        }

        if (memo[zero][one][lastBit] === -1) {
            let res = 0;
            if (lastBit === 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }

    return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut memo = vec![vec![vec![-1; 2]; (one + 1) as usize]; (zero + 1) as usize];

        fn dp(zero: usize, one: usize, last_bit: usize, limit: usize, memo: &mut Vec<Vec<Vec<i32>>>) -> i32 {
            if zero == 0 {
                return if last_bit == 0 || one > limit { 0 } else { 1 };
            } else if one == 0 {
                return if last_bit == 1 || zero > limit { 0 } else { 1 };
            }

            if memo[zero][one][last_bit] == -1 {
                let mut res = 0;
                if last_bit == 0 {
                    res = (dp(zero - 1, one, 0, limit, memo) + dp(zero - 1, one, 1, limit, memo)) % MOD;
                    if zero > limit {
                        res = (res - dp(zero - limit - 1, one, 1, limit, memo) + MOD) % MOD;
                    }
                } else {
                    res = (dp(zero, one - 1, 0, limit, memo) + dp(zero, one - 1, 1, limit, memo)) % MOD;
                    if one > limit {
                        res = (res - dp(zero, one - limit - 1, 0, limit, memo) + MOD) % MOD;
                    }
                }
                memo[zero][one][last_bit] = res % MOD;
            }
            memo[zero][one][last_bit]
        }

        let zero = zero as usize;
        let one = one as usize;
        let limit = limit as usize;
        (dp(zero, one, 0, limit, &mut memo) + dp(zero, one, 1, limit, &mut memo)) % MOD
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero}\times\textit{one})$,动态规划的状态一共有 $O(\textit{zero}\times\textit{one})$ 种,每个状态消耗 $O(1)$ 时间消耗。

  • 空间复杂度:$O(\textit{zero}\times\textit{one})$。

方法二:动态规划

思路

方法一用的是记忆化搜索,状态的求解是自顶向下的。方法二中我们使用动态规划,从而自底向上来求出所有状态,并用数组保存结果。状态方程的关系和方法一一致。

代码

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        mod = 10 ** 9 + 7

        dp = [[[0, 0] for _ in range(one + 1)] for _ in range(zero + 1)]
        for i in range(zero+1):
            for j in range(one+1):
                for lastBit in range(2):
                    if i == 0:
                        if lastBit == 0 or j > limit:
                            dp[i][j][lastBit] = 0
                        else:
                            dp[i][j][lastBit] = 1
                    elif j == 0:
                        if lastBit == 1 or i > limit:
                            dp[i][j][lastBit] = 0
                        else:
                            dp[i][j][lastBit] = 1
                    elif lastBit == 0:
                        dp[i][j][lastBit] = dp[i-1][j][0] + dp[i-1][j][1]
                        if i > limit:
                            dp[i][j][lastBit] -= dp[i-limit-1][j][1]
                    else:
                        dp[i][j][lastBit] = dp[i][j-1][0] + dp[i][j-1][1]
                        if j > limit:
                            dp[i][j][lastBit] -= dp[i][j-limit-1][0]
                    dp[i][j][lastBit] %= mod
        return (dp[-1][-1][0] + dp[-1][-1][1]) % mod

###C++

class Solution {
public:
    constexpr static int MOD = 1000000007;
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<int>>> dp(zero + 1, vector<vector<int>>(one + 1, vector<int>(2)));
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }

        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
};

###Java

class Solution {
    public int numberOfStableArrays(int zero, int one, int limit) {
        final int MOD = 1000000007;
        int[][][] dp = new int[zero + 1][one + 1][2];
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
}

###C#

public class Solution {
    public int NumberOfStableArrays(int zero, int one, int limit) {
        const int MOD = 1000000007;
        int[][][] dp = new int[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            dp[i] = new int[one + 1][];
            for (int j = 0; j <= one; j++) {
                dp[i][j] = new int[2];
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
}

###Go

const MOD = 1000000007

func numberOfStableArrays(zero int, one int, limit int) int {
    dp := make([][][]int, zero + 1)
    for i := range dp {
        dp[i] = make([][]int, one + 1)
        for j := range dp[i] {
            dp[i][j] = make([]int, 2)
        }
    }

    for i := 0; i <= zero; i++ {
        for j := 0; j <= one; j++ {
            for lastBit := 0; lastBit <= 1; lastBit++ {
                if i == 0 {
                    if lastBit == 0 || j > limit {
                        dp[i][j][lastBit] = 0
                    } else {
                        dp[i][j][lastBit] = 1
                    }
                } else if j == 0 {
                    if lastBit == 1 || i > limit {
                        dp[i][j][lastBit] = 0
                    } else {
                        dp[i][j][lastBit] = 1
                    }
                } else if lastBit == 0 {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1]
                    if i > limit {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1]
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1]
                    if j > limit {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0]
                    }
                }
                dp[i][j][lastBit] %= MOD
                if dp[i][j][lastBit] < 0 {
                    dp[i][j][lastBit] += MOD
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD
}

###C

#define MOD 1000000007

int numberOfStableArrays(int zero, int one, int limit) {
    int dp[zero + 1][one + 1][2];
    memset(dp, 0, sizeof(dp));
    for (int i = 0; i <= zero; i++) {
        for (int j = 0; j <= one; j++) {
            for (int lastBit = 0; lastBit <= 1; lastBit++) {
                if (i == 0) {
                    if (lastBit == 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j == 0) {
                    if (lastBit == 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit == 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    let dp = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= zero; i++) {
        for (let j = 0; j <= one; j++) {
            for (let lastBit = 0; lastBit <= 1; lastBit++) {
                if (i === 0) {
                    if (lastBit === 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j === 0) {
                    if (lastBit === 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit === 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    let dp: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= zero; i++) {
        for (let j = 0; j <= one; j++) {
            for (let lastBit = 0; lastBit <= 1; lastBit++) {
                if (i === 0) {
                    if (lastBit === 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j === 0) {
                    if (lastBit === 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit === 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut dp = vec![vec![vec![0; 2]; one as usize + 1]; zero as usize + 1];

        for i in 0..=zero as usize {
            for j in 0..=one as usize {
                for last_bit in 0..=1 {
                    if i == 0 {
                        if last_bit == 0 || j > limit as usize {
                            dp[i][j][last_bit] = 0;
                        } else {
                            dp[i][j][last_bit] = 1;
                        }
                    } else if j == 0 {
                        if last_bit == 1 || i > limit as usize {
                            dp[i][j][last_bit] = 0;
                        } else {
                            dp[i][j][last_bit] = 1;
                        }
                    } else if last_bit == 0 {
                        dp[i][j][last_bit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if i > limit as usize {
                            dp[i][j][last_bit] -= dp[i - (limit as usize) - 1][j][1];
                        }
                    } else {
                        dp[i][j][last_bit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                        if j > limit as usize {
                            dp[i][j][last_bit] -= dp[i][j - (limit as usize) - 1][0];
                        }
                    }
                    dp[i][j][last_bit] %= MOD;
                    if dp[i][j][last_bit] < 0 {
                        dp[i][j][last_bit] += MOD;
                    }
                }
            }
        }

        return (dp[zero as usize][one as usize][0] + dp[zero as usize][one as usize][1]) % MOD;
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero}\times\textit{one})$,动态规划的状态一共有 $O(\textit{zero}\times\textit{one})$ 种,每个状态消耗 $O(1)$ 时间消耗。

  • 空间复杂度:$O(\textit{zero}\times\textit{one})$。

找出所有稳定的二进制数组 I

方法一:动态规划

题目要求二进制数组 $\textit{arr}$ 中每个长度超过 $\textit{limit}$ 的子数组同时包含 $0$ 和 $1$,这个条件等价于二进制数组 $\textit{arr}$ 中每个长度等于 $\textit{limit} + 1$ 的子数组都同时包含 $0$ 和 $1$(读者可以思考一下证明过程)。

按照题目要求,我们需要将 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$ 依次填入二进制数组 $\textit{arr}$,使用 $\textit{dp}_0[i][j]$ 表示已经填入 $i$ 个 $0$ 和 $\textit{j}$ 个 $1$,并且最后填入的数字为 $0$ 的可行方案数目,$\textit{dp}_1[i][j]$ 表示已经填入 $i$ 个 $0$ 和 $\textit{j}$ 个 $1$,并且最后填入的数字为 $1$ 的可行方案数目。对于 $\textit{dp}_0[i][j]$,我们分析一下它的转换方程:

  • 当 $j = 0$ 且 $i \in [0, \min(\textit{zero}, \textit{limit})]$ 时:我们可以不断地填入 $0$,所以 $\textit{dp}_0[i][j] = 1$。

  • 当 $i = 0$,或者 $j = 0$ 且 $i \notin [0, \min(\textit{zero}, \textit{limit})]$ 时:我们没法构造可行的方案,所以 $\textit{dp}_0[i][j] = 0$。

  • 当 $i > 0$ 且 $j > 0$ 时:$\textit{dp}_0[i][j]$ 可以分别由 $\textit{dp}_0[i - 1][j]$ 和 $\textit{dp}_1[i - 1][j]$ 转移而来,分别考虑两种情况:

    • 对于 $\textit{dp}_1[i - 1][j]$:显然可以通过在 $\textit{dp}_1[i - 1][j]$ 对应的所有填入方案后再填入一个 $0$ 得到对应的可行方案。

    • 对于 $\textit{dp}_0[i - 1][j]$:当 $i \le \textit{limit}$ 时,显然可以通过在 $\textit{dp}_1[i - 1][j]$ 对应的所有填入方案后再填入一个 $0$ 得到对应的可行方案;当 $i \gt \textit{limit}$ 时,我们需要去除一些不可行的方案数。因为 $\textit{dp}_0[i - 1][j]$ 对应的所有填入方案都是可行的,而只有一种情况会在额外填入一个 $0$ 时,变成不可行,即先前已经连续填入 $\textit{limit}$ 个 $0$,对应的方案数为 $\textit{dp}_1[i - \textit{limit} - 1][j]$。

根据以上分析,我们有 $\textit{dp}_0[i][j]$ 的转移方程:

$$
\textit{dp}_0[i][j] = \begin{cases}
1, & i \in [0, \min(\textit{zero}, \textit{limit})], j = 0 \
\textit{dp}_1[i - 1][j] + \textit{dp}_0[i - 1][j] - \textit{dp}_1[i - \textit{limit} - 1][j], & i > limit, j > 0 \
\textit{dp}_1[i - 1][j] + \textit{dp}_0[i - 1][j], & i \in [0, limit], j > 0 \
0, & otherwise
\end{cases}
$$

同理,我们也可以获得 $\textit{dp}_1[i][j]$ 的转移方程:

$$
\textit{dp}_1[i][j] = \begin{cases}
1, & i = 0, j \in [0, \min(\textit{one}, \textit{limit})] \
\textit{dp}_0[i][j - 1] + \textit{dp}_1[i][j - 1] - \textit{dp}_0[i][j - \textit{limit} - 1], & i > 0, j > limit \
\textit{dp}_0[i][j - 1] + \textit{dp}_1[i][j - 1], & i > 0, j \in [0, limit] \
0, & otherwise
\end{cases}
$$

最后,稳定二进制数组的数目等于 $\textit{dp}_0[\textit{zero}][\textit{one}] + \textit{dp}_1[\textit{zero}][\textit{one}]$。

###C++

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<long long>>> dp(zero + 1, vector<vector<long long>>(one + 1, vector<long long>(2)));
        long long mod = 1e9 + 7;
        for (int i = 0; i <= min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod;
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % mod;
    }
};

###Java

class Solution {
    public int numberOfStableArrays(int zero, int one, int limit) {
        final long MOD = 1000000007;
        long[][][] dp = new long[zero + 1][one + 1][2];
        for (int i = 0; i <= Math.min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= Math.min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        return (int) ((dp[zero][one][0] + dp[zero][one][1]) % MOD);
    }
}

###C#

public class Solution {
    public int NumberOfStableArrays(int zero, int one, int limit) {
        const long MOD = 1000000007;
        long[][][] dp = new long[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            dp[i] = new long[one + 1][];
            for (int j = 0; j <= one; j++) {
                dp[i][j] = new long[2];
            }
        }
        for (int i = 0; i <= Math.Min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= Math.Min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        return (int) ((dp[zero][one][0] + dp[zero][one][1]) % MOD);
    }
}

###Go

func numberOfStableArrays(zero int, one int, limit int) int {
    dp := make([][][2]int, zero + 1)
    mod := int(1e9 + 7)
    for i := 0; i <= zero; i++ {
        dp[i] = make([][2]int, one + 1)
    }
    for i := 0; i <= min(zero, limit); i++ {
        dp[i][0][0] = 1
    }
    for j := 0; j <= min(one, limit); j++ {
        dp[0][j][1] = 1
    }
    for i := 1; i <= zero; i++ {
        for j := 1; j <= one; j++ {
            if i > limit {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1]
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1]
            }
            dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod
            if j > limit {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0]
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0]
            }
            dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % mod
}

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        dp = [[[0, 0] for _ in range(one + 1)] for _ in range(zero + 1)]
        mod = int(1e9 + 7)
        for i in range(min(zero, limit) + 1):
            dp[i][0][0] = 1
        for j in range(min(one, limit) + 1):
            dp[0][j][1] = 1
        for i in range(1, zero + 1):
            for j in range(1, one + 1):
                if i > limit:
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1]
                else:
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1]
                dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod
                if j > limit:
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0]
                else:
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0]
                dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod
        return (dp[zero][one][0] + dp[zero][one][1]) % mod

###C

#define MOD 1000000007

int numberOfStableArrays(int zero, int one, int limit) {
    long long dp[zero + 1][one + 1][2];
    memset(dp, 0, sizeof(dp));
    for (int i = 0; i <= (zero < limit ? zero : limit); ++i) {
        dp[i][0][0] = 1;
    }
    for (int j = 0; j <= (one < limit ? one : limit); ++j) {
        dp[0][j][1] = 1;
    }
    for (int i = 1; i <= zero; ++i) {
        for (int j = 1; j <= one; ++j) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    int result = (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    return result;
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    const dp = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= Math.min(zero, limit); i++) {
        dp[i][0][0] = 1;
    }
    for (let j = 0; j <= Math.min(one, limit); j++) {
        dp[0][j][1] = 1;
    }

    for (let i = 1; i <= zero; i++) {
        for (let j = 1; j <= one; j++) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    const dp: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= Math.min(zero, limit); i++) {
        dp[i][0][0] = 1;
    }
    for (let j = 0; j <= Math.min(one, limit); j++) {
        dp[0][j][1] = 1;
    }

    for (let i = 1; i <= zero; i++) {
        for (let j = 1; j <= one; j++) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut dp = vec![vec![vec![0; 2]; one as usize + 1]; zero as usize + 1];

        for i in 0..=zero.min(limit) as usize {
            dp[i][0][0] = 1;
        }
        for j in 0..=one.min(limit) as usize {
            dp[0][j][1] = 1;
        }

        for i in 1..=zero as usize {
            for j in 1..=one as usize {
                if i > limit as usize {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit as usize - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if j > limit as usize {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit as usize - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        (dp[zero as usize][one as usize][0] + dp[zero as usize][one as usize][1]) % MOD
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero} \times \textit{one})$,其中 $\textit{zero}$ 和 $\textit{one}$ 分别为 $0$ 和 $1$ 的出现次数。

  • 空间复杂度:$O(\textit{zero} \times \textit{one})$。

检查二进制字符串字段

方法一:寻找 $01$ 串

思路与算法

题目给定一个长度为 $n$ 的二进制字符串 $s$,并满足该字符串不含前导零。现在我们需要判断字符串中是否只包含零个或一个由连续 $1$ 组成的字段。首先我们依次分析这两种情况:

  • 字符串 $s$ 中包含零个由连续 $1$ 组成的字段,那么整个串的表示为 $00 \cdots 00$。
  • 字符串 $s$ 中只包含一个由连续 $1$ 组成的字段,因为已知字符串 $s$ 不包含前导零,所以整个串的表示为 $1 \cdots 100 \cdots 00$。

那么可以看到两种情况中都不包含 $01$ 串。且不包含的 $01$ 串的一个二进制字符串也有且仅有上面两种情况。所以我们可以通过原字符串中是否有 $01$ 串来判断字符串中是否只包含零个或一个由连续 $1$ 组成的字段。如果有 $01$ 串则说明该情况不满足,否则即满足该情况条件。

代码

###Python

class Solution:
    def checkOnesSegment(self, s: str) -> bool:
        return "01" not in s

###Java

class Solution {
    public boolean checkOnesSegment(String s) {
        return !s.contains("01");
    }
}

###C#

public class Solution {
    public bool CheckOnesSegment(string s) {
        return !s.Contains("01");
    }
}

###C++

class Solution {
public:
    bool checkOnesSegment(string s) {
        return s.find("01") == string::npos;
    }
};

###C

bool checkOnesSegment(char * s){
    return strstr(s, "01") == NULL;
}

###JavaScript

var checkOnesSegment = function(s) {
    return s.indexOf('01') === -1;
};

###go

func checkOnesSegment(s string) bool {
return !strings.Contains(s, "01")
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为字符串 $s$ 的长度。
  • 空间复杂度:$O(1)$,仅适用常量空间。

生成交替二进制字符串的最少操作数

方法一:模拟

思路

根据题意,经过多次操作,$s$ 可能会变成两种不同的交替二进制字符串,即:

  • 开头为 $0$,后续交替的字符串;
  • 开头为 $1$,后续交替的字符串。

注意到,变成这两种不同的交替二进制字符串所需要的最少操作数加起来等于 $s$ 的长度,我们只需要计算出变为其中一种字符串的最少操作数,就可以推出另一个最少操作数,然后取最小值即可。

代码

###Python

class Solution:
    def minOperations(self, s: str) -> int:
        cnt = sum(int(c) != i % 2 for i, c in enumerate(s))
        return min(cnt, len(s) - cnt)

###Java

class Solution {
    public int minOperations(String s) {
        int cnt = 0;
        for (int i = 0; i < s.length(); i++) {
            char c = s.charAt(i);
            if (c != (char) ('0' + i % 2)) {
                cnt++;
            }
        }
        return Math.min(cnt, s.length() - cnt);
    }
}

###C#

public class Solution {
    public int MinOperations(string s) {
        int cnt = 0;
        for (int i = 0; i < s.Length; i++) {
            char c = s[i];
            if (c != (char) ('0' + i % 2)) {
                cnt++;
            }
        }
        return Math.Min(cnt, s.Length - cnt);
    }
}

###C++

class Solution {
public:
    int minOperations(string s) {
        int cnt = 0;
        for (int i = 0; i < s.size(); i++) {
            char c = s[i];
            if (c != ('0' + i % 2)) {
                cnt++;
            }
        }
        return min(cnt, (int)s.size() - cnt);
    }
};

###C

#define MIN(a, b) ((a) < (b) ? (a) : (b))

int minOperations(char * s) {
    int cnt = 0, len = strlen(s);
    for (int i = 0; i < len; i++) {
        char c = s[i];
        if (c != ('0' + i % 2)) {
            cnt++;
        }
    }
    return MIN(cnt, len - cnt);
}

###JavaScript

var minOperations = function(s) {
    let cnt = 0;
    for (let i = 0; i < s.length; i++) {
        const c = s[i];
        if (c !== (String.fromCharCode('0'.charCodeAt() + i % 2))) {
            cnt++;
        }
    }
    return Math.min(cnt, s.length - cnt);
};

###go

func minOperations(s string) int {
    cnt := 0
    for i, c := range s {
        if i%2 != int(c-'0') {
            cnt++
        }
    }
    return min(cnt, len(s)-cnt)
}

func min(a, b int) int {
    if a > b {
        return b
    }
    return a
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为输入 $s$ 的长度,仅需遍历一遍字符串。

  • 空间复杂度:$O(1)$,只需要常数额外空间。

找出第 N 个二进制字符串中的第 K 位

方法一:递归

观察二进制字符串 $S_n$,可以发现,当 $n>1$ 时,$S_n$ 是在 $S_{n-1}$ 的基础上形成的。用 $\text{len}n$ 表示 $S_n$ 的长度,则 $S_n$ 的前 $\text{len}{n-1}$ 个字符与 $S_{n-1}$ 相同。还可以发现,当 $n>1$ 时,$\text{len}n=\text{len}{n-1} \times 2 + 1$,根据 $\text{len}_1=1$ 可知 $\text{len}_n=2^n-1$。

由于 $S_1=``0"$,且对于任意 $n \ge 1$,$S_n$ 的第 $1$ 位字符也一定是 $0'$,因此当 $k=1$ 时,直接返回字符 $0'$。

当 $n>1$ 时,$S_n$ 的长度是 $2^n-1$。$S_n$ 可以分成三个部分,左边 $2^{n-1}-1$ 个字符是 $S_{n-1}$,中间 $1$ 个字符是 $1'$,右边 $2^{n-1}-1$ 个字符是 $S_{n-1}$ 翻转与反转之后的结果。中间的字符 $1'$ 是 $S_n$ 的第 $2^{n-1}$ 位字符,因此如果 $k=2^{n-1}$,直接返回字符 $`1'$。

当 $k \ne 2^{n-1}$ 时,考虑以下两种情况:

  • 如果 $k<2^{n-1}$,则第 $k$ 位字符在 $S_n$ 的前半部分,即第 $k$ 位字符在 $S_{n-1}$ 中,因此在 $S_{n-1}$ 中寻找第 $k$ 位字符;

  • 如果 $k>2^{n-1}$,则第 $k$ 位字符在 $S_n$ 的后半部分,由于后半部分为前半部分进行翻转与反转之后的结果,因此在前半部分寻找第 $2^n-k$ 位字符,将其反转之后即为 $S_n$ 的第 $k$ 位字符。

上述过程可以通过递归实现。

###Java

class Solution {
    public char findKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return findKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return invert(findKthBit(n - 1, k));
        }
    }

    public char invert(char bit) {
        return (char) ('0' + '1' - bit);
    }
}

###JavaScript

const invert = (bit) => bit === '0' ? '1' : '0';

var findKthBit = function(n, k) {
    if (k == 1) {
        return '0';
    }
    const mid = 1 << (n - 1);
    if (k == mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
};

###C++

class Solution {
public:
    char findKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return findKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return invert(findKthBit(n - 1, k));
        }
    }

    char invert(char bit) {
        return (char) ('0' + '1' - bit);
    }
};

###Python

class Solution:
    def findKthBit(self, n: int, k: int) -> str:
        if k == 1:
            return "0"
        
        mid = 1 << (n - 1)
        if k == mid:
            return "1"
        elif k < mid:
            return self.findKthBit(n - 1, k)
        else:
            k = mid * 2 - k
            return "0" if self.findKthBit(n - 1, k) == "1" else "1"

###C#

public class Solution {
    public char FindKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return FindKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return Invert(FindKthBit(n - 1, k));
        }
    }

    private char Invert(char bit) {
        return (char)('0' + '1' - bit);
    }
}

###Go

func findKthBit(n int, k int) byte {
    if k == 1 {
        return '0'
    }
    mid := 1 << (n - 1)
    if k == mid {
        return '1'
    } else if k < mid {
        return findKthBit(n - 1, k)
    } else {
        k = mid*2 - k
        return invert(findKthBit(n - 1, k))
    }
}

func invert(bit byte) byte {
    if bit == '0' {
        return '1'
    }
    return '0'
}

###C

char invert(char bit) {
    return '0' + '1' - bit;
}

char findKthBit(int n, int k) {
    if (k == 1) {
        return '0';
    }
    int mid = 1 << (n - 1);
    if (k == mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
}

###TypeScript

function findKthBit(n: number, k: number): string {
    if (k === 1) {
        return '0';
    }
    const mid = 1 << (n - 1);
    if (k === mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
}

function invert(bit: string): string {
    return bit === '0' ? '1' : '0';
}

###Rust

impl Solution {
    pub fn find_kth_bit(n: i32, k: i32) -> char {
        Self::find_kth_bit_recursive(n, k)
    }
    
    fn find_kth_bit_recursive(n: i32, k: i32) -> char {
        if k == 1 {
            return '0';
        }
        let mid = 1 << (n - 1);
        if k == mid {
            return '1';
        } else if k < mid {
            return Self::find_kth_bit_recursive(n - 1, k);
        } else {
            let new_k = mid * 2 - k;
            return Self::invert(Self::find_kth_bit_recursive(n - 1, new_k));
        }
    }
    
    fn invert(bit: char) -> char {
        if bit == '0' {
            '1'
        } else {
            '0'
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n)$。字符串 $S_n$ 的长度为 $2^n-1$,每次递归调用可以将查找范围缩小一半,因此时间复杂度为 $O(\log 2^n)=O(n)$。

  • 空间复杂度:$O(n)$。空间复杂度主要取决于递归调用产生的栈空间,递归调用层数不会超过 $n$。

从根到叶的二进制数之和

前言

关于二叉树后序遍历的详细说明请参考「145. 二叉树的后序遍历的官方题解」。

方法一:递归

后序遍历的访问顺序为:左子树——右子树——根节点。我们对根节点 $\textit{root}$ 进行后序遍历:

  • 如果节点是叶子节点,返回它对应的数字 $\textit{val}$。

  • 如果节点是非叶子节点,返回它的左子树和右子树对应的结果之和。

###Python

class Solution:
    def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
        def dfs(node: Optional[TreeNode], val: int) -> int:
            if node is None:
                return 0
            val = (val << 1) | node.val
            if node.left is None and node.right is None:
                return val
            return dfs(node.left, val) + dfs(node.right, val)
        return dfs(root, 0)

###C++

class Solution {
public:
    int dfs(TreeNode *root, int val) {
        if (root == nullptr) {
            return 0;
        }
        val = (val << 1) | root->val;
        if (root->left == nullptr && root->right == nullptr) {
            return val;
        }
        return dfs(root->left, val) + dfs(root->right, val);
    }

    int sumRootToLeaf(TreeNode* root) {
        return dfs(root, 0);
    }
};

###Java

class Solution {
    public int sumRootToLeaf(TreeNode root) {
        return dfs(root, 0);
    }

    public int dfs(TreeNode root, int val) {
        if (root == null) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (root.left == null && root.right == null) {
            return val;
        }
        return dfs(root.left, val) + dfs(root.right, val);
    }
}

###C#

public class Solution {
    public int SumRootToLeaf(TreeNode root) {
        return DFS(root, 0);
    }

    public int DFS(TreeNode root, int val) {
        if (root == null) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (root.left == null && root.right == null) {
            return val;
        }
        return DFS(root.left, val) + DFS(root.right, val);
    }
}

###C

int dfs(struct TreeNode *root, int val) {
    if (root == NULL) {
        return 0;
    }
    val = (val << 1) | root->val;
    if (root->left == NULL && root->right == NULL) {
        return val;
    }
    return dfs(root->left, val) + dfs(root->right, val);
}

int sumRootToLeaf(struct TreeNode* root){
    return dfs(root, 0);
}

###JavaScript

var sumRootToLeaf = function(root) {
    const dfs = (root, val) => {
        if (!root) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (!root.left&& !root.right) {
            return val;
        }
        return dfs(root.left, val) + dfs(root.right, val);
    }
    return dfs(root, 0);
};

###go

func dfs(node *TreeNode, val int) int {
    if node == nil {
        return 0
    }
    val = val<<1 | node.Val
    if node.Left == nil && node.Right == nil {
        return val
    }
    return dfs(node.Left, val) + dfs(node.Right, val)
}

func sumRootToLeaf(root *TreeNode) int {
    return dfs(root, 0)
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是节点数目。总共访问 $n$ 个节点。

  • 空间复杂度:$O(n)$。递归栈需要 $O(n)$ 的空间。

方法二:迭代

我们用栈来模拟递归,同时使用一个 $\textit{prev}$ 指针来记录先前访问的节点。算法步骤如下:

  1. 如果节点 $\textit{root}$ 非空,我们将不断地将它及它的左节点压入栈中。

  2. 我们从栈中获取节点:

    • 该节点的右节点为空或者等于 $\textit{prev}$,说明该节点的左子树及右子树都已经被访问,我们将它出栈。如果该节点是叶子节点,我们将它对应的数字 $\textit{val}$ 加入结果中。设置 $\textit{prev}$ 为该节点,设置 $\textit{root}$ 为空指针。

    • 该节点的右节点非空且不等于 $\textit{prev}$,我们令 $\textit{root}$ 指向该节点的右节点。

  3. 如果 $\textit{root}$ 为空指针或者栈空,中止算法,否则重复步骤 $1$。

需要注意的是,每次出入栈都需要更新 $\textit{val}$。

###Python

class Solution:
    def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
        ans = val = 0
        st = []
        pre = None
        while root or st:
            while root:
                val = (val << 1) | root.val
                st.append(root)
                root = root.left
            root = st[-1]
            if root.right is None or root.right == pre:
                if root.left is None and root.right is None:
                    ans += val
                val >>= 1
                st.pop()
                pre = root
                root = None
            else:
                root = root.right
        return ans

###C++

class Solution {
public:
    int sumRootToLeaf(TreeNode* root) {
        stack<TreeNode *> st;
        int val = 0, ret = 0;
        TreeNode *prev = nullptr;
        while (root != nullptr || !st.empty()) {
            while (root != nullptr) {
                val = (val << 1) | root->val;
                st.push(root);
                root = root->left;
            }
            root = st.top();
            if (root->right == nullptr || root->right == prev) {
                if (root->left == nullptr && root->right == nullptr) {
                    ret += val;
                }
                val >>= 1;
                st.pop();
                prev = root;
                root = nullptr;
            } else {
                root = root->right;
            }
        }
        return ret;
    }
};

###Java

class Solution {
    public int sumRootToLeaf(TreeNode root) {
        Deque<TreeNode> stack = new ArrayDeque<TreeNode>();
        int val = 0, ret = 0;
        TreeNode prev = null;
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                val = (val << 1) | root.val;
                stack.push(root);
                root = root.left;
            }
            root = stack.peek();
            if (root.right == null || root.right == prev) {
                if (root.left == null && root.right == null) {
                    ret += val;
                }
                val >>= 1;
                stack.pop();
                prev = root;
                root = null;
            } else {
                root = root.right;
            }
        }
        return ret;
    }
}

###C#

public class Solution {
    public int SumRootToLeaf(TreeNode root) {
        Stack<TreeNode> stack = new Stack<TreeNode>();
        int val = 0, ret = 0;
        TreeNode prev = null;
        while (root != null || stack.Count > 0) {
            while (root != null) {
                val = (val << 1) | root.val;
                stack.Push(root);
                root = root.left;
            }
            root = stack.Peek();
            if (root.right == null || root.right == prev) {
                if (root.left == null && root.right == null) {
                    ret += val;
                }
                val >>= 1;
                stack.Pop();
                prev = root;
                root = null;
            } else {
                root = root.right;
            }
        }
        return ret;
    }
}

###C

#define MAX_NODE_SIZE 1000

int sumRootToLeaf(struct TreeNode* root) {
    struct TreeNode ** stack = (struct TreeNode **)malloc(sizeof(struct TreeNode *) * MAX_NODE_SIZE);
    int top = 0;
    int val = 0, ret = 0;
    struct TreeNode *prev = NULL;
    while (root != NULL || top) {
        while (root != NULL) {
            val = (val << 1) | root->val;
            stack[top++] = root;
            root = root->left;
        }
        root = stack[top - 1];
        if (root->right == NULL || root->right == prev) {
            if (root->left == NULL && root->right == NULL) {
                ret += val;
            }
            val >>= 1;
            top--;
            prev = root;
            root = NULL;
        } else {
            root = root->right;
        }
    }
    free(stack);
    return ret;
}

###JavaScript

var sumRootToLeaf = function(root) {
    const stack = [];
    let val = 0, ret = 0;
    let prev = null;
    while (root || stack.length) {
        while (root) {
            val = (val << 1) | root.val;
            stack.push(root);
            root = root.left;
        }
        root = stack[stack.length - 1];
        if (!root.right || root.right === prev) {
            if (!root.left && !root.right) {
                ret += val;
            }
            val >>= 1;
            stack.pop();
            prev = root;
            root = null;
        } else {
            root = root.right;
        }
    }
    return ret;
};

###go

func sumRootToLeaf(root *TreeNode) (ans int) {
    val, st := 0, []*TreeNode{}
    var pre *TreeNode
    for root != nil || len(st) > 0 {
        for root != nil {
            val = val<<1 | root.Val
            st = append(st, root)
            root = root.Left
        }
        root = st[len(st)-1]
        if root.Right == nil || root.Right == pre {
            if root.Left == nil && root.Right == nil {
                ans += val
            }
            val >>= 1
            st = st[:len(st)-1]
            pre = root
            root = nil
        } else {
            root = root.Right
        }
    }
    return
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是节点数目。总共访问 $n$ 个节点。

  • 空间复杂度:$O(n)$。栈最多压入 $n$ 个节点。

❌