每日一题-二进制求和🟢
给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。
示例 1:
输入:a = "11", b = "1" 输出:"100"
示例 2:
输入:a = "1010", b = "1011" 输出:"10101"
提示:
1 <= a.length, b.length <= 104-
a和b仅由字符'0'或'1'组成 - 字符串如果不是
"0",就不含前导零
给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。
示例 1:
输入:a = "11", b = "1" 输出:"100"
示例 2:
输入:a = "1010", b = "1011" 输出:"10101"
提示:
1 <= a.length, b.length <= 104a 和 b 仅由字符 '0' 或 '1' 组成"0" ,就不含前导零二进制的加法怎么算?
和十进制的加法一样,从低到高(从右往左)计算。
示例 1 是 $11+1$,计算过程如下:
由此可见,需要在计算过程中维护进位值 $\textit{carry}$,每次计算
$$
\textit{sum} = a\ 这一位的值 + b\ 这一位的值 + \textit{carry}
$$
然后答案这一位填 $\textit{sum}\bmod 2$,进位更新为 $\left\lfloor\dfrac{\textit{sum}}{2}\right\rfloor$。例如 $\textit{sum} = 10$,那么答案这一位填 $0$,进位更新为 $1$。
把计算结果插在答案的末尾,最后把答案反转。
###py
class Solution:
def addBinary(self, a: str, b: str) -> str:
ans = []
i = len(a) - 1 # 从右往左遍历 a 和 b
j = len(b) - 1
carry = 0 # 保存进位
while i >= 0 or j >= 0 or carry:
x = int(a[i]) if i >= 0 else 0
y = int(b[j]) if j >= 0 else 0
s = x + y + carry # 计算这一位的加法
# 例如 s = 10,把 '0' 填入答案,把 carry 置为 1
ans.append(str(s % 2))
carry = s // 2
i -= 1
j -= 1
return ''.join(reversed(ans))
###java
class Solution {
public String addBinary(String a, String b) {
StringBuilder ans = new StringBuilder();
int i = a.length() - 1; // 从右往左遍历 a 和 b
int j = b.length() - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry > 0) {
int x = i >= 0 ? a.charAt(i) - '0' : 0;
int y = j >= 0 ? b.charAt(j) - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.append(sum % 2);
carry = sum / 2;
i--;
j--;
}
return ans.reverse().toString();
}
}
###cpp
class Solution {
public:
string addBinary(string a, string b) {
string ans;
int i = a.size() - 1; // 从右往左遍历 a 和 b
int j = b.size() - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans += sum % 2 + '0';
carry = sum / 2;
i--;
j--;
}
ranges::reverse(ans);
return ans;
}
};
###c
#define MAX(a, b) ((b) > (a) ? (b) : (a))
char* addBinary(char* a, char* b) {
int n = strlen(a);
int m = strlen(b);
char* ans = malloc((MAX(n, m) + 2) * sizeof(char));
int k = 0;
int i = n - 1; // 从右往左遍历 a 和 b
int j = m - 1;
int carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans[k++] = sum % 2 + '0';
carry = sum / 2;
i--;
j--;
}
// 反转 ans
for (int l = 0, r = k - 1; l < r; l++, r--) {
char tmp = ans[l];
ans[l] = ans[r];
ans[r] = tmp;
}
ans[k] = '\0';
return ans;
}
###go
func addBinary(a, b string) string {
ans := []byte{}
i := len(a) - 1 // 从右往左遍历 a 和 b
j := len(b) - 1
carry := byte(0) // 保存进位
for i >= 0 || j >= 0 || carry > 0 {
// 计算这一位的加法
sum := carry
if i >= 0 {
sum += a[i] - '0'
}
if j >= 0 {
sum += b[j] - '0'
}
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans = append(ans, sum%2+'0')
carry = sum / 2
i--
j--
}
slices.Reverse(ans)
return string(ans)
}
###js
var addBinary = function(a, b) {
const ans = [];
let i = a.length - 1; // 从右往左遍历 a 和 b
let j = b.length - 1;
let carry = 0; // 保存进位
while (i >= 0 || j >= 0 || carry) {
const x = i >= 0 ? Number(a[i]) : 0;
const y = j >= 0 ? Number(b[j]) : 0;
const sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.push(String(sum % 2));
carry = Math.floor(sum / 2);
i--;
j--;
}
return ans.reverse().join('');
};
###rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
let a = a.as_bytes();
let b = b.as_bytes();
let mut ans = vec![];
let mut i = a.len() as isize - 1; // 从右往左遍历 a 和 b
let mut j = b.len() as isize - 1;
let mut carry = 0; // 保存进位
while i >= 0 || j >= 0 || carry > 0 {
let x = if i >= 0 { a[i as usize] - b'0' } else { 0 };
let y = if j >= 0 { b[j as usize] - b'0' } else { 0 };
let sum = x + y + carry; // 计算这一位的加法
// 例如 sum = 10,把 '0' 填入答案,把 carry 置为 1
ans.push(sum % 2 + b'0');
carry = sum / 2;
i -= 1;
j -= 1;
}
ans.reverse();
unsafe { String::from_utf8_unchecked(ans) }
}
}
直接填入答案,不反转。
###py
class Solution:
def addBinary(self, a: str, b: str) -> str:
# 保证 len(a) >= len(b),简化后续代码逻辑
if len(a) < len(b):
a, b = b, a
n, m = len(a), len(b)
ans = [0] * (n + 1)
carry = 0 # 保存进位
for i in range(n - 1, -1, -1):
j = m - (n - i)
y = int(b[j]) if j >= 0 else 0
s = int(a[i]) + y + carry
ans[i + 1] = str(s % 2)
carry = s // 2
ans[0] = str(carry)
return ''.join(ans[carry ^ 1:]) # 如果 carry == 0 则去掉 ans[0]
###java
class Solution {
public String addBinary(String a, String b) {
// 保证 a.length() >= b.length(),简化后续代码逻辑
if (a.length() < b.length()) {
return addBinary(b, a);
}
int n = a.length();
int m = b.length();
char[] ans = new char[n + 1];
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0; i--, j--) {
int x = a.charAt(i) - '0';
int y = j >= 0 ? b.charAt(j) - '0' : 0;
int sum = x + y + carry;
ans[i + 1] = (char) (sum % 2 + '0');
carry = sum / 2;
}
ans[0] = (char) (carry + '0');
// 如果 carry == 0 则去掉 ans[0]
return new String(ans, carry ^ 1, n + carry);
}
}
###cpp
class Solution {
public:
string addBinary(string a, string b) {
// 保证 a.size() >= b.size(),简化后续代码逻辑
if (a.size() < b.size()) {
swap(a, b);
}
int n = a.size(), m = b.size();
string ans(n + 1, 0);
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0; i--, j--) {
int x = a[i] - '0';
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry;
ans[i + 1] = sum % 2 + '0';
carry = sum / 2;
}
if (carry) {
ans[0] = '1';
} else {
ans.erase(ans.begin());
}
return ans;
}
};
###c
#define MAX(a, b) ((b) > (a) ? (b) : (a))
char* addBinary(char* a, char* b) {
int n = strlen(a);
int m = strlen(b);
char* ans = malloc((MAX(n, m) + 2) * sizeof(char));
ans[MAX(n, m) + 1] = '\0';
int carry = 0; // 保存进位
for (int i = n - 1, j = m - 1; i >= 0 || j >= 0; i--, j--) {
int x = i >= 0 ? a[i] - '0' : 0;
int y = j >= 0 ? b[j] - '0' : 0;
int sum = x + y + carry;
ans[MAX(i, j) + 1] = sum % 2 + '0';
carry = sum / 2;
}
ans[0] = carry + '0';
// 如果 carry == 0 则去掉 ans[0]
return ans + (carry ^ 1);
}
###go
func addBinary(a, b string) string {
// 保证 len(a) >= len(b),简化后续代码逻辑
if len(a) < len(b) {
a, b = b, a
}
n, m := len(a), len(b)
ans := make([]byte, n+1)
carry := byte(0) // 保存进位
for i := n - 1; i >= 0; i-- {
sum := a[i] - '0' + carry
if j := m - (n - i); j >= 0 {
sum += b[j] - '0'
}
ans[i+1] = sum%2 + '0'
carry = sum / 2
}
ans[0] = carry + '0'
// 如果 carry == 0 则去掉 ans[0]
return string(ans[carry^1:])
}
###js
var addBinary = function(a, b) {
// 保证 a.length >= b.length,简化后续代码逻辑
if (a.length < b.length) {
[a, b] = [b, a];
}
const n = a.length;
const m = b.length;
const ans = Array(n + 1);
let carry = 0; // 保存进位
for (let i = n - 1, j = m - 1; i >= 0; i--, j--) {
const x = Number(a[i]);
const y = j >= 0 ? Number(b[j]) : 0;
const sum = x + y + carry;
ans[i + 1] = String(sum % 2);
carry = Math.floor(sum / 2);
}
if (carry) {
ans[0] = '1';
} else {
ans.shift();
}
return ans.join('');
};
###rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
// 保证 a.len() >= b.len(),简化后续代码逻辑
if a.len() < b.len() {
return Self::add_binary(b, a);
}
let a = a.as_bytes();
let b = b.as_bytes();
let n = a.len();
let m = b.len();
let mut ans = vec![0; n + 1];
let mut carry = 0; // 保存进位
for i in (0..n).rev() {
let x = a[i] - b'0';
let y = if n - i <= m { b[m - (n - i)] - b'0' } else { 0 };
let sum = x + y + carry;
ans[i + 1] = sum % 2 + b'0';
carry = sum / 2;
}
if carry > 0 {
ans[0] = b'1';
} else {
ans.remove(0);
}
unsafe { String::from_utf8_unchecked(ans) }
}
}
欢迎关注 B站@灵茶山艾府
考虑一个最朴素的方法:先将 $a$ 和 $b$ 转化成十进制数,求和后再转化为二进制数。利用 Python 和 Java 自带的高精度运算,我们可以很简单地写出这个程序:
###python
class Solution:
def addBinary(self, a, b) -> str:
return '{0:b}'.format(int(a, 2) + int(b, 2))
###Java
class Solution {
public String addBinary(String a, String b) {
return Integer.toBinaryString(
Integer.parseInt(a, 2) + Integer.parseInt(b, 2)
);
}
}
如果 $a$ 的位数是 $n$,$b$ 的位数为 $m$,这个算法的渐进时间复杂度为 $O(n + m)$。但是这里非常简单的实现基于 Python 和 Java 本身的高精度功能,在其他的语言中可能并不适用,并且在 Java 中:
Integer
Long
BigInteger
因此,为了适用于长度较大的字符串计算,我们应该使用更加健壮的算法。
思路和算法
我们可以借鉴「列竖式」的方法,末尾对齐,逐位相加。在十进制的计算中「逢十进一」,二进制中我们需要「逢二进一」。
具体的,我们可以取 $n = \max{ |a|, |b| }$,循环 $n$ 次,从最低位开始遍历。我们使用一个变量 $\textit{carry}$ 表示上一个位置的进位,初始值为 $0$。记当前位置对其的两个位为 $a_i$ 和 $b_i$,则每一位的答案为 $(\textit{carry} + a_i + b_i) \bmod{2}$,下一位的进位为 $\lfloor \frac{\textit{carry} + a_i + b_i}{2} \rfloor$。重复上述步骤,直到数字 $a$ 和 $b$ 的每一位计算完毕。最后如果 $\textit{carry}$ 的最高位不为 $0$,则将最高位添加到计算结果的末尾。
注意,为了让各个位置对齐,你可以先反转这个代表二进制数字的字符串,然后低下标对应低位,高下标对应高位。当然你也可以直接把 $a$ 和 $b$ 中短的那一个补 $0$ 直到和长的那个一样长,然后从高位向低位遍历,对应位置的答案按照顺序存入答案字符串内,最终将答案串反转。这里的代码给出第一种的实现。
代码
###Java
class Solution {
public String addBinary(String a, String b) {
StringBuffer ans = new StringBuffer();
int n = Math.max(a.length(), b.length()), carry = 0;
for (int i = 0; i < n; ++i) {
carry += i < a.length() ? (a.charAt(a.length() - 1 - i) - '0') : 0;
carry += i < b.length() ? (b.charAt(b.length() - 1 - i) - '0') : 0;
ans.append((char) (carry % 2 + '0'));
carry /= 2;
}
if (carry > 0) {
ans.append('1');
}
ans.reverse();
return ans.toString();
}
}
###C++
class Solution {
public:
string addBinary(string a, string b) {
string ans;
reverse(a.begin(), a.end());
reverse(b.begin(), b.end());
int n = max(a.size(), b.size()), carry = 0;
for (size_t i = 0; i < n; ++i) {
carry += i < a.size() ? (a.at(i) == '1') : 0;
carry += i < b.size() ? (b.at(i) == '1') : 0;
ans.push_back((carry % 2) ? '1' : '0');
carry /= 2;
}
if (carry) {
ans.push_back('1');
}
reverse(ans.begin(), ans.end());
return ans;
}
};
###Go
func addBinary(a string, b string) string {
ans := ""
carry := 0
lenA, lenB := len(a), len(b)
n := max(lenA, lenB)
for i := 0; i < n; i++ {
if i < lenA {
carry += int(a[lenA-i-1] - '0')
}
if i < lenB {
carry += int(b[lenB-i-1] - '0')
}
ans = strconv.Itoa(carry%2) + ans
carry /= 2
}
if carry > 0 {
ans = "1" + ans
}
return ans
}
###C
void reserve(char* s) {
int len = strlen(s);
for (int i = 0; i < len / 2; i++) {
char t = s[i];
s[i] = s[len - i - 1], s[len - i - 1] = t;
}
}
char* addBinary(char* a, char* b) {
reserve(a);
reserve(b);
int len_a = strlen(a), len_b = strlen(b);
int n = fmax(len_a, len_b), carry = 0, len = 0;
char* ans = (char*)malloc(sizeof(char) * (n + 2));
for (int i = 0; i < n; ++i) {
carry += i < len_a ? (a[i] == '1') : 0;
carry += i < len_b ? (b[i] == '1') : 0;
ans[len++] = carry % 2 + '0';
carry /= 2;
}
if (carry) {
ans[len++] = '1';
}
ans[len] = '\0';
reserve(ans);
return ans;
}
###Python
class Solution:
def addBinary(self, a: str, b: str) -> str:
ans = []
a = a[::-1]
b = b[::-1]
n = max(len(a), len(b))
carry = 0
for i in range(n):
carry += int(a[i]) if i < len(a) else 0
carry += int(b[i]) if i < len(b) else 0
ans.append(str(carry % 2))
carry //= 2
if carry:
ans.append('1')
return ''.join(ans)[::-1]
###C#
public class Solution {
public string AddBinary(string a, string b) {
char[] aArr = a.ToCharArray();
char[] bArr = b.ToCharArray();
Array.Reverse(aArr);
Array.Reverse(bArr);
int n = Math.Max(a.Length, b.Length);
int carry = 0;
List<char> ans = new List<char>();
for (int i = 0; i < n; i++) {
carry += i < aArr.Length ? (aArr[i] == '1' ? 1 : 0) : 0;
carry += i < bArr.Length ? (bArr[i] == '1' ? 1 : 0) : 0;
ans.Add((carry % 2) == 1 ? '1' : '0');
carry /= 2;
}
if (carry > 0) {
ans.Add('1');
}
ans.Reverse();
return new string(ans.ToArray());
}
}
###JavaScript
var addBinary = function(a, b) {
let ans = [];
a = a.split('').reverse().join('');
b = b.split('').reverse().join('');
const n = Math.max(a.length, b.length);
let carry = 0;
for (let i = 0; i < n; i++) {
carry += i < a.length ? parseInt(a[i]) : 0;
carry += i < b.length ? parseInt(b[i]) : 0;
ans.push((carry % 2).toString());
carry = Math.floor(carry / 2);
}
if (carry) {
ans.push('1');
}
return ans.reverse().join('');
};
###TypeScript
function addBinary(a: string, b: string): string {
let ans: string[] = [];
a = a.split('').reverse().join('');
b = b.split('').reverse().join('');
const n = Math.max(a.length, b.length);
let carry = 0;
for (let i = 0; i < n; i++) {
carry += i < a.length ? parseInt(a[i]) : 0;
carry += i < b.length ? parseInt(b[i]) : 0;
ans.push((carry % 2).toString());
carry = Math.floor(carry / 2);
}
if (carry) {
ans.push('1');
}
return ans.reverse().join('');
}
###Rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
let mut a_chars: Vec<char> = a.chars().collect();
let mut b_chars: Vec<char> = b.chars().collect();
a_chars.reverse();
b_chars.reverse();
let n = a_chars.len().max(b_chars.len());
let mut carry = 0;
let mut ans = Vec::new();
for i in 0..n {
carry += if i < a_chars.len() { if a_chars[i] == '1' { 1 } else { 0 } } else { 0 };
carry += if i < b_chars.len() { if b_chars[i] == '1' { 1 } else { 0 } } else { 0 };
ans.push(if carry % 2 == 1 { '1' } else { '0' });
carry /= 2;
}
if carry > 0 {
ans.push('1');
}
ans.reverse();
ans.into_iter().collect()
}
}
复杂度分析
假设 $n = \max{ |a|, |b| }$。
思路和算法
如果不允许使用加减乘除,则可以使用位运算替代上述运算中的一些加减乘除的操作。
如果不了解位运算,可以先了解位运算并尝试练习以下题目:
我们可以设计这样的算法来计算:
answer = x ^ y
carry = (x & y) << 1
x = answer,y = carry
为什么这个方法是可行的呢?在第一轮计算中,answer 的最后一位是 $x$ 和 $y$ 相加之后的结果,carry 的倒数第二位是 $x$ 和 $y$ 最后一位相加的进位。接着每一轮中,由于 carry 是由 $x$ 和 $y$ 按位与并且左移得到的,那么最后会补零,所以在下面计算的过程中后面的数位不受影响,而每一轮都可以得到一个低 $i$ 位的答案和它向低 $i + 1$ 位的进位,也就模拟了加法的过程。
代码
###Java
import java.math.BigInteger;
class Solution {
public String addBinary(String a, String b) {
BigInteger x = new BigInteger(a, 2);
BigInteger y = new BigInteger(b, 2);
while (!y.equals(BigInteger.ZERO)) {
BigInteger answer = x.xor(y);
BigInteger carry = x.and(y).shiftLeft(1);
x = answer;
y = carry;
}
return x.toString(2);
}
}
###C++
class Solution {
public:
string addBinary(string a, string b) {
string result = "";
int i = a.length() - 1, j = b.length() - 1;
int carry = 0;
while (i >= 0 || j >= 0 || carry) {
int sum = carry;
if (i >= 0) {
sum += a[i--] - '0';
}
if (j >= 0) {
sum += b[j--] - '0';
}
result = char(sum % 2 + '0') + result;
carry = sum / 2;
}
return result;
}
};
###Go
func addBinary(a string, b string) string {
if a == "" {
return b
}
if b == "" {
return a
}
x := new(big.Int)
x.SetString(a, 2)
y := new(big.Int)
y.SetString(b, 2)
zero := new(big.Int)
for y.Cmp(zero) != 0 {
answer := new(big.Int)
answer.Xor(x, y)
carry := new(big.Int)
carry.And(x, y)
carry.Lsh(carry, 1)
x.Set(answer)
y.Set(carry)
}
return x.Text(2)
}
###C
char* addBinary(char* a, char* b) {
int len_a = strlen(a);
int len_b = strlen(b);
int max_len = (len_a > len_b ? len_a : len_b) + 2;
char* result = (char*)malloc(max_len * sizeof(char));
if (!result) {
return NULL;
}
int i = len_a - 1, j = len_b - 1;
int carry = 0;
int k = max_len - 2;
result[max_len - 1] = '\0';
while (i >= 0 || j >= 0 || carry) {
int sum = carry;
if (i >= 0) {
sum += a[i--] - '0';
}
if (j >= 0) {
sum += b[j--] - '0';
}
result[k--] = (sum % 2) + '0';
carry = sum / 2;
}
if (k >= 0) {
char* final_result = result + k + 1;
char* dup = strdup(final_result);
free(result);
return dup;
}
return result;
}
###Python
class Solution:
def addBinary(self, a, b) -> str:
x, y = int(a, 2), int(b, 2)
while y:
answer = x ^ y
carry = (x & y) << 1
x, y = answer, carry
return bin(x)[2:]
###C#
public class Solution {
public string AddBinary(string a, string b) {
if (string.IsNullOrEmpty(a)) {
return b;
}
if (string.IsNullOrEmpty(b)) {
return a;
}
BigInteger x = BigInteger.Parse("0" + a, System.Globalization.NumberStyles.AllowBinarySpecifier);
BigInteger y = BigInteger.Parse("0" + b, System.Globalization.NumberStyles.AllowBinarySpecifier);
while (y != 0) {
BigInteger answer = x ^ y;
BigInteger carry = (x & y) << 1;
x = answer;
y = carry;
}
if (x == 0) {
return "0";
}
string result = "";
while (x > 0) {
result = (x % 2).ToString() + result;
x /= 2;
}
return result;
}
}
###JavaScript
var addBinary = function(a, b) {
let x = BigInt('0b' + a);
let y = BigInt('0b' + b);
while (y !== 0n) {
let answer = x ^ y;
let carry = (x & y) << 1n;
x = answer;
y = carry;
}
return x.toString(2);
};
###TypeScript
function addBinary(a: string, b: string): string {
let x = BigInt('0b' + a);
let y = BigInt('0b' + b);
while (y !== 0n) {
let answer = x ^ y;
let carry = (x & y) << 1n;
x = answer;
y = carry;
}
return x.toString(2);
}
###Rust
impl Solution {
pub fn add_binary(a: String, b: String) -> String {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let mut i = a_chars.len() as i32 - 1;
let mut j = b_chars.len() as i32 - 1;
let mut carry = 0;
let mut result = Vec::new();
while i >= 0 || j >= 0 || carry > 0 {
let mut sum = carry;
if i >= 0 {
sum += a_chars[i as usize].to_digit(2).unwrap_or(0);
i -= 1;
}
if j >= 0 {
sum += b_chars[j as usize].to_digit(2).unwrap_or(0);
j -= 1;
}
result.push(char::from_digit(sum % 2, 10).unwrap());
carry = sum / 2;
}
result.iter().rev().collect()
}
}
复杂度分析
整体思路是将两个字符串较短的用 $0$ 补齐,使得两个字符串长度一致,然后从末尾进行遍历计算,得到最终结果。
本题解中大致思路与上述一致,但由于字符串操作原因,不确定最后的结果是否会多出一位进位,所以会有 2 种处理方式:
时间复杂度:$O(n)$
###Java
class Solution {
public String addBinary(String a, String b) {
StringBuilder ans = new StringBuilder();
int ca = 0;
for(int i = a.length() - 1, j = b.length() - 1;i >= 0 || j >= 0; i--, j--) {
int sum = ca;
sum += i >= 0 ? a.charAt(i) - '0' : 0;
sum += j >= 0 ? b.charAt(j) - '0' : 0;
ans.append(sum % 2);
ca = sum / 2;
}
ans.append(ca == 1 ? ca : "");
return ans.reverse().toString();
}
}
###JavaScript
/**
* @param {string} a
* @param {string} b
* @return {string}
*/
var addBinary = function(a, b) {
let ans = "";
let ca = 0;
for(let i = a.length - 1, j = b.length - 1;i >= 0 || j >= 0; i--, j--) {
let sum = ca;
sum += i >= 0 ? parseInt(a[i]) : 0;
sum += j >= 0 ? parseInt(b[j]) : 0;
ans += sum % 2;
ca = Math.floor(sum / 2);
}
ans += ca == 1 ? ca : "";
return ans.split('').reverse().join('');
};
<
,
,
>
想看大鹏画解更多高频面试题,欢迎阅读大鹏的 LeetBook:《画解剑指 Offer 》,O(∩_∩)O
类似 102. 二叉树的层序遍历,用一个 BFS 模拟香槟溢出流程:第一层溢出的香槟流到第二层,第二层溢出的香槟流到第三层,依此类推。
具体地:
###py
class Solution:
def champagneTower(self, poured: int, queryRow: int, queryGlass: int) -> float:
cur = [float(poured)]
for i in range(1, queryRow + 1):
nxt = [0.0] * (i + 1)
for j, x in enumerate(cur):
if x > 1: # 溢出到下一层
nxt[j] += (x - 1) / 2
nxt[j + 1] += (x - 1) / 2
cur = nxt
return min(cur[queryGlass], 1.0) # 如果溢出,容量是 1
###java
class Solution {
public double champagneTower(int poured, int queryRow, int queryGlass) {
double[] cur = new double[]{(double) poured};
for (int i = 1; i <= queryRow; i++) {
double[] nxt = new double[i + 1];
for (int j = 0; j < cur.length; j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = nxt;
}
return Math.min(cur[queryGlass], 1); // 如果溢出,容量是 1
}
}
###cpp
class Solution {
public:
double champagneTower(int poured, int queryRow, int queryGlass) {
vector<double> cur = {1.0 * poured};
for (int i = 1; i <= queryRow; i++) {
vector<double> nxt(i + 1);
for (int j = 0; j < cur.size(); j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = move(nxt);
}
return min(cur[queryGlass], 1.0); // 如果溢出,容量是 1
}
};
###c
#define MIN(a, b) ((b) < (a) ? (b) : (a))
double champagneTower(int poured, int queryRow, int queryGlass) {
double* cur = malloc(sizeof(double));
cur[0] = poured;
int curSize = 1;
for (int i = 1; i <= queryRow; i++) {
double* nxt = calloc(i + 1, sizeof(double));
for (int j = 0; j < curSize; j++) {
double x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
free(cur);
cur = nxt;
curSize = i + 1;
}
double ans = MIN(cur[queryGlass], 1); // 如果溢出,容量是 1
free(cur);
return ans;
}
###go
func champagneTower(poured, queryRow, queryGlass int) float64 {
cur := []float64{float64(poured)}
for i := 1; i <= queryRow; i++ {
nxt := make([]float64, i+1)
for j, x := range cur {
if x > 1 { // 溢出到下一层
nxt[j] += (x - 1) / 2
nxt[j+1] += (x - 1) / 2
}
}
cur = nxt
}
return min(cur[queryGlass], 1) // 如果溢出,容量是 1
}
###js
var champagneTower = function(poured, queryRow, queryGlass) {
let cur = [poured];
for (let i = 1; i <= queryRow; i++) {
const nxt = Array(i + 1).fill(0);
for (let j = 0; j < cur.length; j++) {
const x = cur[j] - 1;
if (x > 0) { // 溢出到下一层
nxt[j] += x / 2;
nxt[j + 1] += x / 2;
}
}
cur = nxt;
}
return Math.min(cur[queryGlass], 1); // 如果溢出,容量是 1
};
###rust
impl Solution {
pub fn champagne_tower(poured: i32, query_row: i32, query_glass: i32) -> f64 {
let mut cur = vec![poured as f64];
for i in 1..=query_row as usize {
let mut nxt = vec![0.0; i + 1];
for (j, x) in cur.into_iter().enumerate() {
if x > 1.0 { // 溢出到下一层
nxt[j] += (x - 1.0) / 2.0;
nxt[j + 1] += (x - 1.0) / 2.0;
}
}
cur = nxt;
}
cur[query_glass as usize].min(1.0) // 如果溢出,容量是 1
}
}
无需使用两个数组,可以像 0-1 背包那样,在同一个数组上修改。
###py
class Solution:
def champagneTower(self, poured: int, queryRow: int, queryGlass: int) -> float:
f = [0.0] * (queryRow + 1)
f[0] = float(poured)
for i in range(queryRow):
for j in range(i, -1, -1):
x = f[j] - 1
if x > 0:
f[j + 1] += x / 2
f[j] = x / 2
else:
f[j] = 0.0
return min(f[queryGlass], 1.0) # 如果溢出,容量是 1
###java
class Solution {
public double champagneTower(int poured, int queryRow, int queryGlass) {
double[] f = new double[queryRow + 1];
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return Math.min(f[queryGlass], 1); // 如果溢出,容量是 1
}
}
###cpp
class Solution {
public:
double champagneTower(int poured, int queryRow, int queryGlass) {
vector<double> f(queryRow + 1);
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return min(f[queryGlass], 1.0); // 如果溢出,容量是 1
}
};
###c
#define MIN(a, b) ((b) < (a) ? (b) : (a))
double champagneTower(int poured, int queryRow, int queryGlass) {
double* f = calloc(queryRow + 1, sizeof(double));
f[0] = poured;
for (int i = 0; i < queryRow; i++) {
for (int j = i; j >= 0; j--) {
double x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
double ans = MIN(f[queryGlass], 1); // 如果溢出,容量是 1
free(f);
return ans;
}
###go
func champagneTower(poured, queryRow, queryGlass int) float64 {
f := make([]float64, queryRow+1)
f[0] = float64(poured)
for i := range queryRow {
for j := i; j >= 0; j-- {
x := f[j] - 1
if x > 0 {
f[j+1] += x / 2
f[j] = x / 2
} else {
f[j] = 0
}
}
}
return min(f[queryGlass], 1) // 如果溢出,容量是 1
}
###js
var champagneTower = function(poured, queryRow, queryGlass) {
const f = Array(queryRow + 1).fill(0);
f[0] = poured;
for (let i = 0; i < queryRow; i++) {
for (let j = i; j >= 0; j--) {
const x = f[j] - 1;
if (x > 0) {
f[j + 1] += x / 2;
f[j] = x / 2;
} else {
f[j] = 0;
}
}
}
return Math.min(f[queryGlass], 1); // 如果溢出,容量是 1
};
###rust
impl Solution {
pub fn champagne_tower(poured: i32, query_row: i32, query_glass: i32) -> f64 {
let query_row = query_row as usize;
let mut f = vec![0.0; query_row + 1];
f[0] = poured as f64;
for i in 0..query_row {
for j in (0..=i).rev() {
let x = f[j] - 1.0;
if x > 0.0 {
f[j + 1] += x / 2.0;
f[j] = x / 2.0;
} else {
f[j] = 0.0;
}
}
}
f[query_glass as usize].min(1.0) // 如果溢出,容量是 1
}
}
欢迎关注 B站@灵茶山艾府
我们把玻璃杯摆成金字塔的形状,其中 第一层 有 1 个玻璃杯, 第二层 有 2 个,依次类推到第 100 层,每个玻璃杯将盛有香槟。
从顶层的第一个玻璃杯开始倾倒一些香槟,当顶层的杯子满了,任何溢出的香槟都会立刻等流量的流向左右两侧的玻璃杯。当左右两边的杯子也满了,就会等流量的流向它们左右两边的杯子,依次类推。(当最底层的玻璃杯满了,香槟会流到地板上)
例如,在倾倒一杯香槟后,最顶层的玻璃杯满了。倾倒了两杯香槟后,第二层的两个玻璃杯各自盛放一半的香槟。在倒三杯香槟后,第二层的香槟满了 - 此时总共有三个满的玻璃杯。在倒第四杯后,第三层中间的玻璃杯盛放了一半的香槟,他两边的玻璃杯各自盛放了四分之一的香槟,如下图所示。
![]()
现在当倾倒了非负整数杯香槟后,返回第 i 行 j 个玻璃杯所盛放的香槟占玻璃杯容积的比例( i 和 j 都从0开始)。
示例 1: 输入: poured(倾倒香槟总杯数) = 1, query_glass(杯子的位置数) = 1, query_row(行数) = 1 输出: 0.00000 解释: 我们在顶层(下标是(0,0))倒了一杯香槟后,没有溢出,因此所有在顶层以下的玻璃杯都是空的。 示例 2: 输入: poured(倾倒香槟总杯数) = 2, query_glass(杯子的位置数) = 1, query_row(行数) = 1 输出: 0.50000 解释: 我们在顶层(下标是(0,0)倒了两杯香槟后,有一杯量的香槟将从顶层溢出,位于(1,0)的玻璃杯和(1,1)的玻璃杯平分了这一杯香槟,所以每个玻璃杯有一半的香槟。
示例 3:
输入: poured = 100000009, query_row = 33, query_glass = 17 输出: 1.00000
提示:
0 <= poured <= 1090 <= query_glass <= query_row < 100我们创建一个二维数组dp[i][j],其中,i表示行号,j表示酒杯编号。
根据题目描述,我们可以知道,针对于第row行第column列(dp[row][column])的这个酒杯,有机会能够注入到它的“上层”酒杯只会是dp[row-1][column-1]和dp[row-1][column],那么这里是“有机会”,因为只有这两个酒杯都满了(减1)的情况下,才会注入到dp[row][column]这个酒杯中,所以,我们可以得到状态转移方程为:
dp[row][column] = Math.max(dp[row-1][column-1]-1, 0)/2 + Math.max(dp[row-1][column]-1, 0)/2。
那么我们从第一行开始计算,逐一可以计算出每一行中每一个酒杯的容量,那么题目的结果就显而易见了。具体操作,如下图所示:
![]()
由于题目只需要获取第query_row行的第query_glass编号的酒杯容量,那么我们其实只需要关注第query_row行的酒杯容量即可,所以,用一维数组dp[]来保存最新计算的那个行中每个酒杯的容量。
计算方式与上面的解法相似,此处就不赘述了。
###java
class Solution {
public double champagneTower(int poured, int query_row, int query_glass) {
double[][] dp = new double[query_row + 2][query_row + 2];
dp[1][1] = poured; // 为了方式越界,下标(0,0)的酒杯我们存放在dp[1][1]的位置上
for (int row = 2; row <= query_row + 1; row++) {
for (int column = 1; column <= row; column++) {
dp[row][column] = Math.max(dp[row - 1][column - 1] - 1, 0) / 2 + Math.max(dp[row - 1][column] - 1, 0) / 2;
}
}
return Math.min(dp[query_row + 1][query_glass + 1], 1);
}
}
![]()
###java
class Solution {
public double champagneTower(int poured, int query_row, int query_glass) {
double[] dp = new double[query_glass + 2]; // 第i层中每个glass的容量
dp[0] = poured; // 第0层的第0个编号酒杯倾倒香槟容量
int row = 0;
while (row < query_row) { // 获取第query_row行,只需要遍历到第query_row减1行即可。
for (int glass = Math.min(row, query_glass); glass >= 0; glass--) {
double overflow = Math.max(dp[glass] - 1, 0) / 2.0;
dp[glass] = overflow; // 覆盖掉旧值
dp[glass + 1] += overflow; // 由于是倒序遍历,所以对于dp[glass + 1]要执行“+=”操作
}
row++; // 计算下一行
}
return Math.min(dp[query_glass], 1); // 如果倾倒香槟容量大于1,则只返回1.
}
}
![]()
今天的文章内容就这些了:
写作不易,笔者几个小时甚至数天完成的一篇文章,只愿换来您几秒钟的 点赞 & 分享 。
更多技术干货,欢迎大家关注公众号“爪哇缪斯” ~ \(^o^)/ ~ 「干货分享,每天更新」
为了方便,我们令 poured 为 k,query_row 和 query_glass 分别为 $n$ 和 $m$。
定义 $f[i][j]$ 为第 $i$ 行第 $j$ 列杯子所经过的水的流量(而不是最终剩余的水量)。
起始我们有 $f[0][0] = k$,最终答案为 $\min(f[n][m], 1)$。
不失一般性考虑 $f[i][j]$ 能够更新哪些状态:显然当 $f[i][j]$ 不足 $1$ 的时候,不会有水从杯子里溢出,即 $f[i][j]$ 将不能更新其他状态;当 $f[i][j]$ 大于 $1$ 时,将会有 $f[i][j] - 1$ 的水会等量留到下一行的杯子里,所流向的杯子分别是「第 $i + 1$ 行第 $j$ 列的杯子」和「第 $i + 1$ 行第 $j + 1$ 列的杯子」,增加流量均为 $\frac{f[i][j] - 1}{2}$,即有 $f[i + 1][j] += \frac{f[i][j] - 1}{2}$ 和 $f[i + 1][j + 1] += \frac{f[i][j] - 1}{2}$。
代码:
###Java
class Solution {
public double champagneTower(int k, int n, int m) {
double[][] f = new double[n + 10][n + 10];
f[0][0] = k;
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= i; j++) {
if (f[i][j] <= 1) continue;
f[i + 1][j] += (f[i][j] - 1) / 2;
f[i + 1][j + 1] += (f[i][j] - 1) / 2;
}
}
return Math.min(f[n][m], 1);
}
}
###TypeScript
function champagneTower(k: number, n: number, m: number): number {
const f = new Array<Array<number>>()
for (let i = 0; i < n + 10; i++) f.push(new Array<number>(n + 10).fill(0))
f[0][0] = k
for (let i = 0; i <= n; i++) {
for (let j = 0; j <= i; j++) {
if (f[i][j] <= 1) continue
f[i + 1][j] += (f[i][j] - 1) / 2
f[i + 1][j + 1] += (f[i][j] - 1) / 2
}
}
return Math.min(f[n][m], 1)
}
###Python3
class Solution:
def champagneTower(self, k: int, n: int, m: int) -> float:
f = [[0] * (n + 10) for _ in range(n + 10)]
f[0][0] = k
for i in range(n + 1):
for j in range(i + 1):
if f[i][j] <= 1:
continue
f[i + 1][j] += (f[i][j] - 1) / 2
f[i + 1][j + 1] += (f[i][j] - 1) / 2
return min(f[n][m], 1)
如果有帮助到你,请给题解点个赞和收藏,让更多的人看到 ~ ("▔□▔)/
也欢迎你 关注我,提供写「证明」&「思路」的高质量题解。
所有题解已经加入 刷题指南,欢迎 star 哦 ~
给你一个只包含字符 'a'、'b' 和 'c' 的字符串 s。
如果一个 子串 中所有 不同 字符出现的次数都 相同,则称该子串为 平衡 子串。
请返回 s 的 最长平衡子串 的 长度 。
子串 是字符串中连续的、非空 的字符序列。
示例 1:
输入: s = "abbac"
输出: 4
解释:
最长的平衡子串是 "abba",因为不同字符 'a' 和 'b' 都恰好出现了 2 次。
示例 2:
输入: s = "aabcc"
输出: 3
解释:
最长的平衡子串是 "abc",因为不同字符 'a'、'b' 和 'c' 都恰好出现了 1 次。
示例 3:
输入: s = "aba"
输出: 2
解释:
最长的平衡子串之一是 "ab",因为不同字符 'a' 和 'b' 都恰好出现了 1 次。另一个最长的平衡子串是 "ba"。
提示:
1 <= s.length <= 105s 仅包含字符 'a'、'b' 和 'c'。先考虑简单一点的问题:如果只有字母 a 和 b 怎么做?
这是一个经典的用前缀和维护的题目。假设平衡子串的下标范围是 $[l, r]$,设 $a_i$ 表示字母 a 在长度为 $i$ 的前缀里的出现次数,$b_i$ 表示字母 b 在长度为 $i$ 的前缀里的出现次数,则
$$
a_r - a_{l - 1} = b_r - b_{l - 1}
$$
移项得
$$
a_r - b_r = a_{l - 1} - b_{l - 1}
$$
因此,类似于 leetcode 974. 和可被 K 整除的子数组,我们枚举子串的右端点 $r$,并找到 $\Delta = (a_i - b_i)$ 的值与 $r$ 相同的最小下标 $l$,则以 $r$ 为右端点的平衡子串的最大长度就是 $(r - l)$。我们可以把 $\Delta$ 的值放入哈希表,并对每个 $\Delta$ 维护最小的下标。
回到当前问题,现在增加了一个字母 c,我们能不能用类似的方法做呢?设 $c_i$ 表示字母 c 在长度为 $i$ 的前缀里的出现次数,我们继续分析一下平衡子串的条件
$$
\begin{matrix}
a_r - a_{l - 1} = b_r - b_{l - 1} \
b_r - b_{l - 1} = c_r - c_{l - 1} \
\end{matrix}
$$
移项得
$$
\begin{matrix}
a_r - b_r = a_{l - 1} - b_{l - 1} \
b_r - c_r = b_{l - 1} - c_{l - 1} \
\end{matrix}
$$
因此,我们仍然可以用相同的做法求出平衡子串的最大长度,只不过哈希表的 key 不是一个整数 $(a_i - b_i)$,而是一个数对 $(a_i - b_i, b_i - c_i)$。
复杂度 $\mathcal{O}(n)$(认为字符集大小是常数)。
其实,这个做法也可以推广到任意字符集,复杂度 $\mathcal{O}(n|\Sigma| \times 2^{|\Sigma|})$,其中 $|\Sigma|$ 是字符集大小。有兴趣的读者可以试着做一下本题强化版 CF GYM100584D - Balanced strings,链接就不附了,怕扣子又把我帖子搞没了。
class Solution {
public:
int longestBalanced(string s) {
int n = s.size(), ans = 0;
// 子串只包含一个字母的情况
auto calc1 = [&]() {
int cnt = 0;
for (int i = 0; i < n; i++) {
if (i == 0 || s[i] == s[i - 1]) cnt++;
else cnt = 1;
ans = max(ans, cnt);
}
};
// 子串只包含两个字母的情况
auto calc2 = [&](char a, char b) {
unordered_map<int, int> mp;
mp[0] = -1;
// x 表示 a_i - b_i 的值
int x = 0;
for(int i = 0; i < n; i++) {
if (s[i] == a) x++;
else if (s[i] == b) x--;
else {
// 遇到不在子串里的字符,截断
mp.clear();
x = 0;
}
if (mp.count(x)) ans = max(ans, i - mp[x]);
else mp[x] = i;
}
};
// 子串包含三个字母的情况
auto calc3 = [&]() {
unordered_map<long long, int> mp;
mp[0] = -1;
// x 表示 a_i - b_i 的值
// y 表示 b_i - c_i 的值
int x = 0, y = 0;
for (int i = 0; i < n; i++) {
if (s[i] == 'a') x++;
else if (s[i] == 'b') x--, y++;
else y--;
// c++ 的 unordered_map 不支持用 pair 作为 key
// 所以只能把数对映射成一个整数
// 当然也可以直接用 map,用 pair 作为 key
// 只是复杂度会乘上一个 log
long long key = 10LL * x * n + y;
if (mp.count(key)) ans = max(ans, i - mp[key]);
else mp[key] = i;
}
};
calc1();
calc2('a', 'b');
calc2('a', 'c');
calc2('b', 'c');
calc3();
return ans;
}
};
读者首先需要掌握用前缀和 + 哈希表的方式,求特定子数组数量或最大长度的方法。读者可以学习 灵神题单 - 常用数据结构 的“前缀和与哈希表”一节。
如果读者掌握了这一技巧,那么至少可以解答只包含 a 和 b 的情况。若此时仍然不会做本题,需要加强从特殊到一般的拓展能力。一般来说可以先考虑一个简化的问题怎么做(例如图上问题先想一下树上怎么做,树上问题先想一下链上怎么做,二维矩阵问题先想一下一维序列怎么做),这只能在大量的练习中慢慢习得。
分成如下三类问题,依次解答:
即 $s$ 中的最长连续相同子串的长度。这题是 1446. 连续字符。
这可以用分组循环解决。
适用场景:按照题目要求,序列会被分割成若干组,每一组的判断/处理逻辑是相同的。
核心思想:
这个写法的好处是,各个逻辑块分工明确,也不需要特判最后一组(易错点)。以我的经验,这个写法是所有写法中最不容易出 bug 的,推荐大家记住。
同样地,用分组循环分组,每组只包含两种字母。
对于每一组,计算含有相同数量的两种字母的最长子串。这题是 525. 连续数组。
做法见 我的题解。
仿照 525 题的做法,设 $\texttt{a}$ 在这个组的个数前缀和数组为 $S_a$,$\texttt{b}$ 在这个组的个数前缀和数组为 $S_b$,$\texttt{c}$ 在这个组的个数前缀和数组为 $S_c$。
子串 $[l,r)$ 中的字母 $\texttt{a},\texttt{b},\texttt{c}$ 的出现次数相等,可以拆分为如下两个约束:
只要满足这两个约束,由等号的传递性可知,子串 $[l,r)$ 中的字母 $\texttt{a}$ 和 $\texttt{c}$ 的出现次数相等,即三个字母的出现次数都相等。
两个约束即如下两个等式
$$
\begin{aligned}
S_a[r] - S_b[r] &= S_a[l] - S_b[l] \
S_b[r] - S_c[r] &= S_b[l] - S_c[l] \
\end{aligned}
$$
定义数组 $a[i] = (S_a[i] - S_b[i], S_b[i] - S_c[i])$,问题变成:
做法同上。
本题视频讲解,欢迎点赞关注~
###py
class Solution:
def longestBalanced(self, s: str) -> int:
n = len(s)
# 一种字母
ans = i = 0
while i < n:
start = i
i += 1
while i < n and s[i] == s[i - 1]:
i += 1
ans = max(ans, i - start)
# 两种字母
def f(x: str, y: str) -> None:
nonlocal ans
i = 0
while i < n:
pos = {0: i - 1} # 前缀和数组的首项是 0,位置相当于在 i-1
d = 0 # x 的个数减去 y 的个数
while i < n and (s[i] == x or s[i] == y):
d += 1 if s[i] == x else -1
if d in pos:
ans = max(ans, i - pos[d])
else:
pos[d] = i
i += 1
i += 1
f('a', 'b')
f('a', 'c')
f('b', 'c')
# 三种字母
# 前缀和数组的首项是 0,位置相当于在 -1
pos = {(0, 0): -1}
cnt = defaultdict(int)
for i, b in enumerate(s):
cnt[b] += 1
p = (cnt['a'] - cnt['b'], cnt['b'] - cnt['c'])
if p in pos:
ans = max(ans, i - pos[p])
else:
pos[p] = i
return ans
###java
class Solution {
public int longestBalanced(String S) {
char[] s = S.toCharArray();
int n = s.length;
int ans = 0;
// 一种字母
for (int i = 0; i < n; ) {
int start = i;
for (i++; i < n && s[i] == s[i - 1]; i++) ;
ans = Math.max(ans, i - start);
}
// 两种字母
ans = Math.max(ans, f(s, 'a', 'b'));
ans = Math.max(ans, f(s, 'a', 'c'));
ans = Math.max(ans, f(s, 'b', 'c'));
// 三种字母
// 把 (x, y) 压缩成一个 long,方便保存至哈希表
// (x, y) 变成 (x + n) << 20 | (y + n),其中 +n 避免出现负数
Map<Long, Integer> pos = new HashMap<>();
pos.put((long) n << 20 | n, -1); // 前缀和数组的首项是 0,位置相当于在 -1
int[] cnt = new int[3];
for (int i = 0; i < n; i++) {
cnt[s[i] - 'a']++;
long p = (long) (cnt[0] - cnt[1] + n) << 20 | (cnt[1] - cnt[2] + n);
if (pos.containsKey(p)) {
ans = Math.max(ans, i - pos.get(p));
} else {
pos.put(p, i);
}
}
return ans;
}
private int f(char[] s, char x, char y) {
int n = s.length;
int ans = 0;
for (int i = 0; i < n; i++) {
Map<Integer, Integer> pos = new HashMap<>();
pos.put(0, i - 1); // 前缀和数组的首项是 0,位置相当于在 i-1
int d = 0; // x 的个数减去 y 的个数
for (; i < n && (s[i] == x || s[i] == y); i++) {
d += s[i] == x ? 1 : -1;
if (pos.containsKey(d)) {
ans = Math.max(ans, i - pos.get(d));
} else {
pos.put(d, i);
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
int longestBalanced(string s) {
int n = s.size();
int ans = 0;
// 一种字母
for (int i = 0; i < n;) {
int start = i;
for (i++; i < n && s[i] == s[i - 1]; i++);
ans = max(ans, i - start);
}
// 两种字母
auto f = [&](char x, char y) -> void {
for (int i = 0; i < n; i++) {
unordered_map<int, int> pos = {{0, i - 1}}; // 前缀和数组的首项是 0,位置相当于在 i-1
int d = 0; // x 的个数减去 y 的个数
for (; i < n && (s[i] == x || s[i] == y); i++) {
d += s[i] == x ? 1 : -1;
if (pos.contains(d)) {
ans = max(ans, i - pos[d]);
} else {
pos[d] = i;
}
}
}
};
f('a', 'b');
f('a', 'c');
f('b', 'c');
// 三种字母
// 把 (x, y) 压缩成一个 long long,方便保存至哈希表
// (x, y) 变成 (x + n) << 32 | (y + n),其中 +n 避免出现负数
unordered_map<long long, int> pos = {{1LL * n << 32 | n, -1}}; // 前缀和数组的首项是 0,位置相当于在 -1
int cnt[3]{};
for (int i = 0; i < n; i++) {
cnt[s[i] - 'a']++;
long long p = 1LL * (cnt[0] - cnt[1] + n) << 32 | (cnt[1] - cnt[2] + n);
if (pos.contains(p)) {
ans = max(ans, i - pos[p]);
} else {
pos[p] = i;
}
}
return ans;
}
};
###go
func longestBalanced(s string) (ans int) {
n := len(s)
// 一种字母
for i := 0; i < n; {
start := i
for i++; i < n && s[i] == s[i-1]; i++ {
}
ans = max(ans, i-start)
}
// 两种字母
f := func(x, y byte) {
for i := 0; i < n; i++ {
pos := map[int]int{0: i - 1} // 前缀和数组的首项是 0,位置相当于在 i-1
d := 0 // x 的个数减去 y 的个数
for ; i < n && (s[i] == x || s[i] == y); i++ {
if s[i] == x {
d++
} else {
d--
}
if j, ok := pos[d]; ok {
ans = max(ans, i-j)
} else {
pos[d] = i
}
}
}
}
f('a', 'b')
f('a', 'c')
f('b', 'c')
// 三种字母
type pair struct{ diffAB, diffBC int }
pos := map[pair]int{{}: -1} // 前缀和数组的首项是 0,位置相当于在 -1
cnt := [3]int{}
for i, b := range s {
cnt[b-'a']++
p := pair{cnt[0] - cnt[1], cnt[1] - cnt[2]}
if j, ok := pos[p]; ok {
ans = max(ans, i-j)
} else {
pos[p] = i
}
}
return
}
我们可以在 $[0,..n-1]$ 范围内枚举子串的起始位置 $i$,然后在 $[i,..,n-1]$ 范围内枚举子串的结束位置 $j$,并使用哈希表 $\textit{cnt}$ 记录子串 $s[i..j]$ 中每个字符出现的次数。我们使用变量 $\textit{mx}$ 记录子串中出现次数最多的字符的出现次数,使用变量 $v$ 记录子串中不同字符的个数。如果在某个位置 $j$,满足 $\textit{mx} \times v = j - i + 1$,则说明子串 $s[i..j]$ 是一个平衡子串,我们更新答案 $\textit{ans} = \max(\textit{ans}, j - i + 1)$。
###python
class Solution:
def longestBalanced(self, s: str) -> int:
n = len(s)
ans = 0
for i in range(n):
cnt = Counter()
mx = v = 0
for j in range(i, n):
cnt[s[j]] += 1
mx = max(mx, cnt[s[j]])
if cnt[s[j]] == 1:
v += 1
if mx * v == j - i + 1:
ans = max(ans, j - i + 1)
return ans
###java
class Solution {
public int longestBalanced(String s) {
int n = s.length();
int[] cnt = new int[26];
int ans = 0;
for (int i = 0; i < n; ++i) {
Arrays.fill(cnt, 0);
int mx = 0, v = 0;
for (int j = i; j < n; ++j) {
int c = s.charAt(j) - 'a';
if (++cnt[c] == 1) {
++v;
}
mx = Math.max(mx, cnt[c]);
if (mx * v == j - i + 1) {
ans = Math.max(ans, j - i + 1);
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
int longestBalanced(string s) {
int n = s.size();
vector<int> cnt(26, 0);
int ans = 0;
for (int i = 0; i < n; ++i) {
fill(cnt.begin(), cnt.end(), 0);
int mx = 0, v = 0;
for (int j = i; j < n; ++j) {
int c = s[j] - 'a';
if (++cnt[c] == 1) {
++v;
}
mx = max(mx, cnt[c]);
if (mx * v == j - i + 1) {
ans = max(ans, j - i + 1);
}
}
}
return ans;
}
};
###go
func longestBalanced(s string) (ans int) {
n := len(s)
for i := 0; i < n; i++ {
cnt := [26]int{}
mx, v := 0, 0
for j := i; j < n; j++ {
c := s[j] - 'a'
cnt[c]++
if cnt[c] == 1 {
v++
}
mx = max(mx, cnt[c])
if mx*v == j-i+1 {
ans = max(ans, j-i+1)
}
}
}
return ans
}
###ts
function longestBalanced(s: string): number {
const n = s.length;
let ans: number = 0;
for (let i = 0; i < n; ++i) {
const cnt: number[] = Array(26).fill(0);
let [mx, v] = [0, 0];
for (let j = i; j < n; ++j) {
const c = s[j].charCodeAt(0) - 97;
if (++cnt[c] === 1) {
++v;
}
mx = Math.max(mx, cnt[c]);
if (mx * v === j - i + 1) {
ans = Math.max(ans, j - i + 1);
}
}
}
return ans;
}
时间复杂度 $O(n^2)$,其中 $n$ 是字符串的长度。空间复杂度 $O(|\Sigma|)$,其中 $|\Sigma|$ 是字符集的大小,本题中 $|\Sigma| = 26$。
有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈 😄~
给你一个由小写英文字母组成的字符串 s。
如果一个 子串 中所有 不同 字符出现的次数都 相同 ,则称该子串为 平衡 子串。
请返回 s 的 最长平衡子串 的 长度 。
子串 是字符串中连续的、非空 的字符序列。
示例 1:
输入: s = "abbac"
输出: 4
解释:
最长的平衡子串是 "abba",因为不同字符 'a' 和 'b' 都恰好出现了 2 次。
示例 2:
输入: s = "zzabccy"
输出: 4
解释:
最长的平衡子串是 "zabc",因为不同字符 'z'、'a'、'b' 和 'c' 都恰好出现了 1 次。
示例 3:
输入: s = "aba"
输出: 2
解释:
最长的平衡子串之一是 "ab",因为不同字符 'a' 和 'b' 都恰好出现了 1 次。另一个最长的平衡子串是 "ba"。
提示:
1 <= s.length <= 1000s 仅由小写英文字母组成。用 $n$ 表示字符串 $s$ 的长度。由于 $n \le 1000$,因此可以枚举字符串 $s$ 的所有子串并判断是否为平衡子串。对于 $0 \le i < n$ 的每个下标 $i$,从小到大遍历 $i \le j < n$ 的所有下标 $j$,对于每个下标 $j$ 判断下标范围 $[i, j]$ 的子串是否为平衡子串。
具体做法是,对于每个起始下标 $i$,使用哈希表记录子串中的每个字符的出现次数。当起始下标 $i$ 确定时,对于遍历到的每个下标 $j$,执行如下操作。
在哈希表中将字符 $s[j]$ 的出现次数增加 $1$。
遍历哈希表中的所有记录,判断是否满足哈希表中的每个字符的出现次数都与字符 $s[j]$ 的出现次数相等。如果出现次数都相等,则下标范围 $[i, j]$ 的子串是平衡子串,其长度是 $j - i + 1$,使用该平衡子串的长度更新最长平衡子串的长度。
遍历结束之后,即可得到字符串 $s$ 的最长平衡子串的长度。
###Java
class Solution {
public int longestBalanced(String s) {
int longest = 0;
int n = s.length();
for (int i = 0; i < n; i++) {
Map<Character, Integer> counts = new HashMap<Character, Integer>();
Set<Map.Entry<Character, Integer>> entries = counts.entrySet();
for (int j = i; j < n; j++) {
char c = s.charAt(j);
int count = counts.getOrDefault(c, 0) + 1;
counts.put(c, count);
boolean isBalanced = true;
for (Map.Entry<Character, Integer> entry : entries) {
if (entry.getValue() != count) {
isBalanced = false;
break;
}
}
if (isBalanced) {
longest = Math.max(longest, j - i + 1);
}
}
}
return longest;
}
}
###C#
public class Solution {
public int LongestBalanced(string s) {
int longest = 0;
int n = s.Length;
for (int i = 0; i < n; i++) {
IDictionary<char, int> counts = new Dictionary<char, int>();
for (int j = i; j < n; j++) {
char c = s[j];
counts.TryAdd(c, 0);
counts[c]++;
bool isBalanced = true;
foreach (KeyValuePair<char, int> pair in counts) {
if (pair.Value != counts[c]) {
isBalanced = false;
break;
}
}
if (isBalanced) {
longest = Math.Max(longest, j - i + 1);
}
}
}
return longest;
}
}
时间复杂度:$O(n^2 |\Sigma|)$,其中 $n$ 是字符串 $s$ 的长度,$\Sigma$ 是字符集,这道题中 $\Sigma$ 是全部小写英语字母,$|\Sigma| = 26$。需要遍历的子串数量是 $O(n^2)$,对于每个子串遍历哈希表判断是否为平衡子串的时间是 $O(|\Sigma|)$,因此时间复杂度是 $O(n^2 |\Sigma|)$。
空间复杂度:$O(\min(n, |\Sigma|))$,其中 $n$ 是字符串 $s$ 的长度,$\Sigma$ 是字符集,这道题中 $\Sigma$ 是全部小写英语字母,$|\Sigma| = 26$。哈希表的空间是 $O(\min(n, |\Sigma|))$。
枚举子串左端点 $i$,然后枚举子串右端点 $j=i,i+1,i+2,\ldots,n-1$。
在枚举右端点 $j$ 的过程中,统计子串 $[i,j]$ 每种字母的出现次数 $\textit{cnt}$。
遍历 $\textit{cnt}$,如果所有字母的出现次数均相同,用子串长度 $j-i+1$ 更新答案的最大值。
本题视频讲解,欢迎点赞关注~
###py
class Solution:
def longestBalanced(self, s: str) -> int:
ans = 0
n = len(s)
for i in range(n):
cnt = defaultdict(int)
for j in range(i, n):
cnt[s[j]] += 1
if len(set(cnt.values())) == 1:
ans = max(ans, j - i + 1)
return ans
###java
class Solution {
public int longestBalanced(String S) {
char[] s = S.toCharArray();
int n = s.length;
int ans = 0;
for (int i = 0; i < n; i++) {
int[] cnt = new int[26];
next:
for (int j = i; j < n; j++) {
int base = ++cnt[s[j] - 'a'];
for (int c : cnt) {
if (c > 0 && c != base) {
continue next;
}
}
ans = Math.max(ans, j - i + 1);
}
}
return ans;
}
}
###cpp
class Solution {
public:
int longestBalanced(string s) {
int n = s.size();
int ans = 0;
for (int i = 0; i < n; i++) {
int cnt[26]{};
for (int j = i; j < n; j++) {
int base = ++cnt[s[j] - 'a'];
for (int c : cnt) {
if (c && c != base) {
base = -1;
break;
}
}
if (base != -1) {
ans = max(ans, j - i + 1);
}
}
}
return ans;
}
};
###go
func longestBalanced(s string) (ans int) {
for i := range s {
cnt := make([]int, 26)
next:
for j := i; j < len(s); j++ {
cnt[s[j]-'a']++
base := cnt[s[j]-'a']
for _, c := range cnt {
if c > 0 && c != base {
continue next
}
}
ans = max(ans, j-i+1)
}
}
return
}
设 $\textit{mx} = \max(\textit{cnt})$,设 $\textit{kinds}$ 为子串中的不同字母个数。
如果 $\textit{mx}\cdot \textit{kinds} = j-i+1$,说明子串所有字母的出现次数均为 $\textit{mx}$,均相等。
###py
# 手写 max 更快
max = lambda a, b: b if b > a else a
class Solution:
def longestBalanced(self, s: str) -> int:
ans = 0
for i in range(len(s)):
cnt = defaultdict(int)
mx = 0
for j in range(i, len(s)):
cnt[s[j]] += 1
mx = max(mx, cnt[s[j]])
if mx * len(cnt) == j - i + 1:
ans = max(ans, j - i + 1)
return ans
###java
class Solution {
public int longestBalanced(String S) {
char[] s = S.toCharArray();
int n = s.length;
int ans = 0;
for (int i = 0; i < n; i++) {
int[] cnt = new int[26];
int mx = 0, kinds = 0;
for (int j = i; j < n; j++) {
int b = s[j] - 'a';
if (cnt[b] == 0) {
kinds++;
}
mx = Math.max(mx, ++cnt[b]);
if (mx * kinds == j - i + 1) {
ans = Math.max(ans, j - i + 1);
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
int longestBalanced(string s) {
int n = s.size();
int ans = 0;
for (int i = 0; i < n; i++) {
int cnt[26]{};
int mx = 0, kinds = 0;
for (int j = i; j < n; j++) {
int b = s[j] - 'a';
if (cnt[b] == 0) {
kinds++;
}
mx = max(mx, ++cnt[b]);
if (mx * kinds == j - i + 1) {
ans = max(ans, j - i + 1);
}
}
}
return ans;
}
};
###go
func longestBalanced(s string) (ans int) {
for i := range s {
cnt := [26]int{}
mx, kinds := 0, 0
for j := i; j < len(s); j++ {
b := s[j] - 'a'
if cnt[b] == 0 {
kinds++
}
cnt[b]++
mx = max(mx, cnt[b])
if mx*kinds == j-i+1 {
ans = max(ans, j-i+1)
}
}
}
return
}
推荐先完成 3714. 最长的平衡子串 II,再来理解这个做法。
能否只用一个哈希表,同时解决子串包含 $1$ 个、$2$ 个……$26$ 个字母的情况?
定义 $S[i][j]$ 表示前缀 $[0,i]$ 中的字母 $j$ 的出现次数。$S[-1][j] = 0$。
如果要找只包含字母 $\texttt{abc}$ 的平衡子串 $(l,r]$,3714 题告诉我们:
据此,定义
$$
d[i][j] =
\begin{cases}
S[i][j] - S[i][\textit{minCh}], & j\ 在子串中 \
S[i][j], & j\ 不在子串中 \
\end{cases}
$$
其中 $\textit{minCh}$ 是子串中的最小字母,用来作为减法的基准。
然而,这个定义面临一个尴尬的问题:
必须先知道子串包含哪些字母,才能准确地算出 $d[i][j]$。
难道要像 3714 题那样,枚举所有非空字母子集,即 $2^{26}-1$ 种情况?
实际上,考虑以 $i$ 为右端点的子串,当子串左端点从 $i$ 开始向左移动(扩展)时,子串中的字母种类数要么不变,要么加一,所以固定右端点时,只有至多 $26$ 种字母集合。(这有点像 LogTrick)
于是,对于每个右端点 $i$,我们至多枚举 $26$ 种字母集合。
对于一个固定的字母集合,$d[i][j]$ 的值就是固定的了。问题相当于:
现在,「枚举右维护左」中的「枚举右」解决了,「维护左」怎么做?
对于子串 $(l,r]$,我们需要:
如果 $d[l] = d[r]$,那么子串 $(l,r]$ 就是平衡子串吗?
不一定。存在 $d[i]$ 相同,但字母集合不同的情况。所以我们还需要在哈希表的 key 中添加一个 $\textit{mask}$,表示字母集合(实现时用 二进制数 压缩表示)。
小优化:如果整个 $s$ 是平衡的,返回 $n$。
###py
class Solution:
def longestBalanced(self, s: str) -> int:
n = len(s)
cnt = Counter(s)
if max(cnt.values()) * len(cnt) == n: # s 是平衡的
return n
mp = {c: i for i, c in enumerate(cnt)}
s = [mp[c] for c in s] # 离散化,字母 -> 数字
n = len(s)
suf_orders = [None] * n
order = []
for i in range(n - 1, -1, -1):
# 把最近出现的字母移到 order 末尾
try: order.remove(s[i])
except: pass
order.append(s[i])
suf_orders[i] = order[:]
order = []
cnt = [0] * len(mp)
pos = {}
ans = 0
for i, b in enumerate(s):
suf_order = suf_orders[i]
min_ch = inf
mask = 0
for j in range(len(suf_order) - 1, -1, -1):
min_ch = min(min_ch, suf_order[j])
# 注意此时 cnt 并不包含 s[i],我们计算的是前缀 s[:i] 的信息
# 在子串中的字母,计算差值
# 不在子串中的字母,维持原样
d = cnt[:]
for ch in suf_order[j:]:
d[ch] -= cnt[min_ch]
mask |= 1 << suf_order[j]
p = (tuple(d), mask) # mask 用来区分 d[ch] 是差值还是原始值
# 记录 p 首次出现的位置
if p not in pos:
pos[p] = i - 1
# 把最近出现的字母移到 order 末尾
try: order.remove(b)
except: pass
order.append(b)
cnt[b] += 1
min_ch = inf
mask = 0
for j in range(len(order) - 1, -1, -1):
min_ch = min(min_ch, order[j])
d = cnt[:]
for ch in order[j:]:
d[ch] -= cnt[min_ch]
mask |= 1 << order[j]
p = (tuple(d), mask)
# 再次遇到完全一样的 p,说明我们找到了一个平衡子串,左端点为 pos[p]+1,右端点为 i
if p in pos:
ans = max(ans, i - pos[p])
return ans
###go
func longestBalanced(s string) (ans int) {
n := len(s)
sufOrders := make([][]byte, n)
order := []byte{}
move := func(b byte) {
// 把最近出现的字母 b 移到 order 末尾
j := bytes.IndexByte(order, b)
if j >= 0 {
order = append(order[:j], order[j+1:]...)
}
order = append(order, b)
}
for i := n - 1; i >= 0; i-- {
move(s[i] - 'a')
sufOrders[i] = slices.Clone(order)
}
order = []byte{}
cnt := [27]int{} // cnt[26] 作为 mask,用来区分 tmp[ch] 是差值还是原始值
pos := map[[27]int]int{}
for i, b := range s {
sufOrder := sufOrders[i]
minCh := byte(25)
cnt[26] = 0
for j := len(sufOrder) - 1; j >= 0; j-- {
cnt[26] |= 1 << sufOrder[j]
minCh = min(minCh, sufOrder[j])
// 注意此时 cnt 并不包含 s[i],我们计算的是前缀 s[:i] 的信息
// 在子串中的字母,计算差值
// 不在子串中的字母,维持原样
d := cnt
for _, ch := range sufOrder[j:] {
d[ch] -= cnt[minCh]
}
// 记录 d 首次出现的位置
if _, ok := pos[d]; !ok {
pos[d] = i - 1
}
}
// 把最近出现的字母移到 order 末尾
move(byte(b - 'a'))
cnt[b-'a']++
minCh = byte(25)
cnt[26] = 0
for j := len(order) - 1; j >= 0; j-- {
cnt[26] |= 1 << order[j]
minCh = min(minCh, order[j])
d := cnt
for _, ch := range order[j:] {
d[ch] -= cnt[minCh]
}
// 再次遇到完全一样的状态,说明找到了一个平衡子串,左端点为 l+1,右端点为 i
if l, ok := pos[d]; ok {
ans = max(ans, i-l)
}
}
}
return
}
注:每次 $\textit{minCh}$ 变小时,我们都要重新算一遍 $\textit{tmp}$。如果计算 $\textit{tmp}$ 的多项式哈希值(代替哈希表的 key),我们可以 $\mathcal{O}(1)$ 计算哈希值的变化量,从而做到 $\mathcal{O}(n|\Sigma|)$ 时间。但该做法无法保证 100% 正确。
我们可以将问题转化为前缀和问题。定义一个前缀和变量 $\textit{now}$,表示当前子数组中奇数和偶数的差值:
$$
\textit{now} = \text{不同奇数} - \text{不同偶数}
$$
对于奇数元素记为 $+1$,偶数元素记为 $-1$。使用哈希表 $\textit{last}$ 记录每个数字上一次出现的位置,如果数字重复出现,需要撤销其之前的贡献。
为了高效计算每次右端点加入元素后子数组长度,我们使用线段树维护区间前缀和的最小值和最大值,同时支持区间加操作和线段树上二分查询。当遍历到右端点 $i$ 时,先更新当前元素的贡献,然后使用线段树查询最早出现当前前缀和 $\textit{now}$ 的位置 $pos$,当前子数组长度为 $i - pos$,更新答案:
$$
\textit{ans} = \max(\textit{ans}, i - pos)
$$
###python
class Solution:
def longestBalanced(self, nums: List[int]) -> int:
n = len(nums)
# 线段树节点
class Node:
__slots__ = ("l", "r", "mn", "mx", "lazy")
def __init__(self):
self.l = self.r = 0
self.mn = self.mx = 0
self.lazy = 0
tr = [Node() for _ in range((n + 1) * 4)]
# 建树,维护前缀和区间 [0, n]
def build(u: int, l: int, r: int):
tr[u].l, tr[u].r = l, r
tr[u].mn = tr[u].mx = tr[u].lazy = 0
if l == r:
return
mid = (l + r) >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
def apply(u: int, v: int):
tr[u].mn += v
tr[u].mx += v
tr[u].lazy += v
def pushdown(u: int):
if tr[u].lazy != 0:
apply(u << 1, tr[u].lazy)
apply(u << 1 | 1, tr[u].lazy)
tr[u].lazy = 0
def pushup(u: int):
tr[u].mn = min(tr[u << 1].mn, tr[u << 1 | 1].mn)
tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx)
# 区间加
def modify(u: int, l: int, r: int, v: int):
if tr[u].l >= l and tr[u].r <= r:
apply(u, v)
return
pushdown(u)
mid = (tr[u].l + tr[u].r) >> 1
if l <= mid:
modify(u << 1, l, r, v)
if r > mid:
modify(u << 1 | 1, l, r, v)
pushup(u)
# 线段树上二分,找最小 pos 使前缀和 == target
def query(u: int, target: int) -> int:
if tr[u].l == tr[u].r:
return tr[u].l
pushdown(u)
if tr[u << 1].mn <= target <= tr[u << 1].mx:
return query(u << 1, target)
return query(u << 1 | 1, target)
build(1, 0, n)
last = {}
now = ans = 0
for i, x in enumerate(nums, start=1):
det = 1 if (x & 1) else -1
if x in last:
modify(1, last[x], n, -det)
now -= det
last[x] = i
modify(1, i, n, det)
now += det
pos = query(1, now)
ans = max(ans, i - pos)
return ans
###java
/**
*
* 思路:
* - 将「不同奇数」视为 +1,「不同偶数」视为 -1
* - 用前缀和表示当前子数组内奇偶平衡状态
* - 由于相同数值只能算一次,需要在数值重复出现时撤销旧贡献
* - 使用线段树维护前缀和的最小值 / 最大值,并支持区间加
* - 通过线段树上二分,找到最早等于当前前缀和的位置
*/
class Solution {
/**
* 线段树节点
*/
static class Node {
int l, r; // 区间范围
int mn, mx; // 区间前缀和最小值 / 最大值
int lazy; // 懒标记:区间整体加
}
/**
* 支持区间加 + 按值二分查位置的线段树
*/
static class SegmentTree {
Node[] tr;
SegmentTree(int n) {
tr = new Node[n << 2];
for (int i = 0; i < tr.length; i++) {
tr[i] = new Node();
}
build(1, 0, n);
}
// 建树,初始前缀和均为 0
void build(int u, int l, int r) {
tr[u].l = l;
tr[u].r = r;
tr[u].mn = tr[u].mx = 0;
tr[u].lazy = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
// 区间 [l, r] 全部加 v
void modify(int u, int l, int r, int v) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(u, v);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
// 线段树上二分:查找最小位置 pos,使前缀和 == target
int query(int u, int target) {
if (tr[u].l == tr[u].r) {
return tr[u].l;
}
pushdown(u);
int left = u << 1;
int right = u << 1 | 1;
if (tr[left].mn <= target && target <= tr[left].mx) {
return query(left, target);
}
return query(right, target);
}
// 应用懒标记
void apply(int u, int v) {
tr[u].mn += v;
tr[u].mx += v;
tr[u].lazy += v;
}
// 向上更新
void pushup(int u) {
tr[u].mn = Math.min(tr[u << 1].mn, tr[u << 1 | 1].mn);
tr[u].mx = Math.max(tr[u << 1].mx, tr[u << 1 | 1].mx);
}
// 懒标记下推
void pushdown(int u) {
if (tr[u].lazy != 0) {
apply(u << 1, tr[u].lazy);
apply(u << 1 | 1, tr[u].lazy);
tr[u].lazy = 0;
}
}
}
public int longestBalanced(int[] nums) {
int n = nums.length;
SegmentTree st = new SegmentTree(n);
// last[x] 表示 x 最近一次出现的位置
Map<Integer, Integer> last = new HashMap<>();
int now = 0; // 当前前缀和
int ans = 0; // 最终答案
// 枚举子数组右端点
for (int i = 1; i <= n; i++) {
int x = nums[i - 1];
int det = (x & 1) == 1 ? 1 : -1;
// 如果之前出现过,撤销旧贡献
if (last.containsKey(x)) {
st.modify(1, last.get(x), n, -det);
now -= det;
}
// 添加新贡献
last.put(x, i);
st.modify(1, i, n, det);
now += det;
// 查找最早前缀和等于 now 的位置
int pos = st.query(1, now);
ans = Math.max(ans, i - pos);
}
return ans;
}
}
###cpp
class Node {
public:
int l = 0, r = 0;
int mn = 0, mx = 0;
int lazy = 0;
};
class SegmentTree {
public:
SegmentTree(int n) {
tr.resize(n << 2);
for (int i = 0; i < tr.size(); ++i) {
tr[i] = new Node();
}
build(1, 0, n);
}
// 区间 [l, r] 全部 +v
void modify(int u, int l, int r, int v) {
if (tr[u]->l >= l && tr[u]->r <= r) {
apply(u, v);
return;
}
pushdown(u);
int mid = (tr[u]->l + tr[u]->r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
// 线段树上二分:找最小 pos 使前缀和 == target
int query(int u, int target) {
if (tr[u]->l == tr[u]->r) {
return tr[u]->l;
}
pushdown(u);
int lc = u << 1, rc = u << 1 | 1;
if (tr[lc]->mn <= target && target <= tr[lc]->mx) {
return query(lc, target);
}
return query(rc, target);
}
private:
vector<Node*> tr;
void build(int u, int l, int r) {
tr[u]->l = l;
tr[u]->r = r;
tr[u]->mn = tr[u]->mx = 0;
tr[u]->lazy = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void apply(int u, int v) {
tr[u]->mn += v;
tr[u]->mx += v;
tr[u]->lazy += v;
}
void pushup(int u) {
tr[u]->mn = min(tr[u << 1]->mn, tr[u << 1 | 1]->mn);
tr[u]->mx = max(tr[u << 1]->mx, tr[u << 1 | 1]->mx);
}
void pushdown(int u) {
if (tr[u]->lazy != 0) {
apply(u << 1, tr[u]->lazy);
apply(u << 1 | 1, tr[u]->lazy);
tr[u]->lazy = 0;
}
}
};
class Solution {
public:
int longestBalanced(vector<int>& nums) {
int n = nums.size();
SegmentTree st(n);
unordered_map<int, int> last;
int now = 0, ans = 0;
// 枚举子数组右端点
for (int i = 1; i <= n; ++i) {
int x = nums[i - 1];
int det = (x & 1) ? 1 : -1;
// 如果该值之前出现过,移除旧贡献
if (last.count(x)) {
st.modify(1, last[x], n, -det);
now -= det;
}
// 添加当前贡献
last[x] = i;
st.modify(1, i, n, det);
now += det;
// 查找最小 pos,使前缀和 == now
int pos = st.query(1, now);
ans = max(ans, i - pos);
}
return ans;
}
};
###go
// 线段树节点
type Node struct {
l, r int // 区间范围
mn, mx int // 当前区间内前缀和最小值 / 最大值
lazy int // 懒标记:区间整体加
}
// 线段树
type SegmentTree struct {
tr []Node
}
// 构造线段树,维护区间 [0, n]
func NewSegmentTree(n int) *SegmentTree {
st := &SegmentTree{
tr: make([]Node, n<<2),
}
st.build(1, 0, n)
return st
}
// 建树:初始所有前缀和为 0
func (st *SegmentTree) build(u, l, r int) {
st.tr[u] = Node{l: l, r: r, mn: 0, mx: 0, lazy: 0}
if l == r {
return
}
mid := (l + r) >> 1
st.build(u<<1, l, mid)
st.build(u<<1|1, mid+1, r)
}
// 区间 [l, r] 整体加 v
func (st *SegmentTree) modify(u, l, r, v int) {
if st.tr[u].l >= l && st.tr[u].r <= r {
st.apply(u, v)
return
}
st.pushdown(u)
mid := (st.tr[u].l + st.tr[u].r) >> 1
if l <= mid {
st.modify(u<<1, l, r, v)
}
if r > mid {
st.modify(u<<1|1, l, r, v)
}
st.pushup(u)
}
// 线段树二分:找到最小位置 pos,使前缀和 == target
func (st *SegmentTree) query(u, target int) int {
if st.tr[u].l == st.tr[u].r {
return st.tr[u].l
}
st.pushdown(u)
left, right := u<<1, u<<1|1
if st.tr[left].mn <= target && target <= st.tr[left].mx {
return st.query(left, target)
}
return st.query(right, target)
}
// 应用懒标记
func (st *SegmentTree) apply(u, v int) {
st.tr[u].mn += v
st.tr[u].mx += v
st.tr[u].lazy += v
}
// 向上更新
func (st *SegmentTree) pushup(u int) {
st.tr[u].mn = min(st.tr[u<<1].mn, st.tr[u<<1|1].mn)
st.tr[u].mx = max(st.tr[u<<1].mx, st.tr[u<<1|1].mx)
}
// 懒标记下推
func (st *SegmentTree) pushdown(u int) {
if st.tr[u].lazy != 0 {
v := st.tr[u].lazy
st.apply(u<<1, v)
st.apply(u<<1|1, v)
st.tr[u].lazy = 0
}
}
// 主函数
func longestBalanced(nums []int) int {
n := len(nums)
st := NewSegmentTree(n)
// 记录每个值最近一次出现的位置
last := make(map[int]int)
now := 0 // 当前前缀和
ans := 0 // 最终答案
// 枚举右端点
for i := 1; i <= n; i++ {
x := nums[i-1]
det := -1
if x&1 == 1 {
det = 1
}
// 若之前出现过,撤销旧贡献
if pos, ok := last[x]; ok {
st.modify(1, pos, n, -det)
now -= det
}
// 添加新贡献
last[x] = i
st.modify(1, i, n, det)
now += det
// 查找最早前缀和等于 now 的位置
pos := st.query(1, now)
ans = max(ans, i-pos)
}
return ans
}
###ts
function longestBalanced(nums: number[]): number {
const n = nums.length;
interface Node {
l: number;
r: number;
mn: number;
mx: number;
lazy: number;
}
const tr: Node[] = Array.from({ length: (n + 1) * 4 }, () => ({
l: 0,
r: 0,
mn: 0,
mx: 0,
lazy: 0,
}));
function build(u: number, l: number, r: number) {
tr[u].l = l;
tr[u].r = r;
if (l === r) return;
const mid = (l + r) >> 1;
build(u << 1, l, mid);
build((u << 1) | 1, mid + 1, r);
}
function apply(u: number, v: number) {
tr[u].mn += v;
tr[u].mx += v;
tr[u].lazy += v;
}
function pushdown(u: number) {
if (tr[u].lazy !== 0) {
apply(u << 1, tr[u].lazy);
apply((u << 1) | 1, tr[u].lazy);
tr[u].lazy = 0;
}
}
function pushup(u: number) {
tr[u].mn = Math.min(tr[u << 1].mn, tr[(u << 1) | 1].mn);
tr[u].mx = Math.max(tr[u << 1].mx, tr[(u << 1) | 1].mx);
}
function modify(u: number, l: number, r: number, v: number) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(u, v);
return;
}
pushdown(u);
const mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify((u << 1) | 1, l, r, v);
pushup(u);
}
function query(u: number, target: number): number {
if (tr[u].l === tr[u].r) return tr[u].l;
pushdown(u);
if (tr[u << 1].mn <= target && target <= tr[u << 1].mx) {
return query(u << 1, target);
}
return query((u << 1) | 1, target);
}
build(1, 0, n);
const last = new Map<number, number>();
let now = 0,
ans = 0;
nums.forEach((x, idx) => {
const i = idx + 1;
const det = x & 1 ? 1 : -1;
if (last.has(x)) {
modify(1, last.get(x)!, n, -det);
now -= det;
}
last.set(x, i);
modify(1, i, n, det);
now += det;
const pos = query(1, now);
ans = Math.max(ans, i - pos);
});
return ans;
}
时间复杂度为 $O(n \log n)$,其中 $n$ 为数组长度。每次修改和查询线段树操作 $O(\log n)$,枚举右端点共 $n$ 次,总时间复杂度为 $O(n \log n)$,空间复杂度为 $O(n)$,其中线段树节点和哈希表各占 $O(n)$ 空间。
有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈 😄~
给你一个整数数组 nums。
如果子数组中 不同偶数 的数量等于 不同奇数 的数量,则称该 子数组 是 平衡的 。
返回 最长 平衡子数组的长度。
子数组 是数组中连续且 非空 的一段元素序列。
示例 1:
输入: nums = [2,5,4,3]
输出: 4
解释:
[2, 5, 4, 3]。[2, 4] 和 2 个不同的奇数 [5, 3]。因此,答案是 4 。示例 2:
输入: nums = [3,2,2,5,4]
输出: 5
解释:
[3, 2, 2, 5, 4] 。[2, 4] 和 2 个不同的奇数 [3, 5]。因此,答案是 5。示例 3:
输入: nums = [1,2,3,2]
输出: 3
解释:
[2, 3, 2]。[2] 和 1 个不同的奇数 [3]。因此,答案是 3。
提示:
1 <= nums.length <= 1051 <= nums[i] <= 105前置知识
该方法假定读者已经熟练掌握前缀和与线段树的相关知识与应用。
本题的关键突破口是将题意中的“奇数元素种类”和“偶数元素种类”以一种量化的方式转换为数据结构可以处理的问题,具体而言,我们可以设出现一种奇数元素记为 $-1$,出现一种偶数元素记为 $1$,子数组平衡的条件即为转换后所有元素之和为 $0$。
这样转换后,易观察出我们其实得到了一个差分数组,只不过将奇数元素记为 $-1$,偶元素记为 $1$。因此对其计算前缀和,前缀和为 $0$ 时说明对应前缀子数组是平衡的。因此在固定左边界的情况下,最长的平衡子数组的右边界即为该前缀和中最后一个 $0$ 所在的位置。
由于该差分数组的变化量绝对值不超过 $1$,因此前缀和满足离散条件下的介值定理,可以使用线段树寻找最右边的 $0$,具体计算方式如下:
由于满足离散条件下的介值定理,故可以直接通过最大值和最小值判断目标值 $0$ 是否在待搜索区间内,因此也能在 $O(\log n)$ 的时间内搜索完毕。
接下来的思路就是遍历左端点,寻找前缀和对应的最右侧的 $0$ 所在位置,得到最长平衡子数组的长度。设当前左边界下标是 $i$,当前最长平衡子数组长度是 $l$,有一个小优化是搜索的起点可以从 $i + l$ 开始,因为更近的结果即便找到也不能更新答案。
最后一个问题是向右移动左端点的过程中,如何撤销前一个位置的元素对前缀和的贡献。
先让我们从差分与前缀和的定义开始理解:差分数组中某位置 $i$ 的非零值 $v_i$,会累加到该位置及其之后的所有前缀和中。例如,若位置 $1$ 的差分贡献为 $-1$,则它会让 $S_1, S_2, \dots, S_N$ 的值都减小 $1$;再比如,若元素 $x$ 先后出现在位置 $p_1$ 和 $p_2$,我们可以认为位置 $p_1$ 处的 $x$ 负责区间 $[p_1, p_2 - 1]$ 上的贡献,而位置 $p_2$ 处的 $x$ 则负责 $[p_2, \dots]$ 上的贡献。
$$
[ \dots, 0, \underbrace{1, 1, \dots, 1}{\text{由第 1 个 x 贡献}}, \underbrace{1, 1, \dots, 1}{\text{由第 2 个 x 贡献}}, \dots ]
$$
因此,我们可以将每种元素出现的所有位置记录到各自的队列中,在更新左边界时,得到要撤销贡献的元素在前缀和中的贡献区间,然后在该区间上减去它的贡献即可。显然,这样区间加法操作也可以使用线段树完成。
基于以上算法,我们先统计前缀和以及元素出现的次数,然后不断更新左端点,使用线段树维护前缀和,寻找最右侧的 $0$,并更新全局最优解即可。
代码
###C++
struct LazyTag {
int to_add = 0;
LazyTag& operator+=(const LazyTag& other) {
this->to_add += other.to_add;
return *this;
}
bool has_tag() const { return to_add != 0; }
void clear() { to_add = 0; }
};
struct SegmentTreeNode {
int min_value = 0;
int max_value = 0;
// int data = 0; // 只有叶子节点使用, 本题不需要
LazyTag lazy_tag;
};
class SegmentTree {
public:
int n;
vector<SegmentTreeNode> tree;
SegmentTree(const vector<int>& data) : n(data.size()) {
tree.resize(n * 4 + 1);
build(data, 1, n, 1);
}
void add(int l, int r, int val) {
LazyTag tag{val};
update(l, r, tag, 1, n, 1);
}
int find_last(int start, int val) {
if (start > n) {
return -1;
}
return find(start, n, val, 1, n, 1);
}
private:
inline void apply_tag(int i, const LazyTag& tag) {
tree[i].min_value += tag.to_add;
tree[i].max_value += tag.to_add;
tree[i].lazy_tag += tag;
}
inline void pushdown(int i) {
if (tree[i].lazy_tag.has_tag()) {
LazyTag tag = tree[i].lazy_tag;
apply_tag(i << 1, tag);
apply_tag(i << 1 | 1, tag);
tree[i].lazy_tag.clear();
}
}
inline void pushup(int i) {
tree[i].min_value =
std::min(tree[i << 1].min_value, tree[i << 1 | 1].min_value);
tree[i].max_value =
std::max(tree[i << 1].max_value, tree[i << 1 | 1].max_value);
}
void build(const vector<int>& data, int l, int r, int i) {
if (l == r) {
tree[i].min_value = tree[i].max_value = data[l - 1];
return;
}
int mid = l + ((r - l) >> 1);
build(data, l, mid, i << 1);
build(data, mid + 1, r, i << 1 | 1);
pushup(i);
}
void update(int target_l, int target_r, const LazyTag& tag, int l, int r,
int i) {
if (target_l <= l && r <= target_r) {
apply_tag(i, tag);
return;
}
pushdown(i);
int mid = l + ((r - l) >> 1);
if (target_l <= mid)
update(target_l, target_r, tag, l, mid, i << 1);
if (target_r > mid)
update(target_l, target_r, tag, mid + 1, r, i << 1 | 1);
pushup(i);
}
int find(int target_l, int target_r, int val, int l, int r, int i) {
if (tree[i].min_value > val || tree[i].max_value < val) {
return -1;
}
// 根据介值定理,此时区间内必然存在解
if (l == r) {
return l;
}
pushdown(i);
int mid = l + ((r - l) >> 1);
// target_l 一定小于等于 r(=n)
if (target_r >= mid + 1) {
int res = find(target_l, target_r, val, mid + 1, r, i << 1 | 1);
if (res != -1)
return res;
}
if (l <= target_r && mid >= target_l) {
return find(target_l, target_r, val, l, mid, i << 1);
}
return -1;
}
};
class Solution {
public:
int longestBalanced(vector<int>& nums) {
map<int, queue<int>> occurrences;
auto sgn = [](int x) { return (x % 2) == 0 ? 1 : -1; };
int len = 0;
vector<int> prefix_sum(nums.size(), 0);
prefix_sum[0] = sgn(nums[0]);
occurrences[nums[0]].push(1);
for (int i = 1; i < nums.size(); i++) {
prefix_sum[i] = prefix_sum[i - 1];
auto& occ = occurrences[nums[i]];
if (occ.empty()) {
prefix_sum[i] += sgn(nums[i]);
}
occ.push(i + 1);
}
SegmentTree seg(prefix_sum);
for (int i = 0; i < nums.size(); i++) {
len = std::max(len, seg.find_last(i + len, 0) - i);
auto next_pos = nums.size() + 1;
occurrences[nums[i]].pop();
if (!occurrences[nums[i]].empty()) {
next_pos = occurrences[nums[i]].front();
}
seg.add(i + 1, next_pos - 1, -sgn(nums[i]));
}
return len;
}
};
###JavaScript
class LazyTag {
constructor() {
this.toAdd = 0;
}
add(other) {
this.toAdd += other.toAdd;
return this;
}
hasTag() {
return this.toAdd !== 0;
}
clear() {
this.toAdd = 0;
}
}
class SegmentTreeNode {
constructor() {
this.minValue = 0;
this.maxValue = 0;
// int data = 0; // 只有叶子节点使用, 本题不需要
this.lazyTag = new LazyTag();
}
}
class SegmentTree {
constructor(data) {
this.n = data.length;
this.tree = new Array(this.n * 4 + 1).fill(null).map(() => new SegmentTreeNode());
this.build(data, 1, this.n, 1);
}
add(l, r, val) {
const tag = new LazyTag();
tag.toAdd = val;
this.update(l, r, tag, 1, this.n, 1);
}
findLast(start, val) {
if (start > this.n) {
return -1;
}
return this.find(start, this.n, val, 1, this.n, 1);
}
applyTag(i, tag) {
this.tree[i].minValue += tag.toAdd;
this.tree[i].maxValue += tag.toAdd;
this.tree[i].lazyTag.add(tag);
}
pushdown(i) {
if (this.tree[i].lazyTag.hasTag()) {
const tag = new LazyTag();
tag.toAdd = this.tree[i].lazyTag.toAdd;
this.applyTag(i << 1, tag);
this.applyTag((i << 1) | 1, tag);
this.tree[i].lazyTag.clear();
}
}
pushup(i) {
this.tree[i].minValue = Math.min(this.tree[i << 1].minValue, this.tree[(i << 1) | 1].minValue);
this.tree[i].maxValue = Math.max(this.tree[i << 1].maxValue, this.tree[(i << 1) | 1].maxValue);
}
build(data, l, r, i) {
if (l == r) {
this.tree[i].minValue = this.tree[i].maxValue = data[l - 1];
return;
}
const mid = l + ((r - l) >> 1);
this.build(data, l, mid, i << 1);
this.build(data, mid + 1, r, (i << 1) | 1);
this.pushup(i);
}
update(targetL, targetR, tag, l, r, i) {
if (targetL <= l && r <= targetR) {
this.applyTag(i, tag);
return;
}
this.pushdown(i);
const mid = l + ((r - l) >> 1);
if (targetL <= mid)
this.update(targetL, targetR, tag, l, mid, i << 1);
if (targetR > mid)
this.update(targetL, targetR, tag, mid + 1, r, (i << 1) | 1);
this.pushup(i);
}
find(targetL, targetR, val, l, r, i) {
if (this.tree[i].minValue > val || this.tree[i].maxValue < val) {
return -1;
}
// 根据介值定理,此时区间内必然存在解
if (l == r) {
return l;
}
this.pushdown(i);
const mid = l + ((r - l) >> 1);
// targetL 一定小于等于 r(=n)
if (targetR >= mid + 1) {
const res = this.find(targetL, targetR, val, mid + 1, r, (i << 1) | 1);
if (res != -1)
return res;
}
if (l <= targetR && mid >= targetL) {
return this.find(targetL, targetR, val, l, mid, i << 1);
}
return -1;
}
}
var longestBalanced = function(nums) {
const occurrences = new Map();
const sgn = (x) => (x % 2 == 0 ? 1 : -1);
let len = 0;
const prefixSum = new Array(nums.length).fill(0);
prefixSum[0] = sgn(nums[0]);
if (!occurrences.has(nums[0])) occurrences.set(nums[0], new Queue());
occurrences.get(nums[0]).push(1);
for (let i = 1; i < nums.length; i++) {
prefixSum[i] = prefixSum[i - 1];
if (!occurrences.has(nums[i]))
occurrences.set(nums[i], new Queue());
const occ = occurrences.get(nums[i]);
if (occ.size() === 0) {
prefixSum[i] += sgn(nums[i]);
}
occ.push(i + 1);
}
const seg = new SegmentTree(prefixSum);
for (let i = 0; i < nums.length; i++) {
len = Math.max(len, seg.findLast(i + len, 0) - i);
let nextPos = nums.length + 1;
const occ = occurrences.get(nums[i]);
occ.pop();
if (occ.size() > 0) {
nextPos = occ.front();
}
seg.add(i + 1, nextPos - 1, -sgn(nums[i]));
}
return len;
}
###TypeScript
class LazyTag {
toAdd: number = 0;
add(other: LazyTag): LazyTag {
this.toAdd += other.toAdd;
return this;
}
hasTag(): boolean {
return this.toAdd !== 0;
}
clear(): void {
this.toAdd = 0;
}
}
class SegmentTreeNode {
minValue: number = 0;
maxValue: number = 0;
// int data = 0; // 只有叶子节点使用, 本题不需要
lazyTag: LazyTag = new LazyTag();
}
class SegmentTree {
n: number;
tree: SegmentTreeNode[];
constructor(data: number[]) {
this.n = data.length;
this.tree = new Array(this.n * 4 + 1).fill(null).map(() => new SegmentTreeNode());
this.build(data, 1, this.n, 1);
}
add(l: number, r: number, val: number): void {
const tag = new LazyTag();
tag.toAdd = val;
this.update(l, r, tag, 1, this.n, 1);
}
findLast(start: number, val: number): number {
if (start > this.n) {
return -1;
}
return this.find(start, this.n, val, 1, this.n, 1);
}
private applyTag(i: number, tag: LazyTag): void {
this.tree[i].minValue += tag.toAdd;
this.tree[i].maxValue += tag.toAdd;
this.tree[i].lazyTag.add(tag);
}
private pushdown(i: number): void {
if (this.tree[i].lazyTag.hasTag()) {
const tag = new LazyTag();
tag.toAdd = this.tree[i].lazyTag.toAdd;
this.applyTag(i << 1, tag);
this.applyTag((i << 1) | 1, tag);
this.tree[i].lazyTag.clear();
}
}
private pushup(i: number): void {
this.tree[i].minValue = Math.min(
this.tree[i << 1].minValue,
this.tree[(i << 1) | 1].minValue,
);
this.tree[i].maxValue = Math.max(
this.tree[i << 1].maxValue,
this.tree[(i << 1) | 1].maxValue,
);
}
private build(data: number[], l: number, r: number, i: number): void {
if (l == r) {
this.tree[i].minValue = this.tree[i].maxValue = data[l - 1];
return;
}
const mid = l + ((r - l) >> 1);
this.build(data, l, mid, i << 1);
this.build(data, mid + 1, r, (i << 1) | 1);
this.pushup(i);
}
private update(
targetL: number,
targetR: number,
tag: LazyTag,
l: number,
r: number,
i: number,
): void {
if (targetL <= l && r <= targetR) {
this.applyTag(i, tag);
return;
}
this.pushdown(i);
const mid = l + ((r - l) >> 1);
if (targetL <= mid) this.update(targetL, targetR, tag, l, mid, i << 1);
if (targetR > mid) this.update(targetL, targetR, tag, mid + 1, r, (i << 1) | 1);
this.pushup(i);
}
private find(
targetL: number,
targetR: number,
val: number,
l: number,
r: number,
i: number,
): number {
if (this.tree[i].minValue > val || this.tree[i].maxValue < val) {
return -1;
}
// 根据介值定理,此时区间内必然存在解
if (l == r) {
return l;
}
this.pushdown(i);
const mid = l + ((r - l) >> 1);
// targetL 一定小于等于 r(=n)
if (targetR >= mid + 1) {
const res = this.find(targetL, targetR, val, mid + 1, r, (i << 1) | 1);
if (res != -1) return res;
}
if (l <= targetR && mid >= targetL) {
return this.find(targetL, targetR, val, l, mid, i << 1);
}
return -1;
}
}
function longestBalanced(nums: number[]): number {
const occurrences = new Map<number, Queue<number>>();
const sgn = (x: number) => (x % 2 == 0 ? 1 : -1);
let len = 0;
const prefixSum: number[] = new Array(nums.length).fill(0);
prefixSum[0] = sgn(nums[0]);
if (!occurrences.has(nums[0])) occurrences.set(nums[0], new Queue());
occurrences.get(nums[0])!.push(1);
for (let i = 1; i < nums.length; i++) {
prefixSum[i] = prefixSum[i - 1];
if (!occurrences.has(nums[i])) occurrences.set(nums[i], new Queue());
const occ = occurrences.get(nums[i])!;
if (occ.size() === 0) {
prefixSum[i] += sgn(nums[i]);
}
occ.push(i + 1);
}
const seg = new SegmentTree(prefixSum);
for (let i = 0; i < nums.length; i++) {
len = Math.max(len, seg.findLast(i + len, 0) - i);
let nextPos = nums.length + 1;
const occ = occurrences.get(nums[i])!;
occ.pop();
if (occ.size() > 0) {
nextPos = occ.front();
}
seg.add(i + 1, nextPos - 1, -sgn(nums[i]));
}
return len;
}
###Java
class LazyTag {
int toAdd;
LazyTag() {
this.toAdd = 0;
}
LazyTag add(LazyTag other) {
this.toAdd += other.toAdd;
return this;
}
boolean hasTag() {
return this.toAdd != 0;
}
void clear() {
this.toAdd = 0;
}
}
class SegmentTreeNode {
int minValue;
int maxValue;
LazyTag lazyTag;
SegmentTreeNode() {
this.minValue = 0;
this.maxValue = 0;
this.lazyTag = new LazyTag();
}
}
class SegmentTree {
private int n;
private SegmentTreeNode[] tree;
SegmentTree(int[] data) {
this.n = data.length;
this.tree = new SegmentTreeNode[this.n * 4 + 1];
for (int i = 0; i < tree.length; i++) {
tree[i] = new SegmentTreeNode();
}
build(data, 1, this.n, 1);
}
void add(int l, int r, int val) {
LazyTag tag = new LazyTag();
tag.toAdd = val;
update(l, r, tag, 1, this.n, 1);
}
int findLast(int start, int val) {
if (start > this.n) {
return -1;
}
return find(start, this.n, val, 1, this.n, 1);
}
private void applyTag(int i, LazyTag tag) {
tree[i].minValue += tag.toAdd;
tree[i].maxValue += tag.toAdd;
tree[i].lazyTag.add(tag);
}
private void pushdown(int i) {
if (tree[i].lazyTag.hasTag()) {
LazyTag tag = new LazyTag();
tag.toAdd = tree[i].lazyTag.toAdd;
applyTag(i << 1, tag);
applyTag((i << 1) | 1, tag);
tree[i].lazyTag.clear();
}
}
private void pushup(int i) {
tree[i].minValue = Math.min(tree[i << 1].minValue, tree[(i << 1) | 1].minValue);
tree[i].maxValue = Math.max(tree[i << 1].maxValue, tree[(i << 1) | 1].maxValue);
}
private void build(int[] data, int l, int r, int i) {
if (l == r) {
tree[i].minValue = tree[i].maxValue = data[l - 1];
return;
}
int mid = l + ((r - l) >> 1);
build(data, l, mid, i << 1);
build(data, mid + 1, r, (i << 1) | 1);
pushup(i);
}
private void update(int targetL, int targetR, LazyTag tag, int l, int r, int i) {
if (targetL <= l && r <= targetR) {
applyTag(i, tag);
return;
}
pushdown(i);
int mid = l + ((r - l) >> 1);
if (targetL <= mid)
update(targetL, targetR, tag, l, mid, i << 1);
if (targetR > mid)
update(targetL, targetR, tag, mid + 1, r, (i << 1) | 1);
pushup(i);
}
private int find(int targetL, int targetR, int val, int l, int r, int i) {
if (tree[i].minValue > val || tree[i].maxValue < val) {
return -1;
}
if (l == r) {
return l;
}
pushdown(i);
int mid = l + ((r - l) >> 1);
if (targetR >= mid + 1) {
int res = find(targetL, targetR, val, mid + 1, r, (i << 1) | 1);
if (res != -1)
return res;
}
if (l <= targetR && mid >= targetL) {
return find(targetL, targetR, val, l, mid, i << 1);
}
return -1;
}
}
class Solution {
public int longestBalanced(int[] nums) {
Map<Integer, Queue<Integer>> occurrences = new HashMap<>();
int len = 0;
int[] prefixSum = new int[nums.length];
prefixSum[0] = sgn(nums[0]);
occurrences.computeIfAbsent(nums[0], k -> new LinkedList<>()).add(1);
for (int i = 1; i < nums.length; i++) {
prefixSum[i] = prefixSum[i - 1];
Queue<Integer> occ = occurrences.computeIfAbsent(nums[i], k -> new LinkedList<>());
if (occ.isEmpty()) {
prefixSum[i] += sgn(nums[i]);
}
occ.add(i + 1);
}
SegmentTree seg = new SegmentTree(prefixSum);
for (int i = 0; i < nums.length; i++) {
len = Math.max(len, seg.findLast(i + len, 0) - i);
int nextPos = nums.length + 1;
occurrences.get(nums[i]).poll();
if (!occurrences.get(nums[i]).isEmpty()) {
nextPos = occurrences.get(nums[i]).peek();
}
seg.add(i + 1, nextPos - 1, -sgn(nums[i]));
}
return len;
}
private int sgn(int x) {
return (x % 2) == 0 ? 1 : -1;
}
}
###C#
public class LazyTag {
public int toAdd;
public LazyTag() {
this.toAdd = 0;
}
public LazyTag Add(LazyTag other) {
this.toAdd += other.toAdd;
return this;
}
public bool HasTag() {
return this.toAdd != 0;
}
public void Clear() {
this.toAdd = 0;
}
}
public class SegmentTreeNode {
public int minValue;
public int maxValue;
public LazyTag lazyTag;
public SegmentTreeNode() {
this.minValue = 0;
this.maxValue = 0;
this.lazyTag = new LazyTag();
}
}
public class SegmentTree {
private int n;
private SegmentTreeNode[] tree;
public SegmentTree(int[] data) {
this.n = data.Length;
this.tree = new SegmentTreeNode[this.n * 4 + 1];
for (int i = 0; i < tree.Length; i++) {
tree[i] = new SegmentTreeNode();
}
Build(data, 1, this.n, 1);
}
public void Add(int l, int r, int val) {
LazyTag tag = new LazyTag();
tag.toAdd = val;
Update(l, r, tag, 1, this.n, 1);
}
public int FindLast(int start, int val) {
if (start > this.n) {
return -1;
}
return Find(start, this.n, val, 1, this.n, 1);
}
private void ApplyTag(int i, LazyTag tag) {
tree[i].minValue += tag.toAdd;
tree[i].maxValue += tag.toAdd;
tree[i].lazyTag.Add(tag);
}
private void Pushdown(int i) {
if (tree[i].lazyTag.HasTag()) {
LazyTag tag = new LazyTag();
tag.toAdd = tree[i].lazyTag.toAdd;
ApplyTag(i << 1, tag);
ApplyTag((i << 1) | 1, tag);
tree[i].lazyTag.Clear();
}
}
private void Pushup(int i) {
tree[i].minValue = Math.Min(tree[i << 1].minValue, tree[(i << 1) | 1].minValue);
tree[i].maxValue = Math.Max(tree[i << 1].maxValue, tree[(i << 1) | 1].maxValue);
}
private void Build(int[] data, int l, int r, int i) {
if (l == r) {
tree[i].minValue = tree[i].maxValue = data[l - 1];
return;
}
int mid = l + ((r - l) >> 1);
Build(data, l, mid, i << 1);
Build(data, mid + 1, r, (i << 1) | 1);
Pushup(i);
}
private void Update(int targetL, int targetR, LazyTag tag, int l, int r, int i) {
if (targetL <= l && r <= targetR) {
ApplyTag(i, tag);
return;
}
Pushdown(i);
int mid = l + ((r - l) >> 1);
if (targetL <= mid)
Update(targetL, targetR, tag, l, mid, i << 1);
if (targetR > mid)
Update(targetL, targetR, tag, mid + 1, r, (i << 1) | 1);
Pushup(i);
}
private int Find(int targetL, int targetR, int val, int l, int r, int i) {
if (tree[i].minValue > val || tree[i].maxValue < val) {
return -1;
}
if (l == r) {
return l;
}
Pushdown(i);
int mid = l + ((r - l) >> 1);
if (targetR >= mid + 1) {
int res = Find(targetL, targetR, val, mid + 1, r, (i << 1) | 1);
if (res != -1)
return res;
}
if (l <= targetR && mid >= targetL) {
return Find(targetL, targetR, val, l, mid, i << 1);
}
return -1;
}
}
public class Solution {
public int LongestBalanced(int[] nums) {
var occurrences = new Dictionary<int, Queue<int>>();
int len = 0;
int[] prefixSum = new int[nums.Length];
prefixSum[0] = Sgn(nums[0]);
if (!occurrences.ContainsKey(nums[0])) {
occurrences[nums[0]] = new Queue<int>();
}
occurrences[nums[0]].Enqueue(1);
for (int i = 1; i < nums.Length; i++) {
prefixSum[i] = prefixSum[i - 1];
if (!occurrences.ContainsKey(nums[i])) {
occurrences[nums[i]] = new Queue<int>();
}
var occ = occurrences[nums[i]];
if (occ.Count == 0) {
prefixSum[i] += Sgn(nums[i]);
}
occ.Enqueue(i + 1);
}
var seg = new SegmentTree(prefixSum);
for (int i = 0; i < nums.Length; i++) {
len = Math.Max(len, seg.FindLast(i + len, 0) - i);
int nextPos = nums.Length + 1;
occurrences[nums[i]].Dequeue();
if (occurrences[nums[i]].Count > 0) {
nextPos = occurrences[nums[i]].Peek();
}
seg.Add(i + 1, nextPos - 1, -Sgn(nums[i]));
}
return len;
}
private int Sgn(int x) {
return (x % 2) == 0 ? 1 : -1;
}
}
###Python
class LazyTag:
def __init__(self):
self.to_add = 0
def add(self, other):
self.to_add += other.to_add
return self
def has_tag(self):
return self.to_add != 0
def clear(self):
self.to_add = 0
class SegmentTreeNode:
def __init__(self):
self.min_value = 0
self.max_value = 0
self.lazy_tag = LazyTag()
class SegmentTree:
def __init__(self, data):
self.n = len(data)
self.tree = [SegmentTreeNode() for _ in range(self.n * 4 + 1)]
self._build(data, 1, self.n, 1)
def add(self, l, r, val):
tag = LazyTag()
tag.to_add = val
self._update(l, r, tag, 1, self.n, 1)
def find_last(self, start, val):
if start > self.n:
return -1
return self._find(start, self.n, val, 1, self.n, 1)
def _apply_tag(self, i, tag):
self.tree[i].min_value += tag.to_add
self.tree[i].max_value += tag.to_add
self.tree[i].lazy_tag.add(tag)
def _pushdown(self, i):
if self.tree[i].lazy_tag.has_tag():
tag = LazyTag()
tag.to_add = self.tree[i].lazy_tag.to_add
self._apply_tag(i << 1, tag)
self._apply_tag((i << 1) | 1, tag)
self.tree[i].lazy_tag.clear()
def _pushup(self, i):
self.tree[i].min_value = min(self.tree[i << 1].min_value,
self.tree[(i << 1) | 1].min_value)
self.tree[i].max_value = max(self.tree[i << 1].max_value,
self.tree[(i << 1) | 1].max_value)
def _build(self, data, l, r, i):
if l == r:
self.tree[i].min_value = data[l - 1]
self.tree[i].max_value = data[l - 1]
return
mid = l + ((r - l) >> 1)
self._build(data, l, mid, i << 1)
self._build(data, mid + 1, r, (i << 1) | 1)
self._pushup(i)
def _update(self, target_l, target_r, tag, l, r, i):
if target_l <= l and r <= target_r:
self._apply_tag(i, tag)
return
self._pushdown(i)
mid = l + ((r - l) >> 1)
if target_l <= mid:
self._update(target_l, target_r, tag, l, mid, i << 1)
if target_r > mid:
self._update(target_l, target_r, tag, mid + 1, r, (i << 1) | 1)
self._pushup(i)
def _find(self, target_l, target_r, val, l, r, i):
if self.tree[i].min_value > val or self.tree[i].max_value < val:
return -1
if l == r:
return l
self._pushdown(i)
mid = l + ((r - l) >> 1)
if target_r >= mid + 1:
res = self._find(target_l, target_r, val, mid + 1, r, (i << 1) | 1)
if res != -1:
return res
if l <= target_r and mid >= target_l:
return self._find(target_l, target_r, val, l, mid, i << 1)
return -1
class Solution:
def longestBalanced(self, nums: List[int]) -> int:
occurrences = defaultdict(deque)
def sgn(x):
return 1 if x % 2 == 0 else -1
length = 0
prefix_sum = [0] * len(nums)
prefix_sum[0] = sgn(nums[0])
occurrences[nums[0]].append(1)
for i in range(1, len(nums)):
prefix_sum[i] = prefix_sum[i - 1]
occ = occurrences[nums[i]]
if not occ:
prefix_sum[i] += sgn(nums[i])
occ.append(i + 1)
seg = SegmentTree(prefix_sum)
for i in range(len(nums)):
length = max(length, seg.find_last(i + length, 0) - i)
next_pos = len(nums) + 1
occurrences[nums[i]].popleft()
if occurrences[nums[i]]:
next_pos = occurrences[nums[i]][0]
seg.add(i + 1, next_pos - 1, -sgn(nums[i]))
return length
###Go
type LazyTag struct {
toAdd int
}
func (l *LazyTag) Add(other *LazyTag) *LazyTag {
l.toAdd += other.toAdd
return l
}
func (l *LazyTag) HasTag() bool {
return l.toAdd != 0
}
func (l *LazyTag) Clear() {
l.toAdd = 0
}
type SegmentTreeNode struct {
minValue int
maxValue int
lazyTag *LazyTag
}
func NewSegmentTreeNode() *SegmentTreeNode {
return &SegmentTreeNode{
minValue: 0,
maxValue: 0,
lazyTag: &LazyTag{},
}
}
type SegmentTree struct {
n int
tree []*SegmentTreeNode
}
func NewSegmentTree(data []int) *SegmentTree {
n := len(data)
tree := make([]*SegmentTreeNode, n*4+1)
for i := range tree {
tree[i] = NewSegmentTreeNode()
}
seg := &SegmentTree{n: n, tree: tree}
seg.build(data, 1, n, 1)
return seg
}
func (seg *SegmentTree) Add(l, r, val int) {
tag := &LazyTag{toAdd: val}
seg.update(l, r, tag, 1, seg.n, 1)
}
func (seg *SegmentTree) FindLast(start, val int) int {
if start > seg.n {
return -1
}
return seg.find(start, seg.n, val, 1, seg.n, 1)
}
func (seg *SegmentTree) applyTag(i int, tag *LazyTag) {
seg.tree[i].minValue += tag.toAdd
seg.tree[i].maxValue += tag.toAdd
seg.tree[i].lazyTag.Add(tag)
}
func (seg *SegmentTree) pushdown(i int) {
if seg.tree[i].lazyTag.HasTag() {
tag := &LazyTag{toAdd: seg.tree[i].lazyTag.toAdd}
seg.applyTag(i<<1, tag)
seg.applyTag((i<<1)|1, tag)
seg.tree[i].lazyTag.Clear()
}
}
func (seg *SegmentTree) pushup(i int) {
left := seg.tree[i<<1]
right := seg.tree[(i<<1)|1]
seg.tree[i].minValue = min(left.minValue, right.minValue)
seg.tree[i].maxValue = max(left.maxValue, right.maxValue)
}
func (seg *SegmentTree) build(data []int, l, r, i int) {
if l == r {
seg.tree[i].minValue = data[l-1]
seg.tree[i].maxValue = data[l-1]
return
}
mid := l + ((r - l) >> 1)
seg.build(data, l, mid, i<<1)
seg.build(data, mid+1, r, (i<<1)|1)
seg.pushup(i)
}
func (seg *SegmentTree) update(targetL, targetR int, tag *LazyTag, l, r, i int) {
if targetL <= l && r <= targetR {
seg.applyTag(i, tag)
return
}
seg.pushdown(i)
mid := l + ((r - l) >> 1)
if targetL <= mid {
seg.update(targetL, targetR, tag, l, mid, i<<1)
}
if targetR > mid {
seg.update(targetL, targetR, tag, mid+1, r, (i<<1)|1)
}
seg.pushup(i)
}
func (seg *SegmentTree) find(targetL, targetR, val, l, r, i int) int {
if seg.tree[i].minValue > val || seg.tree[i].maxValue < val {
return -1
}
if l == r {
return l
}
seg.pushdown(i)
mid := l + ((r - l) >> 1)
if targetR >= mid+1 {
res := seg.find(targetL, targetR, val, mid+1, r, (i<<1)|1)
if res != -1 {
return res
}
}
if l <= targetR && mid >= targetL {
return seg.find(targetL, targetR, val, l, mid, i<<1)
}
return -1
}
func longestBalanced(nums []int) int {
occurrences := make(map[int][]int)
sgn := func(x int) int {
if x%2 == 0 {
return 1
}
return -1
}
length := 0
prefixSum := make([]int, len(nums))
prefixSum[0] = sgn(nums[0])
occurrences[nums[0]] = append(occurrences[nums[0]], 1)
for i := 1; i < len(nums); i++ {
prefixSum[i] = prefixSum[i-1]
occ := occurrences[nums[i]]
if len(occ) == 0 {
prefixSum[i] += sgn(nums[i])
}
occurrences[nums[i]] = append(occ, i+1)
}
seg := NewSegmentTree(prefixSum)
for i := 0; i < len(nums); i++ {
length = max(length, seg.FindLast(i+length, 0)-i)
nextPos := len(nums) + 1
occurrences[nums[i]] = occurrences[nums[i]][1:]
if len(occurrences[nums[i]]) > 0 {
nextPos = occurrences[nums[i]][0]
}
seg.Add(i+1, nextPos-1, -sgn(nums[i]))
}
return length
}
###C
typedef struct ListNode ListNode;
typedef struct {
ListNode *head;
int size;
} List;
typedef struct {
int key;
List *val;
UT_hash_handle hh;
} HashItem;
List* listCreate() {
List *list = (List*)malloc(sizeof(List));
list->head = NULL;
list->size = 0;
return list;
}
void listPush(List *list, int val) {
ListNode *node = (ListNode*)malloc(sizeof(ListNode));
node->val = val;
node->next = list->head;
list->head = node;
list->size++;
}
void listPop(List *list) {
if (list->head == NULL) return;
ListNode *temp = list->head;
list->head = list->head->next;
free(temp);
list->size--;
}
int listAt(List *list, int index) {
ListNode *cur = list->head;
for (int i = 0; i < index && cur != NULL; i++) {
cur = cur->next;
}
return cur ? cur->val : -1;
}
void listReverse(List *list) {
ListNode *prev = NULL;
ListNode *cur = list->head;
ListNode *next = NULL;
while (cur != NULL) {
next = cur->next;
cur->next = prev;
prev = cur;
cur = next;
}
list->head = prev;
}
void listFree(List *list) {
while (list->head != NULL) {
listPop(list);
}
free(list);
}
HashItem* hashFindItem(HashItem **obj, int key) {
HashItem *pEntry = NULL;
HASH_FIND_INT(*obj, &key, pEntry);
return pEntry;
}
bool hashAddItem(HashItem **obj, int key, List *val) {
if (hashFindItem(obj, key)) {
return false;
}
HashItem *pEntry = (HashItem*)malloc(sizeof(HashItem));
pEntry->key = key;
pEntry->val = val;
HASH_ADD_INT(*obj, key, pEntry);
return true;
}
List* hashGetItem(HashItem **obj, int key) {
HashItem *pEntry = hashFindItem(obj, key);
if (!pEntry) {
List *newList = listCreate();
hashAddItem(obj, key, newList);
return newList;
}
return pEntry->val;
}
void hashFree(HashItem **obj) {
HashItem *curr = NULL, *tmp = NULL;
HASH_ITER(hh, *obj, curr, tmp) {
HASH_DEL(*obj, curr);
listFree(curr->val);
free(curr);
}
}
void hashIterate(HashItem **obj, void (*callback)(HashItem *item)) {
HashItem *curr = NULL, *tmp = NULL;
HASH_ITER(hh, *obj, curr, tmp) {
callback(curr);
}
}
typedef struct {
int toAdd;
} LazyTag;
void lazyTagAdd(LazyTag *tag, LazyTag *other) {
tag->toAdd += other->toAdd;
}
bool lazyTagHasTag(LazyTag *tag) {
return tag->toAdd != 0;
}
void lazyTagClear(LazyTag *tag) {
tag->toAdd = 0;
}
typedef struct {
int minValue;
int maxValue;
LazyTag lazyTag;
} SegmentTreeNode;
typedef struct {
int n;
SegmentTreeNode *tree;
} SegmentTree;
void segmentTreeApplyTag(SegmentTree *seg, int i, LazyTag *tag) {
seg->tree[i].minValue += tag->toAdd;
seg->tree[i].maxValue += tag->toAdd;
lazyTagAdd(&seg->tree[i].lazyTag, tag);
}
void segmentTreePushdown(SegmentTree *seg, int i) {
if (lazyTagHasTag(&seg->tree[i].lazyTag)) {
LazyTag tag = {seg->tree[i].lazyTag.toAdd};
segmentTreeApplyTag(seg, i << 1, &tag);
segmentTreeApplyTag(seg, (i << 1) | 1, &tag);
lazyTagClear(&seg->tree[i].lazyTag);
}
}
void segmentTreePushup(SegmentTree *seg, int i) {
seg->tree[i].minValue = fmin(seg->tree[i << 1].minValue, seg->tree[(i << 1) | 1].minValue);
seg->tree[i].maxValue = fmax(seg->tree[i << 1].maxValue, seg->tree[(i << 1) | 1].maxValue);
}
void segmentTreeBuild(SegmentTree *seg, int *data, int l, int r, int i) {
if (l == r) {
seg->tree[i].minValue = seg->tree[i].maxValue = data[l - 1];
return;
}
int mid = l + ((r - l) >> 1);
segmentTreeBuild(seg, data, l, mid, i << 1);
segmentTreeBuild(seg, data, mid + 1, r, (i << 1) | 1);
segmentTreePushup(seg, i);
}
void segmentTreeUpdate(SegmentTree *seg, int targetL, int targetR, LazyTag *tag,
int l, int r, int i) {
if (targetL <= l && r <= targetR) {
segmentTreeApplyTag(seg, i, tag);
return;
}
segmentTreePushdown(seg, i);
int mid = l + ((r - l) >> 1);
if (targetL <= mid) {
segmentTreeUpdate(seg, targetL, targetR, tag, l, mid, i << 1);
}
if (targetR > mid) {
segmentTreeUpdate(seg, targetL, targetR, tag, mid + 1, r, (i << 1) | 1);
}
segmentTreePushup(seg, i);
}
int segmentTreeFind(SegmentTree *seg, int targetL, int targetR, int val,
int l, int r, int i) {
if (seg->tree[i].minValue > val || seg->tree[i].maxValue < val) {
return -1;
}
if (l == r) {
return l;
}
segmentTreePushdown(seg, i);
int mid = l + ((r - l) >> 1);
if (targetR >= mid + 1) {
int res = segmentTreeFind(seg, targetL, targetR, val, mid + 1, r, (i << 1) | 1);
if (res != -1) {
return res;
}
}
if (targetL <= mid) {
return segmentTreeFind(seg, targetL, targetR, val, l, mid, i << 1);
}
return -1;
}
SegmentTree* segmentTreeCreate(int *data, int n) {
SegmentTree *seg = (SegmentTree*)malloc(sizeof(SegmentTree));
seg->n = n;
seg->tree = (SegmentTreeNode*)calloc(n * 4 + 1, sizeof(SegmentTreeNode));
segmentTreeBuild(seg, data, 1, n, 1);
return seg;
}
void segmentTreeAdd(SegmentTree *seg, int l, int r, int val) {
LazyTag tag = {val};
segmentTreeUpdate(seg, l, r, &tag, 1, seg->n, 1);
}
int segmentTreeFindLast(SegmentTree *seg, int start, int val) {
if (start > seg->n) {
return -1;
}
return segmentTreeFind(seg, start, seg->n, val, 1, seg->n, 1);
}
void segmentTreeFree(SegmentTree *seg) {
free(seg->tree);
free(seg);
}
int sgn(int x) {
return (x % 2 == 0) ? 1 : -1;
}
void reverseList(HashItem *item) {
listReverse(item->val);
}
int longestBalanced(int* nums, int numsSize) {
HashItem *occurrences = NULL;
int len = 0;
int *prefixSum = (int*)calloc(numsSize, sizeof(int));
prefixSum[0] = sgn(nums[0]);
List *list0 = hashGetItem(&occurrences, nums[0]);
listPush(list0, 1);
for (int i = 1; i < numsSize; i++) {
prefixSum[i] = prefixSum[i - 1];
List *occ = hashGetItem(&occurrences, nums[i]);
if (occ->size == 0) {
prefixSum[i] += sgn(nums[i]);
}
listPush(occ, i + 1);
}
hashIterate(&occurrences, reverseList);
SegmentTree *seg = segmentTreeCreate(prefixSum, numsSize);
for (int i = 0; i < numsSize; i++) {
int findResult = segmentTreeFindLast(seg, i + len, 0);
int newLen = findResult - i;
if (newLen > len) {
len = newLen;
}
int nextPos = numsSize + 1;
List *occ = hashGetItem(&occurrences, nums[i]);
listPop(occ);
if (occ->size > 0) {
nextPos = listAt(occ, 0);
}
segmentTreeAdd(seg, i + 1, nextPos - 1, -sgn(nums[i]));
}
segmentTreeFree(seg);
free(prefixSum);
hashFree(&occurrences);
return len;
}
###Rust
use std::collections::{HashMap, VecDeque};
use std::cmp::max;
#[derive(Debug, Clone, Copy)]
struct LazyTag {
add: i32,
}
impl LazyTag {
fn new() -> Self {
LazyTag { add: 0 }
}
fn is_empty(&self) -> bool {
self.add == 0
}
fn combine(&mut self, other: &LazyTag) {
self.add += other.add;
}
fn clear(&mut self) {
self.add = 0;
}
}
#[derive(Debug, Clone)]
struct Node {
min_val: i32,
max_val: i32,
lazy: LazyTag,
}
impl Node {
fn new() -> Self {
Node {
min_val: 0,
max_val: 0,
lazy: LazyTag::new(),
}
}
}
struct SegmentTree {
n: usize,
tree: Vec<Node>,
}
impl SegmentTree {
fn new(data: &[i32]) -> Self {
let n = data.len();
let mut tree = vec![Node::new(); 4 * n];
let mut seg = SegmentTree { n, tree };
seg.build(data, 1, n, 1);
seg
}
fn build(&mut self, data: &[i32], l: usize, r: usize, idx: usize) {
if l == r {
self.tree[idx].min_val = data[l - 1];
self.tree[idx].max_val = data[l - 1];
return;
}
let mid = (l + r) / 2;
self.build(data, l, mid, idx * 2);
self.build(data, mid + 1, r, idx * 2 + 1);
self.push_up(idx);
}
fn push_up(&mut self, idx: usize) {
let left_min = self.tree[idx * 2].min_val;
let left_max = self.tree[idx * 2].max_val;
let right_min = self.tree[idx * 2 + 1].min_val;
let right_max = self.tree[idx * 2 + 1].max_val;
self.tree[idx].min_val = left_min.min(right_min);
self.tree[idx].max_val = left_max.max(right_max);
}
fn apply(&mut self, idx: usize, tag: &LazyTag) {
self.tree[idx].min_val += tag.add;
self.tree[idx].max_val += tag.add;
self.tree[idx].lazy.combine(tag);
}
fn push_down(&mut self, idx: usize) {
if self.tree[idx].lazy.is_empty() {
return;
}
let tag = self.tree[idx].lazy;
self.apply(idx * 2, &tag);
self.apply(idx * 2 + 1, &tag);
self.tree[idx].lazy.clear();
}
fn range_add(&mut self, l: usize, r: usize, val: i32) {
if l > r || l > self.n || r < 1 {
return;
}
let tag = LazyTag { add: val };
self._update(l, r, &tag, 1, self.n, 1);
}
fn _update(&mut self, ql: usize, qr: usize, tag: &LazyTag,
l: usize, r: usize, idx: usize) {
if ql > r || qr < l {
return;
}
if ql <= l && r <= qr {
self.apply(idx, tag);
return;
}
self.push_down(idx);
let mid = (l + r) / 2;
if ql <= mid {
self._update(ql, qr, tag, l, mid, idx * 2);
}
if qr > mid {
self._update(ql, qr, tag, mid + 1, r, idx * 2 + 1);
}
self.push_up(idx);
}
fn find_last_zero(&mut self, start: usize, val: i32) -> i32 {
if start > self.n {
return -1;
}
self._find(start, self.n, val, 1, self.n, 1)
}
fn _find(&mut self, ql: usize, qr: usize, val: i32,
l: usize, r: usize, idx: usize) -> i32 {
if l > qr || r < ql || self.tree[idx].min_val > val || self.tree[idx].max_val < val {
return -1;
}
if l == r {
return l as i32;
}
self.push_down(idx);
let mid = (l + r) / 2;
let right_res = self._find(ql, qr, val, mid + 1, r, idx * 2 + 1);
if right_res != -1 {
return right_res;
}
self._find(ql, qr, val, l, mid, idx * 2)
}
fn query_min(&self, l: usize, r: usize) -> i32 {
self._query_min(l, r, 1, self.n, 1)
}
fn _query_min(&self, ql: usize, qr: usize, l: usize, r: usize, idx: usize) -> i32 {
if ql > r || qr < l {
return i32::MAX;
}
if ql <= l && r <= qr {
return self.tree[idx].min_val;
}
let mid = (l + r) / 2;
let left_min = self._query_min(ql, qr, l, mid, idx * 2);
let right_min = self._query_min(ql, qr, mid + 1, r, idx * 2 + 1);
left_min.min(right_min)
}
}
impl Solution {
pub fn longest_balanced(nums: Vec<i32>) -> i32 {
let n = nums.len();
if n == 0 {
return 0;
}
fn sign(x: i32) -> i32 {
if x % 2 == 0 { 1 } else { -1 }
}
let mut prefix_sum = vec![0; n];
prefix_sum[0] = sign(nums[0]);
let mut pos_map: HashMap<i32, VecDeque<usize>> = HashMap::new();
pos_map.entry(nums[0]).or_insert_with(VecDeque::new).push_back(1);
for i in 1..n {
prefix_sum[i] = prefix_sum[i - 1];
let positions = pos_map.entry(nums[i]).or_insert_with(VecDeque::new);
if positions.is_empty() {
prefix_sum[i] += sign(nums[i]);
}
positions.push_back(i + 1);
}
let mut seg_tree = SegmentTree::new(&prefix_sum);
let mut max_len = 0;
for i in 0..n {
let start_idx = i + max_len as usize;
if start_idx < n {
let last_pos = seg_tree.find_last_zero(start_idx + 1, 0);
if last_pos != -1 {
max_len = max(max_len, last_pos - i as i32);
}
}
let num = nums[i];
let next_pos = pos_map.get_mut(&num)
.and_then(|positions| {
positions.pop_front();
positions.front().copied()
})
.unwrap_or(n + 2);
let delta = -sign(num);
if i + 1 <= next_pos - 1 {
seg_tree.range_add(i + 1, next_pos - 1, delta);
}
}
max_len
}
}
复杂度分析
时间复杂度:$O(n \log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。预处理元素出现下标以及前缀和需要 $O(n \log n)$,线段树建树需要 $O(n \log n)$,后续遍历寻找合法区间需要 $O(n)$,循环内读取映射集需要 $O(\log n)$,使用线段树进行上界查找和区间加都需要 $O(\log n)$,故主循环需要 $O(n \log n)$。最后总时间复杂度为 $O(n \log n)$。
空间复杂度:$O(n)$。线段树需要 $O(n)$ 的空间,队列和映射集总计需要 $O(n)$ 的空间。