非暴力做法(Python/Java/C++/Go)
本题和周赛第四题是一样的,请看 我的题解。
本题和周赛第四题是一样的,请看 我的题解。
本文把 $\textit{width}$ 简称为 $w$,把 $\textit{height}$ 简称为 $h$。
根据题意,机器人只能在网格图的最外圈中移动,移动一整圈需要 $2(w+h-2)$ 步。
![]()
设当前移动的总步数模 $2(w+h-2)$ 的结果为 $s$。分类讨论:
⚠注意:总步数为 $0$ 时,机器人面朝东,但总步数为 $2(w+h-2)$ 的正整数倍时,机器人面朝南。需要特判总步数为 $0$ 的特殊情况吗?不需要,当总步数大于 $0$ 时,我们可以把取模后的范围从 $[0, 2(w+h-2)-1]$ 调整到 $[1, 2(w+h-2)]$,从而使原先模为 $0$ 的总步数变成 $2(w+h-2)$,落入面朝南的分支中,这样就可以避免特判了。
class Robot:
def __init__(self, width: int, height: int) -> None:
self.w = width
self.h = height
self.s = 0
def step(self, num: int) -> None:
# 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
# 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s == 0 时的方向
self.s = (self.s + num - 1) % ((self.w + self.h - 2) * 2) + 1
def _getState(self) -> Tuple[int, int, str]:
w, h, s = self.w, self.h, self.s
if s < w:
return s, 0, "East"
if s < w + h - 1:
return w - 1, s - w + 1, "North"
if s < w * 2 + h - 2:
return w * 2 + h - s - 3, h - 1, "West"
return 0, (w + h) * 2 - s - 4, "South"
def getPos(self) -> List[int]:
x, y, _ = self._getState()
return [x, y]
def getDir(self) -> str:
return self._getState()[2]
class Robot {
private int w, h, s;
public Robot(int width, int height) {
w = width;
h = height;
s = 0;
}
public void step(int num) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s == 0 时的方向
s = (s + num - 1) % ((w + h - 2) * 2) + 1;
}
public int[] getPos() {
Object[] t = getState();
return new int[]{(int) t[0], (int) t[1]};
}
public String getDir() {
Object[] t = getState();
return (String) t[2];
}
private Object[] getState() {
if (s < w) {
return new Object[]{s, 0, "East"};
} else if (s < w + h - 1) {
return new Object[]{w - 1, s - w + 1, "North"};
} else if (s < w * 2 + h - 2) {
return new Object[]{w * 2 + h - s - 3, h - 1, "West"};
} else {
return new Object[]{0, (w + h) * 2 - s - 4, "South"};
}
}
}
class Robot {
int w;
int h;
int s = 0;
tuple<int, int, string> getState() {
if (s < w) {
return {s, 0, "East"};
} else if (s < w + h - 1) {
return {w - 1, s - w + 1, "North"};
} else if (s < w * 2 + h - 2) {
return {w * 2 + h - s - 3, h - 1, "West"};
} else {
return {0, (w + h) * 2 - s - 4, "South"};
}
}
public:
Robot(int width, int height) : w(width), h(height) {}
void step(int num) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s == 0 时的方向
s = (s + num - 1) % ((w + h - 2) * 2) + 1;
}
vector<int> getPos() {
auto [x, y, _] = getState();
return {x, y};
}
string getDir() {
return get<2>(getState());
}
};
typedef struct {
int w;
int h;
int s;
} Robot;
Robot* robotCreate(int width, int height) {
Robot* r = malloc(sizeof(Robot));
r->w = width;
r->h = height;
r->s = 0;
return r;
}
void robotStep(Robot* r, int num) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s == 0 时的方向
r->s = (r->s + num - 1) % ((r->w + r->h - 2) * 2) + 1;
}
int* robotGetPos(Robot* r, int* returnSize) {
int w = r->w, h = r->h, s = r->s;
int x, y;
if (s < w) {
x = s;
y = 0;
} else if (s < w + h - 1) {
x = w - 1;
y = s - w + 1;
} else if (s < w * 2 + h - 2) {
x = w * 2 + h - s - 3;
y = h - 1;
} else {
x = 0;
y = (w + h) * 2 - s - 4;
}
int* ans = malloc(2 * sizeof(int));
*returnSize = 2;
ans[0] = x;
ans[1] = y;
return ans;
}
char* robotGetDir(Robot* r) {
int w = r->w, h = r->h, s = r->s;
if (s < w) {
return "East";
} else if (s < w + h - 1) {
return "North";
} else if (s < w * 2 + h - 2) {
return "West";
} else {
return "South";
}
}
void robotFree(Robot* r) {
free(r);
}
type Robot struct {
w, h, step int
}
func Constructor(width, height int) Robot {
return Robot{width, height, 0}
}
func (r *Robot) Step(num int) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 step 取模调整到 [1, (w+h-2)*2],这样不需要特判 step == 0 时的方向
r.step = (r.step+num-1)%((r.w+r.h-2)*2) + 1
}
func (r *Robot) getState() (x, y int, dir string) {
w, h, step := r.w, r.h, r.step
switch {
case step < w:
return step, 0, "East"
case step < w+h-1:
return w - 1, step - w + 1, "North"
case step < w*2+h-2:
return w*2 + h - step - 3, h - 1, "West"
default:
return 0, (w+h)*2 - step - 4, "South"
}
}
func (r *Robot) GetPos() []int {
x, y, _ := r.getState()
return []int{x, y}
}
func (r *Robot) GetDir() string {
_, _, d := r.getState()
return d
}
var Robot = function(width, height) {
this.w = width;
this.h = height;
this.s = 0;
};
Robot.prototype.step = function(num) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s === 0 时的方向
this.s = (this.s + num - 1) % ((this.w + this.h - 2) * 2) + 1;
};
Robot.prototype.getState = function() {
const w = this.w, h = this.h, s = this.s;
if (s < w) {
return [s, 0, "East"];
} else if (s < w + h - 1) {
return [w - 1, s - w + 1, "North"];
} else if (s < w * 2 + h - 2) {
return [w * 2 + h - s - 3, h - 1, "West"];
} else {
return [0, (w + h) * 2 - s - 4, "South"];
}
};
Robot.prototype.getPos = function() {
const [x, y, _] = this.getState();
return [x, y];
};
Robot.prototype.getDir = function() {
return this.getState()[2];
};
struct Robot {
w: i32,
h: i32,
s: i32,
}
impl Robot {
fn new(width: i32, height: i32) -> Self {
Self { w: width, h: height, s: 0 }
}
fn step(&mut self, num: i32) {
// 由于机器人只能走外圈,那么走 (w+h-2)*2 步后会回到起点
// 把 s 取模调整到 [1, (w+h-2)*2],这样不需要特判 s == 0 时的方向
self.s = (self.s + num - 1) % ((self.w + self.h - 2) * 2) + 1;
}
fn get_state(&self) -> (i32, i32, String) {
let w = self.w;
let h = self.h;
let s = self.s;
if s < w {
(s, 0, "East".to_string())
} else if s < w + h - 1 {
(w - 1, s - w + 1, "North".to_string())
} else if s < w * 2 + h - 2 {
(w * 2 + h - s - 3, h - 1, "West".to_string())
} else {
(0, (w + h) * 2 - s - 4, "South".to_string())
}
}
fn get_pos(&self) -> Vec<i32> {
let (x, y, _) = self.get_state();
vec![x, y]
}
fn get_dir(&self) -> String {
let (_, _, d) = self.get_state();
d
}
}
欢迎关注 B站@灵茶山艾府
总体思路:模拟机器人行走的过程。一步一步走,如果下一步是障碍物,则停止移动,继续执行下一个命令。
怎么表示机器人移动的方向?
我们可以用一个向量数组
$$
\textit{dirs} = [(0, 1), (1, 0), (0, -1), (-1, 0)]
$$
分别表示顺时针的上右下左(北东南西)四个方向。
用一个下标 $k$ 表示当前机器人的方向为 $\textit{dirs}[k]$,初始 $k=0$,表示初始方向为上。
设 $c = \textit{commands}[i]$。当 $c>0$ 时,机器人要往 $\textit{dirs}[k]$ 方向移动 $c$ 个单位长度。一步一步移动,如果发现下一步是障碍物,则停止移动,继续执行下一个命令。
为了快速判断某个坐标是否为障碍物(是否在 $\textit{obstacles}$ 数组中),我们可以把 $\textit{obstacles}$ 转成哈希集合,判断坐标是否在哈希集合中。
注:可能起点也有障碍物,这相当于机器人站在障碍物上,是可以继续移动的。
###py
DIRS = (0, 1), (1, 0), (0, -1), (-1, 0) # 上右下左(顺时针)
class Solution:
def robotSim(self, commands: List[int], obstacles: List[List[int]]) -> int:
obstacle_set = set(map(tuple, obstacles))
ans = x = y = k = 0
for c in commands:
if c == -1: # 右转
k = (k + 1) % 4
elif c == -2: # 左转
k = (k + 3) % 4
else: # 直行
while c > 0 and (x + DIRS[k][0], y + DIRS[k][1]) not in obstacle_set:
x += DIRS[k][0]
y += DIRS[k][1]
c -= 1
ans = max(ans, x * x + y * y)
return ans
###java
class Solution {
private static final int[][] DIRS = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}}; // 上右下左(顺时针)
public int robotSim(int[] commands, int[][] obstacles) {
HashSet<Integer> obstacleSet = new HashSet<>(obstacles.length, 1); // 预分配空间
final int OFFSET = (int) 3e4;
for (int[] p : obstacles) {
// p 是两个 16 位整数,合并成一个 32 位整数
obstacleSet.add((p[0] + OFFSET) << 16 | (p[1] + OFFSET));
}
int x = 0;
int y = 0;
int k = 0;
int ans = 0;
for (int c : commands) {
if (c == -1) { // 右转
k = (k + 1) % 4;
} else if (c == -2) { // 左转
k = (k + 3) % 4;
} else { // 直行
while (c-- > 0) {
int nx = x + DIRS[k][0];
int ny = y + DIRS[k][1];
if (obstacleSet.contains((nx + OFFSET) << 16 | (ny + OFFSET))) {
break;
}
x = nx;
y = ny;
}
ans = Math.max(ans, x * x + y * y);
}
}
return ans;
}
}
###cpp
class Solution {
static constexpr int DIRS[4][2] = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}}; // 上右下左(顺时针)
public:
int robotSim(vector<int>& commands, vector<vector<int>>& obstacles) {
unordered_set<int> obstacle_set;
obstacle_set.reserve(obstacles.size()); // 预分配空间
constexpr int OFFSET = 3e4;
for (auto& p : obstacles) {
// p 是两个 16 位整数,合并成一个 32 位整数
obstacle_set.insert((p[0] + OFFSET) << 16 | (p[1] + OFFSET));
}
int ans = 0, x = 0, y = 0, k = 0;
for (int c : commands) {
if (c == -1) { // 右转
k = (k + 1) % 4;
} else if (c == -2) { // 左转
k = (k + 3) % 4;
} else { // 直行
while (c--) {
int nx = x + DIRS[k][0];
int ny = y + DIRS[k][1];
if (obstacle_set.contains((nx + OFFSET) << 16 | (ny + OFFSET))) {
break;
}
x = nx;
y = ny;
}
ans = max(ans, x * x + y * y);
}
}
return ans;
}
};
###go
type pair struct{ x, y int }
var dirs = [...]pair{{0, 1}, {1, 0}, {0, -1}, {-1, 0}} // 上右下左(顺时针)
func robotSim(commands []int, obstacles [][]int) (ans int) {
isObstacle := make(map[pair]bool, len(obstacles)) // 预分配空间
for _, p := range obstacles {
isObstacle[pair{p[0], p[1]}] = true
}
x, y, k := 0, 0, 0
for _, c := range commands {
if c == -1 { // 右转
k = (k + 1) % 4
} else if c == -2 { // 左转
k = (k + 3) % 4
} else { // 直行
for ; c > 0 && !isObstacle[pair{x + dirs[k].x, y + dirs[k].y}]; c-- {
x += dirs[k].x
y += dirs[k].y
}
ans = max(ans, x*x+y*y)
}
}
return
}
设 $c = \textit{commands}[i]$。
因此,这两种情况可以进一步统一成,把 $k$ 更新成 $(k + 2c + 3)\bmod 4$。但是,当 $k=0$ 且 $2c+3=-1$ 时,$k + 2c + 3=-1$ 是负数。对于模 $4$ 运算,多增加 $4$ 不影响结果,所以可以把 $2c+3$ 改成 $2c+7$,也就把 $k$ 更新成
$$
(k + 2c + 7)\bmod 4
$$
###py
DIRS = (0, 1), (1, 0), (0, -1), (-1, 0) # 上右下左(顺时针)
class Solution:
def robotSim(self, commands: List[int], obstacles: List[List[int]]) -> int:
obstacle_set = set(map(tuple, obstacles))
ans = x = y = k = 0
for c in commands:
if c < 0:
k = (k + c * 2 + 7) % 4 # c=-2 左转,c=-1 右转
continue
while c > 0 and (x + DIRS[k][0], y + DIRS[k][1]) not in obstacle_set:
x += DIRS[k][0]
y += DIRS[k][1]
c -= 1
ans = max(ans, x * x + y * y)
return ans
###java
class Solution {
private static final int[][] DIRS = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}}; // 上右下左(顺时针)
public int robotSim(int[] commands, int[][] obstacles) {
HashSet<Integer> obstacleSet = new HashSet<>(obstacles.length, 1); // 预分配空间
final int OFFSET = (int) 3e4;
for (int[] p : obstacles) {
// p 是两个 16 位整数,合并成一个 32 位整数
obstacleSet.add((p[0] + OFFSET) << 16 | (p[1] + OFFSET));
}
int x = 0;
int y = 0;
int k = 0;
int ans = 0;
for (int c : commands) {
if (c < 0) {
k = (k + c * 2 + 7) % 4; // c=-2 左转,c=-1 右转
continue;
}
while (c-- > 0) {
int nx = x + DIRS[k][0];
int ny = y + DIRS[k][1];
if (obstacleSet.contains((nx + OFFSET) << 16 | (ny + OFFSET))) {
break;
}
x = nx;
y = ny;
}
ans = Math.max(ans, x * x + y * y);
}
return ans;
}
}
###cpp
class Solution {
static constexpr int DIRS[4][2] = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}}; // 上右下左(顺时针)
public:
int robotSim(vector<int>& commands, vector<vector<int>>& obstacles) {
unordered_set<int> obstacle_set;
obstacle_set.reserve(obstacles.size()); // 预分配空间
constexpr int OFFSET = 3e4;
for (auto& p : obstacles) {
// p 是两个 16 位整数,合并成一个 32 位整数
obstacle_set.insert((p[0] + OFFSET) << 16 | (p[1] + OFFSET));
}
int ans = 0, x = 0, y = 0, k = 0;
for (int c : commands) {
if (c < 0) {
k = (k + c * 2 + 7) % 4; // c=-2 左转,c=-1 右转
continue;
}
while (c--) {
int nx = x + DIRS[k][0];
int ny = y + DIRS[k][1];
if (obstacle_set.contains((nx + OFFSET) << 16 | (ny + OFFSET))) {
break;
}
x = nx;
y = ny;
}
ans = max(ans, x * x + y * y);
}
return ans;
}
};
###go
type pair struct{ x, y int }
var dirs = [...]pair{{0, 1}, {1, 0}, {0, -1}, {-1, 0}} // 上右下左(顺时针)
func robotSim(commands []int, obstacles [][]int) (ans int) {
isObstacle := make(map[pair]bool, len(obstacles)) // 预分配空间
for _, p := range obstacles {
isObstacle[pair{p[0], p[1]}] = true
}
x, y, k := 0, 0, 0
for _, c := range commands {
if c < 0 {
k = (k + c*2 + 7) % 4 // c=-2 左转,c=-1 右转
continue
}
for ; c > 0 && !isObstacle[pair{x + dirs[k].x, y + dirs[k].y}]; c-- {
x += dirs[k].x
y += dirs[k].y
}
ans = max(ans, x*x+y*y)
}
return
}
欢迎关注 B站@灵茶山艾府
机器人的横坐标,等于向右移动的次数,减去向左移动的次数。如果 $\texttt{R}$ 的个数等于 $\texttt{L}$ 的个数,那么最终横坐标为 $0$。
机器人的纵坐标,等于向上移动的次数,减去向下移动的次数。如果 $\texttt{U}$ 的个数等于 $\texttt{D}$ 的个数,那么最终纵坐标为 $0$。
这两个条件同时成立,才能回到原点。
###py
class Solution:
def judgeCircle(self, moves: str) -> bool:
return moves.count('R') == moves.count('L') and \
moves.count('U') == moves.count('D')
###py
class Solution:
def judgeCircle(self, moves: str) -> bool:
cnt = Counter(moves)
return cnt['R'] == cnt['L'] and cnt['U'] == cnt['D']
###java
class Solution {
public boolean judgeCircle(String moves) {
int x = 0;
int y = 0;
for (char move : moves.toCharArray()) {
if (move == 'R') {
x++;
} else if (move == 'L') {
x--;
} else if (move == 'U') {
y++;
} else {
y--;
}
}
return x == 0 && y == 0;
}
}
###cpp
class Solution {
public:
bool judgeCircle(string moves) {
return ranges::count(moves, 'R') == ranges::count(moves, 'L') &&
ranges::count(moves, 'U') == ranges::count(moves, 'D');
}
};
###c
bool judgeCircle(char* moves) {
int x = 0, y = 0;
for (int i = 0; moves[i]; i++) {
char move = moves[i];
if (move == 'R') {
x++;
} else if (move == 'L') {
x--;
} else if (move == 'U') {
y++;
} else {
y--;
}
}
return x == 0 && y == 0;
}
###go
func judgeCircle(moves string) bool {
return strings.Count(moves, "R") == strings.Count(moves, "L") &&
strings.Count(moves, "U") == strings.Count(moves, "D")
}
###js
var judgeCircle = function(moves) {
const cnt = _.countBy(moves);
return cnt['R'] === cnt['L'] && cnt['U'] === cnt['D'];
};
###rust
impl Solution {
pub fn judge_circle(moves: String) -> bool {
moves.matches('R').count() == moves.matches('L').count() &&
moves.matches('U').count() == moves.matches('D').count()
}
}
欢迎关注 B站@灵茶山艾府
脑筋急转弯:由于题目保证代价均为非负数,所以除了径直走以外,其它弯弯绕绕的策略都不可能更优,那么直接统计径直走的代价即可。
设起点为 $(x_0,y_0)$,终点为 $(x_1,y_1)$。
分别计算上下移动的代价,左右移动的代价,二者之和就是总代价。
代码实现时,不需要根据 $x_0$ 和 $x_1$ 的大小关系分情况讨论,而是计算 $\textit{rowCosts}$ 的子数组 $[\min(x_0,x_1), \max(x_0,x_1)]$ 的元素和,再减去多算的起点代价 $\textit{rowCosts}[x_0]$。对于 $y_0$ 和 $y_1$ 同理。
class Solution:
def minCost(self, startPos: List[int], homePos: List[int], rowCosts: List[int], colCosts: List[int]) -> int:
x0, y0 = startPos
x1, y1 = homePos
# 起点的代价不计入,先减去
ans = -rowCosts[x0] - colCosts[y0]
# 累加代价(包含起点)
ans += sum(rowCosts[min(x0, x1): max(x0, x1) + 1])
ans += sum(colCosts[min(y0, y1): max(y0, y1) + 1])
return ans
class Solution {
public int minCost(int[] startPos, int[] homePos, int[] rowCosts, int[] colCosts) {
int x0 = startPos[0], y0 = startPos[1];
int x1 = homePos[0], y1 = homePos[1];
// 起点的代价不计入,先减去
int ans = -rowCosts[x0] - colCosts[y0];
// 累加代价(包含起点)
int l1 = Math.min(x0, x1), r1 = Math.max(x0, x1);
for (int i = l1; i <= r1; i++) {
ans += rowCosts[i];
}
int l2 = Math.min(y0, y1), r2 = Math.max(y0, y1);
for (int i = l2; i <= r2; i++) {
ans += colCosts[i];
}
return ans;
}
}
class Solution {
public:
int minCost(vector<int>& startPos, vector<int>& homePos, vector<int>& rowCosts, vector<int>& colCosts) {
int x0 = startPos[0], y0 = startPos[1];
int x1 = homePos[0], y1 = homePos[1];
// 起点的代价不计入,先减去
int ans = -rowCosts[x0] - colCosts[y0];
// 累加代价(包含起点)
ans += reduce(rowCosts.begin() + min(x0, x1), rowCosts.begin() + max(x0, x1) + 1, 0);
ans += reduce(colCosts.begin() + min(y0, y1), colCosts.begin() + max(y0, y1) + 1, 0);
return ans;
}
};
#define MIN(a, b) ((b) < (a) ? (b) : (a))
#define MAX(a, b) ((b) > (a) ? (b) : (a))
int minCost(int* startPos, int startPosSize, int* homePos, int homePosSize, int* rowCosts, int rowCostsSize, int* colCosts, int colCostsSize) {
int x0 = startPos[0], y0 = startPos[1];
int x1 = homePos[0], y1 = homePos[1];
// 起点的代价不计入,先减去
int ans = -rowCosts[x0] - colCosts[y0];
// 累加代价(包含起点)
int l1 = MIN(x0, x1), r1 = MAX(x0, x1);
for (int i = l1; i <= r1; i++) {
ans += rowCosts[i];
}
int l2 = MIN(y0, y1), r2 = MAX(y0, y1);
for (int i = l2; i <= r2; i++) {
ans += colCosts[i];
}
return ans;
}
func minCost(startPos, homePos, rowCosts, colCosts []int) int {
x0, y0 := startPos[0], startPos[1]
x1, y1 := homePos[0], homePos[1]
// 起点的代价不计入,先减去
ans := -rowCosts[x0] - colCosts[y0]
// 累加代价(包含起点)
for _, cost := range rowCosts[min(x0, x1) : max(x0, x1)+1] {
ans += cost
}
for _, cost := range colCosts[min(y0, y1) : max(y0, y1)+1] {
ans += cost
}
return ans
}
var minCost = function(startPos, homePos, rowCosts, colCosts) {
const [x0, y0] = startPos;
const [x1, y1] = homePos;
// 起点的代价不计入,先减去
let ans = -rowCosts[x0] - colCosts[y0];
// 累加代价(包含起点)
ans += _.sum(rowCosts.slice(Math.min(x0, x1), Math.max(x0, x1) + 1));
ans += _.sum(colCosts.slice(Math.min(y0, y1), Math.max(y0, y1) + 1));
return ans;
};
impl Solution {
pub fn min_cost(start_pos: Vec<i32>, home_pos: Vec<i32>, row_costs: Vec<i32>, col_costs: Vec<i32>) -> i32 {
let x0 = start_pos[0] as usize;
let y0 = start_pos[1] as usize;
let x1 = home_pos[0] as usize;
let y1 = home_pos[1] as usize;
// 起点的代价不计入,先减去
let mut ans = -row_costs[x0] - col_costs[y0];
// 累加代价(包含起点)
ans += row_costs[x0.min(x1)..=x0.max(x1)].iter().sum::<i32>();
ans += col_costs[y0.min(y1)..=y0.max(y1)].iter().sum::<i32>();
ans
}
}
本题是图论中的最短路问题。在有负数边权的情况下,可以用 Bellman-Ford 算法解决。需要注意的是,如果有负环,则最小代价为 $-\infty$。
见下面贪心与思维题单的「§5.2 脑筋急转弯」。
欢迎关注 B站@灵茶山艾府
先把机器人和墙壁从小到大排序。
考虑最右边的机器人。分类讨论:
这些问题都是和原问题相似的、规模更小的子问题,可以用递归解决。
注:从右往左思考,主要是为了方便把递归翻译成递推。从左往右思考也是可以的。
根据上面的讨论,定义状态为 $\textit{dfs}(i,j)$,表示对于(排序后)下标在 $[0,i]$ 中的机器人,在机器人 $i+1$ 往左/右射击的前提下,能摧毁的最大墙壁数量。其中 $j=0$ 表示机器人 $i+1$ 往左射击,$j=1$ 表示机器人 $i+1$ 往右射击。
考虑机器人 $i$ 往哪个方向射击:
这两种情况取最大值,就得到了 $\textit{dfs}(i,j)$,即
$$
\textit{dfs}(i,j) = \max(\textit{dfs}(i-1,0) + \textit{cur}_0- \textit{left}, \textit{dfs}(i-1,1) + \textit{right} - \textit{cur}_1)
$$
递归边界:$\textit{dfs}(-1,j)=0$。没有机器人,无法摧毁墙壁。
递归入口:$\textit{dfs}(n-1,1)$。机器人 $n-1$ 右边没有机器人,等价于右边那个机器人往右射击。
考虑到整个递归过程中有大量重复递归调用(递归入参相同)。由于递归函数没有副作用,同样的入参无论计算多少次,算出来的结果都是一样的,因此可以用记忆化搜索来优化:
⚠注意:$\textit{memo}$ 数组的初始值一定不能等于要记忆化的值!例如初始值设置为 $0$,并且要记忆化的 $\textit{dfs}(i,j)$ 也等于 $0$,那就没法判断 $0$ 到底表示第一次遇到这个状态,还是表示之前遇到过了,从而导致记忆化失效。一般把初始值设置为 $-1$。
Python 用户可以无视上面这段,直接用
@cache装饰器。
具体请看视频讲解 动态规划入门:从记忆化搜索到递推【基础算法精讲 17】,其中包含把记忆化搜索 1:1 翻译成递推的技巧。
本题视频讲解,欢迎点赞关注~
###py
class Solution:
def maxWalls(self, robots: List[int], distance: List[int], walls: List[int]) -> int:
n = len(robots)
a = sorted(zip(robots, distance), key=lambda p: p[0])
walls.sort()
@cache # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
def dfs(i: int, j: int) -> int:
if i < 0:
return 0
x, d = a[i]
# 往左射,墙的坐标范围为 [left_x, x]
left_x = x - d
if i > 0:
left_x = max(left_x, a[i - 1][0] + 1) # +1 表示不能射到左边那个机器人
left = bisect_left(walls, left_x)
cur = bisect_right(walls, x)
res_left = dfs(i - 1, 0) + cur - left # 下标在 [left, cur-1] 中的墙都能摧毁
# 往右射,墙的坐标范围为 [x, right_x]
right_x = x + d
if i + 1 < n:
x2, d2 = a[i + 1]
if j == 0: # 右边那个机器人往左射
x2 -= d2
right_x = min(right_x, x2 - 1) # -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right = bisect_right(walls, right_x)
cur = bisect_left(walls, x)
res_right = dfs(i - 1, 1) + right - cur # 下标在 [cur, right-1] 中的墙都能摧毁
return max(res_left, res_right)
return dfs(n - 1, 1)
###java
class Solution {
public int maxWalls(int[] robots, int[] distance, int[] walls) {
int n = robots.length;
int[][] a = new int[n][2];
for (int i = 0; i < n; i++) {
a[i][0] = robots[i];
a[i][1] = distance[i];
}
Arrays.sort(a, (p, q) -> p[0] - q[0]);
Arrays.sort(walls);
int[][] memo = new int[n][2];
for (int[] row : memo) {
Arrays.fill(row, -1); // -1 表示没有计算过
}
return dfs(n - 1, 1, a, walls, memo);
}
private int dfs(int i, int j, int[][] a, int[] walls, int[][] memo) {
if (i < 0) {
return 0;
}
if (memo[i][j] != -1) { // 之前计算过
return memo[i][j];
}
int x = a[i][0], d = a[i][1];
// 往左射,墙的坐标范围为 [leftX, x]
int leftX = x - d;
if (i > 0) {
leftX = Math.max(leftX, a[i - 1][0] + 1); // +1 表示不能射到左边那个机器人
}
int left = lowerBound(walls, leftX);
int cur = lowerBound(walls, x + 1);
int resLeft = dfs(i - 1, 0, a, walls, memo) + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [x, rightX]
int rightX = x + d;
if (i + 1 < a.length) {
int x2 = a[i + 1][0];
if (j == 0) { // 右边那个机器人往左射
x2 -= a[i + 1][1];
}
rightX = Math.min(rightX, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
}
int right = lowerBound(walls, rightX + 1);
cur = lowerBound(walls, x);
int resRight = dfs(i - 1, 1, a, walls, memo) + right - cur; // 下标在 [cur, right-1] 中的墙都能摧毁
return memo[i][j] = Math.max(resLeft, resRight); // 记忆化
}
// 见 https://www.bilibili.com/video/BV1AP41137w7/
private int lowerBound(int[] nums, int target) {
int left = -1;
int right = nums.length;
while (left + 1 < right) {
int mid = left + (right - left) / 2;
if (nums[mid] >= target) {
right = mid;
} else {
left = mid;
}
}
return right;
}
}
###cpp
class Solution {
public:
int maxWalls(vector<int>& robots, vector<int>& distance, vector<int>& walls) {
int n = robots.size();
struct Pair { int x, d; };
vector<Pair> a(n);
for (int i = 0; i < n; i++) {
a[i] = {robots[i], distance[i]};
}
ranges::sort(a, {}, &Pair::x);
ranges::sort(walls);
vector memo(n, array<int, 2>{-1, -1}); // -1 表示没有计算过
auto dfs = [&](this auto&& dfs, int i, int j) -> int {
if (i < 0) {
return 0;
}
int& res = memo[i][j]; // 注意这里是引用
if (res != -1) { // 之前计算过
return res;
}
auto [x, d] = a[i];
// 往左射,墙的坐标范围为 [left_x, x]
int left_x = x - d;
if (i > 0) {
left_x = max(left_x, a[i - 1].x + 1); // +1 表示不能射到左边那个机器人
}
int left = ranges::lower_bound(walls, left_x) - walls.begin();
int cur = ranges::upper_bound(walls, x) - walls.begin();
res = dfs(i - 1, 0) + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [x, right_x]
int right_x = x + d;
if (i + 1 < n) {
auto [x2, d2] = a[i + 1];
if (j == 0) { // 右边那个机器人往左射
x2 -= d2;
}
right_x = min(right_x, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
}
int right = ranges::upper_bound(walls, right_x) - walls.begin();
cur = ranges::lower_bound(walls, x) - walls.begin();
res = max(res, dfs(i - 1, 1) + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
return res;
};
return dfs(n - 1, 1);
}
};
###go
func maxWalls(robots []int, distance []int, walls []int) int {
n := len(robots)
type pair struct{ x, d int }
a := make([]pair, n)
for i, x := range robots {
a[i] = pair{x, distance[i]}
}
slices.SortFunc(a, func(a, b pair) int { return a.x - b.x })
slices.Sort(walls)
memo := make([][2]int, n)
for i := range memo {
memo[i] = [2]int{-1, -1}
}
var dfs func(int, int) int
dfs = func(i, j int) int {
if i < 0 {
return 0
}
p := &memo[i][j]
if *p != -1 {
return *p
}
// 往左射,墙的坐标范围为 [leftX, a[i].x]
leftX := a[i].x - a[i].d
if i > 0 {
leftX = max(leftX, a[i-1].x+1) // +1 表示不能射到左边那个机器人
}
left := sort.SearchInts(walls, leftX)
cur := sort.SearchInts(walls, a[i].x+1)
res := dfs(i-1, 0) + cur - left // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [a[i].x, rightX]
rightX := a[i].x + a[i].d
if i+1 < n {
x2 := a[i+1].x
if j == 0 { // 右边那个机器人往左射
x2 -= a[i+1].d
}
rightX = min(rightX, x2-1) // 不能到达右边那个机器人(或者它往左射到的墙)
}
right := sort.SearchInts(walls, rightX+1)
cur = sort.SearchInts(walls, a[i].x)
res = max(res, dfs(i-1, 1)+right-cur) // 下标在 [cur, right-1] 中的墙都能摧毁
*p = res
return res
}
return dfs(n-1, 1)
}
添加两个位置分别为 $0$ 和 $\infty$ 的机器人,当作哨兵,从而简化边界的判断。
###py
class Solution:
def maxWalls(self, robots: List[int], distance: List[int], walls: List[int]) -> int:
n = len(robots)
a = [(0, 0)] + sorted(zip(robots, distance), key=lambda p: p[0]) + [(inf, 0)]
walls.sort()
@cache # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
def dfs(i: int, j: int) -> int:
if i == 0:
return 0
x, d = a[i]
# 往左射,墙的坐标范围为 [left_x, x]
left_x = max(x - d, a[i - 1][0] + 1) # +1 表示不能射到左边那个机器人
left = bisect_left(walls, left_x)
cur = bisect_right(walls, x)
res_left = dfs(i - 1, 0) + cur - left # 下标在 [left, cur-1] 中的墙都能摧毁
# 往右射,墙的坐标范围为 [x, right_x]
x2, d2 = a[i + 1]
if j == 0: # 右边那个机器人往左射
x2 -= d2
right_x = min(x + d, x2 - 1) # -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right = bisect_right(walls, right_x)
cur = bisect_left(walls, x)
res_right = dfs(i - 1, 1) + right - cur # 下标在 [cur, right-1] 中的墙都能摧毁
return max(res_left, res_right)
return dfs(n, 1)
###java
class Solution {
public int maxWalls(int[] robots, int[] distance, int[] walls) {
int n = robots.length;
int[][] a = new int[n + 2][2];
for (int i = 0; i < n; i++) {
a[i][0] = robots[i];
a[i][1] = distance[i];
}
a[n + 1][0] = Integer.MAX_VALUE;
Arrays.sort(a, (p, q) -> p[0] - q[0]);
Arrays.sort(walls);
int[][] memo = new int[n + 1][2];
for (int[] row : memo) {
Arrays.fill(row, -1); // -1 表示没有计算过
}
return dfs(n, 1, a, walls, memo);
}
private int dfs(int i, int j, int[][] a, int[] walls, int[][] memo) {
if (i == 0) {
return 0;
}
if (memo[i][j] != -1) { // 之前计算过
return memo[i][j];
}
int x = a[i][0], d = a[i][1];
// 往左射,墙的坐标范围为 [leftX, x]
int leftX = Math.max(x - d, a[i - 1][0] + 1); // +1 表示不能射到左边那个机器人
int left = lowerBound(walls, leftX);
int cur = lowerBound(walls, x + 1);
int resLeft = dfs(i - 1, 0, a, walls, memo) + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [x, rightX]
int x2 = a[i + 1][0];
if (j == 0) { // 右边那个机器人往左射
x2 -= a[i + 1][1];
}
int rightX = Math.min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = lowerBound(walls, rightX + 1);
cur = lowerBound(walls, x);
int resRight = dfs(i - 1, 1, a, walls, memo) + right - cur; // 下标在 [cur, right-1] 中的墙都能摧毁
return memo[i][j] = Math.max(resLeft, resRight); // 记忆化
}
// 见 https://www.bilibili.com/video/BV1AP41137w7/
private int lowerBound(int[] nums, int target) {
int left = -1;
int right = nums.length;
while (left + 1 < right) {
int mid = left + (right - left) / 2;
if (nums[mid] >= target) {
right = mid;
} else {
left = mid;
}
}
return right;
}
}
###cpp
class Solution {
public:
int maxWalls(vector<int>& robots, vector<int>& distance, vector<int>& walls) {
int n = robots.size();
struct Pair { int x, d; };
vector<Pair> a(n + 2);
for (int i = 0; i < n; i++) {
a[i] = {robots[i], distance[i]};
}
a[n + 1].x = INT_MAX;
ranges::sort(a, {}, &Pair::x);
ranges::sort(walls);
vector memo(n + 1, array<int, 2>{-1, -1}); // -1 表示没有计算过
auto dfs = [&](this auto&& dfs, int i, int j) -> int {
if (i == 0) {
return 0;
}
int& res = memo[i][j]; // 注意这里是引用
if (res != -1) { // 之前计算过
return res;
}
auto [x, d] = a[i];
// 往左射,墙的坐标范围为 [left_x, x]
int left_x = max(x - d, a[i - 1].x + 1); // +1 表示不能射到左边那个机器人
int left = ranges::lower_bound(walls, left_x) - walls.begin();
int cur = ranges::upper_bound(walls, x) - walls.begin();
res = dfs(i - 1, 0) + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [x, right_x]
auto [x2, d2] = a[i + 1];
if (j == 0) { // 右边那个机器人往左射
x2 -= d2;
}
int right_x = min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = ranges::upper_bound(walls, right_x) - walls.begin();
cur = ranges::lower_bound(walls, x) - walls.begin();
res = max(res, dfs(i - 1, 1) + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
return res;
};
return dfs(n, 1);
}
};
###go
func maxWalls(robots []int, distance []int, walls []int) int {
n := len(robots)
type pair struct{ x, d int }
a := make([]pair, n+2)
for i, x := range robots {
a[i] = pair{x, distance[i]}
}
a[n+1].x = math.MaxInt // 哨兵
slices.SortFunc(a, func(a, b pair) int { return a.x - b.x })
slices.Sort(walls)
memo := make([][2]int, n+1)
for i := range memo {
memo[i] = [2]int{-1, -1}
}
var dfs func(int, int) int
dfs = func(i, j int) int {
if i == 0 {
return 0
}
p := &memo[i][j]
if *p != -1 {
return *p
}
// 往左射,墙的坐标范围为 [leftX, a[i].x]
leftX := max(a[i].x-a[i].d, a[i-1].x+1) // +1 表示不能射到左边那个机器人
left := sort.SearchInts(walls, leftX)
cur := sort.SearchInts(walls, a[i].x+1)
res := dfs(i-1, 0) + cur - left // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,墙的坐标范围为 [a[i].x, rightX]
x2 := a[i+1].x
if j == 0 { // 右边那个机器人往左射
x2 -= a[i+1].d
}
rightX := min(a[i].x+a[i].d, x2-1) // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right := sort.SearchInts(walls, rightX+1)
cur = sort.SearchInts(walls, a[i].x)
res = max(res, dfs(i-1, 1)+right-cur) // 下标在 [cur, right-1] 中的墙都能摧毁
*p = res
return res
}
return dfs(n, 1)
}
我们可以去掉递归中的「递」,只保留「归」的部分,即自底向上计算。
具体来说,$f[i][j]$ 的定义和 $\textit{dfs}(i,j)$ 的定义是一样的,都表示对于(排序,添加哨兵后的)下标在 $[1,i]$ 中的机器人,在机器人 $i+1$ 往左/右射击的前提下,能摧毁的最大墙壁数量。
相应的递推式(状态转移方程)也和 $\textit{dfs}$ 一样:
$$
f[i][j] = \max(f[i-1][0] + \textit{cur}_0- \textit{left}, f[i-1][1] + \textit{right} - \textit{cur}_1)
$$
初始值 $f[0][j]=0$,翻译自(添加哨兵后的)递归边界 $\textit{dfs}(0,j)=0$。
答案为 $f[n][1]$,翻译自(添加哨兵后的)递归入口 $\textit{dfs}(n,1)$。
###py
class Solution:
def maxWalls(self, robots: List[int], distance: List[int], walls: List[int]) -> int:
n = len(robots)
a = [(0, 0)] + sorted(zip(robots, distance), key=lambda p: p[0]) + [(inf, 0)]
walls.sort()
f = [[0, 0] for _ in range(n + 1)]
for i in range(1, n + 1):
x, d = a[i]
# 往左射,墙的坐标范围为 [left_x, x]
left_x = max(x - d, a[i - 1][0] + 1) # +1 表示不能射到左边那个机器人
left = bisect_left(walls, left_x)
cur = bisect_right(walls, x)
left_res = f[i - 1][0] + cur - left # 下标在 [left, cur-1] 中的墙都能摧毁
cur = bisect_left(walls, x)
for j in range(2):
# 往右射,墙的坐标范围为 [x, right_x]
x2, d2 = a[i + 1]
if j == 0: # 右边那个机器人往左射
x2 -= d2
right_x = min(x + d, x2 - 1) # -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right = bisect_right(walls, right_x)
f[i][j] = max(left_res, f[i - 1][1] + right - cur) # 下标在 [cur, right-1] 中的墙都能摧毁
return f[n][1]
###java
class Solution {
public int maxWalls(int[] robots, int[] distance, int[] walls) {
int n = robots.length;
int[][] a = new int[n + 2][2];
for (int i = 0; i < n; i++) {
a[i][0] = robots[i];
a[i][1] = distance[i];
}
a[n + 1][0] = Integer.MAX_VALUE;
Arrays.sort(a, (p, q) -> p[0] - q[0]);
Arrays.sort(walls);
int[][] f = new int[n + 1][2];
for (int i = 1; i <= n; i++) {
int x = a[i][0], d = a[i][1];
// 往左射,墙的坐标范围为 [leftX, x]
int leftX = Math.max(x - d, a[i - 1][0] + 1); // +1 表示不能射到左边那个机器人
int left = lowerBound(walls, leftX);
int cur = lowerBound(walls, x + 1);
int leftRes = f[i - 1][0] + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
cur = lowerBound(walls, x);
for (int j = 0; j < 2; j++) {
// 往右射,墙的坐标范围为 [x, rightX]
int x2 = a[i + 1][0];
if (j == 0) { // 右边那个机器人往左射
x2 -= a[i + 1][1];
}
int rightX = Math.min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = lowerBound(walls, rightX + 1);
f[i][j] = Math.max(leftRes, f[i - 1][1] + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[n][1];
}
// 见 https://www.bilibili.com/video/BV1AP41137w7/
private int lowerBound(int[] nums, int target) {
int left = -1;
int right = nums.length;
while (left + 1 < right) {
int mid = left + (right - left) / 2;
if (nums[mid] >= target) {
right = mid;
} else {
left = mid;
}
}
return right;
}
}
###cpp
class Solution {
public:
int maxWalls(vector<int>& robots, vector<int>& distance, vector<int>& walls) {
int n = robots.size();
struct Pair { int x, d; };
vector<Pair> a(n + 2);
for (int i = 0; i < n; i++) {
a[i] = {robots[i], distance[i]};
}
a[n + 1].x = INT_MAX;
ranges::sort(a, {}, &Pair::x);
ranges::sort(walls);
vector<array<int, 2>> f(n + 1);
for (int i = 1; i <= n; i++) {
auto [x, d] = a[i];
// 往左射,墙的坐标范围为 [left_x, x]
int left_x = max(x - d, a[i - 1].x + 1); // +1 表示不能射到左边那个机器人
int left = ranges::lower_bound(walls, left_x) - walls.begin();
int cur = ranges::upper_bound(walls, x) - walls.begin();
int left_res = f[i - 1][0] + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
cur = ranges::lower_bound(walls, x) - walls.begin();
for (int j = 0; j < 2; j++) {
// 往右射,墙的坐标范围为 [x, right_x]
auto [x2, d2] = a[i + 1];
if (j == 0) { // 右边那个机器人往左射
x2 -= d2;
}
int right_x = min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = ranges::upper_bound(walls, right_x) - walls.begin();
f[i][j] = max(left_res, f[i - 1][1] + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[n][1];
}
};
###go
func maxWalls(robots []int, distance []int, walls []int) int {
n := len(robots)
type pair struct{ x, d int }
a := make([]pair, n+2)
for i, x := range robots {
a[i] = pair{x, distance[i]}
}
a[n+1].x = math.MaxInt // 哨兵
slices.SortFunc(a, func(a, b pair) int { return a.x - b.x })
slices.Sort(walls)
f := make([][2]int, n+1)
for i := 1; i <= n; i++ {
p := a[i]
// 往左射,墙的坐标范围为 [leftX, p.x]
leftX := max(p.x-p.d, a[i-1].x+1) // +1 表示不能射到左边那个机器人
left := sort.SearchInts(walls, leftX)
cur := sort.SearchInts(walls, p.x+1)
leftRes := f[i-1][0] + cur - left // 下标在 [left, cur-1] 中的墙都能摧毁
cur = sort.SearchInts(walls, p.x)
for j := range 2 {
// 往右射,墙的坐标范围为 [p.x, rightX]
x2 := a[i+1].x
if j == 0 { // 右边那个机器人往左射
x2 -= a[i+1].d
}
rightX := min(p.x+p.d, x2-1) // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right := sort.SearchInts(walls, rightX+1)
f[i][j] = max(leftRes, f[i-1][1]+right-cur) // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[n][1]
}
观察上面的状态转移方程,在计算 $f[i+1]$ 时,只会用到 $f[i]$,不会用到比 $i$ 更早的状态。
类似 背包问题,去掉 $f$ 的第一个维度,把 $f[i+1]$ 和 $f[i]$ 保存到同一个数组中。
###py
# 手写 min max 更快
min = lambda a, b: b if b < a else a
max = lambda a, b: b if b > a else a
class Solution:
def maxWalls(self, robots: List[int], distance: List[int], walls: List[int]) -> int:
n = len(robots)
a = [(0, 0)] + sorted(zip(robots, distance), key=lambda p: p[0]) + [(inf, 0)]
walls.sort()
f = [0, 0]
for i in range(1, n + 1):
x, d = a[i]
# 往左射,墙的坐标范围为 [left_x, x]
left_x = max(x - d, a[i - 1][0] + 1) # +1 表示不能射到左边那个机器人
left = bisect_left(walls, left_x)
cur = bisect_right(walls, x)
left_res = f[0] + cur - left # 下标在 [left, cur-1] 中的墙都能摧毁
cur = bisect_left(walls, x)
for j in range(2):
# 往右射,墙的坐标范围为 [x, right_x]
x2, d2 = a[i + 1]
if j == 0: # 右边那个机器人往左射
x2 -= d2
right_x = min(x + d, x2 - 1) # -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right = bisect_right(walls, right_x)
f[j] = max(left_res, f[1] + right - cur) # 下标在 [cur, right-1] 中的墙都能摧毁
return f[1]
###java
class Solution {
public int maxWalls(int[] robots, int[] distance, int[] walls) {
int n = robots.length;
int[][] a = new int[n + 2][2];
for (int i = 0; i < n; i++) {
a[i][0] = robots[i];
a[i][1] = distance[i];
}
a[n + 1][0] = Integer.MAX_VALUE;
Arrays.sort(a, (p, q) -> p[0] - q[0]);
Arrays.sort(walls);
int[] f = new int[2];
for (int i = 1; i <= n; i++) {
int x = a[i][0], d = a[i][1];
// 往左射,墙的坐标范围为 [leftX, x]
int leftX = Math.max(x - d, a[i - 1][0] + 1); // +1 表示不能射到左边那个机器人
int left = lowerBound(walls, leftX);
int cur = lowerBound(walls, x + 1);
int leftRes = f[0] + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
cur = lowerBound(walls, x);
for (int j = 0; j < 2; j++) {
// 往右射,墙的坐标范围为 [x, rightX]
int x2 = a[i + 1][0];
if (j == 0) { // 右边那个机器人往左射
x2 -= a[i + 1][1];
}
int rightX = Math.min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = lowerBound(walls, rightX + 1);
f[j] = Math.max(leftRes, f[1] + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[1];
}
// 见 https://www.bilibili.com/video/BV1AP41137w7/
private int lowerBound(int[] nums, int target) {
int left = -1;
int right = nums.length;
while (left + 1 < right) {
int mid = left + (right - left) / 2;
if (nums[mid] >= target) {
right = mid;
} else {
left = mid;
}
}
return right;
}
}
###cpp
class Solution {
public:
int maxWalls(vector<int>& robots, vector<int>& distance, vector<int>& walls) {
int n = robots.size();
struct Pair { int x, d; };
vector<Pair> a(n + 2);
for (int i = 0; i < n; i++) {
a[i] = {robots[i], distance[i]};
}
a[n + 1].x = INT_MAX;
ranges::sort(a, {}, &Pair::x);
ranges::sort(walls);
int f[2]{};
for (int i = 1; i <= n; i++) {
auto [x, d] = a[i];
// 往左射,墙的坐标范围为 [left_x, x]
int left_x = max(x - d, a[i - 1].x + 1); // +1 表示不能射到左边那个机器人
int left = ranges::lower_bound(walls, left_x) - walls.begin();
int cur = ranges::upper_bound(walls, x) - walls.begin();
int left_res = f[0] + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
cur = ranges::lower_bound(walls, x) - walls.begin();
for (int j = 0; j < 2; j++) {
// 往右射,墙的坐标范围为 [x, right_x]
auto [x2, d2] = a[i + 1];
if (j == 0) { // 右边那个机器人往左射
x2 -= d2;
}
int right_x = min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
int right = ranges::upper_bound(walls, right_x) - walls.begin();
f[j] = max(left_res, f[1] + right - cur); // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[1];
}
};
###go
func maxWalls(robots []int, distance []int, walls []int) int {
n := len(robots)
type pair struct{ x, d int }
a := make([]pair, n+2)
for i, x := range robots {
a[i] = pair{x, distance[i]}
}
a[n+1].x = math.MaxInt // 哨兵
slices.SortFunc(a, func(a, b pair) int { return a.x - b.x })
slices.Sort(walls)
f := [2]int{}
for i := 1; i <= n; i++ {
p := a[i]
// 往左射,墙的坐标范围为 [leftX, p.x]
leftX := max(p.x-p.d, a[i-1].x+1) // +1 表示不能射到左边那个机器人
left := sort.SearchInts(walls, leftX)
cur := sort.SearchInts(walls, p.x+1)
leftRes := f[0] + cur - left // 下标在 [left, cur-1] 中的墙都能摧毁
cur = sort.SearchInts(walls, p.x)
for j := range 2 {
// 往右射,墙的坐标范围为 [p.x, rightX]
x2 := a[i+1].x
if j == 0 { // 右边那个机器人往左射
x2 -= a[i+1].d
}
rightX := min(p.x+p.d, x2-1) // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
right := sort.SearchInts(walls, rightX+1)
f[j] = max(leftRes, f[1]+right-cur) // 下标在 [cur, right-1] 中的墙都能摧毁
}
}
return f[1]
}
由于随着 $i$ 变大,二分查找中的 $\textit{left},\textit{cur},\textit{right}$ 也随之变大,我们可以用双指针(多指针)优化。这样算法瓶颈就在排序上了。
###py
# 手写 min max 更快
min = lambda a, b: b if b < a else a
max = lambda a, b: b if b > a else a
class Solution:
def maxWalls(self, robots: List[int], distance: List[int], walls: List[int]) -> int:
n, m = len(robots), len(walls)
a = [(0, 0)] + sorted(zip(robots, distance), key=lambda p: p[0]) + [(inf, 0)]
walls.sort()
f0 = f1 = left = cur = right0 = right1 = 0
for i in range(1, n + 1):
x, d = a[i]
# 往左射,墙的坐标范围为 [left_x, x]
left_x = max(x - d, a[i - 1][0] + 1) # +1 表示不能射到左边那个机器人
while left < m and walls[left] < left_x:
left += 1
while cur < m and walls[cur] < x:
cur += 1
cur1 = cur
if cur < m and walls[cur] == x:
cur += 1
left_res = f0 + cur - left # 下标在 [left, cur-1] 中的墙都能摧毁
# 往右射,右边那个机器人往左射,墙的坐标范围为 [x, right_x]
x2, d2 = a[i + 1]
right_x = min(x + d, x2 - d2 - 1) # -1 表示不能射到右边那个机器人
while right0 < m and walls[right0] <= right_x:
right0 += 1
f0 = max(left_res, f1 + right0 - cur1) # 下标在 [cur1, right0-1] 中的墙都能摧毁
# 往右射,右边那个机器人往右射,墙的坐标范围为 [x, right_x]
right_x = min(x + d, x2 - 1) # -1 表示不能射到右边那个机器人
while right1 < m and walls[right1] <= right_x:
right1 += 1
f1 = max(left_res, f1 + right1 - cur1) # 下标在 [cur1, right1-1] 中的墙都能摧毁
return f1
###java
class Solution {
public int maxWalls(int[] robots, int[] distance, int[] walls) {
int n = robots.length, m = walls.length;
int[][] a = new int[n + 2][2];
for (int i = 0; i < n; i++) {
a[i][0] = robots[i];
a[i][1] = distance[i];
}
a[n + 1][0] = Integer.MAX_VALUE;
Arrays.sort(a, (p, q) -> p[0] - q[0]);
Arrays.sort(walls);
int f0 = 0, f1 = 0, left = 0, cur = 0, right0 = 0, right1 = 0;
for (int i = 1; i <= n; i++) {
int x = a[i][0], d = a[i][1];
// 往左射,墙的坐标范围为 [leftX, x]
int leftX = Math.max(x - d, a[i - 1][0] + 1); // +1 表示不能射到左边那个机器人
while (left < m && walls[left] < leftX) {
left++;
}
while (cur < m && walls[cur] < x) {
cur++;
}
int cur1 = cur;
if (cur < m && walls[cur] == x) {
cur++;
}
int leftRes = f0 + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,右边那个机器人往左射,墙的坐标范围为 [x, rightX]
int x2 = a[i + 1][0], d2 = a[i + 1][1];
int rightX = Math.min(x + d, x2 - d2 - 1); // -1 表示不能射到右边那个机器人
while (right0 < m && walls[right0] <= rightX) {
right0++;
}
f0 = Math.max(leftRes, f1 + right0 - cur1); // 下标在 [cur1, right0-1] 中的墙都能摧毁
// 往右射,右边那个机器人往右射,墙的坐标范围为 [x, rightX]
rightX = Math.min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人
while (right1 < m && walls[right1] <= rightX) {
right1++;
}
f1 = Math.max(leftRes, f1 + right1 - cur1); // 下标在 [cur1, right1-1] 中的墙都能摧毁
}
return f1;
}
// 见 https://www.bilibili.com/video/BV1AP41137w7/
private int lowerBound(int[] nums, int target) {
int left = -1;
int right = nums.length;
while (left + 1 < right) {
int mid = left + (right - left) / 2;
if (nums[mid] >= target) {
right = mid;
} else {
left = mid;
}
}
return right;
}
}
###cpp
class Solution {
public:
int maxWalls(vector<int>& robots, vector<int>& distance, vector<int>& walls) {
int n = robots.size(), m = walls.size();
struct Pair { int x, d; };
vector<Pair> a(n + 2);
for (int i = 0; i < n; i++) {
a[i] = {robots[i], distance[i]};
}
a[n + 1].x = INT_MAX;
ranges::sort(a, {}, &Pair::x);
ranges::sort(walls);
int f0 = 0, f1 = 0, left = 0, cur = 0, right0 = 0, right1 = 0;
for (int i = 1; i <= n; i++) {
auto [x, d] = a[i];
// 往左射,墙的坐标范围为 [left_x, x]
int left_x = max(x - d, a[i - 1].x + 1); // +1 表示不能射到左边那个机器人
while (left < m && walls[left] < left_x) {
left++;
}
while (cur < m && walls[cur] < x) {
cur++;
}
int cur1 = cur;
if (cur < m && walls[cur] == x) {
cur++;
}
int left_res = f0 + cur - left; // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,右边那个机器人往左射,墙的坐标范围为 [x, right_x]
auto [x2, d2] = a[i + 1];
int right_x = min(x + d, x2 - d2 - 1); // -1 表示不能射到右边那个机器人
while (right0 < m && walls[right0] <= right_x) {
right0++;
}
f0 = max(left_res, f1 + right0 - cur1); // 下标在 [cur1, right0-1] 中的墙都能摧毁
// 往右射,右边那个机器人往右射,墙的坐标范围为 [x, right_x]
right_x = min(x + d, x2 - 1); // -1 表示不能射到右边那个机器人
while (right1 < m && walls[right1] <= right_x) {
right1++;
}
f1 = max(left_res, f1 + right1 - cur1); // 下标在 [cur1, right1-1] 中的墙都能摧毁
}
return f1;
}
};
###go
func maxWalls(robots []int, distance []int, walls []int) int {
n, m := len(robots), len(walls)
type pair struct{ x, d int }
a := make([]pair, n+2)
for i, x := range robots {
a[i] = pair{x, distance[i]}
}
a[n+1].x = math.MaxInt // 哨兵
slices.SortFunc(a, func(a, b pair) int { return a.x - b.x })
slices.Sort(walls)
var f0, f1, left, cur, right0, right1 int
for i := 1; i <= n; i++ {
p := a[i]
// 往左射,墙的坐标范围为 [leftX, p.x]
leftX := max(p.x-p.d, a[i-1].x+1) // +1 表示不能射到左边那个机器人
for left < m && walls[left] < leftX {
left++
}
for cur < m && walls[cur] < p.x {
cur++
}
cur1 := cur
if cur < m && walls[cur] == p.x {
cur++
}
leftRes := f0 + cur - left // 下标在 [left, cur-1] 中的墙都能摧毁
// 往右射,右边那个机器人往左射,墙的坐标范围为 [p.x, rightX]
q := a[i+1]
rightX := min(p.x+p.d, q.x-q.d-1) // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
for right0 < m && walls[right0] <= rightX {
right0++
}
f0 = max(leftRes, f1+right0-cur1) // 下标在 [cur1, right0-1] 中的墙都能摧毁
// 往右射,右边那个机器人往右射,墙的坐标范围为 [p.x, rightX]
rightX = min(p.x+p.d, q.x-1) // -1 表示不能射到右边那个机器人(或者它往左射到的墙)
for right1 < m && walls[right1] <= rightX {
right1++
}
f1 = max(leftRes, f1+right1-cur1) // 下标在 [cur1, right0-1] 中的墙都能摧毁
}
return f1
}
见下面动态规划题单的「六、状态机 DP」。
欢迎关注 B站@灵茶山艾府
本题相当于可以不选路径上的至多 $2$ 个数。
多一个约束,就多一个参数。
额外增加一个参数 $k$,定义 $\textit{dfs}(i,j,k)$ 表示从 $(0,0)$ 走到 $(i,j)$,在可用感化次数为 $k$ 的情况下,可以获得的最大金币数。
用「选或不选」分类讨论:
两种情况取最大值。
递归边界:
递归入口:$\textit{dfs}(m-1,n-1,2)$,这是原问题,也是答案。
⚠注意:由于答案可能是负数,所以记忆化数组的初始值不能是 $-1$。可以初始化成 $-\infty$。
具体请看 视频讲解,欢迎点赞关注~
###py
class Solution:
def maximumAmount(self, coins: List[List[int]]) -> int:
@cache # 缓存装饰器,避免重复计算 dfs 的结果(记忆化)
def dfs(i: int, j: int, k: int) -> int:
if i < 0 or j < 0:
return -inf
x = coins[i][j]
if i == 0 and j == 0:
return max(x, 0) if k else x
res = max(dfs(i - 1, j, k), dfs(i, j - 1, k)) + x # 选
if k and x < 0:
res = max(res, dfs(i - 1, j, k - 1), dfs(i, j - 1, k - 1)) # 不选
return res
ans = dfs(len(coins) - 1, len(coins[0]) - 1, 2)
dfs.cache_clear() # 避免超出内存限制
return ans
###java
class Solution {
public int maximumAmount(int[][] coins) {
int m = coins.length;
int n = coins[0].length;
int[][][] memo = new int[m][n][3];
for (int[][] mat : memo) {
for (int[] row : mat) {
Arrays.fill(row, Integer.MIN_VALUE);
}
}
return dfs(m - 1, n - 1, 2, coins, memo);
}
private int dfs(int i, int j, int k, int[][] coins, int[][][] memo) {
if (i < 0 || j < 0) {
return Integer.MIN_VALUE;
}
int x = coins[i][j];
if (i == 0 && j == 0) {
return k > 0 ? Math.max(x, 0) : x;
}
if (memo[i][j][k] != Integer.MIN_VALUE) { // 之前计算过
return memo[i][j][k];
}
int res = Math.max(dfs(i - 1, j, k, coins, memo), dfs(i, j - 1, k, coins, memo)) + x; // 选
if (k > 0 && x < 0) {
res = Math.max(res, Math.max(dfs(i - 1, j, k - 1, coins, memo), dfs(i, j - 1, k - 1, coins, memo))); // 不选
}
return memo[i][j][k] = res; // 记忆化
}
}
###cpp
class Solution {
public:
int maximumAmount(vector<vector<int>>& coins) {
int m = coins.size(), n = coins[0].size();
vector memo(m, vector(n, array<int, 3>{INT_MIN, INT_MIN, INT_MIN}));
auto dfs = [&](this auto&& dfs, int i, int j, int k) -> int {
if (i < 0 || j < 0) {
return INT_MIN;
}
int x = coins[i][j];
if (i == 0 && j == 0) {
return memo[i][j][k] = k ? max(x, 0) : x;
}
int& res = memo[i][j][k]; // 注意这里是引用
if (res != INT_MIN) { // 之前计算过
return res;
}
res = max(dfs(i - 1, j, k), dfs(i, j - 1, k)) + x; // 选
if (k && x < 0) {
res = max({res, dfs(i - 1, j, k - 1), dfs(i, j - 1, k - 1)}); // 不选
}
return res;
};
return dfs(m - 1, n - 1, 2);
}
};
###go
func maximumAmount(coins [][]int) int {
m, n := len(coins), len(coins[0])
memo := make([][][3]int, m)
for i := range memo {
memo[i] = make([][3]int, n)
for j := range memo[i] {
for k := range memo[i][j] {
memo[i][j][k] = math.MinInt
}
}
}
var dfs func(int, int, int) int
dfs = func(i, j, k int) int {
if i < 0 || j < 0 {
return math.MinInt
}
x := coins[i][j]
if i == 0 && j == 0 {
if k == 0 {
return x
}
return max(x, 0)
}
p := &memo[i][j][k]
if *p != math.MinInt { // 之前计算过
return *p
}
res := max(dfs(i-1, j, k), dfs(i, j-1, k)) + x // 选
if x < 0 && k > 0 {
res = max(res, dfs(i-1, j, k-1), dfs(i, j-1, k-1)) // 不选
}
*p = res // 记忆化
return res
}
return dfs(m-1, n-1, 2)
}
1:1 地把记忆化搜索翻译成递推,见 讲解。
代码实现时,可以把 $f[0][1][k]$ 初始化成 $0$,这样我们无需单独计算 $f[1][1]$。
###py
class Solution:
def maximumAmount(self, coins: List[List[int]]) -> int:
m, n = len(coins), len(coins[0])
f = [[[-inf] * 3 for _ in range(n + 1)] for _ in range(m + 1)]
f[0][1] = [0] * 3
for i, row in enumerate(coins):
for j, x in enumerate(row):
f[i + 1][j + 1][0] = max(f[i + 1][j][0], f[i][j + 1][0]) + x
f[i + 1][j + 1][1] = max(f[i + 1][j][1] + x, f[i][j + 1][1] + x,
f[i + 1][j][0], f[i][j + 1][0])
f[i + 1][j + 1][2] = max(f[i + 1][j][2] + x, f[i][j + 1][2] + x,
f[i + 1][j][1], f[i][j + 1][1])
return f[m][n][2]
###java
class Solution {
public int maximumAmount(int[][] coins) {
int m = coins.length;
int n = coins[0].length;
int[][][] f = new int[m + 1][n + 1][3];
for (int[] row : f[0]) {
Arrays.fill(row, Integer.MIN_VALUE);
}
Arrays.fill(f[0][1], 0);
for (int i = 0; i < m; i++) {
Arrays.fill(f[i + 1][0], Integer.MIN_VALUE);
for (int j = 0; j < n; j++) {
int x = coins[i][j];
f[i + 1][j + 1][0] = Math.max(f[i + 1][j][0], f[i][j + 1][0]) + x;
f[i + 1][j + 1][1] = Math.max(
Math.max(f[i + 1][j][1], f[i][j + 1][1]) + x,
Math.max(f[i + 1][j][0], f[i][j + 1][0])
);
f[i + 1][j + 1][2] = Math.max(
Math.max(f[i + 1][j][2], f[i][j + 1][2]) + x,
Math.max(f[i + 1][j][1], f[i][j + 1][1])
);
}
}
return f[m][n][2];
}
}
###cpp
class Solution {
public:
int maximumAmount(vector<vector<int>>& coins) {
int m = coins.size(), n = coins[0].size();
vector f(m + 1, vector(n + 1, array<int, 3>{INT_MIN / 2, INT_MIN / 2, INT_MIN / 2}));
f[0][1] = {0, 0, 0};
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
int x = coins[i][j];
f[i + 1][j + 1][0] = max(f[i + 1][j][0], f[i][j + 1][0]) + x;
f[i + 1][j + 1][1] = max({f[i + 1][j][1] + x, f[i][j + 1][1] + x,
f[i + 1][j][0], f[i][j + 1][0]});
f[i + 1][j + 1][2] = max({f[i + 1][j][2] + x, f[i][j + 1][2] + x,
f[i + 1][j][1], f[i][j + 1][1]});
}
}
return f[m][n][2];
}
};
###go
func maximumAmount(coins [][]int) int {
m, n := len(coins), len(coins[0])
f := make([][][3]int, m+1)
for i := range f {
f[i] = make([][3]int, n+1)
}
for j := range f[0] {
f[0][j] = [3]int{math.MinInt / 2, math.MinInt / 2, math.MinInt / 2}
}
f[0][1] = [3]int{}
for i, row := range coins {
f[i+1][0] = [3]int{math.MinInt / 2, math.MinInt / 2, math.MinInt / 2}
for j, x := range row {
f[i+1][j+1][0] = max(f[i+1][j][0], f[i][j+1][0]) + x
f[i+1][j+1][1] = max(f[i+1][j][1]+x, f[i][j+1][1]+x, f[i+1][j][0], f[i][j+1][0])
f[i+1][j+1][2] = max(f[i+1][j][2]+x, f[i][j+1][2]+x, f[i+1][j][1], f[i][j+1][1])
}
}
return f[m][n][2]
}
举个例子,在计算 $f[1][1]$ 时,会用到 $f[0][1]$,但是之后就不再用到了。那么干脆把 $f[1][1]$ 记到 $f[0][1]$ 中,这样对于 $f[1][2]$ 来说,它需要的数据就在 $f[0][1]$ 和 $f[0][2]$ 中。$f[1][2]$ 算完后也可以同样记到 $f[0][2]$ 中。
所以第一个维度可以去掉。
具体可以看【基础算法精讲 18】中的讲解。本题的转移方程类似完全背包,故整体采用正序遍历(但内部的 $k$ 要倒序)。
###py
class Solution:
def maximumAmount(self, coins: List[List[int]]) -> int:
n = len(coins[0])
f = [[-inf] * 3 for _ in range(n + 1)]
f[1] = [0] * 3
for row in coins:
for j, x in enumerate(row):
f[j + 1][2] = max(f[j][2] + x, f[j + 1][2] + x, f[j][1], f[j + 1][1])
f[j + 1][1] = max(f[j][1] + x, f[j + 1][1] + x, f[j][0], f[j + 1][0])
f[j + 1][0] = max(f[j][0], f[j + 1][0]) + x
return f[n][2]
###py
class Solution:
def maximumAmount(self, coins: List[List[int]]) -> int:
max = lambda a, b: a if a > b else b
n = len(coins[0])
f = [[-inf] * 3 for _ in range(n + 1)]
f[1] = [0] * 3
for row in coins:
for j, x in enumerate(row):
f[j + 1][2] = max(max(f[j][2], f[j + 1][2]) + x, max(f[j][1], f[j + 1][1]))
f[j + 1][1] = max(max(f[j][1], f[j + 1][1]) + x, max(f[j][0], f[j + 1][0]))
f[j + 1][0] = max(f[j][0], f[j + 1][0]) + x
return f[n][2]
###java
class Solution {
public int maximumAmount(int[][] coins) {
int n = coins[0].length;
int[][] f = new int[n + 1][3];
for (int[] row : f) {
Arrays.fill(row, Integer.MIN_VALUE);
}
Arrays.fill(f[1], 0);
for (int[] row : coins) {
for (int j = 0; j < n; j++) {
int x = row[j];
f[j + 1][2] = Math.max(
Math.max(f[j][2], f[j + 1][2]) + x,
Math.max(f[j][1], f[j + 1][1])
);
f[j + 1][1] = Math.max(
Math.max(f[j][1], f[j + 1][1]) + x,
Math.max(f[j][0], f[j + 1][0])
);
f[j + 1][0] = Math.max(f[j][0], f[j + 1][0]) + x;
}
}
return f[n][2];
}
}
###cpp
class Solution {
public:
int maximumAmount(vector<vector<int>>& coins) {
int n = coins[0].size();
vector f(n + 1, array<int, 3>{INT_MIN / 2, INT_MIN / 2, INT_MIN / 2});
f[1] = {0, 0, 0};
for (auto& row : coins) {
for (int j = 0; j < n; j++) {
int x = row[j];
f[j + 1][2] = max({f[j][2] + x, f[j + 1][2] + x, f[j][1], f[j + 1][1]});
f[j + 1][1] = max({f[j][1] + x, f[j + 1][1] + x, f[j][0], f[j + 1][0]});
f[j + 1][0] = max(f[j][0], f[j + 1][0]) + x;
}
}
return f[n][2];
}
};
###go
func maximumAmount(coins [][]int) int {
n := len(coins[0])
f := make([][3]int, n+1)
for j := range f {
f[j] = [3]int{math.MinInt / 2, math.MinInt / 2, math.MinInt / 2}
}
f[1] = [3]int{}
for _, row := range coins {
for j, x := range row {
f[j+1][2] = max(f[j][2]+x, f[j+1][2]+x, f[j][1], f[j+1][1])
f[j+1][1] = max(f[j][1]+x, f[j+1][1]+x, f[j][0], f[j+1][0])
f[j+1][0] = max(f[j][0], f[j+1][0]) + x
}
}
return f[n][2]
}
更多相似题目,见下面动态规划题单中的「二、网格图 DP」。
从左到右遍历这些机器人(需要先按照位置排序),向右的机器人会和向左的机器人碰撞。
遍历到一个向左的机器人时,我们需要找到左边最近的未移除的机器人。这可以用一个栈维护。
如果当前机器人向右,那么直接入栈,继续向后遍历。
如果当前机器人向左,设其健康度为 $h$,栈顶机器人的健康度为 $\textit{top}$,分类讨论:
⚠注意:比大小的这两个健康度都是正整数,所以减一的那个健康度一定大于 $1$。所以减一后,健康度大于 $0$。
代码实现时,直接在 $\textit{healths}$ 上修改,移除机器人 $i$ 相当于把 $\textit{healths}[i]$ 置为 $0$。最后返回 $\textit{healths}$ 中的正数。
class Solution:
def survivedRobotsHealths(self, positions: List[int], healths: List[int], directions: str) -> List[int]:
# 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
idx = sorted(range(len(positions)), key=lambda i: positions[i])
st = []
for i in idx:
if directions[i] == 'R': # 机器人 i 向右
st.append(i)
continue
while st: # 栈顶机器人向右
j = st[-1]
if healths[j] > healths[i]: # 栈顶机器人的健康度大
healths[i] = 0 # 移除机器人 i
healths[j] -= 1
break
if healths[j] == healths[i]: # 健康度一样大,都移除
healths[i] = 0
healths[j] = 0
st.pop()
break
# 机器人 i 的健康度大
healths[i] -= 1
healths[j] = 0 # 移除机器人 j
st.pop()
# 返回幸存机器人的健康度
return [h for h in healths if h > 0]
class Solution {
public List<Integer> survivedRobotsHealths(int[] positions, int[] healths, String directions) {
int n = positions.length;
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
Integer[] idx = new Integer[n];
for (int i = 0; i < n; i++) {
idx[i] = i;
}
Arrays.sort(idx, (i, j) -> positions[i] - positions[j]);
int[] st = new int[n];
int top = -1;
for (int i : idx) {
if (directions.charAt(i) == 'R') { // 机器人 i 向右
st[++top] = i;
continue;
}
while (top >= 0) { // 栈顶机器人向右
int j = st[top];
if (healths[j] > healths[i]) { // 栈顶机器人的健康度大
healths[i] = 0; // 移除机器人 i
healths[j]--;
break;
}
if (healths[j] == healths[i]) { // 健康度一样大,都移除
healths[i] = 0;
healths[j] = 0;
top--;
break;
}
// 机器人 i 的健康度大
healths[i]--;
healths[j] = 0; // 移除机器人 j
top--;
}
}
// 返回幸存机器人的健康度
List<Integer> ans = new ArrayList<>();
for (int h : healths) {
if (h > 0) {
ans.add(h);
}
}
return ans;
}
}
class Solution {
public:
vector<int> survivedRobotsHealths(vector<int>& positions, vector<int>& healths, string directions) {
int n = positions.size();
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
vector<int> idx(n);
ranges::iota(idx, 0); // idx[i] = i
ranges::sort(idx, {}, [&](int i) { return positions[i]; });
stack<int> st;
for (int i : idx) {
if (directions[i] == 'R') { // 机器人 i 向右
st.push(i);
continue;
}
while (!st.empty()) { // 栈顶机器人向右
int j = st.top();
if (healths[j] > healths[i]) { // 栈顶机器人的健康度大
healths[i] = 0; // 移除机器人 i
healths[j]--;
break;
}
if (healths[j] == healths[i]) { // 健康度一样大,都移除
healths[i] = 0;
healths[j] = 0;
st.pop();
break;
}
// 机器人 i 的健康度大
healths[i]--;
healths[j] = 0; // 移除机器人 j
st.pop();
}
}
// 返回幸存机器人的健康度
vector<int> ans;
for (int h : healths) {
if (h > 0) {
ans.push_back(h);
}
}
return ans;
}
};
int* _positions;
int cmp(const void* i, const void* j) {
return _positions[*(int*)i] - _positions[*(int*)j];
}
int* survivedRobotsHealths(int* positions, int positionsSize, int* healths, int healthsSize, char* directions, int* returnSize) {
int n = positionsSize;
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
int* idx = malloc(n * sizeof(int));
for (int i = 0; i < n; i++) {
idx[i] = i;
}
_positions = positions;
qsort(idx, n, sizeof(int), cmp);
int* st = malloc(n * sizeof(int));
int top = -1;
for (int k = 0; k < n; k++) {
int i = idx[k];
if (directions[i] == 'R') { // 机器人 i 向右
st[++top] = i;
continue;
}
while (top >= 0) { // 栈顶机器人向右
int j = st[top];
if (healths[j] > healths[i]) { // 栈顶机器人的健康度大
healths[i] = 0; // 移除机器人 i
healths[j]--;
break;
}
if (healths[j] == healths[i]) { // 健康度一样大,都移除
healths[i] = 0;
healths[j] = 0;
top--;
break;
}
// 机器人 i 的健康度大
healths[i]--;
healths[j] = 0; // 移除机器人 j
top--;
}
}
free(idx);
// 返回幸存机器人的健康度
int* ans = st;
*returnSize = 0;
for (int i = 0; i < n; i++) {
if (healths[i] > 0) {
ans[(*returnSize)++] = healths[i];
}
}
return ans;
}
func survivedRobotsHealths(positions []int, healths []int, directions string) (ans []int) {
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
idx := make([]int, len(positions))
for i := range idx {
idx[i] = i
}
slices.SortFunc(idx, func(i, j int) int { return positions[i] - positions[j] })
st := []int{}
for _, i := range idx {
if directions[i] == 'R' { // 机器人 i 向右
st = append(st, i)
continue
}
for len(st) > 0 { // 栈顶机器人向右
j := st[len(st)-1]
if healths[j] > healths[i] { // 栈顶机器人的健康度大
healths[i] = 0 // 移除机器人 i
healths[j]--
break
}
if healths[j] == healths[i] { // 健康度一样大,都移除
healths[i] = 0
healths[j] = 0
st = st[:len(st)-1]
break
}
// 机器人 i 的健康度大
healths[i]--
healths[j] = 0 // 移除机器人 j
st = st[:len(st)-1]
}
}
// 返回幸存机器人的健康度
for _, h := range healths {
if h > 0 {
ans = append(ans, h)
}
}
return
}
var survivedRobotsHealths = function(positions, healths, directions) {
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
const idx = Array.from({ length: positions.length }, (_, i) => i)
.sort((i, j) => positions[i] - positions[j]);
const st = [];
for (const i of idx) {
if (directions[i] === 'R') { // 机器人 i 向右
st.push(i);
continue;
}
while (st.length > 0) { // 栈顶机器人向右
const j = st[st.length - 1];
if (healths[j] > healths[i]) { // 栈顶机器人的健康度大
healths[i] = 0; // 移除机器人 i
healths[j] -= 1;
break;
}
if (healths[j] === healths[i]) { // 健康度一样大,都移除
healths[i] = 0;
healths[j] = 0;
st.pop();
break;
}
// 机器人 i 的健康度大
healths[i] -= 1;
healths[j] = 0; // 移除机器人 j
st.pop();
}
}
// 返回幸存机器人的健康度
return healths.filter(h => h > 0);
};
impl Solution {
pub fn survived_robots_healths(positions: Vec<i32>, mut healths: Vec<i32>, directions: String) -> Vec<i32> {
// 创建一个下标数组,对下标数组排序,这样不会打乱输入顺序
let mut idx = (0..positions.len()).collect::<Vec<_>>();
idx.sort_unstable_by_key(|&i| positions[i]);
let directions = directions.as_bytes();
let mut st = vec![];
for i in idx {
if directions[i] == b'R' { // 机器人 i 向右
st.push(i);
continue;
}
while let Some(&j) = st.last() { // 栈顶机器人向右
if healths[j] > healths[i] { // 栈顶机器人的健康度大
healths[i] = 0; // 移除机器人 i
healths[j] -= 1;
break;
}
if healths[j] == healths[i] { // 健康度一样大,都移除
healths[i] = 0;
healths[j] = 0;
st.pop();
break;
}
// 机器人 i 的健康度大
healths[i] -= 1;
healths[j] = 0; // 移除机器人 j
st.pop();
}
}
// 返回幸存机器人的健康度
healths.into_iter().filter(|&h| h > 0).collect()
}
}
见下面数据结构题单的「§3.3 邻项消除」。
欢迎关注 B站@灵茶山艾府
首先说做法。下文把 $\textit{str}_1$ 简记为 $s$,把 $\textit{str}_2$ 简记为 $t$。
先模拟:处理 $s$ 中的 T,把字符串 $t$ 填入答案的对应位置,如果发现矛盾,就返回空串。没填的位置(待定位置)初始化为 $\texttt{a}$。
再贪心:从左到右检查 F 对应的答案子串,如果发现子串和 $t$ 相同,那么把子串的最后一个待定位置改成 $\texttt{b}$。
本题的贪心策略是简单的,难点在正确性上。考虑如下问题:
$t$ 全为 $\texttt{a}$ 的情况。
这是容易证明的,因为把待定位置改成 $\texttt{b}$ 后,前面的受到影响的子串(包含这个 $\texttt{b}$ 的子串)一定不会等于 $t$,毕竟 $t$ 只有 $\texttt{a}$。
例如 $t=\texttt{aaa}$,现在 $\textit{ans}=\texttt{aaa?????aaa}$。其中 $\texttt{?}$ 表示待定位置,初始值为 $\texttt{a}$。
下面讨论 $t$ 包含不等于 $\texttt{a}$ 的字母的情况。
猜想:$t$ 形如 $t' + \texttt{aa\ldots a} + t'$。例如 $\texttt{baab},\texttt{baaaaba},\texttt{abaaaba}$ 等。
例如 $t=\texttt{baaaaba}$,即 $\texttt{ba} + \texttt{aaa} + \texttt{ba}$。
设 $\textit{ans} = \texttt{baaaaba???baaaaba}$。中间的 $\texttt{???}$ 不能全为 $\texttt{a}$,改成 $\texttt{aab}$,得 $\texttt{baaaaba}\underline{\texttt{aab}}\texttt{baaaaba}$,这里产生的 $\texttt{baaab}$ 可以保证前面的 F 对应子串不会和 $t$ 相同。
这可以推广到一般情况。抛砖引玉,欢迎在评论区发表你的证明。
同理,一旦我们修改了 $\textit{ans}[j]$,那么后面包含 $\textit{ans}[j]$ 的子串都不会和 $t$ 相同。所以只需改最后一个待定位置,不会出现改子串倒数第二个待定位置的情况。进一步地,可以直接跳到 $j+1$ 继续循环,这个优化用在方法二中。
###py
class Solution:
def generateString(self, s: str, t: str) -> str:
n, m = len(s), len(t)
ans = ['?'] * (n + m - 1) # ? 表示待定位置
# 处理 T
for i, b in enumerate(s):
if b != 'T':
continue
# 子串必须等于 t
for j, c in enumerate(t):
v = ans[i + j]
if v != '?' and v != c:
return ""
ans[i + j] = c
old_ans = ans
ans = ['a' if c == '?' else c for c in ans] # 待定位置的初始值为 a
# 处理 F
for i, b in enumerate(s):
if b != 'F':
continue
# 子串必须不等于 t
if ''.join(ans[i: i + m]) != t:
continue
# 找最后一个待定位置
for j in range(i + m - 1, i - 1, -1):
if old_ans[j] == '?': # 之前填 a,现在改成 b
ans[j] = 'b'
break
else:
return ""
return ''.join(ans)
###java
class Solution {
public String generateString(String S, String t) {
char[] s = S.toCharArray();
int n = s.length;
int m = t.length();
char[] ans = new char[n + m - 1];
Arrays.fill(ans, '?'); // '?' 表示待定位置
// 处理 T
for (int i = 0; i < n; i++) {
if (s[i] != 'T') {
continue;
}
// 子串必须等于 t
for (int j = 0; j < m; j++) {
char v = ans[i + j];
if (v != '?' && v != t.charAt(j)) {
return "";
}
ans[i + j] = t.charAt(j);
}
}
char[] oldAns = ans.clone();
for (int i = 0; i < ans.length; i++) {
if (ans[i] == '?') {
ans[i] = 'a'; // 待定位置的初始值为 'a'
}
}
// 处理 F
for (int i = 0; i < n; i++) {
if (s[i] != 'F') {
continue;
}
// 子串必须不等于 t
if (!new String(ans, i, m).equals(t)) {
continue;
}
// 找最后一个待定位置
boolean ok = false;
for (int j = i + m - 1; j >= i; j--) {
if (oldAns[j] == '?') { // 之前填 'a',现在改成 'b'
ans[j] = 'b';
ok = true;
break;
}
}
if (!ok) {
return "";
}
}
return new String(ans);
}
}
###cpp
class Solution {
public:
string generateString(string s, string t) {
int n = s.size(), m = t.size();
string ans(n + m - 1, '?'); // ? 表示待定位置
// 处理 T
for (int i = 0; i < n; i++) {
if (s[i] != 'T') {
continue;
}
// 子串必须等于 t
for (int j = 0; j < m; j++) {
char v = ans[i + j];
if (v != '?' && v != t[j]) {
return "";
}
ans[i + j] = t[j];
}
}
string old_ans = ans;
for (char& c : ans) {
if (c == '?') {
c = 'a'; // 待定位置的初始值为 a
}
}
// 处理 F
for (int i = 0; i < n; i++) {
if (s[i] != 'F') {
continue;
}
// 子串必须不等于 t
if (string(ans.begin() + i, ans.begin() + i + m) != t) {
continue;
}
// 找最后一个待定位置
bool ok = false;
for (int j = i + m - 1; j >= i; j--) {
if (old_ans[j] == '?') { // 之前填 a,现在改成 b
ans[j] = 'b';
ok = true;
break;
}
}
if (!ok) {
return "";
}
}
return ans;
}
};
###go
func generateString(s, T string) string {
n, m := len(s), len(T)
t := []byte(T)
ans := bytes.Repeat([]byte{'?'}, n+m-1) // ? 表示待定位置
// 处理 T
for i, b := range s {
if b != 'T' {
continue
}
// sub 必须等于 t
sub := ans[i : i+m]
for j, c := range sub {
if c != '?' && c != t[j] {
return ""
}
sub[j] = t[j]
}
}
oldAns := ans
ans = bytes.ReplaceAll(ans, []byte{'?'}, []byte{'a'}) // 待定位置的初始值为 a
// 处理 F
next:
for i, b := range s {
if b != 'F' {
continue
}
// sub 必须不等于 t
sub := ans[i : i+m]
if !bytes.Equal(sub, t) {
continue
}
// 找最后一个待定位置
old := oldAns[i : i+m]
for j := m - 1; j >= 0; j-- {
if old[j] == '?' { // 之前填 a,现在改成 b
sub[j] = 'b'
continue next
}
}
return ""
}
return string(ans)
}
在模拟(处理 $s$ 中的 T)的过程中,如果两个 $t$ 重叠,我们需要判断 $t$ 的某个长度的前后缀是否相同,这可以用 Z 函数直接解决。
判断 $\textit{ans}$ 子串是否等于 $t$ 也可以用 Z 函数。计算 $t + \textit{ans}$ 的 Z 函数,如果 $z[i+m]<m$,就说明从 $i$ 开始的 $\textit{ans}$ 子串不等于 $t$。
如果子串等于 $t$,那么找一个小于 $i+m$ 的最近待定位置,改成 $\texttt{b}$。这可以用一个数组 $\textit{preQ}$ 预处理每个 $\le i$ 的最近待定位置。
###py
class Solution:
def calc_z(self, s: str) -> List[int]:
n = len(s)
z = [0] * n
box_l, box_r = 0, 0 # z-box 左右边界(闭区间)
for i in range(1, n):
if i <= box_r:
z[i] = min(z[i - box_l], box_r - i + 1)
while i + z[i] < n and s[z[i]] == s[i + z[i]]:
box_l, box_r = i, i + z[i]
z[i] += 1
z[0] = n
return z
def generateString(self, s: str, t: str) -> str:
n, m = len(s), len(t)
ans = ['?'] * (n + m - 1)
# 处理 T
z = self.calc_z(t)
pre = -m
for i, b in enumerate(s):
if b != 'T':
continue
size = max(pre + m - i, 0)
# t 的长为 size 的前后缀必须相同
if size > 0 and z[m - size] < size:
return ""
# size 后的内容都是 '?',填入 t
ans[i + size: i + m] = t[size:]
pre = i
# 计算 <= i 的最近待定位置
pre_q = [-1] * len(ans)
pre = -1
for i, c in enumerate(ans):
if c == '?':
ans[i] = 'a' # 待定位置的初始值为 a
pre = i
pre_q[i] = pre
# 找 ans 中的等于 t 的位置,可以用 KMP 或者 Z 函数
z = self.calc_z(t + ''.join(ans))
# 处理 F
i = 0
while i < n:
if s[i] != 'F':
i += 1
continue
# 子串必须不等于 t
if z[m + i] < m:
i += 1
continue
# 找最后一个待定位置
j = pre_q[i + m - 1]
if j < i: # 没有
return ""
ans[j] = 'b'
i = j + 1 # 直接跳过 j
return ''.join(ans)
###java
class Solution {
public String generateString(String S, String t) {
char[] s = S.toCharArray();
int n = s.length;
int m = t.length();
char[] ans = new char[n + m - 1];
Arrays.fill(ans, '?');
// 处理 T
int[] z = calcZ(t);
int pre = -m;
for (int i = 0; i < n; i++) {
if (s[i] != 'T') {
continue;
}
int size = Math.max(pre + m - i, 0);
// t 的长为 size 的前后缀必须相同
if (size > 0 && z[m - size] < size) {
return "";
}
// size 后的内容都是 '?',填入 t
for (int j = size; j < m; j++) {
ans[i + j] = t.charAt(j);
}
pre = i;
}
// 计算 <= i 的最近待定位置
int[] preQ = new int[ans.length];
pre = -1;
for (int i = 0; i < ans.length; i++) {
if (ans[i] == '?') {
ans[i] = 'a'; // 待定位置的初始值为 a
pre = i;
}
preQ[i] = pre;
}
// 找 ans 中的等于 t 的位置,可以用 KMP 或者 Z 函数
z = calcZ(t + new String(ans));
// 处理 F
for (int i = 0; i < n; i++) {
if (s[i] != 'F') {
continue;
}
// 子串必须不等于 t
if (z[m + i] < m) {
continue;
}
// 找最后一个待定位置
int j = preQ[i + m - 1];
if (j < i) { // 没有
return "";
}
ans[j] = 'b';
i = j; // 直接跳到 j
}
return new String(ans);
}
private int[] calcZ(String S) {
char[] s = S.toCharArray();
int n = s.length;
int[] z = new int[n];
int boxL = 0; // z-box 左右边界(闭区间)
int boxR = 0;
for (int i = 1; i < n; i++) {
if (i <= boxR) {
z[i] = Math.min(z[i - boxL], boxR - i + 1);
}
while (i + z[i] < n && s[z[i]] == s[i + z[i]]) {
boxL = i;
boxR = i + z[i];
z[i]++;
}
}
z[0] = n;
return z;
}
}
###cpp
class Solution {
vector<int> calc_z(const string& s) {
int n = s.size();
vector<int> z(n);
int box_l = 0, box_r = 0; // z-box 左右边界(闭区间)
for (int i = 1; i < n; i++) {
if (i <= box_r) {
z[i] = min(z[i - box_l], box_r - i + 1);
}
while (i + z[i] < n && s[z[i]] == s[i + z[i]]) {
box_l = i;
box_r = i + z[i];
z[i]++;
}
}
z[0] = n;
return z;
}
public:
string generateString(string s, string t) {
int n = s.size(), m = t.size();
string ans(n + m - 1, '?');
// 处理 T
vector<int> z = calc_z(t);
int pre = -m;
for (int i = 0; i < n; i++) {
if (s[i] != 'T') {
continue;
}
int size = max(pre + m - i, 0);
// t 的长为 size 的前后缀必须相同
if (size > 0 && z[m - size] < size) {
return "";
}
// size 后的内容都是 '?',填入 t
for (int j = size; j < m; j++) {
ans[i + j] = t[j];
}
pre = i;
}
// 计算 <= i 的最近待定位置
vector<int> pre_q(ans.size());
pre = -1;
for (int i = 0; i < ans.size(); i++) {
if (ans[i] == '?') {
ans[i] = 'a'; // 待定位置的初始值为 a
pre = i;
}
pre_q[i] = pre;
}
// 找 ans 中的等于 t 的位置,可以用 KMP 或者 Z 函数
z = calc_z(t + ans);
// 处理 F
for (int i = 0; i < n; i++) {
if (s[i] != 'F') {
continue;
}
// 子串必须不等于 t
if (z[m + i] < m) {
continue;
}
// 找最后一个待定位置
int j = pre_q[i + m - 1];
if (j < i) { // 没有
return "";
}
ans[j] = 'b';
i = j; // 直接跳到 j
}
return ans;
}
};
###go
func calcZ(s string) []int {
n := len(s)
z := make([]int, n)
boxL, boxR := 0, 0 // z-box 左右边界(闭区间)
for i := 1; i < n; i++ {
if i <= boxR {
z[i] = min(z[i-boxL], boxR-i+1)
}
for i+z[i] < n && s[z[i]] == s[i+z[i]] {
boxL, boxR = i, i+z[i]
z[i]++
}
}
z[0] = n
return z
}
func generateString(s, t string) string {
n, m := len(s), len(t)
ans := bytes.Repeat([]byte{'?'}, n+m-1)
// 处理 T
pre := -m
z := calcZ(t)
for i, b := range s {
if b != 'T' {
continue
}
size := max(pre+m-i, 0)
// t 的长为 size 的前后缀必须相同
if size > 0 && z[m-size] < size {
return ""
}
// size 后的内容都是 '?',填入 t
copy(ans[i+size:], t[size:])
pre = i
}
// 计算 <= i 的最近待定位置
preQ := make([]int, len(ans))
pre = -1
for i, c := range ans {
if c == '?' {
ans[i] = 'a' // 待定位置的初始值为 a
pre = i
}
preQ[i] = pre
}
// 找 ans 中的等于 t 的位置,可以用 KMP 或者 Z 函数
z = calcZ(t + string(ans))
// 处理 F
for i := 0; i < n; i++ {
if s[i] != 'F' {
continue
}
// 子串必须不等于 t
if z[m+i] < m {
continue
}
// 找最后一个待定位置
j := preQ[i+m-1]
if j < i { // 没有
return ""
}
ans[j] = 'b'
i = j // 直接跳到 j
}
return string(ans)
}
更多相似题目,见下面贪心题单中的「§3.1 字典序最小/最大」和字符串题单中的「二、Z 函数」。
有一个更强的结论:改成 $j-i=2$,也是一样的。
既然可以交换相距为 $2$ 的字符,那么相距为 $4$ 的字符可以通过多次交换实现。例如 $x-y-z$ 变成 $y-x-z$ 变成 $y-z-x$ 变成 $z-y-x$。
依此类推,所有相距为偶数的字符都可以随意交换。
所以只需要看下标为偶数的字符个数是否都一样,以及下标为奇数的字符个数是否都一样。
class Solution:
def checkStrings(self, s1: str, s2: str) -> bool:
return Counter(s1[::2]) == Counter(s2[::2]) and \
Counter(s1[1::2]) == Counter(s2[1::2])
class Solution {
public boolean checkStrings(String s1, String s2) {
int[][] cnt1 = new int[2][26];
int[][] cnt2 = new int[2][26];
for (int i = 0; i < s1.length(); i++) {
cnt1[i % 2][s1.charAt(i) - 'a']++;
cnt2[i % 2][s2.charAt(i) - 'a']++;
}
return Arrays.deepEquals(cnt1, cnt2);
}
}
class Solution {
public:
bool checkStrings(string s1, string s2) {
int cnt1[2][26]{}, cnt2[2][26]{};
for (int i = 0; i < s1.length(); i++) {
cnt1[i % 2][s1[i] - 'a']++;
cnt2[i % 2][s2[i] - 'a']++;
}
return memcmp(cnt1, cnt2, sizeof(cnt1)) == 0;
}
};
func checkStrings(s1, s2 string) bool {
var cnt1, cnt2 [2][26]int
for i, c := range s1 {
cnt1[i%2][c-'a']++
cnt2[i%2][s2[i]-'a']++
}
return cnt1 == cnt2
}
改成 $j-i=3$ 要怎么做?
欢迎关注 B站@灵茶山艾府
实际上,只需把每行右移 $k$ 次。无需判断奇偶,无需考虑左移。
为什么?如果一行左移 $k$ 次等于自己,那么这个过程的逆过程,就是把自己右移 $k$ 次,得到自己。
判断 $\textit{row}$ 右移 $k$ 次是否等于 $\textit{row}$,可以比较 $\textit{row}[j]$ 与右移 $k$ 次后的位置 $\textit{row}[(j+k)\bmod n]$ 是否相等。
###py
class Solution:
def areSimilar(self, mat: List[List[int]], k: int) -> bool:
k %= len(mat[0]) # 右移 n 次等价于右移 0 次,右移 n+1 次等价于右移 1 次,依此类推,先模个 n
return k == 0 or all(row == row[k:] + row[:k] for row in mat)
###py
class Solution:
def areSimilar(self, mat: List[List[int]], k: int) -> bool:
n = len(mat[0])
for row in mat:
for j in range(n):
if row[j] != row[(j + k) % n]:
return False
return True
###java
class Solution {
public boolean areSimilar(int[][] mat, int k) {
int n = mat[0].length;
for (int[] row : mat) {
for (int j = 0; j < n; j++) {
if (row[j] != row[(j + k) % n]) {
return false;
}
}
}
return true;
}
}
###cpp
class Solution {
public:
bool areSimilar(vector<vector<int>>& mat, int k) {
int n = mat[0].size();
for (auto& row : mat) {
for (int j = 0; j < n; j++) {
if (row[j] != row[(j + k) % n]) {
return false;
}
}
}
return true;
}
};
###c
bool areSimilar(int** mat, int matSize, int* matColSize, int k) {
int n = matColSize[0];
for (int i = 0; i < matSize; i++) {
int* row = mat[i];
for (int j = 0; j < n; j++) {
if (row[j] != row[(j + k) % n]) {
return false;
}
}
}
return true;
}
###go
func areSimilar(mat [][]int, k int) bool {
n := len(mat[0])
for _, row := range mat {
for j, x := range row {
if x != row[(j+k)%n] {
return false
}
}
}
return true
}
###go
func areSimilar(mat [][]int, k int) bool {
k %= len(mat[0]) // 右移 n 次等价于右移 0 次,右移 n+1 次等价于右移 1 次,依此类推,先模个 n
for _, row := range mat {
if !slices.Equal(row, append(row[k:], row[:k]...)) {
return false
}
}
return true
}
###js
var areSimilar = function(mat, k) {
const n = mat[0].length;
for (const row of mat) {
for (let j = 0; j < n; j++) {
if (row[j] !== row[(j + k) % n]) {
return false;
}
}
}
return true;
};
###rust
impl Solution {
pub fn are_similar(mat: Vec<Vec<i32>>, k: i32) -> bool {
let n = mat[0].len();
let k = k as usize;
for row in mat {
for j in 0..n {
if row[j] != row[(j + k) % n] {
return false;
}
}
}
true
}
}
可能有同学脑筋没转过来,这里详细解释下。
如果给你两个数组 $a$ 和 $b$,要判断 $a$ 左移/右移后是否等于 $b$,那么 $a$ 左移 $k$ 次和右移 $k$ 次是不一样的。
但本题这两个数组都是 $a$,要判断 $a$ 左移/右移后是否等于 $a$ 自己。
由于 $a$ 左移 $k$ 次后和 $b$ 比较,等价于 $b$ 右移 $k$ 次后和 $a$ 比较。在 $b$ 就是 $a$ 的情况下,等价于 $a$ 自己右移 $k$ 次和 $a$ 比较。
欢迎关注 B站@灵茶山艾府
设整个 $\textit{grid}$ 的元素和为 $\textit{total}$。
设第一部分的元素和为 $s$,那么第二部分的元素和为 $\textit{total}-s$。
据此,我们可以一边遍历 $\textit{grid}$,一边计算第一部分的元素和 $s$,一边用哈希集合记录遍历过的元素。
每一行/列遍历结束后,判断 $x=2s-\textit{total}$ 是否在哈希集合中,如果在,就说明存在 $x$,使得 $s - x = \textit{total}-s$ 成立。
小技巧:预先把 $0$ 加到哈希集合中,这样可以把不删和删合并成一种情况。
对于删第二部分中的元素的情况,可以把 $\textit{grid}$ 上下翻转,复用删第一部分中的元素的代码。
先计算水平分割的情况。
分类讨论:
对于垂直分割,可以把 $\textit{grid}$ 旋转 $90$ 度,复用上述代码。
具体请看 视频讲解,欢迎点赞关注~
###py
class Solution:
def canPartitionGrid(self, grid: List[List[int]]) -> bool:
total = sum(sum(row) for row in grid)
# 能否水平分割
def check(a: List[List[int]]) -> bool:
m, n = len(a), len(a[0])
# 删除上半部分中的一个数,能否满足要求
def f(a: List[List[int]]) -> bool:
st = {0} # 0 对应不删除数字
s = 0
for i, row in enumerate(a[:-1]):
for j, x in enumerate(row):
s += x
# 第一行,不能删除中间元素
if i > 0 or j == 0 or j == n - 1:
st.add(x)
# 特殊处理只有一列的情况,此时只能删除第一个数或者分割线上那个数
if n == 1:
if s * 2 == total or s * 2 - total == a[0][0] or s * 2 - total == row[0]:
return True
continue
if s * 2 - total in st:
return True
# 如果分割到更下面,那么可以删第一行的元素
if i == 0:
st.update(row)
return False
# 删除上半部分中的数 or 删除下半部分中的数
return f(a) or f(a[::-1])
# 水平分割 or 垂直分割
return check(grid) or check(list(zip(*grid)))
###java
class Solution {
public boolean canPartitionGrid(int[][] grid) {
long total = 0;
for (int[] row : grid) {
for (int x : row) {
total += x;
}
}
// 水平分割 or 垂直分割
return check(grid, total) || check(rotate(grid), total);
}
private boolean check(int[][] a, long total) {
// 删除上半部分中的一个数
if (f(a, total)) {
return true;
}
reverse(a);
// 删除下半部分中的一个数
return f(a, total);
}
private boolean f(int[][] a, long total) {
int m = a.length, n = a[0].length;
Set<Long> st = new HashSet<>();
st.add(0L); // 0 对应不删除数字
long s = 0;
for (int i = 0; i < m - 1; i++) {
int[] row = a[i];
for (int j = 0; j < n; j++) {
int x = row[j];
s += x;
// 第一行,不能删除中间元素
if (i > 0 || j == 0 || j == n - 1) {
st.add((long) x);
}
}
// 特殊处理只有一列的情况,此时只能删除第一个数或者分割线上那个数
if (n == 1) {
if (s * 2 == total || s * 2 - total == a[0][0] || s * 2 - total == row[0]) {
return true;
}
continue;
}
if (st.contains(s * 2 - total)) {
return true;
}
// 如果分割到更下面,那么可以删第一行的元素
if (i == 0) {
for (int x : row) {
st.add((long) x);
}
}
}
return false;
}
// 顺时针旋转矩阵 90°
private int[][] rotate(int[][] a) {
int m = a.length, n = a[0].length;
int[][] b = new int[n][m];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
b[j][m - 1 - i] = a[i][j];
}
}
return b;
}
private void reverse(int[][] a) {
for (int i = 0, j = a.length - 1; i < j; i++, j--) {
int[] tmp = a[i];
a[i] = a[j];
a[j] = tmp;
}
}
}
###cpp
class Solution {
// 顺时针旋转矩阵 90°
vector<vector<int>> rotate(vector<vector<int>>& a) {
int m = a.size(), n = a[0].size();
vector b(n, vector<int>(m));
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
b[j][m - 1 - i] = a[i][j];
}
}
return b;
}
public:
bool canPartitionGrid(vector<vector<int>>& grid) {
long long total = 0;
for (auto& row : grid) {
for (int x : row) {
total += x;
}
}
auto check = [&](vector<vector<int>> a) -> bool {
int m = a.size(), n = a[0].size();
auto f = [&]() -> bool {
unordered_set<long long> st = {0}; // 0 对应不删除数字
long long s = 0;
for (int i = 0; i < m - 1; i++) {
auto& row = a[i];
for (int j = 0; j < n; j++) {
int x = row[j];
s += x;
// 第一行,不能删除中间元素
if (i > 0 || j == 0 || j == n - 1) {
st.insert(x);
}
}
// 特殊处理只有一列的情况,此时只能删除第一个数或者分割线上那个数
if (n == 1) {
if (s * 2 == total || s * 2 - total == a[0][0] || s * 2 - total == row[0]) {
return true;
}
continue;
}
if (st.contains(s * 2 - total)) {
return true;
}
// 如果分割到更下面,那么可以删第一行的元素
if (i == 0) {
for (int x : row) {
st.insert(x);
}
}
}
return false;
};
// 删除上半部分中的一个数
if (f()) {
return true;
}
ranges::reverse(a);
// 删除下半部分中的一个数
return f();
};
// 水平分割 or 垂直分割
return check(grid) || check(rotate(grid));
}
};
###go
func canPartitionGrid(grid [][]int) bool {
total := 0
for _, row := range grid {
for _, x := range row {
total += x
}
}
// 能否水平分割
check := func(a [][]int) bool {
m, n := len(a), len(a[0])
f := func() bool {
has := map[int]bool{0: true} // 0 对应不删除数字
s := 0
for i, row := range a[:m-1] {
for j, x := range row {
s += x
// 第一行,不能删除中间元素
if i > 0 || j == 0 || j == n-1 {
has[x] = true
}
}
// 特殊处理只有一列的情况,此时只能删除第一个数或者分割线上那个数
if n == 1 {
if s*2 == total || s*2-total == a[0][0] || s*2-total == row[0] {
return true
}
continue
}
if has[s*2-total] {
return true
}
// 如果分割到更下面,那么可以删第一行的元素
if i == 0 {
for _, x := range row {
has[x] = true
}
}
}
return false
}
// 删除上半部分中的一个数
if f() {
return true
}
slices.Reverse(a)
// 删除下半部分中的一个数
return f()
}
// 水平分割 or 垂直分割
return check(grid) || check(rotate(grid))
}
// 顺时针旋转矩阵 90°
func rotate(a [][]int) [][]int {
m, n := len(a), len(a[0])
b := make([][]int, n)
for i := range b {
b[i] = make([]int, m)
}
for i, row := range a {
for j, x := range row {
b[j][m-1-i] = x
}
}
return b
}
欢迎关注 B站@灵茶山艾府
请先完成本题的一维版本:238. 除了自身以外数组的乘积。
把矩阵平铺成一维数组,就是 238 题了。我们需要算出每个数左边所有数的乘积,以及右边所有数的乘积。
先算出从 $\textit{grid}[i][j]$ 的下一个元素开始,到最后一个元素 $\textit{grid}[n-1][m-1]$ 的乘积,记作 $\textit{suf}[i][j]$。这可以从最后一个数 $\textit{grid}[n-1][m-1]$ 开始,倒着遍历 $\textit{grid}$ 得到。
然后算出从第一个数 $\textit{grid}[0][0]$ 开始,到 $\textit{grid}[i][j]$ 的上一个元素的乘积,记作 $\textit{pre}[i][j]$。这可以从第一行第一列开始,正着遍历得到。
那么
$$
p[i][j] = \textit{pre}[i][j]\cdot \textit{suf}[i][j]
$$
代码实现时,可以先初始化 $p[i][j]=\textit{suf}[i][j]$,然后在计算 $\textit{pre}[i][j]$ 的过程中,把 $\textit{pre}[i][j]$ 乘到 $\textit{p}[i][j]$ 中,就得到了最终答案。这样写的话,$\textit{pre}$ 和 $\textit{suf}$ 可以直接用单个变量表示,无需创建数组。
代码实现时,注意取模。为什么可以在中途取模?原理见 模运算的世界:当加减乘除遇上取模。
class Solution:
def constructProductMatrix(self, grid: List[List[int]]) -> List[List[int]]:
MOD = 12345
n, m = len(grid), len(grid[0])
p = [[0] * m for _ in range(n)]
suf = 1 # 后缀乘积
for i in range(n - 1, -1, -1):
for j in range(m - 1, -1, -1):
p[i][j] = suf # p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % MOD
pre = 1 # 前缀乘积
for i, row in enumerate(grid):
for j, x in enumerate(row):
p[i][j] = p[i][j] * pre % MOD # 乘上前缀乘积
pre = pre * x % MOD
return p
class Solution {
public int[][] constructProductMatrix(int[][] grid) {
final int MOD = 12345;
int n = grid.length;
int m = grid[0].length;
int[][] p = new int[n][m];
long suf = 1; // 后缀乘积
for (int i = n - 1; i >= 0; i--) {
for (int j = m - 1; j >= 0; j--) {
p[i][j] = (int) suf; // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % MOD;
}
}
long pre = 1; // 前缀乘积
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
p[i][j] = (int) (p[i][j] * pre % MOD); // 乘上前缀乘积
pre = pre * grid[i][j] % MOD;
}
}
return p;
}
}
class Solution {
public:
vector<vector<int>> constructProductMatrix(vector<vector<int>>& grid) {
constexpr int MOD = 12345;
int n = grid.size(), m = grid[0].size();
vector p(n, vector<int>(m));
long long suf = 1; // 后缀乘积
for (int i = n - 1; i >= 0; i--) {
for (int j = m - 1; j >= 0; j--) {
p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % MOD;
}
}
long long pre = 1; // 前缀乘积
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
pre = pre * grid[i][j] % MOD;
}
}
return p;
}
};
int** constructProductMatrix(int** grid, int gridSize, int* gridColSize, int* returnSize, int** returnColumnSizes) {
const int MOD = 12345;
int n = gridSize, m = gridColSize[0];
int** p = malloc(n * sizeof(int*));
*returnSize = n;
*returnColumnSizes = malloc(n * sizeof(int));
for (int i = 0; i < n; i++) {
p[i] = malloc(m * sizeof(int));
(*returnColumnSizes)[i] = m;
}
long long suf = 1; // 后缀乘积
for (int i = n - 1; i >= 0; i--) {
for (int j = m - 1; j >= 0; j--) {
p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % MOD;
}
}
long long pre = 1; // 前缀乘积
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
pre = pre * grid[i][j] % MOD;
}
}
return p;
}
func constructProductMatrix(grid [][]int) [][]int {
const mod = 12345
n, m := len(grid), len(grid[0])
p := make([][]int, n)
suf := 1 // 后缀乘积
for i := n - 1; i >= 0; i-- {
p[i] = make([]int, m)
for j := m - 1; j >= 0; j-- {
p[i][j] = suf // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % mod
}
}
pre := 1 // 前缀乘积
for i, row := range grid {
for j, x := range row {
p[i][j] = p[i][j] * pre % mod // 乘上前缀乘积
pre = pre * x % mod
}
}
return p
}
var constructProductMatrix = function(grid) {
const MOD = 12345;
const n = grid.length, m = grid[0].length;
const p = Array.from({ length: n }, () => Array(m).fill(0));
let suf = 1; // 后缀乘积
for (let i = n - 1; i >= 0; i--) {
for (let j = m - 1; j >= 0; j--) {
p[i][j] = suf; // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] % MOD;
}
}
let pre = 1; // 前缀乘积
for (let i = 0; i < n; i++) {
for (let j = 0; j < m; j++) {
p[i][j] = p[i][j] * pre % MOD; // 乘上前缀乘积
pre = pre * grid[i][j] % MOD;
}
}
return p;
};
impl Solution {
pub fn construct_product_matrix(grid: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
const MOD: i64 = 12345;
let n = grid.len();
let m = grid[0].len();
let mut p = vec![vec![0; m]; n];
let mut suf = 1; // 后缀乘积
for i in (0..n).rev() {
for j in (0..m).rev() {
p[i][j] = suf as i32; // p[i][j] 先初始化成后缀乘积
suf = suf * grid[i][j] as i64 % MOD;
}
}
let mut pre = 1; // 前缀乘积
for i in 0..n {
for j in 0..m {
p[i][j] = (p[i][j] as i64 * pre % MOD) as i32; // 乘上前缀乘积
pre = pre * grid[i][j] as i64 % MOD;
}
}
p
}
}
见下面动态规划题单的「专题:前后缀分解」。
欢迎关注 B站@灵茶山艾府
思路和 152 题是一样的,除了计算最大路径乘积,还要计算最小路径乘积(因为负负得正)。
定义 $\textit{dfs}(i,j)$ 表示从左上角 $(0,0)$ 到 $(i,j)$ 的最小路径乘积以及最大路径乘积($\textit{dfs}$ 返回两个数)。
设 $x = \textit{grid}[i][j]$。分类讨论如何到达 $(i,j)$:
两种情况取最小值(最大值),即为 $\textit{dfs}(i,j)$ 的返回值。
递归边界:$\textit{dfs}(0,0) = (x,x)$。
递归入口:$\textit{dfs}(m-1,n-1)$。取返回值中的最大路径乘积作为答案。如果答案是负数,返回 $-1$;否则返回答案模 $10^9+7$ 的结果。
⚠注意:题目要求算完了再取模。如果在中途取模,可能会把一个很大的数模成很小的数,导致计算错误。比如两个数 $10^9+8$ 和 $10^9$,取模之前是 $10^9+8$ 更大,但取模后这两个数分别变成 $1$ 和 $10^9$,后者更大。
###py
class Solution:
def maxProductPath(self, grid: List[List[int]]) -> int:
@cache # 缓存装饰器,避免重复计算 dfs(一行代码实现记忆化)
def dfs(i: int, j: int) -> Tuple[int, int]:
x = grid[i][j]
if i == j == 0:
return x, x
res_min, res_max = inf, -inf
if i > 0:
mn, mx = dfs(i - 1, j)
res_min = min(mn * x, mx * x)
res_max = max(mn * x, mx * x)
if j > 0:
mn, mx = dfs(i, j - 1)
res_min = min(res_min, mn * x, mx * x)
res_max = max(res_max, mn * x, mx * x)
return res_min, res_max
ans = dfs(len(grid) - 1, len(grid[0]) - 1)[1]
return -1 if ans < 0 else ans % 1_000_000_007
###java
class Solution {
public int maxProductPath(int[][] grid) {
int m = grid.length, n = grid[0].length;
long[][][] memo = new long[m][n][2];
for (long[][] row : memo) {
for (long[] p : row) {
p[0] = p[1] = Long.MIN_VALUE;
}
}
long ans = dfs(m - 1, n - 1, grid, memo)[1];
return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
}
private long[] dfs(int i, int j, int[][] grid, long[][][] memo) {
long x = grid[i][j];
if (i == 0 && j == 0) {
return new long[]{x, x};
}
long[] p = memo[i][j];
if (p[0] != Long.MIN_VALUE) { // 之前计算过
return p;
}
long resMin = Long.MAX_VALUE;
long resMax = Long.MIN_VALUE;
if (i > 0) {
long[] res = dfs(i - 1, j, grid, memo);
long mn = res[0], mx = res[1];
resMin = Math.min(mn * x, mx * x);
resMax = Math.max(mn * x, mx * x);
}
if (j > 0) {
long[] res = dfs(i, j - 1, grid, memo);
long mn = res[0], mx = res[1];
resMin = Math.min(resMin, Math.min(mn * x, mx * x));
resMax = Math.max(resMax, Math.max(mn * x, mx * x));
}
p[0] = resMin;
p[1] = resMax; // 记忆化
return p;
}
}
###cpp
class Solution {
public:
int maxProductPath(vector<vector<int>>& grid) {
int m = grid.size(), n = grid[0].size();
vector memo(m, vector<array<long long, 2>>(n, {LLONG_MIN, LLONG_MIN}));
auto dfs = [&](this auto&& dfs, int i, int j) -> array<long long, 2> {
long long x = grid[i][j];
if (i == 0 && j == 0) {
return {x, x};
}
auto& res = memo[i][j]; // 注意这里是引用
if (res[0] != LLONG_MIN) { // 之前计算过
return res;
}
long long res_min = LLONG_MAX;
long long res_max = LLONG_MIN;
if (i > 0) {
auto [mn, mx] = dfs(i - 1, j);
res_min = min(mn * x, mx * x);
res_max = max(mn * x, mx * x);
}
if (j > 0) {
auto [mn, mx] = dfs(i, j - 1);
res_min = min(res_min, min(mn * x, mx * x));
res_max = max(res_max, max(mn * x, mx * x));
}
res = {res_min, res_max}; // 记忆化
return res;
};
long long ans = dfs(m - 1, n - 1)[1];
return ans < 0 ? -1 : ans % 1'000'000'007;
}
};
###go
func maxProductPath(grid [][]int) int {
m, n := len(grid), len(grid[0])
memo := make([][][2]int, m)
for i := range memo {
memo[i] = make([][2]int, n)
for j := range memo[i] {
memo[i][j] = [2]int{math.MinInt, math.MinInt}
}
}
var dfs func(int, int) (int, int)
dfs = func(i, j int) (int, int) {
x := grid[i][j]
if i == 0 && j == 0 {
return x, x
}
p := &memo[i][j]
if p[0] != math.MinInt { // 之前计算过
return p[0], p[1]
}
resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := dfs(i-1, j)
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := dfs(i, j-1)
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}
p[0], p[1] = resMin, resMax // 记忆化
return resMin, resMax
}
_, ans := dfs(m-1, n-1)
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}
把 $\textit{dfs}(i,j)$ 改成 $f[i][j]$。
###py
class Solution:
def maxProductPath(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
f = [[None] * n for _ in range(m)]
for i, row in enumerate(grid):
for j, x in enumerate(row):
if i == j == 0:
f[0][0] = (x, x)
continue
res_min, res_max = inf, -inf
if i > 0:
mn, mx = f[i - 1][j]
res_min = min(mn * x, mx * x)
res_max = max(mn * x, mx * x)
if j > 0:
mn, mx = f[i][j - 1]
res_min = min(res_min, mn * x, mx * x)
res_max = max(res_max, mn * x, mx * x)
f[i][j] = (res_min, res_max)
ans = f[-1][-1][1]
return -1 if ans < 0 else ans % 1_000_000_007
###java
class Solution {
public int maxProductPath(int[][] grid) {
int m = grid.length, n = grid[0].length;
long[][][] f = new long[m][n][2];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
long x = grid[i][j];
if (i == 0 && j == 0) {
f[0][0][0] = x;
f[0][0][1] = x;
continue;
}
long resMin = Long.MAX_VALUE;
long resMax = Long.MIN_VALUE;
if (i > 0) {
long mn = f[i - 1][j][0], mx = f[i - 1][j][1];
resMin = Math.min(mn * x, mx * x);
resMax = Math.max(mn * x, mx * x);
}
if (j > 0) {
long mn = f[i][j - 1][0], mx = f[i][j - 1][1];
resMin = Math.min(resMin, Math.min(mn * x, mx * x));
resMax = Math.max(resMax, Math.max(mn * x, mx * x));
}
f[i][j][0] = resMin;
f[i][j][1] = resMax;
}
}
long ans = f[m - 1][n - 1][1];
return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
}
}
###cpp
class Solution {
public:
int maxProductPath(vector<vector<int>>& grid) {
int m = grid.size(), n = grid[0].size();
vector f(m, vector<array<long long, 2>>(n));
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
long long x = grid[i][j];
if (i == 0 && j == 0) {
f[0][0] = {x, x};
continue;
}
long long res_min = LLONG_MAX;
long long res_max = LLONG_MIN;
if (i > 0) {
auto [mn, mx] = f[i - 1][j];
res_min = min(mn * x, mx * x);
res_max = max(mn * x, mx * x);
}
if (j > 0) {
auto [mn, mx] = f[i][j - 1];
res_min = min(res_min, min(mn * x, mx * x));
res_max = max(res_max, max(mn * x, mx * x));
}
f[i][j] = {res_min, res_max};
}
}
long long ans = f[m - 1][n - 1][1];
return ans < 0 ? -1 : ans % 1'000'000'007;
}
};
###go
func maxProductPath(grid [][]int) int {
m, n := len(grid), len(grid[0])
f := make([][][2]int, m)
for i := range f {
f[i] = make([][2]int, n)
}
for i, row := range grid {
for j, x := range row {
if i == 0 && j == 0 {
f[0][0] = [2]int{x, x}
continue
}
resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := f[i-1][j][0], f[i-1][j][1]
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := f[i][j-1][0], f[i][j-1][1]
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}
f[i][j] = [2]int{resMin, resMax}
}
}
ans := f[m-1][n-1][1]
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}
原理见 64 题 我的题解。
###py
class Solution:
def maxProductPath(self, grid: List[List[int]]) -> int:
n = len(grid[0])
f_min = [0] * n
f_max = [0] * n
for i, row in enumerate(grid):
for j, x in enumerate(row):
if i == j == 0:
f_min[0] = f_max[0] = x
continue
res_min, res_max = inf, -inf
if i > 0:
mn, mx = f_min[j], f_max[j]
res_min = min(mn * x, mx * x)
res_max = max(mn * x, mx * x)
if j > 0:
mn, mx = f_min[j - 1], f_max[j - 1]
res_min = min(res_min, mn * x, mx * x)
res_max = max(res_max, mn * x, mx * x)
f_min[j] = res_min
f_max[j] = res_max
ans = f_max[-1]
return -1 if ans < 0 else ans % 1_000_000_007
###java
class Solution {
public int maxProductPath(int[][] grid) {
int m = grid.length, n = grid[0].length;
long[] fMin = new long[n];
long[] fMax = new long[n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
long x = grid[i][j];
if (i == 0 && j == 0) {
fMin[0] = fMax[0] = x;
continue;
}
long resMin = Long.MAX_VALUE;
long resMax = Long.MIN_VALUE;
if (i > 0) {
long mn = fMin[j], mx = fMax[j];
resMin = Math.min(mn * x, mx * x);
resMax = Math.max(mn * x, mx * x);
}
if (j > 0) {
long mn = fMin[j - 1], mx = fMax[j - 1];
resMin = Math.min(resMin, Math.min(mn * x, mx * x));
resMax = Math.max(resMax, Math.max(mn * x, mx * x));
}
fMin[j] = resMin;
fMax[j] = resMax;
}
}
long ans = fMax[n - 1];
return ans < 0 ? -1 : (int) (ans % 1_000_000_007);
}
}
###cpp
class Solution {
public:
int maxProductPath(vector<vector<int>>& grid) {
int m = grid.size(), n = grid[0].size();
vector<long long> f_min(n), f_max(n);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
long long x = grid[i][j];
if (i == 0 && j == 0) {
f_min[0] = f_max[0] = x;
continue;
}
long long res_min = LLONG_MAX;
long long res_max = LLONG_MIN;
if (i > 0) {
long long mn = f_min[j], mx = f_max[j];
res_min = min(mn * x, mx * x);
res_max = max(mn * x, mx * x);
}
if (j > 0) {
long long mn = f_min[j - 1], mx = f_max[j - 1];
res_min = min(res_min, min(mn * x, mx * x));
res_max = max(res_max, max(mn * x, mx * x));
}
f_min[j] = res_min;
f_max[j] = res_max;
}
}
long long ans = f_max[n - 1];
return ans < 0 ? -1 : ans % 1'000'000'007;
}
};
###go
func maxProductPath(grid [][]int) int {
n := len(grid[0])
fMin := make([]int, n)
fMax := make([]int, n)
for i, row := range grid {
for j, x := range row {
if i == 0 && j == 0 {
fMin[0], fMax[0] = x, x
continue
}
resMin := math.MaxInt
resMax := math.MinInt
if i > 0 {
mn, mx := fMin[j], fMax[j]
resMin = min(mn*x, mx*x)
resMax = max(mn*x, mx*x)
}
if j > 0 {
mn, mx := fMin[j-1], fMax[j-1]
resMin = min(resMin, mn*x, mx*x)
resMax = max(resMax, mn*x, mx*x)
}
fMin[j] = resMin
fMax[j] = resMax
}
}
ans := fMax[n-1]
if ans < 0 {
return -1
}
return ans % 1_000_000_007
}
见下面动态规划题单的「二、网格图 DP」。
欢迎关注 B站@灵茶山艾府
枚举 $\textit{mat}$ 旋转 $0,1,2,3$ 次,判断旋转后的 $\textit{mat}$ 是否等于 $\textit{target}$。
class Solution:
# 48. 旋转图像
def rotate(self, matrix: List[List[int]]) -> None:
n = len(matrix)
for i, row in enumerate(matrix):
for j in range(i + 1, n): # 遍历对角线上方元素,做转置
row[j], matrix[j][i] = matrix[j][i], row[j]
row.reverse() # 行翻转
def findRotation(self, mat: List[List[int]], target: List[List[int]]) -> bool:
for _ in range(4):
if mat == target:
return True
self.rotate(mat)
return False
class Solution {
public boolean findRotation(int[][] mat, int[][] target) {
for (int i = 0; i < 4; i++) {
if (Arrays.deepEquals(mat, target)) {
return true;
}
rotate(mat);
}
return false;
}
// 48. 旋转图像
public void rotate(int[][] matrix) {
int n = matrix.length;
for (int i = 0; i < n; i++) {
int[] row = matrix[i];
for (int j = i + 1; j < n; j++) { // 遍历对角线上方元素,做转置
int tmp = row[j];
row[j] = matrix[j][i];
matrix[j][i] = tmp;
}
for (int j = 0; j < n / 2; j++) { // 遍历左半元素,做行翻转
int tmp = row[j];
row[j] = row[n - 1 - j];
row[n - 1 - j] = tmp;
}
}
}
}
class Solution {
// 48. 旋转图像
void rotate(vector<vector<int>>& matrix) {
int n = matrix.size();
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) { // 遍历对角线上方元素,做转置
swap(matrix[i][j], matrix[j][i]);
}
ranges::reverse(matrix[i]); // 行翻转
}
}
public:
bool findRotation(vector<vector<int>>& mat, vector<vector<int>>& target) {
for (int i = 0; i < 4; i++) {
if (mat == target) {
return true;
}
rotate(mat);
}
return false;
}
};
// 48. 旋转图像
func rotate(matrix [][]int) {
n := len(matrix)
for i, row := range matrix {
for j := i + 1; j < n; j++ { // 遍历对角线上方元素,做转置
row[j], matrix[j][i] = matrix[j][i], row[j]
}
slices.Reverse(row) // 行翻转
}
}
func findRotation(mat, target [][]int) bool {
for range 4 {
if slices.EqualFunc(mat, target, slices.Equal[[]int]) {
return true
}
rotate(mat)
}
return false
}
顺时针旋转 $90^\circ$ 后,位于 $(i,j)$ 的元素去哪了?
根据 48 题 我的题解,结论如下:
$$
(i,j)\xrightarrow{旋转\ 90^\circ} (j,n-1-i) \xrightarrow{旋转\ 90^\circ} (n-1-i,n-1-j) \xrightarrow{旋转\ 90^\circ} (n-1-j,i)
$$
所以对于 $\textit{mat}[i][j]$,它需要比较四个位置上的值:
如果对于某个旋转次数,所有的比较都为真,那么返回 $\texttt{true}$。否则返回 $\texttt{false}$。
class Solution:
def findRotation(self, mat: List[List[int]], target: List[List[int]]) -> bool:
ok = (1 << 4) - 1 # ok = [True] * 4
for i, row in enumerate(mat):
for j, x in enumerate(row):
if x != target[i][j]:
ok &= ~(1 << 0) # ok[0] = False
if x != target[j][-1 - i]:
ok &= ~(1 << 1) # ok[1] = False
if x != target[-1 - i][-1 - j]:
ok &= ~(1 << 2) # ok[2] = False
if x != target[-1 - j][i]:
ok &= ~(1 << 3) # ok[3] = False
if ok == 0: # 所有的 ok[i] 都是 False
return False
return True # 至少有一个 ok[i] 是 True
class Solution {
public boolean findRotation(int[][] mat, int[][] target) {
int n = mat.length;
int ok = (1 << 4) - 1; // boolean[] ok = {true, true, true, true};
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int x = mat[i][j];
if (x != target[i][j]) {
ok &= ~(1 << 0); // ok[0] = false
}
if (x != target[j][n - 1 - i]) {
ok &= ~(1 << 1); // ok[1] = false
}
if (x != target[n - 1 - i][n - 1 - j]) {
ok &= ~(1 << 2); // ok[2] = false
}
if (x != target[n - 1 - j][i]) {
ok &= ~(1 << 3); // ok[3] = false
}
if (ok == 0) { // 所有的 ok[i] 都是 false
return false;
}
}
}
return true; // 至少有一个 ok[i] 是 true
}
}
class Solution {
public:
bool findRotation(vector<vector<int>>& mat, vector<vector<int>>& target) {
int n = mat.size();
int ok = (1 << 4) - 1; // bool ok[4] = {true, true, true, true}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int x = mat[i][j];
if (x != target[i][j]) {
ok &= ~(1 << 0); // ok[0] = false
}
if (x != target[j][n - 1 - i]) {
ok &= ~(1 << 1); // ok[1] = false
}
if (x != target[n - 1 - i][n - 1 - j]) {
ok &= ~(1 << 2); // ok[2] = false
}
if (x != target[n - 1 - j][i]) {
ok &= ~(1 << 3); // ok[3] = false
}
if (ok == 0) { // 所有的 ok[i] 都是 false
return false;
}
}
}
return true; // 至少有一个 ok[i] 是 true
}
};
func findRotation(mat, target [][]int) bool {
n := len(mat)
ok := 1<<4 - 1 // ok := [4]bool{true, true, true, true}
for i, row := range mat {
for j, x := range row {
if x != target[i][j] {
ok &^= 1 << 0 // ok[0] = false
}
if x != target[j][n-1-i] {
ok &^= 1 << 1 // ok[1] = false
}
if x != target[n-1-i][n-1-j] {
ok &^= 1 << 2 // ok[2] = false
}
if x != target[n-1-j][i] {
ok &^= 1 << 3 // ok[3] = false
}
if ok == 0 { // 所有的 ok[i] 都是 false
return false
}
}
}
return true // 至少有一个 ok[i] 是 true
}
欢迎关注 B站@灵茶山艾府
根据题意,交换的范围是行号 $[x,x+k-1]$,列号 $[y,y+k-1]$。
类似 344. 反转字符串,用双指针实现:
具体请看 视频讲解,欢迎点赞关注~
###py
class Solution:
def reverseSubmatrix(self, grid: List[List[int]], x: int, y: int, k: int) -> List[List[int]]:
l, r = x, x + k - 1
while l < r:
for j in range(y, y + k):
grid[l][j], grid[r][j] = grid[r][j], grid[l][j]
l += 1
r -= 1
return grid
###py
class Solution:
def reverseSubmatrix(self, grid: List[List[int]], x: int, y: int, k: int) -> List[List[int]]:
l, r = x, x + k - 1
while l < r:
grid[l][y: y + k], grid[r][y: y + k] = grid[r][y: y + k], grid[l][y: y + k]
l += 1
r -= 1
return grid
###java
class Solution {
public int[][] reverseSubmatrix(int[][] grid, int x, int y, int k) {
int l = x;
int r = x + k - 1;
while (l < r) {
for (int j = y; j < y + k; j++) {
int tmp = grid[l][j];
grid[l][j] = grid[r][j];
grid[r][j] = tmp;
}
l++;
r--;
}
return grid;
}
}
###cpp
class Solution {
public:
vector<vector<int>> reverseSubmatrix(vector<vector<int>>& grid, int x, int y, int k) {
int l = x, r = x + k - 1;
while (l < r) {
for (int j = y; j < y + k; j++) {
swap(grid[l][j], grid[r][j]);
}
l++;
r--;
}
return grid;
}
};
###go
func reverseSubmatrix(grid [][]int, x, y, k int) [][]int {
l, r := x, x+k-1
for l < r {
for j := y; j < y+k; j++ {
grid[l][j], grid[r][j] = grid[r][j], grid[l][j]
}
l++
r--
}
return grid
}
暴力枚举所有子矩形。把子矩形中的所有元素添加到一个数组 $a$ 中,然后把 $a$ 排序。排序后,不同元素之差的最小值一定来自 $a$ 的相邻元素,计算相邻不同元素之差的最小值。
本题视频讲解,欢迎点赞关注~
###py
class Solution:
def minAbsDiff(self, grid: List[List[int]], k: int) -> List[List[int]]:
m, n = len(grid), len(grid[0])
ans = [[0] * (n - k + 1) for _ in range(m - k + 1)]
for i in range(m - k + 1):
sub_grid = grid[i: i + k]
for j in range(n - k + 1):
a = []
for row in sub_grid:
a += row[j: j + k]
a.sort()
res = inf
for x, y in pairwise(a):
if x < y: # 题目要求相减的两个数必须不同
res = min(res, y - x)
if res < inf:
ans[i][j] = res
return ans
###java
class Solution {
public int[][] minAbsDiff(int[][] grid, int k) {
int m = grid.length;
int n = grid[0].length;
int[][] ans = new int[m - k + 1][n - k + 1];
int[] a = new int[k * k];
for (int i = 0; i <= m - k; i++) {
for (int j = 0; j <= n - k; j++) {
int idx = 0;
for (int x = 0; x < k; x++) {
for (int y = 0; y < k; y++) {
a[idx++] = grid[i + x][j + y];
}
}
Arrays.sort(a);
int res = Integer.MAX_VALUE;
for (int p = 1; p < a.length; p++) {
if (a[p] > a[p - 1]) { // 题目要求相减的两个数必须不同
res = Math.min(res, a[p] - a[p - 1]);
}
}
if (res < Integer.MAX_VALUE) {
ans[i][j] = res;
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
vector<vector<int>> minAbsDiff(vector<vector<int>>& grid, int k) {
int m = grid.size(), n = grid[0].size();
vector ans(m - k + 1, vector<int>(n - k + 1));
for (int i = 0; i <= m - k; i++) {
for (int j = 0; j <= n - k; j++) {
vector<int> a;
for (int x = 0; x < k; x++) {
for (int y = 0; y < k; y++) {
a.push_back(grid[i + x][j + y]);
}
}
ranges::sort(a);
int res = INT_MAX;
for (int p = 1; p < a.size(); p++) {
if (a[p] > a[p - 1]) { // 题目要求相减的两个数必须不同
res = min(res, a[p] - a[p - 1]);
}
}
if (res < INT_MAX) {
ans[i][j] = res;
}
}
}
return ans;
}
};
###go
func minAbsDiff(grid [][]int, k int) [][]int {
m, n := len(grid), len(grid[0])
ans := make([][]int, m-k+1)
arr := make([]int, k*k)
for i := range ans {
ans[i] = make([]int, n-k+1)
for j := range ans[i] {
a := arr[:0] // 避免反复 make
for _, row := range grid[i : i+k] {
a = append(a, row[j:j+k]...)
}
slices.Sort(a)
res := math.MaxInt
for p := 1; p < len(a); p++ {
if a[p] > a[p-1] { // 题目要求相减的两个数必须不同
res = min(res, a[p]-a[p-1])
}
}
if res < math.MaxInt {
ans[i][j] = res
}
}
}
return ans
}
注:考虑用定长滑动窗口 + 有序集合 + 懒删除堆,用有序集合维护窗口(子矩阵)元素,用懒删除堆维护相邻不同元素之差。添加删除的时候更新相邻不同元素之差。
这样可以做到 $\mathcal{O}((m-k)nk\log k)$,但常数比较大。
前置知识:【图解】二维前缀和。
本题相当于统计有多少个二维前缀和 $\le k$。
###py
class Solution:
def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
m, n = len(grid), len(grid[0])
s = [[0] * (n + 1) for _ in range(m + 1)]
ans = 0
for i, row in enumerate(grid):
for j, x in enumerate(row):
s[i + 1][j + 1] = s[i + 1][j] + s[i][j + 1] - s[i][j] + x
if s[i + 1][j + 1] <= k:
ans += 1
return ans
###java
class Solution {
public int countSubmatrices(int[][] grid, int k) {
int m = grid.length;
int n = grid[0].length;
int[][] sum = new int[m + 1][n + 1];
int ans = 0;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
if (sum[i + 1][j + 1] <= k) {
ans++;
}
}
}
return ans;
}
}
###cpp
class Solution {
public:
int countSubmatrices(vector<vector<int>>& grid, int k) {
int m = grid.size(), n = grid[0].size();
vector sum(m + 1, vector<int>(n + 1));
int ans = 0;
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
sum[i + 1][j + 1] = sum[i + 1][j] + sum[i][j + 1] - sum[i][j] + grid[i][j];
ans += sum[i + 1][j + 1] <= k;
}
}
return ans;
}
};
###go
func countSubmatrices(grid [][]int, k int) (ans int) {
m, n := len(grid), len(grid[0])
sum := make([][]int, m+1)
sum[0] = make([]int, n+1)
for i, row := range grid {
sum[i+1] = make([]int, n+1)
for j, x := range row {
sum[i+1][j+1] = sum[i+1][j] + sum[i][j+1] - sum[i][j] + x
if sum[i+1][j+1] <= k {
ans++
}
}
}
return
}
注:如果原地计算二维前缀和,可以做到 $\mathcal{O}(1)$ 额外空间。
遍历每一行,同时用一个长为 $n$ 的数组 $\textit{colSum}$ 维护每一列的元素和。
遍历当前行时,一边更新 $\textit{colSum}[j]$,一边累加 $\textit{colSum}[j]$ 到变量 $s$ 中。
如果 $s\le k$ 则把答案加一,否则可以跳出循环(因为矩阵元素都非负)。
###py
class Solution:
def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
col_sum = [0] * len(grid[0])
ans = 0
for row in grid:
s = 0
for j, x in enumerate(row):
col_sum[j] += x
s += col_sum[j]
if s > k:
break
ans += 1
return ans
###java
class Solution {
public int countSubmatrices(int[][] grid, int k) {
int n = grid[0].length;
int[] colSum = new int[n];
int ans = 0;
for (int[] row : grid) {
int s = 0;
for (int j = 0; j < n; j++) {
colSum[j] += row[j];
s += colSum[j];
if (s > k) {
break;
}
ans++;
}
}
return ans;
}
}
###cpp
class Solution {
public:
int countSubmatrices(vector<vector<int>>& grid, int k) {
int n = grid[0].size();
vector<int> col_sum(n);
int ans = 0;
for (auto& row : grid) {
int s = 0;
for (int j = 0; j < n; j++) {
col_sum[j] += row[j];
s += col_sum[j];
if (s > k) {
break;
}
ans++;
}
}
return ans;
}
};
###go
func countSubmatrices(grid [][]int, k int) (ans int) {
colSum := make([]int, len(grid[0]))
for _, row := range grid {
s := 0
for j, x := range row {
colSum[j] += x
s += colSum[j]
if s > k {
break
}
ans++
}
}
return
}
注:如果把每列元素和保存到 $\textit{grid}$ 的第一行,可以做到 $\mathcal{O}(1)$ 额外空间。
欢迎关注 B站@灵茶山艾府
做法类似 85. 最大矩形,枚举子矩形的底边(最后一行),定义 $\textit{heights}[j]$ 表示从 $\textit{matrix}[i][j]$ 往上有多少个连续的 $1$(柱子的高度),问题变成:
对于示例 1,以第三行为底边算出来的 $\textit{heights} = [2,0,3]$,下图重排后是 $[2,3,0]$。其中子数组 $[2,3]$,长为 $2$,最小值为 $2$,所以对应的子矩形面积为 $2\times 2 = 4$。
{:width=430px}
如何找到面积最大的子矩形?还是枚举。
枚举子数组的长度 $k = 1,2,\ldots,n$。由于我们可以重排 $\textit{heights}$,那么贪心地,把 $\textit{heights}$ 最大的 $k$ 个数排在一起,就可以让子数组的最小值(矩形的高)尽量大,从而得到最大的矩形面积。
对于 $\textit{heights}$ 的计算,如果 $\textit{matrix}[i][j]=0$,那么 $\textit{heights}[j] = 0$。否则,把高度增加 $1$。形象地说,就是在柱子下面垫一块石头,把柱子抬高。
###py
class Solution:
def largestSubmatrix(self, matrix: List[List[int]]) -> int:
n = len(matrix[0])
heights = [0] * n
ans = 0
for row in matrix: # 枚举子矩形的底边
for j, x in enumerate(row):
if x == 0:
heights[j] = 0
else:
heights[j] += 1
hs = sorted(heights) # 复制一份 heights 再排序
for i, h in enumerate(hs): # 把 hs[i:] 作为子数组
# 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
ans = max(ans, (n - i) * h)
return ans
###java
class Solution {
public int largestSubmatrix(int[][] matrix) {
int n = matrix[0].length;
int[] heights = new int[n];
int ans = 0;
for (int[] row : matrix) { // 枚举子矩形的底边
for (int j = 0; j < n; j++) {
if (row[j] == 0) {
heights[j] = 0;
} else {
heights[j]++;
}
}
int[] hs = heights.clone();
Arrays.sort(hs);
for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
// 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
ans = Math.max(ans, (n - i) * hs[i]);
}
}
return ans;
}
}
###cpp
class Solution {
public:
int largestSubmatrix(vector<vector<int>>& matrix) {
int n = matrix[0].size();
vector<int> heights(n);
int ans = 0;
for (auto& row : matrix) { // 枚举子矩形的底边
for (int j = 0; j < n; j++) {
int x = row[j];
if (x == 0) {
heights[j] = 0;
} else {
heights[j]++;
}
}
auto hs = heights;
ranges::sort(hs);
for (int i = 0; i < n; i++) { // 把 [i,n-1] 作为子数组
// 子数组长为 n-i,最小值为 hs[i],对应的子矩形面积为 (n-i)*hs[i]
ans = max(ans, (n - i) * hs[i]);
}
}
return ans;
}
};
###go
func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)
for _, row := range matrix { // 枚举子矩形的底边
for j, x := range row {
if x == 0 {
heights[j] = 0
} else {
heights[j]++
}
}
hs := slices.Clone(heights)
slices.Sort(hs)
for i, h := range hs { // 把 hs[i:] 作为子数组
ans = max(ans, (n-i)*h) // 子数组长为 n-i,最小值为 h,对应的子矩形面积为 (n-i)*h
}
}
return
}
考察从 $i-1$ 行到 $i$ 行,$\textit{heights}$ 会如何变化:
举个例子。假设 $i-1$ 行的 $\textit{heights}$ 排序后是 $[0,{\color{red}0},{\color{red}0},1,{\color{red}2},{\color{red}3}]$,把红色数字加一,其余数字变成 $0$,得到 $[0,{\color{red}1},{\color{red}1},0,{\color{red}3},{\color{red}4}]$。把 $0$ 排在红色数字前面,得到 $[0,0,{\color{red}1},{\color{red}1},{\color{red}3},{\color{red}4}]$。注意红色数字的相对大小是不变的,无需再次排序。
一般地,如果已知 $i-1$ 行的 $\textit{heights}$ 排序后的结果,那么对于 $i$ 行,我们只需把高度变成 $0$ 的数据排在前面,其余(增加一的)高度的相对大小不变,无需再次排序。这样就可以把排序的时间从 $\mathcal{O}(n\log n)$ 优化成 $\mathcal{O}(n)$。
但是,如果直接对 $\textit{heights}$ 排序,我们就不知道每个高度对应矩阵的哪一列了。如何解决?创建一个 $0$ 到 $n-1$ 的下标数组(列号数组)$\textit{idx}$,对下标数组排序。
###py
class Solution:
def largestSubmatrix(self, matrix: List[List[int]]) -> int:
n = len(matrix[0])
heights = [0] * n
idx = list(range(n)) # 按照高度排序后的列号
ans = 0
for row in matrix:
zeros = []
non_zeros = []
for j in idx:
if row[j] == 0:
heights[j] = 0
zeros.append(j)
else:
heights[j] += 1
non_zeros.append(j)
idx = zeros + non_zeros # 把高度为 0 的列号排在其他高度前面
# heights[idx[i]] 是递增的
for i in range(len(zeros), n): # 高度 0 无需计算
ans = max(ans, (n - i) * heights[idx[i]])
return ans
###java
class Solution {
public int largestSubmatrix(int[][] matrix) {
int n = matrix[0].length;
int[] heights = new int[n];
int[] idx = new int[n]; // 按照高度排序后的列号
for (int i = 0; i < n; i++) {
idx[i] = i;
}
int[] nonZeros = new int[n]; // 避免在循环内反复申请内存
int ans = 0;
for (int[] row : matrix) {
int p = 0;
int q = 0;
for (int j : idx) {
if (row[j] == 0) {
heights[j] = 0;
idx[p++] = j; // 高度 0 排在前面
} else {
heights[j]++;
nonZeros[q++] = j;
}
}
// heights[idx[i]] 是递增的
for (int i = p; i < n; i++) { // 高度 0 无需计算
idx[i] = nonZeros[i - p]; // 把 nonZeros 复制到 idx 的 [p,n-1] 中
ans = Math.max(ans, (n - i) * heights[idx[i]]);
}
}
return ans;
}
}
###cpp
class Solution {
public:
int largestSubmatrix(vector<vector<int>>& matrix) {
int n = matrix[0].size();
vector<int> heights(n);
vector<int> idx(n); // 按照高度排序后的列号
ranges::iota(idx, 0); // idx[i] = i
vector<int> non_zeros(n); // 避免在循环内反复申请内存
int ans = 0;
for (auto& row : matrix) {
int p = 0, q = 0;
for (int j : idx) {
if (row[j] == 0) {
heights[j] = 0;
idx[p++] = j; // 高度 0 排在前面
} else {
heights[j]++;
non_zeros[q++] = j;
}
}
// heights[idx[i]] 是递增的
for (int i = p; i < n; i++) { // 高度 0 无需计算
idx[i] = non_zeros[i - p]; // 把 non_zeros 复制到 idx 的 [p,n-1] 中
ans = max(ans, (n - i) * heights[idx[i]]);
}
}
return ans;
}
};
###go
func largestSubmatrix(matrix [][]int) (ans int) {
n := len(matrix[0])
heights := make([]int, n)
idx := make([]int, n) // 按照高度排序后的列号
for i := range idx {
idx[i] = i
}
_nonZeros := make([]int, n) // 避免在循环内反复申请内存
for _, row := range matrix {
zeros := idx[:0]
nonZeros := _nonZeros[:0]
for _, j := range idx {
if row[j] == 0 {
heights[j] = 0
zeros = append(zeros, j)
} else {
heights[j]++
nonZeros = append(nonZeros, j)
}
}
idx = append(zeros, nonZeros...) // 把高度为 0 的列号排在其他高度前面
// heights[idx[i]] 是递增的
for i := len(zeros); i < n; i++ { // 高度 0 无需计算
ans = max(ans, (n-i)*heights[idx[i]])
}
}
return
}
见下面贪心题单的「§1.6 先枚举,再贪心」。
欢迎关注 B站@灵茶山艾府