在 Python 中实现一个有向无环图(DAG)通常有两种场景:
下面我将分层次展示如何实现。
如果你需要理解原理,或者只需要一个简单的依赖管理工具,可以手动实现一个包含拓扑排序的 DAG 类。
核心功能:
from collections import defaultdict, deque from typing import List, Dict, Set, Callable, Any import concurrent.futures class DAG: def __init__(self): # 邻接表:节点 -> [依赖它的节点] (下游) self.graph = defaultdict(list) # 反向邻接表:节点 -> [它依赖的节点] (上游) self.reverse_graph = defaultdict(list) # 所有节点集合 self.nodes = set() # 存储任务函数 self.tasks = {} def add_node(self, node_id: str, task_func: Callable = None): """添加节点""" self.nodes.add(node_id) if task_func: self.tasks[node_id] = task_func def add_edge(self, from_node: str, to_node: str): """添加依赖边:from_node -> to_node (from 必须先于 to 执行)""" if from_node not in self.nodes or to_node not in self.nodes: raise ValueError("Node not found") self.graph[from_node].append(to_node) self.reverse_graph[to_node].append(from_node) def _has_cycle(self) -> bool: """使用 DFS 检测是否有环""" visited = set() rec_stack = set() # 递归栈,用于检测当前路径上的节点 def dfs(node): visited.add(node) rec_stack.add(node) for neighbor in self.graph[node]: if neighbor not in visited: if dfs(neighbor): return True elif neighbor in rec_stack: return True rec_stack.remove(node) return False for node in self.nodes: if node not in visited: if dfs(node): return True return False def topological_sort(self) -> List[str]: """返回拓扑排序后的节点列表 (执行顺序)""" if self._has_cycle(): raise ValueError("Graph contains a cycle! Not a DAG.") # 计算入度 in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes} queue = deque([node for node in self.nodes if in_degree[node] == 0]) sorted_order = [] while queue: node = queue.popleft() sorted_order.append(node) for neighbor in self.graph[node]: in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if len(sorted_order) != len(self.nodes): raise ValueError("Cycle detected during sorting") return sorted_order def execute(self, max_workers: int = 4) -> Dict[str, Any]: """ 简单的并行执行引擎: 只有当所有依赖完成后,任务才会被提交执行。 """ if self._has_cycle(): raise ValueError("Cannot execute a graph with cycles.") results = {} # 记录每个节点的入度(动态拷贝) in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes} # 就绪队列:入度为 0 的节点 ready_queue = deque([node for node in self.nodes if in_degree[node] == 0]) print(f"Starting execution with {max_workers} workers...") with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = {} # future -> node_id while ready_queue or futures: # 1. 提交所有就绪的任务 while ready_queue: node = ready_queue.popleft() if node in self.tasks: print(f"Submitting task: {node}") future = executor.submit(self.tasks[node], results) futures[future] = node else: # 如果没有绑定函数,视为立即完成(仅作为依赖桩) results[node] = None self._update_dependents(node, in_degree, ready_queue) if not futures: break # 2. 等待任意一个任务完成 done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) for future in done: node = futures.pop(future) try: result = future.result() results[node] = result print(f"Task {node} completed.") except Exception as e: print(f"Task {node} failed: {e}") raise e self._update_dependents(node, in_degree, ready_queue) return results def _update_dependents(self, completed_node: str, in_degree: dict, queue: deque): """当一个节点完成时,减少其下游节点的入度""" for neighbor in self.graph[completed_node]: in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) # --- 使用示例 --- if __name__ == "__main__": dag = DAG() # 定义任务 def task_a(results): return "Result A" def task_b(results): return f"Result B (dep on {results['A']})" def task_c(results): return f"Result C (dep on {results['A']})" def task_d(results): return f"Result D (dep on {results['B']}, {results['C']})" # 构建图 dag.add_node("A", task_a) dag.add_node("B", task_b) dag.add_node("C", task_c) dag.add_node("D", task_d) # 建立依赖: A -> B, A -> C, B -> D, C -> D dag.add_edge("A", "B") dag.add_edge("A", "C") dag.add_edge("B", "D") dag.add_edge("C", "D") # 1. 查看执行顺序 (拓扑排序) print("Execution Order:", dag.topological_sort()) # 2. 执行 final_results = dag.execute(max_workers=2) print("Final Results:", final_results)
networkx
如果你需要复杂的图算法(如关键路径分析、可视化、子图提取),不要重复造轮子,使用 networkx。
pip install networkx matplotlib
import networkx as nx import matplotlib.pyplot as plt # 创建有向图 G = nx.DiGraph() # 添加节点和边 G.add_edge("A", "B") G.add_edge("A", "C") G.add_edge("B", "D") G.add_edge("C", "D") G.add_edge("D", "E") # 1. 检查是否是 DAG if nx.is_directed_acyclic_graph(G): print("这是一个有效的 DAG") else: print("图中存在环!") # 2. 获取拓扑排序 order = list(nx.topological_sort(G)) print("执行顺序:", order) # 3. 获取所有祖先/后代 print("D 的祖先:", nx.ancestors(G, "D")) # ['A', 'B', 'C'] print("A 的后代:", nx.descendants(G, "A")) # ['B', 'C', 'D', 'E'] # 4. 可视化 nx.draw(G, with_labels=True, node_color='lightblue', arrowsize=20) plt.show()
如果你是想实现类似 Nextflow 的生产级任务调度(支持重试、日志、持久化状态、分布式),请直接使用现有的 Python 生态工具,而不是自己写代码。
Prefect (现代、Pythonic、易于上手)
from prefect import flow, task @task def extract(): return [1, 2, 3] @task def transform(data): return [x * 2 for x in data] @flow def my_dag_flow(): data = extract() # Prefect 自动根据变量依赖构建 DAG transformed = transform(data) return transformed if __name__ == "__main__": my_dag_flow()
Airflow (老牌、强大、基于配置)
Snakemake (生物信息学专用)
Prefect
Airflow
Github开源生信云平台 DEMO