[Python3/Java/C++/Go/TypeScript] 一题一解:树形动态规划(清晰题解)
方法一:树形动态规划
对每个节点 $u$,我们维护一个二维数组 $f_u[j][pre]$,表示在以 $u$ 为根的子树中,预算不超过 $j$ 且 $u$ 的上司是否购买了股票(其中 $pre=1$ 表示购买,而 $pre=0$ 表示未购买)的情况下,可以获得的最大利润。那么答案就是 $f_1[\text{budget}][0]$。
对节点 $u$,函数 $\text{dfs}(u)$ 返回一个 $(\text{budget}+1) \times 2$ 的二维数组 $f$,表示在以 $u$ 为根的子树中,不超过预算 $j$ 且 $u$ 的上司是否购买了股票的情况下,可以获得的最大利润。
对 $u$,我们要考虑两件事:
- 节点 $u$ 本身是否买股票(会占用一部分预算 $\text{cost}$,其中 $\text{cost} = \lfloor \text{present}[u] / (pre + 1) \rfloor$)。并增加利润 $\text{future}[u] - \text{cost}$。
- 节点 $u$ 的子节点 $v$ 如何分配预算以最大化利润。我们把每个子节点的 $\text{dfs}(v)$ 看成“物品”,用背包把子树的利润合并到当前 $u$ 的 $\text{nxt}$ 数组中。
具体实现时,我们先初始化一个 $(\text{budget}+1) \times 2$ 的二维数组 $\text{nxt}$,表示当前已经合并了子节点的利润。然后对于每个子节点 $v$,我们递归调用 $\text{dfs}(v)$ 得到子节点的利润数组 $\text{fv}$,并用背包把 $\text{fv}$ 合并到 $\text{nxt}$ 中。
合并公式为:
$$
\text{nxt}[j][pre] = \max(\text{nxt}[j][pre], \text{nxt}[j - j_v][pre] + \text{fv}[j_v][pre])
$$
其中 $j_v$ 表示分配给子节点 $v$ 的预算。
合并完所有子节点后的 $\text{nxt}[j][pre]$ 表示在 $u$ 本身尚未决定是否购买股票的情况下,且 $u$ 的上次购买状态为 $pre$ 时,把预算 $j$ 全部用于子节点所能获得的最大利润。
最后,我们决定 $u$ 是否购买股票。
- 如果 $j \lt \text{cost}$,则 $u$ 无法购买股票,此时 $f[j][pre] = \text{nxt}[j][0]$。
- 如果 $j \geq \text{cost}$,则 $u$ 可以选择购买或不购买股票,此时 $f[j][pre] = \max(\text{nxt}[j][0], \text{nxt}[j - \text{cost}][1] + (\text{future}[u] - \text{cost}))$。
最后返回 $f$ 即可。
答案为 $\text{dfs}(1)[\text{budget}][0]$。
###python
class Solution:
def maxProfit(
self,
n: int,
present: List[int],
future: List[int],
hierarchy: List[List[int]],
budget: int,
) -> int:
max = lambda a, b: a if a > b else b
g = [[] for _ in range(n + 1)]
for u, v in hierarchy:
g[u].append(v)
def dfs(u: int):
nxt = [[0, 0] for _ in range(budget + 1)]
for v in g[u]:
fv = dfs(v)
for j in range(budget, -1, -1):
for jv in range(j + 1):
for pre in (0, 1):
val = nxt[j - jv][pre] + fv[jv][pre]
if val > nxt[j][pre]:
nxt[j][pre] = val
f = [[0, 0] for _ in range(budget + 1)]
price = future[u - 1]
for j in range(budget + 1):
for pre in (0, 1):
cost = present[u - 1] // (pre + 1)
if j >= cost:
f[j][pre] = max(nxt[j][0], nxt[j - cost][1] + (price - cost))
else:
f[j][pre] = nxt[j][0]
return f
return dfs(1)[budget][0]
###java
class Solution {
private List<Integer>[] g;
private int[] present;
private int[] future;
private int budget;
public int maxProfit(int n, int[] present, int[] future, int[][] hierarchy, int budget) {
this.present = present;
this.future = future;
this.budget = budget;
g = new ArrayList[n + 1];
Arrays.setAll(g, k -> new ArrayList<>());
for (int[] e : hierarchy) {
g[e[0]].add(e[1]);
}
return dfs(1)[budget][0];
}
private int[][] dfs(int u) {
int[][] nxt = new int[budget + 1][2];
for (int v : g[u]) {
int[][] fv = dfs(v);
for (int j = budget; j >= 0; j--) {
for (int jv = 0; jv <= j; jv++) {
for (int pre = 0; pre < 2; pre++) {
int val = nxt[j - jv][pre] + fv[jv][pre];
if (val > nxt[j][pre]) {
nxt[j][pre] = val;
}
}
}
}
}
int[][] f = new int[budget + 1][2];
int price = future[u - 1];
for (int j = 0; j <= budget; j++) {
for (int pre = 0; pre < 2; pre++) {
int cost = present[u - 1] / (pre + 1);
if (j >= cost) {
f[j][pre] = Math.max(nxt[j][0], nxt[j - cost][1] + (price - cost));
} else {
f[j][pre] = nxt[j][0];
}
}
}
return f;
}
}
###cpp
class Solution {
public:
int maxProfit(int n, vector<int>& present, vector<int>& future, vector<vector<int>>& hierarchy, int budget) {
vector<vector<int>> g(n + 1);
for (auto& e : hierarchy) {
g[e[0]].push_back(e[1]);
}
auto dfs = [&](const auto& dfs, int u) -> vector<array<int, 2>> {
vector<array<int, 2>> nxt(budget + 1);
for (int j = 0; j <= budget; j++) nxt[j] = {0, 0};
for (int v : g[u]) {
auto fv = dfs(dfs, v);
for (int j = budget; j >= 0; j--) {
for (int jv = 0; jv <= j; jv++) {
for (int pre = 0; pre < 2; pre++) {
int val = nxt[j - jv][pre] + fv[jv][pre];
if (val > nxt[j][pre]) {
nxt[j][pre] = val;
}
}
}
}
}
vector<array<int, 2>> f(budget + 1);
int price = future[u - 1];
for (int j = 0; j <= budget; j++) {
for (int pre = 0; pre < 2; pre++) {
int cost = present[u - 1] / (pre + 1);
if (j >= cost) {
f[j][pre] = max(nxt[j][0], nxt[j - cost][1] + (price - cost));
} else {
f[j][pre] = nxt[j][0];
}
}
}
return f;
};
return dfs(dfs, 1)[budget][0];
}
};
###go
func maxProfit(n int, present []int, future []int, hierarchy [][]int, budget int) int {
g := make([][]int, n+1)
for _, e := range hierarchy {
u, v := e[0], e[1]
g[u] = append(g[u], v)
}
var dfs func(u int) [][2]int
dfs = func(u int) [][2]int {
nxt := make([][2]int, budget+1)
for _, v := range g[u] {
fv := dfs(v)
for j := budget; j >= 0; j-- {
for jv := 0; jv <= j; jv++ {
for pre := 0; pre < 2; pre++ {
nxt[j][pre] = max(nxt[j][pre], nxt[j-jv][pre]+fv[jv][pre])
}
}
}
}
f := make([][2]int, budget+1)
price := future[u-1]
for j := 0; j <= budget; j++ {
for pre := 0; pre < 2; pre++ {
cost := present[u-1] / (pre + 1)
if j >= cost {
buyProfit := nxt[j-cost][1] + (price - cost)
f[j][pre] = max(nxt[j][0], buyProfit)
} else {
f[j][pre] = nxt[j][0]
}
}
}
return f
}
return dfs(1)[budget][0]
}
###ts
function maxProfit(
n: number,
present: number[],
future: number[],
hierarchy: number[][],
budget: number,
): number {
const g: number[][] = Array.from({ length: n + 1 }, () => []);
for (const [u, v] of hierarchy) {
g[u].push(v);
}
const dfs = (u: number): number[][] => {
const nxt: number[][] = Array.from({ length: budget + 1 }, () => [0, 0]);
for (const v of g[u]) {
const fv = dfs(v);
for (let j = budget; j >= 0; j--) {
for (let jv = 0; jv <= j; jv++) {
for (let pre = 0; pre < 2; pre++) {
nxt[j][pre] = Math.max(nxt[j][pre], nxt[j - jv][pre] + fv[jv][pre]);
}
}
}
}
const f: number[][] = Array.from({ length: budget + 1 }, () => [0, 0]);
const price = future[u - 1];
for (let j = 0; j <= budget; j++) {
for (let pre = 0; pre < 2; pre++) {
const cost = Math.floor(present[u - 1] / (pre + 1));
if (j >= cost) {
const profitIfBuy = nxt[j - cost][1] + (price - cost);
f[j][pre] = Math.max(nxt[j][0], profitIfBuy);
} else {
f[j][pre] = nxt[j][0];
}
}
}
return f;
};
return dfs(1)[budget][0];
}
时间复杂度 $O(n \times \text{budget}^2)$,空间复杂度 $O(n \times \text{budget})$。
有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈 😄~