在 Python 中实现一个有向无环图(DAG)通常有两种场景:
下面我将分层次展示如何实现。
如果你需要理解原理,或者只需要一个简单的依赖管理工具,可以手动实现一个包含拓扑排序的 DAG 类。
核心功能:
from collections import defaultdict, deque
from typing import List, Dict, Set, Callable, Any
import concurrent.futures
class DAG:
def __init__(self):
# 邻接表:节点 -> [依赖它的节点] (下游)
self.graph = defaultdict(list)
# 反向邻接表:节点 -> [它依赖的节点] (上游)
self.reverse_graph = defaultdict(list)
# 所有节点集合
self.nodes = set()
# 存储任务函数
self.tasks = {}
def add_node(self, node_id: str, task_func: Callable = None):
"""添加节点"""
self.nodes.add(node_id)
if task_func:
self.tasks[node_id] = task_func
def add_edge(self, from_node: str, to_node: str):
"""添加依赖边:from_node -> to_node (from 必须先于 to 执行)"""
if from_node not in self.nodes or to_node not in self.nodes:
raise ValueError("Node not found")
self.graph[from_node].append(to_node)
self.reverse_graph[to_node].append(from_node)
def _has_cycle(self) -> bool:
"""使用 DFS 检测是否有环"""
visited = set()
rec_stack = set() # 递归栈,用于检测当前路径上的节点
def dfs(node):
visited.add(node)
rec_stack.add(node)
for neighbor in self.graph[node]:
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
return True
rec_stack.remove(node)
return False
for node in self.nodes:
if node not in visited:
if dfs(node):
return True
return False
def topological_sort(self) -> List[str]:
"""返回拓扑排序后的节点列表 (执行顺序)"""
if self._has_cycle():
raise ValueError("Graph contains a cycle! Not a DAG.")
# 计算入度
in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes}
queue = deque([node for node in self.nodes if in_degree[node] == 0])
sorted_order = []
while queue:
node = queue.popleft()
sorted_order.append(node)
for neighbor in self.graph[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if len(sorted_order) != len(self.nodes):
raise ValueError("Cycle detected during sorting")
return sorted_order
def execute(self, max_workers: int = 4) -> Dict[str, Any]:
"""
简单的并行执行引擎:
只有当所有依赖完成后,任务才会被提交执行。
"""
if self._has_cycle():
raise ValueError("Cannot execute a graph with cycles.")
results = {}
# 记录每个节点的入度(动态拷贝)
in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes}
# 就绪队列:入度为 0 的节点
ready_queue = deque([node for node in self.nodes if in_degree[node] == 0])
print(f"Starting execution with {max_workers} workers...")
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {} # future -> node_id
while ready_queue or futures:
# 1. 提交所有就绪的任务
while ready_queue:
node = ready_queue.popleft()
if node in self.tasks:
print(f"Submitting task: {node}")
future = executor.submit(self.tasks[node], results)
futures[future] = node
else:
# 如果没有绑定函数,视为立即完成(仅作为依赖桩)
results[node] = None
self._update_dependents(node, in_degree, ready_queue)
if not futures:
break
# 2. 等待任意一个任务完成
done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
for future in done:
node = futures.pop(future)
try:
result = future.result()
results[node] = result
print(f"Task {node} completed.")
except Exception as e:
print(f"Task {node} failed: {e}")
raise e
self._update_dependents(node, in_degree, ready_queue)
return results
def _update_dependents(self, completed_node: str, in_degree: dict, queue: deque):
"""当一个节点完成时,减少其下游节点的入度"""
for neighbor in self.graph[completed_node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# --- 使用示例 ---
if __name__ == "__main__":
dag = DAG()
# 定义任务
def task_a(results): return "Result A"
def task_b(results): return f"Result B (dep on {results['A']})"
def task_c(results): return f"Result C (dep on {results['A']})"
def task_d(results): return f"Result D (dep on {results['B']}, {results['C']})"
# 构建图
dag.add_node("A", task_a)
dag.add_node("B", task_b)
dag.add_node("C", task_c)
dag.add_node("D", task_d)
# 建立依赖: A -> B, A -> C, B -> D, C -> D
dag.add_edge("A", "B")
dag.add_edge("A", "C")
dag.add_edge("B", "D")
dag.add_edge("C", "D")
# 1. 查看执行顺序 (拓扑排序)
print("Execution Order:", dag.topological_sort())
# 2. 执行
final_results = dag.execute(max_workers=2)
print("Final Results:", final_results)
networkx 库(适合复杂图分析)如果你需要复杂的图算法(如关键路径分析、可视化、子图提取),不要重复造轮子,使用 networkx。
pip install networkx matplotlib
import networkx as nx
import matplotlib.pyplot as plt
# 创建有向图
G = nx.DiGraph()
# 添加节点和边
G.add_edge("A", "B")
G.add_edge("A", "C")
G.add_edge("B", "D")
G.add_edge("C", "D")
G.add_edge("D", "E")
# 1. 检查是否是 DAG
if nx.is_directed_acyclic_graph(G):
print("这是一个有效的 DAG")
else:
print("图中存在环!")
# 2. 获取拓扑排序
order = list(nx.topological_sort(G))
print("执行顺序:", order)
# 3. 获取所有祖先/后代
print("D 的祖先:", nx.ancestors(G, "D")) # ['A', 'B', 'C']
print("A 的后代:", nx.descendants(G, "A")) # ['B', 'C', 'D', 'E']
# 4. 可视化
nx.draw(G, with_labels=True, node_color='lightblue', arrowsize=20)
plt.show()
如果你是想实现类似 Nextflow 的生产级任务调度(支持重试、日志、持久化状态、分布式),请直接使用现有的 Python 生态工具,而不是自己写代码。
Prefect (现代、Pythonic、易于上手)
from prefect import flow, task
@task
def extract():
return [1, 2, 3]
@task
def transform(data):
return [x * 2 for x in data]
@flow
def my_dag_flow():
data = extract()
# Prefect 自动根据变量依赖构建 DAG
transformed = transform(data)
return transformed
if __name__ == "__main__":
my_dag_flow()
Airflow (老牌、强大、基于配置)
Snakemake (生物信息学专用)
networkx)。Prefect 或 Airflow),不要自己维护调度器,因为处理失败重试、并发锁和状态持久化非常复杂。