[Python3/Java/C++/Go/TypeScript] 一题一解:线段树+前缀和+哈希表(清晰题解)
方法一:线段树 + 前缀和 + 哈希表
我们可以将问题转化为前缀和问题。定义一个前缀和变量 $\textit{now}$,表示当前子数组中奇数和偶数的差值:
$$
\textit{now} = \text{不同奇数} - \text{不同偶数}
$$
对于奇数元素记为 $+1$,偶数元素记为 $-1$。使用哈希表 $\textit{last}$ 记录每个数字上一次出现的位置,如果数字重复出现,需要撤销其之前的贡献。
为了高效计算每次右端点加入元素后子数组长度,我们使用线段树维护区间前缀和的最小值和最大值,同时支持区间加操作和线段树上二分查询。当遍历到右端点 $i$ 时,先更新当前元素的贡献,然后使用线段树查询最早出现当前前缀和 $\textit{now}$ 的位置 $pos$,当前子数组长度为 $i - pos$,更新答案:
$$
\textit{ans} = \max(\textit{ans}, i - pos)
$$
###python
class Solution:
def longestBalanced(self, nums: List[int]) -> int:
n = len(nums)
# 线段树节点
class Node:
__slots__ = ("l", "r", "mn", "mx", "lazy")
def __init__(self):
self.l = self.r = 0
self.mn = self.mx = 0
self.lazy = 0
tr = [Node() for _ in range((n + 1) * 4)]
# 建树,维护前缀和区间 [0, n]
def build(u: int, l: int, r: int):
tr[u].l, tr[u].r = l, r
tr[u].mn = tr[u].mx = tr[u].lazy = 0
if l == r:
return
mid = (l + r) >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
def apply(u: int, v: int):
tr[u].mn += v
tr[u].mx += v
tr[u].lazy += v
def pushdown(u: int):
if tr[u].lazy != 0:
apply(u << 1, tr[u].lazy)
apply(u << 1 | 1, tr[u].lazy)
tr[u].lazy = 0
def pushup(u: int):
tr[u].mn = min(tr[u << 1].mn, tr[u << 1 | 1].mn)
tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx)
# 区间加
def modify(u: int, l: int, r: int, v: int):
if tr[u].l >= l and tr[u].r <= r:
apply(u, v)
return
pushdown(u)
mid = (tr[u].l + tr[u].r) >> 1
if l <= mid:
modify(u << 1, l, r, v)
if r > mid:
modify(u << 1 | 1, l, r, v)
pushup(u)
# 线段树上二分,找最小 pos 使前缀和 == target
def query(u: int, target: int) -> int:
if tr[u].l == tr[u].r:
return tr[u].l
pushdown(u)
if tr[u << 1].mn <= target <= tr[u << 1].mx:
return query(u << 1, target)
return query(u << 1 | 1, target)
build(1, 0, n)
last = {}
now = ans = 0
for i, x in enumerate(nums, start=1):
det = 1 if (x & 1) else -1
if x in last:
modify(1, last[x], n, -det)
now -= det
last[x] = i
modify(1, i, n, det)
now += det
pos = query(1, now)
ans = max(ans, i - pos)
return ans
###java
/**
*
* 思路:
* - 将「不同奇数」视为 +1,「不同偶数」视为 -1
* - 用前缀和表示当前子数组内奇偶平衡状态
* - 由于相同数值只能算一次,需要在数值重复出现时撤销旧贡献
* - 使用线段树维护前缀和的最小值 / 最大值,并支持区间加
* - 通过线段树上二分,找到最早等于当前前缀和的位置
*/
class Solution {
/**
* 线段树节点
*/
static class Node {
int l, r; // 区间范围
int mn, mx; // 区间前缀和最小值 / 最大值
int lazy; // 懒标记:区间整体加
}
/**
* 支持区间加 + 按值二分查位置的线段树
*/
static class SegmentTree {
Node[] tr;
SegmentTree(int n) {
tr = new Node[n << 2];
for (int i = 0; i < tr.length; i++) {
tr[i] = new Node();
}
build(1, 0, n);
}
// 建树,初始前缀和均为 0
void build(int u, int l, int r) {
tr[u].l = l;
tr[u].r = r;
tr[u].mn = tr[u].mx = 0;
tr[u].lazy = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
// 区间 [l, r] 全部加 v
void modify(int u, int l, int r, int v) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(u, v);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
// 线段树上二分:查找最小位置 pos,使前缀和 == target
int query(int u, int target) {
if (tr[u].l == tr[u].r) {
return tr[u].l;
}
pushdown(u);
int left = u << 1;
int right = u << 1 | 1;
if (tr[left].mn <= target && target <= tr[left].mx) {
return query(left, target);
}
return query(right, target);
}
// 应用懒标记
void apply(int u, int v) {
tr[u].mn += v;
tr[u].mx += v;
tr[u].lazy += v;
}
// 向上更新
void pushup(int u) {
tr[u].mn = Math.min(tr[u << 1].mn, tr[u << 1 | 1].mn);
tr[u].mx = Math.max(tr[u << 1].mx, tr[u << 1 | 1].mx);
}
// 懒标记下推
void pushdown(int u) {
if (tr[u].lazy != 0) {
apply(u << 1, tr[u].lazy);
apply(u << 1 | 1, tr[u].lazy);
tr[u].lazy = 0;
}
}
}
public int longestBalanced(int[] nums) {
int n = nums.length;
SegmentTree st = new SegmentTree(n);
// last[x] 表示 x 最近一次出现的位置
Map<Integer, Integer> last = new HashMap<>();
int now = 0; // 当前前缀和
int ans = 0; // 最终答案
// 枚举子数组右端点
for (int i = 1; i <= n; i++) {
int x = nums[i - 1];
int det = (x & 1) == 1 ? 1 : -1;
// 如果之前出现过,撤销旧贡献
if (last.containsKey(x)) {
st.modify(1, last.get(x), n, -det);
now -= det;
}
// 添加新贡献
last.put(x, i);
st.modify(1, i, n, det);
now += det;
// 查找最早前缀和等于 now 的位置
int pos = st.query(1, now);
ans = Math.max(ans, i - pos);
}
return ans;
}
}
###cpp
class Node {
public:
int l = 0, r = 0;
int mn = 0, mx = 0;
int lazy = 0;
};
class SegmentTree {
public:
SegmentTree(int n) {
tr.resize(n << 2);
for (int i = 0; i < tr.size(); ++i) {
tr[i] = new Node();
}
build(1, 0, n);
}
// 区间 [l, r] 全部 +v
void modify(int u, int l, int r, int v) {
if (tr[u]->l >= l && tr[u]->r <= r) {
apply(u, v);
return;
}
pushdown(u);
int mid = (tr[u]->l + tr[u]->r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
// 线段树上二分:找最小 pos 使前缀和 == target
int query(int u, int target) {
if (tr[u]->l == tr[u]->r) {
return tr[u]->l;
}
pushdown(u);
int lc = u << 1, rc = u << 1 | 1;
if (tr[lc]->mn <= target && target <= tr[lc]->mx) {
return query(lc, target);
}
return query(rc, target);
}
private:
vector<Node*> tr;
void build(int u, int l, int r) {
tr[u]->l = l;
tr[u]->r = r;
tr[u]->mn = tr[u]->mx = 0;
tr[u]->lazy = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void apply(int u, int v) {
tr[u]->mn += v;
tr[u]->mx += v;
tr[u]->lazy += v;
}
void pushup(int u) {
tr[u]->mn = min(tr[u << 1]->mn, tr[u << 1 | 1]->mn);
tr[u]->mx = max(tr[u << 1]->mx, tr[u << 1 | 1]->mx);
}
void pushdown(int u) {
if (tr[u]->lazy != 0) {
apply(u << 1, tr[u]->lazy);
apply(u << 1 | 1, tr[u]->lazy);
tr[u]->lazy = 0;
}
}
};
class Solution {
public:
int longestBalanced(vector<int>& nums) {
int n = nums.size();
SegmentTree st(n);
unordered_map<int, int> last;
int now = 0, ans = 0;
// 枚举子数组右端点
for (int i = 1; i <= n; ++i) {
int x = nums[i - 1];
int det = (x & 1) ? 1 : -1;
// 如果该值之前出现过,移除旧贡献
if (last.count(x)) {
st.modify(1, last[x], n, -det);
now -= det;
}
// 添加当前贡献
last[x] = i;
st.modify(1, i, n, det);
now += det;
// 查找最小 pos,使前缀和 == now
int pos = st.query(1, now);
ans = max(ans, i - pos);
}
return ans;
}
};
###go
// 线段树节点
type Node struct {
l, r int // 区间范围
mn, mx int // 当前区间内前缀和最小值 / 最大值
lazy int // 懒标记:区间整体加
}
// 线段树
type SegmentTree struct {
tr []Node
}
// 构造线段树,维护区间 [0, n]
func NewSegmentTree(n int) *SegmentTree {
st := &SegmentTree{
tr: make([]Node, n<<2),
}
st.build(1, 0, n)
return st
}
// 建树:初始所有前缀和为 0
func (st *SegmentTree) build(u, l, r int) {
st.tr[u] = Node{l: l, r: r, mn: 0, mx: 0, lazy: 0}
if l == r {
return
}
mid := (l + r) >> 1
st.build(u<<1, l, mid)
st.build(u<<1|1, mid+1, r)
}
// 区间 [l, r] 整体加 v
func (st *SegmentTree) modify(u, l, r, v int) {
if st.tr[u].l >= l && st.tr[u].r <= r {
st.apply(u, v)
return
}
st.pushdown(u)
mid := (st.tr[u].l + st.tr[u].r) >> 1
if l <= mid {
st.modify(u<<1, l, r, v)
}
if r > mid {
st.modify(u<<1|1, l, r, v)
}
st.pushup(u)
}
// 线段树二分:找到最小位置 pos,使前缀和 == target
func (st *SegmentTree) query(u, target int) int {
if st.tr[u].l == st.tr[u].r {
return st.tr[u].l
}
st.pushdown(u)
left, right := u<<1, u<<1|1
if st.tr[left].mn <= target && target <= st.tr[left].mx {
return st.query(left, target)
}
return st.query(right, target)
}
// 应用懒标记
func (st *SegmentTree) apply(u, v int) {
st.tr[u].mn += v
st.tr[u].mx += v
st.tr[u].lazy += v
}
// 向上更新
func (st *SegmentTree) pushup(u int) {
st.tr[u].mn = min(st.tr[u<<1].mn, st.tr[u<<1|1].mn)
st.tr[u].mx = max(st.tr[u<<1].mx, st.tr[u<<1|1].mx)
}
// 懒标记下推
func (st *SegmentTree) pushdown(u int) {
if st.tr[u].lazy != 0 {
v := st.tr[u].lazy
st.apply(u<<1, v)
st.apply(u<<1|1, v)
st.tr[u].lazy = 0
}
}
// 主函数
func longestBalanced(nums []int) int {
n := len(nums)
st := NewSegmentTree(n)
// 记录每个值最近一次出现的位置
last := make(map[int]int)
now := 0 // 当前前缀和
ans := 0 // 最终答案
// 枚举右端点
for i := 1; i <= n; i++ {
x := nums[i-1]
det := -1
if x&1 == 1 {
det = 1
}
// 若之前出现过,撤销旧贡献
if pos, ok := last[x]; ok {
st.modify(1, pos, n, -det)
now -= det
}
// 添加新贡献
last[x] = i
st.modify(1, i, n, det)
now += det
// 查找最早前缀和等于 now 的位置
pos := st.query(1, now)
ans = max(ans, i-pos)
}
return ans
}
###ts
function longestBalanced(nums: number[]): number {
const n = nums.length;
interface Node {
l: number;
r: number;
mn: number;
mx: number;
lazy: number;
}
const tr: Node[] = Array.from({ length: (n + 1) * 4 }, () => ({
l: 0,
r: 0,
mn: 0,
mx: 0,
lazy: 0,
}));
function build(u: number, l: number, r: number) {
tr[u].l = l;
tr[u].r = r;
if (l === r) return;
const mid = (l + r) >> 1;
build(u << 1, l, mid);
build((u << 1) | 1, mid + 1, r);
}
function apply(u: number, v: number) {
tr[u].mn += v;
tr[u].mx += v;
tr[u].lazy += v;
}
function pushdown(u: number) {
if (tr[u].lazy !== 0) {
apply(u << 1, tr[u].lazy);
apply((u << 1) | 1, tr[u].lazy);
tr[u].lazy = 0;
}
}
function pushup(u: number) {
tr[u].mn = Math.min(tr[u << 1].mn, tr[(u << 1) | 1].mn);
tr[u].mx = Math.max(tr[u << 1].mx, tr[(u << 1) | 1].mx);
}
function modify(u: number, l: number, r: number, v: number) {
if (tr[u].l >= l && tr[u].r <= r) {
apply(u, v);
return;
}
pushdown(u);
const mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify((u << 1) | 1, l, r, v);
pushup(u);
}
function query(u: number, target: number): number {
if (tr[u].l === tr[u].r) return tr[u].l;
pushdown(u);
if (tr[u << 1].mn <= target && target <= tr[u << 1].mx) {
return query(u << 1, target);
}
return query((u << 1) | 1, target);
}
build(1, 0, n);
const last = new Map<number, number>();
let now = 0,
ans = 0;
nums.forEach((x, idx) => {
const i = idx + 1;
const det = x & 1 ? 1 : -1;
if (last.has(x)) {
modify(1, last.get(x)!, n, -det);
now -= det;
}
last.set(x, i);
modify(1, i, n, det);
now += det;
const pos = query(1, now);
ans = Math.max(ans, i - pos);
});
return ans;
}
时间复杂度为 $O(n \log n)$,其中 $n$ 为数组长度。每次修改和查询线段树操作 $O(\log n)$,枚举右端点共 $n$ 次,总时间复杂度为 $O(n \log n)$,空间复杂度为 $O(n)$,其中线段树节点和哈希表各占 $O(n)$ 空间。
有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈 😄~