两种方法:动态规划 / 组合数学(Python/Java/C++/Go)
本题和双周赛第四题是一样的,请看 我的题解。
本题和双周赛第四题是一样的,请看 我的题解。
把 $\textit{nums}$ 中的字符串转成二进制整数,保存到一个哈希集合中。
枚举 $\textit{ans} = 0,1,2,\ldots$ 直到 $\textit{ans}$ 不在哈希集合中,即为答案。
方法二告诉我们,满足要求的答案是一定存在的。
class Solution:
def findDifferentBinaryString(self, nums: List[str]) -> str:
st = {int(s, 2) for s in nums}
ans = 0
while ans in st:
ans += 1
n = len(nums)
return f"{ans:0{n}b}"
class Solution {
public String findDifferentBinaryString(String[] nums) {
Set<Integer> set = new HashSet<>();
for (String s : nums) {
set.add(Integer.parseInt(s, 2));
}
int ans = 0;
while (set.contains(ans)) {
ans++;
}
String bin = Integer.toBinaryString(ans);
return "0".repeat(nums.length - bin.length()) + bin;
}
}
class Solution {
public:
string findDifferentBinaryString(vector<string>& nums) {
unordered_set<int> st;
for (auto& s : nums) {
st.insert(stoi(s, nullptr, 2));
}
int ans = 0;
while (st.contains(ans)) {
ans++;
}
int n = nums.size();
return bitset<32>(ans).to_string().substr(32 - n);
}
};
func findDifferentBinaryString(nums []string) string {
n := len(nums)
has := make(map[int]bool, n)
for _, s := range nums {
x, _ := strconv.ParseInt(s, 2, 64)
has[int(x)] = true
}
ans := 0
for has[ans] {
ans++
}
return fmt.Sprintf("%0*b", n, ans)
}
这个方法灵感来自数学家康托关于「实数是不可数无限」的证明。
例如 $\textit{nums} = [\texttt{111}, \texttt{011}, \texttt{000}]$。我们可以构造一个字符串 $\textit{ans}$,满足:
$\textit{ans} = \texttt{001}$ 和每个 $\textit{nums}[i]$ 都至少有一个字符不同,满足题目要求。
一般地,令 $\textit{ans}[i] = \textit{nums}[i][i]\oplus 1$,即可满足要求。其中 $\oplus$ 是异或运算。
class Solution:
def findDifferentBinaryString(self, nums: List[str]) -> str:
ans = [''] * len(nums)
for i, s in enumerate(nums):
ans[i] = '1' if s[i] == '0' else '0'
return ''.join(ans)
class Solution {
public String findDifferentBinaryString(String[] nums) {
int n = nums.length;
char[] ans = new char[n];
for (int i = 0; i < n; i++) {
ans[i] = (char) (nums[i].charAt(i) ^ 1);
}
return new String(ans);
}
}
class Solution {
public:
string findDifferentBinaryString(vector<string>& nums) {
int n = nums.size();
string ans(n, 0);
for (int i = 0; i < n; i++) {
ans[i] = nums[i][i] ^ 1;
}
return ans;
}
};
func findDifferentBinaryString(nums []string) string {
ans := make([]byte, len(nums))
for i, s := range nums {
ans[i] = s[i] ^ 1
}
return string(ans)
}
欢迎关注 B站@灵茶山艾府
注意 $s$ 不含前导零,那么只含一段连续 $\texttt{1}$ 的 $s$ 只有两种情况:
如果 $s$ 包含多段连续的 $\texttt{1}$,比如示例 1 的 $s = \texttt{1001}$,在 $\texttt{0}$ 的后面还有 $\texttt{1}$。所以检查 $s$ 是否包含 $\texttt{01}$ 即可。
注:只有一个 $\texttt{1}$ 也算一段连续的 $\texttt{1}$。
###py
class Solution:
def checkOnesSegment(self, s: str) -> bool:
return "01" not in s
###java
class Solution {
public boolean checkOnesSegment(String s) {
return !s.contains("01");
}
}
###cpp
class Solution {
public:
bool checkOnesSegment(string s) {
return s.find("01") == string::npos;
}
};
###c
bool checkOnesSegment(char* s) {
return strstr(s, "01") == NULL;
}
###go
func checkOnesSegment(s string) bool {
return !strings.Contains(s, "01")
}
###js
var checkOnesSegment = function(s) {
return !s.includes("01");
};
###rust
impl Solution {
pub fn check_ones_segment(s: String) -> bool {
!s.contains("01")
}
}
欢迎关注 B站@灵茶山艾府
我们需要确定第 $k$ 个字符位于 $S_n$ 的左半、正中间还是右半。为此,首先要知道 $S_n$ 的长度。
用 $|s|$ 表示字符串 $s$ 的长度。根据题意,$|S_1| = 1$,$|S_n| = 2|S_{n-1}| + 1$,所以有
$$
|S_n| + 1 = 2(|S_{n-1}| + 1)
$$
所以 ${|S_n| + 1}$ 是个首项为 $2$,公比为 $2$ 的等比数列,得
$$
|S_n| = 2^n - 1
$$
所以 $|S_{n-1}| = 2^{n-1} - 1$,这说明 $S_n$ 的左半是第 $1$ 个字符到第 $2^{n-1}-1$ 个字符,正中间是第 $2^{n-1}$ 个字符,右半是第 $2^{n-1} + 1$ 个字符到第 $2^n-1$ 个字符。
分类讨论:
递归边界:
class Solution:
def findKthBit(self, n: int, k: int) -> str:
if n == 1:
return '0'
if k == 1 << (n - 1):
return '1'
if k < 1 << (n - 1):
return self.findKthBit(n - 1, k)
res = self.findKthBit(n - 1, (1 << n) - k)
return '0' if res == '1' else '1'
class Solution {
public char findKthBit(int n, int k) {
if (n == 1) {
return '0';
}
if (k == 1 << (n - 1)) {
return '1';
}
if (k < 1 << (n - 1)) {
return findKthBit(n - 1, k);
}
char res = findKthBit(n - 1, (1 << n) - k);
return (char) (res ^ 1);
}
}
class Solution {
public:
char findKthBit(int n, int k) {
if (n == 1) {
return '0';
}
if (k == 1 << (n - 1)) {
return '1';
}
if (k < 1 << (n - 1)) {
return findKthBit(n - 1, k);
}
return findKthBit(n - 1, (1 << n) - k) ^ 1;
}
};
char findKthBit(int n, int k) {
if (n == 1) {
return '0';
}
if (k == 1 << (n - 1)) {
return '1';
}
if (k < 1 << (n - 1)) {
return findKthBit(n - 1, k);
}
return findKthBit(n - 1, (1 << n) - k) ^ 1;
}
func findKthBit(n, k int) byte {
if n == 1 {
return '0'
}
if k == 1<<(n-1) {
return '1'
}
if k < 1<<(n-1) {
return findKthBit(n-1, k)
}
return findKthBit(n-1, 1<<n-k) ^ 1
}
var findKthBit = function(n, k) {
if (n === 1) {
return '0';
}
if (k === 1 << (n - 1)) {
return '1';
}
if (k < 1 << (n - 1)) {
return findKthBit(n - 1, k);
}
return findKthBit(n - 1, (1 << n) - k) === '1' ? '0' : '1';
};
impl Solution {
pub fn find_kth_bit(n: i32, k: i32) -> char {
if n == 1 {
return '0';
}
if k == 1 << (n - 1) {
return '1';
}
if k < 1 << (n - 1) {
return Self::find_kth_bit(n - 1, k);
}
(Self::find_kth_bit(n - 1, (1 << n) - k) as u8 ^ 1) as _
}
}
class Solution:
def findKthBit(self, n: int, k: int) -> str:
rev = 0 # 翻转次数的奇偶性
while True:
if n == 1:
return '1' if rev else '0'
if k == 1 << (n - 1):
return '0' if rev else '1'
if k > 1 << (n - 1):
k = (1 << n) - k
rev ^= 1
n -= 1
class Solution {
public char findKthBit(int n, int k) {
int rev = 0; // 翻转次数的奇偶性
while (true) {
if (n == 1) {
return (char) ('0' ^ rev);
}
if (k == 1 << (n - 1)) {
return (char) ('1' ^ rev);
}
if (k > 1 << (n - 1)) {
k = (1 << n) - k;
rev ^= 1;
}
n--;
}
}
}
class Solution {
public:
char findKthBit(int n, int k) {
int rev = 0; // 翻转次数的奇偶性
while (true) {
if (n == 1) {
return '0' ^ rev;
}
if (k == 1 << (n - 1)) {
return '1' ^ rev;
}
if (k > 1 << (n - 1)) {
k = (1 << n) - k;
rev ^= 1;
}
n--;
}
}
};
char findKthBit(int n, int k) {
int rev = 0; // 翻转次数的奇偶性
while (true) {
if (n == 1) {
return '0' ^ rev;
}
if (k == 1 << (n - 1)) {
return '1' ^ rev;
}
if (k > 1 << (n - 1)) {
k = (1 << n) - k;
rev ^= 1;
}
n--;
}
}
func findKthBit(n, k int) byte {
rev := byte(0) // 翻转次数的奇偶性
for {
if n == 1 {
return '0' ^ rev
}
if k == 1<<(n-1) {
return '1' ^ rev
}
if k > 1<<(n-1) {
k = 1<<n - k
rev ^= 1
}
n--
}
}
var findKthBit = function(n, k) {
let rev = 0; // 翻转次数的奇偶性
while (true) {
if (n === 1) {
return rev ? '1' : '0';
}
if (k === 1 << (n - 1)) {
return rev ? '0' : '1';
}
if (k > 1 << (n - 1)) {
k = (1 << n) - k;
rev ^= 1;
}
n--;
}
};
impl Solution {
pub fn find_kth_bit(mut n: i32, mut k: i32) -> char {
let mut rev = 0; // 翻转次数的奇偶性
loop {
if n == 1 {
return (b'0' ^ rev) as _;
}
if k == 1 << (n - 1) {
return (b'1' ^ rev) as _;
}
if k > 1 << (n - 1) {
k = (1 << n) - k;
rev ^= 1;
}
n -= 1;
}
}
}
$S_4 = \texttt{011100110110001}$,只看奇数位(下标从 $1$ 开始)的字符,是 $\texttt{01010101}$,这是一个 $\texttt{01}$ 交替序列,为什么?
只看奇数位:
一般地,由于 $\texttt{01}$ 交替序列反转再翻转,结果不变,所以从 $S_{i-1}$ 到 $S_i\ (i\ge 3)$,其中奇数位相当于复制了一份自身,拼在了自身后面,得到的仍然是 $\texttt{01}$ 交替序列。
所以,当 $k$ 是奇数时,可以立刻得出答案:
奇数位的字符,都发源于 $S_1 = \texttt{0}$。
偶数位的字符呢?都发源于 $S_i\ (i\ge 2)$ 正中间的那个 $\texttt{1}$,即位置为 $2,4,8,16,\ldots$ 的字符 $\texttt{1}$。
根据方法一的结论,$S_{n-1}$ 的第 $k$ 个字符,反转后,是 $S_n$ 的第 $2^n-k$ 个字符。
$2^n-k$ 有什么性质?
比如二进制 $10000 - 100 = 1100$,去掉末尾的两个 $0$,相当于 $100 - 1 = 11$,结果最低位一定是 $1$,所以 $100$ 和 $1100$ 的尾零个数相同。一般地,$k$ 和 $2^n-k$ 的尾零个数是相同的,这是个不变量!我们可以根据 $k$ 的尾零个数,找到 $k$ 发源于哪个 $S_i$ 正中间的 $\texttt{1}$。
以 $S_2$ 的中间字符(第 $2$ 个字符)为例:
一般地,设 $t$ 为 $k$ 去掉尾零后的值,即 $k = t\cdot 2^x$ 且 $t$ 是奇数。比如 $k=2,6,10,14,\ldots$ 对应着 $t=1,3,5,7,\ldots$
如何去掉 $k$ 的尾零?把 $k$ 除以其 $\text{lowbit}$ 即可。关于 $\text{lowbit}$ 的原理,请看 从集合论到位运算,常见位运算技巧分类总结。
class Solution:
def findKthBit(self, _, k: int) -> str:
if k % 2:
return str(k // 2 % 2)
k //= k & -k # 去掉 k 的尾零
return str(1 - k // 2 % 2)
class Solution {
public char findKthBit(int n, int k) {
if (k % 2 > 0) {
return (char) ('0' + k / 2 % 2);
}
k /= k & -k; // 去掉 k 的尾零
return (char) ('1' - k / 2 % 2);
}
}
class Solution {
public:
char findKthBit(int, int k) {
if (k % 2) {
return '0' + k / 2 % 2;
}
k /= k & -k; // 去掉 k 的尾零
return '1' - k / 2 % 2;
}
};
char findKthBit(int, int k) {
if (k % 2) {
return '0' + k / 2 % 2;
}
k /= k & -k; // 去掉 k 的尾零
return '1' - k / 2 % 2;
}
func findKthBit(_, k int) byte {
if k%2 > 0 {
return '0' + byte(k/2%2)
}
k /= k & -k // 去掉 k 的尾零
return '1' - byte(k/2%2)
}
var findKthBit = function(_, k) {
if (k % 2) {
return (k - 1) / 2 % 2 ? '1' : '0';
}
k /= k & -k; // 去掉 k 的尾零
return (k - 1) / 2 % 2 ? '0' : '1';
};
impl Solution {
pub fn find_kth_bit(_: i32, mut k: i32) -> char {
if k % 2 > 0 {
return (b'0' + k as u8 / 2 % 2) as _;
}
k /= k & -k; // 去掉 k 的尾零
(b'1' - k as u8 / 2 % 2) as _
}
}
见下面回溯题单的「五、其他递归/分治」。
欢迎关注 B站@灵茶山艾府
例如 $n=321$,其中最大的数字是 $3$。这个 $3$ 至少要拆分成 $3$ 个 $1$,即 $321=1__ + 1__ + 1__$。对于 $n$ 中的其余数字 $d$,可以拆分成 $d$ 个 $1$ 和 $3-d$ 个 $0$,即 $2=1+1+0$ 和 $1=1+0+0$,填到对应的位置上,得到 $321 = 111 + 110 + 100$。
一般地,设 $m$ 为 $n$ 中的最大数字,那么答案为 $m$。构造方案为:设 $n$ 的第 $i$ 个数字为 $n_i$,那么拆分出的这 $m$ 个数的第 $i$ 位上,有 $n_i$ 个 $1$ 和 $m-n_i$ 个 $0$(填入顺序随意)。
###py
class Solution:
def minPartitions(self, n: str) -> int:
return int(max(n))
###java
class Solution {
public int minPartitions(String n) {
int mx = 0;
for (char ch : n.toCharArray()) {
mx = Math.max(mx, ch);
}
return mx - '0';
}
}
###cpp
class Solution {
public:
int minPartitions(string n) {
return ranges::max(n) - '0';
}
};
###c
#define MAX(a, b) ((b) > (a) ? (b) : (a))
int minPartitions(char* n) {
char mx = 0;
for (int i = 0; n[i]; i++) {
mx = MAX(mx, n[i]);
}
return mx - '0';
}
###go
func minPartitions(n string) int {
ans := rune(0)
for _, ch := range n {
ans = max(ans, ch)
}
return int(ans - '0')
}
###js
var minPartitions = function(n) {
return Number(_.max(n));
};
###rust
impl Solution {
pub fn min_partitions(n: String) -> i32 {
(n.as_bytes().iter().max().unwrap() - b'0') as _
}
}
见下面贪心与思维题单的「§5.2 脑筋急转弯」。
欢迎关注 B站@灵茶山艾府
用位运算模拟这个过程:每拼接一个数 $i$,就把之前拼接过的数左移 $i$ 的二进制长度,然后加上 $i$。
由于左移后空出的位置全为 $0$,加法运算也可以写成或运算。
###go
func concatenatedBinary(n int) (ans int) {
for i := 1; i <= n; i++ {
ans = (ans<<bits.Len(uint(i)) | i) % (1e9 + 7)
}
return
}
做法和 2612. 最少翻转操作数 是类似的,请先阅读 我的题解。
设 $s$ 的长度为 $n$,其中有 $z$ 个 $0$。
翻转一次后,$s$ 有多少个 $0$?$z$ 可以变成什么数?
设翻转了 $x$ 个 $0$,那么也同时翻转了 $k-x$ 个 $1$,这些 $1$ 变成了 $0$。
所以 $z$ 减少了 $x$,然后又增加了 $k-x$。
所以新的 $z'$ 为
$$
z' = z - x + (k-x) = z+k-2x
$$
$x$ 最大可以是 $k$,但这不能超过 $s$ 中的 $0$ 的个数 $z$,所以 $x$ 最大为 $\min(k,z)$。
$k-x$ 最大可以是 $k$,但这不能超过 $s$ 中的 $1$ 的个数 $n-z$,所以 $k-x$ 最大为 $\min(k,n-z)$,所以 $x$ 最小为 $\max(0,k-n+z)$。
所以 $x$ 的范围为
$$
[\max(0,k-n+z),\min(k,z)]
$$
其余逻辑同 2612 题。
###py
class Solution:
def minOperations(self, s: str, k: int) -> int:
n = len(s)
not_vis = [SortedList(range(0, n + 1, 2)), SortedList(range(1, n + 1, 2))]
not_vis[0].add(n + 1) # 哨兵,下面 sl[idx] <= mx 无需判断越界
not_vis[1].add(n + 1)
start = s.count('0') # 起点
not_vis[start % 2].discard(start)
q = [start]
ans = 0
while q:
tmp = q
q = []
for z in tmp:
if z == 0: # 没有 0,翻转完毕
return ans
# not_vis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
mn = z + k - 2 * min(k, z)
mx = z + k - 2 * max(0, k - n + z)
sl = not_vis[mn % 2]
idx = sl.bisect_left(mn)
while sl[idx] <= mx:
j = sl.pop(idx) # 注意 pop(idx) 会使后续元素向左移,不需要写 idx += 1
q.append(j)
ans += 1
return -1
###java
class Solution {
public int minOperations(String s, int k) {
int n = s.length();
TreeSet<Integer>[] notVis = new TreeSet[2];
for (int m = 0; m < 2; m++) {
notVis[m] = new TreeSet<>();
for (int i = m; i <= n; i += 2) {
notVis[m].add(i);
}
}
// 计算起点
int start = 0;
for (int i = 0; i < n; i++) {
if (s.charAt(i) == '0') {
start++;
}
}
notVis[start % 2].remove(start);
List<Integer> q = List.of(start);
for (int ans = 0; !q.isEmpty(); ans++) {
List<Integer> tmp = q;
q = new ArrayList<>();
for (int z : tmp) {
if (z == 0) { // 没有 0,翻转完毕
return ans;
}
// notVis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
int mn = z + k - 2 * Math.min(k, z);
int mx = z + k - 2 * Math.max(0, k - n + z);
TreeSet<Integer> set = notVis[mn % 2];
for (Iterator<Integer> it = set.tailSet(mn).iterator(); it.hasNext(); it.remove()) {
int j = it.next();
if (j > mx) {
break;
}
q.add(j);
}
}
}
return -1;
}
}
###cpp
class Solution {
public:
int minOperations(string s, int k) {
int n = s.size();
set<int> not_vis[2];
for (int m = 0; m < 2; m++) {
for (int i = m; i <= n; i += 2) {
not_vis[m].insert(i);
}
not_vis[m].insert(n + 1); // 哨兵,下面无需判断 it != st.end()
}
int start = ranges::count(s, '0'); // 起点
not_vis[start % 2].erase(start);
vector<int> q = {start};
for (int ans = 0; !q.empty(); ans++) {
vector<int> nxt;
for (int z : q) {
if (z == 0) { // 没有 0,翻转完毕
return ans;
}
// not_vis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
int mn = z + k - 2 * min(k, z);
int mx = z + k - 2 * max(0, k - n + z);
auto& st = not_vis[mn % 2];
for (auto it = st.lower_bound(mn); *it <= mx; it = st.erase(it)) {
nxt.push_back(*it);
}
}
q = move(nxt);
}
return -1;
}
};
###go
// import "github.com/emirpasic/gods/v2/trees/redblacktree"
func minOperations(s string, k int) (ans int) {
n := len(s)
notVis := [2]*redblacktree.Tree[int, struct{}]{}
for m := range notVis {
notVis[m] = redblacktree.New[int, struct{}]()
for i := m; i <= n; i += 2 {
notVis[m].Put(i, struct{}{})
}
notVis[m].Put(n+1, struct{}{}) // 哨兵,下面无需判断 node != nil
}
start := strings.Count(s, "0")
notVis[start%2].Remove(start)
q := []int{start}
for q != nil {
tmp := q
q = nil
for _, z := range tmp {
if z == 0 { // 没有 0,翻转完毕
return ans
}
// notVis[mn % 2] 中的从 mn 到 mx 都可以从 z 翻转到
mn := z + k - 2*min(k, z)
mx := z + k - 2*max(0, k-n+z)
t := notVis[mn%2]
for node, _ := t.Ceiling(mn); node.Key <= mx; node, _ = t.Ceiling(mn) {
q = append(q, node.Key)
t.Remove(node.Key)
}
}
ans++
}
return -1
}
设 $s$ 中有 $z$ 个 $0$,设一共操作了 $m$ 次。那么总翻转次数为 $mk$。
这 $z$ 个 $0$ 必须翻转奇数次,其余 $n-z$ 个 $1$ 必须翻转偶数次。
总翻转次数减去 $z$,剩下每个位置都必须翻转偶数次,所以
$$
mk-z\ 是偶数
$$
下面计算 $m$ 的下界。只要能证明 $m$ 可以等于下界,问题就解决了。
要想把 $z$ 个 $0$ 变成 $1$,总翻转次数至少要是 $z$,即
$$
mk\ge z
$$
即
$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$
除此以外,还需要满足什么要求?
由于 $mk-z$ 是偶数,如果 $m$ 是偶数,那么 $z$ 也必须是偶数。
$s$ 中的每个位置至多翻转 $m$ 次。但是,对于 $s$ 中的 $0$,由于要翻转奇数次,所以至多翻转 $m-1$ 次。
所以 $s$ 中的所有位置的翻转次数的上界是 $z(m-1)+(n-z)m$,其可能等于 $mk$,也可能比 $mk$ 大(因为是上界),所以有
$$
z(m-1)+(n-z)m\ge mk
$$
解得
$$
m\ge \left\lceil\dfrac{z}{n-k}\right\rceil
$$
与
$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$
联立得
$$
m\ge \max\left(\left\lceil\dfrac{z}{k}\right\rceil,\left\lceil\dfrac{z}{n-k}\right\rceil\right)
$$
由于 $mk-z$ 是偶数,如果 $m$ 是奇数,那么 $z$ 和 $k$ 必须同为奇数,或者同为偶数(奇偶性相同)。
$s$ 中的每个位置至多翻转 $m$ 次。但是,对于 $s$ 中的 $1$,由于要翻转偶数次,所以至多翻转 $m-1$ 次。
所以 $s$ 中的所有位置的翻转次数的上界是 $zm+(n-z)(m-1)$,其可能等于 $mk$,也可能比 $mk$ 大(因为是上界),所以有
$$
zm+(n-z)(m-1)\ge mk
$$
解得
$$
m\ge \left\lceil\dfrac{n-z}{n-k}\right\rceil
$$
与
$$
m\ge \left\lceil\dfrac{z}{k}\right\rceil
$$
联立得
$$
m\ge \max\left(\left\lceil\dfrac{z}{k}\right\rceil,\left\lceil\dfrac{n-z}{n-k}\right\rceil\right)
$$
情况一和情况二取最小值。
如果两个情况都不满足要求,返回 $-1$。
这可以用 Gale-Ryser 定理证明。
具体来说,我们需要判断是否存在一个 $m$ 行 $n$ 列的 $0\text{-}1$ 矩阵 $M$,第 $i$ 行对应着第 $i$ 次操作,其中 $M_{i,j} = 0$ 表示没有翻转 $s_j$,$M_{i,j} = 1$ 表示翻转 $s_j$。每一行的元素和都是 $k$,第 $j$ 列的元素和是 $s_j$ 的翻转次数 $a_j$。由于 $a_j\le m$ 且 $\sum\limits_{j} a_j\le mk$,由 Gale-Ryser 定理可得,这样的矩阵是存在的。
如果 $z=0$,那么无需操作,答案是 $0$。
由于下界公式中的分母 $n-k$ 不能为 $0$,我们需要特判 $n=k$ 的情况,此时每次操作只能翻转整个 $s$。
关于上取整的计算,当 $a$ 为非负整数,$b$ 为正整数时,有恒等式
$$
\left\lceil\dfrac{a}{b}\right\rceil = \left\lfloor\dfrac{a+b-1}{b}\right\rfloor
$$
证明见 上取整下取整转换公式的证明。
###py
class Solution:
def minOperations(self, s: str, k: int) -> int:
n = len(s)
z = s.count('0')
if z == 0:
return 0
if k == n:
return 1 if z == n else -1
ans = inf
# 情况一:操作次数 m 是偶数
if z % 2 == 0: # z 必须是偶数
m = max((z + k - 1) // k, (z + n - k - 1) // (n - k)) # 下界
ans = m + m % 2 # 把 m 往上调整为偶数
# 情况二:操作次数 m 是奇数
if z % 2 == k % 2: # z 和 k 的奇偶性必须相同
m = max((z + k - 1) // k, (n - z + n - k - 1) // (n - k)) # 下界
ans = min(ans, m | 1) # 把 m 往上调整为奇数
return ans if ans < inf else -1
###java
class Solution {
public int minOperations(String s, int k) {
int n = s.length();
int z = 0;
for (int i = 0; i < n; i++) {
if (s.charAt(i) == '0') {
z++;
}
}
if (z == 0) {
return 0;
}
if (k == n) {
return z == n ? 1 : -1;
}
int ans = Integer.MAX_VALUE;
// 情况一:操作次数 m 是偶数
if (z % 2 == 0) { // z 必须是偶数
int m = Math.max((z + k - 1) / k, (z + n - k - 1) / (n - k)); // 下界
ans = m + m % 2; // 把 m 往上调整为偶数
}
// 情况二:操作次数 m 是奇数
if (z % 2 == k % 2) { // z 和 k 的奇偶性必须相同
int m = Math.max((z + k - 1) / k, (n - z + n - k - 1) / (n - k)); // 下界
ans = Math.min(ans, m | 1); // 把 m 往上调整为奇数
}
return ans < Integer.MAX_VALUE ? ans : -1;
}
}
###cpp
class Solution {
public:
int minOperations(string s, int k) {
int n = s.size();
int z = ranges::count(s, '0');
if (z == 0) {
return 0;
}
if (k == n) {
return z == n ? 1 : -1;
}
int ans = INT_MAX;
// 情况一:操作次数 m 是偶数
if (z % 2 == 0) { // z 必须是偶数
int m = max((z + k - 1) / k, (z + n - k - 1) / (n - k)); // 下界
ans = m + m % 2; // 把 m 往上调整为偶数
}
// 情况二:操作次数 m 是奇数
if (z % 2 == k % 2) { // z 和 k 的奇偶性必须相同
int m = max((z + k - 1) / k, (n - z + n - k - 1) / (n - k)); // 下界
ans = min(ans, m | 1); // 把 m 往上调整为奇数
}
return ans < INT_MAX ? ans : -1;
}
};
###go
func minOperations(s string, k int) int {
n := len(s)
z := strings.Count(s, "0")
if z == 0 {
return 0
}
if k == n {
if z == n {
return 1
}
return -1
}
ans := math.MaxInt
// 情况一:操作次数 m 是偶数
if z%2 == 0 { // z 必须是偶数
m := max((z+k-1)/k, (z+n-k-1)/(n-k)) // 下界
ans = m + m%2 // 把 m 往上调整为偶数
}
// 情况二:操作次数 m 是奇数
if z%2 == k%2 { // z 和 k 的奇偶性必须相同
m := max((z+k-1)/k, (n-z+n-k-1)/(n-k)) // 下界
ans = min(ans, m|1) // 把 m 往上调整为奇数
}
if ans < math.MaxInt {
return ans
}
return -1
}
双关键字排序。对于 $\textit{arr}$ 中的两个数:
###py
class Solution:
def sortByBits(self, arr: List[int]) -> List[int]:
arr.sort(key=lambda x: (x.bit_count(), x))
return arr
###java
class Solution {
public int[] sortByBits(int[] arr) {
return IntStream.of(arr)
.boxed()
.sorted((a, b) -> {
int ca = Integer.bitCount(a);
int cb = Integer.bitCount(b);
return ca != cb ? ca - cb : a - b;
})
.mapToInt(a -> a)
.toArray();
}
}
###java
class Solution {
public int[] sortByBits(int[] arr) {
for (int i = 0; i < arr.length; i++) {
arr[i] = Integer.bitCount(arr[i]) << 16 | arr[i];
}
Arrays.sort(arr);
for (int i = 0; i < arr.length; i++) {
arr[i] &= 0xffff;
}
return arr;
}
}
###cpp
class Solution {
public:
vector<int> sortByBits(vector<int>& arr) {
ranges::sort(arr, {}, [](int x) {
return pair(popcount((uint32_t) x), x);
});
return arr;
}
};
###c
int cmp(const void* a, const void* b) {
int x = *(int*)a, y = *(int*)b;
int cx = __builtin_popcount(x), cy = __builtin_popcount(y);
return cx != cy ? cx - cy : x - y;
}
int* sortByBits(int* arr, int arrSize, int* returnSize) {
qsort(arr, arrSize, sizeof(int), cmp);
*returnSize = arrSize;
return arr;
}
###go
func sortByBits(arr []int) []int {
slices.SortFunc(arr, func(a, b int) int {
return cmp.Or(bits.OnesCount(uint(a))-bits.OnesCount(uint(b)), a-b)
})
return arr
}
###js
var sortByBits = function(arr) {
return arr.sort((a, b) => bitCount32(a) - bitCount32(b) || a - b);
};
// 参考 Java 的 Integer.bitCount
function bitCount32(i) {
i = i - ((i >>> 1) & 0x55555555);
i = (i & 0x33333333) + ((i >>> 2) & 0x33333333);
i = (i + (i >>> 4)) & 0x0f0f0f0f;
i = i + (i >>> 8);
i = i + (i >>> 16);
return i & 0x3f;
}
###rust
impl Solution {
pub fn sort_by_bits(mut arr: Vec<i32>) -> Vec<i32> {
arr.sort_unstable_by_key(|&x| (x.count_ones(), x));
arr
}
}
欢迎关注 B站@灵茶山艾府
从大家熟悉的十进制说起。如何把路径 $1\to 2\to 3$ 变成十进制数 $123$?过程如下:
$$
0\xrightarrow{\times 10 + 1}1\xrightarrow{\times 10 + 2} 12\xrightarrow{\times 10 + 3}123
$$
二进制的做法类似。例如把路径 $1\to 0\to 1\to 1$ 变成二进制数 $1011$,过程如下:
$$
0\xrightarrow{\times 2 + 1}1\xrightarrow{\times 2 + 0} 10\xrightarrow{\times 2 + 1}101\xrightarrow{\times 2 + 1}1011
$$
其中 $\times 2$ 等价于左移一位,$+$ 也可以写成或运算。
我们可以对 $\textit{dfs}$ 额外添加参数 $\textit{num}$,表示在自顶向下递归的过程中,当前数字是 $\textit{num}$。每访问到一个新的节点 $\textit{node}$,就把 $\textit{num}$ 更新成 num << 1 | node.val。
如果 $\textit{node}$ 是叶子节点,把 $\textit{num}$ 加到答案中。
class Solution:
def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
# 从根到 node(不含)的路径值为 num
def dfs(node: Optional[TreeNode], num: int) -> None:
nonlocal ans
if node is None:
return
num = num << 1 | node.val
if node.left is None and node.right is None:
ans += num
return
dfs(node.left, num)
dfs(node.right, num)
ans = 0
dfs(root, 0)
return ans
class Solution {
private int ans = 0;
public int sumRootToLeaf(TreeNode root) {
dfs(root, 0);
return ans;
}
// 从根到 node(不含)的路径值为 num
private void dfs(TreeNode node, int num) {
if (node == null) {
return;
}
num = num << 1 | node.val;
if (node.left == null && node.right == null) {
ans += num;
return;
}
dfs(node.left, num);
dfs(node.right, num);
}
}
class Solution {
public:
int sumRootToLeaf(TreeNode* root) {
int ans = 0;
// 从根到 node(不含)的路径值为 num
auto dfs = [&](this auto&& dfs, TreeNode* node, int num) -> void {
if (node == nullptr) {
return;
}
num = num << 1 | node->val;
if (node->left == nullptr && node->right == nullptr) {
ans += num;
return;
}
dfs(node->left, num);
dfs(node->right, num);
};
dfs(root, 0);
return ans;
}
};
func sumRootToLeaf(root *TreeNode) (ans int) {
// 从根到 node(不含)的路径值为 num
var dfs func(*TreeNode, int)
dfs = func(node *TreeNode, num int) {
if node == nil {
return
}
num = num<<1 | node.Val
if node.Left == nil && node.Right == nil {
ans += num
return
}
dfs(node.Left, num)
dfs(node.Right, num)
}
dfs(root, 0)
return
}
class Solution:
def sumRootToLeaf(self, root: Optional[TreeNode]) -> int:
def dfs(node: Optional[TreeNode], num: int) -> int:
if node is None:
return 0
num = num << 1 | node.val
if node.left is None and node.right is None:
return num
return dfs(node.left, num) + dfs(node.right, num)
return dfs(root, 0)
class Solution {
public int sumRootToLeaf(TreeNode root) {
return dfs(root, 0);
}
private int dfs(TreeNode node, int num) {
if (node == null) {
return 0;
}
num = num << 1 | node.val;
if (node.left == null && node.right == null) {
return num;
}
return dfs(node.left, num) + dfs(node.right, num);
}
}
class Solution {
int dfs(TreeNode* node, int num) {
if (node == nullptr) {
return 0;
}
num = num << 1 | node->val;
if (node->left == nullptr && node->right == nullptr) {
return num;
}
return dfs(node->left, num) + dfs(node->right, num);
}
public:
int sumRootToLeaf(TreeNode* root) {
return dfs(root, 0);
}
};
func dfs(node *TreeNode, num int) int {
if node == nil {
return 0
}
num = num<<1 | node.Val
if node.Left == nil && node.Right == nil {
return num
}
return dfs(node.Left, num) + dfs(node.Right, num)
}
func sumRootToLeaf(root *TreeNode) int {
return dfs(root, 0)
}
见下面树题单的「§2.2 自顶向下 DFS」。
欢迎关注 B站@灵茶山艾府
暴力枚举所有长为 $k$ 的子串,保存到一个哈希集合中。
如果最终哈希集合的大小恰好等于 $2^k$,那么说明所有长为 $k$ 的二进制串都在 $s$ 中。
###py
class Solution:
def hasAllCodes(self, s: str, k: int) -> bool:
st = {s[i - k: i] for i in range(k, len(s) + 1)}
return len(st) == 1 << k
###java
class Solution {
public boolean hasAllCodes(String s, int k) {
Set<String> set = new HashSet<>();
for (int i = k; i <= s.length(); i++) {
set.add(s.substring(i - k, i));
}
return set.size() == (1 << k);
}
}
###cpp
class Solution {
public:
bool hasAllCodes(string s, int k) {
unordered_set<string> st;
for (int i = k; i <= s.size(); i++) {
st.insert(s.substr(i - k, k));
}
return st.size() == (1 << k);
}
};
###go
func hasAllCodes(s string, k int) bool {
set := map[string]struct{}{}
for i := k; i <= len(s); i++ {
set[s[i-k:i]] = struct{}{}
}
return len(set) == 1<<k
}
把子串转成整数,保存到哈希集合或者布尔数组中。
小优化:如果循环过程中发现已经找到 $2^k$ 个不同的二进制数,可以提前返回 $\texttt{true}$。
###py
class Solution:
def hasAllCodes(self, s: str, k: int) -> bool:
MASK = (1 << k) - 1
st = set() # 更快的写法见另一份代码【Python3 列表】
x = 0
for i, ch in enumerate(s):
# 把 ch 加到 x 的末尾:x 整体左移一位,然后或上 ch
# &MASK 目的是去掉超出 k 的比特位
x = (x << 1 & MASK) | int(ch)
if i >= k - 1:
st.add(x)
return len(st) == 1 << k
###py
class Solution:
def hasAllCodes(self, s: str, k: int) -> bool:
MASK = (1 << k) - 1
has = [False] * (1 << k)
cnt = x = 0
for i, ch in enumerate(s):
# 把 ch 加到 x 的末尾:x 整体左移一位,然后或上 ch
# &MASK 目的是去掉超出 k 的比特位
x = (x << 1 & MASK) | int(ch)
if i < k - 1 or has[x]:
continue
has[x] = True
cnt += 1
if cnt == 1 << k:
return True
return False
###java
class Solution {
public boolean hasAllCodes(String s, int k) {
final int MASK = (1 << k) - 1;
boolean[] has = new boolean[1 << k];
int cnt = 0;
int x = 0;
for (int i = 0; i < s.length() && cnt < (1 << k); i++) {
char ch = s.charAt(i);
// 把 ch 加到 x 的末尾:x 整体左移一位,然后或上 ch&1
// &MASK 目的是去掉超出 k 的比特位
x = (x << 1 & MASK) | (ch & 1);
if (i >= k - 1 && !has[x]) {
has[x] = true;
cnt++;
}
}
return cnt == (1 << k);
}
}
###cpp
class Solution {
public:
bool hasAllCodes(string s, int k) {
const int MASK = (1 << k) - 1;
vector<int8_t> has(1 << k);
int cnt = 0;
int x = 0;
for (int i = 0; i < s.size() && cnt < (1 << k); i++) {
// 把 s[i] 加到 x 的末尾:x 整体左移一位,然后或上 s[i]&1
// &MASK 目的是去掉超出 k 的比特位
x = (x << 1 & MASK) | (s[i] & 1);
if (i >= k - 1 && !has[x]) {
has[x] = true;
cnt++;
}
}
return cnt == (1 << k);
}
};
###go
func hasAllCodes(s string, k int) bool {
has := make([]bool, 1<<k)
cnt := 0
mask := 1<<k - 1
x := 0
for i, ch := range s {
// 把 ch 加到 x 的末尾:x 整体左移一位,然后或上 ch&1
// &mask 目的是去掉超出 k 的比特位
x = x<<1&mask | int(ch&1)
if i < k-1 || has[x] {
continue
}
has[x] = true
cnt++
if cnt == 1<<k {
return true
}
}
return false
}
欢迎关注 B站@灵茶山艾府
以 $n = 1010010$ 为例。从右往左,我们需要计算 $1001$ 的间距 $3$,以及 $101$ 的间距 $2$:
class Solution:
def binaryGap(self, n: int) -> int:
ans = 0
n //= (n & -n) * 2 # 去掉 n 末尾的 100..0
while n > 0:
gap = (n & -n).bit_length() # n 的尾零个数加一
ans = max(ans, gap)
n >>= gap # 去掉 n 末尾的 100..0
return ans
class Solution {
public int binaryGap(int n) {
int ans = 0;
n /= (n & -n) * 2; // 去掉 n 末尾的 100..0
while (n > 0) {
int gap = Integer.numberOfTrailingZeros(n) + 1;
ans = Math.max(ans, gap);
n >>= gap; // 去掉 n 末尾的 100..0
}
return ans;
}
}
class Solution {
public:
int binaryGap(int n) {
int ans = 0;
n /= (n & -n) * 2; // 去掉 n 末尾的 100..0
while (n > 0) {
int gap = countr_zero((uint32_t) n) + 1;
ans = max(ans, gap);
n >>= gap; // 去掉 n 末尾的 100..0
}
return ans;
}
};
#define MAX(a, b) ((b) > (a) ? (b) : (a))
int binaryGap(int n) {
int ans = 0;
n /= (n & -n) * 2; // 去掉 n 末尾的 100..0
while (n > 0) {
int gap = __builtin_ctz(n) + 1;
ans = MAX(ans, gap);
n >>= gap; // 去掉 n 末尾的 100..0
}
return ans;
}
func binaryGap(n int) (ans int) {
n /= n & -n * 2 // 去掉 n 末尾的 100..0
for n > 0 {
gap := bits.TrailingZeros(uint(n)) + 1
ans = max(ans, gap)
n >>= gap // 去掉 n 末尾的 100..0
}
return
}
var binaryGap = function(n) {
let ans = 0;
n /= (n & -n) * 2; // 去掉 n 末尾的 100..0
while (n > 0) {
const gap = 32 - Math.clz32(n & -n); // n 的尾零个数加一
ans = Math.max(ans, gap);
n >>= gap; // 去掉 n 末尾的 100..0
}
return ans;
};
impl Solution {
pub fn binary_gap(mut n: i32) -> i32 {
let mut ans = 0;
n /= (n & -n) * 2; // 去掉 n 末尾的 100..0
while n > 0 {
let gap = n.trailing_zeros() + 1;
ans = ans.max(gap);
n >>= gap; // 去掉 n 末尾的 100..0
}
ans as _
}
}
见下面位运算题单的「一、基础题」。
欢迎关注 B站@灵茶山艾府
枚举 $[\textit{left},\textit{right}]$ 中的整数 $x$,计算 $x$ 二进制中的 $1$ 的个数 $c$。如果 $c$ 是质数,那么答案增加一。
由于 $[1,10^6]$ 中的二进制数至多有 $19$ 个 $1$,所以只需 $19$ 以内的质数,即
$$
2, 3, 5, 7, 11, 13, 17, 19
$$
primes = {2, 3, 5, 7, 11, 13, 17, 19}
class Solution:
def countPrimeSetBits(self, left: int, right: int) -> int:
ans = 0
for x in range(left, right + 1):
if x.bit_count() in primes:
ans += 1
return ans
class Solution {
private static final Set<Integer> primes = Set.of(2, 3, 5, 7, 11, 13, 17, 19);
public int countPrimeSetBits(int left, int right) {
int ans = 0;
for (int x = left; x <= right; x++) {
if (primes.contains(Integer.bitCount(x))) {
ans++;
}
}
return ans;
}
}
class Solution {
// 注:也可以用哈希集合做,由于本题质数很少,用数组也可以
static constexpr int primes[] = {2, 3, 5, 7, 11, 13, 17, 19};
public:
int countPrimeSetBits(int left, int right) {
int ans = 0;
for (uint32_t x = left; x <= right; x++) {
if (ranges::contains(primes, popcount(x))) {
ans++;
}
}
return ans;
}
};
// 注:也可以用哈希集合做,由于本题质数很少,用 slice 也可以
var primes = []int{2, 3, 5, 7, 11, 13, 17, 19}
func countPrimeSetBits(left, right int) (ans int) {
for x := left; x <= right; x++ {
if slices.Contains(primes, bits.OnesCount(uint(x))) {
ans++
}
}
return
}
数位 DP v2.0 模板讲解(上下界数位 DP)
对于本题,在递归边界($i=n$)我们需要判断是否填了质数个 $1$,所以需要参数 $\textit{cnt}_1$ 表示填过的 $1$ 的个数。其余同 v2.0 模板。
primes = {2, 3, 5, 7, 11, 13, 17, 19}
class Solution:
def countPrimeSetBits(self, left: int, right: int) -> int:
high_s = list(map(int, bin(right)[2:])) # 避免在 dfs 中频繁调用 int()
n = len(high_s)
low_s = list(map(int, bin(left)[2:].zfill(n))) # 添加前导零,长度和 high_s 对齐
# 在 dfs 的过程中,统计二进制中的 1 的个数 cnt1
@cache # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
def dfs(i: int, cnt1: int, limit_low: bool, limit_high: bool) -> int:
if i == n:
return 1 if cnt1 in primes else 0
lo = low_s[i] if limit_low else 0
hi = high_s[i] if limit_high else 1
res = 0
for d in range(lo, hi + 1):
res += dfs(i + 1, cnt1 + d, limit_low and d == lo, limit_high and d == hi)
return res
return dfs(0, 0, True, True)
class Solution {
private static final Set<Integer> primes = Set.of(2, 3, 5, 7, 11, 13, 17, 19);
public int countPrimeSetBits(int left, int right) {
int n = 32 - Integer.numberOfLeadingZeros(right);
int[][] memo = new int[n][n + 1];
for (int[] row : memo) {
Arrays.fill(row, -1);
}
return dfs(n - 1, 0, true, true, left, right, memo);
}
// 在 dfs 的过程中,统计二进制中的 1 的个数 cnt1
private int dfs(int i, int cnt1, boolean limitLow, boolean limitHigh, int left, int right, int[][] memo) {
if (i < 0) {
return primes.contains(cnt1) ? 1 : 0;
}
if (!limitLow && !limitHigh && memo[i][cnt1] != -1) {
return memo[i][cnt1];
}
int lo = limitLow ? left >> i & 1 : 0;
int hi = limitHigh ? right >> i & 1 : 1;
int res = 0;
for (int d = lo; d <= hi; d++) {
res += dfs(i - 1, cnt1 + d, limitLow && d == lo, limitHigh && d == hi, left, right, memo);
}
if (!limitLow && !limitHigh) {
memo[i][cnt1] = res;
}
return res;
}
}
class Solution {
// 注:也可以用哈希集合做,由于本题质数很少,用数组也可以
static constexpr int primes[] = {2, 3, 5, 7, 11, 13, 17, 19};
public:
int countPrimeSetBits(int left, int right) {
int n = bit_width((uint32_t) right);
vector memo(n, vector<int>(n + 1, -1));
// 在 dfs 的过程中,统计二进制中的 1 的个数 cnt1
auto dfs = [&](this auto&& dfs, int i, int cnt1, bool limit_low, bool limit_high) -> int {
if (i < 0) {
return ranges::contains(primes, cnt1);
}
if (!limit_low && !limit_high && memo[i][cnt1] != -1) {
return memo[i][cnt1];
}
int lo = limit_low ? left >> i & 1 : 0;
int hi = limit_high ? right >> i & 1 : 1;
int res = 0;
for (int d = lo; d <= hi; d++) {
res += dfs(i - 1, cnt1 + d, limit_low && d == lo, limit_high && d == hi);
}
if (!limit_low && !limit_high) {
memo[i][cnt1] = res;
}
return res;
};
return dfs(n - 1, 0, true, true);
}
};
// 注:也可以用哈希集合做,由于本题质数很少,用数组也可以
var primes = []int{2, 3, 5, 7, 11, 13, 17, 19}
func countPrimeSetBits(left int, right int) int {
n := bits.Len(uint(right))
memo := make([][]int, n)
for i := range memo {
memo[i] = make([]int, n+1)
for j := range memo[i] {
memo[i][j] = -1
}
}
// 在 dfs 的过程中,统计二进制中的 1 的个数 cnt1
var dfs func(int, int, bool, bool) int
dfs = func(i, cnt1 int, limitLow, limitHigh bool) (res int) {
if i < 0 {
if slices.Contains(primes, cnt1) {
return 1
}
return 0
}
if !limitLow && !limitHigh {
p := &memo[i][cnt1]
if *p >= 0 {
return *p
}
defer func() { *p = res }()
}
lo := 0
if limitLow {
lo = left >> i & 1
}
hi := 1
if limitHigh {
hi = right >> i & 1
}
for d := lo; d <= hi; d++ {
res += dfs(i-1, cnt1+d, limitLow && d == lo, limitHigh && d == hi)
}
return
}
return dfs(n-1, 0, true, true)
}
primes = [2, 3, 5, 7, 11, 13, 17, 19]
class Solution:
def calc(self, high: int) -> int:
# 转换成计算 < high + 1 的合法正整数个数
# 这样转换可以方便下面的代码把 high 也算进来
high += 1
res = ones = 0
for i in range(high.bit_length() - 1, -1, -1):
if high >> i & 1 == 0:
continue
# 如果这一位填 0,那么后面可以随便填
# 问题变成在 i 个位置中填 k 个 1 的方案数,满足 ones + k 是质数
for p in primes:
k = p - ones # 剩余需要填的 1 的个数
if k > i:
break
if k >= 0:
res += comb(i, k)
# 这一位填 1,继续计算
ones += 1
return res
def countPrimeSetBits(self, left: int, right: int) -> int:
return self.calc(right) - self.calc(left - 1)
MX = 20
comb = [[0] * MX for _ in range(MX)]
for i in range(MX):
comb[i][0] = 1
for j in range(1, i + 1):
comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j]
primes = [2, 3, 5, 7, 11, 13, 17, 19]
class Solution:
def calc(self, high: int) -> int:
# 转换成计算 < high + 1 的合法正整数个数
# 这样转换可以方便下面的代码把 high 也算进来
high += 1
res = ones = 0
for i in range(high.bit_length() - 1, -1, -1):
if high >> i & 1 == 0:
continue
# 如果这一位填 0,那么后面可以随便填
# 问题变成在 i 个位置中填 k 个 1 的方案数,满足 ones + k 是质数
for p in primes:
k = p - ones # 剩余需要填的 1 的个数
if k > i:
break
if k >= 0:
res += comb[i][k]
# 这一位填 1,继续计算
ones += 1
return res
def countPrimeSetBits(self, left: int, right: int) -> int:
return self.calc(right) - self.calc(left - 1)
class Solution {
private static final int MX = 20;
private static final int[][] comb = new int[MX][MX];
private static final int[] primes = {2, 3, 5, 7, 11, 13, 17, 19};
private static boolean initialized = false;
// 这样写比 static block 快
public Solution() {
if (initialized) {
return;
}
initialized = true;
// 预处理组合数
for (int i = 0; i < MX; i++) {
comb[i][0] = 1;
for (int j = 1; j <= i; j++) {
comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j];
}
}
}
public int countPrimeSetBits(int left, int right) {
return calc(right) - calc(left - 1);
}
private int calc(int high) {
// 转换成计算 < high + 1 的合法正整数个数
// 这样转换可以方便下面的代码把 high 也算进来
high++;
int res = 0;
int ones = 0;
for (int i = 31 - Integer.numberOfLeadingZeros(high); i >= 0; i--) {
if ((high >> i & 1) == 0) {
continue;
}
// 如果这一位填 0,那么后面可以随便填
// 问题变成在 pos 个位置中填 k 个 1 的方案数,满足 ones + k 是质数
for (int p : primes) {
int k = p - ones; // 剩余需要填的 1 的个数
if (k > i) {
break;
}
if (k >= 0) {
res += comb[i][k];
}
}
ones++; // 这一位填 1,继续计算
}
return res;
}
}
constexpr int MX = 20;
int comb[MX][MX];
auto init = [] {
// 预处理组合数
for (int i = 0; i < MX; i++) {
comb[i][0] = 1;
for (int j = 1; j <= i; j++) {
comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j];
}
}
return 0;
}();
class Solution {
static constexpr int primes[] = {2, 3, 5, 7, 11, 13, 17, 19};
int calc(int high) {
// 转换成计算 < high + 1 的合法正整数个数
// 这样转换可以方便下面的代码把 high 也算进来
high++;
int res = 0, ones = 0;
for (int i = bit_width((uint32_t) high) - 1; i >= 0; i--) {
if ((high >> i & 1) == 0) {
continue;
}
// 如果这一位填 0,那么后面可以随便填
// 问题变成在 i 个位置中填 k 个 1 的方案数,满足 ones + k 是质数
for (int p : primes) {
int k = p - ones; // 剩余需要填的 1 的个数
if (k > i) {
break;
}
if (k >= 0) {
res += comb[i][k];
}
}
ones++; // 这一位填 1,继续计算
}
return res;
}
public:
int countPrimeSetBits(int left, int right) {
return calc(right) - calc(left - 1);
}
};
const mx = 20
var comb [mx][mx]int
var primes = []int{2, 3, 5, 7, 11, 13, 17, 19}
func init() {
// 预处理组合数
for i := range comb {
comb[i][0] = 1
for j := 1; j <= i; j++ {
comb[i][j] = comb[i-1][j-1] + comb[i-1][j]
}
}
}
func calc(high int) (res int) {
// 转换成计算 < high + 1 的合法正整数个数
// 这样转换可以方便下面的代码把 high 也算进来
high++
ones := 0
for i := bits.Len(uint(high)) - 1; i >= 0; i-- {
if high>>i&1 == 0 {
continue
}
// 如果这一位填 0,那么后面可以随便填
// 问题变成在 i 个位置中填 k 个 1 的方案数,满足 ones + k 是质数
for _, p := range primes {
k := p - ones // 剩余需要填的 1 的个数
if k > i {
break
}
if k >= 0 {
res += comb[i][k]
}
}
// 这一位填 1,继续计算
ones++
}
return res
}
func countPrimeSetBits(left, right int) int {
return calc(right) - calc(left-1)
}
不计入预处理的时间和空间。
题意:子串必须形如 $\underbrace{\texttt{0}\cdots \texttt{0}}{k\ 个\ \texttt{0}}\underbrace{\texttt{1}\cdots \texttt{1}}{k\ 个\ \texttt{1}}$ 或者 $\underbrace{\texttt{1}\cdots \texttt{1}}{k\ 个\ \texttt{1}}\underbrace{\texttt{0}\cdots \texttt{0}}{k\ 个\ \texttt{0}}$。只能有一段 $\texttt{0}$ 和一段 $\texttt{1}$,不能是 $\texttt{00111}$(两段长度不等)或者 $\texttt{010}$(超过两段)等。
例如 $s = \texttt{001110000}$,按照连续相同字符,分成三组 $\texttt{00},\texttt{111},\texttt{0000}$。
一般地,遍历 $s$,按照连续相同字符分组,计算每一组的长度。设当前这组的长度为 $\textit{cur}$,上一组的长度为 $\textit{pre}$,那么当前这组和上一组,能得到 $\min(\textit{pre},\textit{cur})$ 个合法子串,加到答案中。
###py
class Solution:
def countBinarySubstrings(self, s: str) -> int:
n = len(s)
pre = cur = ans = 0
for i in range(n):
cur += 1
if i == n - 1 or s[i] != s[i + 1]:
# 遍历到了这一组的末尾
ans += min(pre, cur)
pre = cur
cur = 0
return ans
###java
class Solution {
public int countBinarySubstrings(String S) {
char[] s = S.toCharArray();
int n = s.length;
int pre = 0;
int cur = 0;
int ans = 0;
for (int i = 0; i < n; i++) {
cur++;
if (i == n - 1 || s[i] != s[i + 1]) {
// 遍历到了这一组的末尾
ans += Math.min(pre, cur);
pre = cur;
cur = 0;
}
}
return ans;
}
}
###cpp
class Solution {
public:
int countBinarySubstrings(string s) {
int n = s.size();
int pre = 0, cur = 0, ans = 0;
for (int i = 0; i < n; i++) {
cur++;
if (i == n - 1 || s[i] != s[i + 1]) {
// 遍历到了这一组的末尾
ans += min(pre, cur);
pre = cur;
cur = 0;
}
}
return ans;
}
};
###c
#define MIN(a, b) ((b) < (a) ? (b) : (a))
int countBinarySubstrings(char* s) {
int pre = 0, cur = 0, ans = 0;
for (int i = 0; s[i]; i++) {
cur++;
if (s[i] != s[i + 1]) {
// 遍历到了这一组的末尾
ans += MIN(pre, cur);
pre = cur;
cur = 0;
}
}
return ans;
}
###go
func countBinarySubstrings(s string) (ans int) {
n := len(s)
pre, cur := 0, 0
for i := range n {
cur++
if i == n-1 || s[i] != s[i+1] {
// 遍历到了这一组的末尾
ans += min(pre, cur)
pre = cur
cur = 0
}
}
return
}
###js
var countBinarySubstrings = function(s) {
const n = s.length;
let pre = 0, cur = 0, ans = 0;
for (let i = 0; i < n; i++) {
cur++;
if (i === n - 1 || s[i] !== s[i + 1]) {
// 遍历到了这一组的末尾
ans += Math.min(pre, cur);
pre = cur;
cur = 0;
}
}
return ans;
};
###rust
impl Solution {
pub fn count_binary_substrings(s: String) -> i32 {
let s = s.as_bytes();
let n = s.len();
let mut pre = 0;
let mut cur = 0;
let mut ans = 0;
for i in 0..n {
cur += 1;
if i == n - 1 || s[i] != s[i + 1] {
// 遍历到了这一组的末尾
ans += pre.min(cur);
pre = cur;
cur = 0;
}
}
ans
}
}
见下面双指针题单的「六、分组循环」。
欢迎关注 B站@灵茶山艾府
为了做到 $\mathcal{O}(1)$ 时间,我们需要快速判断所有相邻比特位是否都不同。
如何判断不同?用哪个位运算最合适?
用异或运算最合适。对于单个比特的异或,如果两个数不同,那么结果是 $1$;如果两个数相同,那么结果是 $0$。
如何对所有相邻比特位做异或运算?
例如 $n = 10101$,可以把 $n$ 右移一位,得到 $01010$,再与 $10101$ 做异或运算,计算的就是相邻比特位的异或值了。
如果异或结果全为 $1$,就说明所有相邻比特位都不同。
如何判断一个二进制数全为 $1$?
这相当于判断二进制数加一后,是否为 231. 2 的幂。
设 $x$ 为 (n >> 1) ^ n,如果 (x + 1) & x 等于 $0$,那么说明 $x$ 全为 $1$。
class Solution:
def hasAlternatingBits(self, n: int) -> bool:
x = (n >> 1) ^ n
return (x + 1) & x == 0
class Solution {
public boolean hasAlternatingBits(int n) {
int x = (n >> 1) ^ n;
return ((x + 1) & x) == 0;
}
}
class Solution {
public:
bool hasAlternatingBits(int n) {
uint32_t x = (n >> 1) ^ n;
return ((x + 1) & x) == 0;
}
};
bool hasAlternatingBits(int n) {
uint32_t x = (n >> 1) ^ n;
return ((x + 1) & x) == 0;
}
func hasAlternatingBits(n int) bool {
x := n>>1 ^ n
return (x+1)&x == 0
}
var hasAlternatingBits = function(n) {
const x = (n >> 1) ^ n;
return ((x + 1) & x) === 0;
};
impl Solution {
pub fn has_alternating_bits(n: i32) -> bool {
let x = (n >> 1) ^ n;
(x + 1) & x == 0
}
}
见下面位运算题单的「一、基础题」。
欢迎关注 B站@灵茶山艾府
枚举小时 $h=0,1,2,\ldots,11$ 以及分钟 $m=0,1,2,\ldots,59$。如果 $h$ 二进制中的 $1$ 的个数加上 $m$ 二进制中的 $1$ 的个数恰好等于 $\textit{turnedOn}$,那么把 $h:m$ 添加到答案中。
注意如果 $m$ 是个位数,需要添加一个前导零。
class Solution:
def readBinaryWatch(self, turnedOn: int) -> List[str]:
ans = []
for h in range(12):
for m in range(60):
if h.bit_count() + m.bit_count() == turnedOn:
ans.append(f"{h}:{m:02d}")
return ans
class Solution {
public List<String> readBinaryWatch(int turnedOn) {
List<String> ans = new ArrayList<>();
for (int h = 0; h < 12; h++) {
for (int m = 0; m < 60; m++) {
if (Integer.bitCount(h) + Integer.bitCount(m) == turnedOn) {
ans.add(String.format("%d:%02d", h, m));
}
}
}
return ans;
}
}
class Solution {
public:
vector<string> readBinaryWatch(int turnedOn) {
vector<string> ans;
char s[6];
for (uint8_t h = 0; h < 12; h++) {
for (uint8_t m = 0; m < 60; m++) {
if (popcount(h) + popcount(m) == turnedOn) {
sprintf(s, "%d:%02d", h, m);
ans.emplace_back(s);
}
}
}
return ans;
}
};
func readBinaryWatch(turnedOn int) (ans []string) {
for h := range 12 {
for m := range 60 {
if bits.OnesCount8(uint8(h))+bits.OnesCount8(uint8(m)) == turnedOn {
ans = append(ans, fmt.Sprintf("%d:%02d", h, m))
}
}
}
return
}
欢迎关注 B站@灵茶山艾府
以反转一个 $8$ 位整数为例。
为方便阅读,我把这个数字记作 $12345678$。目标是得到 $87654321$。
用分治思考,反转 $12345678$ 可以分成如下三步:
反转 $1234$ 可以拆分为反转 $12$ 和 $34$,反转 $5678$ 可以拆分为反转 $56$ 和 $78$。
对于 $12$ 这种长为 $2$ 的情况,交换 $1$ 和 $2$ 即可完成反转。
![]()
你可能会问:这样做,算法能更快吗?
利用位运算「并行计算」的特点,我们可以高效地实现上述过程。
去掉递归的「递」,直接看「归」的过程(自底向上)。
递归的最底层是反转 $12$,反转 $34$,反转 $56$,反转 $78$。利用位运算,这些反转可以同时完成:
$$
\begin{array}{c}
\text{12345678} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{分离}} \
\text{1\phantom{2}3\phantom{4}5\phantom{6}7\phantom{8}} \
\text{\phantom{1}2\phantom{3}4\phantom{5}6\phantom{7}8} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{移位}} \
\text{\phantom{2}1\phantom{2}3\phantom{4}5\phantom{6}7} \
\text{2\phantom{3}4\phantom{5}6\phantom{7}8\phantom{7}} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{合并}} \
\text{21436587} \
\end{array}
$$
然后两个两个交换:
$$
\begin{array}{c}
\text{21436587} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{分离}} \
\text{21\phantom{11}65\phantom{11}} \
\text{\phantom{11}43\phantom{11}87} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{移位}} \
\text{\phantom{11}21\phantom{11}65} \
\text{43\phantom{11}87\phantom{11}} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{合并}} \
\text{43218765} \
\end{array}
$$
然后四个四个交换:
$$
\begin{array}{c}
\text{43218765} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{分离}} \
\text{4321\phantom{1111}} \
\text{\phantom{1111}8765} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{移位}} \
\text{\phantom{1111}4321} \
\text{8765\phantom{1111}} \
\left\downarrow \rule{0pt}{1.5em} \right. \rlap{\text{合并}} \
\text{87654321} \
\end{array}
$$
依此类推。
对于 $32$ 位整数,还需要执行八个八个交换,最后把高低 $16$ 位交换。
m0 = 0x55555555 # 01010101 ...
m1 = 0x33333333 # 00110011 ...
m2 = 0x0f0f0f0f # 00001111 ...
m3 = 0x00ff00ff # 00000000111111110000000011111111
m4 = 0x0000ffff # 00000000000000001111111111111111
class Solution:
def reverseBits(self, n: int) -> int:
n = n>>1&m0 | (n&m0)<<1 # 交换相邻位
n = n>>2&m1 | (n&m1)<<2 # 两个两个交换
n = n>>4&m2 | (n&m2)<<4 # 四个四个交换
n = n>>8&m3 | (n&m3)<<8 # 八个八个交换
return n>>16 | (n&m4)<<16 # 交换高低 16 位
class Solution {
private static final int m0 = 0x55555555; // 01010101 ...
private static final int m1 = 0x33333333; // 00110011 ...
private static final int m2 = 0x0f0f0f0f; // 00001111 ...
private static final int m3 = 0x00ff00ff; // 00000000111111110000000011111111
public int reverseBits(int n) {
n = n>>>1&m0 | (n&m0)<<1; // 交换相邻位
n = n>>>2&m1 | (n&m1)<<2; // 两个两个交换
n = n>>>4&m2 | (n&m2)<<4; // 四个四个交换
n = n>>>8&m3 | (n&m3)<<8; // 八个八个交换
return n>>>16 | n<<16; // 交换高低 16 位
}
}
class Solution {
static constexpr uint32_t m0 = 0x55555555; // 01010101 ...
static constexpr uint32_t m1 = 0x33333333; // 00110011 ...
static constexpr uint32_t m2 = 0x0f0f0f0f; // 00001111 ...
static constexpr uint32_t m3 = 0x00ff00ff; // 00000000111111110000000011111111
uint32_t reverseBits32(uint32_t n) {
n = n>>1&m0 | (n&m0)<<1; // 交换相邻位
n = n>>2&m1 | (n&m1)<<2; // 两个两个交换
n = n>>4&m2 | (n&m2)<<4; // 四个四个交换
n = n>>8&m3 | (n&m3)<<8; // 八个八个交换
return n>>16 | n<<16; // 交换高低 16 位
}
public:
int reverseBits(int n) {
return reverseBits32(n);
}
};
const m0 = 0x55555555 // 01010101 ...
const m1 = 0x33333333 // 00110011 ...
const m2 = 0x0f0f0f0f // 00001111 ...
const m3 = 0x00ff00ff // 00000000111111110000000011111111
const m4 = 0x0000ffff // 00000000000000001111111111111111
func reverseBits(n int) int {
n = n>>1&m0 | n&m0<<1 // 交换相邻位
n = n>>2&m1 | n&m1<<2 // 两个两个交换
n = n>>4&m2 | n&m2<<4 // 四个四个交换
n = n>>8&m3 | n&m3<<8 // 八个八个交换
return n>>16 | n&m4<<16 // 交换高低 16 位
}
class Solution:
def reverseBits(self, n: int) -> int:
# 没有 O(1) 的库函数,只能用字符串转换代替
# 032b 中的 b 表示转成二进制串,032 表示补前导零到长度等于 32
return int(f'{n:032b}'[::-1], 2)
class Solution:
def reverseBits(self, n: int) -> int:
# 没有 O(1) 的库函数,只能用字符串转换代替
return int(bin(n)[2:].zfill(32)[::-1], 2)
class Solution {
public int reverseBits(int n) {
return Integer.reverse(n);
}
}
class Solution {
public:
int reverseBits(int n) {
return __builtin_bitreverse32(n);
}
};
func reverseBits(n int) int {
return int(bits.Reverse32(uint32(n)))
}
欢迎关注 B站@灵茶山艾府
二进制的加法怎么算?
和十进制的加法一样,从低到高(从右往左)计算。
示例 1 是 $11+1$,计算过程如下:
由此可见,需要在计算过程中维护进位值 $\textit{carry}$,每次计算
$$
\textit{sum} = a\ 这一位的值 + b\ 这一位的值 + \textit{carry}
$$
然后答案这一位填 $\textit{sum}\bmod 2$,进位更新为 $\left\lfloor\dfrac{\textit{sum}}{2}\right\rfloor$。例如 $\textit{sum} = 10$,那么答案这一位填 $0$,进位更新为 $1$。
把计算结果插在答案的末尾,最后把答案反转。
###py
class Solution:
def addBinary(self, a: str, b: str) -> str:
ans = []
i = len(a) - 1 # 从右往左遍历 a 和 b
j = len(b) - 1
carry = 0 # 保存进位
while i >= 0 or j >= 0 or carry:
x = int(a[i]) if i >= 0 else 0
y = int(b[j]) if j >= 0 else 0
s = x + y + carry # 计算这一位的加法
# 例如 s = 10,把 '0' 填入答案,把 carry 置为 1
ans.append(str(s % 2))
carry = s // 2
i -= 1
j -= 1
return ''.join(reversed(ans))
###java
class Solution {
public String addBinary(String a, String b) {
StringBuilder ans = new StringBuilder();
int i = a.length() - 1; // 从右往左遍历 a 和 b
int j = b.length() - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry > 0) {
int x = i >= 0 ? a.charAt(i) - '0' : 0;
int y = j >= 0 ? b.charAt(j) - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.append(sum % 2);
carry = sum / 2;
i--;
j--;
}
return ans.reverse().toString();
}
}
###cpp
class Solution {
public:
string addBinary(string a, string b) {
string ans;
int i = a.size() - 1; // 从右往左遍历 a 和 b
int j = b.size() - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans += sum % 2 + '0';
carry = sum / 2;
i--;
j--;
}
ranges::reverse(ans);
return ans;
}
};
###c
#define MAX(a, b) ((b) > (a) ? (b) : (a))
char* addBinary(char* a, char* b) {
int n = strlen(a);
int m = strlen(b);
char* ans = malloc((MAX(n, m) + 2) * sizeof(char));
int k = 0;
int i = n - 1; // 从右往左遍历 a 和 b
int j = m - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans[k++] = sum % 2 + '0';
carry = sum / 2;
i--;
j--;
}
// 反转 ans
for (int l = 0, r = k - 1; l < r; l++, r--) {
char tmp = ans[l];
ans[l] = ans[r];
ans[r] = tmp;
}
ans[k] = '\0';
return ans;
}
###go
func addBinary(a, b string) string {
ans := []byte{}
i := len(a) - 1 // 从右往左遍历 a 和 b
j := len(b) - 1
carry := byte(0) // 保存进位
for i >= 0 || j >= 0 || carry > 0 {
// 计算这一位的加法
sum := carry
if i >= 0 {
sum += a[i] - '0'
}
if j >= 0 {
sum += b[j] - '0'
}
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans = append(ans, sum%2+'0')
carry = sum / 2
i--
j--
}
slices.Reverse(ans)
return string(ans)
}
###js
var addBinary = function(a, b) {
const ans = [];
let i = a.length - 1; // 从右往左遍历 a 和 b
let j = b.length - 1;
let carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
const x = i >= 0 ? Number(a[i]) : 0;
const y = j >= 0 ? Number(b[j]) : 0;
const sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.push(String(sum % 2));
carry = Math.floor(sum / 2);
i--;
j--;
}
return ans.reverse().join('');
};
###rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
let a = a.as_bytes();
let b = b.as_bytes();
let mut ans = vec![];
let mut i = a.len() as isize - 1; // 从右往左遍历 a 和 b
let mut j = b.len() as isize - 1;
let mut carry = 0; // 保存进位
while i >= 0 || j >= 0 || carry > 0 {
let x = if i >= 0 { a[i as usize] - b'0' } else { 0 };
let y = if j >= 0 { b[j as usize] - b'0' } else { 0 };
let sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.push(sum % 2 + b'0');
carry = sum / 2;
i -= 1;
j -= 1;
}
ans.reverse();
unsafe { String::from_utf8_unchecked(ans) }
}
}
直接填入答案,不反转。
###py
class Solution:
def addBinary(self, a: str, b: str) -> str:
# 保证 len(a) >= len(b),简化后续代码逻辑
if len(a) < len(b):
a, b = b, a
n, m = len(a), len(b)
ans = [0] * (n + 1)
carry = 0 # 保存进位
for i in range(n - 1, -1, -1):
j = m - (n - i)
y = int(b[j]) if j >= 0 else 0
s = int(a[i]) + y + carry
ans[i + 1] = str(s % 2)
carry = s // 2
ans[0] = str(carry)
return ''.join(ans[carry ^ 1:]) # 如果 carry == 0 则去掉 ans[0]
###java
class Solution {
public String addBinary(String a, String b) {
// 保证 a.length() >= b.length(),简化后续代码逻辑
if (a.length() < b.length()) {
return addBinary(b, a);
}
int n = a.length();
int m = b.length();
char[] ans = new char[n + 1];
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0; i--, j--) {
int x = a.charAt(i) - '0';
int y = j >= 0 ? b.charAt(j) - '0' : 0;
int sum = x + y + carry;
ans[i + 1] = (char) (sum % 2 + '0');
carry = sum / 2;
}
ans[0] = (char) (carry + '0');
// 如果 carry == 0 则去掉 ans[0]
return new String(ans, carry ^ 1, n + carry);
}
}
###cpp
class Solution {
public:
string addBinary(string a, string b) {
// 保证 a.size() >= b.size(),简化后续代码逻辑
if (a.size() < b.size()) {
swap(a, b);
}
int n = a.size(), m = b.size();
string ans(n + 1, 0);
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0; i--, j--) {
int x = a[i] - '0';
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry;
ans[i + 1] = sum % 2 + '0';
carry = sum / 2;
}
if (carry) {
ans[0] = '1';
} else {
ans.erase(ans.begin());
}
return ans;
}
};
###c
#define MAX(a, b) ((b) > (a) ? (b) : (a))
char* addBinary(char* a, char* b) {
int n = strlen(a);
int m = strlen(b);
char* ans = malloc((MAX(n, m) + 2) * sizeof(char));
ans[MAX(n, m) + 1] = '\0';
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0 || j >= 0; i--, j--) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry;
ans[MAX(i, j) + 1] = sum % 2 + '0';
carry = sum / 2;
}
ans[0] = carry + '0';
// 如果 carry == 0 则去掉 ans[0]
return ans + (carry ^ 1);
}
###go
func addBinary(a, b string) string {
// 保证 len(a) >= len(b),简化后续代码逻辑
if len(a) < len(b) {
a, b = b, a
}
n, m := len(a), len(b)
ans := make([]byte, n+1)
carry := byte(0) // 保存进位
for i := n - 1; i >= 0; i-- {
sum := a[i] - '0' + carry
if j := m - (n - i); j >= 0 {
sum += b[j] - '0'
}
ans[i+1] = sum%2 + '0'
carry = sum / 2
}
ans[0] = carry + '0'
// 如果 carry == 0 则去掉 ans[0]
return string(ans[carry^1:])
}
###js
var addBinary = function(a, b) {
// 保证 a.length >= b.length,简化后续代码逻辑
if (a.length < b.length) {
[a, b] = [b, a];
}
const n = a.length;
const m = b.length;
const ans = Array(n + 1);
let carry = 0; // 保存进位
for (let i = n - 1, j = m - 1; i >= 0; i--, j--) {
const x = Number(a[i]);
const y = j >= 0 ? Number(b[j]) : 0;
const sum = x + y + carry;
ans[i + 1] = String(sum % 2);
carry = Math.floor(sum / 2);
}
if (carry) {
ans[0] = '1';
} else {
ans.shift();
}
return ans.join('');
};
###rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
// 保证 a.len() >= b.len(),简化后续代码逻辑
if a.len() < b.len() {
return Self::add_binary(b, a);
}
let a = a.as_bytes();
let b = b.as_bytes();
let n = a.len();
let m = b.len();
let mut ans = vec![0; n + 1];
let mut carry = 0; // 保存进位
for i in (0..n).rev() {
let x = a[i] - b'0';
let y = if n - i <= m { b[m - (n - i)] - b'0' } else { 0 };
let sum = x + y + carry;
ans[i + 1] = sum % 2 + b'0';
carry = sum / 2;
}
if carry > 0 {
ans[0] = b'1';
} else {
ans.remove(0);
}
unsafe { String::from_utf8_unchecked(ans) }
}
}
欢迎关注 B站@灵茶山艾府
类似 102. 二叉树的层序遍历,用一个 BFS 模拟香槟溢出流程:第一层溢出的香槟流到第二层,第二层溢出的香槟流到第三层,依此类推。
具体地:
###py
class Solution:
def champagneTower(self, poured: int, queryRow: int, queryGlass: int) -> float:
cur = [float(poured)]
for i in range(1, queryRow + 1):
nxt = [0.0] * (i + 1)
for j, x in enumerate(cur):
if x > 1: # 溢出到下一层
nxt[j] += (x - 1) / 2
nxt[j + 1] += (x - 1) / 2
cur = nxt
return min(cur[queryGlass], 1.0) # 如果溢出,容量是 1
###java
class Solution {
public double champagneTower(int poured, int queryRow, int queryGlass) {
double[] cur = new double[]{(double) poured};
for (int i = 1; i <= queryRow; i++) {
double[] nxt = new double[i + 1];
for (int j = 0; j < cur.length; j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = nxt;
}
return Math.min(cur[queryGlass], 1); // 如果溢出,容量是 1
}
}
###cpp
class Solution {
public:
double champagneTower(int poured, int queryRow, int queryGlass) {
vector<double> cur = {1.0 * poured};
for (int i = 1; i <= queryRow; i++) {
vector<double> nxt(i + 1);
for (int j = 0; j < cur.size(); j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = move(nxt);
}
return min(cur[queryGlass], 1.0); // 如果溢出,容量是 1
}
};
###c
#define MIN(a, b) ((b) < (a) ? (b) : (a))
double champagneTower(int poured, int queryRow, int queryGlass) {
double* cur = malloc(sizeof(double));
cur[0] = poured;
int curSize = 1;
for (int i = 1; i <= queryRow; i++) {
double* nxt = calloc(i + 1, sizeof(double));
for (int j = 0; j < curSize; j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
free(cur);
cur = nxt;
curSize = i + 1;
}
double ans = MIN(cur[queryGlass], 1); // 如果溢出,容量是 1
free(cur);
return ans;
}
###go
func champagneTower(poured, queryRow, queryGlass int) float64 {
cur := []float64{float64(poured)}
for i := 1; i <= queryRow; i++ {
nxt := make([]float64, i+1)
for j, x := range cur {
if x > 1 { // 溢出到下一层
nxt[j] += (x - 1) / 2
nxt[j+1] += (x - 1) / 2
}
}
cur = nxt
}
return min(cur[queryGlass], 1) // 如果溢出,容量是 1
}
###js
var champagneTower = function(poured, queryRow, queryGlass) {
let cur = [poured];
for (let i = 1; i <= queryRow; i++) {
const nxt = Array(i + 1).fill(0);
for (let j = 0; j < cur.length; j++) {
const x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = nxt;
}
return Math.min(cur[queryGlass], 1); // 如果溢出,容量是 1
};
###rust
impl Solution {
pub fn champagne_tower(poured: i32, query_row: i32, query_glass: i32) -> f64 {
let mut cur = vec![poured as f64];
for i in 1..=query_row as usize {
let mut nxt = vec![0.0; i + 1];
for (j, x) in cur.into_iter().enumerate() {
if x > 1.0 { // 溢出到下一层
nxt[j] += (x - 1.0) / 2.0;
nxt[j + 1] += (x - 1.0) / 2.0;
}
}
cur = nxt;
}
cur[query_glass as usize].min(1.0) // 如果溢出,容量是 1
}
}
无需使用两个数组,可以像 0-1 背包那样,在同一个数组上修改。
###py
class Solution:
def champagneTower(self, poured: int, queryRow: int, queryGlass: int) -> float:
f = [0.0] * (queryRow + 1)
f[0] = float(poured)
for i in range(queryRow):
for j in range(i, -1, -1):
x = f[j] - 1
if x > 0:
f[j + 1] += x / 2
f[j] = x / 2
else:
f[j] = 0.0
return min(f[queryGlass], 1.0) # 如果溢出,容量是 1
###java
class Solution {
public double champagneTower(int poured, int queryRow, int queryGlass) {
double[] f = new double[queryRow + 1];
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return Math.min(f[queryGlass], 1); // 如果溢出,容量是 1
}
}
###cpp
class Solution {
public:
double champagneTower(int poured, int queryRow, int queryGlass) {
vector<double> f(queryRow + 1);
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return min(f[queryGlass], 1.0); // 如果溢出,容量是 1
}
};
###c
#define MIN(a, b) ((b) < (a) ? (b) : (a))
double champagneTower(int poured, int queryRow, int queryGlass) {
double* f = calloc(queryRow + 1, sizeof(double));
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
double ans = MIN(f[queryGlass], 1); // 如果溢出,容量是 1
free(f);
return ans;
}
###go
func champagneTower(poured, queryRow, queryGlass int) float64 {
f := make([]float64, queryRow+1)
f[0] = float64(poured)
for i := range queryRow {
for j := i; j >= 0; j-- {
x := f[j] - 1
if x > 0 {
f[j+1] += x / 2
f[j] = x / 2
} else {
f[j] = 0
}
}
}
return min(f[queryGlass], 1) // 如果溢出,容量是 1
}
###js
var champagneTower = function(poured, queryRow, queryGlass) {
const f = Array(queryRow + 1).fill(0);
f[0] = poured;
for (let i = 0; i < queryRow; i++) {
for (let j = i; j >= 0; j--) {
const x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return Math.min(f[queryGlass], 1); // 如果溢出,容量是 1
};
###rust
impl Solution {
pub fn champagne_tower(poured: i32, query_row: i32, query_glass: i32) -> f64 {
let query_row = query_row as usize;
let mut f = vec![0.0; query_row + 1];
f[0] = poured as f64;
for i in 0..query_row {
for j in (0..=i).rev() {
let x = f[j] - 1.0;
if x > 0.0 {
f[j + 1] += x / 2.0;
f[j] = x / 2.0;
} else {
f[j] = 0.0;
}
}
}
f[query_glass as usize].min(1.0) // 如果溢出,容量是 1
}
}
欢迎关注 B站@灵茶山艾府
分成如下三类问题,依次解答:
即 $s$ 中的最长连续相同子串的长度。这题是 1446. 连续字符。
这可以用分组循环解决。
适用场景:按照题目要求,序列会被分割成若干组,每一组的判断/处理逻辑是相同的。
核心思想:
这个写法的好处是,各个逻辑块分工明确,也不需要特判最后一组(易错点)。以我的经验,这个写法是所有写法中最不容易出 bug 的,推荐大家记住。
同样地,用分组循环分组,每组只包含两种字母。
对于每一组,计算含有相同数量的两种字母的最长子串。这题是 525. 连续数组。
做法见 我的题解。
仿照 525 题的做法,设 $\texttt{a}$ 在这个组的个数前缀和数组为 $S_a$,$\texttt{b}$ 在这个组的个数前缀和数组为 $S_b$,$\texttt{c}$ 在这个组的个数前缀和数组为 $S_c$。
子串 $[l,r)$ 中的字母 $\texttt{a},\texttt{b},\texttt{c}$ 的出现次数相等,可以拆分为如下两个约束:
只要满足这两个约束,由等号的传递性可知,子串 $[l,r)$ 中的字母 $\texttt{a}$ 和 $\texttt{c}$ 的出现次数相等,即三个字母的出现次数都相等。
两个约束即如下两个等式
$$
\begin{aligned}
S_a[r] - S_b[r] &= S_a[l] - S_b[l] \
S_b[r] - S_c[r] &= S_b[l] - S_c[l] \
\end{aligned}
$$
定义数组 $a[i] = (S_a[i] - S_b[i], S_b[i] - S_c[i])$,问题变成:
做法同上。
本题视频讲解,欢迎点赞关注~
###py
class Solution:
def longestBalanced(self, s: str) -> int:
n = len(s)
# 一种字母
ans = i = 0
while i < n:
start = i
i += 1
while i < n and s[i] == s[i - 1]:
i += 1
ans = max(ans, i - start)
# 两种字母
def f(x: str, y: str) -> None:
nonlocal ans
i = 0
while i < n:
pos = {0: i - 1} # 前缀和数组的首项是 0,位置相当于在 i-1
d = 0 # x 的个数减去 y 的个数
while i < n and (s[i] == x or s[i] == y):
d += 1 if s[i] == x else -1
if d in pos:
ans = max(ans, i - pos[d])
else:
pos[d] = i
i += 1
i += 1
f('a', 'b')
f('a', 'c')
f('b', 'c')
# 三种字母
# 前缀和数组的首项是 0,位置相当于在 -1
pos = {(0, 0): -1}
cnt = defaultdict(int)
for i, b in enumerate(s):
cnt[b] += 1
p = (cnt['a'] - cnt['b'], cnt['b'] - cnt['c'])
if p in pos:
ans = max(ans, i - pos[p])
else:
pos[p] = i
return ans
###java
class Solution {
public int longestBalanced(String S) {
char[] s = S.toCharArray();
int n = s.length;
int ans = 0;
// 一种字母
for (int i = 0; i < n; ) {
int start = i;
for (i++; i < n && s[i] == s[i - 1]; i++) ;
ans = Math.max(ans, i - start);
}
// 两种字母
ans = Math.max(ans, f(s, 'a', 'b'));
ans = Math.max(ans, f(s, 'a', 'c'));
ans = Math.max(ans, f(s, 'b', 'c'));
// 三种字母
// 把 (x, y) 压缩成一个 long,方便保存至哈希表
// (x, y) 变成 (x + n) << 20 | (y + n),其中 +n 避免出现负数
Map<Long, Integer> pos = new HashMap<>();
pos.put((long) n << 20 | n, -1); // 前缀和数组的首项是 0,位置相当于在 -1
int[] cnt = new int[3];
for (int i = 0; i < n; i++) {
cnt[s[i] - 'a']++;
long p = (long) (cnt[0] - cnt[1] + n) << 20 | (cnt[1] - cnt[2] + n);
if (pos.containsKey(p)) {
ans = Math.max(ans, i - pos.get(p));
} else {
pos.put(p, i);
}
}
return ans;
}
private int f(char[] s, char x, char y) {
int n = s.length;
int ans = 0;
for (int i = 0; i < n; i++) {
Map<Integer, Integer> pos = new HashMap<>();
pos.put(0, i - 1); // 前缀和数组的首项是 0,位置相当于在 i-1
int d = 0; // x 的个数减去 y 的个数
for (; i < n && (s[i] == x || s[i] == y); i++) {
d += s[i] == x ? 1 : -1;
if (pos.containsKey(d)) {
ans = Math.max(ans, i - pos.get(d));
} else {
pos.put(d, i);
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
int longestBalanced(string s) {
int n = s.size();
int ans = 0;
// 一种字母
for (int i = 0; i < n;) {
int start = i;
for (i++; i < n && s[i] == s[i - 1]; i++);
ans = max(ans, i - start);
}
// 两种字母
auto f = [&](char x, char y) -> void {
for (int i = 0; i < n; i++) {
unordered_map<int, int> pos = {{0, i - 1}}; // 前缀和数组的首项是 0,位置相当于在 i-1
int d = 0; // x 的个数减去 y 的个数
for (; i < n && (s[i] == x || s[i] == y); i++) {
d += s[i] == x ? 1 : -1;
if (pos.contains(d)) {
ans = max(ans, i - pos[d]);
} else {
pos[d] = i;
}
}
}
};
f('a', 'b');
f('a', 'c');
f('b', 'c');
// 三种字母
// 把 (x, y) 压缩成一个 long long,方便保存至哈希表
// (x, y) 变成 (x + n) << 32 | (y + n),其中 +n 避免出现负数
unordered_map<long long, int> pos = {{1LL * n << 32 | n, -1}}; // 前缀和数组的首项是 0,位置相当于在 -1
int cnt[3]{};
for (int i = 0; i < n; i++) {
cnt[s[i] - 'a']++;
long long p = 1LL * (cnt[0] - cnt[1] + n) << 32 | (cnt[1] - cnt[2] + n);
if (pos.contains(p)) {
ans = max(ans, i - pos[p]);
} else {
pos[p] = i;
}
}
return ans;
}
};
###go
func longestBalanced(s string) (ans int) {
n := len(s)
// 一种字母
for i := 0; i < n; {
start := i
for i++; i < n && s[i] == s[i-1]; i++ {
}
ans = max(ans, i-start)
}
// 两种字母
f := func(x, y byte) {
for i := 0; i < n; i++ {
pos := map[int]int{0: i - 1} // 前缀和数组的首项是 0,位置相当于在 i-1
d := 0 // x 的个数减去 y 的个数
for ; i < n && (s[i] == x || s[i] == y); i++ {
if s[i] == x {
d++
} else {
d--
}
if j, ok := pos[d]; ok {
ans = max(ans, i-j)
} else {
pos[d] = i
}
}
}
}
f('a', 'b')
f('a', 'c')
f('b', 'c')
// 三种字母
type pair struct{ diffAB, diffBC int }
pos := map[pair]int{{}: -1} // 前缀和数组的首项是 0,位置相当于在 -1
cnt := [3]int{}
for i, b := range s {
cnt[b-'a']++
p := pair{cnt[0] - cnt[1], cnt[1] - cnt[2]}
if j, ok := pos[p]; ok {
ans = max(ans, i-j)
} else {
pos[p] = i
}
}
return
}