普通视图

发现新文章,点击刷新页面。
昨天 — 2026年4月14日首页

python 费用流

作者 981377660LMT
2022年11月6日 13:47

代码

###python3

INF = int(1e18)

class Solution:
    def minimumTotalDistance(self, robot: List[int], factory: List[List[int]]) -> int:
        n, m = len(robot), len(factory)
        STRAT, END = n + m + 3, n + m + 4
        mcmf = MinCostMaxFlow(n + m + 10, STRAT, END)
        for i in range(n):
            mcmf.addEdge(STRAT, i, 1, 0)
        for i in range(n):
            for j in range(m):
                mcmf.addEdge(i, n + j, 1, abs(robot[i] - factory[j][0]))
        for i in range(m):
            mcmf.addEdge(n + i, END, factory[i][1], 0)
        return mcmf.work()[1]



class Edge:
    __slots__ = ("fromV", "toV", "cap", "cost", "flow")

    def __init__(self, fromV: int, toV: int, cap: int, cost: int, flow: int) -> None:
        self.fromV = fromV
        self.toV = toV
        self.cap = cap
        self.cost = cost
        self.flow = flow


class MinCostMaxFlow:
    """最小费用流的连续最短路算法复杂度为流量*最短路算法复杂度"""

    __slots__ = ("_n", "_start", "_end", "_edges", "_reGraph", "_dist", "_visited", "_curEdges")

    def __init__(self, n: int, start: int, end: int):
        """
        Args:
            n (int): 包含虚拟点在内的总点数
            start (int): (虚拟)源点
            end (int): (虚拟)汇点
        """
        assert 0 <= start < n and 0 <= end < n
        self._n = n
        self._start = start
        self._end = end
        self._edges: List["Edge"] = []
        self._reGraph: List[List[int]] = [[] for _ in range(n + 10)]  # 残量图存储的是边的下标

        self._dist = [INF] * (n + 10)
        self._visited = [False] * (n + 10)
        self._curEdges = [0] * (n + 10)

    def addEdge(self, fromV: int, toV: int, cap: int, cost: int) -> None:
        """原边索引为i 反向边索引为i^1"""
        self._edges.append(Edge(fromV, toV, cap, cost, 0))
        self._edges.append(Edge(toV, fromV, 0, -cost, 0))
        len_ = len(self._edges)
        self._reGraph[fromV].append(len_ - 2)
        self._reGraph[toV].append(len_ - 1)

    def work(self) -> Tuple[int, int]:
        """
        Returns:
            Tuple[int, int]: [最大流,最小费用]
        """
        maxFlow, minCost = 0, 0
        while self._spfa():
            # !如果流量限定为1,那么一次dfs只会找到一条费用最小的增广流
            # !如果流量限定为INF,那么一次dfs不只会找到一条费用最小的增广流
            flow = self._dfs(self._start, self._end, INF)
            maxFlow += flow
            minCost += flow * self._dist[self._end]
        return maxFlow, minCost

    def slope(self) -> List[Tuple[int, int]]:
        """
        Returns:
            List[Tuple[int, int]]: 流量为a时,最小费用是b
        """
        res = [(0, 0)]
        flow, cost = 0, 0
        while self._spfa():
            deltaFlow = self._dfs(self._start, self._end, INF)
            flow += deltaFlow
            cost += deltaFlow * self._dist[self._end]
            res.append((flow, cost))  # type: ignore
        return res

    def _spfa(self) -> bool:
        """spfa沿着最短路寻找增广路径  有负cost的边不能用dijkstra"""
        n, start, end, edges, reGraph, visited = (
            self._n,
            self._start,
            self._end,
            self._edges,
            self._reGraph,
            self._visited,
        )

        self._curEdges = [0] * n
        self._dist = dist = [INF] * n
        dist[start] = 0
        queue = deque([start])

        while queue:
            cur = queue.popleft()
            visited[cur] = False
            for edgeIndex in reGraph[cur]:
                edge = edges[edgeIndex]
                cost, remain, next = edge.cost, edge.cap - edge.flow, edge.toV
                if remain > 0 and dist[cur] + cost < dist[next]:
                    dist[next] = dist[cur] + cost
                    if not visited[next]:
                        visited[next] = True
                        if queue and dist[queue[0]] > dist[next]:
                            queue.appendleft(next)
                        else:
                            queue.append(next)

        return dist[end] != INF

    def _dfs(self, cur: int, end: int, flow: int) -> int:
        if cur == end:
            return flow

        visited, reGraph, curEdges, edges, dist = (
            self._visited,
            self._reGraph,
            self._curEdges,
            self._edges,
            self._dist,
        )

        visited[cur] = True
        res = flow
        index = curEdges[cur]
        while res and index < len(reGraph[cur]):
            edgeIndex = reGraph[cur][index]
            next, remain = edges[edgeIndex].toV, edges[edgeIndex].cap - edges[edgeIndex].flow
            if remain > 0 and not visited[next] and dist[next] == dist[cur] + edges[edgeIndex].cost:
                delta = self._dfs(next, end, remain if remain < res else res)
                res -= delta
                edges[edgeIndex].flow += delta
                edges[edgeIndex ^ 1].flow -= delta
            curEdges[cur] += 1
            index = curEdges[cur]

        visited[cur] = False
        return flow - res

###python3

INF = int(1e18)

class Solution:
    def minimumTotalDistance(self, robot: List[int], factory: List[List[int]]) -> int:
        boys, girls = robot, []  
        for pos, limit in factory:
            girls.extend([pos] * limit)  
        costMatrix = [[0] * len(girls) for _ in range(len(boys))]
        for i, pos1 in enumerate(boys):
            for j, pos2 in enumerate(girls):
                costMatrix[i][j] = -abs(pos1 - pos2)  # 最大权匹配转换为最小权匹配
        return -KM(costMatrix)[0]

        
def KM(costMatrix: List[List[int]]) -> Tuple[int, Tuple[List[int], List[int]]]:
    """KM算法求带权二分图的最大权匹配

    Args
    ----------
    costMatrix (List[List[int]]):
        二分图的权值矩阵,不存在的边应初始化为`-INF`

    Returns
    ----------
    Tuple[int, Tuple[List[int], List[int]]]:
        `最大权匹配值, 匹配对的行索引、列索引`

    Examples
    ----------
    >>> costMatrix = [[1, 2, 3], [2, 4, 6], [3, 6, 9]]
    >>> maxSum, (rows, cols) = KuhnMunkres(costMatrix)
    >>> maxSum
    14
    >>> rows cols
    [0, 1, 2] [0, 1, 2]
    >>> sum(costMatrix[i][j] for i, j in zip(rows, cols))
    14
    """
    max_ = max(len(costMatrix), len(costMatrix[0]))
    _match = [-1] * max_  # 记录每个女生匹配到的男生 如果没有则为-1
    _graph = costMatrix  # 记录每个男生和每个女生之间的`好感度`
    _visitedBoy = set()  # 记录每一轮匹配匹配过的男生
    _visitedGirl = set()  # 记录每一轮匹配匹配过的女生
    _expBoy = [max(row) for row in costMatrix]  # 每个男生的期望值
    _expGirl = [0] * max_  # 每个女生的期望值,为0表示只要有一个男生就可以
    _slack = []  # 记录每个女生如果能被男生倾心最少还需要多少期望值
    _pre = []
    _row = len(costMatrix)
    _col = len(costMatrix[0])

    def dfs(boy: int) -> int:
        _visitedBoy.add(boy)
        for girl in range(_col):
            if girl in _visitedGirl:
                continue
            delta = _expBoy[boy] + _expGirl[girl] - _graph[boy][girl]
            # 符合要求
            if delta == 0:
                _visitedGirl.add(girl)
                _pre[girl + _row] = boy
                if _match[girl] == -1:
                    return girl + _row
                _pre[_match[girl]] = girl + _row
                nextRes = dfs(_match[girl])  # 找到增广
                if nextRes > 0:
                    return nextRes
            # 女生要得到男生的倾心 还需多少期望值
            elif _slack[boy] > delta:
                _slack[boy] = delta

        return -1

    for boy in range(_row):
        _visitedBoy.clear()
        _visitedGirl.clear()
        _slack = [INF] * _col
        _pre = [-1] * (_row + _col)
        visited = False
        cand = -1
        # 记录每轮匹配中男生女生是否被尝试匹配过
        while True:
            if not visited:
                cand = dfs(boy)
                visited = True
            else:
                for r in range(_row):
                    if _slack[r] == 0:
                        _slack[r] = INF
                        cand = dfs(r)
                        if cand > 0:
                            break

            if cand > 0:
                tmp = cand
                while tmp > 0:
                    _match[tmp - _row] = _pre[tmp]
                    tmp = _pre[_pre[tmp]]
                break
            else:
                # 如果不能找到 就降低期望值
                # 最小可降低的期望值
                delta = INF
                for c in range(_row):
                    if c in _visitedBoy and _slack[c] < delta:
                        delta = _slack[c]
                for r in range(_row):
                    if r in _visitedBoy:
                        # 所有访问过的男生降低期望值
                        _expBoy[r] -= delta
                        _slack[r] -= delta
                for c in range(_col):
                    if c in _visitedGirl:
                        # 所有访问过的女生增加期望值
                        _expGirl[c] += delta

    # 匹配完成 求出所有配对的好感度的和
    res, rows, cols = 0, [], []
    for girl, boy in enumerate(_match):
        if boy != -1:
            res += _graph[boy][girl]
            rows.append(boy)
            cols.append(girl)
    return res, (rows, cols)
❌
❌