普通视图

发现新文章,点击刷新页面。
今天 — 2026年1月23日首页

两种方法:两个有序集合 / 懒删除堆+数组模拟双向链表(Python/Java/C++/Go)

作者 endlesscheng
2025年4月6日 13:47

为了快速模拟题目的操作,我们需要维护三种信息:

  1. 把相邻元素和 $s$,以及相邻元素中的左边元素的下标 $i$,组成一个 pair $(s,i)$。我们需要添加 pair、删除 pair 以及查询这些 pair 的最小值(双关键字比较),这可以用有序集合,或者懒删除堆
  2. 维护剩余下标。我们需要查询每个下标 $i$ 左侧最近剩余下标,以及右侧最近剩余下标。这可以用有序集合,或者两个并查集,或者双向链表
  3. 在相邻元素中,满足「左边元素大于右边元素」的个数,记作 $\textit{dec}$。

不断模拟操作,直到 $\textit{dec} = 0$。

题目说「用它们的和替换这对元素」,设操作的这对元素的下标为 $i$ 和 $\textit{nxt}$,操作相当于把 $\textit{nums}[i]$ 增加 $\textit{nums}[\textit{nxt}]$,然后删除 $\textit{nums}[\textit{nxt}]$。

在这个过程中,$\textit{dec}$ 如何变化?

设操作的这对元素的下标为 $i$ 和 $\textit{nxt}$,$i$ 左侧最近剩余下标为 $\textit{pre}$,$\textit{nxt}$ 右侧最近剩余下标为 $\textit{nxt}_2$。

操作会影响 $\textit{nums}[i]$ 和 $\textit{nums}[\textit{nxt}]$,也会影响周边相邻元素的大小关系。所以每次操作,我们需要重新考察 $4$ 个元素值的大小关系,下标从左到右为 $\textit{pre},i,\textit{nxt},\textit{nxt}_2$。

  1. 删除 $\textit{nums}[\textit{nxt}]$。如果删除前 $\textit{nums}[i] > \textit{nums}[\textit{nxt}]$,把 $\textit{dec}$ 减一。
  2. 如果删除前 $\textit{nums}[\textit{pre}] > \textit{nums}[i]$,把 $\textit{dec}$ 减一。如果删除后 $\textit{nums}[\textit{pre}] > s$,把 $\textit{dec}$ 加一。这里 $s$ 表示操作的这对元素之和,也就是新的 $\textit{nums}[i]$ 的值。
  3. 如果删除前 $\textit{nums}[\textit{nxt}] > \textit{nums}[\textit{nxt}_2]$,把 $\textit{dec}$ 减一。删除后 $i$ 和 $\textit{nxt}_2$ 相邻,如果删除后 $s > \textit{nums}[\textit{nxt}_2]$,把 $\textit{dec}$ 加一。

上述过程中,同时维护(添加删除)新旧相邻元素和以及下标。

本题视频讲解,欢迎点赞关注~

写法一:两个有序集合

###py

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        sl = SortedList()  # (相邻元素和,左边那个数的下标)
        idx = SortedList(range(len(nums)))  # 剩余下标
        dec = 0  # 递减的相邻对的个数

        for i, (x, y) in enumerate(pairwise(nums)):
            if x > y:
                dec += 1
            sl.add((x + y, i))

        ans = 0
        while dec > 0:
            ans += 1

            s, i = sl.pop(0)  # 删除相邻元素和最小的一对
            k = idx.bisect_left(i)

            # (当前元素,下一个数)
            nxt = idx[k + 1]
            if nums[i] > nums[nxt]:  # 旧数据
                dec -= 1

            # (前一个数,当前元素)
            if k > 0:
                pre = idx[k - 1]
                if nums[pre] > nums[i]:  # 旧数据
                    dec -= 1
                if nums[pre] > s:  # 新数据
                    dec += 1
                sl.remove((nums[pre] + nums[i], pre))
                sl.add((nums[pre] + s, pre))

            # (下一个数,下下一个数)
            if k + 2 < len(idx):
                nxt2 = idx[k + 2]
                if nums[nxt] > nums[nxt2]:  # 旧数据
                    dec -= 1
                if s > nums[nxt2]:  # 新数据(当前元素,下下一个数)
                    dec += 1
                sl.remove((nums[nxt] + nums[nxt2], nxt))
                sl.add((s + nums[nxt2], i))

            nums[i] = s  # 把 nums[nxt] 加到 nums[i] 中
            idx.remove(nxt)  # 删除 nxt

        return ans

###java

class Solution {
    private record Pair(long s, int i) {
    }

    public int minimumPairRemoval(int[] nums) {
        int n = nums.length;
        // (相邻元素和,左边那个数的下标)
        TreeSet<Pair> pairs = new TreeSet<>((a, b) -> a.s != b.s ? Long.compare(a.s, b.s) : a.i - b.i);
        int dec = 0; // 递减的相邻对的个数
        for (int i = 0; i < n - 1; i++) {
            int x = nums[i];
            int y = nums[i + 1];
            if (x > y) {
                dec++;
            }
            pairs.add(new Pair(x + y, i));
        }

        // 剩余下标
        TreeSet<Integer> idx = new TreeSet<>();
        for (int i = 0; i < n; i++) {
            idx.add(i);
        }

        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = nums[i];
        }

        int ans = 0;
        while (dec > 0) {
            ans++;

            // 删除相邻元素和最小的一对
            Pair p = pairs.pollFirst();
            long s = p.s;
            int i = p.i;

            // (当前元素,下一个数)
            int nxt = idx.higher(i);
            if (a[i] > a[nxt]) { // 旧数据
                dec--;
            }

            // (前一个数,当前元素)
            Integer pre = idx.lower(i);
            if (pre != null) {
                if (a[pre] > a[i]) { // 旧数据
                    dec--;
                }
                if (a[pre] > s) { // 新数据
                    dec++;
                }
                pairs.remove(new Pair(a[pre] + a[i], pre));
                pairs.add(new Pair(a[pre] + s, pre));
            }

            // (下一个数,下下一个数)
            Integer nxt2 = idx.higher(nxt);
            if (nxt2 != null) {
                if (a[nxt] > a[nxt2]) { // 旧数据
                    dec--;
                }
                if (s > a[nxt2]) { // 新数据(当前元素,下下一个数)
                    dec++;
                }
                pairs.remove(new Pair(a[nxt] + a[nxt2], nxt));
                pairs.add(new Pair(s + a[nxt2], i));
            }

            a[i] = s; // 把 a[nxt] 加到 a[i] 中
            idx.remove(nxt); // 删除 nxt
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int minimumPairRemoval(vector<int>& nums) {
        int n = nums.size();
        set<pair<long long, int>> pairs; // (相邻元素和,左边那个数的下标)
        int dec = 0; // 递减的相邻对的个数
        for (int i = 0; i + 1 < n; i++) {
            int x = nums[i], y = nums[i + 1];
            if (x > y) {
                dec++;
            }
            pairs.emplace(x + y, i);
        }

        set<int> idx; // 剩余下标
        for (int i = 0; i < n; i++) {
            idx.insert(i);
        }

        vector<long long> a(nums.begin(), nums.end());
        int ans = 0;
        while (dec > 0) {
            ans++;

            // 删除相邻元素和最小的一对
            auto [s, i] = *pairs.begin();
            pairs.erase(pairs.begin());

            auto it = idx.lower_bound(i);

            // (当前元素,下一个数)
            auto nxt_it = next(it);
            int nxt = *nxt_it;
            dec -= a[i] > a[nxt]; // 旧数据

            // (前一个数,当前元素)
            if (it != idx.begin()) {
                int pre = *prev(it);
                dec -= a[pre] > a[i]; // 旧数据
                dec += a[pre] > s; // 新数据
                pairs.erase({a[pre] + a[i], pre});
                pairs.emplace(a[pre] + s, pre);
            }

            // (下一个数,下下一个数)
            auto nxt2_it = next(nxt_it);
            if (nxt2_it != idx.end()) {
                int nxt2 = *nxt2_it;
                dec -= a[nxt] > a[nxt2]; // 旧数据
                dec += s > a[nxt2]; // 新数据(当前元素,下下一个数)
                pairs.erase({a[nxt] + a[nxt2], nxt});
                pairs.emplace(s + a[nxt2], i);
            }

            a[i] = s; // 把 a[nxt] 加到 a[i] 中
            idx.erase(nxt); // 删除 nxt
        }
        return ans;
    }
};

###go

// import "github.com/emirpasic/gods/v2/trees/redblacktree"
func minimumPairRemoval(nums []int) (ans int) {
n := len(nums)
type pair struct{ s, i int }
// (相邻元素和,左边那个数的下标)
pairs := redblacktree.NewWith[pair, struct{}](func(a, b pair) int { return cmp.Or(a.s-b.s, a.i-b.i) })
dec := 0 // 递减的相邻对的个数
for i := range n - 1 {
x, y := nums[i], nums[i+1]
if x > y {
dec++
}
pairs.Put(pair{x + y, i}, struct{}{})
}

// 剩余下标
idx := redblacktree.New[int, struct{}]()
for i := range n {
idx.Put(i, struct{}{})
}

for dec > 0 {
ans++

it := pairs.Left()
s := it.Key.s
i := it.Key.i
pairs.Remove(it.Key) // 删除相邻元素和最小的一对

// (当前元素,下一个数)
node, _ := idx.Ceiling(i + 1)
nxt := node.Key
if nums[i] > nums[nxt] { // 旧数据
dec--
}

// (前一个数,当前元素)
node, _ = idx.Floor(i - 1)
if node != nil {
pre := node.Key
if nums[pre] > nums[i] { // 旧数据
dec--
}
if nums[pre] > s { // 新数据
dec++
}
pairs.Remove(pair{nums[pre] + nums[i], pre})
pairs.Put(pair{nums[pre] + s, pre}, struct{}{})
}

// (下一个数,下下一个数)
node, _ = idx.Ceiling(nxt + 1)
if node != nil {
nxt2 := node.Key
if nums[nxt] > nums[nxt2] { // 旧数据
dec--
}
if s > nums[nxt2] { // 新数据(当前元素,下下一个数)
dec++
}
pairs.Remove(pair{nums[nxt] + nums[nxt2], nxt})
pairs.Put(pair{s + nums[nxt2], i}, struct{}{})
}

nums[i] = s // 把 nums[nxt] 加到 nums[i] 中
idx.Remove(nxt) // 删除 nxt
}
return
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

写法二:懒删除堆 + 两个数组模拟双向链表

用最小堆(懒删除堆)代替维护 pair 的有序集合。

用双向链表代替维护下标的有序集合。进一步地,可以用两个数组模拟双向链表的 $\textit{prev}$ 指针和 $\textit{next}$ 指针。

如果堆顶下标 $i$ 被删除,或者 $i$ 右边下标 $\textit{nxt}$ 被删除,或者堆顶元素和不等于 $\textit{nums}[i]+\textit{nums}[\textit{nxt}]$,则弹出堆顶。

###py

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        n = len(nums)
        h = []  # (相邻元素和,左边那个数的下标)
        dec = 0  # 递减的相邻对的个数
        for i, (x, y) in enumerate(pairwise(nums)):
            if x > y:
                dec += 1
            h.append((x + y, i))
        heapify(h)
        lazy = defaultdict(int)

        # 每个下标的左右最近的未删除下标
        left = list(range(-1, n))  # 加一个哨兵,防止下标越界
        right = list(range(1, n + 1))

        ans = 0
        while dec:
            ans += 1

            while lazy[h[0]]:
                lazy[heappop(h)] -= 1
            s, i = heappop(h)  # 删除相邻元素和最小的一对

            # (当前元素,下一个数)
            nxt = right[i]
            if nums[i] > nums[nxt]:  # 旧数据
                dec -= 1

            # (前一个数,当前元素)
            pre = left[i]
            if pre >= 0:
                if nums[pre] > nums[i]:  # 旧数据
                    dec -= 1
                if nums[pre] > s:  # 新数据
                    dec += 1
                lazy[(nums[pre] + nums[i], pre)] += 1  # 懒删除
                heappush(h, (nums[pre] + s, pre))

            # (下一个数,下下一个数)
            nxt2 = right[nxt]
            if nxt2 < n:
                if nums[nxt] > nums[nxt2]:  # 旧数据
                    dec -= 1
                if s > nums[nxt2]:  # 新数据(当前元素,下下一个数)
                    dec += 1
                lazy[(nums[nxt] + nums[nxt2], nxt)] += 1  # 懒删除
                heappush(h, (s + nums[nxt2], i))

            nums[i] = s
            # 删除 nxt
            l, r = left[nxt], right[nxt]
            right[l] = r  # 模拟双向链表的删除操作
            left[r] = l

        return ans

###py

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        n = len(nums)
        h = []  # (相邻元素和,左边那个数的下标)
        dec = 0  # 递减的相邻对的个数
        for i, (x, y) in enumerate(pairwise(nums)):
            if x > y:
                dec += 1
            h.append((x + y, i))
        heapify(h)

        # 每个下标的左右最近的未删除下标
        left = list(range(-1, n))  # 加一个哨兵,防止下标越界
        right = list(range(1, n + 1))  # 注意最下面的代码,删除 nxt 的时候额外把 right[nxt] 置为 n

        ans = 0
        while dec:
            ans += 1

            # 如果堆顶数据与实际数据不符,说明堆顶数据是之前本应删除,但没有删除的数据(懒删除)
            while right[h[0][1]] >= n or h[0][0] != nums[h[0][1]] + nums[right[h[0][1]]]:
                heappop(h)
            s, i = heappop(h)  # 删除相邻元素和最小的一对

            # (当前元素,下一个数)
            nxt = right[i]
            if nums[i] > nums[nxt]:  # 旧数据
                dec -= 1

            # (前一个数,当前元素)
            pre = left[i]
            if pre >= 0:
                if nums[pre] > nums[i]:  # 旧数据
                    dec -= 1
                if nums[pre] > s:  # 新数据
                    dec += 1
                heappush(h, (nums[pre] + s, pre))

            # (下一个数,下下一个数)
            nxt2 = right[nxt]
            if nxt2 < n:
                if nums[nxt] > nums[nxt2]:  # 旧数据
                    dec -= 1
                if s > nums[nxt2]:  # 新数据(当前元素,下下一个数)
                    dec += 1
                heappush(h, (s + nums[nxt2], i))

            nums[i] = s
            # 删除 nxt
            l, r = left[nxt], right[nxt]
            right[l] = r  # 模拟双向链表的删除操作
            left[r] = l
            right[nxt] = n  # 表示删除 nxt

        return ans

###java

class Solution {
    private record Pair(long s, int i) {
    }

    public int minimumPairRemoval(int[] nums) {
        int n = nums.length;
        // (相邻元素和,左边那个数的下标)
        PriorityQueue<Pair> h = new PriorityQueue<>((a, b) -> a.s != b.s ? Long.compare(a.s, b.s) : a.i - b.i);
        int dec = 0; // 递减的相邻对的个数
        for (int i = 0; i < n - 1; i++) {
            int x = nums[i];
            int y = nums[i + 1];
            if (x > y) {
                dec++;
            }
            h.offer(new Pair(x + y, i));
        }

        // 每个下标的左右最近的未删除下标
        int[] left = new int[n + 1];
        int[] right = new int[n + 1];
        for (int i = 0; i <= n; i++) {
            left[i] = i - 1;
            right[i] = i + 1;
        }

        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = nums[i];
        }

        int ans = 0;
        while (dec > 0) {
            ans++;

            // 如果堆顶数据与实际数据不符,说明堆顶数据是之前本应删除,但没有删除的数据(懒删除)
            while (right[h.peek().i] >= n || h.peek().s != a[h.peek().i] + a[right[h.peek().i]]) {
                h.poll();
            }

            // 删除相邻元素和最小的一对
            Pair p = h.poll();
            long s = p.s;
            int i = p.i;

            // (当前元素,下一个数)
            int nxt = right[i];
            if (a[i] > a[nxt]) { // 旧数据
                dec--;
            }

            // (前一个数,当前元素)
            int pre = left[i];
            if (pre >= 0) {
                if (a[pre] > a[i]) { // 旧数据
                    dec--;
                }
                if (a[pre] > s) { // 新数据
                    dec++;
                }
                h.offer(new Pair(a[pre] + s, pre));
            }

            // (下一个数,下下一个数)
            int nxt2 = right[nxt];
            if (nxt2 < n) {
                if (a[nxt] > a[nxt2]) { // 旧数据
                    dec--;
                }
                if (s > a[nxt2]) { // 新数据(当前元素,下下一个数)
                    dec++;
                }
                h.add(new Pair(s + a[nxt2], i));
            }

            a[i] = s; // 把 a[nxt] 加到 a[i] 中
            // 删除 nxt
            int l = left[nxt];
            int r = right[nxt];
            right[l] = r; // 模拟双向链表的删除操作
            left[r] = l;
            right[nxt] = n; // 表示删除 nxt
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    int minimumPairRemoval(vector<int>& nums) {
        int n = nums.size();
        priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<>> pq; // (相邻元素和,左边那个数的下标)
        int dec = 0; // 递减的相邻对的个数
        for (int i = 0; i + 1 < n; i++) {
            int x = nums[i], y = nums[i + 1];
            if (x > y) {
                dec++;
            }
            pq.emplace(x + y, i);
        }

        // 每个下标的左右最近的未删除下标
        vector<int> left(n + 1), right(n);
        ranges::iota(left, -1);
        ranges::iota(right, 1);

        vector<long long> a(nums.begin(), nums.end());
        int ans = 0;
        while (dec) {
            ans++;

            // 如果堆顶数据与实际数据不符,说明堆顶数据是之前本应删除,但没有删除的数据(懒删除)
            while (right[pq.top().second] >= n || pq.top().first != a[pq.top().second] + a[right[pq.top().second]]) {
                pq.pop();
            }
            auto [s, i] = pq.top();
            pq.pop(); // 删除相邻元素和最小的一对

            // (当前元素,下一个数)
            int nxt = right[i];
            dec -= a[i] > a[nxt]; // 旧数据

            // (前一个数,当前元素)
            int pre = left[i];
            if (pre >= 0) {
                dec -= a[pre] > a[i]; // 旧数据
                dec += a[pre] > s; // 新数据
                pq.emplace(a[pre] + s, pre);
            }

            // (下一个数,下下一个数)
            int nxt2 = right[nxt];
            if (nxt2 < n) {
                dec -= a[nxt] > a[nxt2]; // 旧数据
                dec += s > a[nxt2]; // 新数据(当前元素,下下一个数)
                pq.emplace(s + a[nxt2], i);
            }

            a[i] = s;
            // 删除 nxt
            int l = left[nxt], r = right[nxt];
            right[l] = r; // 模拟双向链表的删除操作
            left[r] = l;
            right[nxt] = n; // 表示删除 nxt
        }

        return ans;
    }
};

###go

func minimumPairRemoval(nums []int) (ans int) {
n := len(nums)
h := make(hp, n-1)
dec := 0 // 递减的相邻对的个数
for i := range n - 1 {
x, y := nums[i], nums[i+1]
if x > y {
dec++
}
h[i] = pair{x + y, i}
}
heap.Init(&h)
lazy := map[pair]int{}

// 每个下标的左右最近的未删除下标
left := make([]int, n+1) // 加一个哨兵,防止下标越界
right := make([]int, n)
for i := range n {
left[i] = i - 1
right[i] = i + 1
}
remove := func(i int) {
l, r := left[i], right[i]
right[l] = r // 模拟双向链表的删除操作
left[r] = l
}

for dec > 0 {
ans++

for lazy[h[0]] > 0 {
lazy[h[0]]--
heap.Pop(&h)
}
p := heap.Pop(&h).(pair) // 删除相邻元素和最小的一对
s := p.s
i := p.i

// (当前元素,下一个数)
nxt := right[i]
if nums[i] > nums[nxt] { // 旧数据
dec--
}

// (前一个数,当前元素)
pre := left[i]
if pre >= 0 {
if nums[pre] > nums[i] { // 旧数据
dec--
}
if nums[pre] > s { // 新数据
dec++
}
lazy[pair{nums[pre] + nums[i], pre}]++ // 懒删除
heap.Push(&h, pair{nums[pre] + s, pre})
}

// (下一个数,下下一个数)
nxt2 := right[nxt]
if nxt2 < n {
if nums[nxt] > nums[nxt2] { // 旧数据
dec--
}
if s > nums[nxt2] { // 新数据(当前元素,下下一个数)
dec++
}
lazy[pair{nums[nxt] + nums[nxt2], nxt}]++ // 懒删除
heap.Push(&h, pair{s + nums[nxt2], i})
}

nums[i] = s
remove(nxt)
}
return
}

type pair struct{ s, i int } // (相邻元素和,左边那个数的下标)
type hp []pair

func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { a, b := h[i], h[j]; return a.s < b.s || a.s == b.s && a.i < b.i }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(pair)) }
func (h *hp) Pop() any          { a := *h; v := a[len(a)-1]; *h = a[:len(a)-1]; return v }

###go

func minimumPairRemoval(nums []int) (ans int) {
n := len(nums)
h := make(hp, n-1)
dec := 0 // 递减的相邻对的个数
for i := range n - 1 {
x, y := nums[i], nums[i+1]
if x > y {
dec++
}
h[i] = pair{x + y, i}
}
heap.Init(&h)

// 每个下标的左右最近的未删除下标
left := make([]int, n+1) // 加一个哨兵,防止下标越界
right := make([]int, n)
for i := range n {
left[i] = i - 1
right[i] = i + 1
}
remove := func(i int) {
l, r := left[i], right[i]
right[l] = r // 模拟双向链表的删除操作
left[r] = l
right[i] = n // 表示 i 已被删除
}

for dec > 0 {
ans++

// 如果堆顶数据与实际数据不符,说明堆顶数据是之前本应删除,但没有删除的数据(懒删除)
for right[h[0].i] >= n || nums[h[0].i]+nums[right[h[0].i]] != h[0].s {
heap.Pop(&h)
}
p := heap.Pop(&h).(pair) // 删除相邻元素和最小的一对
s := p.s
i := p.i

// (当前元素,下一个数)
nxt := right[i]
if nums[i] > nums[nxt] { // 旧数据
dec--
}

// (前一个数,当前元素)
pre := left[i]
if pre >= 0 {
if nums[pre] > nums[i] { // 旧数据
dec--
}
if nums[pre] > s { // 新数据
dec++
}
heap.Push(&h, pair{nums[pre] + s, pre})
}

// (下一个数,下下一个数)
nxt2 := right[nxt]
if nxt2 < n {
if nums[nxt] > nums[nxt2] { // 旧数据
dec--
}
if s > nums[nxt2] { // 新数据(当前元素,下下一个数)
dec++
}
heap.Push(&h, pair{s + nums[nxt2], i})
}

nums[i] = s
remove(nxt)
}
return
}

type pair struct{ s, i int } // (相邻元素和,左边那个数的下标)
type hp []pair

func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { a, b := h[i], h[j]; return a.s < b.s || a.s == b.s && a.i < b.i }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(pair)) }
func (h *hp) Pop() any          { a := *h; v := a[len(a)-1]; *h = a[:len(a)-1]; return v }

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(n)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

昨天 — 2026年1月22日首页
昨天以前首页

O(1) 计算每个数(Python/Java/C++/Go)

作者 endlesscheng
2024年10月13日 08:11

例如 $x=100111$,那么 $x\ |\ (x+1) = 100111\ |\ 101000 = 101111$。

可以发现,$x\ |\ (x+1)$ 的本质是把二进制最右边的 $0$ 置为 $1$。

反过来,如果已知 $x\ |\ (x+1) = 101111$,那么倒推 $x$,需要把 $101111$ 中的某个 $1$ 变成 $0$。满足要求的 $x$ 有:

$$
\begin{aligned}
100111 \
101011 \
101101 \
101110 \
\end{aligned}
$$

其中最小的是 $100111$,也就是把 $101111$ 最右边的 $0$ 的右边的 $1$ 置为 $0$。

无解的情况:由于 $x\ |\ (x+1)$ 最低位一定是 $1$(因为 $x$ 和 $x+1$ 中必有一奇数),所以如果 $\textit{nums}[i]$ 是偶数(质数中只有 $2$),那么无解。

写法一

举例说明:把 $101111$ 取反,得 $010000$,其 $\text{lowbit}=10000$,右移一位得 $1000$。把 $101111$ 与 $1000$ 异或,即可得到 $100111$。

关于 $\text{lowbit}$ 的原理,请看 从集合论到位运算,常见位运算技巧分类总结!

本题视频讲解,欢迎点赞关注~

###py

class Solution:
    def minBitwiseArray(self, nums: List[int]) -> List[int]:
        for i, x in enumerate(nums):
            if x == 2:
                nums[i] = -1
            else:
                t = ~x
                nums[i] ^= (t & -t) >> 1
        return nums

###java

class Solution {
    public int[] minBitwiseArray(List<Integer> nums) {
        int n = nums.size();
        int[] ans = new int[n];
        for (int i = 0; i < n; i++) {
            int x = nums.get(i);
            if (x == 2) {
                ans[i] = -1;
            } else {
                int t = ~x;
                int lowbit = t & -t;
                ans[i] = x ^ (lowbit >> 1);
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    vector<int> minBitwiseArray(vector<int>& nums) {
        for (int& x : nums) { // 注意这里是引用
            if (x == 2) {
                x = -1;
            } else {
                int t = ~x;
                x ^= (t & -t) >> 1;
            }
        }
        return nums;
    }
};

###go

func minBitwiseArray(nums []int) []int {
for i, x := range nums {
if x == 2 {
nums[i] = -1
} else {
t := ^x
nums[i] ^= t & -t >> 1
}
}
return nums
}

写法二

把 $101111$ 加一,得到 $110000$,再 AND $101111$ 取反后的值 $010000$,可以得到方法一中的 $\text{lowbit}=10000$。

###py

class Solution:
    def minBitwiseArray(self, nums: List[int]) -> List[int]:
        for i, x in enumerate(nums):
            if x == 2:
                nums[i] = -1
            else:
                nums[i] ^= ((x + 1) & ~x) >> 1
        return nums

###java

class Solution {
    public int[] minBitwiseArray(List<Integer> nums) {
        int n = nums.size();
        int[] ans = new int[n];
        for (int i = 0; i < n; i++) {
            int x = nums.get(i);
            if (x == 2) {
                ans[i] = -1;
            } else {
                int lowbit = (x + 1) & ~x;
                ans[i] = x ^ (lowbit >> 1);
            }
        }
        return ans;
    }
}

###cpp

class Solution {
public:
    vector<int> minBitwiseArray(vector<int>& nums) {
        for (int& x : nums) { // 注意这里是引用
            if (x == 2) {
                x = -1;
            } else {
                x ^= ((x + 1) & ~x) >> 1;
            }
        }
        return nums;
    }
};

###go

func minBitwiseArray(nums []int) []int {
for i, x := range nums {
if x == 2 {
nums[i] = -1
} else {
nums[i] ^= (x + 1) &^ x >> 1
}
}
return nums
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 是 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

专题训练

见下面位运算题单的「八、思维题」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

无需二分,暴力枚举就是 O(mn)(Python/Java/C++/Go)

作者 endlesscheng
2026年1月6日 09:29

前置知识【图解】一张图秒懂二维前缀和

预处理二维前缀和后,暴力的做法是写一个三重循环:

  • 外面两重循环,枚举正方形的左上角 $(i,j)$。
  • 最内层循环,枚举正方形的边长为 $1,2,3,\ldots$ 直到出界或者正方形元素和超过 $\textit{threshold}$ 为止。在此过程中更新答案 $\textit{ans}$ 的最大值。

这样做的时间复杂度是 $\mathcal{O}(mn\min(m,n))$。

只需改一个地方,就能让算法的时间复杂度变小:

  • 最内层循环,从 $\textit{ans}+1$ 开始枚举。

比如现在 $\textit{ans}=3$,那么枚举正方形的边长为 $1,2,3$ 是毫无意义的,不会让答案变得更大。所以直接从 $\textit{ans}+1=4$ 开始枚举更好。

###py

class Solution:
    def maxSideLength(self, mat: List[List[int]], threshold: int) -> int:
        m, n = len(mat), len(mat[0])
        s = [[0] * (n + 1) for _ in range(m + 1)]
        for i, row in enumerate(mat):
            for j, x in enumerate(row):
                s[i + 1][j + 1] = s[i + 1][j] + s[i][j + 1] - s[i][j] + x

        # 返回左上角在 (r1, c1),右下角在 (r2, c2) 的子矩阵元素和
        def query(r1: int, c1: int, r2: int, c2: int) -> int:
            return s[r2 + 1][c2 + 1] - s[r2 + 1][c1] - s[r1][c2 + 1] + s[r1][c1]

        ans = 0
        for i in range(m):
            for j in range(n):
                # 边长为 ans+1 的正方形,左上角在 (i, j),右下角在 (i+ans, j+ans)
                while i + ans < m and j + ans < n and query(i, j, i + ans, j + ans) <= threshold:
                    ans += 1
        return ans

###java

class Solution {
    public int maxSideLength(int[][] mat, int threshold) {
        int m = mat.length;
        int n = mat[0].length;
        int[][] sum = new int[m + 1][n + 1];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + mat[i][j];
            }
        }

        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                // 边长为 ans+1 的正方形,左上角在 (i, j),右下角在 (i+ans, j+ans)
                while (i + ans < m && j + ans < n && query(sum, i, j, i + ans, j + ans) <= threshold) {
                    ans++;
                }
            }
        }
        return ans;
    }

    // 返回左上角在 (r1, c1),右下角在 (r2, c2) 的子矩阵元素和
    private int query(int[][] sum, int r1, int c1, int r2, int c2) {
        return sum[r2 + 1][c2 + 1] - sum[r2 + 1][c1] - sum[r1][c2 + 1] + sum[r1][c1];
    }
}

###cpp

class Solution {
public:
    int maxSideLength(vector<vector<int>>& mat, int threshold) {
        int m = mat.size(), n = mat[0].size();
        vector sum(m + 1, vector<int>(n + 1));
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + mat[i][j];
            }
        }

        // 返回左上角在 (r1, c1),右下角在 (r2, c2) 的子矩阵元素和
        auto query = [&](int r1, int c1, int r2, int c2) -> int {
            return sum[r2 + 1][c2 + 1] - sum[r2 + 1][c1] - sum[r1][c2 + 1] + sum[r1][c1];
        };

        int ans = 0;
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                // 边长为 ans+1 的正方形,左上角在 (i, j),右下角在 (i+ans, j+ans)
                while (i + ans < m && j + ans < n && query(i, j, i + ans, j + ans) <= threshold) {
                    ans++;
                }
            }
        }
        return ans;
    }
};

###go

func maxSideLength(mat [][]int, threshold int) (ans int) {
m, n := len(mat), len(mat[0])
sum := make([][]int, m+1)
sum[0] = make([]int, n+1)
for i, row := range mat {
sum[i+1] = make([]int, n+1)
for j, x := range row {
sum[i+1][j+1] = sum[i+1][j] + sum[i][j+1] - sum[i][j] + x
}
}

// 返回左上角在 (r1, c1),右下角在 (r2, c2) 的子矩阵元素和
query := func(r1, c1, r2, c2 int) int {
return sum[r2+1][c2+1] - sum[r2+1][c1] - sum[r1][c2+1] + sum[r1][c1]
}

for i := range m {
for j := range n {
// 边长为 ans+1 的正方形,左上角在 (i, j),右下角在 (i+ans, j+ans)
for i+ans < m && j+ans < n && query(i, j, i+ans, j+ans) <= threshold {
ans++
}
}
}
return
}

注:外层循环可以改成 i + ans < m 以及 j + ans < n。但测试了一下,并没有明显提升。

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。虽然我们写了个三重循环,但由于答案最大是 $\min(m,n)$,所以最内层的 ans++ 最多执行 $\min(m,n)$ 次,三重循环的时间复杂度为 $\mathcal{O}(mn + \min(m,n)) = \mathcal{O}(mn)$。
  • 空间复杂度:$\mathcal{O}(mn)$。

思考题

本题还有一种做法:枚举对角线,在对角线上做 不定长滑动窗口,把正方形的左上角和右下角看作滑动窗口的左右端点。这也可以做到 $\mathcal{O}(mn)$ 时间。

用这个思路解决如下问题:

  • 计算元素总和小于或等于阈值的正方形的个数

欢迎在评论区分享你的代码。

专题训练

见下面数据结构题单的「§1.6 二维前缀和」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

从 O(N^4) 优化到 O(N^3)(Python/Java/C++/Go)

作者 endlesscheng
2021年6月13日 00:16

注:本题不能二分答案。「每行每列的元素和都相等」是一个非常刁钻的要求,可能中间的某个 $k$ 满足要求,$k$ 大一点或小一点都无法让每行每列的元素和都相等。

方法一:四种前缀和

从大到小枚举 $k$,判断 $\textit{grid}$ 是否存在一个 $k\times k$ 的子矩阵 $M$,满足如下要求:

  • 设 $M$ 第一行的元素和为 $s$。
  • $M$ 每行的元素和都是 $s$。
  • $M$ 每列的元素和都是 $s$。
  • $M$ 主对角线的元素和为 $s$。
  • $M$ 反对角线的元素和为 $s$。

这些参与求和的元素,在 $\textit{grid}$ 中都是连续的,我们可以用四种前缀和计算:

  • $\textit{rowSum}[i][j+1]$ 表示 $\textit{grid}$ 的 $i$ 行的前缀 $[0,j]$ 的元素和,即 $(i,0),(i,1),\ldots,(i,j)$ 的元素和。
  • $\textit{colSum}[i+1][j]$ 表示 $\textit{grid}$ 的 $j$ 列的前缀 $[0,i]$ 的元素和,即 $(0,j),(1,j),\ldots,(i,j)$ 的元素和。
  • $\textit{diagSum}[i+1][j+1]$ 表示从最上边或最左边出发,向右下↘到 $(i,j)$ 这条线上的元素和。
  • $\textit{antiSum}[i+1][j]$ 表示从最上边或最右边出发,向左下↙到 $(i,j)$ 这条线上的元素和。

为什么这里有一些 $+1$?原理在 前缀和 中讲了,是为了兼容子数组恰好是前缀的情况,此时仍然可以用两个前缀和之差算出子数组和,无需特判。

写个三重循环,依次枚举 $k,i,j$,其中 $k\times k$ 子矩阵的左上角为 $(i-k,j-k)$,右下角为 $(i-1,j-1)$,那么:

  • 主对角线的元素和为 $\textit{diagSum}[i][j] - \textit{diagSum}[i-k][j-k]$。
  • 反对角线的元素和为 $\textit{antiSum}[i][j-k]-\textit{antiSum}[i-k][j]$。
  • 在 $[i-k,i-1]$ 中枚举行号 $r$,行元素和为 $\textit{rowSum}[r][j] - \textit{rowSum}[r][j-k]$。
  • 在 $[j-k,j-1]$ 中枚举列号 $c$,列元素和为 $\textit{colSum}[i][c] - \textit{colSum}[i-k][c]$。

代码实现时,可以先求主对角线的元素和、反对角线的元素和,如果二者不相等,则无需枚举 $r$ 和 $c$。

class Solution:
    def largestMagicSquare(self, grid: List[List[int]]) -> int:
        m, n = len(grid), len(grid[0])
        row_sum = [[0] * (n + 1) for _ in range(m)]       # → 前缀和
        col_sum = [[0] * n for _ in range(m + 1)]         # ↓ 前缀和
        diag_sum = [[0] * (n + 1) for _ in range(m + 1)]  # ↘ 前缀和
        anti_sum = [[0] * (n + 1) for _ in range(m + 1)]  # ↙ 前缀和

        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                row_sum[i][j + 1] = row_sum[i][j] + x
                col_sum[i + 1][j] = col_sum[i][j] + x
                diag_sum[i + 1][j + 1] = diag_sum[i][j] + x
                anti_sum[i + 1][j] = anti_sum[i][j + 1] + x

        # k×k 子矩阵的左上角为 (i−k, j−k),右下角为 (i−1, j−1)
        for k in range(min(m, n), 0, -1):
            for i in range(k, m + 1):
                for j in range(k, n + 1):
                    # 子矩阵主对角线的和
                    s = diag_sum[i][j] - diag_sum[i - k][j - k]

                    # 子矩阵反对角线的和等于 s
                    # 子矩阵每行的和都等于 s
                    # 子矩阵每列的和都等于 s
                    if anti_sum[i][j - k] - anti_sum[i - k][j] == s and \
                       all(row_sum[r][j] - row_sum[r][j - k] == s for r in range(i - k, i)) and \
                       all(col_sum[i][c] - col_sum[i - k][c] == s for c in range(j - k, j)):
                        return k
class Solution {
    public int largestMagicSquare(int[][] grid) {
        int m = grid.length;
        int n = grid[0].length;
        int[][] rowSum = new int[m][n + 1];      // → 前缀和
        int[][] colSum = new int[m + 1][n];      // ↓ 前缀和
        int[][] diagSum = new int[m + 1][n + 1]; // ↘ 前缀和
        int[][] antiSum = new int[m + 1][n + 1]; // ↙ 前缀和

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                int x = grid[i][j];
                rowSum[i][j + 1] = rowSum[i][j] + x;
                colSum[i + 1][j] = colSum[i][j] + x;
                diagSum[i + 1][j + 1] = diagSum[i][j] + x;
                antiSum[i + 1][j] = antiSum[i][j + 1] + x;
            }
        }

        // k×k 子矩阵的左上角为 (i−k, j−k),右下角为 (i−1, j−1)
        for (int k = Math.min(m, n); ; k--) {
            for (int i = k; i <= m; i++) {
                next:
                for (int j = k; j <= n; j++) {
                    // 子矩阵主对角线的和
                    int sum = diagSum[i][j] - diagSum[i - k][j - k];

                    // 子矩阵反对角线的和
                    if (antiSum[i][j - k] - antiSum[i - k][j] != sum) {
                        continue;
                    }

                    // 子矩阵每行的和
                    for (int r = i - k; r < i; r++) {
                        if (rowSum[r][j] - rowSum[r][j - k] != sum) {
                            continue next;
                        }
                    }

                    // 子矩阵每列的和
                    for (int c = j - k; c < j; c++) {
                        if (colSum[i][c] - colSum[i - k][c] != sum) {
                            continue next;
                        }
                    }

                    return k;
                }
            }
        }
    }
}
class Solution {
public:
    int largestMagicSquare(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector row_sum(m, vector<int>(n + 1));      // → 前缀和
        vector col_sum(m + 1, vector<int>(n));      // ↓ 前缀和
        vector diag_sum(m + 1, vector<int>(n + 1)); // ↘ 前缀和
        vector anti_sum(m + 1, vector<int>(n + 1)); // ↙ 前缀和

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                int x = grid[i][j];
                row_sum[i][j + 1] = row_sum[i][j] + x;
                col_sum[i + 1][j] = col_sum[i][j] + x;
                diag_sum[i + 1][j + 1] = diag_sum[i][j] + x;
                anti_sum[i + 1][j] = anti_sum[i][j + 1] + x;
            }
        }

        // k×k 子矩阵的左上角为 (i−k, j−k),右下角为 (i−1, j−1)
        for (int k = min(m, n); ; k--) {
            for (int i = k; i <= m; i++) {
                for (int j = k; j <= n; j++) {
                    // 子矩阵主对角线的和
                    int sum = diag_sum[i][j] - diag_sum[i - k][j - k];

                    // 子矩阵反对角线的和
                    if (anti_sum[i][j - k] - anti_sum[i - k][j] != sum) {
                        continue;
                    }

                    // 子矩阵每行的和
                    bool ok = true;
                    for (int r = i - k; r < i; r++) {
                        if (row_sum[r][j] - row_sum[r][j - k] != sum) {
                            ok = false;
                            break;
                        }
                    }
                    if (!ok) {
                        continue;
                    }

                    // 子矩阵每列的和
                    for (int c = j - k; c < j; c++) {
                        if (col_sum[i][c] - col_sum[i - k][c] != sum) {
                            ok = false;
                            break;
                        }
                    }
                    if (ok) {
                        return k;
                    }
                }
            }
        }
    }
};
func largestMagicSquare(grid [][]int) int {
m, n := len(grid), len(grid[0])
rowSum := make([][]int, m)    // → 前缀和
colSum := make([][]int, m+1)  // ↓ 前缀和
diagSum := make([][]int, m+1) // ↘ 前缀和
antiSum := make([][]int, m+1) // ↙ 前缀和
for i := range m + 1 {
colSum[i] = make([]int, n)
diagSum[i] = make([]int, n+1)
antiSum[i] = make([]int, n+1)
}

for i, row := range grid {
rowSum[i] = make([]int, n+1)
for j, x := range row {
rowSum[i][j+1] = rowSum[i][j] + x
colSum[i+1][j] = colSum[i][j] + x
diagSum[i+1][j+1] = diagSum[i][j] + x
antiSum[i+1][j] = antiSum[i][j+1] + x
}
}

// k×k 子矩阵的左上角为 (i−k, j−k),右下角为 (i−1, j−1)
for k := min(m, n); ; k-- {
for i := k; i <= m; i++ {
next:
for j := k; j <= n; j++ {
// 子矩阵主对角线的和
sum := diagSum[i][j] - diagSum[i-k][j-k]

// 子矩阵反对角线的和
if antiSum[i][j-k]-antiSum[i-k][j] != sum {
continue
}

// 子矩阵每行的和
for _, rowS := range rowSum[i-k : i] {
if rowS[j]-rowS[j-k] != sum {
continue next
}
}

// 子矩阵每列的和
for c := j - k; c < j; c++ {
if colSum[i][c]-colSum[i-k][c] != sum {
continue next
}
}

return k
}
}
}
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn\min(m,n)^2)$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

方法二:维护连续等和行列的个数

从大到小枚举 $k$,判断 $\textit{grid}$ 是否存在一个 $k\times k$ 的子矩阵 $M$,满足如下要求:

  • 设 $M$ 第一行的元素和为 $s$。
  • $M$ 每行的元素和都是 $s$。优化:想象有一个 $k\times k$ 的窗口在向下滑动,我们可以维护到第 $i$ 行时,有连续多少行的和都等于 $s$。维护一个计数器 $\textit{sameCnt}$,如果当前行的和等于前一行的和,那么把 $\textit{sameCnt}$ 加一,否则把 $\textit{sameCnt}$ 重置为 $1$。如果 $\textit{sameCnt}\ge k$,则说明子矩阵每行的元素和都相等。
  • $M$ 每列的元素和都是 $s$。优化:想象有一个 $k\times k$ 的窗口在向右滑动,我们可以维护到第 $j$ 列时,有连续多少列的和都等于 $s$。算法同上。
  • $M$ 主对角线的元素和为 $s$。
  • $M$ 反对角线的元素和为 $s$。
class Solution:
    def largestMagicSquare(self, grid: List[List[int]]) -> int:
        m, n = len(grid), len(grid[0])
        row_sum = [[0] * (n + 1) for _ in range(m)]       # → 前缀和
        col_sum = [[0] * n for _ in range(m + 1)]         # ↓ 前缀和
        diag_sum = [[0] * (n + 1) for _ in range(m + 1)]  # ↘ 前缀和
        anti_sum = [[0] * (n + 1) for _ in range(m + 1)]  # ↙ 前缀和

        for i, row in enumerate(grid):
            for j, x in enumerate(row):
                row_sum[i][j + 1] = row_sum[i][j] + x
                col_sum[i + 1][j] = col_sum[i][j] + x
                diag_sum[i + 1][j + 1] = diag_sum[i][j] + x
                anti_sum[i + 1][j] = anti_sum[i][j + 1] + x

        # is_same_col_sum[i][j] 表示右下角为 (i, j) 的子矩形,每列元素和是否都相等
        is_same_col_sum = [[False] * n for _ in range(m)]

        for k in range(min(m, n), 1, -1):
            for i in range(k, m + 1):
                # 想象有一个 k×k 的窗口在向右滑动
                same_cnt = 1
                for j in range(1, n):
                    if col_sum[i][j] - col_sum[i - k][j] == col_sum[i][j - 1] - col_sum[i - k][j - 1]:
                        same_cnt += 1
                    else:
                        same_cnt = 1
                    # 连续 k 列元素和是否都一样
                    is_same_col_sum[i - 1][j] = same_cnt >= k

            for j in range(k, n + 1):
                # 想象有一个 k×k 的窗口在向下滑动
                sum_row = row_sum[0][j] - row_sum[0][j - k]
                same_cnt = 1
                for i in range(2, m + 1):
                    row_s = row_sum[i - 1][j] - row_sum[i - 1][j - k]
                    if row_s == sum_row:
                        same_cnt += 1
                        if (same_cnt >= k and  # 连续 k 行元素和都一样
                            is_same_col_sum[i - 1][j - 1] and  # 连续 k 列元素和都一样
                            col_sum[i][j - 1] - col_sum[i - k][j - 1] == sum_row and  # 列和 = 行和
                            diag_sum[i][j] - diag_sum[i - k][j - k] == sum_row and  # 主对角线和 = 行和
                            anti_sum[i][j - k] - anti_sum[i - k][j] == sum_row):  # 反对角线和 = 行和
                            return k
                    else:
                        sum_row = row_s
                        same_cnt = 1

        return 1
class Solution {
    public int largestMagicSquare(int[][] grid) {
        int m = grid.length;
        int n = grid[0].length;
        int[][] rowSum = new int[m][n + 1];      // → 前缀和
        int[][] colSum = new int[m + 1][n];      // ↓ 前缀和
        int[][] diagSum = new int[m + 1][n + 1]; // ↘ 前缀和
        int[][] antiSum = new int[m + 1][n + 1]; // ↙ 前缀和

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                int x = grid[i][j];
                rowSum[i][j + 1] = rowSum[i][j] + x;
                colSum[i + 1][j] = colSum[i][j] + x;
                diagSum[i + 1][j + 1] = diagSum[i][j] + x;
                antiSum[i + 1][j] = antiSum[i][j + 1] + x;
            }
        }

        // isSameColSum[i][j] 表示右下角为 (i, j) 的子矩形,每列元素和是否都相等
        boolean[][] isSameColSum = new boolean[m][n];

        for (int k = Math.min(m, n); k > 1; k--) {
            for (int i = k; i <= m; i++) {
                // 想象有一个 k×k 的窗口在向右滑动
                int sameCnt = 1;
                for (int j = 1; j < n; j++) {
                    if (colSum[i][j] - colSum[i - k][j] == colSum[i][j - 1] - colSum[i - k][j - 1]) {
                        sameCnt++;
                    } else {
                        sameCnt = 1;
                    }
                    // 连续 k 列元素和是否都一样
                    isSameColSum[i - 1][j] = sameCnt >= k;
                }
            }

            for (int j = k; j <= n; j++) {
                // 想象有一个 k×k 的窗口在向下滑动
                int sum = rowSum[0][j] - rowSum[0][j - k];
                int sameCnt = 1;
                for (int i = 2; i <= m; i++) {
                    int rowS = rowSum[i - 1][j] - rowSum[i - 1][j - k];
                    if (rowS == sum) {
                        sameCnt++;
                        if (sameCnt >= k && // 连续 k 行元素和都一样
                            isSameColSum[i - 1][j - 1] && // 连续 k 列元素和都一样
                            colSum[i][j - 1] - colSum[i - k][j - 1] == sum && // 列和 = 行和
                            diagSum[i][j] - diagSum[i - k][j - k] == sum && // 主对角线和 = 行和
                            antiSum[i][j - k] - antiSum[i - k][j] == sum) { // 反对角线和 = 行和
                            return k;
                        }
                    } else {
                        sum = rowS;
                        sameCnt = 1;
                    }
                }
            }
        }

        return 1;
    }
}
class Solution {
public:
    int largestMagicSquare(vector<vector<int>>& grid) {
        int m = grid.size(), n = grid[0].size();
        vector row_sum(m, vector<int>(n + 1));      // → 前缀和
        vector col_sum(m + 1, vector<int>(n));      // ↓ 前缀和
        vector diag_sum(m + 1, vector<int>(n + 1)); // ↘ 前缀和
        vector anti_sum(m + 1, vector<int>(n + 1)); // ↙ 前缀和

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                int x = grid[i][j];
                row_sum[i][j + 1] = row_sum[i][j] + x;
                col_sum[i + 1][j] = col_sum[i][j] + x;
                diag_sum[i + 1][j + 1] = diag_sum[i][j] + x;
                anti_sum[i + 1][j] = anti_sum[i][j + 1] + x;
            }
        }

        // is_same_col_sum[i][j] 表示右下角为 (i, j) 的子矩形,每列元素和是否都相等
        vector is_same_col_sum(m, vector<int8_t>(n));

        for (int k = min(m, n); k > 1; k--) {
            for (int i = k; i <= m; i++) {
                // 想象有一个 k×k 的窗口在向右滑动
                int same_cnt = 1;
                for (int j = 1; j < n; j++) {
                    if (col_sum[i][j] - col_sum[i - k][j] == col_sum[i][j - 1] - col_sum[i - k][j - 1]) {
                        same_cnt++;
                    } else {
                        same_cnt = 1;
                    }
                    // 连续 k 列元素和是否都一样
                    is_same_col_sum[i - 1][j] = same_cnt >= k;
                }
            }

            for (int j = k; j <= n; j++) {
                // 想象有一个 k×k 的窗口在向下滑动
                int sum_row = row_sum[0][j] - row_sum[0][j - k];
                int same_cnt = 1;
                for (int i = 2; i <= m; i++) {
                    int row_s = row_sum[i - 1][j] - row_sum[i - 1][j - k];
                    if (row_s == sum_row) {
                        same_cnt++;
                        if (same_cnt >= k && // 连续 k 行元素和都一样
                            is_same_col_sum[i - 1][j - 1] && // 连续 k 列元素和都一样
                            col_sum[i][j - 1] - col_sum[i - k][j - 1] == sum_row && // 列和 = 行和
                            diag_sum[i][j] - diag_sum[i - k][j - k] == sum_row && // 主对角线和 = 行和
                            anti_sum[i][j - k] - anti_sum[i - k][j] == sum_row) { // 反对角线和 = 行和
                            return k;
                        }
                    } else {
                        sum_row = row_s;
                        same_cnt = 1;
                    }
                }
            }
        }

        return 1;
    }
};
func largestMagicSquare(grid [][]int) int {
m, n := len(grid), len(grid[0])
rowSum := make([][]int, m)    // → 前缀和
colSum := make([][]int, m+1)  // ↓ 前缀和
diagSum := make([][]int, m+1) // ↘ 前缀和
antiSum := make([][]int, m+1) // ↙ 前缀和
for i := range m + 1 {
colSum[i] = make([]int, n)
diagSum[i] = make([]int, n+1)
antiSum[i] = make([]int, n+1)
}
for i, row := range grid {
rowSum[i] = make([]int, n+1)
for j, x := range row {
rowSum[i][j+1] = rowSum[i][j] + x
colSum[i+1][j] = colSum[i][j] + x
diagSum[i+1][j+1] = diagSum[i][j] + x
antiSum[i+1][j] = antiSum[i][j+1] + x
}
}

// isSameColSum[i][j] 表示右下角为 (i, j) 的子矩形,每列元素和是否都相等
isSameColSum := make([][]bool, m)
for i := range isSameColSum {
isSameColSum[i] = make([]bool, n)
}
for k := min(m, n); k > 1; k-- {
for i := k; i <= m; i++ {
// 想象有一个 k×k 的窗口在向右滑动
sameCnt := 1
for j := 1; j < n; j++ {
if colSum[i][j]-colSum[i-k][j] == colSum[i][j-1]-colSum[i-k][j-1] {
sameCnt++
} else {
sameCnt = 1
}
// 连续 k 列元素和是否都一样
isSameColSum[i-1][j] = sameCnt >= k
}
}

for j := k; j <= n; j++ {
// 想象有一个 k×k 的窗口在向下滑动
sum := rowSum[0][j] - rowSum[0][j-k]
sameCnt := 1
for i := 2; i <= m; i++ {
rowS := rowSum[i-1][j] - rowSum[i-1][j-k]
if rowS == sum {
sameCnt++
if sameCnt >= k && // 连续 k 行元素和都一样
isSameColSum[i-1][j-1] && // 连续 k 列元素和都一样
colSum[i][j-1]-colSum[i-k][j-1] == sum && // 列和 = 行和
diagSum[i][j]-diagSum[i-k][j-k] == sum && // 主对角线和 = 行和
antiSum[i][j-k]-antiSum[i-k][j] == sum {  // 反对角线和 = 行和
return k
}
} else {
sum = rowS
sameCnt = 1
}
}
}
}

return 1
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(mn\min(m,n))$,其中 $m$ 和 $n$ 分别是 $\textit{grid}$ 的行数和列数。
  • 空间复杂度:$\mathcal{O}(mn)$。

相似题目

1878. 矩阵中最大的三个菱形和

专题训练

见下面数据结构题单的「一、前缀和」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

❌
❌