阅读视图

发现新文章,点击刷新页面。

[Python3/Java/C++/Go/TypeScript] 一题一解:有序集合(清晰题解)

方法一:有序集合

题目需要我们将数组 $\textit{nums}$ 分割成 $k$ 个连续且不相交的子数组,并且第二个子数组的第一个元素与第 $k$ 个子数组的第一个元素的下标距离不超过 $\textit{dist}$,这等价于让我们从 $\textit{nums}$ 下标为 $1$ 的元素开始,找到一个窗口大小为 $\textit{dist}+1$ 的子数组,求其中前 $k-1$ 的最小元素之和。我们将 $k$ 减 $1$,这样我们只需要求出 $k$ 个最小元素之和,再加上 $\textit{nums}[0]$ 即可。

我们可以使用两个有序集合 $\textit{l}$ 和 $\textit{r}$ 分别维护大小为 $\textit{dist} + 1$ 的窗口元素,其中 $\textit{l}$ 维护最小的 $k$ 个元素,而 $\textit{r}$ 维护窗口的剩余元素。我们维护一个变量 $\textit{s}$ 表示 $\textit{nums}[0]$ 与 $l$ 中元素之和。初始时,我们将前 $\textit{dist}+2$ 个元素之和累加到 $\textit{s}$ 中,并将下标为 $[1, \textit{dist} + 1]$ 的所有元素加入到 $\textit{l}$ 中。如果 $\textit{l}$ 的大小大于 $k$,我们循环将 $\textit{l}$ 中的最大元素移动到 $\textit{r}$ 中,直到 $\textit{l}$ 的大小等于 $k$,过程中更新 $\textit{s}$ 的值。

那么此时初始答案 $\textit{ans} = \textit{s}$。

接下来我们从 $\textit{dist}+2$ 开始遍历 $\textit{nums}$,对于每一个元素 $\textit{nums}[i]$,我们需要将 $\textit{nums}[i-\textit{dist}-1]$ 从 $\textit{l}$ 或 $\textit{r}$ 中移除,然后将 $\textit{nums}[i]$ 加入到 $\textit{l}$ 或 $\textit{r}$ 中。如果 $\textit{nums}[i]$ 小于 $\textit{l}$ 中的最大元素,我们将 $\textit{nums}[i]$ 加入到 $\textit{l}$ 中,否则加入到 $\textit{r}$ 中。如果此时 $\textit{l}$ 的大小小于 $k$,我们将 $\textit{r}$ 中的最小元素移动到 $\textit{l}$ 中,直到 $\textit{l}$ 的大小等于 $k$。如果此时 $\textit{l}$ 的大小大于 $k$,我们将 $\textit{l}$ 中的最大元素移动到 $\textit{r}$ 中,直到 $\textit{l}$ 的大小等于 $k$。过程中更新 $\textit{s}$ 的值,并更新 $\textit{ans} = \min(\textit{ans}, \textit{s})$。

最终答案即为 $\textit{ans}$。

###python

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        def l2r():
            nonlocal s
            x = l.pop()
            s -= x
            r.add(x)

        def r2l():
            nonlocal s
            x = r.pop(0)
            l.add(x)
            s += x

        k -= 1
        s = sum(nums[: dist + 2])
        l = SortedList(nums[1 : dist + 2])
        r = SortedList()
        while len(l) > k:
            l2r()
        ans = s
        for i in range(dist + 2, len(nums)):
            x = nums[i - dist - 1]
            if x in l:
                l.remove(x)
                s -= x
            else:
                r.remove(x)
            y = nums[i]
            if y < l[-1]:
                l.add(y)
                s += y
            else:
                r.add(y)
            while len(l) < k:
                r2l()
            while len(l) > k:
                l2r()
            ans = min(ans, s)
        return ans

###java

class Solution {
    private final TreeMap<Integer, Integer> l = new TreeMap<>();
    private final TreeMap<Integer, Integer> r = new TreeMap<>();
    private long s;
    private int size;

    public long minimumCost(int[] nums, int k, int dist) {
        --k;
        s = nums[0];
        for (int i = 1; i < dist + 2; ++i) {
            s += nums[i];
            l.merge(nums[i], 1, Integer::sum);
        }
        size = dist + 1;
        while (size > k) {
            l2r();
        }
        long ans = s;
        for (int i = dist + 2; i < nums.length; ++i) {
            int x = nums[i - dist - 1];
            if (l.containsKey(x)) {
                if (l.merge(x, -1, Integer::sum) == 0) {
                    l.remove(x);
                }
                s -= x;
                --size;
            } else if (r.merge(x, -1, Integer::sum) == 0) {
                r.remove(x);
            }
            int y = nums[i];
            if (y < l.lastKey()) {
                l.merge(y, 1, Integer::sum);
                ++size;
                s += y;
            } else {
                r.merge(y, 1, Integer::sum);
            }
            while (size < k) {
                r2l();
            }
            while (size > k) {
                l2r();
            }
            ans = Math.min(ans, s);
        }
        return ans;
    }

    private void l2r() {
        int x = l.lastKey();
        s -= x;
        if (l.merge(x, -1, Integer::sum) == 0) {
            l.remove(x);
        }
        --size;
        r.merge(x, 1, Integer::sum);
    }

    private void r2l() {
        int x = r.firstKey();
        if (r.merge(x, -1, Integer::sum) == 0) {
            r.remove(x);
        }
        l.merge(x, 1, Integer::sum);
        s += x;
        ++size;
    }
}

###cpp

class Solution {
public:
    long long minimumCost(vector<int>& nums, int k, int dist) {
        --k;
        multiset<int> l(nums.begin() + 1, nums.begin() + dist + 2), r;
        long long s = accumulate(nums.begin(), nums.begin() + dist + 2, 0LL);
        while (l.size() > k) {
            int x = *l.rbegin();
            l.erase(l.find(x));
            s -= x;
            r.insert(x);
        }
        long long ans = s;
        for (int i = dist + 2; i < nums.size(); ++i) {
            int x = nums[i - dist - 1];
            auto it = l.find(x);
            if (it != l.end()) {
                l.erase(it);
                s -= x;
            } else {
                r.erase(r.find(x));
            }
            int y = nums[i];
            if (y < *l.rbegin()) {
                l.insert(y);
                s += y;
            } else {
                r.insert(y);
            }
            while (l.size() == k - 1) {
                int x = *r.begin();
                r.erase(r.find(x));
                l.insert(x);
                s += x;
            }
            while (l.size() == k + 1) {
                int x = *l.rbegin();
                l.erase(l.find(x));
                s -= x;
                r.insert(x);
            }
            ans = min(ans, s);
        }
        return ans;
    }
};

###go

func minimumCost(nums []int, k int, dist int) int64 {
k--
l := redblacktree.New[int, int]()
r := redblacktree.New[int, int]()
merge := func(st *redblacktree.Tree[int, int], x, v int) {
c, _ := st.Get(x)
if c+v == 0 {
st.Remove(x)
} else {
st.Put(x, c+v)
}
}

s := nums[0]
for _, x := range nums[1 : dist+2] {
s += x
merge(l, x, 1)
}
size := dist + 1

l2r := func() {
x := l.Right().Key
merge(l, x, -1)
s -= x
size--
merge(r, x, 1)
}
r2l := func() {
x := r.Left().Key
merge(r, x, -1)
merge(l, x, 1)
s += x
size++
}

for size > k {
l2r()
}

ans := s
for i := dist + 2; i < len(nums); i++ {
x := nums[i-dist-1]
if _, ok := l.Get(x); ok {
merge(l, x, -1)
s -= x
size--
} else {
merge(r, x, -1)
}
y := nums[i]
if y < l.Right().Key {
merge(l, y, 1)
s += y
size++
} else {
merge(r, y, 1)
}
for size < k {
r2l()
}
for size > k {
l2r()
}
ans = min(ans, s)

}
return int64(ans)
}

###ts

function minimumCost(nums: number[], k: number, dist: number): number {
    --k;
    const l = new TreapMultiSet<number>((a, b) => a - b);
    const r = new TreapMultiSet<number>((a, b) => a - b);
    let s = nums[0];
    for (let i = 1; i < dist + 2; ++i) {
        s += nums[i];
        l.add(nums[i]);
    }
    const l2r = () => {
        const x = l.pop()!;
        s -= x;
        r.add(x);
    };
    const r2l = () => {
        const x = r.shift()!;
        l.add(x);
        s += x;
    };
    while (l.size > k) {
        l2r();
    }
    let ans = s;
    for (let i = dist + 2; i < nums.length; ++i) {
        const x = nums[i - dist - 1];
        if (l.has(x)) {
            l.delete(x);
            s -= x;
        } else {
            r.delete(x);
        }
        const y = nums[i];
        if (y < l.last()!) {
            l.add(y);
            s += y;
        } else {
            r.add(y);
        }
        while (l.size < k) {
            r2l();
        }
        while (l.size > k) {
            l2r();
        }
        ans = Math.min(ans, s);
    }
    return ans;
}

type CompareFunction<T, R extends 'number' | 'boolean'> = (
    a: T,
    b: T,
) => R extends 'number' ? number : boolean;

interface ITreapMultiSet<T> extends Iterable<T> {
    add: (...value: T[]) => this;
    has: (value: T) => boolean;
    delete: (value: T) => void;

    bisectLeft: (value: T) => number;
    bisectRight: (value: T) => number;

    indexOf: (value: T) => number;
    lastIndexOf: (value: T) => number;

    at: (index: number) => T | undefined;
    first: () => T | undefined;
    last: () => T | undefined;

    lower: (value: T) => T | undefined;
    higher: (value: T) => T | undefined;
    floor: (value: T) => T | undefined;
    ceil: (value: T) => T | undefined;

    shift: () => T | undefined;
    pop: (index?: number) => T | undefined;

    count: (value: T) => number;

    keys: () => IterableIterator<T>;
    values: () => IterableIterator<T>;
    rvalues: () => IterableIterator<T>;
    entries: () => IterableIterator<[number, T]>;

    readonly size: number;
}

class TreapNode<T = number> {
    value: T;
    count: number;
    size: number;
    priority: number;
    left: TreapNode<T> | null;
    right: TreapNode<T> | null;

    constructor(value: T) {
        this.value = value;
        this.count = 1;
        this.size = 1;
        this.priority = Math.random();
        this.left = null;
        this.right = null;
    }

    static getSize(node: TreapNode<any> | null): number {
        return node?.size ?? 0;
    }

    static getFac(node: TreapNode<any> | null): number {
        return node?.priority ?? 0;
    }

    pushUp(): void {
        let tmp = this.count;
        tmp += TreapNode.getSize(this.left);
        tmp += TreapNode.getSize(this.right);
        this.size = tmp;
    }

    rotateRight(): TreapNode<T> {
        // eslint-disable-next-line @typescript-eslint/no-this-alias
        let node: TreapNode<T> = this;
        const left = node.left;
        node.left = left?.right ?? null;
        left && (left.right = node);
        left && (node = left);
        node.right?.pushUp();
        node.pushUp();
        return node;
    }

    rotateLeft(): TreapNode<T> {
        // eslint-disable-next-line @typescript-eslint/no-this-alias
        let node: TreapNode<T> = this;
        const right = node.right;
        node.right = right?.left ?? null;
        right && (right.left = node);
        right && (node = right);
        node.left?.pushUp();
        node.pushUp();
        return node;
    }
}

class TreapMultiSet<T = number> implements ITreapMultiSet<T> {
    private readonly root: TreapNode<T>;
    private readonly compareFn: CompareFunction<T, 'number'>;
    private readonly leftBound: T;
    private readonly rightBound: T;

    constructor(compareFn?: CompareFunction<T, 'number'>);
    constructor(compareFn: CompareFunction<T, 'number'>, leftBound: T, rightBound: T);
    constructor(
        compareFn: CompareFunction<T, any> = (a: any, b: any) => a - b,
        leftBound: any = -Infinity,
        rightBound: any = Infinity,
    ) {
        this.root = new TreapNode<T>(rightBound);
        this.root.priority = Infinity;
        this.root.left = new TreapNode<T>(leftBound);
        this.root.left.priority = -Infinity;
        this.root.pushUp();

        this.leftBound = leftBound;
        this.rightBound = rightBound;
        this.compareFn = compareFn;
    }

    get size(): number {
        return this.root.size - 2;
    }

    get height(): number {
        const getHeight = (node: TreapNode<T> | null): number => {
            if (node == null) return 0;
            return 1 + Math.max(getHeight(node.left), getHeight(node.right));
        };

        return getHeight(this.root);
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns true if value is a member.
     */
    has(value: T): boolean {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): boolean => {
            if (node == null) return false;
            if (compare(node.value, value) === 0) return true;
            if (compare(node.value, value) < 0) return dfs(node.right, value);
            return dfs(node.left, value);
        };

        return dfs(this.root, value);
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Add value to sorted set.
     */
    add(...values: T[]): this {
        const compare = this.compareFn;
        const dfs = (
            node: TreapNode<T> | null,
            value: T,
            parent: TreapNode<T>,
            direction: 'left' | 'right',
        ): void => {
            if (node == null) return;
            if (compare(node.value, value) === 0) {
                node.count++;
                node.pushUp();
            } else if (compare(node.value, value) > 0) {
                if (node.left) {
                    dfs(node.left, value, node, 'left');
                } else {
                    node.left = new TreapNode(value);
                    node.pushUp();
                }

                if (TreapNode.getFac(node.left) > node.priority) {
                    parent[direction] = node.rotateRight();
                }
            } else if (compare(node.value, value) < 0) {
                if (node.right) {
                    dfs(node.right, value, node, 'right');
                } else {
                    node.right = new TreapNode(value);
                    node.pushUp();
                }

                if (TreapNode.getFac(node.right) > node.priority) {
                    parent[direction] = node.rotateLeft();
                }
            }
            parent.pushUp();
        };

        values.forEach(value => dfs(this.root.left, value, this.root, 'left'));
        return this;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Remove value from sorted set if it is a member.
     * If value is not a member, do nothing.
     */
    delete(value: T): void {
        const compare = this.compareFn;
        const dfs = (
            node: TreapNode<T> | null,
            value: T,
            parent: TreapNode<T>,
            direction: 'left' | 'right',
        ): void => {
            if (node == null) return;

            if (compare(node.value, value) === 0) {
                if (node.count > 1) {
                    node.count--;
                    node?.pushUp();
                } else if (node.left == null && node.right == null) {
                    parent[direction] = null;
                } else {
                    // 旋到根节点
                    if (
                        node.right == null ||
                        TreapNode.getFac(node.left) > TreapNode.getFac(node.right)
                    ) {
                        parent[direction] = node.rotateRight();
                        dfs(parent[direction]?.right ?? null, value, parent[direction]!, 'right');
                    } else {
                        parent[direction] = node.rotateLeft();
                        dfs(parent[direction]?.left ?? null, value, parent[direction]!, 'left');
                    }
                }
            } else if (compare(node.value, value) > 0) {
                dfs(node.left, value, node, 'left');
            } else if (compare(node.value, value) < 0) {
                dfs(node.right, value, node, 'right');
            }

            parent?.pushUp();
        };

        dfs(this.root.left, value, this.root, 'left');
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns an index to insert value in the sorted set.
     * If the value is already present, the insertion point will be before (to the left of) any existing values.
     */
    bisectLeft(value: T): number {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): number => {
            if (node == null) return 0;

            if (compare(node.value, value) === 0) {
                return TreapNode.getSize(node.left);
            } else if (compare(node.value, value) > 0) {
                return dfs(node.left, value);
            } else if (compare(node.value, value) < 0) {
                return dfs(node.right, value) + TreapNode.getSize(node.left) + node.count;
            }

            return 0;
        };

        return dfs(this.root, value) - 1;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns an index to insert value in the sorted set.
     * If the value is already present, the insertion point will be before (to the right of) any existing values.
     */
    bisectRight(value: T): number {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): number => {
            if (node == null) return 0;

            if (compare(node.value, value) === 0) {
                return TreapNode.getSize(node.left) + node.count;
            } else if (compare(node.value, value) > 0) {
                return dfs(node.left, value);
            } else if (compare(node.value, value) < 0) {
                return dfs(node.right, value) + TreapNode.getSize(node.left) + node.count;
            }

            return 0;
        };
        return dfs(this.root, value) - 1;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns the index of the first occurrence of a value in the set, or -1 if it is not present.
     */
    indexOf(value: T): number {
        const compare = this.compareFn;
        let isExist = false;

        const dfs = (node: TreapNode<T> | null, value: T): number => {
            if (node == null) return 0;

            if (compare(node.value, value) === 0) {
                isExist = true;
                return TreapNode.getSize(node.left);
            } else if (compare(node.value, value) > 0) {
                return dfs(node.left, value);
            } else if (compare(node.value, value) < 0) {
                return dfs(node.right, value) + TreapNode.getSize(node.left) + node.count;
            }

            return 0;
        };
        const res = dfs(this.root, value) - 1;
        return isExist ? res : -1;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns the index of the last occurrence of a value in the set, or -1 if it is not present.
     */
    lastIndexOf(value: T): number {
        const compare = this.compareFn;
        let isExist = false;

        const dfs = (node: TreapNode<T> | null, value: T): number => {
            if (node == null) return 0;

            if (compare(node.value, value) === 0) {
                isExist = true;
                return TreapNode.getSize(node.left) + node.count - 1;
            } else if (compare(node.value, value) > 0) {
                return dfs(node.left, value);
            } else if (compare(node.value, value) < 0) {
                return dfs(node.right, value) + TreapNode.getSize(node.left) + node.count;
            }

            return 0;
        };

        const res = dfs(this.root, value) - 1;
        return isExist ? res : -1;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Returns the item located at the specified index.
     * @param index The zero-based index of the desired code unit. A negative index will count back from the last item.
     */
    at(index: number): T | undefined {
        if (index < 0) index += this.size;
        if (index < 0 || index >= this.size) return undefined;

        const dfs = (node: TreapNode<T> | null, rank: number): T | undefined => {
            if (node == null) return undefined;

            if (TreapNode.getSize(node.left) >= rank) {
                return dfs(node.left, rank);
            } else if (TreapNode.getSize(node.left) + node.count >= rank) {
                return node.value;
            } else {
                return dfs(node.right, rank - TreapNode.getSize(node.left) - node.count);
            }
        };

        const res = dfs(this.root, index + 2);
        return ([this.leftBound, this.rightBound] as any[]).includes(res) ? undefined : res;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Find and return the element less than `val`, return `undefined` if no such element found.
     */
    lower(value: T): T | undefined {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): T | undefined => {
            if (node == null) return undefined;
            if (compare(node.value, value) >= 0) return dfs(node.left, value);

            const tmp = dfs(node.right, value);
            if (tmp == null || compare(node.value, tmp) > 0) {
                return node.value;
            } else {
                return tmp;
            }
        };

        const res = dfs(this.root, value) as any;
        return res === this.leftBound ? undefined : res;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Find and return the element greater than `val`, return `undefined` if no such element found.
     */
    higher(value: T): T | undefined {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): T | undefined => {
            if (node == null) return undefined;
            if (compare(node.value, value) <= 0) return dfs(node.right, value);

            const tmp = dfs(node.left, value);

            if (tmp == null || compare(node.value, tmp) < 0) {
                return node.value;
            } else {
                return tmp;
            }
        };

        const res = dfs(this.root, value) as any;
        return res === this.rightBound ? undefined : res;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Find and return the element less than or equal to `val`, return `undefined` if no such element found.
     */
    floor(value: T): T | undefined {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): T | undefined => {
            if (node == null) return undefined;
            if (compare(node.value, value) === 0) return node.value;
            if (compare(node.value, value) >= 0) return dfs(node.left, value);

            const tmp = dfs(node.right, value);
            if (tmp == null || compare(node.value, tmp) > 0) {
                return node.value;
            } else {
                return tmp;
            }
        };

        const res = dfs(this.root, value) as any;
        return res === this.leftBound ? undefined : res;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description Find and return the element greater than or equal to `val`, return `undefined` if no such element found.
     */
    ceil(value: T): T | undefined {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): T | undefined => {
            if (node == null) return undefined;
            if (compare(node.value, value) === 0) return node.value;
            if (compare(node.value, value) <= 0) return dfs(node.right, value);

            const tmp = dfs(node.left, value);

            if (tmp == null || compare(node.value, tmp) < 0) {
                return node.value;
            } else {
                return tmp;
            }
        };

        const res = dfs(this.root, value) as any;
        return res === this.rightBound ? undefined : res;
    }

    /**
     * @complexity `O(logn)`
     * @description
     * Returns the last element from set.
     * If the set is empty, undefined is returned.
     */
    first(): T | undefined {
        const iter = this.inOrder();
        iter.next();
        const res = iter.next().value;
        return res === this.rightBound ? undefined : res;
    }

    /**
     * @complexity `O(logn)`
     * @description
     * Returns the last element from set.
     * If the set is empty, undefined is returned .
     */
    last(): T | undefined {
        const iter = this.reverseInOrder();
        iter.next();
        const res = iter.next().value;
        return res === this.leftBound ? undefined : res;
    }

    /**
     * @complexity `O(logn)`
     * @description
     * Removes the first element from an set and returns it.
     * If the set is empty, undefined is returned and the set is not modified.
     */
    shift(): T | undefined {
        const first = this.first();
        if (first === undefined) return undefined;
        this.delete(first);
        return first;
    }

    /**
     * @complexity `O(logn)`
     * @description
     * Removes the last element from an set and returns it.
     * If the set is empty, undefined is returned and the set is not modified.
     */
    pop(index?: number): T | undefined {
        if (index == null) {
            const last = this.last();
            if (last === undefined) return undefined;
            this.delete(last);
            return last;
        }

        const toDelete = this.at(index);
        if (toDelete == null) return;
        this.delete(toDelete);
        return toDelete;
    }

    /**
     *
     * @complexity `O(logn)`
     * @description
     * Returns number of occurrences of value in the sorted set.
     */
    count(value: T): number {
        const compare = this.compareFn;
        const dfs = (node: TreapNode<T> | null, value: T): number => {
            if (node == null) return 0;
            if (compare(node.value, value) === 0) return node.count;
            if (compare(node.value, value) < 0) return dfs(node.right, value);
            return dfs(node.left, value);
        };

        return dfs(this.root, value);
    }

    *[Symbol.iterator](): Generator<T, any, any> {
        yield* this.values();
    }

    /**
     * @description
     * Returns an iterable of keys in the set.
     */
    *keys(): Generator<T, any, any> {
        yield* this.values();
    }

    /**
     * @description
     * Returns an iterable of values in the set.
     */
    *values(): Generator<T, any, any> {
        const iter = this.inOrder();
        iter.next();
        const steps = this.size;
        for (let _ = 0; _ < steps; _++) {
            yield iter.next().value;
        }
    }

    /**
     * @description
     * Returns a generator for reversed order traversing the set.
     */
    *rvalues(): Generator<T, any, any> {
        const iter = this.reverseInOrder();
        iter.next();
        const steps = this.size;
        for (let _ = 0; _ < steps; _++) {
            yield iter.next().value;
        }
    }

    /**
     * @description
     * Returns an iterable of key, value pairs for every entry in the set.
     */
    *entries(): IterableIterator<[number, T]> {
        const iter = this.inOrder();
        iter.next();
        const steps = this.size;
        for (let i = 0; i < steps; i++) {
            yield [i, iter.next().value];
        }
    }

    private *inOrder(root: TreapNode<T> | null = this.root): Generator<T, any, any> {
        if (root == null) return;
        yield* this.inOrder(root.left);
        const count = root.count;
        for (let _ = 0; _ < count; _++) {
            yield root.value;
        }
        yield* this.inOrder(root.right);
    }

    private *reverseInOrder(root: TreapNode<T> | null = this.root): Generator<T, any, any> {
        if (root == null) return;
        yield* this.reverseInOrder(root.right);
        const count = root.count;
        for (let _ = 0; _ < count; _++) {
            yield root.value;
        }
        yield* this.reverseInOrder(root.left);
    }
}

时间复杂度 $O(n \times \log \textit{dist})$,空间复杂度 $O(\textit{dist})$。其中 $n$ 为数组 $\textit{nums}$ 的长度。


有任何问题,欢迎评论区交流,欢迎评论区提供其它解题思路(代码),也可以点个赞支持一下作者哈 😄~

每日一题-将数组分成最小总代价的子数组 II🔴

给你一个下标从 0 开始长度为 n 的整数数组 nums 和两个  整数 k 和 dist 。

一个数组的 代价 是数组中的 第一个 元素。比方说,[1,2,3] 的代价为 1 ,[3,4,1] 的代价为 3 。

你需要将 nums 分割成 k 个 连续且互不相交 的子数组,满足 第二 个子数组与第 k 个子数组中第一个元素的下标距离 不超过 dist 。换句话说,如果你将 nums 分割成子数组 nums[0..(i1 - 1)], nums[i1..(i2 - 1)], ..., nums[ik-1..(n - 1)] ,那么它需要满足 ik-1 - i1 <= dist 。

请你返回这些子数组的 最小 总代价。

 

示例 1:

输入:nums = [1,3,2,6,4,2], k = 3, dist = 3
输出:5
解释:将数组分割成 3 个子数组的最优方案是:[1,3] ,[2,6,4] 和 [2] 。这是一个合法分割,因为 ik-1 - i1 等于 5 - 2 = 3 ,等于 dist 。总代价为 nums[0] + nums[2] + nums[5] ,也就是 1 + 2 + 2 = 5 。
5 是分割成 3 个子数组的最小总代价。

示例 2:

输入:nums = [10,1,2,2,2,1], k = 4, dist = 3
输出:15
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[1] ,[2] 和 [2,2,1] 。这是一个合法分割,因为 ik-1 - i1 等于 3 - 1 = 2 ,小于 dist 。总代价为 nums[0] + nums[1] + nums[2] + nums[3] ,也就是 10 + 1 + 2 + 2 = 15 。
分割 [10] ,[1] ,[2,2,2] 和 [1] 不是一个合法分割,因为 ik-1 和 i1 的差为 5 - 1 = 4 ,大于 dist 。
15 是分割成 4 个子数组的最小总代价。

示例 3:

输入:nums = [10,8,18,9], k = 3, dist = 1
输出:36
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[8] 和 [18,9] 。这是一个合法分割,因为 ik-1 - i1 等于 2 - 1 = 1 ,等于 dist 。总代价为 nums[0] + nums[1] + nums[2] ,也就是 10 + 8 + 18 = 36 。
分割 [10] ,[8,18] 和 [9] 不是一个合法分割,因为 ik-1 和 i1 的差为 3 - 1 = 2 ,大于 dist 。
36 是分割成 3 个子数组的最小总代价。

 

提示:

  • 3 <= n <= 105
  • 1 <= nums[i] <= 109
  • 3 <= k <= n
  • k - 2 <= dist <= n - 2

两个有序集合维护前 k-1 小(Python/Java/C++/Go)

第一段的第一个数是确定的,即 $\textit{nums}[0]$。

如果知道了第二段的第一个数的位置(记作 $p$),第三段的第一个数的位置,……,第 $k$ 段的第一个数的位置(记作 $q$),那么这个划分方案也就确定了。

考虑到 $q-p \le \textit{dist}$,本题相当于在一个大小固定为 $\textit{dist}+1$ 的滑动窗口内,求前 $k-1$ 小的元素和。

仿照 480. 滑动窗口中位数,这可以用两个有序集合来做:

  1. 初始化两个有序集合 $L$ 和 $R$。注意:为方便计算,把 $k$ 减一。
  2. 把 $\textit{nums}[1]$ 到 $\textit{nums}[\textit{dist}+1]$ 加到 $L$ 中。
  3. 保留 $L$ 最小的 $k$ 个数,把其余数丢到 $R$ 中。
  4. 从 $i=\textit{dist}+2$ 开始滑窗。
  5. 先把 $\textit{out} = \textit{nums}[i-\textit{dist}-1]$ 移出窗口:如果 $\textit{out}$ 在 $L$ 中,就从 $L$ 中移除,否则从 $R$ 中移除。
  6. 然后把 $\textit{in} = \textit{nums}[i]$ 移入窗口:如果 $\textit{in}$ 小于 $L$ 中的最大元素,则加入 $L$,否则加入 $R$。
  7. 上面两步做完后,如果 $L$ 中的元素个数小于 $k$(等于 $k-1$),则从 $R$ 中取一个最小元素加入 $L$;反之,如果 $L$ 中的元素个数大于 $k$(等于 $k+1$),则从 $L$ 中取一个最大元素加入 $R$。

上述过程维护 $L$ 中元素之和 $\textit{sumL}$,取 $\textit{sumL}$ 的最小值,即为答案。

视频讲解 第四题。

###py

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        k -= 1
        sum_left = sum(nums[:dist + 2])
        L = SortedList(nums[1:dist + 2])
        R = SortedList()

        def L2R() -> None:
            x = L.pop()
            nonlocal sum_left
            sum_left -= x
            R.add(x)

        def R2L() -> None:
            x = R.pop(0)
            nonlocal sum_left
            sum_left += x
            L.add(x)

        while len(L) > k:
            L2R()

        ans = sum_left
        for i in range(dist + 2, len(nums)):
            # 移除 out
            out = nums[i - dist - 1]
            if out in L:
                sum_left -= out
                L.remove(out)
            else:
                R.remove(out)

            # 添加 in
            in_val = nums[i]
            if in_val < L[-1]:
                sum_left += in_val
                L.add(in_val)
            else:
                R.add(in_val)

            # 维护大小
            if len(L) == k - 1:
                R2L()
            elif len(L) == k + 1:
                L2R()

            ans = min(ans, sum_left)

        return ans

###cpp

class Solution {
public:
    long long minimumCost(vector<int>& nums, int k, int dist) {
        k--;
        long long sum = reduce(nums.begin(), nums.begin() + dist + 2, 0LL);
        multiset<int> L(nums.begin() + 1, nums.begin() + dist + 2), R;
        auto L2R = [&]() {
            int x = *L.rbegin();
            sum -= x;
            L.erase(L.find(x));
            R.insert(x);
        };
        auto R2L = [&]() {
            int x = *R.begin();
            sum += x;
            R.erase(R.find(x));
            L.insert(x);
        };
        while (L.size() > k) {
            L2R();
        }

        long long ans = sum;
        for (int i = dist + 2; i < nums.size(); i++) {
            // 移除 out
            int out = nums[i - dist - 1];
            auto it = L.find(out);
            if (it != L.end()) {
                sum -= out;
                L.erase(it);
            } else {
                R.erase(R.find(out));
            }

            // 添加 in
            int in = nums[i];
            if (in < *L.rbegin()) {
                sum += in;
                L.insert(in);
            } else {
                R.insert(in);
            }

            // 维护大小
            if (L.size() == k - 1) {
                R2L();
            } else if (L.size() == k + 1) {
                L2R();
            }

            ans = min(ans, sum);
        }
        return ans;
    }
};

###java

class Solution {
    public long minimumCost(int[] nums, int k, int dist) {
        k--;
        sumL = nums[0];
        for (int i = 1; i < dist + 2; i++) {
            sumL += nums[i];
            L.merge(nums[i], 1, Integer::sum);
        }
        sizeL = dist + 1;
        while (sizeL > k) {
            l2r();
        }

        long ans = sumL;
        for (int i = dist + 2; i < nums.length; i++) {
            // 移除 out
            int out = nums[i - dist - 1];
            if (L.containsKey(out)) {
                sumL -= out;
                sizeL--;
                removeOne(L, out);
            } else {
                removeOne(R, out);
            }

            // 添加 in
            int in = nums[i];
            if (in < L.lastKey()) {
                sumL += in;
                sizeL++;
                L.merge(in, 1, Integer::sum);
            } else {
                R.merge(in, 1, Integer::sum);
            }

            // 维护大小
            if (sizeL == k - 1) {
                r2l();
            } else if (sizeL == k + 1) {
                l2r();
            }

            ans = Math.min(ans, sumL);
        }
        return ans;
    }

    private long sumL;
    private int sizeL;
    private final TreeMap<Integer, Integer> L = new TreeMap<>();
    private final TreeMap<Integer, Integer> R = new TreeMap<>();

    private void l2r() {
        int x = L.lastKey();
        removeOne(L, x);
        sumL -= x;
        sizeL--;
        R.merge(x, 1, Integer::sum);
    }

    private void r2l() {
        int x = R.firstKey();
        removeOne(R, x);
        sumL += x;
        sizeL++;
        L.merge(x, 1, Integer::sum);
    }

    private void removeOne(Map<Integer, Integer> m, int x) {
        int cnt = m.get(x);
        if (cnt > 1) {
            m.put(x, cnt - 1);
        } else {
            m.remove(x);
        }
    }
}

###go

import "github.com/emirpasic/gods/trees/redblacktree"

func minimumCost(nums []int, k, dist int) int64 {
k--
L := redblacktree.NewWithIntComparator()
R := redblacktree.NewWithIntComparator()
add := func(t *redblacktree.Tree, x int) {
c, ok := t.Get(x)
if ok {
t.Put(x, c.(int)+1)
} else {
t.Put(x, 1)
}
}
del := func(t *redblacktree.Tree, x int) {
c, _ := t.Get(x)
if c.(int) > 1 {
t.Put(x, c.(int)-1)
} else {
t.Remove(x)
}
}

sumL := nums[0]
for _, x := range nums[1 : dist+2] {
sumL += x
add(L, x)
}
sizeL := dist + 1

l2r := func() {
x := L.Right().Key.(int)
sumL -= x
sizeL--
del(L, x)
add(R, x)
}
r2l := func() {
x := R.Left().Key.(int)
sumL += x
sizeL++
del(R, x)
add(L, x)
}
for sizeL > k {
l2r()
}

ans := sumL
for i := dist + 2; i < len(nums); i++ {
// 移除 out
out := nums[i-dist-1]
if _, ok := L.Get(out); ok {
sumL -= out
sizeL--
del(L, out)
} else {
del(R, out)
}

// 添加 in
in := nums[i]
if in < L.Right().Key.(int) {
sumL += in
sizeL++
add(L, in)
} else {
add(R, in)
}

// 维护大小
if sizeL == k-1 {
r2l()
} else if sizeL == k+1 {
l2r()
}

ans = min(ans, sumL)
}
return int64(ans)
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log \textit{dist})$,其中 $n$ 为 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(\textit{dist})$。

附:对顶堆+懒删除

###go

func minimumCost(nums []int, k int, dist int) int64 {
k--
l := &lazyHeap{todo: map[int]int{}} // 最大堆
r := &lazyHeap{todo: map[int]int{}} // 最小堆,所有元素取反
for _, x := range nums[1 : dist+2] {
l.push(x)
}
for l.size > k {
r.push(-l.pop())
}

mn := l.sum
for i := dist + 2; i < len(nums); i++ {
// 移除 out
out := nums[i-dist-1]
if out <= l.top() {
l.del(out)
} else {
r.del(-out)
}

// 添加 in
in := nums[i]
if in < l.top() {
l.push(in)
} else {
r.push(-in)
}

// 维护大小
if l.size == k-1 {
l.push(-r.pop())
} else if l.size == k+1 {
r.push(-l.pop())
}

mn = min(mn, l.sum)
}
return int64(nums[0] + mn)
}

type lazyHeap struct {
sort.IntSlice
todo map[int]int
size int // 实际大小
sum  int // 实际元素和
}

func (h lazyHeap) Less(i, j int) bool { return h.IntSlice[i] > h.IntSlice[j] } // 最大堆
func (h *lazyHeap) Push(v any)        { h.IntSlice = append(h.IntSlice, v.(int)) }
func (h *lazyHeap) Pop() any          { a := h.IntSlice; v := a[len(a)-1]; h.IntSlice = a[:len(a)-1]; return v }
func (h *lazyHeap) del(v int)         { h.todo[v]++; h.size--; h.sum -= v } // 懒删除
func (h *lazyHeap) push(v int) {
if h.todo[v] > 0 {
h.todo[v]--
} else {
heap.Push(h, v)
}
h.size++
h.sum += v
}
func (h *lazyHeap) pop() int { h.do(); h.size--; v := heap.Pop(h).(int); h.sum -= v; return v }
func (h *lazyHeap) top() int { h.do(); return h.IntSlice[0] }
func (h *lazyHeap) do() {
for h.Len() > 0 && h.todo[h.IntSlice[0]] > 0 {
h.todo[h.IntSlice[0]]--
heap.Pop(h)
}
}

专题训练

见下面数据结构题单的「§5.7 对顶堆」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

数据结构

解法:数据结构

由于第一个子数组的开头一定是 nums[0],因此把数组分成 $k$ 个子数组,等价于从 nums[1]nums[n - 1] 里选出 $(k - 1)$ 个数作为子数组的开头。

我们枚举最后一个子数组的开头 nums[i],则我们需要从 nums[i - delta]nums[i - 1] 中再选出 $(k - 2)$ 个数作为其它子数组的开头,这显然是一个长度为 delta 的滑动窗口。为了最小化答案,我们应该选择最小的 $(k - 2)$ 个数。

滑动窗口中的前 $k$ 小值是一个非常经典的问题,可以通过两个堆或者两个 multiset 维护。具体请看 leetcode 480. 滑动窗口的中位数

复杂度 $\mathcal{O}(n \log n)$。

滑动窗口的前 $k$ 小值在其它比赛中也是常见的套路,感兴趣的读者还可以继续练习:

  • 2023 ICPC 济南站 K 题:<某大家都懂的代码力量网站>/gym/104901/problem/K

参考代码(c++)

###c++

// 两个 multiset 维护滑动窗口中的前 K 小值
struct Magic {
    int K;
    // st1 保存前 K 小值,st2 保存其它值
    multiset<long long> st1, st2;
    // sm 表示 st1 中所有数的和
    long long sm;

    Magic(int K): K(K), sm(0) {}

    // 调整 st1 和 st2 的大小,保证调整后 st1 保存前 K 小值
    void adjust() {
        while (st1.size() < K && st2.size() > 0) {
            long long t = *(st2.begin());
            st1.insert(t);
            sm += t;
            st2.erase(st2.begin());
        }
        while (st1.size() > K) {
            long long t = *prev(st1.end());
            st2.insert(t);
            st1.erase(prev(st1.end()));
            sm -= t;
        }
    }

    // 插入元素 x
    void add(long long x) {
        if (!st2.empty() && x >= *(st2.begin())) st2.insert(x);
        else st1.insert(x), sm += x;
        adjust();
    }

    // 删除元素 x
    void del(long long x) {
        auto it = st1.find(x);
        if (it != st1.end()) st1.erase(it), sm -= x;
        else st2.erase(st2.find(x));
        adjust();
    }
};

class Solution {
public:
    long long minimumCost(vector<int>& nums, int K, int dist) {
        int n = nums.size();

        // 滑动窗口初始化
        Magic magic(K - 2);
        for (int i = 1; i < K - 1; i++) magic.add(nums[i]);

        long long ans = magic.sm + nums[K - 1];
        // 枚举最后一个数组的开头
        for (int i = K; i < n; i++) {
            int t = i - dist - 1;
            if (t > 0) magic.del(nums[t]);
            magic.add(nums[i - 1]);
            ans = min(ans, magic.sm + nums[i]);
        }

        return ans + nums[0];
    }
};

将数组分成最小总代价的子数组 I

方法一:排序

思路与算法

根据题意可知一个数组的代价是它的第一个元素。需要将给定数组 $\textit{nums}$ 分成 $3$ 个连续且没有交集的子数组,题目要求返回这 $3$ 子数组的最小代价和。
根据题意可知,第一个子数组的代价已确定为 $\textit{nums}[0]$。如果确定了第二个子数组的第一个数的位置和第三个子数组的第一个数的位置,此时子数组的划分方案也就确定。我们可以任意选择两个索引 $(i,j)$ 作为第二个子数组的起始位置和第三个子数组的起始位置,且满足 $1 \le i < j \le n -1$,其中 $n$ 表示给定数组 $\textit{nums}$ 的长度。此时,第二个子数组的代价为 $\textit{nums}[i]$,第三个子数组的代价为 $\textit{nums}[j]$。为保证代价和最小,此时可以在 $[1,n−1]$ 中的选择值最小的两个下标即可,可将子数组 $\textit{nums}[1 \cdots n-1]$ 按照从小到大排序,取最小的两个元素即可。

代码

###C++

class Solution {
public:
    int minimumCost(vector<int>& nums) {
        sort(nums.begin() + 1, nums.end());
        return reduce(nums.begin(), nums.begin() + 3, 0);
    }
};

###Java

class Solution {
    public int minimumCost(int[] nums) {
        Arrays.sort(nums, 1, nums.length);
        return nums[0] + nums[1] + nums[2];
    }
}

###C#

public class Solution {
    public int MinimumCost(int[] nums) {
        Array.Sort(nums, 1, nums.Length - 1);
        return nums.Take(3).Sum();
    }
}

###Go

func minimumCost(nums []int) int {
    sort.Ints(nums[1:])
    return nums[0] + nums[1] + nums[2]
}

###Python

class Solution:
    def minimumCost(self, nums: List[int]) -> int:
        nums[1:] = sorted(nums[1:])
        return sum(nums[:3])

###C

int cmp(const void *a, const void *b) {
    return (*(int *)a) - (*(int *)b);
}

int minimumCost(int *nums, int numsSize) {
    qsort(nums + 1, numsSize - 1, sizeof(int), cmp);
    return nums[0] + nums[1] + nums[2];
}

###JavaScript

var minimumCost = function(nums) {
    nums = [nums[0], ...nums.slice(1).sort((a, b) => a - b)];
    return nums.slice(0, 3).reduce((sum, num) => sum + num, 0);
};

###TypeScript

function minimumCost(nums: number[]): number {
    nums = [nums[0], ...nums.slice(1).sort((a, b) => a - b)];
    return nums.slice(0, 3).reduce((sum, num) => sum + num, 0);
};

###Rust

impl Solution {
    pub fn minimum_cost(mut nums: Vec<i32>) -> i32 {
        if nums.len() > 1 {
            let (first, rest) = nums.split_at_mut(1);
            rest.sort();
        }
        nums.iter().take(3).sum()
    }
}

复杂度分析

  • 时间复杂度:$O(n \log n)$,其中 $n$ 表示给定数组 $\textit{nums}$ 的长度。排序需要 $O(n \log n)$ 的时间。

  • 空间复杂度:$O(\log n)$。排序需要 $O(\log n)$ 的栈空间。

方法二:维护最小值和次小值

思路与算法

根据方法一可知,我们需要找到下标在 $[1,n−1]$ 中的两个最小元素,此时可以在遍历数组的过程中维护最小值 $\textit{first}$ 和次小值 $\textit{second}$,最终答案即为 $\textit{nums}[0] + \textit{first} + \textit{second}$。

代码

###C++

class Solution {
public:
    int minimumCost(vector<int> &nums) {
        int first = INT_MAX, second = INT_MAX;
        for (int i = 1; i < nums.size(); i++) {
            int x = nums[i];
            if (x < first) {
                second = first;
                first = x;
            } else if (x < second) {
                second = x;
            }
        }
        return nums[0] + first + second;
    }
};

###Java

class Solution {
    public int minimumCost(int[] nums) {
        int first = Integer.MAX_VALUE;
        int second = Integer.MAX_VALUE;

        for (int i = 1; i < nums.length; i++) {
            int x = nums[i];
            if (x < first) {
                second = first;
                first = x;
            } else if (x < second) {
                second = x;
            }
        }
        return nums[0] + first + second;
    }
}

###C#

public class Solution {
    public int MinimumCost(int[] nums) {
        int first = int.MaxValue;
        int second = int.MaxValue;
        
        for (int i = 1; i < nums.Length; i++) {
            int x = nums[i];
            if (x < first) {
                second = first;
                first = x;
            } else if (x < second) {
                second = x;
            }
        }
        return nums[0] + first + second;
    }
}

###Go

func minimumCost(nums []int) int {
    first := int(^uint(0) >> 1)
    second := int(^uint(0) >> 1)
    
    for i := 1; i < len(nums); i++ {
        x := nums[i]
        if x < first {
            second = first
            first = x
        } else if x < second {
            second = x
        }
    }
    return nums[0] + first + second
}

###Python

class Solution:
    def minimumCost(self, nums: List[int]) -> int:
        return nums[0] + sum(nsmallest(2, nums[1:]))

###C

int minimumCost(int* nums, int numsSize) {
    int first = INT_MAX;
    int second = INT_MAX;
    
    for (int i = 1; i < numsSize; i++) {
        int x = nums[i];
        if (x < first) {
            second = first;
            first = x;
        } else if (x < second) {
            second = x;
        }
    }
    return nums[0] + first + second;
}

###JavaScript

var minimumCost = function(nums) {
    let first = Number.MAX_SAFE_INTEGER;
    let second = Number.MAX_SAFE_INTEGER;
    
    for (let i = 1; i < nums.length; i++) {
        const x = nums[i];
        if (x < first) {
            second = first;
            first = x;
        } else if (x < second) {
            second = x;
        }
    }
    return nums[0] + first + second;
};

###TypeScript

function minimumCost(nums: number[]): number {
    let first: number = Number.MAX_SAFE_INTEGER;
    let second: number = Number.MAX_SAFE_INTEGER;
    
    for (let i = 1; i < nums.length; i++) {
        const x: number = nums[i];
        if (x < first) {
            second = first;
            first = x;
        } else if (x < second) {
            second = x;
        }
    }
    return nums[0] + first + second;
};

###Rust

impl Solution {
    pub fn minimum_cost(nums: Vec<i32>) -> i32 {
        let mut first = i32::MAX;
        let mut second = i32::MAX;
        
        for i in 1..nums.len() {
            let x = nums[i];
            if x < first {
                second = first;
                first = x;
            } else if x < second {
                second = x;
            }
        }
        nums[0] + first + second
    }
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是数组 $\textit{nums}$ 的长度。

  • 空间复杂度:$O(1)$。

每日一题-将数组分成最小总代价的子数组 I🟢

给你一个长度为 n 的整数数组 nums 。

一个数组的 代价 是它的 第一个 元素。比方说,[1,2,3] 的代价是 1 ,[3,4,1] 的代价是 3 。

你需要将 nums 分成 3 个 连续且没有交集 的子数组。

请你返回这些子数组最小 代价 总和 。

 

示例 1:

输入:nums = [1,2,3,12]
输出:6
解释:最佳分割成 3 个子数组的方案是:[1] ,[2] 和 [3,12] ,总代价为 1 + 2 + 3 = 6 。
其他得到 3 个子数组的方案是:
- [1] ,[2,3] 和 [12] ,总代价是 1 + 2 + 12 = 15 。
- [1,2] ,[3] 和 [12] ,总代价是 1 + 3 + 12 = 16 。

示例 2:

输入:nums = [5,4,3]
输出:12
解释:最佳分割成 3 个子数组的方案是:[5] ,[4] 和 [3] ,总代价为 5 + 4 + 3 = 12 。
12 是所有分割方案里的最小总代价。

示例 3:

输入:nums = [10,3,1,1]
输出:12
解释:最佳分割成 3 个子数组的方案是:[10,3] ,[1] 和 [1] ,总代价为 10 + 1 + 1 = 12 。
12 是所有分割方案里的最小总代价。

 

提示:

  • 3 <= n <= 50
  • 1 <= nums[i] <= 50

从 O(nlogn) 到 O(n)(Python/Java/C++/Go)

题意:把数组分成三段,每一段取第一个数再求和,问和的最小值是多少。

第一段的第一个数是确定的,即 $\textit{nums}[0]$。

如果知道了第二段的第一个数的位置,和第三段的第一个数的位置,那么这个划分方案也就确定了。

这两个下标可以在 $[1,n-1]$ 中随意取。

所以问题变成求下标在 $[1,n-1]$ 中的两个最小的数。

视频讲解

方法一:直接排序

###py

class Solution:
    def minimumCost(self, nums: List[int]) -> int:
        return nums[0] + sum(sorted(nums[1:])[:2])

###java

class Solution {
    public int minimumCost(int[] nums) {
        Arrays.sort(nums, 1, nums.length);
        return nums[0] + nums[1] + nums[2];
    }
}

###cpp

class Solution {
public:
    int minimumCost(vector<int>& nums) {
        sort(nums.begin() + 1, nums.end());
        return reduce(nums.begin(), nums.begin() + 3, 0);
    }
};

###go

func minimumCost(nums []int) int {
slices.Sort(nums[1:])
return nums[0] + nums[1] + nums[2]
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n\log n)$,其中 $n$ 为 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。忽略排序的栈开销。

方法二:维护最小值和次小值

###py

class Solution:
    def minimumCost(self, nums: List[int]) -> int:
        return nums[0] + sum(nsmallest(2, nums[1:]))

###java

class Solution {
    public int minimumCost(int[] nums) {
        int fi = Integer.MAX_VALUE, se = Integer.MAX_VALUE;
        for (int i = 1; i < nums.length; i++) {
            int x = nums[i];
            if (x < fi) {
                se = fi;
                fi = x;
            } else if (x < se) {
                se = x;
            }
        }
        return nums[0] + fi + se;
    }
}

###cpp

class Solution {
public:
    int minimumCost(vector<int> &nums) {
        int fi = INT_MAX, se = INT_MAX;
        for (int i = 1; i < nums.size(); i++) {
            int x = nums[i];
            if (x < fi) {
                se = fi;
                fi = x;
            } else if (x < se) {
                se = x;
            }
        }
        return nums[0] + fi + se;
    }
};

###go

func minimumCost(nums []int) int {
fi, se := math.MaxInt, math.MaxInt
for _, x := range nums[1:] {
if x < fi {
se = fi
fi = x
} else if x < se {
se = x
}
}
return nums[0] + fi + se
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n)$,其中 $n$ 为 $\textit{nums}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、二叉树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA/一般树)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

首元素+后续两个最小元素,C0ms/python一行,双百

Problem: 100181. 将数组分成最小总代价的子数组 I

[TOC]

思路

就是计算首个元素、后续两个最小元素,三个元素之和。

Code

执行用时分布0ms击败100.00%;消耗内存分布5.47MB击败100.00%

###C

int minimumCost(int* nums, int numsSize) {
    int a1 = nums[1], a2 = 50;
    for (int i = 2; i < numsSize; ++ i) 
        if      (nums[i] <= a1) a2 = a1, a1 = nums[i];
        else if (nums[i] <  a2) a2 = nums[i];
    return nums[0] + a1 + a2;
}

###Python3

class Solution:
    def minimumCost(self, nums: List[int]) -> int:
        return nums[0] + sum(sorted(nums[1:])[:2])

您若还有不同方法,欢迎贴在评论区,一起交流探讨! ^_^

↓ 点个赞,点收藏,留个言,再划走,感谢您支持作者! ^_^

【模板】二分查找(Python/Java/C++/Go)

视频讲解二分查找 红蓝染色法【基础算法精讲 04】,包含闭区间、半闭半开、开区间三种写法。

相关题目34. 在排序数组中查找元素的第一个和最后一个位置,推荐阅读 题解 中的答疑。

手写二分

推荐写开区间二分,简单好记。

class Solution:
    def nextGreatestLetter(self, letters: List[str], target: str) -> str:
        n = len(letters)
        left, right = -1, n
        while left + 1 < right:
            mid = (left + right) // 2
            if letters[mid] > target:
                right = mid
            else:
                left = mid
        return letters[right] if right < n else letters[0]
class Solution {
    public char nextGreatestLetter(char[] letters, char target) {
        int n = letters.length;
        int left = -1;
        int right = n;
        while (left + 1 < right) {
            int mid = left + (right - left) / 2;
            if (letters[mid] > target) {
                right = mid;
            } else {
                left = mid;
            }
        }
        return right < n ? letters[right] : letters[0];
    }
}
class Solution {
public:
    char nextGreatestLetter(vector<char>& letters, char target) {
        int n = letters.size();
        int left = -1, right = n;
        while (left + 1 < right) {
            int mid = left + (right - left) / 2;
            (letters[mid] > target ? right : left) = mid;
        }
        return right < n ? letters[right] : letters[0];
    }
};
func nextGreatestLetter(letters []byte, target byte) byte {
n := len(letters)
left, right := -1, n
for left+1 < right {
mid := left + (right-left)/2
if letters[mid] > target {
right = mid
} else {
left = mid
}
}
if right == n {
return letters[0]
}
return letters[right]
}

库函数写法

class Solution:
    def nextGreatestLetter(self, letters: List[str], target: str) -> str:
        i = bisect_right(letters, target)
        return letters[i] if i < len(letters) else letters[0]
class Solution {
    public char nextGreatestLetter(char[] letters, char target) {
        int i = Arrays.binarySearch(letters, (char) (target + 1));
        if (i < 0) { // letters 中没有 target + 1
            i = ~i; // 根据 Arrays.binarySearch 的文档,~i 是第一个大于 target 的字母的下标
        }
        return i < letters.length ? letters[i] : letters[0];
    }
}
class Solution {
public:
    char nextGreatestLetter(vector<char>& letters, char target) {
        auto it = ranges::upper_bound(letters, target);
        return it < letters.end() ? *it : letters[0];
    }
};
func nextGreatestLetter(letters []byte, target byte) byte {
i, _ := slices.BinarySearch(letters, target+1)
if i == len(letters) {
return letters[0]
}
return letters[i]
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(\log n)$,其中 $n$ 是 $\textit{letters}$ 的长度。
  • 空间复杂度:$\mathcal{O}(1)$。

附:二分查找常用转化

需求 写法 如果不存在
$\ge x$ 的第一个元素的下标 $\texttt{lowerBound}(\textit{nums},x)$ 结果为 $n$
$> x$ 的第一个元素的下标 $\texttt{lowerBound}(\textit{nums},x+1)$ 结果为 $n$
$< x$ 的最后一个元素的下标 $\texttt{lowerBound}(\textit{nums},x)-1$ 结果为 $-1$
$\le x$ 的最后一个元素的下标 $\texttt{lowerBound}(\textit{nums},x+1)-1$ 结果为 $-1$
需求 写法
$< x$ 的元素个数 $\texttt{lowerBound}(\textit{nums},x)$
$\le x$ 的元素个数 $\texttt{lowerBound}(\textit{nums},x+1)$
$\ge x$ 的元素个数 $n - \texttt{lowerBound}(\textit{nums},x)$
$> x$ 的元素个数 $n - \texttt{lowerBound}(\textit{nums},x+1)$

注意 $< x$ 和 $\ge x$ 互为补集,元素个数之和为 $n$。$\le x$ 和 $> x$ 同理。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

每日一题-寻找比目标字母大的最小字母🟢

给你一个字符数组 letters,该数组按非递减顺序排序,以及一个字符 targetletters 里至少有两个不同的字符。

返回 letters 中大于 target 的最小的字符。如果不存在这样的字符,则返回 letters 的第一个字符。

 

示例 1:

输入: letters = ['c', 'f', 'j'],target = 'a'
输出: 'c'
解释:letters 中字典上比 'a' 大的最小字符是 'c'。

示例 2:

输入: letters = ['c','f','j'], target = 'c'
输出: 'f'
解释:letters 中字典顺序上大于 'c' 的最小字符是 'f'。

示例 3:

输入: letters = ['x','x','y','y'], target = 'z'
输出: 'x'
解释:letters 中没有一个字符在字典上大于 'z',所以我们返回 letters[0]。

 

提示:

  • 2 <= letters.length <= 104
  • letters[i] 是一个小写字母
  • letters非递减顺序排序
  • letters 最少包含两个不同的字母
  • target 是一个小写字母

【宫水三叶】简单二分运用题

二分

给定的数组「有序」,找到比 $target$ 大的最小字母,容易想到二分。

唯一需要注意的是,二分结束后需要再次 check,如果不满足,则取数组首位元素。

代码:

###Java

class Solution {
    public char nextGreatestLetter(char[] letters, char target) {
        int n = letters.length;
        int l = 0, r = n - 1;
        while (l < r) {
            int mid = l + r >> 1;
            if (letters[mid] > target) r = mid;
            else l = mid + 1;
        }
        return letters[r] > target ? letters[r] : letters[0];
    }
}
  • 时间复杂度:$O(\log{n})$
  • 空间复杂度:$O(1)$

加餐 & 其他「二分」

今日份加餐 【综合笔试题】两种强有力的字符串处理方式 🎉🎉

或者继续学习「二分」相关内容 🍭🍭🍭

不太清楚 check 的大于小于怎么确定的可以重点看看这篇:

题目 题解 难度 推荐指数
4. 寻找两个正序数组的中位数 LeetCode 题解链接 困难 🤩🤩🤩🤩
29. 两数相除 LeetCode 题解链接 中等 🤩🤩🤩
74. 搜索二维矩阵 LeetCode 题解链接 中等 🤩🤩🤩🤩
220. 存在重复元素 III LeetCode 题解链接 中等 🤩🤩🤩
354. 俄罗斯套娃信封问题 LeetCode 题解链接 困难 🤩🤩🤩
363. 矩形区域不超过 K 的最大数值和 LeetCode 题解链接 困难 🤩🤩🤩
852. 山脉数组的峰顶索引 LeetCode 题解链接 简单 🤩🤩🤩🤩🤩
1004. 最大连续1的个数 III LeetCode 题解链接 中等 🤩🤩🤩
1208. 尽可能使字符串相等 LeetCode 题解链接 中等 🤩🤩🤩
1482. 制作 m 束花所需的最少天数 LeetCode 题解链接 中等 🤩🤩🤩
1707. 与数组中元素的最大异或值 LeetCode 题解链接 困难 🤩🤩🤩
1751. 最多可以参加的会议数目 II LeetCode 题解链接 困难 🤩🤩🤩

注:以上目录整理来自 wiki,任何形式的转载引用请保留出处。


最后

如果有帮助到你,请给题解点个赞和收藏,让更多的人看到 ~ ("▔□▔)/

也欢迎你 关注我 和 加入我们的「组队打卡」小群 ,提供写「证明」&「思路」的高质量题解。

所有题解已经加入 刷题指南,欢迎 star 哦 ~

寻找比目标字母大的最小字母

方法一:线性查找

由于给定的列表已经按照递增顺序排序,因此可以从左到右遍历列表,找到第一个比目标字母大的字母,即为比目标字母大的最小字母。

如果目标字母小于列表中的最后一个字母,则一定可以在列表中找到比目标字母大的最小字母。如果目标字母大于或等于列表中的最后一个字母,则列表中不存在比目标字母大的字母,根据循环出现的顺序,列表的首个字母是比目标字母大的最小字母。

###Python

class Solution:
    def nextGreatestLetter(self, letters: List[str], target: str) -> str:
        return next((letter for letter in letters if letter > target), letters[0])

###Java

class Solution {
    public char nextGreatestLetter(char[] letters, char target) {
        int length = letters.length;
        char nextGreater = letters[0];
        for (int i = 0; i < length; i++) {
            if (letters[i] > target) {
                nextGreater = letters[i];
                break;
            }
        }
        return nextGreater;
    }
}

###C#

public class Solution {
    public char NextGreatestLetter(char[] letters, char target) {
        int length = letters.Length;
        char nextGreater = letters[0];
        for (int i = 0; i < length; i++) {
            if (letters[i] > target) {
                nextGreater = letters[i];
                break;
            }
        }
        return nextGreater;
    }
}

###C++

class Solution {
public:
    char nextGreatestLetter(vector<char>& letters, char target) {
        for (char letter : letters) {
            if (letter > target) {
                return letter;
            }
        }
        return letters[0];
    }
};

###C

char nextGreatestLetter(char* letters, int lettersSize, char target){
    for (int i = 0; i < lettersSize; i++) {
        if (letters[i] > target) {
            return letters[i];
        }
    }
    return letters[0];
}

###JavaScript

var nextGreatestLetter = function(letters, target) {
    const length = letters.length;
    let nextGreater = letters[0];
    for (let i = 0; i < length; i++) {
        if (letters[i] > target) {
            nextGreater = letters[i];
            break;
        }
    }
    return nextGreater;
};

###go

func nextGreatestLetter(letters []byte, target byte) byte {
    for _, letter := range letters {
        if letter > target {
            return letter
        }
    }
    return letters[0]
}

复杂度分析

  • 时间复杂度:$O(n)$,其中 $n$ 是列表 $\textit{letters}$ 的长度。需要遍历列表一次寻找比目标字母大的最小字母。

  • 空间复杂度:$O(1)$。

方法二:二分查找

利用列表有序的特点,可以使用二分查找降低时间复杂度。

首先比较目标字母和列表中的最后一个字母,当目标字母大于或等于列表中的最后一个字母时,答案是列表的首个字母。当目标字母小于列表中的最后一个字母时,列表中一定存在比目标字母大的字母,可以使用二分查找得到比目标字母大的最小字母。

初始时,二分查找的范围是整个列表的下标范围。每次比较当前下标处的字母和目标字母,如果当前下标处的字母大于目标字母,则在当前下标以及当前下标的左侧继续查找,否则在当前下标的右侧继续查找。

###Python

class Solution:
    def nextGreatestLetter(self, letters: List[str], target: str) -> str:
        return letters[bisect_right(letters, target)] if target < letters[-1] else letters[0]

###Java

class Solution {
    public char nextGreatestLetter(char[] letters, char target) {
        int length = letters.length;
        if (target >= letters[length - 1]) {
            return letters[0];
        }
        int low = 0, high = length - 1;
        while (low < high) {
            int mid = (high - low) / 2 + low;
            if (letters[mid] > target) {
                high = mid;
            } else {
                low = mid + 1;
            }
        }
        return letters[low];
    }
}

###C#

public class Solution {
    public char NextGreatestLetter(char[] letters, char target) {
        int length = letters.Length;
        if (target >= letters[length - 1]) {
            return letters[0];
        }
        int low = 0, high = length - 1;
        while (low < high) {
            int mid = (high - low) / 2 + low;
            if (letters[mid] > target) {
                high = mid;
            } else {
                low = mid + 1;
            }
        }
        return letters[low];
    }
}

###C++

class Solution {
public:
    char nextGreatestLetter(vector<char> &letters, char target) {
        return target < letters.back() ? *upper_bound(letters.begin(), letters.end() - 1, target) : letters[0];
    }
};

###C

char nextGreatestLetter(char* letters, int lettersSize, char target){
    if (target >= letters[lettersSize - 1]) {
        return letters[0];
    }
    int low = 0, high = lettersSize - 1;
    while (low < high) {
        int mid = (high - low) / 2 + low;
        if (letters[mid] > target) {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    return letters[low];
}

###JavaScript

var nextGreatestLetter = function(letters, target) {
    const length = letters.length;
    if (target >= letters[length - 1]) {
        return letters[0];
    }
    let low = 0, high = length - 1;
    while (low < high) {
        const mid = Math.floor((high - low) / 2) + low;
        if (letters[mid] > target) {
            high = mid;
        } else {
            low = mid + 1;
        }
    }
    return letters[low];
};

###go

func nextGreatestLetter(letters []byte, target byte) byte {
    if target >= letters[len(letters)-1] {
        return letters[0]
    }
    i := sort.Search(len(letters)-1, func(i int) bool { return letters[i] > target })
    return letters[i]
}

复杂度分析

  • 时间复杂度:$O(\log n)$,其中 $n$ 是列表 $\textit{letters}$ 的长度。二分查找的时间复杂度是 $O(\log n)$。

  • 空间复杂度:$O(1)$。

每日一题-转换字符串的最小成本 II🔴

给你两个下标从 0 开始的字符串 sourcetarget ,它们的长度均为 n 并且由 小写 英文字母组成。

另给你两个下标从 0 开始的字符串数组 originalchanged ,以及一个整数数组 cost ,其中 cost[i] 代表将字符串 original[i] 更改为字符串 changed[i] 的成本。

你从字符串 source 开始。在一次操作中,如果 存在 任意 下标 j 满足 cost[j] == z  、original[j] == x 以及 changed[j] == y ,你就可以选择字符串中的 子串 x 并以 z 的成本将其更改为 y 。 你可以执行 任意数量 的操作,但是任两次操作必须满足 以下两个 条件 之一

  • 在两次操作中选择的子串分别是 source[a..b]source[c..d] ,满足 b < c  d < a 。换句话说,两次操作中选择的下标 不相交
  • 在两次操作中选择的子串分别是 source[a..b]source[c..d] ,满足 a == c b == d 。换句话说,两次操作中选择的下标 相同

返回将字符串 source 转换为字符串 target 所需的 最小 成本。如果不可能完成转换,则返回 -1

注意,可能存在下标 ij 使得 original[j] == original[i]changed[j] == changed[i]

 

示例 1:

输入:source = "abcd", target = "acbe", original = ["a","b","c","c","e","d"], changed = ["b","c","b","e","b","e"], cost = [2,5,5,1,2,20]
输出:28
解释:将 "abcd" 转换为 "acbe",执行以下操作:
- 将子串 source[1..1] 从 "b" 改为 "c" ,成本为 5 。
- 将子串 source[2..2] 从 "c" 改为 "e" ,成本为 1 。
- 将子串 source[2..2] 从 "e" 改为 "b" ,成本为 2 。
- 将子串 source[3..3] 从 "d" 改为 "e" ,成本为 20 。
产生的总成本是 5 + 1 + 2 + 20 = 28 。 
可以证明这是可能的最小成本。

示例 2:

输入:source = "abcdefgh", target = "acdeeghh", original = ["bcd","fgh","thh"], changed = ["cde","thh","ghh"], cost = [1,3,5]
输出:9
解释:将 "abcdefgh" 转换为 "acdeeghh",执行以下操作:
- 将子串 source[1..3] 从 "bcd" 改为 "cde" ,成本为 1 。
- 将子串 source[5..7] 从 "fgh" 改为 "thh" ,成本为 3 。可以执行此操作,因为下标 [5,7] 与第一次操作选中的下标不相交。
- 将子串 source[5..7] 从 "thh" 改为 "ghh" ,成本为 5 。可以执行此操作,因为下标 [5,7] 与第一次操作选中的下标不相交,且与第二次操作选中的下标相同。
产生的总成本是 1 + 3 + 5 = 9 。
可以证明这是可能的最小成本。

示例 3:

输入:source = "abcdefgh", target = "addddddd", original = ["bcd","defgh"], changed = ["ddd","ddddd"], cost = [100,1578]
输出:-1
解释:无法将 "abcdefgh" 转换为 "addddddd" 。
如果选择子串 source[1..3] 执行第一次操作,以将 "abcdefgh" 改为 "adddefgh" ,你无法选择子串 source[3..7] 执行第二次操作,因为两次操作有一个共用下标 3 。
如果选择子串 source[3..7] 执行第一次操作,以将 "abcdefgh" 改为 "abcddddd" ,你无法选择子串 source[1..3] 执行第二次操作,因为两次操作有一个共用下标 3 。

 

提示:

  • 1 <= source.length == target.length <= 1000
  • sourcetarget 均由小写英文字母组成
  • 1 <= cost.length == original.length == changed.length <= 100
  • 1 <= original[i].length == changed[i].length <= source.length
  • original[i]changed[i] 均由小写英文字母组成
  • original[i] != changed[i]
  • 1 <= cost[i] <= 106

较为综合的一道题目(动态规划 + 字典树 + 最短路)

解题目分两步:

第一步,搭建思路,先不考虑具体实现。

第二步,根据框架实现代码,使代码满足时间复杂度和空间复杂度的要求。

有的题目第一步的比较难,有的比较侧重第二步。这个题目就属于后者。

解法思路

因为题目中要求任意两个变换要么在相同区间、要么不相交,所以可用动态规划求解:令 $dp[i]$ = 字符串 source[i...n-1] 变为字符串 target[i...n-1] 的最小代价。状态转移为,枚举 $j = i \sim n-1$,$dp[i] = \min(dp[i], f(i, j) + dp[j+1])$,其中,$f(i, j)$ 为字符串 source[i...j] 变为字符串 target[i...j] 的最小代价。

关键问题是如何求最小代价。根据题目描述,可以把字符串看作 ”点“,题目描述的 originalchanged 的变换看作 ”边“,那么可以构造出一个有向图。这样,对于子串对 source[a...b]target[c...d],改变所需要的最小代价就是字符串 source[a...b] 到字符串 target[c...d] 的最短路。

实现方式

现在我们拿到了 source[i...j]target[i...j],希望能够在 $O(1)$ 的时间内求出最小代价(否则仅枚举子字符串就需要 $O(n^2)$ 的时间,如果求最短路不是 $O(1)$,那么将超时)。但是,无论是用哈希表,还是用 set 存字符串,都不能做到 $O(1)$。

但是,注意看之前的转移方程,我们枚举 $j = i \sim n-1$ ,相当于从左到右依次 “添加” 字符。那么什么数据结构可以做到 $O(1)$ 时间来 “添加” 一个字符?字典树可以做到。

字符串 hash 也可以,但是 hash 有一定的不确定性,这里选择字典树。

这样答案就呼之欲出了:

  1. 首先用字典树,把 sourcetarget 中的字符串映射成整数,构造邻接矩阵,得出一个边数不超过 $200$ 的有向图;
  2. 用 floyd 算法求出所有字符串之间的最短路;
  3. 用动态规划求解原问题。我们可以在从左到右枚举 $j$ 的过程中,通过字典树分别求出 source[i...j]target[i...j] 所对应的整数(由于每次都是 “添加” 字符,故每次操作都是 $O(1)$ 的),然后通过查询邻接矩阵的方式得出两个字符串之间的最短距离。

代码

###c++

int f[200005][26], g[200005];

class Solution {
public:
    long long minimumCost(string source, string target, vector<string>& original, vector<string>& changed, vector<int>& cost) {
        int n = original.size() * 2;
        long long d[n][n];
        memset(d, 0x3f, sizeof(d));
        long long inf = 0x3f3f3f3f3f3f3f3fll;
        
        int p = 1;
        memset(f[0], -1, sizeof(f[0]));
        g[0] = -1;
        
        auto insert = [&](string& s) {
            int cur = 0;
            for(char c : s) {
                if(f[cur][c-'a'] == -1) {
                    f[cur][c-'a'] = p;
                    memset(f[p], -1, sizeof(f[p]));
                    g[p] = -1;
                    p++;
                }
                cur = f[cur][c-'a'];
            }
            return cur;
        };
        
        int m = 0;
        for(int i = 0; i < original.size(); ++i) {
            int from = insert(original[i]), to = insert(changed[i]);
            if(g[from] == -1)
                g[from] = m++;
            if(g[to] == -1)
                g[to] = m++;
            d[g[from]][g[to]] = min(d[g[from]][g[to]], (long long)cost[i]);
        }
        
        for(int k = 0; k < m; ++k) {
            for(int i = 0; i < m; ++i) {
                for(int j = 0; j < m; ++j) {
                    d[i][j] = min(d[i][j], d[i][k] + d[k][j]);
                }
            }
        }
        
        long long dp[source.size() + 1];
        dp[source.size()] = 0;
        
        for(int i = source.size() - 1; i >= 0; --i) {
            dp[i] = inf;
            if(source[i] == target[i])
                dp[i] = dp[i+1];
            for(int j = i, cur1 = 0, cur2 = 0; j < source.size(); ++j) {
                cur1 = f[cur1][source[j]-'a'];
                cur2 = f[cur2][target[j]-'a'];
                if(cur1 == -1 || cur2 == -1) break;
                if(g[cur1] != -1 && g[cur2] != -1)
                    dp[i] = min(dp[i], d[g[cur1]][g[cur2]] + dp[j+1]);
            }
        }
        
        if(dp[0] >= inf) return -1;
        return dp[0];
    }
};

字典树+Floyd+记忆化搜索/递推(Python/Java/C++/Go)

思路

  1. 把每个字符串转换成一个整数编号,这一步可以用字典树完成。见 208. 实现 Trie (前缀树)原理讲解
  2. 建图,从 $\textit{original}[i]$ 向 $\textit{changed}[i]$ 连边,边权为 $\textit{cost}[i]$。
  3. Floyd 算法求图中任意两点最短路,得到 $\textit{dis}$ 矩阵,原理请看【图解】带你发明 Floyd 算法!这里得到的 $\textit{dis}[i][j]$ 表示编号为 $i$ 的子串,通过若干次替换操作变成编号为 $j$ 的子串的最小成本。
  4. 动态规划。定义 $\textit{dfs}(i)$ 表示从 $\textit{source}[i]$ 开始向后修改的最小成本。
  5. 如果 $\textit{source}[i] = \textit{target}[i]$,可以不修改,$\textit{dfs}(i) = \textit{dfs}(i+1)$。
  6. 也可以从 $\textit{source}[i]$ 开始向后修改,利用字典树快速判断 $\textit{source}$ 和 $\textit{target}$ 的下标从 $i$ 到 $j$ 的子串是否在 $\textit{original}$ 和 $\textit{changed}$ 中,如果在就用 $\textit{dis}[x][y] + \textit{dfs}(j+1)$ 更新 $\textit{dfs}(i)$ 的最小值,其中 $x$ 和 $y$ 分别是 $\textit{source}$ 和 $\textit{target}$ 的这段子串对应的编号。
  7. 递归边界 $\textit{dfs}(n) = 0$。
  8. 递归入口 $\textit{dfs}(0)$,即为答案。如果答案是无穷大则返回 $-1$。

本题视频讲解

写法一:记忆化搜索

###py

class Solution:
    def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
        len_to_strs = defaultdict(set)
        dis = defaultdict(lambda: defaultdict(lambda: inf))
        for x, y, c in zip(original, changed, cost):
            len_to_strs[len(x)].add(x)  # 按照长度分组
            len_to_strs[len(y)].add(y)
            dis[x][y] = min(dis[x][y], c)
            dis[x][x] = 0
            dis[y][y] = 0

        # 不同长度的字符串必然在不同的连通块中,分别计算 Floyd
        for strs in len_to_strs.values():
            for k in strs:
                for i in strs:
                    if dis[i][k] == inf:  # 加上这句话,巨大优化!
                        continue
                    for j in strs:
                        dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j])

        # 返回把 source[:i] 变成 target[:i] 的最小成本
        @cache
        def dfs(i: int) -> int:
            if i == 0:
                return 0
            res = inf
            if source[i - 1] == target[i - 1]:
                res = dfs(i - 1)  # 不修改 source[i]
            for size, strs in len_to_strs.items():  # 枚举子串长度
                if i < size:
                    continue
                s = source[i - size: i]
                t = target[i - size: i]
                if s in strs and t in strs:  # 可以替换
                    res = min(res, dis[s][t] + dfs(i - size))
            return res
        ans = dfs(len(source))
        return ans if ans < inf else -1

###py

class Node:
    __slots__ = 'son', 'sid'

    def __init__(self):
        self.son = [None] * 26
        self.sid = -1  # 字符串的编号

class Solution:
    def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
        ord_a = ord('a')
        root = Node()
        sid = 0

        def put(s: str) -> int:
            o = root
            for c in s:
                i = ord(c) - ord_a
                if o.son[i] is None:
                    o.son[i] = Node()
                o = o.son[i]
            if o.sid < 0:
                nonlocal sid
                o.sid = sid
                sid += 1
            return o.sid

        # 初始化距离矩阵
        m = len(cost)
        dis = [[inf] * (m * 2) for _ in range(m * 2)]
        for x, y, c in zip(original, changed, cost):
            x = put(x)
            y = put(y)
            dis[x][y] = min(dis[x][y], c)

        # Floyd 求任意两点最短路
        for k in range(sid):
            for i in range(sid):
                if dis[i][k] == inf:  # 加上这句话,巨大优化!
                    continue
                for j in range(sid):
                    dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j])

        n = len(source)
        @cache
        def dfs(i: int) -> int:
            if i >= n:
                return 0
            res = inf
            if source[i] == target[i]:
                res = dfs(i + 1)  # 不修改 source[i]
            p, q = root, root
            for j in range(i, n):
                p = p.son[ord(source[j]) - ord_a]
                q = q.son[ord(target[j]) - ord_a]
                if p is None or q is None:
                    break
                if p.sid < 0 or q.sid < 0:
                    continue
                # 修改从 i 到 j 的这一段
                res = min(res, dis[p.sid][q.sid] + dfs(j + 1))
            return res
        ans = dfs(0)
        return ans if ans < inf else -1

###java

class Node {
    Node[] son = new Node[26];
    int sid = -1; // 字符串的编号
}

class Solution {
    private Node root = new Node();
    private int sid = 0;
    private char[] s, t;
    private int[][] dis;
    private long[] memo;

    public long minimumCost(String source, String target, String[] original, String[] changed, int[] cost) {
        // 初始化距离矩阵
        int m = cost.length;
        dis = new int[m * 2][m * 2];
        for (int i = 0; i < dis.length; i++) {
            Arrays.fill(dis[i], Integer.MAX_VALUE / 2);
            dis[i][i] = 0;
        }
        for (int i = 0; i < cost.length; i++) {
            int x = put(original[i]);
            int y = put(changed[i]);
            dis[x][y] = Math.min(dis[x][y], cost[i]);
        }

        // Floyd 求任意两点最短路
        for (int k = 0; k < sid; k++) {
            for (int i = 0; i < sid; i++) {
                if (dis[i][k] == Integer.MAX_VALUE / 2) {
                    continue;
                }
                for (int j = 0; j < sid; j++) {
                    dis[i][j] = Math.min(dis[i][j], dis[i][k] + dis[k][j]);
                }
            }
        }

        s = source.toCharArray();
        t = target.toCharArray();
        memo = new long[s.length];
        Arrays.fill(memo, -1);
        long ans = dfs(0);
        return ans < Long.MAX_VALUE / 2 ? ans : -1;
    }

    private int put(String s) {
        Node o = root;
        for (char b : s.toCharArray()) {
            int i = b - 'a';
            if (o.son[i] == null) {
                o.son[i] = new Node();
            }
            o = o.son[i];
        }
        if (o.sid < 0) {
            o.sid = sid++;
        }
        return o.sid;
    }

    private long dfs(int i) {
        if (i >= s.length) {
            return 0;
        }
        if (memo[i] != -1) { // 之前算过
            return memo[i];
        }
        long res = Long.MAX_VALUE / 2;
        if (s[i] == t[i]) {
            res = dfs(i + 1); // 不修改 source[i]
        }
        Node p = root, q = root;
        for (int j = i; j < s.length; j++) {
            p = p.son[s[j] - 'a'];
            q = q.son[t[j] - 'a'];
            if (p == null || q == null) {
                break;
            }
            if (p.sid < 0 || q.sid < 0) {
                continue;
            }
            // 修改从 i 到 j 的这一段
            int d = dis[p.sid][q.sid];
            if (d < Integer.MAX_VALUE / 2) {
                res = Math.min(res, d + dfs(j + 1));
            }
        }
        return memo[i] = res; // 记忆化
    }
}

###cpp

struct Node {
    Node* son[26]{};
    int sid = -1; // 字符串的编号
};

class Solution {
public:
    long long minimumCost(string source, string target, vector<string>& original, vector<string>& changed, vector<int>& cost) {
        Node* root = new Node();
        int sid = 0;
        auto put = [&](string& s) -> int {
            Node* o = root;
            for (char b : s) {
                int i = b - 'a';
                if (o->son[i] == nullptr) {
                    o->son[i] = new Node();
                }
                o = o->son[i];
            }
            if (o->sid < 0) {
                o->sid = sid++;
            }
            return o->sid;
        };

        // 初始化距离矩阵
        int m = cost.size();
        vector dis(m * 2, vector<int>(m * 2, INT_MAX / 2));
        for (int i = 0; i < m * 2; i++) {
            dis[i][i] = 0;
        }
        for (int i = 0; i < m; i++) {
            int x = put(original[i]);
            int y = put(changed[i]);
            dis[x][y] = min(dis[x][y], cost[i]);
        }

        // Floyd 求任意两点最短路
        for (int k = 0; k < sid; k++) {
            for (int i = 0; i < sid; i++) {
                if (dis[i][k] == INT_MAX / 2) { // 加上这句话,巨大优化!
                    continue;
                }
                for (int j = 0; j < sid; j++) {
                    dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);
                }
            }
        }

        int n = source.size();
        vector<long long> memo(n, -1);
        auto dfs = [&](this auto&& dfs, int i) -> long long {
            if (i >= n) {
                return 0;
            }
            auto& res = memo[i]; // 注意这里是引用
            if (res != -1) {
                return res;
            }
            res = LONG_LONG_MAX / 2;
            if (source[i] == target[i]) {
                res = dfs(i + 1); // 不修改 source[i]
            }
            Node* p = root;
            Node* q = root;
            for (int j = i; j < n; j++) {
                p = p->son[source[j] - 'a'];
                q = q->son[target[j] - 'a'];
                if (p == nullptr || q == nullptr) {
                    break;
                }
                if (p->sid < 0 || q->sid < 0) {
                    continue;
                }
                // 修改从 i 到 j 的这一段
                int d = dis[p->sid][q->sid];
                if (d < INT_MAX / 2) {
                    res = min(res, dis[p->sid][q->sid] + dfs(j + 1));
                }
            }
            return res;
        };
        long long ans = dfs(0);
        return ans < LONG_LONG_MAX / 2 ? ans : -1;
    }
};

###go

func minimumCost(source, target string, original, changed []string, cost []int) int64 {
const inf = math.MaxInt / 2
type node struct {
son [26]*node
sid int // 字符串的编号
}
root := &node{}
sid := 0
put := func(s string) int {
o := root
for _, b := range s {
b -= 'a'
if o.son[b] == nil {
o.son[b] = &node{sid: -1}
}
o = o.son[b]
}
if o.sid < 0 {
o.sid = sid
sid++
}
return o.sid
}

// 初始化距离矩阵
m := len(cost)
dis := make([][]int, m*2)
for i := range dis {
dis[i] = make([]int, m*2)
for j := range dis[i] {
if j != i {
dis[i][j] = inf
}
}
}
for i, c := range cost {
x := put(original[i])
y := put(changed[i])
dis[x][y] = min(dis[x][y], c)
}

// Floyd 求任意两点最短路
for k := 0; k < sid; k++ {
for i := 0; i < sid; i++ {
if dis[i][k] == inf {
continue
}
for j := 0; j < sid; j++ {
dis[i][j] = min(dis[i][j], dis[i][k]+dis[k][j])
}
}
}

n := len(source)
memo := make([]int, n)
for i := range memo {
memo[i] = -1
}
var dfs func(int) int
dfs = func(i int) int {
if i >= n {
return 0
}
ptr := &memo[i]
if *ptr != -1 {
return *ptr
}
res := inf
if source[i] == target[i] {
res = dfs(i + 1) // 不修改 source[i]
}
p, q := root, root
for j := i; j < n; j++ {
p = p.son[source[j]-'a']
q = q.son[target[j]-'a']
if p == nil || q == nil {
break
}
if p.sid >= 0 && q.sid >= 0 {
// 修改从 i 到 j 的这一段
res = min(res, dis[p.sid][q.sid]+dfs(j+1))
}
}
*ptr = res
return res
}
ans := dfs(0)
if ans == inf {
return -1
}
return int64(ans)
}

写法二:递推

也可以按照 动态规划入门:从记忆化搜索到递推【基础算法精讲 17】中讲的方法,1:1 翻译成递推。$f[i]$ 的含义与 $\textit{dfs}(i)$ 的含义是一样的。

###py

class Solution:
    def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
        len_to_strs = defaultdict(set)
        dis = defaultdict(lambda: defaultdict(lambda: inf))
        for x, y, c in zip(original, changed, cost):
            len_to_strs[len(x)].add(x)  # 按照长度分组
            len_to_strs[len(y)].add(y)
            dis[x][y] = min(dis[x][y], c)
            dis[x][x] = 0
            dis[y][y] = 0

        # 不同长度的字符串必然在不同的连通块中,分别计算 Floyd
        for strs in len_to_strs.values():
            for k in strs:
                for i in strs:
                    if dis[i][k] == inf:  # 加上这句话,巨大优化!
                        continue
                    for j in strs:
                        dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j])

        # f[i] 表示把 source[:i] 变成 target[:i] 的最小成本
        n = len(source)
        f = [0] + [inf] * n
        for i in range(1, n + 1):
            if source[i - 1] == target[i - 1]:
                f[i] = f[i - 1]  # 不修改 source[i]
            for size, strs in len_to_strs.items():  # 枚举子串长度
                if i < size:
                    continue
                s = source[i - size: i]
                t = target[i - size: i]
                if s in strs and t in strs:  # 可以替换
                    f[i] = min(f[i], dis[s][t] + f[i - size])
        return f[n] if f[n] < inf else -1

###py

class Node:
    __slots__ = 'son', 'sid'

    def __init__(self):
        self.son = [None] * 26
        self.sid = -1  # 字符串的编号

class Solution:
    def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
        ord_a = ord('a')
        root = Node()
        sid = 0

        def put(s: str) -> int:
            o = root
            for c in s:
                i = ord(c) - ord_a
                if o.son[i] is None:
                    o.son[i] = Node()
                o = o.son[i]
            if o.sid < 0:
                nonlocal sid
                o.sid = sid
                sid += 1
            return o.sid

        # 初始化距离矩阵
        m = len(cost)
        dis = [[inf] * (m * 2) for _ in range(m * 2)]
        for x, y, c in zip(original, changed, cost):
            x = put(x)
            y = put(y)
            dis[x][y] = min(dis[x][y], c)

        # Floyd 求任意两点最短路
        for k in range(sid):
            for i in range(sid):
                if dis[i][k] == inf:  # 加上这句话,巨大优化!
                    continue
                for j in range(sid):
                    dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j])

        n = len(source)
        f = [inf] * n + [0]
        for i in range(n - 1, -1, -1):
            if source[i] == target[i]:
                f[i] = f[i + 1]  # 不修改 source[i]
            p, q = root, root
            for j in range(i, n):
                p = p.son[ord(source[j]) - ord_a]
                q = q.son[ord(target[j]) - ord_a]
                if p is None or q is None:
                    break
                if p.sid < 0 or q.sid < 0:
                    continue
                # 修改从 i 到 j 的这一段
                f[i] = min(f[i], dis[p.sid][q.sid] + f[j + 1])
        return f[0] if f[0] < inf else -1

###java

class Node {
    Node[] son = new Node[26];
    int sid = -1; // 字符串的编号
}

class Solution {
    private Node root = new Node();
    private int sid = 0;

    public long minimumCost(String source, String target, String[] original, String[] changed, int[] cost) {
        // 初始化距离矩阵
        int m = cost.length;
        int[][] dis = new int[m * 2][m * 2];
        for (int i = 0; i < dis.length; i++) {
            Arrays.fill(dis[i], Integer.MAX_VALUE / 2);
            dis[i][i] = 0;
        }
        for (int i = 0; i < cost.length; i++) {
            int x = put(original[i]);
            int y = put(changed[i]);
            dis[x][y] = Math.min(dis[x][y], cost[i]);
        }

        // Floyd 求任意两点最短路
        for (int k = 0; k < sid; k++) {
            for (int i = 0; i < sid; i++) {
                if (dis[i][k] == Integer.MAX_VALUE / 2) {
                    continue;
                }
                for (int j = 0; j < sid; j++) {
                    dis[i][j] = Math.min(dis[i][j], dis[i][k] + dis[k][j]);
                }
            }
        }

        char[] s = source.toCharArray();
        char[] t = target.toCharArray();
        int n = s.length;
        long[] f = new long[n + 1];
        for (int i = n - 1; i >= 0; i--) {
            // 不修改 source[i]
            f[i] = s[i] == t[i] ? f[i + 1] : Long.MAX_VALUE / 2;
            Node p = root, q = root;
            for (int j = i; j < n; j++) {
                p = p.son[s[j] - 'a'];
                q = q.son[t[j] - 'a'];
                if (p == null || q == null) {
                    break;
                }
                if (p.sid < 0 || q.sid < 0) {
                    continue;
                }
                // 修改从 i 到 j 的这一段
                int d = dis[p.sid][q.sid];
                if (d < Integer.MAX_VALUE / 2) {
                    f[i] = Math.min(f[i], d + f[j + 1]);
                }
            }
        }
        return f[0] < Long.MAX_VALUE / 2 ? f[0] : -1;
    }

    private int put(String s) {
        Node o = root;
        for (char b : s.toCharArray()) {
            int i = b - 'a';
            if (o.son[i] == null) {
                o.son[i] = new Node();
            }
            o = o.son[i];
        }
        if (o.sid < 0) {
            o.sid = sid++;
        }
        return o.sid;
    }
}

###cpp

struct Node {
    Node* son[26]{};
    int sid = -1; // 字符串的编号
};

class Solution {
public:
    long long minimumCost(string source, string target, vector<string>& original, vector<string>& changed, vector<int>& cost) {
        Node* root = new Node();
        int sid = 0;
        auto put = [&](string& s) -> int {
            Node* o = root;
            for (char b: s) {
                int i = b - 'a';
                if (o->son[i] == nullptr) {
                    o->son[i] = new Node();
                }
                o = o->son[i];
            }
            if (o->sid < 0) {
                o->sid = sid++;
            }
            return o->sid;
        };

        // 初始化距离矩阵
        int m = cost.size();
        vector dis(m * 2, vector<int>(m * 2, INT_MAX / 2));
        for (int i = 0; i < m * 2; i++) {
            dis[i][i] = 0;
        }
        for (int i = 0; i < m; i++) {
            int x = put(original[i]);
            int y = put(changed[i]);
            dis[x][y] = min(dis[x][y], cost[i]);
        }

        // Floyd 求任意两点最短路
        for (int k = 0; k < sid; k++) {
            for (int i = 0; i < sid; i++) {
                if (dis[i][k] == INT_MAX / 2) { // 加上这句话,巨大优化!
                    continue;
                }
                for (int j = 0; j < sid; j++) {
                    dis[i][j] = min(dis[i][j], dis[i][k] + dis[k][j]);
                }
            }
        }

        int n = source.size();
        vector<long long> f(n + 1);
        for (int i = n - 1; i >= 0; i--) {
            // 不修改 source[i]
            f[i] = source[i] == target[i] ? f[i + 1] : LONG_LONG_MAX / 2;
            Node* p = root;
            Node* q = root;
            for (int j = i; j < n; j++) {
                p = p->son[source[j] - 'a'];
                q = q->son[target[j] - 'a'];
                if (p == nullptr || q == nullptr) {
                    break;
                }
                if (p->sid < 0 || q->sid < 0) {
                    continue;
                }
                // 修改从 i 到 j 的这一段
                int d = dis[p->sid][q->sid];
                if (d < INT_MAX / 2) {
                    f[i] = min(f[i], dis[p->sid][q->sid] + f[j + 1]);
                }
            }
        }
        return f[0] < LONG_LONG_MAX / 2 ? f[0] : -1;
    }
};

###go

func minimumCost(source, target string, original, changed []string, cost []int) int64 {
const inf = math.MaxInt / 2
type node struct {
son [26]*node
sid int // 字符串的编号
}
root := &node{}
sid := 0
put := func(s string) int {
o := root
for _, b := range s {
b -= 'a'
if o.son[b] == nil {
o.son[b] = &node{sid: -1}
}
o = o.son[b]
}
if o.sid < 0 {
o.sid = sid
sid++
}
return o.sid
}

// 初始化距离矩阵
m := len(cost)
dis := make([][]int, m*2)
for i := range dis {
dis[i] = make([]int, m*2)
for j := range dis[i] {
if j != i {
dis[i][j] = inf
}
}
}
for i, c := range cost {
x := put(original[i])
y := put(changed[i])
dis[x][y] = min(dis[x][y], c)
}

// Floyd 求任意两点最短路
for k := 0; k < sid; k++ {
for i := 0; i < sid; i++ {
if dis[i][k] == inf {
continue
}
for j := 0; j < sid; j++ {
dis[i][j] = min(dis[i][j], dis[i][k]+dis[k][j])
}
}
}

n := len(source)
f := make([]int, n+1)
for i := n - 1; i >= 0; i-- {
if source[i] == target[i] {
f[i] = f[i+1] // 不修改 source[i]
} else {
f[i] = inf
}
p, q := root, root
for j := i; j < n; j++ {
p = p.son[source[j]-'a']
q = q.son[target[j]-'a']
if p == nil || q == nil {
break
}
if p.sid >= 0 && q.sid >= 0 {
// 修改从 i 到 j 的这一段
f[i] = min(f[i], dis[p.sid][q.sid]+f[j+1])
}
}
}
if f[0] == inf {
return -1
}
return int64(f[0])
}

复杂度分析

  • 时间复杂度:$\mathcal{O}(n^2+mn+m^3)$,其中 $n$ 为 $\textit{source}$ 的长度,$m$ 为 $\textit{cost}$ 的长度。DP 需要 $\mathcal{O}(n^2)$ 的时间,把 $2m$ 个长度至多为 $n$ 的字符串插入字典树需要 $\mathcal{O}(mn)$ 的时间,Floyd 需要 $\mathcal{O}(m^3)$ 的时间。
  • 空间复杂度:$\mathcal{O}(n+mn+m^2)$。DP 需要 $\mathcal{O}(n)$ 的空间,把 $2m$ 个长度至多为 $n$ 的字符串插入字典树需要 $\mathcal{O}(mn)$ 的空间,Floyd 需要 $\mathcal{O}(m^2)$ 的空间。

专题训练

  1. 动态规划题单的「§5.2 最优划分」。
  2. 图论题单的「§3.2 全源最短路:Floyd 算法」。
  3. 数据结构题单的「六、字典树(trie)」。

分类题单

如何科学刷题?

  1. 滑动窗口与双指针(定长/不定长/单序列/双序列/三指针/分组循环)
  2. 二分算法(二分答案/最小化最大值/最大化最小值/第K小)
  3. 单调栈(基础/矩形面积/贡献法/最小字典序)
  4. 网格图(DFS/BFS/综合应用)
  5. 位运算(基础/性质/拆位/试填/恒等式/思维)
  6. 图论算法(DFS/BFS/拓扑排序/基环树/最短路/最小生成树/网络流)
  7. 动态规划(入门/背包/划分/状态机/区间/状压/数位/数据结构优化/树形/博弈/概率期望)
  8. 常用数据结构(前缀和/差分/栈/队列/堆/字典树/并查集/树状数组/线段树)
  9. 数学算法(数论/组合/概率期望/博弈/计算几何/随机算法)
  10. 贪心与思维(基本贪心策略/反悔/区间/字典序/数学/思维/脑筋急转弯/构造)
  11. 链表、树与回溯(前后指针/快慢指针/DFS/BFS/直径/LCA)
  12. 字符串(KMP/Z函数/Manacher/字符串哈希/AC自动机/后缀数组/子序列自动机)

我的题解精选(已分类)

欢迎关注 B站@灵茶山艾府

dijstra 模板题(对每个字母运行一次dijstra算法)

Problem: 2976. 转换字符串的最小成本 I

[TOC]

思路

dijstra 模板题(对每个字母运行一次dijstra算法)

Code

###Python3

class Solution:
    def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
        # 先建图,有向图(26个字母映射到0-25)
        g = [[] for _ in range(26)]
        dt = defaultdict()
        n = len(original)
        for i, x in enumerate(original):
            y = changed[i]
            c = cost[i]
            x0 = ord(x) - ord('a')
            y0 = ord(y) - ord('a')
            if (x0, y0) not in dt.keys():
                dt[(x0, y0)] = c 
                g[x0].append(y0)
            else:
                dt[(x0, y0)] = min(c, dt[(x0, y0)])
        
        # 运行dijstra 算法,统计所有点(26个字母)能到达的其他点的最短距离
        final_dt = defaultdict()
        
        # dijstra算法 统计x能到达的所有点点最短距路
        def dfs (x: int) -> None:
            dist = {} # 本次dijstra统计的距离,后续加到final_dt中
            unvisited_nodes = {} # 先用广度优先搜索将x能达到的点初始化到inf
            q = [(0, x, x)]
            vis = [0] * 26
            cur = [x]
            vis[x] = 1
            while cur:
                pre = cur
                cur = []
                for el in pre:
                    for y in g[el]:
                        if vis[y] == 0:
                            unvisited_nodes[(x, y)] = inf
                            vis[y] = 1
                            cur.append(y)
            # 开始最小路径搜索
            unvisited_nodes[(x, x)] = 0
            seen = set()
            # 使用 dijstra算法计算达到各店的最短值
            while unvisited_nodes:
                current_distance, x1, y1 = heapq.heappop(q)
                if y1 in seen:
                    continue
                seen.add(y1)
                for el in g[y1]:
                    if (x, el) not in unvisited_nodes: continue
                    new_distance = current_distance + dt[(y1,el)]
                    if new_distance < unvisited_nodes[(x, el)]:
                        unvisited_nodes[(x, el)] = new_distance
                        heapq.heappush(q, (new_distance, x, el))
                dist[(x,y1)] = current_distance
                unvisited_nodes.pop((x, y1))
            for k, v in dist.items():
                final_dt[k] = v

        # 对每个字母运行dijstra算法
        for i in range(26):
            dfs(i)
        ans = 0
        # 统计完,开始对每个字母改变计算答案,如果达不到,返回-1
        for i, x in enumerate(source):
            x1 = ord(x) - ord('a')
            y1 = ord(target[i]) - ord('a')
            if x1 != y1:
                if (x1, y1) not in final_dt.keys():
                    return - 1
                else:
                    ans += final_dt[(x1, y1)]
        return ans

每日一题-转换字符串的最小成本 I🟡

给你两个下标从 0 开始的字符串 sourcetarget ,它们的长度均为 n 并且由 小写 英文字母组成。

另给你两个下标从 0 开始的字符数组 originalchanged ,以及一个整数数组 cost ,其中 cost[i] 代表将字符 original[i] 更改为字符 changed[i] 的成本。

你从字符串 source 开始。在一次操作中,如果 存在 任意 下标 j 满足 cost[j] == z  、original[j] == x 以及 changed[j] == y 。你就可以选择字符串中的一个字符 x 并以 z 的成本将其更改为字符 y

返回将字符串 source 转换为字符串 target 所需的 最小 成本。如果不可能完成转换,则返回 -1

注意,可能存在下标 ij 使得 original[j] == original[i]changed[j] == changed[i]

 

示例 1:

输入:source = "abcd", target = "acbe", original = ["a","b","c","c","e","d"], changed = ["b","c","b","e","b","e"], cost = [2,5,5,1,2,20]
输出:28
解释:将字符串 "abcd" 转换为字符串 "acbe" :
- 更改下标 1 处的值 'b' 为 'c' ,成本为 5 。
- 更改下标 2 处的值 'c' 为 'e' ,成本为 1 。
- 更改下标 2 处的值 'e' 为 'b' ,成本为 2 。
- 更改下标 3 处的值 'd' 为 'e' ,成本为 20 。
产生的总成本是 5 + 1 + 2 + 20 = 28 。
可以证明这是可能的最小成本。

示例 2:

输入:source = "aaaa", target = "bbbb", original = ["a","c"], changed = ["c","b"], cost = [1,2]
输出:12
解释:要将字符 'a' 更改为 'b':
- 将字符 'a' 更改为 'c',成本为 1 
- 将字符 'c' 更改为 'b',成本为 2 
产生的总成本是 1 + 2 = 3。
将所有 'a' 更改为 'b',产生的总成本是 3 * 4 = 12 。

示例 3:

输入:source = "abcd", target = "abce", original = ["a"], changed = ["e"], cost = [10000]
输出:-1
解释:无法将 source 字符串转换为 target 字符串,因为下标 3 处的值无法从 'd' 更改为 'e' 。

 

提示:

  • 1 <= source.length == target.length <= 105
  • sourcetarget 均由小写英文字母组成
  • 1 <= cost.length== original.length == changed.length <= 2000
  • original[i]changed[i] 是小写英文字母
  • 1 <= cost[i] <= 106
  • original[i] != changed[i]
❌