普通视图

发现新文章,点击刷新页面。
今天 — 2026年4月12日首页

二指输入的的最小距离

2020年2月19日 11:20

方法一:动态规划

我们用 dp[i][l][r] 表示在输入了字符串 word 的第 i 个字母后,左手的位置为 l,右手的位置为 r,达到该状态的最小移动距离。这里的位置为指向的字母编号,例如 A 对应 0B 对应 1,以此类推,而非字母在键盘上的位置。这样做的好处是将字母的位置映射成一个整数而非二维的坐标,使得我们更加方便地进行状态转移。

那么如何进行状态转移呢?我们首先需要看出一个非常重要的性质:对于状态 dp[i][l][r],要么 word[i] == l,要么 word[i] == r,即在输入了第 i 个字母后,左手和右手中至少有一个在 word[i] 的位置。我们可以根据这两种情况,分别进行状态转移:

  • word[i] == l 时,左手在 word[i] 的位置。我们需要考虑在输入字符串 word 的第 i - 1 个字母时,是左手还是右手在 word[i - 1] 的位置:

    • 如果左手在 word[i - 1] 的位置,那么在输入第 i 个字母时,左手从 word[i - 1] 移动至 word[i],状态转移方程为:

      dp[i][l = word[i]][r] = dp[i - 1][l0 = word[i - 1]][r] + dist(word[i - 1], word[i])
      
    • 如果右手在 word[i - 1] 的位置,那么由于第 i 个字母使用了左手,右手就没有移动,即 word[i - 1] == r。同时,在输入 word[i1] 之前的左手位置也无关紧要,可以为任意的 l0,状态转移方程为:

      dp[i][l = word[i]][r = word[i - 1]] = dp[i - 1][l0][r = word[i - 1]] + dist(l0, word[i])
      
  • word[i] == r 时,右手在 word[i] 的位置。我们需要考虑在输入字符串 word 的第 i - 1 个字母时,是右手还是左手在 word[i - 1] 的位置:

    • 如果右手在 word[i - 1] 的位置,那么在输入第 i 个字母时,右手从 word[i - 1] 移动至 word[i],状态转移方程为:

      dp[i][l][r = word[i]] = dp[i - 1][l][r0 = word[i - 1]] + dist(word[i - 1], word[i])
      
    • 如果左手在 word[i - 1] 的位置,那么由于第 i 个字母使用了右手,左手就没有移动,即 word[i - 1] == l。同时,在输入 word[i] 之前的右手位置也无关紧要,可以为任意的 r0,状态转移方程为:

      dp[i][l = word[i - 1]][r = word[i]] = dp[i - 1][l = word[i - 1]][r0] + dist(r0, word[i])
      

对于每一个状态 dp[i][l][r],我们取它所有转移中的最小值,即为输入了字符串 word 的第 i 个字母后,左手的位置为 l,右手的位置为 r,达到该状态的最小移动距离。

在这个动态规划中,我们还需要考虑不合法的状态以及边界状态。对于某一个不合法的状态,如果用它来进行状态转移,可能会使得 dp[i][l][r] 取到一个更小且不合法的值。因此,我们一般会给所有不合法的状态赋予一个非常大的值(例如 C++ 中的整数最大值 INT_MAX),这样即使用它来进行状态转移,也会因为本身值非常大的缘故,对最优解没有任何影响。在考虑边界状态时,由于题目中规定两根手指的起始位置是零代价的,因此对于字符串中的第 0 个字母 word[0],输入它的最小移动距离为 0。此时要么左手的位置为 word[0],要么右手的位置为 word[0],因此我们可以将所有的 dp[0][l = word[0]][r] 以及 dp[0][l][r = word[0]] 作为边界状态,它们的值为 0

###C++

class Solution {
public:
    int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return abs(x1 - x2) + abs(y1 - y2);
    }

    int minimumDistance(string word) {
        int n = word.size();
        int dp[n][26][26];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                fill(dp[i][j], dp[i][j] + 26, INT_MAX >> 1);
            }
        }
        for (int i = 0; i < 26; ++i) {
            dp[0][i][word[0] - 'A'] = dp[0][word[0] - 'A'][i] = 0;
        }
        
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = getDistance(prev, cur);
            for (int j = 0; j < 26; ++j) {
                dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][prev][j] + d);
                dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][prev] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                        dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                    }
                }
            }
        }

        int ans = INT_MAX >> 1;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = min(ans, dp[n - 1][i][j]);
            }
        }
        return ans;
    }
};

###Python

class Solution:
    def minimumDistance(self, word: str) -> int:
        n = len(word)
        BIG = 2**30
        dp = [[[BIG] * 26 for x in range(26)] for y in range(n)]
        for i in range(26):
            dp[0][i][ord(word[0]) - 65] = 0
            dp[0][ord(word[0]) - 65][i] = 0
    
        def getDistance(p, q):
            x1, y1 = p // 6, p % 6
            x2, y2 = q // 6, q % 6
            return abs(x1 - x2) + abs(y1 - y2)

        for i in range(1, n):
            cur, prev = ord(word[i]) - 65, ord(word[i - 1]) - 65
            d = getDistance(prev, cur)
            for j in range(26):
                dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][prev][j] + d)
                dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][prev] + d)
                if prev == j:
                    for k in range(26):
                        d0 = getDistance(k, cur)
                        dp[i][cur][j] = min(dp[i][cur][j], dp[i - 1][k][j] + d0)
                        dp[i][j][cur] = min(dp[i][j][cur], dp[i - 1][j][k] + d0)
        
        ans = min(min(dp[n - 1][x]) for x in range(26))
        return ans

###Java

class Solution {
    private int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    }

    public int minimumDistance(String word) {
        int n = word.length();
        int[][][] dp = new int[n][26][26];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                for (int k = 0; k < 26; ++k) {
                    dp[i][j][k] = Integer.MAX_VALUE / 2;
                }
            }
        }
        
        for (int i = 0; i < 26; ++i) {
            dp[0][i][word.charAt(0) - 'A'] = 0;
            dp[0][word.charAt(0) - 'A'][i] = 0;
        }
        
        for (int i = 1; i < n; ++i) {
            int cur = word.charAt(i) - 'A';
            int prev = word.charAt(i - 1) - 'A';
            int d = getDistance(prev, cur);
            
            for (int j = 0; j < 26; ++j) {
                dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
                dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
                
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                        dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                    }
                }
            }
        }
        
        int ans = Integer.MAX_VALUE / 2;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = Math.min(ans, dp[n - 1][i][j]);
            }
        }
        return ans;
    }
}

###C#

public class Solution {
    private int GetDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.Abs(x1 - x2) + Math.Abs(y1 - y2);
    }

    public int MinimumDistance(string word) {
        int n = word.Length;
        int[,,] dp = new int[n, 26, 26];
        
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 26; ++j) {
                for (int k = 0; k < 26; ++k) {
                    dp[i, j, k] = int.MaxValue / 2;
                }
            }
        }
        
        for (int i = 0; i < 26; ++i) {
            dp[0, i, word[0] - 'A'] = 0;
            dp[0, word[0] - 'A', i] = 0;
        }
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = GetDistance(prev, cur);
            
            for (int j = 0; j < 26; ++j) {
                dp[i, cur, j] = Math.Min(dp[i, cur, j], dp[i - 1, prev, j] + d);
                dp[i, j, cur] = Math.Min(dp[i, j, cur], dp[i - 1, j, prev] + d);
                
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = GetDistance(k, cur);
                        dp[i, cur, j] = Math.Min(dp[i, cur, j], dp[i - 1, k, j] + d0);
                        dp[i, j, cur] = Math.Min(dp[i, j, cur], dp[i - 1, j, k] + d0);
                    }
                }
            }
        }
        
        int ans = int.MaxValue / 2;
        for (int i = 0; i < 26; ++i) {
            for (int j = 0; j < 26; ++j) {
                ans = Math.Min(ans, dp[n - 1, i, j]);
            }
        }
        return ans;
    }
}

###Go

func getDistance(p, q int) int {
    x1, y1 := p/6, p%6
    x2, y2 := q/6, q%6
    return abs(x1 - x2) + abs(y1 - y2)
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

func minimumDistance(word string) int {
    n := len(word)
    dp := make([][26][26]int, n)
    
    for i := 0; i < n; i++ {
        for j := 0; j < 26; j++ {
            for k := 0; k < 26; k++ {
                dp[i][j][k] = 1 << 30
            }
        }
    }
    
    firstChar := int(word[0] - 'A')
    for i := 0; i < 26; i++ {
        dp[0][i][firstChar] = 0
        dp[0][firstChar][i] = 0
    }
    
    for i := 1; i < n; i++ {
        cur := int(word[i] - 'A')
        prev := int(word[i-1] - 'A')
        d := getDistance(prev, cur)
        
        for j := 0; j < 26; j++ {
            dp[i][cur][j] = min(dp[i][cur][j], dp[i-1][prev][j]+d)
            dp[i][j][cur] = min(dp[i][j][cur], dp[i-1][j][prev]+d)
            
            if prev == j {
                for k := 0; k < 26; k++ {
                    d0 := getDistance(k, cur)
                    dp[i][cur][j] = min(dp[i][cur][j], dp[i-1][k][j]+d0)
                    dp[i][j][cur] = min(dp[i][j][cur], dp[i-1][j][k]+d0)
                }
            }
        }
    }
    
    ans := 1 << 30
    for i := 0; i < 26; i++ {
        for j := 0; j < 26; j++ {
            ans = min(ans, dp[n-1][i][j])
        }
    }
    return ans
}

###C

int getDistance(int p, int q) {
    int x1 = p / 6, y1 = p % 6;
    int x2 = q / 6, y2 = q % 6;
    return abs(x1 - x2) + abs(y1 - y2);
}

int minimumDistance(char* word) {
    int n = strlen(word);
    int*** dp = (int***)malloc(n * sizeof(int**));
    for (int i = 0; i < n; ++i) {
        dp[i] = (int**)malloc(26 * sizeof(int*));
        for (int j = 0; j < 26; ++j) {
            dp[i][j] = (int*)malloc(26 * sizeof(int));
            for (int k = 0; k < 26; ++k) {
                dp[i][j][k] = INT_MAX / 2;
            }
        }
    }
    
    for (int i = 0; i < 26; ++i) {
        dp[0][i][word[0] - 'A'] = 0;
        dp[0][word[0] - 'A'][i] = 0;
    }
    for (int i = 1; i < n; ++i) {
        int cur = word[i] - 'A';
        int prev = word[i - 1] - 'A';
        int d = getDistance(prev, cur);
        
        for (int j = 0; j < 26; ++j) {
            dp[i][cur][j] = fmin(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = fmin(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev == j) {
                for (int k = 0; k < 26; ++k) {
                    int d0 = getDistance(k, cur);
                    dp[i][cur][j] = fmin(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = fmin(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    int ans = INT_MAX / 2;
    for (int i = 0; i < 26; ++i) {
        for (int j = 0; j < 26; ++j) {
            if (ans > dp[n - 1][i][j]) {
                ans = dp[n - 1][i][j];
            }
        }
    }
    
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < 26; ++j) {
            free(dp[i][j]);
        }
        free(dp[i]);
    }
    free(dp);
    
    return ans;
}

###JavaScript

var minimumDistance = function(word) {
    const n = word.length;
    const getDistance = (p, q) => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26);
        for (let j = 0; j < 26; j++) {
            dp[i][j] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
        }
    }
    
    const firstChar = word.charCodeAt(0) - 65;
    for (let i = 0; i < 26; i++) {
        dp[0][i][firstChar] = 0;
        dp[0][firstChar][i] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    let ans = Number.MAX_SAFE_INTEGER;
    for (let i = 0; i < 26; i++) {
        for (let j = 0; j < 26; j++) {
            ans = Math.min(ans, dp[n - 1][i][j]);
        }
    }
    return ans;
};

###TypeScript

function minimumDistance(word: string): number {
    const n = word.length;
    const getDistance = (p: number, q: number): number => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp: number[][][] = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26);
        for (let j = 0; j < 26; j++) {
            dp[i][j] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
        }
    }
    const firstChar = word.charCodeAt(0) - 65;
    for (let i = 0; i < 26; i++) {
        dp[0][i][firstChar] = 0;
        dp[0][firstChar][i] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][prev][j] + d);
            dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][prev] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][cur][j] = Math.min(dp[i][cur][j], dp[i - 1][k][j] + d0);
                    dp[i][j][cur] = Math.min(dp[i][j][cur], dp[i - 1][j][k] + d0);
                }
            }
        }
    }
    
    let ans = Number.MAX_SAFE_INTEGER;
    for (let i = 0; i < 26; i++) {
        for (let j = 0; j < 26; j++) {
            ans = Math.min(ans, dp[n - 1][i][j]);
        }
    }
    return ans;
}

###Rust

impl Solution {
    fn get_distance(p: i32, q: i32) -> i32 {
        let x1 = p / 6;
        let y1 = p % 6;
        let x2 = q / 6;
        let y2 = q % 6;
        (x1 - x2).abs() + (y1 - y2).abs()
    }

    pub fn minimum_distance(word: String) -> i32 {
        let n = word.len();
        let word_chars: Vec<char> = word.chars().collect();
        let mut dp = vec![vec![vec![i32::MAX >> 1; 26]; 26]; n];
        let first_char = (word_chars[0] as u8 - b'A') as usize;
        for i in 0..26 {
            dp[0][i][first_char] = 0;
            dp[0][first_char][i] = 0;
        }
        
        for i in 1..n {
            let cur = (word_chars[i] as u8 - b'A') as i32;
            let prev = (word_chars[i - 1] as u8 - b'A') as i32;
            let d = Self::get_distance(prev, cur);
            
            for j in 0..26 {
                let j_i32 = j as i32;
                dp[i][cur as usize][j] = dp[i][cur as usize][j].min(
                    dp[i - 1][prev as usize][j].saturating_add(d)
                );
                dp[i][j][cur as usize] = dp[i][j][cur as usize].min(
                    dp[i - 1][j][prev as usize].saturating_add(d)
                );
                
                if prev == j_i32 {
                    for k in 0..26 {
                        let d0 = Self::get_distance(k as i32, cur);
                        dp[i][cur as usize][j] = dp[i][cur as usize][j].min(
                            dp[i - 1][k][j].saturating_add(d0)
                        );
                        dp[i][j][cur as usize] = dp[i][j][cur as usize].min(
                            dp[i - 1][j][k].saturating_add(d0)
                        );
                    }
                }
            }
        }
        
        let mut ans = i32::MAX >> 1;
        for i in 0..26 {
            for j in 0..26 {
                ans = ans.min(dp[n - 1][i][j]);
            }
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(|\Sigma|N)$,其中 $N$ 是字符串 word 的长度,$|\Sigma|$ 是可能出现的字母数量,在本题中 $|\Sigma| = 26$。对于状态 dp[i][l][r],枚举 i 需要的时间复杂度为 $O(N)$,在此之后,如果 word[i] == l,根据上面的状态转移方程:

    • 如果左手在 word[i - 1] 的位置,那么单次状态转移的时间复杂度为 $O(1)$,需要对所有的 r 都进行转移,总时间复杂度为 $O(|\Sigma|)$;

    • 如果右手在 word[i - 1] 的位置,那么 r == word[i - 1]。虽然我们要枚举 l0,但是合法的 r 只有一个,因此总时间复杂度也为 $O(|\Sigma|)$。

    如果 word[i] == r,分析的过程相同,在此不再赘述。这样总时间复杂度即为 $O(|\Sigma|N)$。

  • 空间复杂度:$O(|\Sigma|^2 N)$。

方法二:动态规划 + 空间优化

在方法一中,我们提到了一条重要的性质:对于状态 dp[i][l][r],要么 word[i] == l,要么 word[i] == r,即在输入了第 i 个字母后,左手和右手中至少有一个在 word[i] 的位置。那么对于每一个 i,我们其实只需要存储 $2|\Sigma|$ 而不是 $|\Sigma|^2$ 个状态。例如我们可以用 dp[i][op][rest] 表示状态,其中 op 的值只能为 01op == 0 表示左手在 word[i] 的位置,op == 1 表示右手在 word[i] 的位置,而 rest 表示不在 word[i] 位置的另一只手的位置。这样我们在状态转移方程几乎不变的前提下,减少了动态规划需要的空间。

那么我们是否还可以继续进行优化呢?我们可以发现,在方法一中,状态转移方程具有高度对称性,那么我们可以断定,dp[i][op = 0][rest]dp[i][op = 1][rest] 的值一定是相等的。这是因为 dp[i][op = 0][rest] 表示左手在 word[i] 的位置且右手在 rest 的位置,如果我们将左右手互换,那么我们同样可以使用 dp[i][op = 0][rest] 的移动距离使得右手在 word[i] 的位置且左手在 rest 的位置,而这恰好就是 dp[i][op = 1][rest]

因此我们可以直接使用 dp[i][rest] 进行状态转移,其表示一只手在 word[i] 的位置,另一只手在 rest 的位置的最小移动距离。我们并不需要关心具体哪只手在 word[i] 的位置,因为两只手是完全对称的。这样以来,我们将三维的动态规划优化至了二维,大大减少了空间的使用。

###C++

class Solution {
public:
    int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return abs(x1 - x2) + abs(y1 - y2);
    }

    int minimumDistance(string word) {
        int n = word.size();
        vector<vector<int>> dp(n, vector<int>(26, INT_MAX >> 1));
        fill(dp[0].begin(), dp[0].end(), 0);
        
        for (int i = 1; i < n; ++i) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = getDistance(prev, cur);
            for (int j = 0; j < 26; ++j) {
                dp[i][j] = min(dp[i][j], dp[i - 1][j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; ++k) {
                        int d0 = getDistance(k, cur);
                        dp[i][j] = min(dp[i][j], dp[i - 1][k] + d0);
                    }
                }
            }
        }

        int ans = *min_element(dp[n - 1].begin(), dp[n - 1].end());
        return ans;
    }
};

###Python

class Solution:
    def minimumDistance(self, word: str) -> int:
        n = len(word)
        BIG = 2**30
        dp = [[0] * 26] + [[BIG] * 26 for _ in range(n - 1)]
        
        def getDistance(p, q):
            x1, y1 = p // 6, p % 6
            x2, y2 = q // 6, q % 6
            return abs(x1 - x2) + abs(y1 - y2)

        for i in range(1, n):
            cur, prev = ord(word[i]) - 65, ord(word[i - 1]) - 65
            d = getDistance(prev, cur)
            for j in range(26):
                dp[i][j] = min(dp[i][j], dp[i - 1][j] + d)
                if prev == j:
                    for k in range(26):
                        d0 = getDistance(k, cur)
                        dp[i][j] = min(dp[i][j], dp[i - 1][k] + d0)

        ans = min(dp[n - 1])
        return ans

###Java

class Solution {
    private int getDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    }

    public int minimumDistance(String word) {
        int n = word.length();
        int[][] dp = new int[n][26];
        for (int i = 0; i < n; i++) {
            Arrays.fill(dp[i], Integer.MAX_VALUE / 2);
        }
        Arrays.fill(dp[0], 0);
        
        for (int i = 1; i < n; i++) {
            int cur = word.charAt(i) - 'A';
            int prev = word.charAt(i - 1) - 'A';
            int d = getDistance(prev, cur);
            
            for (int j = 0; j < 26; j++) {
                dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; k++) {
                        int d0 = getDistance(k, cur);
                        dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                    }
                }
            }
        }
        
        int ans = Integer.MAX_VALUE / 2;
        for (int value : dp[n - 1]) {
            ans = Math.min(ans, value);
        }
        return ans;
    }
}

###C#

public class Solution {
    private int GetDistance(int p, int q) {
        int x1 = p / 6, y1 = p % 6;
        int x2 = q / 6, y2 = q % 6;
        return Math.Abs(x1 - x2) + Math.Abs(y1 - y2);
    }

    public int MinimumDistance(string word) {
        int n = word.Length;
        int[,] dp = new int[n, 26];
        
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < 26; j++) {
                dp[i, j] = int.MaxValue / 2;
            }
        }
        
        for (int j = 0; j < 26; j++) {
            dp[0, j] = 0;
        }
        for (int i = 1; i < n; i++) {
            int cur = word[i] - 'A';
            int prev = word[i - 1] - 'A';
            int d = GetDistance(prev, cur);
            
            for (int j = 0; j < 26; j++) {
                dp[i, j] = Math.Min(dp[i, j], dp[i - 1, j] + d);
                if (prev == j) {
                    for (int k = 0; k < 26; k++) {
                        int d0 = GetDistance(k, cur);
                        dp[i, j] = Math.Min(dp[i, j], dp[i - 1, k] + d0);
                    }
                }
            }
        }
        
        int ans = int.MaxValue / 2;
        for (int j = 0; j < 26; j++) {
            ans = Math.Min(ans, dp[n - 1, j]);
        }
        return ans;
    }
}

###Go

func getDistance(p, q int) int {
    x1, y1 := p / 6, p % 6
    x2, y2 := q / 6, q % 6
    return abs(x1 - x2) + abs(y1 - y2)
}

func abs(x int) int {
    if x < 0 {
        return -x
    }
    return x
}

func minimumDistance(word string) int {
    n := len(word)
    dp := make([][]int, n)
    for i := range dp {
        dp[i] = make([]int, 26)
        for j := range dp[i] {
            dp[i][j] = 1 << 30
        }
    }
    for j := 0; j < 26; j++ {
        dp[0][j] = 0
    }
    
    for i := 1; i < n; i++ {
        cur := int(word[i] - 'A')
        prev := int(word[i-1] - 'A')
        d := getDistance(prev, cur)
        
        for j := 0; j < 26; j++ {
            dp[i][j] = min(dp[i][j], dp[i-1][j]+d)
            if prev == j {
                for k := 0; k < 26; k++ {
                    d0 := getDistance(k, cur)
                    dp[i][j] = min(dp[i][j], dp[i-1][k]+d0)
                }
            }
        }
    }
    
    ans := 1 << 30
    for j := 0; j < 26; j++ {
        ans = min(ans, dp[n-1][j])
    }
    return ans
}

###C

int getDistance(int p, int q) {
    int x1 = p / 6, y1 = p % 6;
    int x2 = q / 6, y2 = q % 6;
    return abs(x1 - x2) + abs(y1 - y2);
}

int minimumDistance(char* word) {
    int n = strlen(word);
    int** dp = (int**)malloc(n * sizeof(int*));
    for (int i = 0; i < n; i++) {
        dp[i] = (int*)malloc(26 * sizeof(int));
        for (int j = 0; j < 26; j++) {
            dp[i][j] = INT_MAX / 2;
        }
    }
    for (int j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (int i = 1; i < n; i++) {
        int cur = word[i] - 'A';
        int prev = word[i - 1] - 'A';
        int d = getDistance(prev, cur);
        
        for (int j = 0; j < 26; j++) {
            dp[i][j] = fmin(dp[i][j], dp[i - 1][j] + d);
            if (prev == j) {
                for (int k = 0; k < 26; k++) {
                    int d0 = getDistance(k, cur);
                    dp[i][j] = fmin(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    int ans = INT_MAX / 2;
    for (int j = 0; j < 26; j++) {
        if (ans > dp[n - 1][j]) {
            ans = dp[n - 1][j];
        }
    }
    for (int i = 0; i < n; i++) {
        free(dp[i]);
    }
    free(dp);
    
    return ans;
}

###JavaScript

var minimumDistance = function(word) {
    const n = word.length;
    const getDistance = (p, q) => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
    }
    
    for (let j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    let ans = Math.floor(Number.MAX_SAFE_INTEGER / 2);
    for (let j = 0; j < 26; j++) {
        ans = Math.min(ans, dp[n - 1][j]);
    }
    return ans;
};

###TypeScript

function minimumDistance(word: string): number {
    const n = word.length;
    const getDistance = (p: number, q: number): number => {
        const x1 = Math.floor(p / 6), y1 = p % 6;
        const x2 = Math.floor(q / 6), y2 = q % 6;
        return Math.abs(x1 - x2) + Math.abs(y1 - y2);
    };
    
    const dp: number[][] = new Array(n);
    for (let i = 0; i < n; i++) {
        dp[i] = new Array(26).fill(Math.floor(Number.MAX_SAFE_INTEGER / 2));
    }
    for (let j = 0; j < 26; j++) {
        dp[0][j] = 0;
    }
    
    for (let i = 1; i < n; i++) {
        const cur = word.charCodeAt(i) - 65;
        const prev = word.charCodeAt(i - 1) - 65;
        const d = getDistance(prev, cur);
        
        for (let j = 0; j < 26; j++) {
            dp[i][j] = Math.min(dp[i][j], dp[i - 1][j] + d);
            
            if (prev === j) {
                for (let k = 0; k < 26; k++) {
                    const d0 = getDistance(k, cur);
                    dp[i][j] = Math.min(dp[i][j], dp[i - 1][k] + d0);
                }
            }
        }
    }
    
    let ans = Math.floor(Number.MAX_SAFE_INTEGER / 2);
    for (let j = 0; j < 26; j++) {
        ans = Math.min(ans, dp[n - 1][j]);
    }
    return ans;
}

###Rust

impl Solution {
    fn get_distance(p: i32, q: i32) -> i32 {
        let x1 = p / 6;
        let y1 = p % 6;
        let x2 = q / 6;
        let y2 = q % 6;
        (x1 - x2).abs() + (y1 - y2).abs()
    }

    pub fn minimum_distance(word: String) -> i32 {
        let n = word.len();
        let word_bytes = word.as_bytes();
        let mut dp = vec![vec![i32::MAX >> 1; 26]; n];
        
        for j in 0..26 {
            dp[0][j] = 0;
        }
        
        for i in 1..n {
            let cur = (word_bytes[i] - b'A') as i32;
            let prev = (word_bytes[i - 1] - b'A') as i32;
            let d = Self::get_distance(prev, cur);
            
            for j in 0..26 {
                let j_i32 = j as i32;
                dp[i][j] = dp[i][j].min(dp[i - 1][j].saturating_add(d));
                
                if prev == j_i32 {
                    for k in 0..26 {
                        let d0 = Self::get_distance(k as i32, cur);
                        dp[i][j] = dp[i][j].min(dp[i - 1][k].saturating_add(d0));
                    }
                }
            }
        }
        
        let mut ans = i32::MAX >> 1;
        for j in 0..26 {
            ans = ans.min(dp[n - 1][j]);
        }
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(|\Sigma|N)$。

  • 空间复杂度:$O(|\Sigma|N)$。

昨天 — 2026年4月11日首页

三个相等元素之间的最小距离 II

2026年3月31日 10:57

方法一:遍历 + 哈希表

思路与算法

分析题目可知,所求三元组的距离实际上是广义三角形的三边之和,不管选取的三个点顺序如何,长度一定等于两倍的端点构成的线段的长度;换而言之,设最右侧点的下标是 $k$,最左侧点的下标是 $i$,所求的距离就是 $2 \times (k - i)$。

显然,对于所有相同元素对应下标构成的有效三元组,其最小距离必定在三个相邻元素构成的三元组间产生。以此为突破口,类比链表,我们可以通过维护前驱数组或者后继数组的方式快速求解当前元素的前驱和后继,以便计算距离并更新答案。

下面以后继数组为例讲解具体实现,采用前驱数组的方法只需要一次遍历,留给读者自行思考。

首先定义后继数组 $\textit{next}$,设 $\textit{next}[i]$ 记录了 $\textit{nums}[i]$ 在 $\textit{nums}$ 中下一次出现的位置。倒序遍历 $\textit{nums}$,配合哈希表记录 $\textit{nums}[i]$ 在倒序遍历中最近一次出现的位置,即可求出 $\textit{next}$ 数组。

随后遍历 $\textit{nums}$,借助 $\textit{next}$ 数组,我们可以在 $O(1)$ 的时间内求出与 $\textit{nums}[i]$ 值相同的两个相邻的后继元素,计算距离并更新答案即可。

代码

###C++

class Solution {
public:
    int minimumDistance(vector<int>& nums) {
        int n = nums.size();
        std::vector<int> next(n, -1);
        std::unordered_map<int, int> occur;
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.count(nums[i])) {
                next[i] = occur[nums[i]];
            }
            occur[nums[i]] = i;
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = std::min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
};

###JavaScript

var minimumDistance = function (nums) {
    const next = Array.from({ length: nums.length }).fill(-1);
    const occur = new Map();
    let ans = nums.length + 1;

    for (let i = nums.length - 1; i >= 0; i--) {
        if (occur.has(nums[i])) {
            next[i] = occur.get(nums[i]);
        }
        occur.set(nums[i], i);
    }

    for (let i = 0; i < nums.length; i++) {
        let secondPos = next[i];
        let thirdPos = next[secondPos];
        if (secondPos !== -1 && thirdPos !== -1) {
            ans = Math.min(ans, thirdPos - i);
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###TypeScript

function minimumDistance(nums: number[]): number {
    const next = Array.from<number>({ length: nums.length }).fill(-1);
    const occur = new Map<number, number>();
    let ans = nums.length + 1;

    for (let i = nums.length - 1; i >= 0; i--) {
        if (occur.has(nums[i])) {
            next[i] = occur.get(nums[i])!;
        }
        occur.set(nums[i], i);
    }

    for (let i = 0; i < nums.length; i++) {
        let secondPos = next[i];
        let thirdPos = next[secondPos];
        if (secondPos !== -1 && thirdPos !== -1) {
            ans = Math.min(ans, thirdPos - i);
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###Java

class Solution {
    public int minimumDistance(int[] nums) {
        int n = nums.length;
        int[] next = new int[n];
        Arrays.fill(next, -1);
        Map<Integer, Integer> occur = new HashMap<>();
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.containsKey(nums[i])) {
                next[i] = occur.get(nums[i]);
            }
            occur.put(nums[i], i);
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = Math.min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###C#

public class Solution {
    public int MinimumDistance(int[] nums) {
        int n = nums.Length;
        int[] next = new int[n];
        Array.Fill(next, -1);
        Dictionary<int, int> occur = new();
        int ans = n + 1;

        for (int i = n - 1; i >= 0; i--) {
            if (occur.TryGetValue(nums[i], out int val)) {
                next[i] = val;
            }
            occur[nums[i]] = i;
        }

        for (int i = 0; i < n; i++) {
            int secondPos = next[i];
            if (secondPos != -1) {
                int thirdPos = next[secondPos];
                if (thirdPos != -1) {
                    ans = Math.Min(ans, thirdPos - i);
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###Go

func minimumDistance(nums []int) int {
n := len(nums)
next := make([]int, n)
for i := range next {
next[i] = -1
}
occur := make(map[int]int)
ans := n + 1

for i := n - 1; i >= 0; i-- {
if val, ok := occur[nums[i]]; ok {
next[i] = val
}
occur[nums[i]] = i
}

for i := 0; i < n; i++ {
secondPos := next[i]
if secondPos != -1 {
thirdPos := next[secondPos]
if thirdPos != -1 {
if dist := thirdPos - i; dist < ans {
ans = dist
}
}
}
}

if ans == n + 1 {
return -1
}
return ans * 2
}

###Python

class Solution:
    def minimumDistance(self, nums: List[int]) -> int:
        n = len(nums)
        nxt = [-1] * n
        occur = {}
        ans = n + 1

        for i in range(n - 1, -1, -1):
            if nums[i] in occur:
                nxt[i] = occur[nums[i]]
            occur[nums[i]] = i

        for i in range(n):
            second_pos = nxt[i]
            if second_pos != -1:
                third_pos = nxt[second_pos]
                if third_pos != -1:
                    ans = min(ans, third_pos - i)

        return -1 if ans == n + 1 else ans * 2

###C

typedef struct {
    int key;
    int val;
    UT_hash_handle hh;
} HashItem;

HashItem *hashFindItem(HashItem **obj, int key) {
    HashItem *pEntry = NULL;
    HASH_FIND_INT(*obj, &key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, int key, int val) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = (HashItem *)malloc(sizeof(HashItem));
    pEntry->key = key;
    pEntry->val = val;
    HASH_ADD_INT(*obj, key, pEntry);
    return true;
}

bool hashSetItem(HashItem **obj, int key, int val) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        hashAddItem(obj, key, val);
    } else {
        pEntry->val = val;
    }
    return true;
}

int hashGetItem(HashItem **obj, int key, int defaultVal) {
    HashItem *pEntry = hashFindItem(obj, key);
    if (!pEntry) {
        return defaultVal;
    }
    return pEntry->val;
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);  
        free(curr);
    }
}

int minimumDistance(int* nums, int numsSize) {
    int* next = (int*)malloc(numsSize * sizeof(int));
    for (int i = 0; i < numsSize; i++) {
        next[i] = -1;
    }
    
    HashItem* occur = NULL;
    int ans = numsSize + 1;
    
    for (int i = numsSize - 1; i >= 0; i--) {
        int prevPos = hashGetItem(&occur, nums[i], -1);
        if (prevPos != -1) {
            next[i] = prevPos;
        }
        hashSetItem(&occur, nums[i], i);
    }
    
    for (int i = 0; i < numsSize; i++) {
        int secondPos = next[i];
        if (secondPos != -1) {
            int thirdPos = next[secondPos];
            if (thirdPos != -1) {
                int distance = thirdPos - i;
                if (distance < ans) {
                    ans = distance;
                }
            }
        }
    }
    
    free(next);
    hashFree(&occur);
    
    return ans == numsSize + 1 ? - 1 : ans * 2;
}

###Rust

use std::collections::HashMap;

impl Solution {
    pub fn minimum_distance(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        let mut next = vec![-1_isize; n];
        let mut occur = HashMap::new();
        let mut ans = n + 1;

        for i in (0..n).rev() {
            if let Some(&val) = occur.get(&nums[i]) {
                next[i] = val as isize;
            }
            occur.insert(nums[i], i);
        }

        for i in 0..n {
            let second_pos = next[i];
            if second_pos != -1 {
                let third_pos = next[second_pos as usize];
                if third_pos != -1 {
                    ans = ans.min(third_pos as usize - i);
                }
            }
        }

        if ans == n + 1 {
            -1
        } else {
            (ans * 2) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。倒序遍历构造 $\textit{next}$ 数组和正序遍历求解答案各需要 $O(n)$,哈希表各项操作平均复杂度为 $O(1)$。

  • 空间复杂度:$O(n)$,$\textit{next}$ 数组和哈希表需要 $O(n)$ 的空间。

昨天以前首页

三个相等元素之间的最小距离 I

2026年3月31日 10:57

方法一:暴力

思路与算法

本题是「3741. 三个相等元素之间的最小距离 II」的数据简化版,由于数据范围较小,可以直接使用暴力求解。

首先观察要求的绝对值距离之和计算公式:可以发现实际上这就是一个广义三角形的三边之和,不管选取的三个点顺序如何,长度一定等于两倍的端点构成的线段的长度;换而言之,设最右侧点的下标是 $k$,最左侧点的下标是 $i$,所求的距离就是 $2 \times (k - i)$。

故使用三重循环暴力枚举所有不同的顺序三元组,若 $\textit{nums}$ 中对应位置的元素相同,则根据上述分析计算距离,最后取全局最小值即为所求。

代码

###C++

class Solution {
public:
    int minimumDistance(vector<int>& nums) {
        int n = nums.size();
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = std::min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
};

###JavaScript

var minimumDistance = function (nums) {
    let ans = nums.length + 1;

    for (let i = 0; i < nums.length - 2; i++) {
        for (let j = i + 1; j < nums.length - 1; j++) {
            if (nums[i] !== nums[j]) {
                continue;
            }
            for (let k = j + 1; k < nums.length; k++) {
                if (nums[j] === nums[k]) {
                    ans = Math.min(ans, k - i);
                    break;
                }
            }
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
};

###TypeScript

function minimumDistance(nums: number[]): number {
    let ans = nums.length + 1;
    for (let i = 0; i < nums.length - 2; i++) {
        for (let j = i + 1; j < nums.length - 1; j++) {
            if (nums[i] !== nums[j]) {
                continue;
            }
            for (let k = j + 1; k < nums.length; k++) {
                if (nums[j] === nums[k]) {
                    ans = Math.min(ans, k - i);
                    break;
                }
            }
        }
    }

    if (ans === nums.length + 1) {
        return -1;
    } else {
        return ans * 2;
    }
}

###Java

class Solution {
    public int minimumDistance(int[] nums) {
        int n = nums.length;
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = Math.min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###C#

public class Solution {
    public int MinimumDistance(int[] nums) {
        int n = nums.Length;
        int ans = n + 1;

        for (int i = 0; i < n - 2; i++) {
            for (int j = i + 1; j < n - 1; j++) {
                if (nums[i] != nums[j]) {
                    continue;
                }
                for (int k = j + 1; k < n; k++) {
                    if (nums[j] == nums[k]) {
                        ans = Math.Min(ans, k - i);
                        break;
                    }
                }
            }
        }

        return ans == n + 1 ? -1 : ans * 2;
    }
}

###Go

func minimumDistance(nums []int) int {
n := len(nums)
ans := n + 1

for i := 0; i < n-2; i++ {
for j := i + 1; j < n-1; j++ {
if nums[i] != nums[j] {
continue
}
for k := j + 1; k < n; k++ {
if nums[j] == nums[k] {
if dist := k - i; dist < ans {
ans = dist
}
break
}
}
}
}

if ans == n+1 {
return -1
}
return ans * 2
}

###Python

class Solution:
    def minimumDistance(self, nums: List[int]) -> int:
        n = len(nums)
        ans = n + 1

        for i in range(n - 2):
            for j in range(i + 1, n - 1):
                if nums[i] != nums[j]:
                    continue
                for k in range(j + 1, n):
                    if nums[j] == nums[k]:
                        ans = min(ans, k - i)
                        break

        return -1 if ans == n + 1 else ans * 2

###C

int minimumDistance(int* nums, int numsSize) {
    int ans = numsSize + 1;

    for (int i = 0; i < numsSize - 2; i++) {
        for (int j = i + 1; j < numsSize - 1; j++) {
            if (nums[i] != nums[j]) {
                continue;
            }
            for (int k = j + 1; k < numsSize; k++) {
                if (nums[j] == nums[k]) {
                    if (k - i < ans) {
                        ans = k - i;
                    }
                    break;
                }
            }
        }
    }

    return ans == numsSize + 1 ? -1 : ans * 2;
}

###Rust

impl Solution {
    pub fn minimum_distance(nums: Vec<i32>) -> i32 {
        let n = nums.len();
        let mut ans = n + 1;

        if n < 3 {
           return -1;
        }

        for i in 0..n - 2 {
            for j in i + 1..n - 1 {
                if nums[i] != nums[j] {
                    continue;
                }
                for k in j + 1..n {
                    if nums[j] == nums[k] {
                        ans = ans.min(k - i);
                        break;
                    }
                }
            }
        }

        if ans == n + 1 {
            -1
        } else {
            (ans * 2) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n^3)$,其中 $n$ 是 $\textit{nums}$ 的长度,求解用到三重循环,每重循环需要 $O(n)$,故总时间复杂度是 $O(n^3)$。

  • 空间复杂度:$O(1)$,只声明了常数个变量。

网格图中机器人回家的最小代价

2021年11月28日 18:21

方法一:贪心

提示 $1$

如果在某一条路径中,相邻的两步分别为横向(左/右)和纵向(上/下)移动,那么交换这两步前后,路径的总代价不变。

提示 $1$ 解释

由于路径的其它部分不会改变,对应部分的代价也不会改变,因此我们只需要考虑交换的两步。不妨假设在这两步的过程中,机器人从 $(r, c)$ 移动到了 $(r + 1, c + 1)$。

考虑交换前后两种不同的移动方式(用 $\rightarrow$ 表示沿着某个方向一直移动,下同):

  • $(r, c) \rightarrow (r + 1, c) \rightarrow (r + 1, c + 1)$:第一步移动到 $r + 1$ 行,代价为 $\textit{rowCost}[r + 1]$;第二步移动到 $c + 1$ 列,代价为 $\textit{colCost}[c + 1]$。总代价为 $\textit{rowCost}[r + 1] + \textit{colCost}[c + 1]$。

  • $(r, c) \rightarrow (r, c + 1) \rightarrow (r + 1, c + 1)$:第一步移动到 $c + 1$ 列,代价为 $\textit{colCost}[c + 1]$;第二步移动到 $r + 1$ 行,代价为 $\textit{rowCost}[r + 1]$。总代价为 $\textit{colCost}[c + 1] + \textit{rowCost}[r + 1]$。

可以发现,这两种方式代价相同。因此,路径的总代价也不会改变。

提示 $2$

如果某一条路径中包含相反操作(即同时含有向左和向右的操作,或同时含有向上和向下的操作),那么这条路径的代价一定不优于将这些操作成对抵消后的路径。

除此之外,任意不包含任何相反操作的路径对应的总代价一定最小。

提示 $2$ 解释

我们首先考虑前半部分。

不失一般性地,首先考虑从 $(r, c)$ 到 $(r + x, c) (x \ge 0)$ 的两种路径。一种路径为 $(r, c) \rightarrow (r + x, c)$,另一种路径为 $(r, c) \rightarrow (r, c + 1) \rightarrow (r + x, c + 1) \rightarrow (r + x, c)$。计算可得,后者相对于前者多出了 $\textit{colCost}[c] + \textit{colCost}[c + 1] \ge 0$ 的总代价,亦即前者一定更优。

而对于一般的存在相反方向操作的路径,其中必定包含上述的路径片段;而将路径片段中的相反操作抵消后,新的路径在总代价上一定不高于原路径。因此,我们可以递归地抵消这些相反操作,直至路径不包含任何相反操作,同时在每次操作时,总代价一定不会增加。

综上可知,对于任意包含相反操作的路径,一定存在一个不包含相反操作的路径,后者的总代价小于等于前者。因此,最小总代价对应的路径一定是不包含相反操作的路径。

而对于所有的这些不包含任何相反操作的路径,这些路径一定是由一些(数量可能为 $0$)单方向的横向操作和一些(数量可能为 $0$)单方向的纵向操作组成。根据 提示 $1$,我们可以任意交换这些操作,且总代价不变。因此,任意不包含任何相反操作的路径对应的总代价一定最小。

思路与算法

根据 提示 $2$,我们只需要构造任意一条从起点到家的不包含相反操作的路径,该路径对应的总代价即为最小总代价。

为了方便计算,我们先让机器人向上或向下移动至家所在行,再让机器人向左或向右移动至家所在的格子,并在这过程中计算总代价。

而对于如何确定移动的方向,我们行间的上下移动为例:我们比较机器人所在行号 $r_1$ 与家所在行号 $r_2$,如果 $r_1 < r_2$,则我们需要向下移动;如果 $r_1 > r_2$,则我们需要向上移动;如果 $r_1 = r_2$,则我们无需移动。

最终,我们返回该总代价作为答案。

代码

###C++

class Solution {
public:
    int minCost(vector<int>& startPos, vector<int>& homePos, vector<int>& rowCosts, vector<int>& colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1){
            res += accumulate(rowCosts.begin() + r1 + 1, rowCosts.begin() + r2 + 1, 0);
        }
        else{
            res += accumulate(rowCosts.begin() + r2, rowCosts.begin() + r1, 0);
        }
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1){
            res += accumulate(colCosts.begin() + c1 + 1, colCosts.begin() + c2 + 1, 0);
        }
        else{
            res += accumulate(colCosts.begin() + c2, colCosts.begin() + c1, 0);
        }
        return res;
    }
};

###Python

class Solution:
    def minCost(self, startPos: List[int], homePos: List[int], rowCosts: List[int], colCosts: List[int]) -> int:
        r1, c1 = startPos[0], startPos[1]
        r2, c2 = homePos[0], homePos[1]
        res = 0   # 总代价
        # 移动至家所在行,判断行间移动方向并计算对应代价
        if r2 >= r1:
            for i in range(r1 + 1, r2 + 1):
                res += rowCosts[i]
        else:
            for i in range(r2, r1):
                res += rowCosts[i]
        # 移动至家所在位置,判断列间移动方向并计算对应代价
        if c2 >= c1:
            for i in range(c1 + 1, c2 + 1):
                res += colCosts[i]
        else:
            for i in range(c2, c1):
                res += colCosts[i]
        return res

###Java

class Solution {
    public int minCost(int[] startPos, int[] homePos, int[] rowCosts, int[] colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1) {
            for (int i = r1 + 1; i <= r2; i++) {
                res += rowCosts[i];
            }
        } else {
            for (int i = r2; i < r1; i++) {
                res += rowCosts[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1) {
            for (int i = c1 + 1; i <= c2; i++) {
                res += colCosts[i];
            }
        } else {
            for (int i = c2; i < c1; i++) {
                res += colCosts[i];
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int MinCost(int[] startPos, int[] homePos, int[] rowCosts, int[] colCosts) {
        int r1 = startPos[0], c1 = startPos[1];
        int r2 = homePos[0], c2 = homePos[1];
        int res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if (r2 >= r1) {
            for (int i = r1 + 1; i <= r2; i++) {
                res += rowCosts[i];
            }
        } else {
            for (int i = r2; i < r1; i++) {
                res += rowCosts[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if (c2 >= c1) {
            for (int i = c1 + 1; i <= c2; i++) {
                res += colCosts[i];
            }
        } else {
            for (int i = c2; i < c1; i++) {
                res += colCosts[i];
            }
        }
        
        return res;
    }
}

###Go

func minCost(startPos []int, homePos []int, rowCosts []int, colCosts []int) int {
    r1, c1 := startPos[0], startPos[1]
    r2, c2 := homePos[0], homePos[1]
    res := 0 // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if r2 >= r1 {
        for i := r1 + 1; i <= r2; i++ {
            res += rowCosts[i]
        }
    } else {
        for i := r2; i < r1; i++ {
            res += rowCosts[i]
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if c2 >= c1 {
        for i := c1 + 1; i <= c2; i++ {
            res += colCosts[i]
        }
    } else {
        for i := c2; i < c1; i++ {
            res += colCosts[i]
        }
    }
    
    return res
}

###C

int minCost(int* startPos, int startPosSize, int* homePos, int homePosSize, 
            int* rowCosts, int rowCostsSize, int* colCosts, int colCostsSize) {
    int r1 = startPos[0], c1 = startPos[1];
    int r2 = homePos[0], c2 = homePos[1];
    int res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (int i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (int i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (int i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (int i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
}

###JavaScript

var minCost = function(startPos, homePos, rowCosts, colCosts) {
    const r1 = startPos[0], c1 = startPos[1];
    const r2 = homePos[0], c2 = homePos[1];
    let res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (let i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (let i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (let i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (let i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
};

###TypeScript

function minCost(startPos: number[], homePos: number[], rowCosts: number[], colCosts: number[]): number {
    const r1 = startPos[0], c1 = startPos[1];
    const r2 = homePos[0], c2 = homePos[1];
    let res = 0;   // 总代价
    
    // 移动至家所在行,判断行间移动方向并计算对应代价
    if (r2 >= r1) {
        for (let i = r1 + 1; i <= r2; i++) {
            res += rowCosts[i];
        }
    } else {
        for (let i = r2; i < r1; i++) {
            res += rowCosts[i];
        }
    }
    
    // 移动至家所在位置,判断列间移动方向并计算对应代价
    if (c2 >= c1) {
        for (let i = c1 + 1; i <= c2; i++) {
            res += colCosts[i];
        }
    } else {
        for (let i = c2; i < c1; i++) {
            res += colCosts[i];
        }
    }
    
    return res;
}

###Rust

impl Solution {
    pub fn min_cost(start_pos: Vec<i32>, home_pos: Vec<i32>, row_costs: Vec<i32>, col_costs: Vec<i32>) -> i32 {
        let r1 = start_pos[0] as usize;
        let c1 = start_pos[1] as usize;
        let r2 = home_pos[0] as usize;
        let c2 = home_pos[1] as usize;
        let mut res = 0;   // 总代价
        
        // 移动至家所在行,判断行间移动方向并计算对应代价
        if r2 >= r1 {
            for i in (r1 + 1)..=r2 {
                res += row_costs[i];
            }
        } else {
            for i in r2..r1 {
                res += row_costs[i];
            }
        }
        
        // 移动至家所在位置,判断列间移动方向并计算对应代价
        if c2 >= c1 {
            for i in (c1 + 1)..=c2 {
                res += col_costs[i];
            }
        } else {
            for i in c2..c1 {
                res += col_costs[i];
            }
        }
        
        res
    }
}

复杂度分析

  • 时间复杂度:$O(m + n)$,其中 $m$ 为网格图的行数,$n$ 为网格图的列数。即为计算最小代价的时间复杂度。

  • 空间复杂度:$O(1)$。

等和矩阵分割 II

2026年3月16日 11:36

方法一:旋转矩阵 + 哈希表 + 枚举上半矩阵之和

思路与算法

本题是「等和矩阵分割 I」的增强版,在这一题的基础上,添加了 删除至多一个单元格 并且 删除后剩余部分必须保持连通 的条件。

那么需要进行删除的时候,我们需要考虑两条分割线的选取以及删除分割线哪一边的元素,为了简化思路,假设我们只判断是否存在水平分割线,并且进行删除操作时只删除水平分割线以上的元素。

能够发现,我们将矩阵进行 $3$ 次 $90$ 度的旋转,每次旋转后进行如上述所说的简化操作,就能够覆盖枚举分割线以及枚举删除元素的位置所带来的 $4$ 种不同情况。

接下来分析如何判断:

  1. 假设当前 $\textit{grid}$ 上半矩阵之和为 $\textit{sum}$,整个矩阵 $\textit{grid}$ 之和为 $\textit{total}$,那么 $\textit{grid}$ 下半矩阵之和为 $\textit{total} - \textit{sum}$。
  2. 假设我们要删除的元素为 $x$,那么需要满足 $\textit{sum} - x == \textit{total} - \textit{sum}$,于是有:$x == \textit{sum} * 2 - \textit{total}$。
  3. 因此在枚举完每一行之后只需要判断是否存在 $\textit{grid}[i][j]$ 满足 $\textit{grid}[i][j] == \textit{sum} * 2 - \textit{total}$ 即可。

我们可以使用一个集合来保存出现过的元素,便于判断是否存在满足题目要求的元素,集合中可以预存一个 $0$,这样可以将删除元素与不删除元素合并为一种情况。

特殊情况处理:

  1. 矩阵 $\textit{grid}$ 在遍历第一行的情况:
    在遍历第一行时能够删除的元素只有第一行的首尾元素,因此在统计完第一行的和之后需要判断 $\textit{grid}[0][0]$ 或者 $\textit{grid}[0][n - 1]$ 或者 $0$ 是否满足题目要求。
  2. 矩阵 $\textit{grid}$ 只有一列的情况:
    $\textit{grid}$ 只有一列时能够删除的元素只有首行以及尾行的元素,需要在遍历第 $i$ 行后判断 $\textit{grid}[0][0]$ 或者 $\textit{grid}[i][0]$ 或者 $0$ 是否满足题目要求。
  3. 当矩阵 $\textit{grid}$ 只有一行时可以跳过,因为无法进行水平分割。

其他情况中 $\textit{grid}$ 上半矩阵中所有的元素均可被删除。

在 $3$ 次旋转后就能够将所有情况覆盖到。

代码

###C++

class Solution {
public:
    vector<vector<int>> rotation(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector tmp(n, vector<int>(m));
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
    bool canPartitionGrid(vector<vector<int>>& grid) {
        long long total = 0;
        long long sum;
        long long tag;
        int m = grid.size();
        int n = grid[0].size();
        unordered_set<long long> exist;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            exist.clear();
            exist.insert(0);
            sum = 0;
            m = grid.size();
            n = grid[0].size();
            if (m < 2) {
                grid = rotation(grid);
                continue;
            }
            if(n == 1){
                for(int i = 0; i < m - 1; i++){
                    sum += grid[i][0];
                    tag = sum * 2 - total;
                    if(tag == 0 || tag == grid[0][0] || tag == grid[i][0]){
                        return true;
                    }
                }
                grid = rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for(int j = 0; j < n; j++){
                    exist.insert(grid[i][j]);
                    sum += grid[i][j];
                }
                tag = sum * 2 - total;
                if(i == 0){
                    if(tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]){
                        return true;
                    }
                    continue;
                }
                if(exist.contains(tag)){
                    return true;
                }
            }
            grid = rotation(grid);
        }
        return false;
    }
};

###JavaScript

var canPartitionGrid = function(grid) {
    let total = 0;
    let m = grid.length;
    let n = grid[0].length;
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    for (let k = 0; k < 4; k++) {
        const exist = new Set();
        exist.add(0);
        let sum = 0;
        m = grid.length;
        n = grid[0].length;
        if (m < 2) {
            grid = rotation(grid);
            continue;
        }
        if (n == 1) {
            for (let i = 0; i < m - 1; i++) {
                sum += grid[i][0];
                let tag = sum * 2 - total;
                if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                    return true;
                }
            }
            grid = rotation(grid);
            continue;
        }
        for (let i = 0; i < m - 1; i++) {
            for (let j = 0; j < n; j++) {
                exist.add(grid[i][j]);
                sum += grid[i][j];
            }
            let tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                    return true;
                }
                continue;
            }
            if (exist.has(tag)) {
                return true;
            }
        }
        grid = rotation(grid);
    }
    return false;
};

function rotation(grid) {
    const m = grid.length, n = grid[0].length;
    const tmp = Array.from({ length: n }, () => Array(m).fill(0));
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

###TypeScript

function canPartitionGrid(grid: number[][]): boolean {
    let total = 0;
    let m = grid.length;
    let n = grid[0].length;
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    for (let k = 0; k < 4; k++) {
        const exist = new Set<number>();
        exist.add(0);
        let sum = 0;
        m = grid.length;
        n = grid[0].length;
        if (m < 2) {
            grid = rotation(grid);
            continue;
        }
        if (n == 1) {
            for (let i = 0; i < m - 1; i++) {
                sum += grid[i][0];
                let tag = sum * 2 - total;
                if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                    return true;
                }
            }
            grid = rotation(grid);
            continue;
        }
        for (let i = 0; i < m - 1; i++) {
            for (let j = 0; j < n; j++) {
                exist.add(grid[i][j]);
                sum += grid[i][j];
            }
            let tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                    return true;
                }
                continue;
            }
            if (exist.has(tag)) {
                return true;
            }
        }
        grid = rotation(grid);
    }
    return false;
}

function rotation(grid: number[][]): number[][] {
    const m = grid.length, n = grid[0].length;
    const tmp: number[][] = Array.from({ length: n }, () => Array(m).fill(0));
    for (let i = 0; i < m; i++) {
        for (let j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

###Java

class Solution {
    public boolean canPartitionGrid(int[][] grid) {
        long total = 0;
        int m = grid.length;
        int n = grid[0].length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            Set<Long> exist = new HashSet<>();
            exist.add(0L);
            long sum = 0;
            m = grid.length;
            n = grid[0].length;
            if (m < 2) {
                grid = rotation(grid);
                continue;
            }
            if (n == 1) {
                for (int i = 0; i < m - 1; i++) {
                    sum += grid[i][0];
                    long tag = sum * 2 - total;
                    if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                        return true;
                    }
                }
                grid = rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for (int j = 0; j < n; j++) {
                    exist.add((long) grid[i][j]);
                    sum += grid[i][j];
                }
                long tag = sum * 2 - total;
                if (i == 0) {
                    if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                        return true;
                    }
                    continue;
                }
                if (exist.contains(tag)) {
                    return true;
                }
            }
            grid = rotation(grid);
        }
        return false;
    }

    public int[][] rotation(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        int[][] tmp = new int[n][m];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
}

###C#

public class Solution {
    public bool CanPartitionGrid(int[][] grid) {
        long total = 0;
        int m = grid.Length;
        int n = grid[0].Length;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                total += grid[i][j];
            }
        }
        for (int k = 0; k < 4; k++) {
            HashSet<long> exist = new HashSet<long>();
            exist.Add(0);
            long sum = 0;
            m = grid.Length;
            n = grid[0].Length;
            if (m < 2) {
                grid = Rotation(grid);
                continue;
            }
            if (n == 1) {
                for (int i = 0; i < m - 1; i++) {
                    sum += grid[i][0];
                    long tag = sum * 2 - total;
                    if (tag == 0 || tag == grid[0][0] || tag == grid[i][0]) {
                        return true;
                    }
                }
                grid = Rotation(grid);
                continue;
            }
            for (int i = 0; i < m - 1; i++) {
                for (int j = 0; j < n; j++) {
                    exist.Add(grid[i][j]);
                    sum += grid[i][j];
                }
                long tag = sum * 2 - total;
                if (i == 0) {
                    if (tag == 0 || tag == grid[0][0] || tag == grid[0][n - 1]) {
                        return true;
                    }
                    continue;
                }
                if (exist.Contains(tag)) {
                    return true;
                }
            }
            grid = Rotation(grid);
        }
        return false;
    }

    public int[][] Rotation(int[][] grid) {
        int m = grid.Length, n = grid[0].Length;
        int[][] tmp = new int[n][];
        for (int i = 0; i < n; i++) {
            tmp[i] = new int[m];
        }
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        return tmp;
    }
}

###Go

func canPartitionGrid(grid [][]int) bool {
    var total int64 = 0
    m, n := len(grid), len(grid[0])
    for i := 0; i < m; i++ {
        for j := 0; j < n; j++ {
            total += int64(grid[i][j])
        }
    }
    for k := 0; k < 4; k++ {
        exist := make(map[int64]bool)
        exist[0] = true
        var sum int64 = 0
        m, n = len(grid), len(grid[0])
        if m < 2 {
            grid = rotation(grid)
            continue
        }
        if n == 1 {
            for i := 0; i < m-1; i++ {
                sum += int64(grid[i][0])
                tag := sum*2 - total
                if tag == 0 || tag == int64(grid[0][0]) || tag == int64(grid[i][0]) {
                    return true
                }
            }
            grid = rotation(grid)
            continue
        }
        for i := 0; i < m-1; i++ {
            for j := 0; j < n; j++ {
                exist[int64(grid[i][j])] = true
                sum += int64(grid[i][j])
            }
            tag := sum*2 - total
            if i == 0 {
                if tag == 0 || tag == int64(grid[0][0]) || tag == int64(grid[0][n-1]) {
                    return true
                }
                continue
            }
            if exist[tag] {
                return true
            }
        }
        grid = rotation(grid)
    }
    return false
}

func rotation(grid [][]int) [][]int {
    m, n := len(grid), len(grid[0])
    tmp := make([][]int, n)
    for i := range tmp {
        tmp[i] = make([]int, m)
    }
    for i := 0; i < m; i++ {
        for j := 0; j < n; j++ {
            tmp[j][m-1-i] = grid[i][j]
        }
    }
    return tmp
}

###Python

class Solution:
    def canPartitionGrid(self, grid: List[List[int]]) -> bool:
        total = 0
        m = len(grid)
        n = len(grid[0])
        for i in range(m):
            for j in range(n):
                total += grid[i][j]
        for _ in range(4):
            exist = set()
            exist.add(0)
            sum_val = 0
            m = len(grid)
            n = len(grid[0])
            if m < 2:
                grid = self.rotation(grid)
                continue
            if n == 1:
                for i in range(m - 1):
                    sum_val += grid[i][0]
                    tag = sum_val * 2 - total
                    if tag == 0 or tag == grid[0][0] or tag == grid[i][0]:
                        return True
                grid = self.rotation(grid)
                continue
            for i in range(m - 1):
                for j in range(n):
                    exist.add(grid[i][j])
                    sum_val += grid[i][j]
                tag = sum_val * 2 - total
                if i == 0:
                    if tag == 0 or tag == grid[0][0] or tag == grid[0][n - 1]:
                        return True
                    continue
                if tag in exist:
                    return True
            grid = self.rotation(grid)
        return False

    def rotation(self, grid: List[List[int]]) -> List[List[int]]:
        m = len(grid)
        n = len(grid[0])
        tmp = [[0] * m for _ in range(n)]
        for i in range(m):
            for j in range(n):
                tmp[j][m - 1 - i] = grid[i][j]
        return tmp

###C

typedef struct {
    long long key;
    UT_hash_handle hh;
} HashItem;

static inline HashItem* hashFindItem(HashItem **obj, long long key) {
    HashItem *pEntry = NULL;
    HASH_FIND(hh, *obj, &key, sizeof(key), pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, long long key) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = malloc(sizeof(HashItem));
    if (!pEntry) return false;
    pEntry->key = key;
    HASH_ADD(hh, *obj, key, sizeof(key), pEntry);
    return true;
}

void hashFree(HashItem **obj) {
    HashItem *curr, *tmp;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);
        free(curr);
    }
    *obj = NULL;
}


int** rotation(int** grid, int m, int n, int* newM, int* newN) {
    *newM = n;
    *newN = m;
    int** tmp = malloc(n * sizeof(int*));
    int* data = malloc(n * m * sizeof(int));
    if (!tmp || !data) {
        free(tmp);
        free(data);
        return NULL;
    }
    for (int i = 0; i < n; i++) {
        tmp[i] = data + i * m;
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            tmp[j][m - 1 - i] = grid[i][j];
        }
    }
    return tmp;
}

void freeGrid(int** grid, int rows) {
    if (grid && grid[0]) {
        free(grid[0]);
    }
    free(grid);
}

static inline bool checkAndReturnTrue(HashItem **exist, int** currentGrid, int currentM, int** originalGrid) {
    hashFree(exist);
    if (currentGrid != originalGrid) {
        freeGrid(currentGrid, currentM);
    }
    return true;
}

static inline void rotateAndUpdate(int** *currentGrid, int *currentM, int *currentN, int** originalGrid) {
    int newM, newN;
    int** newGrid = rotation(*currentGrid, *currentM, *currentN, &newM, &newN);
    if (*currentGrid != originalGrid) {
        freeGrid(*currentGrid, *currentM);
    }
    *currentGrid = newGrid;
    *currentM = newM;
    *currentN = newN;
}

bool canPartitionGrid(int** grid, int gridSize, int* gridColSize) {
    const int m = gridSize;
    const int n = gridColSize[0];
    long long total = 0;

    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            total += grid[i][j];
        }
    }
    int currentM = m, currentN = n;
    int** currentGrid = grid;

    for (int k = 0; k < 4; k++) {
        HashItem* exist = NULL;
        hashAddItem(&exist, 0LL);
        long long sum = 0;
        if (currentM < 2 || currentN == 1) {
            if (currentN == 1 && currentM >= 2) {
                for (int i = 0; i < currentM - 1; i++) {
                    sum += currentGrid[i][0];
                    long long tag = sum * 2 - total;
                    if (tag == 0 || tag == currentGrid[0][0] || tag == currentGrid[i][0]) {
                        return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
                    }
                }
            }
            rotateAndUpdate(&currentGrid, &currentM, &currentN, grid);
            hashFree(&exist);
            continue;
        }

        for (int i = 0; i < currentM - 1; i++) {
            for (int j = 0; j < currentN; j++) {
                hashAddItem(&exist, (long long)currentGrid[i][j]);
                sum += currentGrid[i][j];
            }
            long long tag = sum * 2 - total;
            if (i == 0) {
                if (tag == 0 || tag == currentGrid[0][0] || tag == currentGrid[0][currentN - 1]) {
                    return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
                }
                continue;
            }
            if (hashFindItem(&exist, tag)) {
                return checkAndReturnTrue(&exist, currentGrid, currentM, grid);
            }
        }

        rotateAndUpdate(&currentGrid, &currentM, &currentN, grid);
        hashFree(&exist);
    }

    if (currentGrid != grid) {
        freeGrid(currentGrid, currentM);
    }

    return false;
}

###Rust

use std::collections::HashSet;

impl Solution {
    fn rotation(grid: &Vec<Vec<i32>>) -> Vec<Vec<i32>> {
        let m = grid.len();
        let n = grid[0].len();
        let mut tmp = vec![vec![0; m]; n];

        for i in 0..m {
            for j in 0..n {
                tmp[j][m - 1 - i] = grid[i][j];
            }
        }
        tmp
    }

    pub fn can_partition_grid(grid: Vec<Vec<i32>>) -> bool {
        let mut grid = grid;
        let mut total: i64 = 0;
        let mut sum: i64;
        let mut tag: i64;
        let mut m = grid.len();
        let mut n = grid[0].len();
        for i in 0..m {
            for j in 0..n {
                total += grid[i][j] as i64;
            }
        }

        let mut exist = HashSet::new();

        for _ in 0..4 {
            exist.clear();
            exist.insert(0);
            sum = 0;
            m = grid.len();
            n = grid[0].len();
            if m < 2 {
                grid = Self::rotation(&grid);
                continue;
            }
            if n == 1 {
                for i in 0..m - 1 {
                    sum += grid[i][0] as i64;
                    tag = sum * 2 - total;
                    if tag == 0 || tag == grid[0][0] as i64 || tag == grid[i][0] as i64 {
                        return true;
                    }
                }
                grid = Self::rotation(&grid);
                continue;
            }

            for i in 0..m - 1 {
                for j in 0..n {
                    exist.insert(grid[i][j] as i64);
                    sum += grid[i][j] as i64;
                }

                tag = sum * 2 - total;

                if i == 0 {
                    if tag == 0 || tag == grid[0][0] as i64 || tag == grid[0][n - 1] as i64 {
                        return true;
                    }
                    continue;
                }

                if exist.contains(&tag) {
                    return true;
                }
            }

            grid = Self::rotation(&grid);
        }

        false
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 为 $\textit{grid}$ 矩阵的行数,$n$ 为 $\textit{grid}$ 矩阵的列数。

  • 空间复杂度:$O(mn)$,其中 $m$ 为 $\textit{grid}$ 矩阵的行数,$n$ 为 $\textit{grid}$ 矩阵的列数。

矩阵的最大非负积

2020年10月9日 11:42

方法一:动态规划

思路与算法

由于矩阵中的元素有正有负,要想得到最大积,我们只存储移动过程中的最大积是不够的,例如当前的最大积为正数时,乘上一个负数后,反而不如一个负数乘上相同的负数得到的积大。

因此,我们需要存储的是移动过程中的积的范围,也就是积的最小值以及最大值。由于只能向下或者向右走,我们可以考虑使用动态规划的方法解决本题。

设 $\textit{maxgt}[i][j], \textit{minlt}[i][j]$ 分别为从坐标 $(0, 0)$ 出发,到达位置 $(i, j)$ 时乘积的最大值与最小值。由于我们只能向下或者向右走,因此乘积的取值必然只与 $(i, j-1)$ 和 $(i-1, j)$ 两个位置有关。

对于乘积的最大值而言:若 $\textit{grid}[i][j] \ge 0$,则 $\textit{maxgt}[i][j]$ 的取值取决于这两个位置的最大值,此时

$$
\textit{maxgt}[i][j] = \max(\textit{maxgt}[i][j-1], \textit{maxgt}[i-1][j]) \times \textit{grid}[i][j]
$$

相反地,若 $\textit{grid}[i][j] \le 0$,则 $\textit{maxgt}[i][j]$ 的取值取决于这两个位置的最小值,此时

$$
\textit{maxgt}[i][j] = \min(\textit{minlt}[i][j-1], \textit{minlt}[i-1][j]) \times \textit{grid}[i][j]
$$

计算乘积的最小值也是类似的思路。若 $\textit{grid}[i][j] \ge 0$,此时

$$
\textit{mingt}[i][j] = \min(\textit{mingt}[i][j-1], \textit{mingt}[i-1][j]) \times \textit{grid}[i][j]
$$

若 $\textit{grid}[i][j] \le 0$,此时

$$
\textit{mingt}[i][j] = \max(\textit{maxgt}[i][j-1], \textit{maxgt}[i-1][j]) \times \textit{grid}[i][j]
$$

特别地,当 $i=0$ 时,只需要从 $(i, j-1)$ 进行转移;$j=0$ 时,只需要从 $(i-1, j)$ 进行转移;$i=0$ 且 $j=0$ 时,$\textit{maxgt}[i][j]$ 与 $\textit{mingt}[i][j]$ 的值均为左上角的元素值 $\textit{grid}[i][j]$。

最终的答案即为 $\textit{maxgt}[m-1][n-1]$,其中 $m$ 和 $n$ 分别是矩阵的行数与列数。

代码

###C++

class Solution {
public:
    int maxProductPath(vector<vector<int>>& grid) {
        const int mod = 1000000000 + 7;
        int m = grid.size(), n = grid[0].size();
        vector<vector<long long>> maxgt(m, vector<long long>(n));
        vector<vector<long long>> minlt(m, vector<long long>(n));

        maxgt[0][0] = minlt[0][0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
        }

        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    maxgt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                } else {
                    maxgt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                }
            }
        }
        if (maxgt[m - 1][n - 1] < 0) {
            return -1;
        } else {
            return maxgt[m - 1][n - 1] % mod;
        }
    }
};

###Java

class Solution {
    public int maxProductPath(int[][] grid) {
        final int MOD = 1000000000 + 7;
        int m = grid.length, n = grid[0].length;
        long[][] maxgt = new long[m][n];
        long[][] minlt = new long[m][n];

        maxgt[0][0] = minlt[0][0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
        }

        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                } else {
                    maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                    minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                }
            }
        }
        if (maxgt[m - 1][n - 1] < 0) {
            return -1;
        } else {
            return (int) (maxgt[m - 1][n - 1] % MOD);
        }
    }
}

###Python

class Solution:
    def maxProductPath(self, grid: List[List[int]]) -> int:
        mod = 10**9 + 7
        m, n = len(grid), len(grid[0])
        maxgt = [[0] * n for _ in range(m)]
        minlt = [[0] * n for _ in range(m)]

        maxgt[0][0] = minlt[0][0] = grid[0][0]
        for i in range(1, n):
            maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i]
        for i in range(1, m):
            maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0]
        
        for i in range(1, m):
            for j in range(1, n):
                if grid[i][j] >= 0:
                    maxgt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j]
                    minlt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j]
                else:
                    maxgt[i][j] = min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j]
                    minlt[i][j] = max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j]
        
        if maxgt[m - 1][n - 1] < 0:
            return -1
        return maxgt[m - 1][n - 1] % mod

###C#

public class Solution {
    public int MaxProductPath(int[][] grid) {
        const int MOD = 1000000007;
        int m = grid.Length, n = grid[0].Length;
        long[,] maxgt = new long[m, n];
        long[,] minlt = new long[m, n];

        maxgt[0, 0] = minlt[0, 0] = grid[0][0];
        for (int i = 1; i < n; i++) {
            maxgt[0, i] = minlt[0, i] = maxgt[0, i - 1] * grid[0][i];
        }
        for (int i = 1; i < m; i++) {
            maxgt[i, 0] = minlt[i, 0] = maxgt[i - 1, 0] * grid[i][0];
        }
        for (int i = 1; i < m; i++) {
            for (int j = 1; j < n; j++) {
                if (grid[i][j] >= 0) {
                    long maxPrev = Math.Max(maxgt[i, j - 1], maxgt[i - 1, j]);
                    long minPrev = Math.Min(minlt[i, j - 1], minlt[i - 1, j]);
                    maxgt[i, j] = maxPrev * grid[i][j];
                    minlt[i, j] = minPrev * grid[i][j];
                } else {
                    long maxPrev = Math.Max(maxgt[i, j - 1], maxgt[i - 1, j]);
                    long minPrev = Math.Min(minlt[i, j - 1], minlt[i - 1, j]);
                    maxgt[i, j] = minPrev * grid[i][j];
                    minlt[i, j] = maxPrev * grid[i][j];
                }
            }
        }
        
        if (maxgt[m - 1, n - 1] < 0) {
            return -1;
        } else {
            return (int)(maxgt[m - 1, n - 1] % MOD);
        }
    }
}

###Go

func maxProductPath(grid [][]int) int {
    const MOD = 1000000007
    m, n := len(grid), len(grid[0])
    
    maxgt := make([][]int64, m)
    minlt := make([][]int64, m)
    for i := range maxgt {
        maxgt[i] = make([]int64, n)
        minlt[i] = make([]int64, n)
    }
    
    maxgt[0][0] = int64(grid[0][0])
    minlt[0][0] = int64(grid[0][0])
    for i := 1; i < n; i++ {
        maxgt[0][i] = maxgt[0][i-1] * int64(grid[0][i])
        minlt[0][i] = maxgt[0][i]
    }
    for i := 1; i < m; i++ {
        maxgt[i][0] = maxgt[i-1][0] * int64(grid[i][0])
        minlt[i][0] = maxgt[i][0]
    }
    for i := 1; i < m; i++ {
        for j := 1; j < n; j++ {
            if grid[i][j] >= 0 {
                maxPrev := max(maxgt[i][j-1], maxgt[i-1][j])
                minPrev := min(minlt[i][j-1], minlt[i-1][j])
                maxgt[i][j] = maxPrev * int64(grid[i][j])
                minlt[i][j] = minPrev * int64(grid[i][j])
            } else {
                maxPrev := max(maxgt[i][j-1], maxgt[i-1][j])
                minPrev := min(minlt[i][j-1], minlt[i-1][j])
                maxgt[i][j] = minPrev * int64(grid[i][j])
                minlt[i][j] = maxPrev * int64(grid[i][j])
            }
        }
    }
    
    if maxgt[m-1][n-1] < 0 {
        return -1
    }
    return int(maxgt[m-1][n-1] % MOD)
}

###C

#define MOD 1000000007

int maxProductPath(int** grid, int gridSize, int* gridColSize) {
    int m = gridSize, n = gridColSize[0];
    long long** maxgt = (long long**)malloc(m * sizeof(long long*));
    long long** minlt = (long long**)malloc(m * sizeof(long long*));
    for (int i = 0; i < m; i++) {
        maxgt[i] = (long long*)malloc(n * sizeof(long long));
        minlt[i] = (long long*)malloc(n * sizeof(long long));
    }
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (int i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (int i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (int i = 1; i < m; i++) {
        for (int j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = fmax(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = fmin(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = fmin(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = fmax(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    long long result = maxgt[m - 1][n - 1];
    for (int i = 0; i < m; i++) {
        free(maxgt[i]);
        free(minlt[i]);
    }
    free(maxgt);
    free(minlt);
    
    if (result < 0) {
        return -1;
    } else {
        return result % MOD;
    }
}

###JavaScript

var maxProductPath = function(grid) {
    const MOD = 1000000007;
    const m = grid.length, n = grid[0].length;
    const maxgt = new Array(m).fill(0).map(() => new Array(n).fill(0));
    const minlt = new Array(m).fill(0).map(() => new Array(n).fill(0));
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (let i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (let i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (let i = 1; i < m; i++) {
        for (let j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    if (maxgt[m - 1][n - 1] < 0) {
        return -1;
    } else {
        return maxgt[m - 1][n - 1] % MOD;
    }
};

###TypeScript

function maxProductPath(grid: number[][]): number {
    const MOD = 1000000007;
    const m = grid.length, n = grid[0].length;
    
    const maxgt: number[][] = new Array(m).fill(0).map(() => new Array(n).fill(0));
    const minlt: number[][] = new Array(m).fill(0).map(() => new Array(n).fill(0));
    
    maxgt[0][0] = minlt[0][0] = grid[0][0];
    for (let i = 1; i < n; i++) {
        maxgt[0][i] = minlt[0][i] = maxgt[0][i - 1] * grid[0][i];
    }
    for (let i = 1; i < m; i++) {
        maxgt[i][0] = minlt[i][0] = maxgt[i - 1][0] * grid[i][0];
    }
    for (let i = 1; i < m; i++) {
        for (let j = 1; j < n; j++) {
            if (grid[i][j] >= 0) {
                maxgt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
            } else {
                maxgt[i][j] = Math.min(minlt[i][j - 1], minlt[i - 1][j]) * grid[i][j];
                minlt[i][j] = Math.max(maxgt[i][j - 1], maxgt[i - 1][j]) * grid[i][j];
            }
        }
    }
    
    if (maxgt[m - 1][n - 1] < 0) {
        return -1;
    } else {
        return maxgt[m - 1][n - 1] % MOD;
    }
}

###Rust

impl Solution {
    pub fn max_product_path(grid: Vec<Vec<i32>>) -> i32 {
        const MOD: i64 = 1_000_000_007;
        let m = grid.len();
        let n = grid[0].len();
        let mut maxgt = vec![vec![0i64; n]; m];
        let mut minlt = vec![vec![0i64; n]; m];
        maxgt[0][0] = grid[0][0] as i64;
        minlt[0][0] = grid[0][0] as i64;
        
        for i in 1..n {
            maxgt[0][i] = maxgt[0][i-1] * grid[0][i] as i64;
            minlt[0][i] = maxgt[0][i];
        }
        for i in 1..m {
            maxgt[i][0] = maxgt[i-1][0] * grid[i][0] as i64;
            minlt[i][0] = maxgt[i][0];
        }
        for i in 1..m {
            for j in 1..n {
                let grid_val = grid[i][j] as i64;
                if grid_val >= 0 {
                    let max_prev = maxgt[i][j-1].max(maxgt[i-1][j]);
                    let min_prev = minlt[i][j-1].min(minlt[i-1][j]);
                    maxgt[i][j] = max_prev * grid_val;
                    minlt[i][j] = min_prev * grid_val;
                } else {
                    let max_prev = maxgt[i][j-1].max(maxgt[i-1][j]);
                    let min_prev = minlt[i][j-1].min(minlt[i-1][j]);
                    maxgt[i][j] = min_prev * grid_val;
                    minlt[i][j] = max_prev * grid_val;
                }
            }
        }
        
        let result = maxgt[m-1][n-1];
        if result < 0 {
            -1
        } else {
            (result % MOD) as i32
        }
    }
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 和 $n$ 为矩阵的行数与列数。我们需要遍历矩阵的每一个元素,而处理每个元素时只需要常数时间。

  • 空间复杂度:$O(mn)$。我们开辟了两个与原矩阵等大的数组。

元素和小于等于 k 的子矩阵的数目

2026年3月10日 12:43

方法一:二维前缀和

思路与算法

题目要求统计包含矩阵 $\textit{grid}$ 左上角元素的所有子矩阵中,元素和不超过 $k$ 的子矩阵个数。

我们从左上角出发,按行优先顺序遍历矩阵,将当前访问位置 $(i, j)$ 视为子矩阵的右下角。为了在一次遍历中高效计算子矩阵的和,我们维护一个数组 $\textit{cols}[j]$,用于记录当前行之前第 $j$ 列所有元素的和。在遍历第 $i$ 行时,按列从左到右遍历 $j$,将 $\textit{grid}[i][j]$ 累加到 $\textit{cols}[j]$后,并将 $\textit{cols}[j]$ 累加起来,若当前累加和 $\le k$,则答案加 $1$。

代码

###C++

class Solution {
public:
    int countSubmatrices(vector<vector<int>>& grid, int k) {
        int n = grid.size(), m = grid[0].size();
        vector<int> cols(m);
        int res = 0;
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        return res;
    }
};

###Python

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        n, m = len(grid), len(grid[0])
        cols = [0] * m
        res = 0
        
        for i in range(n):
            row_sum = 0
            for j in range(m):
                cols[j] += grid[i][j]
                row_sum += cols[j]
                if row_sum <= k:
                    res += 1
        
        return res

###Rust

impl Solution {
    pub fn count_submatrices(grid: Vec<Vec<i32>>, k: i32) -> i32 {
        let n = grid.len();
        let m = grid[0].len();
        let mut cols = vec![0; m];
        let mut res = 0;
        
        for i in 0..n {
            let mut row_sum = 0;
            for j in 0..m {
                cols[j] += grid[i][j];
                row_sum += cols[j];
                if row_sum <= k {
                    res += 1;
                }
            }
        }
        
        res
    }
}

###Java

class Solution {
    public int countSubmatrices(int[][] grid, int k) {
        int n = grid.length, m = grid[0].length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###C#

public class Solution {
    public int CountSubmatrices(int[][] grid, int k) {
        int n = grid.Length, m = grid[0].Length;
        int[] cols = new int[m];
        int res = 0;
        
        for (int i = 0; i < n; i++) {
            int rows = 0;
            for (int j = 0; j < m; j++) {
                cols[j] += grid[i][j];
                rows += cols[j];
                if (rows <= k) {
                    res++;
                }
            }
        }
        
        return res;
    }
}

###Go

func countSubmatrices(grid [][]int, k int) int {
    n := len(grid)
    m := len(grid[0])
    cols := make([]int, m)
    res := 0
    
    for i := 0; i < n; i++ {
        rows := 0
        for j := 0; j < m; j++ {
            cols[j] += grid[i][j]
            rows += cols[j]
            if rows <= k {
                res++
            }
        }
    }
    
    return res
}

###C

int countSubmatrices(int** grid, int gridSize, int* gridColSize, int k) {
    int n = gridSize;
    int m = *gridColSize;
    int* cols = (int*)calloc(m, sizeof(int));
    int res = 0;
    
    for (int i = 0; i < n; i++) {
        int rows = 0;
        for (int j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    free(cols);
    return res;
}

###JavaScript

var countSubmatrices = function(grid, k) {
    const n = grid.length;
    const m = grid[0].length;
    const cols = new Array(m).fill(0);
    let res = 0;
    
    for (let i = 0; i < n; i++) {
        let rows = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
};

###TypeScript

function countSubmatrices(grid: number[][], k: number): number {
    const n: number = grid.length;
    const m: number = grid[0].length;
    const cols: number[] = new Array(m).fill(0);
    let res: number = 0;
    
    for (let i = 0; i < n; i++) {
        let rows: number = 0;
        for (let j = 0; j < m; j++) {
            cols[j] += grid[i][j];
            rows += cols[j];
            if (rows <= k) {
                res++;
            }
        }
    }
    
    return res;
}

复杂度分析

  • 时间复杂度:$O(mn)$,其中 $m$ 是 $\textit{grid}$ 的行数,$n$ 是 $\textit{grid}$ 的列数。

  • 空间复杂度:$O(n)$。

移山所需的最少秒数

2026年3月2日 10:24

方法一:二分答案

思路与算法

根据题目描述,如果 $t$ 秒可以使山的高度降低到 $0$,那么任何大于 $t$ 的秒数也可以。因此答案具有单调性,我们可以使用二分查找来解决本题。

对于二分查找的每一步,假设当前猜测的秒数为 $\textit{mid}$,我们需要判断所有工人在 $\textit{mid}$ 秒内能否将山的高度降低 $H = \textit{mountainHeight}$。对于第 $i$ 个工人,他将山的高度降低 $k$ 所需的时间为:

$$
\textit{workerTimes}[i] \cdot (1 + 2 + \cdots + k) = \textit{workerTimes}[i] \cdot \frac{k(k+1)}{2}
$$

因此在 $\textit{mid}$ 秒内,第 $i$ 个工人 $i$ 最多能将山降低的高度,是满足

$$\textit{workerTimes}[i] \cdot \frac{k(k+1)}{2} \leq \textit{mid}$$

的最大正整数 $k$。

令 $\textit{work} = \lfloor \dfrac{\textit{mid}}{\textit{workerTimes}[i]} \rfloor$,其中 $\lfloor \cdot \rfloor$ 表示下取整,则需要满足 $\dfrac{k(k+1)}{2} \leq \textit{work}$,利用求一元二次方程求根公式可得:

$$
k = \left\lfloor \frac{-1 + \sqrt{1 + 8 \cdot \textit{work}}}{2} \right\rfloor
$$

我们将所有工人计算得到的 $k$ 值相加,如果总和大于等于 $H$,则说明 $\textit{mid}$ 秒可以完成任务,应当尝试更少的时间,否则尝试更多的时间。

二分查找的下界为 $1$,上界为 $\max(\textit{workerTimes}) \cdot \dfrac{H(H + 1)}{2}$,即最慢的工人独自完成所有工作所需的时间。

代码

###C++

class Solution {
public:
    long long minNumberOfSeconds(int mountainHeight, vector<int>& workerTimes) {
        int maxWorkerTimes = *max_element(workerTimes.begin(), workerTimes.end());
        long long l = 1, r = static_cast<long long>(maxWorkerTimes) * mountainHeight * (mountainHeight + 1) / 2;
        long long ans = 0;

        while (l <= r) {
            long long mid = (l + r) / 2;
            long long cnt = 0;
            for (int t: workerTimes) {
                long long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long long k = (-1.0 + sqrt(1 + work * 8)) / 2 + eps;
                cnt += k;
            }
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            }
            else {
                l = mid + 1;
            }
        }

        return ans;
    }

private:
    static constexpr double eps = 1e-7;
};

###Python

class Solution:
    def minNumberOfSeconds(self, mountainHeight: int, workerTimes: List[int]) -> int:
        maxWorkerTimes = max(workerTimes)
        l, r, ans = 1, maxWorkerTimes * mountainHeight * (mountainHeight + 1) // 2, 0
        eps = 1e-7
        
        while l <= r:
            mid = (l + r) // 2
            cnt = 0
            for t in workerTimes:
                work = mid // t
                # 求最大的 k 满足 1+2+...+k <= work
                k = int((-1 + ((1 + work * 8) ** 0.5)) / 2 + eps)
                cnt += k
            if cnt >= mountainHeight:
                ans = mid
                r = mid - 1
            else:
                l = mid + 1

        return ans

###Java

class Solution {
    private static final double EPS = 1e-7;
    
    public long minNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        for (int t : workerTimes) {
            maxWorkerTimes = Math.max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long) maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            for (int t : workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###C#

class Solution {
    private const double EPS = 1e-7;
    
    public long MinNumberOfSeconds(int mountainHeight, int[] workerTimes) {
        int maxWorkerTimes = 0;
        foreach (int t in workerTimes) {
            maxWorkerTimes = Math.Max(maxWorkerTimes, t);
        }
        
        long l = 1;
        long r = (long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
        long ans = 0;
        
        while (l <= r) {
            long mid = (l + r) / 2;
            long cnt = 0;
            
            foreach (int t in workerTimes) {
                long work = mid / t;
                // 求最大的 k 满足 1+2+...+k <= work
                long k = (long)((-1.0 + Math.Sqrt(1 + work * 8)) / 2 + EPS);
                cnt += k;
            }
            
            if (cnt >= mountainHeight) {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        return ans;
    }
}

###Go

const eps = 1e-7

func minNumberOfSeconds(mountainHeight int, workerTimes []int) int64 {
    maxWorkerTimes := 0
    for _, t := range workerTimes {
        if t > maxWorkerTimes {
            maxWorkerTimes = t
        }
    }
    
    l := int64(1)
    r := int64(maxWorkerTimes) * int64(mountainHeight) * int64(mountainHeight + 1) / 2
    var ans int64 = 0
    
    for l <= r {
        mid := (l + r) / 2
        var cnt int64 = 0
        
        for _, t := range workerTimes {
            work := mid / int64(t)
            // 求最大的 k 满足 1+2+...+k <= work
            k := int64((-1.0 + math.Sqrt(1 + float64(work) * 8)) / 2 + eps)
            cnt += k
        }
        if cnt >= int64(mountainHeight) {
            ans = mid
            r = mid - 1
        } else {
            l = mid + 1
        }
    }
    
    return ans
}

###C

#define EPS 1e-7

long long minNumberOfSeconds(int mountainHeight, int* workerTimes, int workerTimesSize) {
    int maxWorkerTimes = 0;
    for (int i = 0; i < workerTimesSize; i++) {
        if (workerTimes[i] > maxWorkerTimes) {
            maxWorkerTimes = workerTimes[i];
        }
    }
    
    long long l = 1;
    long long r = (long long)maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    long long ans = 0;
    
    while (l <= r) {
        long long mid = (l + r) / 2;
        long long cnt = 0;
        for (int i = 0; i < workerTimesSize; i++) {
            long long work = mid / workerTimes[i];
            // 求最大的 k 满足 1+2+...+k <= work
            long long k = (long long)((-1.0 + sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###JavaScript

const EPS = 1e-7;

var minNumberOfSeconds = function(mountainHeight, workerTimes) {
    const maxWorkerTimes = Math.max(...workerTimes);
    let l = 1;
    let r = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans = 0;
    
    while (l <= r) {
        const mid = Math.floor((l + r) / 2);
        let cnt = 0;
        for (const t of workerTimes) {
            const work = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###TypeScript

const EPS: number = 1e-7;

function minNumberOfSeconds(mountainHeight: number, workerTimes: number[]): number {
    const maxWorkerTimes: number = Math.max(...workerTimes);
    let l: number = 1;
    let r: number = maxWorkerTimes * mountainHeight * (mountainHeight + 1) / 2;
    let ans: number = 0;
    
    while (l <= r) {
        const mid: number = Math.floor((l + r) / 2);
        let cnt: number = 0;
        for (const t of workerTimes) {
            const work: number = Math.floor(mid / t);
            // 求最大的 k 满足 1+2+...+k <= work
            const k: number = Math.floor((-1.0 + Math.sqrt(1 + work * 8)) / 2 + EPS);
            cnt += k;
        }
        
        if (cnt >= mountainHeight) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    
    return ans;
}

###Rust

const EPS: f64 = 1e-7;

impl Solution {
    pub fn min_number_of_seconds(mountain_height: i32, worker_times: Vec<i32>) -> i64 {
        let mountain_height = mountain_height as i64;
        let max_worker_times = *worker_times.iter().max().unwrap_or(&0) as i64;
        
        let mut l: i64 = 1;
        let mut r: i64 = max_worker_times * mountain_height * (mountain_height + 1) / 2;
        let mut ans: i64 = 0;
        
        while l <= r {
            let mid = (l + r) / 2;
            let mut cnt: i64 = 0;
            
            for &t in &worker_times {
                let work = mid / t as i64;
                // 求最大的 k 满足 1+2+...+k <= work
                let k = ((-1.0 + (1.0 + (work as f64) * 8.0).sqrt()) / 2.0 + EPS) as i64;
                cnt += k;
            }
            
            if cnt >= mountain_height {
                ans = mid;
                r = mid - 1;
            } else {
                l = mid + 1;
            }
        }
        
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(n \log(MH^2))$,其中 $n$ 是数组 $\textit{workerTimes}$ 的长度,$M$ 是数组 $\textit{workerTimes}$ 中的最大值,$H$ 是 $\textit{mountainHeight}$。二分查找需要 $O(\log(MH^2))$ 次迭代,每次迭代遍历所有工人,需要 $O(n)$ 的时间。

  • 空间复杂度:$O(1)$。

找出所有稳定的二进制数组 II

2024年8月5日 11:48

方法一:记忆化搜索

思路

根据稳定数组的前两个条件,可知稳定数组的长度为 $\textit{zero} + \textit{one}$。第三个条件可知,稳定数组不存在长度为 $\textit{limit} + 1$ 的全 $0$ 或全 $1$ 子数组。

接下来我们分解问题,包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$ 的稳定数组,末位元素可能为 $1$,也可能为 $0$。

  • 如果末位元素为 $1$,我们需要知道有多少个包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}-1$ 个 $1$ 的稳定数组,再去掉“由于添加了一个 $1$ 而使得原来的稳定数组变得不稳定”的情况。那么有哪些情况会使得原来稳定的数组变得不稳定呢?即原来的稳定数组的末尾连续 $1$ 的个数正好为 $\textit{limit}$ 个。在这种情况下,添加一个 $1$ 会使得原来稳定的数组变得不稳定。这种情况出现的次数,即为包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}-1-\textit{limit}$ 个 $1$,且末位元素为 $0$ 的稳定数组的个数。
  • 如果末位元素为 $0$,我们需要知道有多少个包含 $\textit{zero}-1$ 个 $0$ 和 $\textit{one}$ 个 $1$ 的稳定数组,再去掉“由于添加了一个 $0$ 而使得原来的稳定数组变得不稳定”的情况。

这样一来,我们就将问题分解为子问题了,可以用动态规划求解。用函数 $\textit{dp}(\textit{zero},\textit{one},\textit{lastBit})$,来求解包含 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$,并且末位元素为 $\textit{lastBit}$ 的稳定数组的个数,其中 $\textit{lastBit}$ 为 $0$ 或 $1$。根据上面的讨论,可以得到递推公式:

  • $\textit{dp}(\textit{zero},\textit{one},0)$ = $\textit{dp}(\textit{zero}-1,\textit{one},0) + \textit{dp}(\textit{zero}-1,\textit{one},1) - \textit{dp}(\textit{zero}-1-\textit{limit},\textit{one},1)$
  • $\textit{dp}(\textit{zero},\textit{one},1)$ = $\textit{dp}(\textit{zero},\textit{one}-1,0) + \textit{dp}(\textit{zero},\textit{one}-1,1) - \textit{dp}(\textit{zero},\textit{one}-1-\textit{limit},0)$。

另外考虑边界情况。如果 $\textit{zero}$ 为 $0$,那么当 $\textit{lastBit}$ 为 $1$ 或者 $\textit{one}$ 大于 $\textit{limit}$ 时,不存在这样的稳定数组,返回 $0$,否则返回 $1$。如果 $\textit{zero}$ 为 $1$,也有对应的结论。

我们用记忆化搜索的方式来计算结果,记录所有的中间状态,最终返回 $\textit{dp}(\textit{zero},\textit{one},0)$ + $\textit{dp}(\textit{zero},\textit{one},1)$ 取模后的结果。

代码

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        mod = 10 ** 9 + 7

        @cache
        def dp(zero, one, lastBit):
            if zero == 0:
                if lastBit == 0 or one > limit:
                    return 0
                else:
                    return 1
            elif one == 0:
                if lastBit == 1 or zero > limit:
                    return 0
                else:
                    return 1
            if lastBit == 0:
                res = dp(zero - 1, one, 0) + dp(zero - 1, one, 1)
                if zero > limit:
                    res -= dp(zero - limit - 1, one, 1)
            else:
                res = dp(zero, one - 1, 0) + dp(zero, one - 1, 1)
                if one > limit:
                    res -= dp(zero, one - limit - 1, 0)
            return res % mod
            
        res = (dp(zero, one, 0) + dp(zero, one, 1)) % mod
        dp.cache_clear()
        return res

###C++

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        int mod = 1e9 + 7;
        vector<vector<vector<int>>> memo(zero + 1, vector<vector<int>>(one + 1, vector<int>(2, -1)));

        function<int(int, int, int)> dp = [&](int zero, int one, int lastBit) -> int {
            if (zero == 0) {
                return (lastBit == 0 || one > limit) ? 0 : 1;
            } else if (one == 0) {
                return (lastBit == 1 || zero > limit) ? 0 : 1;
            }

            if (memo[zero][one][lastBit] == -1) {
                int res = 0;
                if (lastBit == 0) {
                    res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % mod;
                    if (zero > limit) {
                        res = (res - dp(zero - limit - 1, one, 1) + mod) % mod;
                    }
                } else {
                    res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % mod;
                    if (one > limit) {
                        res = (res - dp(zero, one - limit - 1, 0) + mod) % mod;
                    }
                }
                memo[zero][one][lastBit] = res % mod;
            }
            return memo[zero][one][lastBit];
        };

        return (dp(zero, one, 0) + dp(zero, one, 1)) % mod;
    }
};

###Java

class Solution {
    static final int MOD = 1000000007;
    int[][][] memo;
    int limit;

    public int numberOfStableArrays(int zero, int one, int limit) {
        this.memo = new int[zero + 1][one + 1][2];
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                Arrays.fill(memo[i][j], -1);
            }
        }
        this.limit = limit;
        return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
    }

    public int dp(int zero, int one, int lastBit) {
        if (zero == 0) {
            return (lastBit == 0 || one > limit) ? 0 : 1;
        } else if (one == 0) {
            return (lastBit == 1 || zero > limit) ? 0 : 1;
        }

        if (memo[zero][one][lastBit] == -1) {
            int res = 0;
            if (lastBit == 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1))% MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }
}

###C#

public class Solution {
    const int MOD = 1000000007;
    int[][][] memo;
    int limit;

    public int NumberOfStableArrays(int zero, int one, int limit) {
        this.memo = new int[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            memo[i] = new int[one + 1][];
            for (int j = 0; j <= one; j++) {
                memo[i][j] = new int[2];
                Array.Fill(memo[i][j], -1);
            }
        }
        this.limit = limit;
        return (DP(zero, one, 0) + DP(zero, one, 1)) % MOD;
    }

    public int DP(int zero, int one, int lastBit) {
        if (zero == 0) {
            return (lastBit == 0 || one > limit) ? 0 : 1;
        } else if (one == 0) {
            return (lastBit == 1 || zero > limit) ? 0 : 1;
        }

        if (memo[zero][one][lastBit] == -1) {
            int res = 0;
            if (lastBit == 0) {
                res = (DP(zero - 1, one, 0) + DP(zero - 1, one, 1))% MOD;
                if (zero > limit) {
                    res = (res - DP(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (DP(zero, one - 1, 0) + DP(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - DP(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }
}

###C

#define MOD 1000000007

int ***createMemo(int zero, int one) {
    int ***memo = malloc((zero + 1) * sizeof(int **));
    for (int i = 0; i <= zero; ++i) {
        memo[i] = malloc((one + 1) * sizeof(int *));
        for (int j = 0; j <= one; ++j) {
            memo[i][j] = malloc(2 * sizeof(int));
            memo[i][j][0] = -1;
            memo[i][j][1] = -1;
        }
    }
    return memo;
}

void freeMemo(int zero, int one, int ***memo) {
    for (int i = 0; i <= zero; ++i) {
        for (int j = 0; j <= one; ++j) {
            free(memo[i][j]);
        }
        free(memo[i]);
    }
    free(memo);
}

int dp(int zero, int one, int lastBit, int limit, int ***memo) {
    if (zero == 0) {
        return (lastBit == 0 || one > limit) ? 0 : 1;
    } else if (one == 0) {
        return (lastBit == 1 || zero > limit) ? 0 : 1;
    }
    if (memo[zero][one][lastBit] == -1) {
        int res = 0;
        if (lastBit == 0) {
            res = (dp(zero - 1, one, 0, limit, memo) + dp(zero - 1, one, 1, limit, memo)) % MOD;
            if (zero > limit) {
                res = (res - dp(zero - limit - 1, one, 1, limit, memo) + MOD) % MOD;
            }
        } else {
            res = (dp(zero, one - 1, 0, limit, memo) + dp(zero, one - 1, 1, limit, memo)) % MOD;
            if (one > limit) {
                res = (res - dp(zero, one - limit - 1, 0, limit, memo) + MOD) % MOD;
            }
        }
        memo[zero][one][lastBit] = res % MOD;
    }
    return memo[zero][one][lastBit];
}

int numberOfStableArrays(int zero, int one, int limit) {
    int ***memo = createMemo(zero, one);
    int result = (dp(zero, one, 0, limit, memo) + dp(zero, one, 1, limit, memo)) % MOD;
    freeMemo(zero, one, memo);
    return result;
}

###Go

const MOD = 1000000007

func numberOfStableArrays(zero int, one int, limit int) int {
    memo := make([][][]int, zero + 1)
for i := range memo {
memo[i] = make([][]int, one + 1)
for j := range memo[i] {
memo[i][j] = []int{-1, -1}
}
}

var dp func(int, int, int) int
dp = func(zero, one, lastBit int) int {
if zero == 0 {
if lastBit == 0 || one > limit {
return 0
} else {
return 1
}
} else if one == 0 {
if lastBit == 1 || zero > limit {
return 0
} else {
return 1
}
}

if memo[zero][one][lastBit] == -1 {
res := 0
if lastBit == 0 {
res = (dp(zero-1, one, 0) + dp(zero - 1, one, 1)) % MOD
if zero > limit {
res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD
}
} else {
res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD
if one > limit {
res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD
}
}
memo[zero][one][lastBit] = res % MOD
}
return memo[zero][one][lastBit]
}

return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    const memo = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [-1, -1])
    );

    function dp(zero, one, lastBit) {
        if (zero === 0) {
            return lastBit === 0 || one > limit ? 0 : 1;
        } else if (one === 0) {
            return lastBit === 1 || zero > limit ? 0 : 1;
        }

        if (memo[zero][one][lastBit] === -1) {
            let res = 0;
            if (lastBit === 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }

    return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    const memo: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [-1, -1])
    );

    function dp(zero: number, one: number, lastBit: number): number {
        if (zero === 0) {
            return lastBit === 0 || one > limit ? 0 : 1;
        } else if (one === 0) {
            return lastBit === 1 || zero > limit ? 0 : 1;
        }

        if (memo[zero][one][lastBit] === -1) {
            let res = 0;
            if (lastBit === 0) {
                res = (dp(zero - 1, one, 0) + dp(zero - 1, one, 1)) % MOD;
                if (zero > limit) {
                    res = (res - dp(zero - limit - 1, one, 1) + MOD) % MOD;
                }
            } else {
                res = (dp(zero, one - 1, 0) + dp(zero, one - 1, 1)) % MOD;
                if (one > limit) {
                    res = (res - dp(zero, one - limit - 1, 0) + MOD) % MOD;
                }
            }
            memo[zero][one][lastBit] = res % MOD;
        }
        return memo[zero][one][lastBit];
    }

    return (dp(zero, one, 0) + dp(zero, one, 1)) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut memo = vec![vec![vec![-1; 2]; (one + 1) as usize]; (zero + 1) as usize];

        fn dp(zero: usize, one: usize, last_bit: usize, limit: usize, memo: &mut Vec<Vec<Vec<i32>>>) -> i32 {
            if zero == 0 {
                return if last_bit == 0 || one > limit { 0 } else { 1 };
            } else if one == 0 {
                return if last_bit == 1 || zero > limit { 0 } else { 1 };
            }

            if memo[zero][one][last_bit] == -1 {
                let mut res = 0;
                if last_bit == 0 {
                    res = (dp(zero - 1, one, 0, limit, memo) + dp(zero - 1, one, 1, limit, memo)) % MOD;
                    if zero > limit {
                        res = (res - dp(zero - limit - 1, one, 1, limit, memo) + MOD) % MOD;
                    }
                } else {
                    res = (dp(zero, one - 1, 0, limit, memo) + dp(zero, one - 1, 1, limit, memo)) % MOD;
                    if one > limit {
                        res = (res - dp(zero, one - limit - 1, 0, limit, memo) + MOD) % MOD;
                    }
                }
                memo[zero][one][last_bit] = res % MOD;
            }
            memo[zero][one][last_bit]
        }

        let zero = zero as usize;
        let one = one as usize;
        let limit = limit as usize;
        (dp(zero, one, 0, limit, &mut memo) + dp(zero, one, 1, limit, &mut memo)) % MOD
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero}\times\textit{one})$,动态规划的状态一共有 $O(\textit{zero}\times\textit{one})$ 种,每个状态消耗 $O(1)$ 时间消耗。

  • 空间复杂度:$O(\textit{zero}\times\textit{one})$。

方法二:动态规划

思路

方法一用的是记忆化搜索,状态的求解是自顶向下的。方法二中我们使用动态规划,从而自底向上来求出所有状态,并用数组保存结果。状态方程的关系和方法一一致。

代码

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        mod = 10 ** 9 + 7

        dp = [[[0, 0] for _ in range(one + 1)] for _ in range(zero + 1)]
        for i in range(zero+1):
            for j in range(one+1):
                for lastBit in range(2):
                    if i == 0:
                        if lastBit == 0 or j > limit:
                            dp[i][j][lastBit] = 0
                        else:
                            dp[i][j][lastBit] = 1
                    elif j == 0:
                        if lastBit == 1 or i > limit:
                            dp[i][j][lastBit] = 0
                        else:
                            dp[i][j][lastBit] = 1
                    elif lastBit == 0:
                        dp[i][j][lastBit] = dp[i-1][j][0] + dp[i-1][j][1]
                        if i > limit:
                            dp[i][j][lastBit] -= dp[i-limit-1][j][1]
                    else:
                        dp[i][j][lastBit] = dp[i][j-1][0] + dp[i][j-1][1]
                        if j > limit:
                            dp[i][j][lastBit] -= dp[i][j-limit-1][0]
                    dp[i][j][lastBit] %= mod
        return (dp[-1][-1][0] + dp[-1][-1][1]) % mod

###C++

class Solution {
public:
    constexpr static int MOD = 1000000007;
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<int>>> dp(zero + 1, vector<vector<int>>(one + 1, vector<int>(2)));
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }

        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
};

###Java

class Solution {
    public int numberOfStableArrays(int zero, int one, int limit) {
        final int MOD = 1000000007;
        int[][][] dp = new int[zero + 1][one + 1][2];
        for (int i = 0; i <= zero; i++) {
            for (int j = 0; j <= one; j++) {
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
}

###C#

public class Solution {
    public int NumberOfStableArrays(int zero, int one, int limit) {
        const int MOD = 1000000007;
        int[][][] dp = new int[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            dp[i] = new int[one + 1][];
            for (int j = 0; j <= one; j++) {
                dp[i][j] = new int[2];
                for (int lastBit = 0; lastBit <= 1; lastBit++) {
                    if (i == 0) {
                        if (lastBit == 0 || j > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (j == 0) {
                        if (lastBit == 1 || i > limit) {
                            dp[i][j][lastBit] = 0;
                        } else {
                            dp[i][j][lastBit] = 1;
                        }
                    } else if (lastBit == 0) {
                        dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if (i > limit) {
                            dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                        }
                    } else {
                        dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j -1 ][1];
                        if (j > limit) {
                            dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                        }
                    }
                    dp[i][j][lastBit] %= MOD;
                    if (dp[i][j][lastBit] < 0) {
                        dp[i][j][lastBit] += MOD;
                    }
                }
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    }
}

###Go

const MOD = 1000000007

func numberOfStableArrays(zero int, one int, limit int) int {
    dp := make([][][]int, zero + 1)
    for i := range dp {
        dp[i] = make([][]int, one + 1)
        for j := range dp[i] {
            dp[i][j] = make([]int, 2)
        }
    }

    for i := 0; i <= zero; i++ {
        for j := 0; j <= one; j++ {
            for lastBit := 0; lastBit <= 1; lastBit++ {
                if i == 0 {
                    if lastBit == 0 || j > limit {
                        dp[i][j][lastBit] = 0
                    } else {
                        dp[i][j][lastBit] = 1
                    }
                } else if j == 0 {
                    if lastBit == 1 || i > limit {
                        dp[i][j][lastBit] = 0
                    } else {
                        dp[i][j][lastBit] = 1
                    }
                } else if lastBit == 0 {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1]
                    if i > limit {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1]
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1]
                    if j > limit {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0]
                    }
                }
                dp[i][j][lastBit] %= MOD
                if dp[i][j][lastBit] < 0 {
                    dp[i][j][lastBit] += MOD
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD
}

###C

#define MOD 1000000007

int numberOfStableArrays(int zero, int one, int limit) {
    int dp[zero + 1][one + 1][2];
    memset(dp, 0, sizeof(dp));
    for (int i = 0; i <= zero; i++) {
        for (int j = 0; j <= one; j++) {
            for (int lastBit = 0; lastBit <= 1; lastBit++) {
                if (i == 0) {
                    if (lastBit == 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j == 0) {
                    if (lastBit == 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit == 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    let dp = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= zero; i++) {
        for (let j = 0; j <= one; j++) {
            for (let lastBit = 0; lastBit <= 1; lastBit++) {
                if (i === 0) {
                    if (lastBit === 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j === 0) {
                    if (lastBit === 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit === 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    let dp: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= zero; i++) {
        for (let j = 0; j <= one; j++) {
            for (let lastBit = 0; lastBit <= 1; lastBit++) {
                if (i === 0) {
                    if (lastBit === 0 || j > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (j === 0) {
                    if (lastBit === 1 || i > limit) {
                        dp[i][j][lastBit] = 0;
                    } else {
                        dp[i][j][lastBit] = 1;
                    }
                } else if (lastBit === 0) {
                    dp[i][j][lastBit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                    if (i > limit) {
                        dp[i][j][lastBit] -= dp[i - limit - 1][j][1];
                    }
                } else {
                    dp[i][j][lastBit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                    if (j > limit) {
                        dp[i][j][lastBit] -= dp[i][j - limit - 1][0];
                    }
                }
                dp[i][j][lastBit] %= MOD;
                if (dp[i][j][lastBit] < 0) {
                    dp[i][j][lastBit] += MOD;
                }
            }
        }
    }

    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut dp = vec![vec![vec![0; 2]; one as usize + 1]; zero as usize + 1];

        for i in 0..=zero as usize {
            for j in 0..=one as usize {
                for last_bit in 0..=1 {
                    if i == 0 {
                        if last_bit == 0 || j > limit as usize {
                            dp[i][j][last_bit] = 0;
                        } else {
                            dp[i][j][last_bit] = 1;
                        }
                    } else if j == 0 {
                        if last_bit == 1 || i > limit as usize {
                            dp[i][j][last_bit] = 0;
                        } else {
                            dp[i][j][last_bit] = 1;
                        }
                    } else if last_bit == 0 {
                        dp[i][j][last_bit] = dp[i - 1][j][0] + dp[i - 1][j][1];
                        if i > limit as usize {
                            dp[i][j][last_bit] -= dp[i - (limit as usize) - 1][j][1];
                        }
                    } else {
                        dp[i][j][last_bit] = dp[i][j - 1][0] + dp[i][j - 1][1];
                        if j > limit as usize {
                            dp[i][j][last_bit] -= dp[i][j - (limit as usize) - 1][0];
                        }
                    }
                    dp[i][j][last_bit] %= MOD;
                    if dp[i][j][last_bit] < 0 {
                        dp[i][j][last_bit] += MOD;
                    }
                }
            }
        }

        return (dp[zero as usize][one as usize][0] + dp[zero as usize][one as usize][1]) % MOD;
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero}\times\textit{one})$,动态规划的状态一共有 $O(\textit{zero}\times\textit{one})$ 种,每个状态消耗 $O(1)$ 时间消耗。

  • 空间复杂度:$O(\textit{zero}\times\textit{one})$。

找出所有稳定的二进制数组 I

2024年8月1日 10:03

方法一:动态规划

题目要求二进制数组 $\textit{arr}$ 中每个长度超过 $\textit{limit}$ 的子数组同时包含 $0$ 和 $1$,这个条件等价于二进制数组 $\textit{arr}$ 中每个长度等于 $\textit{limit} + 1$ 的子数组都同时包含 $0$ 和 $1$(读者可以思考一下证明过程)。

按照题目要求,我们需要将 $\textit{zero}$ 个 $0$ 和 $\textit{one}$ 个 $1$ 依次填入二进制数组 $\textit{arr}$,使用 $\textit{dp}_0[i][j]$ 表示已经填入 $i$ 个 $0$ 和 $\textit{j}$ 个 $1$,并且最后填入的数字为 $0$ 的可行方案数目,$\textit{dp}_1[i][j]$ 表示已经填入 $i$ 个 $0$ 和 $\textit{j}$ 个 $1$,并且最后填入的数字为 $1$ 的可行方案数目。对于 $\textit{dp}_0[i][j]$,我们分析一下它的转换方程:

  • 当 $j = 0$ 且 $i \in [0, \min(\textit{zero}, \textit{limit})]$ 时:我们可以不断地填入 $0$,所以 $\textit{dp}_0[i][j] = 1$。

  • 当 $i = 0$,或者 $j = 0$ 且 $i \notin [0, \min(\textit{zero}, \textit{limit})]$ 时:我们没法构造可行的方案,所以 $\textit{dp}_0[i][j] = 0$。

  • 当 $i > 0$ 且 $j > 0$ 时:$\textit{dp}_0[i][j]$ 可以分别由 $\textit{dp}_0[i - 1][j]$ 和 $\textit{dp}_1[i - 1][j]$ 转移而来,分别考虑两种情况:

    • 对于 $\textit{dp}_1[i - 1][j]$:显然可以通过在 $\textit{dp}_1[i - 1][j]$ 对应的所有填入方案后再填入一个 $0$ 得到对应的可行方案。

    • 对于 $\textit{dp}_0[i - 1][j]$:当 $i \le \textit{limit}$ 时,显然可以通过在 $\textit{dp}_1[i - 1][j]$ 对应的所有填入方案后再填入一个 $0$ 得到对应的可行方案;当 $i \gt \textit{limit}$ 时,我们需要去除一些不可行的方案数。因为 $\textit{dp}_0[i - 1][j]$ 对应的所有填入方案都是可行的,而只有一种情况会在额外填入一个 $0$ 时,变成不可行,即先前已经连续填入 $\textit{limit}$ 个 $0$,对应的方案数为 $\textit{dp}_1[i - \textit{limit} - 1][j]$。

根据以上分析,我们有 $\textit{dp}_0[i][j]$ 的转移方程:

$$
\textit{dp}_0[i][j] = \begin{cases}
1, & i \in [0, \min(\textit{zero}, \textit{limit})], j = 0 \
\textit{dp}_1[i - 1][j] + \textit{dp}_0[i - 1][j] - \textit{dp}_1[i - \textit{limit} - 1][j], & i > limit, j > 0 \
\textit{dp}_1[i - 1][j] + \textit{dp}_0[i - 1][j], & i \in [0, limit], j > 0 \
0, & otherwise
\end{cases}
$$

同理,我们也可以获得 $\textit{dp}_1[i][j]$ 的转移方程:

$$
\textit{dp}_1[i][j] = \begin{cases}
1, & i = 0, j \in [0, \min(\textit{one}, \textit{limit})] \
\textit{dp}_0[i][j - 1] + \textit{dp}_1[i][j - 1] - \textit{dp}_0[i][j - \textit{limit} - 1], & i > 0, j > limit \
\textit{dp}_0[i][j - 1] + \textit{dp}_1[i][j - 1], & i > 0, j \in [0, limit] \
0, & otherwise
\end{cases}
$$

最后,稳定二进制数组的数目等于 $\textit{dp}_0[\textit{zero}][\textit{one}] + \textit{dp}_1[\textit{zero}][\textit{one}]$。

###C++

class Solution {
public:
    int numberOfStableArrays(int zero, int one, int limit) {
        vector<vector<vector<long long>>> dp(zero + 1, vector<vector<long long>>(one + 1, vector<long long>(2)));
        long long mod = 1e9 + 7;
        for (int i = 0; i <= min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod;
            }
        }
        return (dp[zero][one][0] + dp[zero][one][1]) % mod;
    }
};

###Java

class Solution {
    public int numberOfStableArrays(int zero, int one, int limit) {
        final long MOD = 1000000007;
        long[][][] dp = new long[zero + 1][one + 1][2];
        for (int i = 0; i <= Math.min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= Math.min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        return (int) ((dp[zero][one][0] + dp[zero][one][1]) % MOD);
    }
}

###C#

public class Solution {
    public int NumberOfStableArrays(int zero, int one, int limit) {
        const long MOD = 1000000007;
        long[][][] dp = new long[zero + 1][][];
        for (int i = 0; i <= zero; i++) {
            dp[i] = new long[one + 1][];
            for (int j = 0; j <= one; j++) {
                dp[i][j] = new long[2];
            }
        }
        for (int i = 0; i <= Math.Min(zero, limit); i++) {
            dp[i][0][0] = 1;
        }
        for (int j = 0; j <= Math.Min(one, limit); j++) {
            dp[0][j][1] = 1;
        }
        for (int i = 1; i <= zero; i++) {
            for (int j = 1; j <= one; j++) {
                if (i > limit) {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if (j > limit) {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        return (int) ((dp[zero][one][0] + dp[zero][one][1]) % MOD);
    }
}

###Go

func numberOfStableArrays(zero int, one int, limit int) int {
    dp := make([][][2]int, zero + 1)
    mod := int(1e9 + 7)
    for i := 0; i <= zero; i++ {
        dp[i] = make([][2]int, one + 1)
    }
    for i := 0; i <= min(zero, limit); i++ {
        dp[i][0][0] = 1
    }
    for j := 0; j <= min(one, limit); j++ {
        dp[0][j][1] = 1
    }
    for i := 1; i <= zero; i++ {
        for j := 1; j <= one; j++ {
            if i > limit {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1]
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1]
            }
            dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod
            if j > limit {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0]
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0]
            }
            dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % mod
}

###Python

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        dp = [[[0, 0] for _ in range(one + 1)] for _ in range(zero + 1)]
        mod = int(1e9 + 7)
        for i in range(min(zero, limit) + 1):
            dp[i][0][0] = 1
        for j in range(min(one, limit) + 1):
            dp[0][j][1] = 1
        for i in range(1, zero + 1):
            for j in range(1, one + 1):
                if i > limit:
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1]
                else:
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1]
                dp[i][j][0] = (dp[i][j][0] % mod + mod) % mod
                if j > limit:
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0]
                else:
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0]
                dp[i][j][1] = (dp[i][j][1] % mod + mod) % mod
        return (dp[zero][one][0] + dp[zero][one][1]) % mod

###C

#define MOD 1000000007

int numberOfStableArrays(int zero, int one, int limit) {
    long long dp[zero + 1][one + 1][2];
    memset(dp, 0, sizeof(dp));
    for (int i = 0; i <= (zero < limit ? zero : limit); ++i) {
        dp[i][0][0] = 1;
    }
    for (int j = 0; j <= (one < limit ? one : limit); ++j) {
        dp[0][j][1] = 1;
    }
    for (int i = 1; i <= zero; ++i) {
        for (int j = 1; j <= one; ++j) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    int result = (dp[zero][one][0] + dp[zero][one][1]) % MOD;
    return result;
}

###JavaScript

const MOD = 1000000007;

var numberOfStableArrays = function(zero, one, limit) {
    const dp = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= Math.min(zero, limit); i++) {
        dp[i][0][0] = 1;
    }
    for (let j = 0; j <= Math.min(one, limit); j++) {
        dp[0][j][1] = 1;
    }

    for (let i = 1; i <= zero; i++) {
        for (let j = 1; j <= one; j++) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###TypeScript

const MOD = 1000000007;

function numberOfStableArrays(zero: number, one: number, limit: number): number {
    const dp: number[][][] = Array.from({ length: zero + 1 }, () =>
        Array.from({ length: one + 1 }, () => [0, 0])
    );

    for (let i = 0; i <= Math.min(zero, limit); i++) {
        dp[i][0][0] = 1;
    }
    for (let j = 0; j <= Math.min(one, limit); j++) {
        dp[0][j][1] = 1;
    }

    for (let i = 1; i <= zero; i++) {
        for (let j = 1; j <= one; j++) {
            if (i > limit) {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit - 1][j][1];
            } else {
                dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
            }
            dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
            if (j > limit) {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit - 1][0];
            } else {
                dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
            }
            dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
        }
    }
    return (dp[zero][one][0] + dp[zero][one][1]) % MOD;
};

###Rust

const MOD: i32 = 1000000007;

impl Solution {
    pub fn number_of_stable_arrays(zero: i32, one: i32, limit: i32) -> i32 {
        let mut dp = vec![vec![vec![0; 2]; one as usize + 1]; zero as usize + 1];

        for i in 0..=zero.min(limit) as usize {
            dp[i][0][0] = 1;
        }
        for j in 0..=one.min(limit) as usize {
            dp[0][j][1] = 1;
        }

        for i in 1..=zero as usize {
            for j in 1..=one as usize {
                if i > limit as usize {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1] - dp[i - limit as usize - 1][j][1];
                } else {
                    dp[i][j][0] = dp[i - 1][j][0] + dp[i - 1][j][1];
                }
                dp[i][j][0] = (dp[i][j][0] % MOD + MOD) % MOD;
                if j > limit as usize {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0] - dp[i][j - limit as usize - 1][0];
                } else {
                    dp[i][j][1] = dp[i][j - 1][1] + dp[i][j - 1][0];
                }
                dp[i][j][1] = (dp[i][j][1] % MOD + MOD) % MOD;
            }
        }
        (dp[zero as usize][one as usize][0] + dp[zero as usize][one as usize][1]) % MOD
    }
}

复杂度分析

  • 时间复杂度:$O(\textit{zero} \times \textit{one})$,其中 $\textit{zero}$ 和 $\textit{one}$ 分别为 $0$ 和 $1$ 的出现次数。

  • 空间复杂度:$O(\textit{zero} \times \textit{one})$。

检查二进制字符串字段

2022年9月28日 10:03

方法一:寻找 $01$ 串

思路与算法

题目给定一个长度为 $n$ 的二进制字符串 $s$,并满足该字符串不含前导零。现在我们需要判断字符串中是否只包含零个或一个由连续 $1$ 组成的字段。首先我们依次分析这两种情况:

  • 字符串 $s$ 中包含零个由连续 $1$ 组成的字段,那么整个串的表示为 $00 \cdots 00$。
  • 字符串 $s$ 中只包含一个由连续 $1$ 组成的字段,因为已知字符串 $s$ 不包含前导零,所以整个串的表示为 $1 \cdots 100 \cdots 00$。

那么可以看到两种情况中都不包含 $01$ 串。且不包含的 $01$ 串的一个二进制字符串也有且仅有上面两种情况。所以我们可以通过原字符串中是否有 $01$ 串来判断字符串中是否只包含零个或一个由连续 $1$ 组成的字段。如果有 $01$ 串则说明该情况不满足,否则即满足该情况条件。

代码

###Python

class Solution:
    def checkOnesSegment(self, s: str) -> bool:
        return "01" not in s

###Java

class Solution {
    public boolean checkOnesSegment(String s) {
        return !s.contains("01");
    }
}

###C#

public class Solution {
    public bool CheckOnesSegment(string s) {
        return !s.Contains("01");
    }
}

###C++

class Solution {
public:
    bool checkOnesSegment(string s) {
        return s.find("01") == string::npos;
    }
};

###C

bool checkOnesSegment(char * s){
    return strstr(s, "01") == NULL;
}

###JavaScript

var checkOnesSegment = function(s) {
    return s.indexOf('01') === -1;
};

###go

func checkOnesSegment(s string) bool {
return !strings.Contains(s, "01")
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为字符串 $s$ 的长度。
  • 空间复杂度:$O(1)$,仅适用常量空间。

生成交替二进制字符串的最少操作数

2022年11月28日 09:58

方法一:模拟

思路

根据题意,经过多次操作,$s$ 可能会变成两种不同的交替二进制字符串,即:

  • 开头为 $0$,后续交替的字符串;
  • 开头为 $1$,后续交替的字符串。

注意到,变成这两种不同的交替二进制字符串所需要的最少操作数加起来等于 $s$ 的长度,我们只需要计算出变为其中一种字符串的最少操作数,就可以推出另一个最少操作数,然后取最小值即可。

代码

###Python

class Solution:
    def minOperations(self, s: str) -> int:
        cnt = sum(int(c) != i % 2 for i, c in enumerate(s))
        return min(cnt, len(s) - cnt)

###Java

class Solution {
    public int minOperations(String s) {
        int cnt = 0;
        for (int i = 0; i < s.length(); i++) {
            char c = s.charAt(i);
            if (c != (char) ('0' + i % 2)) {
                cnt++;
            }
        }
        return Math.min(cnt, s.length() - cnt);
    }
}

###C#

public class Solution {
    public int MinOperations(string s) {
        int cnt = 0;
        for (int i = 0; i < s.Length; i++) {
            char c = s[i];
            if (c != (char) ('0' + i % 2)) {
                cnt++;
            }
        }
        return Math.Min(cnt, s.Length - cnt);
    }
}

###C++

class Solution {
public:
    int minOperations(string s) {
        int cnt = 0;
        for (int i = 0; i < s.size(); i++) {
            char c = s[i];
            if (c != ('0' + i % 2)) {
                cnt++;
            }
        }
        return min(cnt, (int)s.size() - cnt);
    }
};

###C

#define MIN(a, b) ((a) < (b) ? (a) : (b))

int minOperations(char * s) {
    int cnt = 0, len = strlen(s);
    for (int i = 0; i < len; i++) {
        char c = s[i];
        if (c != ('0' + i % 2)) {
            cnt++;
        }
    }
    return MIN(cnt, len - cnt);
}

###JavaScript

var minOperations = function(s) {
    let cnt = 0;
    for (let i = 0; i < s.length; i++) {
        const c = s[i];
        if (c !== (String.fromCharCode('0'.charCodeAt() + i % 2))) {
            cnt++;
        }
    }
    return Math.min(cnt, s.length - cnt);
};

###go

func minOperations(s string) int {
    cnt := 0
    for i, c := range s {
        if i%2 != int(c-'0') {
            cnt++
        }
    }
    return min(cnt, len(s)-cnt)
}

func min(a, b int) int {
    if a > b {
        return b
    }
    return a
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 为输入 $s$ 的长度,仅需遍历一遍字符串。

  • 空间复杂度:$O(1)$,只需要常数额外空间。

找出第 N 个二进制字符串中的第 K 位

2020年8月20日 20:51

方法一:递归

观察二进制字符串 $S_n$,可以发现,当 $n>1$ 时,$S_n$ 是在 $S_{n-1}$ 的基础上形成的。用 $\text{len}n$ 表示 $S_n$ 的长度,则 $S_n$ 的前 $\text{len}{n-1}$ 个字符与 $S_{n-1}$ 相同。还可以发现,当 $n>1$ 时,$\text{len}n=\text{len}{n-1} \times 2 + 1$,根据 $\text{len}_1=1$ 可知 $\text{len}_n=2^n-1$。

由于 $S_1=``0"$,且对于任意 $n \ge 1$,$S_n$ 的第 $1$ 位字符也一定是 $0'$,因此当 $k=1$ 时,直接返回字符 $0'$。

当 $n>1$ 时,$S_n$ 的长度是 $2^n-1$。$S_n$ 可以分成三个部分,左边 $2^{n-1}-1$ 个字符是 $S_{n-1}$,中间 $1$ 个字符是 $1'$,右边 $2^{n-1}-1$ 个字符是 $S_{n-1}$ 翻转与反转之后的结果。中间的字符 $1'$ 是 $S_n$ 的第 $2^{n-1}$ 位字符,因此如果 $k=2^{n-1}$,直接返回字符 $`1'$。

当 $k \ne 2^{n-1}$ 时,考虑以下两种情况:

  • 如果 $k<2^{n-1}$,则第 $k$ 位字符在 $S_n$ 的前半部分,即第 $k$ 位字符在 $S_{n-1}$ 中,因此在 $S_{n-1}$ 中寻找第 $k$ 位字符;

  • 如果 $k>2^{n-1}$,则第 $k$ 位字符在 $S_n$ 的后半部分,由于后半部分为前半部分进行翻转与反转之后的结果,因此在前半部分寻找第 $2^n-k$ 位字符,将其反转之后即为 $S_n$ 的第 $k$ 位字符。

上述过程可以通过递归实现。

###Java

class Solution {
    public char findKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return findKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return invert(findKthBit(n - 1, k));
        }
    }

    public char invert(char bit) {
        return (char) ('0' + '1' - bit);
    }
}

###JavaScript

const invert = (bit) => bit === '0' ? '1' : '0';

var findKthBit = function(n, k) {
    if (k == 1) {
        return '0';
    }
    const mid = 1 << (n - 1);
    if (k == mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
};

###C++

class Solution {
public:
    char findKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return findKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return invert(findKthBit(n - 1, k));
        }
    }

    char invert(char bit) {
        return (char) ('0' + '1' - bit);
    }
};

###Python

class Solution:
    def findKthBit(self, n: int, k: int) -> str:
        if k == 1:
            return "0"
        
        mid = 1 << (n - 1)
        if k == mid:
            return "1"
        elif k < mid:
            return self.findKthBit(n - 1, k)
        else:
            k = mid * 2 - k
            return "0" if self.findKthBit(n - 1, k) == "1" else "1"

###C#

public class Solution {
    public char FindKthBit(int n, int k) {
        if (k == 1) {
            return '0';
        }
        int mid = 1 << (n - 1);
        if (k == mid) {
            return '1';
        } else if (k < mid) {
            return FindKthBit(n - 1, k);
        } else {
            k = mid * 2 - k;
            return Invert(FindKthBit(n - 1, k));
        }
    }

    private char Invert(char bit) {
        return (char)('0' + '1' - bit);
    }
}

###Go

func findKthBit(n int, k int) byte {
    if k == 1 {
        return '0'
    }
    mid := 1 << (n - 1)
    if k == mid {
        return '1'
    } else if k < mid {
        return findKthBit(n - 1, k)
    } else {
        k = mid*2 - k
        return invert(findKthBit(n - 1, k))
    }
}

func invert(bit byte) byte {
    if bit == '0' {
        return '1'
    }
    return '0'
}

###C

char invert(char bit) {
    return '0' + '1' - bit;
}

char findKthBit(int n, int k) {
    if (k == 1) {
        return '0';
    }
    int mid = 1 << (n - 1);
    if (k == mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
}

###TypeScript

function findKthBit(n: number, k: number): string {
    if (k === 1) {
        return '0';
    }
    const mid = 1 << (n - 1);
    if (k === mid) {
        return '1';
    } else if (k < mid) {
        return findKthBit(n - 1, k);
    } else {
        k = mid * 2 - k;
        return invert(findKthBit(n - 1, k));
    }
}

function invert(bit: string): string {
    return bit === '0' ? '1' : '0';
}

###Rust

impl Solution {
    pub fn find_kth_bit(n: i32, k: i32) -> char {
        Self::find_kth_bit_recursive(n, k)
    }
    
    fn find_kth_bit_recursive(n: i32, k: i32) -> char {
        if k == 1 {
            return '0';
        }
        let mid = 1 << (n - 1);
        if k == mid {
            return '1';
        } else if k < mid {
            return Self::find_kth_bit_recursive(n - 1, k);
        } else {
            let new_k = mid * 2 - k;
            return Self::invert(Self::find_kth_bit_recursive(n - 1, new_k));
        }
    }
    
    fn invert(bit: char) -> char {
        if bit == '0' {
            '1'
        } else {
            '0'
        }
    }
}

复杂度分析

  • 时间复杂度:$O(n)$。字符串 $S_n$ 的长度为 $2^n-1$,每次递归调用可以将查找范围缩小一半,因此时间复杂度为 $O(\log 2^n)=O(n)$。

  • 空间复杂度:$O(n)$。空间复杂度主要取决于递归调用产生的栈空间,递归调用层数不会超过 $n$。

从根到叶的二进制数之和

2022年5月27日 19:47

前言

关于二叉树后序遍历的详细说明请参考「145. 二叉树的后序遍历的官方题解」。

方法一:递归

后序遍历的访问顺序为:左子树——右子树——根节点。我们对根节点 $\textit{root}$ 进行后序遍历:

  • 如果节点是叶子节点,返回它对应的数字 $\textit{val}$。

  • 如果节点是非叶子节点,返回它的左子树和右子树对应的结果之和。

###Python

class Solution:
    def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
        def dfs(node: Optional[TreeNode], val: int) -> int:
            if node is None:
                return 0
            val = (val << 1) | node.val
            if node.left is None and node.right is None:
                return val
            return dfs(node.left, val) + dfs(node.right, val)
        return dfs(root, 0)

###C++

class Solution {
public:
    int dfs(TreeNode *root, int val) {
        if (root == nullptr) {
            return 0;
        }
        val = (val << 1) | root->val;
        if (root->left == nullptr && root->right == nullptr) {
            return val;
        }
        return dfs(root->left, val) + dfs(root->right, val);
    }

    int sumRootToLeaf(TreeNode* root) {
        return dfs(root, 0);
    }
};

###Java

class Solution {
    public int sumRootToLeaf(TreeNode root) {
        return dfs(root, 0);
    }

    public int dfs(TreeNode root, int val) {
        if (root == null) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (root.left == null && root.right == null) {
            return val;
        }
        return dfs(root.left, val) + dfs(root.right, val);
    }
}

###C#

public class Solution {
    public int SumRootToLeaf(TreeNode root) {
        return DFS(root, 0);
    }

    public int DFS(TreeNode root, int val) {
        if (root == null) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (root.left == null && root.right == null) {
            return val;
        }
        return DFS(root.left, val) + DFS(root.right, val);
    }
}

###C

int dfs(struct TreeNode *root, int val) {
    if (root == NULL) {
        return 0;
    }
    val = (val << 1) | root->val;
    if (root->left == NULL && root->right == NULL) {
        return val;
    }
    return dfs(root->left, val) + dfs(root->right, val);
}

int sumRootToLeaf(struct TreeNode* root){
    return dfs(root, 0);
}

###JavaScript

var sumRootToLeaf = function(root) {
    const dfs = (root, val) => {
        if (!root) {
            return 0;
        }
        val = (val << 1) | root.val;
        if (!root.left&& !root.right) {
            return val;
        }
        return dfs(root.left, val) + dfs(root.right, val);
    }
    return dfs(root, 0);
};

###go

func dfs(node *TreeNode, val int) int {
    if node == nil {
        return 0
    }
    val = val<<1 | node.Val
    if node.Left == nil && node.Right == nil {
        return val
    }
    return dfs(node.Left, val) + dfs(node.Right, val)
}

func sumRootToLeaf(root *TreeNode) int {
    return dfs(root, 0)
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是节点数目。总共访问 $n$ 个节点。

  • 空间复杂度:$O(n)$。递归栈需要 $O(n)$ 的空间。

方法二:迭代

我们用栈来模拟递归,同时使用一个 $\textit{prev}$ 指针来记录先前访问的节点。算法步骤如下:

  1. 如果节点 $\textit{root}$ 非空,我们将不断地将它及它的左节点压入栈中。

  2. 我们从栈中获取节点:

    • 该节点的右节点为空或者等于 $\textit{prev}$,说明该节点的左子树及右子树都已经被访问,我们将它出栈。如果该节点是叶子节点,我们将它对应的数字 $\textit{val}$ 加入结果中。设置 $\textit{prev}$ 为该节点,设置 $\textit{root}$ 为空指针。

    • 该节点的右节点非空且不等于 $\textit{prev}$,我们令 $\textit{root}$ 指向该节点的右节点。

  3. 如果 $\textit{root}$ 为空指针或者栈空,中止算法,否则重复步骤 $1$。

需要注意的是,每次出入栈都需要更新 $\textit{val}$。

###Python

class Solution:
    def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
        ans = val = 0
        st = []
        pre = None
        while root or st:
            while root:
                val = (val << 1) | root.val
                st.append(root)
                root = root.left
            root = st[-1]
            if root.right is None or root.right == pre:
                if root.left is None and root.right is None:
                    ans += val
                val >>= 1
                st.pop()
                pre = root
                root = None
            else:
                root = root.right
        return ans

###C++

class Solution {
public:
    int sumRootToLeaf(TreeNode* root) {
        stack<TreeNode *> st;
        int val = 0, ret = 0;
        TreeNode *prev = nullptr;
        while (root != nullptr || !st.empty()) {
            while (root != nullptr) {
                val = (val << 1) | root->val;
                st.push(root);
                root = root->left;
            }
            root = st.top();
            if (root->right == nullptr || root->right == prev) {
                if (root->left == nullptr && root->right == nullptr) {
                    ret += val;
                }
                val >>= 1;
                st.pop();
                prev = root;
                root = nullptr;
            } else {
                root = root->right;
            }
        }
        return ret;
    }
};

###Java

class Solution {
    public int sumRootToLeaf(TreeNode root) {
        Deque<TreeNode> stack = new ArrayDeque<TreeNode>();
        int val = 0, ret = 0;
        TreeNode prev = null;
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                val = (val << 1) | root.val;
                stack.push(root);
                root = root.left;
            }
            root = stack.peek();
            if (root.right == null || root.right == prev) {
                if (root.left == null && root.right == null) {
                    ret += val;
                }
                val >>= 1;
                stack.pop();
                prev = root;
                root = null;
            } else {
                root = root.right;
            }
        }
        return ret;
    }
}

###C#

public class Solution {
    public int SumRootToLeaf(TreeNode root) {
        Stack<TreeNode> stack = new Stack<TreeNode>();
        int val = 0, ret = 0;
        TreeNode prev = null;
        while (root != null || stack.Count > 0) {
            while (root != null) {
                val = (val << 1) | root.val;
                stack.Push(root);
                root = root.left;
            }
            root = stack.Peek();
            if (root.right == null || root.right == prev) {
                if (root.left == null && root.right == null) {
                    ret += val;
                }
                val >>= 1;
                stack.Pop();
                prev = root;
                root = null;
            } else {
                root = root.right;
            }
        }
        return ret;
    }
}

###C

#define MAX_NODE_SIZE 1000

int sumRootToLeaf(struct TreeNode* root) {
    struct TreeNode ** stack = (struct TreeNode **)malloc(sizeof(struct TreeNode *) * MAX_NODE_SIZE);
    int top = 0;
    int val = 0, ret = 0;
    struct TreeNode *prev = NULL;
    while (root != NULL || top) {
        while (root != NULL) {
            val = (val << 1) | root->val;
            stack[top++] = root;
            root = root->left;
        }
        root = stack[top - 1];
        if (root->right == NULL || root->right == prev) {
            if (root->left == NULL && root->right == NULL) {
                ret += val;
            }
            val >>= 1;
            top--;
            prev = root;
            root = NULL;
        } else {
            root = root->right;
        }
    }
    free(stack);
    return ret;
}

###JavaScript

var sumRootToLeaf = function(root) {
    const stack = [];
    let val = 0, ret = 0;
    let prev = null;
    while (root || stack.length) {
        while (root) {
            val = (val << 1) | root.val;
            stack.push(root);
            root = root.left;
        }
        root = stack[stack.length - 1];
        if (!root.right || root.right === prev) {
            if (!root.left && !root.right) {
                ret += val;
            }
            val >>= 1;
            stack.pop();
            prev = root;
            root = null;
        } else {
            root = root.right;
        }
    }
    return ret;
};

###go

func sumRootToLeaf(root *TreeNode) (ans int) {
    val, st := 0, []*TreeNode{}
    var pre *TreeNode
    for root != nil || len(st) > 0 {
        for root != nil {
            val = val<<1 | root.Val
            st = append(st, root)
            root = root.Left
        }
        root = st[len(st)-1]
        if root.Right == nil || root.Right == pre {
            if root.Left == nil && root.Right == nil {
                ans += val
            }
            val >>= 1
            st = st[:len(st)-1]
            pre = root
            root = nil
        } else {
            root = root.Right
        }
    }
    return
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是节点数目。总共访问 $n$ 个节点。

  • 空间复杂度:$O(n)$。栈最多压入 $n$ 个节点。

检查一个字符串是否包含所有长度为 K 的二进制子串

2020年12月12日 20:47

方法一:哈希表

我们遍历字符串 $s$,并用一个哈希集合(HashSet)存储所有长度为 $k$ 的子串。在遍历完成后,只需要判断哈希集合中是否有 $2^k$ 项即可,这是因为长度为 $k$ 的二进制串的数量为 $2^k$。

注意到如果 $s$ 包含 $2^k$ 个长度为 $k$ 的二进制串,那么它的长度至少为 $2^k+k-1$。因此我们可以在遍历前判断 $s$ 是否足够长。

###C++

class Solution {
public:
    bool hasAllCodes(string s, int k) {
        if (s.size() < (1 << k) + k - 1) {
            return false;
        }

        unordered_set<string> exists;
        for (int i = 0; i + k <= s.size(); ++i) {
            exists.insert(move(s.substr(i, k)));
        }
        return exists.size() == (1 << k);
    }
};

###C++

class Solution {
public:
    bool hasAllCodes(string s, int k) {
        if (s.size() < (1 << k) + k - 1) {
            return false;
        }

        string_view sv(s);
        unordered_set<string_view> exists;
        for (int i = 0; i + k <= s.size(); ++i) {
            exists.insert(sv.substr(i, k));
        }
        return exists.size() == (1 << k);
    }
};

###Python

class Solution:
    def hasAllCodes(self, s: str, k: int) -> bool:
        if len(s) < (1 << k) + k - 1:
            return False
        
        exists = set(s[i:i+k] for i in range(len(s) - k + 1))
        return len(exists) == (1 << k)

###Java

class Solution {
    public boolean hasAllCodes(String s, int k) {
        if (s.length() < (1 << k) + k - 1) {
            return false;
        }

        Set<String> exists = new HashSet<String>();
        for (int i = 0; i + k <= s.length(); ++i) {
            exists.add(s.substring(i, i + k));
        }
        return exists.size() == (1 << k);
    }
}

###C#

public class Solution {
    public bool HasAllCodes(string s, int k) {
        if (s.Length < (1 << k) + k - 1) {
            return false;
        }

        HashSet<string> exists = new HashSet<string>();
        for (int i = 0; i + k <= s.Length; ++i) {
            exists.Add(s.Substring(i, k));
        }
        return exists.Count == (1 << k);
    }
}

###Go

func hasAllCodes(s string, k int) bool {
    if len(s) < (1 << k) + k - 1 {
        return false
    }

    exists := make(map[string]bool)
    for i := 0; i + k <= len(s); i++ {
        substring := s[i:i+k]
        exists[substring] = true
    }
    return len(exists) == (1 << k)
}

###C


typedef struct {
    char *key;
    UT_hash_handle hh;
} HashItem; 

HashItem *hashFindItem(HashItem **obj, char *key) {
    HashItem *pEntry = NULL;
    HASH_FIND_STR(*obj, key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, char *key) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = (HashItem *)malloc(sizeof(HashItem));
    pEntry->key = strdup(key);
    HASH_ADD_STR(*obj, key, pEntry);
    return true;
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr); 
        free(curr->key); 
        free(curr);             
    }
}

bool hasAllCodes(char* s, int k) {
    int len = strlen(s);
    int total = 1 << k;
    if (len < total + k - 1) {
        return false;
    }

    HashItem *exists = NULL;
    for (int i = 0; i + k <= len; ++i) {
        char tmp[k + 1];
        strncpy(tmp, s + i, k);
        tmp[k] = '\0';
        hashAddItem(&exists, tmp);
    }

    bool ret = HASH_COUNT(exists) == (1 << k);
    hashFree(&exists);
    return ret;
}

###JavaScript

var hasAllCodes = function(s, k) {
    if (s.length < (1 << k) + k - 1) {
        return false;
    }

    const exists = new Set();
    for (let i = 0; i + k <= s.length; ++i) {
        exists.add(s.substring(i, i + k));
    }
    return exists.size === (1 << k);
};

###TypeScript

function hasAllCodes(s: string, k: number): boolean {
    if (s.length < (1 << k) + k - 1) {
        return false;
    }

    const exists = new Set<string>();
    for (let i = 0; i + k <= s.length; ++i) {
        exists.add(s.substring(i, i + k));
    }
    return exists.size === (1 << k);
}

###Rust

use std::collections::HashSet;

impl Solution {
    pub fn has_all_codes(s: String, k: i32) -> bool {
        let k = k as usize;
        let total = 1 << k;
        
        if s.len() < total + k - 1 {
            return false;
        }

        let mut exists = HashSet::new();
        for i in 0..=(s.len() - k) {
            exists.insert(&s[i..i + k]);
        }
        exists.len() == total
    }
}

复杂度分析

  • 时间复杂度:$O(k * |s|)$,其中 $|s|$ 是字符串 $s$ 的长度。将长度为 $k$ 的字符串加入哈希集合的时间复杂度为 $O(k)$,即为计算哈希值的时间。

  • 空间复杂度:$O(k * 2^k)$。哈希集合中最多有 $2^k$ 项,每一项是一个长度为 $k$ 的字符串。

方法二:哈希表 + 滑动窗口

我们可以借助滑动窗口,对方法一进行优化。

假设我们当前遍历到的长度为 $k$ 的子串为

$$
s_i, s_{i+1}, \cdots, s_{i+k-1}
$$

它的下一个子串为

$$
s_{i+1}, s_{i+2}, \cdots, s_{i+k}
$$

由于这些子串都是二进制串,我们可以将其表示成对应的十进制整数的形式,即

$$
\begin{aligned}
& \textit{num}i &= s_i * 2^{k-1} + s{i+1} * 2^{k-2} + \cdots + s_{i+k-1} * 2^0 \
& \textit{num}{i+1} &= s{i+1} * 2^{k-1} + s_{i+2} * 2^{k-2} + \cdots + s_{i+k} * 2^0 \
\end{aligned}
$$

那么我们可以将这些十进制整数作为哈希表中的项。由于每一个长度为 $k$ 的二进制串都唯一对应了一个十进制整数,因此这样做与方法一是一致的。与二进制串本身不同的是,我们可以在 $O(1)$ 的时间内通过 $\textit{num}i$ 得到 $\textit{num}{i+1}$,即:

$$
num_{i+1} = (num_{i} - s_i * 2^{k-1}) * 2 + s_{i+k}
$$

这样以来,我们在遍历 $s$ 的过程中只维护子串对应的十进制整数,而不需要对字符串进行操作,从而减少了时间复杂度。

###C++

class Solution {
public:
    bool hasAllCodes(string s, int k) {
        if (s.size() < (1 << k) + k - 1) {
            return false;
        }

        int num = stoi(s.substr(0, k), nullptr, 2);
        unordered_set<int> exists = {num};
        
        for (int i = 1; i + k <= s.size(); ++i) {
            num = (num - ((s[i - 1] - '0') << (k - 1))) * 2 + (s[i + k - 1] - '0');
            exists.insert(num);
        }
        return exists.size() == (1 << k);
    }
};

###Python

class Solution:
    def hasAllCodes(self, s: str, k: int) -> bool:
        if len(s) < (1 << k) + k - 1:
            return False
        
        num = int(s[:k], base=2)
        exists = set([num])

        for i in range(1, len(s) - k + 1):
            num = (num - ((ord(s[i - 1]) - 48) << (k - 1))) * 2 + (ord(s[i + k - 1]) - 48)
            exists.add(num)
        
        return len(exists) == (1 << k)

###Java

class Solution {
    public boolean hasAllCodes(String s, int k) {
        if (s.length() < (1 << k) + k - 1) {
            return false;
        }

        int num = Integer.parseInt(s.substring(0, k), 2);
        Set<Integer> exists = new HashSet<Integer>();
        exists.add(num);
        
        for (int i = 1; i + k <= s.length(); ++i) {
            num = (num - ((s.charAt(i - 1) - '0') << (k - 1))) * 2 + (s.charAt(i + k - 1) - '0');
            exists.add(num);
        }
        return exists.size() == (1 << k);
    }
}

###C#

public class Solution {
    public bool HasAllCodes(string s, int k) {
        if (s.Length < (1 << k) + k - 1) {
            return false;
        }

        int num = Convert.ToInt32(s.Substring(0, k), 2);
        HashSet<int> exists = new HashSet<int> { num };
        
        for (int i = 1; i + k <= s.Length; ++i) {
            num = (num - ((s[i - 1] - '0') << (k - 1))) * 2 + (s[i + k - 1] - '0');
            exists.Add(num);
        }
        return exists.Count == (1 << k);
    }
}

###Go

func hasAllCodes(s string, k int) bool {
    if len(s) < (1 << k) + k - 1 {
        return false
    }

    num := 0
    for i := 0; i < k; i++ {
        num = num << 1
        if s[i] == '1' {
            num |= 1
        }
    }
    
    exists := make(map[int]bool)
    exists[num] = true
    for i := 1; i + k <= len(s); i++ {
        num = (num - (int(s[i-1]-'0') << (k-1))) * 2 + int(s[i+k-1]-'0')
        exists[num] = true
    }
    return len(exists) == (1 << k)
}

###C

typedef struct {
    int key;        
    UT_hash_handle hh;
} HashItem;

HashItem *hashFindItem(HashItem **obj, int key) {
    HashItem *pEntry = NULL;
    HASH_FIND_INT(*obj, &key, pEntry);
    return pEntry;
}

bool hashAddItem(HashItem **obj, int key) {
    if (hashFindItem(obj, key)) {
        return false;
    }
    HashItem *pEntry = (HashItem *)malloc(sizeof(HashItem));
    pEntry->key = key;
    HASH_ADD_INT(*obj, key, pEntry);
    return true;
}

void hashFree(HashItem **obj) {
    HashItem *curr = NULL, *tmp = NULL;
    HASH_ITER(hh, *obj, curr, tmp) {
        HASH_DEL(*obj, curr);
        free(curr);
    }
}

bool hasAllCodes(char* s, int k) {
    int len = strlen(s);
    int total = 1 << k;
    if (len < total + k - 1) {
        return false;
    }

    int num = 0;
    for (int i = 0; i < k; i++) {
        num = (num << 1) | (s[i] - '0');
    }
    
    HashItem *exists = NULL;
    hashAddItem(&exists, num);
    for (int i = k; i < len; i++) {
        int mask = (1 << k) - 1;
        num = ((num << 1) | (s[i] - '0')) & mask;
        hashAddItem(&exists, num);
    }

    bool ret = HASH_COUNT(exists) == total;
    hashFree(&exists);
    return ret;
}

###JavaScript

var hasAllCodes = function(s, k) {
    if (s.length < (1 << k) + k - 1) {
        return false;
    }

    let num = parseInt(s.substring(0, k), 2);
    const exists = new Set([num]);
    for (let i = 1; i + k <= s.length; ++i) {
        num = (num - (parseInt(s[i - 1]) << (k - 1))) * 2 + parseInt(s[i + k - 1]);
        exists.add(num);
    }
    return exists.size === (1 << k);
};

###TypeScript

function hasAllCodes(s: string, k: number): boolean {
    if (s.length < (1 << k) + k - 1) {
        return false;
    }

    let num = parseInt(s.substring(0, k), 2);
    const exists = new Set<number>([num]);
    for (let i = 1; i + k <= s.length; ++i) {
        num = (num - (parseInt(s[i - 1]) << (k - 1))) * 2 + parseInt(s[i + k - 1]);
        exists.add(num);
    }
    return exists.size === (1 << k);
}

###Rust

use std::collections::HashSet;

impl Solution {
    pub fn has_all_codes(s: String, k: i32) -> bool {
        let k = k as usize;
        let total = 1 << k;
        
        if s.len() < total + k - 1 {
            return false;
        }

        let bytes = s.as_bytes();
        let mut num = 0;
        for i in 0..k {
            num = (num << 1) | (bytes[i] - b'0') as usize;
        }

        let mut exists = HashSet::new();
        exists.insert(num);
        for i in 1..=bytes.len() - k {
            let high_bit = ((bytes[i - 1] - b'0') as usize) << (k - 1);
            num = (num - high_bit) << 1 | (bytes[i + k - 1] - b'0') as usize;
            exists.insert(num);
        }
        
        exists.len() == total
    }
}

复杂度分析

  • 时间复杂度:$O(|s|)$,其中 $|s|$ 是字符串 $s$ 的长度。

  • 空间复杂度:$O(2^k)$。哈希集合中最多有 $2^k$ 项,每一项是一个十进制整数。

二进制间距

2022年4月22日 22:07

方法一:位运算

思路与算法

我们可以使用一个循环从 $n$ 二进制表示的低位开始进行遍历,并找出所有的 $1$。我们用一个变量 $\textit{last}$ 记录上一个找到的 $1$ 的位置。如果当前在第 $i$ 位找到了 $1$,那么就用 $i - \textit{last}$ 更新答案,再将 $\textit{last}$ 更新为 $i$ 即可。

在循环的每一步中,我们可以使用位运算 $\texttt{n & 1}$ 获取 $n$ 的最低位,判断其是否为 $1$。在这之后,我们将 $n$ 右移一位:$\texttt{n = n >> 1}$,这样在第 $i$ 步时,$\texttt{n & 1}$ 得到的就是初始 $n$ 的第 $i$ 个二进制位。

代码

###Python

class Solution:
    def binaryGap(self, n: int) -> int:
        last, ans, i = -1, 0, 0
        while n:
            if n & 1:
                if last != -1:
                    ans = max(ans, i - last)
                last = i
            n >>= 1
            i += 1
        return ans

###C++

class Solution {
public:
    int binaryGap(int n) {
        int last = -1, ans = 0;
        for (int i = 0; n; ++i) {
            if (n & 1) {
                if (last != -1) {
                    ans = max(ans, i - last);
                }
                last = i;
            }
            n >>= 1;
        }
        return ans;
    }
};

###Java

class Solution {
    public int binaryGap(int n) {
        int last = -1, ans = 0;
        for (int i = 0; n != 0; ++i) {
            if ((n & 1) == 1) {
                if (last != -1) {
                    ans = Math.max(ans, i - last);
                }
                last = i;
            }
            n >>= 1;
        }
        return ans;
    }
}

###C#

public class Solution {
    public int BinaryGap(int n) {
        int last = -1, ans = 0;
        for (int i = 0; n != 0; ++i) {
            if ((n & 1) == 1) {
                if (last != -1) {
                    ans = Math.Max(ans, i - last);
                }
                last = i;
            }
            n >>= 1;
        }
        return ans;
    }
}

###C

#define MAX(a, b) ((a) > (b) ? (a) : (b))

int binaryGap(int n) {
    int last = -1, ans = 0;
    for (int i = 0; n; ++i) {
        if (n & 1) {
            if (last != -1) {
                ans = MAX(ans, i - last);
            }
            last = i;
        }
        n >>= 1;
    }
    return ans;
}

###go

func binaryGap(n int) (ans int) {
    for i, last := 0, -1; n > 0; i++ {
        if n&1 == 1 {
            if last != -1 {
                ans = max(ans, i-last)
            }
            last = i
        }
        n >>= 1
    }
    return
}

func max(a, b int) int {
    if b > a {
        return b
    }
    return a
}

###JavaScript

var binaryGap = function(n) {
    let last = -1, ans = 0;
    for (let i = 0; n != 0; ++i) {
        if ((n & 1) === 1) {
            if (last !== -1) {
                ans = Math.max(ans, i - last);
            }
            last = i;
        }
        n >>= 1;
    }
    return ans;
};

复杂度分析

  • 时间复杂度:$O(\log n)$。循环中的每一步 $n$ 会减少一半,因此需要 $O(\log n)$ 次循环。

  • 空间复杂度:$O(1)$。

二进制表示中质数个计算置位

2022年4月2日 18:44

方法一:数学 + 位运算

我们可以枚举 $[\textit{left},\textit{right}]$ 范围内的每个整数,挨个判断是否满足题目要求。

对于每个数 $x$,我们需要解决两个问题:

  1. 如何求出 $x$ 的二进制中的 $1$ 的个数,见「191. 位 1 的个数」,下面代码用库函数实现;
  2. 如何判断一个数是否为质数,见「204. 计数质数」的「官方解法」的方法一(注意 $0$ 和 $1$ 不是质数)。

###Python

class Solution:
    def isPrime(self, x: int) -> bool:
        if x < 2:
            return False
        i = 2
        while i * i <= x:
            if x % i == 0:
                return False
            i += 1
        return True

    def countPrimeSetBits(self, left: int, right: int) -> int:
        return sum(self.isPrime(x.bit_count()) for x in range(left, right + 1))

###C++

class Solution {
    bool isPrime(int x) {
        if (x < 2) {
            return false;
        }
        for (int i = 2; i * i <= x; ++i) {
            if (x % i == 0) {
                return false;
            }
        }
        return true;
    }

public:
    int countPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if (isPrime(__builtin_popcount(x))) {
                ++ans;
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public int countPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if (isPrime(Integer.bitCount(x))) {
                ++ans;
            }
        }
        return ans;
    }

    private boolean isPrime(int x) {
        if (x < 2) {
            return false;
        }
        for (int i = 2; i * i <= x; ++i) {
            if (x % i == 0) {
                return false;
            }
        }
        return true;
    }
}

###C#

public class Solution {
    public int CountPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if (IsPrime(BitCount(x))) {
                ++ans;
            }
        }
        return ans;
    }

    private bool IsPrime(int x) {
        if (x < 2) {
            return false;
        }
        for (int i = 2; i * i <= x; ++i) {
            if (x % i == 0) {
                return false;
            }
        }
        return true;
    }

    private static int BitCount(int i) {
        i = i - ((i >> 1) & 0x55555555);
        i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
        i = (i + (i >> 4)) & 0x0f0f0f0f;
        i = i + (i >> 8);
        i = i + (i >> 16);
        return i & 0x3f;
    }
}

###go

func isPrime(x int) bool {
    if x < 2 {
        return false
    }
    for i := 2; i*i <= x; i++ {
        if x%i == 0 {
            return false
        }
    }
    return true
}

func countPrimeSetBits(left, right int) (ans int) {
    for x := left; x <= right; x++ {
        if isPrime(bits.OnesCount(uint(x))) {
            ans++
        }
    }
    return
}

###C

bool isPrime(int x) {
    if (x < 2) {
        return false;
    }
    for (int i = 2; i * i <= x; ++i) {
        if (x % i == 0) {
            return false;
        }
    }
    return true;
}

int countPrimeSetBits(int left, int right){
    int ans = 0;
    for (int x = left; x <= right; ++x) {
        if (isPrime(__builtin_popcount(x))) {
            ++ans;
        }
    }
    return ans;
}

###JavaScript

var countPrimeSetBits = function(left, right) {
    let ans = 0;
    for (let x = left; x <= right; ++x) {
        if (isPrime(bitCount(x))) {
            ++ans;
        }
    }
    return ans;
};

const isPrime = (x) => {
    if (x < 2) {
        return false;
    }
    for (let i = 2; i * i <= x; ++i) {
        if (x % i === 0) {
            return false;
        }
    }
    return true;
}

const bitCount = (x) => {
    return x.toString(2).split('0').join('').length;
}

复杂度分析

  • 时间复杂度:$O((\textit{right}-\textit{left})\sqrt{\log\textit{right}})$。二进制中 $1$ 的个数为 $O(\log\textit{right})$,判断值为 $x$ 的数是否为质数的时间为 $O(\sqrt{x})$。

  • 空间复杂度:$O(1)$。我们只需要常数的空间保存若干变量。

方法二:判断质数优化

注意到 $\textit{right} \le 10^6 < 2^{20}$,因此二进制中 $1$ 的个数不会超过 $19$,而不超过 $19$ 的质数只有

$$
2, 3, 5, 7, 11, 13, 17, 19
$$

我们可以用一个二进制数 $\textit{mask}=665772=10100010100010101100_{2}$ 来存储这些质数,其中 $\textit{mask}$ 二进制的从低到高的第 $i$ 位为 $1$ 表示 $i$ 是质数,为 $0$ 表示 $i$ 不是质数。

设整数 $x$ 的二进制中 $1$ 的个数为 $c$,若 $\textit{mask}$ 按位与 $2^c$ 不为 $0$,则说明 $c$ 是一个质数。

###Python

class Solution:
    def countPrimeSetBits(self, left: int, right: int) -> int:
        return sum(((1 << x.bit_count()) & 665772) != 0 for x in range(left, right + 1))

###C++

class Solution {
public:
    int countPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if ((1 << __builtin_popcount(x)) & 665772) {
                ++ans;
            }
        }
        return ans;
    }
};

###Java

class Solution {
    public int countPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if (((1 << Integer.bitCount(x)) & 665772) != 0) {
                ++ans;
            }
        }
        return ans;
    }
}

###C#

public class Solution {
    public int CountPrimeSetBits(int left, int right) {
        int ans = 0;
        for (int x = left; x <= right; ++x) {
            if (((1 << BitCount(x)) & 665772) != 0) {
                ++ans;
            }
        }
        return ans;
    }

    private static int BitCount(int i) {
        i = i - ((i >> 1) & 0x55555555);
        i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
        i = (i + (i >> 4)) & 0x0f0f0f0f;
        i = i + (i >> 8);
        i = i + (i >> 16);
        return i & 0x3f;
    }
}

###go

func countPrimeSetBits(left, right int) (ans int) {
    for x := left; x <= right; x++ {
        if 1<<bits.OnesCount(uint(x))&665772 != 0 {
            ans++
        }
    }
    return
}

###C

int countPrimeSetBits(int left, int right){
    int ans = 0;
    for (int x = left; x <= right; ++x) {
        if ((1 << __builtin_popcount(x)) & 665772) {
            ++ans;
        }
    }
    return ans;
}

###JavaScript

var countPrimeSetBits = function(left, right) {
    let ans = 0;
    for (let x = left; x <= right; ++x) {
        if (((1 << bitCount(x)) & 665772) != 0) {
            ++ans;
        }
    }
    return ans;
};

const bitCount = (x) => {
    return x.toString(2).split('0').join('').length;
}

复杂度分析

  • 时间复杂度:$O(\textit{right}-\textit{left})$。

  • 空间复杂度:$O(1)$。我们只需要常数的空间保存若干变量。

计数二进制子串

2020年8月9日 21:31

方法一:按字符分组

思路与算法

我们可以将字符串 $s$ 按照 $0$ 和 $1$ 的连续段分组,存在 $\textit{counts}$ 数组中,例如 $s = 00111011$,可以得到这样的 $\textit{counts}$ 数组:$\textit{counts} = {2, 3, 1, 2}$。

这里 $\textit{counts}$ 数组中两个相邻的数一定代表的是两种不同的字符。假设 $\textit{counts}$ 数组中两个相邻的数字为 $u$ 或者 $v$,它们对应着 $u$ 个 $0$ 和 $v$ 个 $1$,或者 $u$ 个 $1$ 和 $v$ 个 $0$。它们能组成的满足条件的子串数目为 $\min { u, v }$,即一对相邻的数字对答案的贡献。

我们只要遍历所有相邻的数对,求它们的贡献总和,即可得到答案。

不难得到这样的实现:

###C++

class Solution {
public:
    int countBinarySubstrings(string s) {
        vector<int> counts;
        int ptr = 0, n = s.size();
        while (ptr < n) {
            char c = s[ptr];
            int count = 0;
            while (ptr < n && s[ptr] == c) {
                ++ptr;
                ++count;
            }
            counts.push_back(count);
        }
        int ans = 0;
        for (int i = 1; i < counts.size(); ++i) {
            ans += min(counts[i], counts[i - 1]);
        }
        return ans;
    }
};

###Java

class Solution {
    public int countBinarySubstrings(String s) {
        List<Integer> counts = new ArrayList<Integer>();
        int ptr = 0, n = s.length();
        while (ptr < n) {
            char c = s.charAt(ptr);
            int count = 0;
            while (ptr < n && s.charAt(ptr) == c) {
                ++ptr;
                ++count;
            }
            counts.add(count);
        }
        int ans = 0;
        for (int i = 1; i < counts.size(); ++i) {
            ans += Math.min(counts.get(i), counts.get(i - 1));
        }
        return ans;
    }
}

###JavaScript

var countBinarySubstrings = function(s) {
    const counts = [];
    let ptr = 0, n = s.length;
    while (ptr < n) {
        const c = s.charAt(ptr);
        let count = 0;
        while (ptr < n && s.charAt(ptr) === c) {
            ++ptr;
            ++count;
        }
        counts.push(count);
    }
    let ans = 0;
    for (let i = 1; i < counts.length; ++i) {
        ans += Math.min(counts[i], counts[i - 1]);
    }
    return ans;
};

###Go

func countBinarySubstrings(s string) int {
    counts := []int{}
    ptr, n := 0, len(s)
    for ptr < n {
        c := s[ptr]
        count := 0
        for ptr < n && s[ptr] == c {
            ptr++
            count++
        }
        counts = append(counts, count)
    }
    ans := 0
    for i := 1; i < len(counts); i++ {
        ans += min(counts[i], counts[i-1])
    }
    return ans
}

###C

int countBinarySubstrings(char* s) {
    int n = strlen(s);
    int counts[n], counts_len = 0;
    memset(counts, 0, sizeof(counts));
    int ptr = 0;
    while (ptr < n) {
        char c = s[ptr];
        int count = 0;
        while (ptr < n && s[ptr] == c) {
            ++ptr;
            ++count;
        }
        counts[counts_len++] = count;
    }
    int ans = 0;
    for (int i = 1; i < counts_len; ++i) {
        ans += fmin(counts[i], counts[i - 1]);
    }
    return ans;
}

###Python

class Solution:
    def countBinarySubstrings(self, s: str) -> int:
        counts = []
        ptr, n = 0, len(s)
        
        while ptr < n:
            c = s[ptr]
            count = 0
            while ptr < n and s[ptr] == c:
                ptr += 1
                count += 1
            counts.append(count)
        
        ans = 0
        for i in range(1, len(counts)):
            ans += min(counts[i], counts[i - 1])
        
        return ans

###C#

public class Solution {
    public int CountBinarySubstrings(string s) {
        List<int> counts = new List<int>();
        int ptr = 0, n = s.Length;
        
        while (ptr < n) {
            char c = s[ptr];
            int count = 0;
            while (ptr < n && s[ptr] == c) {
                ptr++;
                count++;
            }
            counts.Add(count);
        }
        
        int ans = 0;
        for (int i = 1; i < counts.Count; i++) {
            ans += Math.Min(counts[i], counts[i - 1]);
        }
        
        return ans;
    }
}

###TypeScript

function countBinarySubstrings(s: string): number {
    const counts: number[] = [];
    let ptr = 0, n = s.length;
    
    while (ptr < n) {
        const c = s[ptr];
        let count = 0;
        while (ptr < n && s[ptr] === c) {
            ptr++;
            count++;
        }
        counts.push(count);
    }
    
    let ans = 0;
    for (let i = 1; i < counts.length; i++) {
        ans += Math.min(counts[i], counts[i - 1]);
    }
    
    return ans;
}

###Rust

impl Solution {
    pub fn count_binary_substrings(s: String) -> i32 {
        let mut counts = Vec::new();
        let bytes = s.as_bytes();
        let n = bytes.len();
        let mut ptr = 0;
        
        while ptr < n {
            let c = bytes[ptr];
            let mut count = 0;
            while ptr < n && bytes[ptr] == c {
                ptr += 1;
                count += 1;
            }
            counts.push(count);
        }
        
        let mut ans = 0;
        for i in 1..counts.len() {
            ans += counts[i].min(counts[i - 1]);
        }
        
        ans
    }
}

这个实现的时间复杂度和空间复杂度都是 $O(n)$。

对于某一个位置 $i$,其实我们只关心 $i - 1$ 位置的 $\textit{counts}$ 值是多少,所以可以用一个 $\textit{last}$ 变量来维护当前位置的前一个位置,这样可以省去一个 $\textit{counts}$ 数组的空间。

代码

###C++

class Solution {
public:
    int countBinarySubstrings(string s) {
        int ptr = 0, n = s.size(), last = 0, ans = 0;
        while (ptr < n) {
            char c = s[ptr];
            int count = 0;
            while (ptr < n && s[ptr] == c) {
                ++ptr;
                ++count;
            }
            ans += min(count, last);
            last = count;
        }
        return ans;
    }
};

###Java

class Solution {
    public int countBinarySubstrings(String s) {
        int ptr = 0, n = s.length(), last = 0, ans = 0;
        while (ptr < n) {
            char c = s.charAt(ptr);
            int count = 0;
            while (ptr < n && s.charAt(ptr) == c) {
                ++ptr;
                ++count;
            }
            ans += Math.min(count, last);
            last = count;
        }
        return ans;
    }
}

###JavaScript

var countBinarySubstrings = function(s) {
    let ptr = 0, n = s.length, last = 0, ans = 0;
    while (ptr < n) {
        const c = s.charAt(ptr);
        let count = 0;
        while (ptr < n && s.charAt(ptr) === c) {
            ++ptr;
            ++count;
        }
        ans += Math.min(count, last);
        last = count;
    }
    return ans;
};

###Go

func countBinarySubstrings(s string) int {
    var ptr, last, ans int
    n := len(s)
    for ptr < n {
        c := s[ptr]
        count := 0
        for ptr < n && s[ptr] == c {
            ptr++
            count++
        }
        ans += min(count, last)
        last = count
    }

    return ans
}

###C

int countBinarySubstrings(char* s) {
    int ptr = 0, n = strlen(s), last = 0, ans = 0;
    while (ptr < n) {
        char c = s[ptr];
        int count = 0;
        while (ptr < n && s[ptr] == c) {
            ++ptr;
            ++count;
        }
        ans += fmin(count, last);
        last = count;
    }
    return ans;
}

###Python

class Solution:
    def countBinarySubstrings(self, s: str) -> int:
        ptr, n = 0, len(s)
        last, ans = 0, 0
        
        while ptr < n:
            c = s[ptr]
            count = 0
            while ptr < n and s[ptr] == c:
                ptr += 1
                count += 1
            ans += min(count, last)
            last = count
        
        return ans

###C#

public class Solution {
    public int CountBinarySubstrings(string s) {
        int ptr = 0, n = s.Length;
        int last = 0, ans = 0;
        
        while (ptr < n) {
            char c = s[ptr];
            int count = 0;
            while (ptr < n && s[ptr] == c) {
                ptr++;
                count++;
            }
            ans += Math.Min(count, last);
            last = count;
        }
        
        return ans;
    }
}

###TypeScript

function countBinarySubstrings(s: string): number {
    let ptr = 0, n = s.length;
    let last = 0, ans = 0;
    
    while (ptr < n) {
        const c = s[ptr];
        let count = 0;
        while (ptr < n && s[ptr] === c) {
            ptr++;
            count++;
        }
        ans += Math.min(count, last);
        last = count;
    }
    
    return ans;
}

###Rust

impl Solution {
    pub fn count_binary_substrings(s: String) -> i32 {
        let bytes = s.as_bytes();
        let n = bytes.len();
        let mut ptr = 0;
        let mut last = 0;
        let mut ans = 0;
        
        while ptr < n {
            let c = bytes[ptr];
            let mut count = 0;
            while ptr < n && bytes[ptr] == c {
                ptr += 1;
                count += 1;
            }
            ans += count.min(last);
            last = count;
        }
        
        ans
    }
}

复杂度分析

  • 时间复杂度:$O(n)$。
  • 空间复杂度:$O(1)$。

交替位二进制数

2022年3月26日 09:47

方法一:模拟

思路

从最低位至最高位,我们用对 $2$ 取模再除以 $2$ 的方法,依次求出输入的二进制表示的每一位,并与前一位进行比较。如果相同,则不符合条件;如果每次比较都不相同,则符合条件。

代码

###Python

class Solution:
    def hasAlternatingBits(self, n: int) -> bool:
        prev = 2
        while n:
            cur = n % 2
            if cur == prev:
                return False
            prev = cur
            n //= 2
        return True

###Java

class Solution {
    public boolean hasAlternatingBits(int n) {
        int prev = 2;
        while (n != 0) {
            int cur = n % 2;
            if (cur == prev) {
                return false;
            }
            prev = cur;
            n /= 2;
        }
        return true;
    }
}

###C#

public class Solution {
    public bool HasAlternatingBits(int n) {
        int prev = 2;
        while (n != 0) {
            int cur = n % 2;
            if (cur == prev) {
                return false;
            }
            prev = cur;
            n /= 2;
        }
        return true;
    }
}

###C++

class Solution {
public:
    bool hasAlternatingBits(int n) {
        int prev = 2;
        while (n != 0) {
            int cur = n % 2;
            if (cur == prev) {
                return false;
            }
            prev = cur;
            n /= 2;
        }
        return true;
    }
};

###C

bool hasAlternatingBits(int n) {
    int prev = 2;
    while (n != 0) {
        int cur = n % 2;
        if (cur == prev) {
            return false;
        }
        prev = cur;
        n /= 2;
    }
    return true;
} 

###go

func hasAlternatingBits(n int) bool {
    for pre := 2; n != 0; n /= 2 {
        cur := n % 2
        if cur == pre {
            return false
        }
        pre = cur
    }
    return true
}

###JavaScript

var hasAlternatingBits = function(n) {
    let prev = 2;
    while (n !== 0) {
        const cur = n % 2;
        if (cur === prev) {
            return false;
        }
        prev = cur;
        n = Math.floor(n / 2);
    }
    return true;
};

复杂度分析

  • 时间复杂度:$O(\log n)$。输入 $n$ 的二进制表示最多有 $O(\log n)$ 位。

  • 空间复杂度:$O(1)$。使用了常数空间来存储中间变量。

方法二:位运算

思路

对输入 $n$ 的二进制表示右移一位后,得到的数字再与 $n$ 按位异或得到 $a$。当且仅当输入 $n$ 为交替位二进制数时,$a$ 的二进制表示全为 $1$(不包括前导 $0$)。这里进行简单证明:当 $a$ 的某一位为 $1$ 时,当且仅当 $n$ 的对应位和其前一位相异。当 $a$ 的每一位为 $1$ 时,当且仅当 $n$ 的所有相邻位相异,即 $n$ 为交替位二进制数。

将 $a$ 与 $a + 1$ 按位与,当且仅当 $a$ 的二进制表示全为 $1$ 时,结果为 $0$。这里进行简单证明:当且仅当 $a$ 的二进制表示全为 $1$ 时,$a + 1$ 可以进位,并将原最高位置为 $0$,按位与的结果为 $0$。否则,不会产生进位,两个最高位都为 $1$,相与结果不为 $0$。

结合上述两步,可以判断输入是否为交替位二进制数。

代码

###Python

class Solution:
    def hasAlternatingBits(self, n: int) -> bool:
        a = n ^ (n >> 1)
        return a & (a + 1) == 0

###Java

class Solution {
    public boolean hasAlternatingBits(int n) {
        int a = n ^ (n >> 1);
        return (a & (a + 1)) == 0;
    }
}

###C#

public class Solution {
    public bool HasAlternatingBits(int n) {
        int a = n ^ (n >> 1);
        return (a & (a + 1)) == 0;
    }
}

###C++

class Solution {
public:
    bool hasAlternatingBits(int n) {
        long a = n ^ (n >> 1);
        return (a & (a + 1)) == 0;
    }
};

###C

bool hasAlternatingBits(int n) {
    long a = n ^ (n >> 1);
    return (a & (a + 1)) == 0;
}

###go

func hasAlternatingBits(n int) bool {
    a := n ^ n>>1
    return a&(a+1) == 0
}

###JavaScript

var hasAlternatingBits = function(n) {
    const a = n ^ (n >> 1);
    return (a & (a + 1)) === 0;
};

复杂度分析

  • 时间复杂度:$O(1)$。仅使用了常数时间来计算。

  • 空间复杂度:$O(1)$。使用了常数空间来存储中间变量。

❌
❌