Python实现dag

最后发布时间:2026-03-19 14:36:52 浏览量:

在 Python 中实现一个有向无环图(DAG)通常有两种场景:

  1. 数据结构层面:你需要构建一个图结构,进行拓扑排序、依赖检查或遍历。
  2. 工作流引擎层面:你需要像 Airflow 或 Nextflow 那样,根据依赖关系调度执行任务。

下面我将分层次展示如何实现。


方案一:纯手工实现 DAG 类(适合学习与轻量级需求)

如果你需要理解原理,或者只需要一个简单的依赖管理工具,可以手动实现一个包含拓扑排序的 DAG 类。

核心功能

  • 添加节点和边(依赖)。
  • 检测环(确保是无环的)。
  • 拓扑排序(确定执行顺序)。
  • 并行执行模拟。
from collections import defaultdict, deque
from typing import List, Dict, Set, Callable, Any
import concurrent.futures

class DAG:
    def __init__(self):
        # 邻接表:节点 -> [依赖它的节点] (下游)
        self.graph = defaultdict(list)
        # 反向邻接表:节点 -> [它依赖的节点] (上游)
        self.reverse_graph = defaultdict(list)
        # 所有节点集合
        self.nodes = set()
        # 存储任务函数
        self.tasks = {}

    def add_node(self, node_id: str, task_func: Callable = None):
        """添加节点"""
        self.nodes.add(node_id)
        if task_func:
            self.tasks[node_id] = task_func

    def add_edge(self, from_node: str, to_node: str):
        """添加依赖边:from_node -> to_node (from 必须先于 to 执行)"""
        if from_node not in self.nodes or to_node not in self.nodes:
            raise ValueError("Node not found")
        
        self.graph[from_node].append(to_node)
        self.reverse_graph[to_node].append(from_node)

    def _has_cycle(self) -> bool:
        """使用 DFS 检测是否有环"""
        visited = set()
        rec_stack = set() # 递归栈,用于检测当前路径上的节点

        def dfs(node):
            visited.add(node)
            rec_stack.add(node)

            for neighbor in self.graph[node]:
                if neighbor not in visited:
                    if dfs(neighbor):
                        return True
                elif neighbor in rec_stack:
                    return True
            
            rec_stack.remove(node)
            return False

        for node in self.nodes:
            if node not in visited:
                if dfs(node):
                    return True
        return False

    def topological_sort(self) -> List[str]:
        """返回拓扑排序后的节点列表 (执行顺序)"""
        if self._has_cycle():
            raise ValueError("Graph contains a cycle! Not a DAG.")

        # 计算入度
        in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes}
        queue = deque([node for node in self.nodes if in_degree[node] == 0])
        sorted_order = []

        while queue:
            node = queue.popleft()
            sorted_order.append(node)

            for neighbor in self.graph[node]:
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)

        if len(sorted_order) != len(self.nodes):
            raise ValueError("Cycle detected during sorting")
            
        return sorted_order

    def execute(self, max_workers: int = 4) -> Dict[str, Any]:
        """
        简单的并行执行引擎:
        只有当所有依赖完成后,任务才会被提交执行。
        """
        if self._has_cycle():
            raise ValueError("Cannot execute a graph with cycles.")

        results = {}
        # 记录每个节点的入度(动态拷贝)
        in_degree = {node: len(self.reverse_graph[node]) for node in self.nodes}
        # 就绪队列:入度为 0 的节点
        ready_queue = deque([node for node in self.nodes if in_degree[node] == 0])
        
        print(f"Starting execution with {max_workers} workers...")

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {} # future -> node_id
            
            while ready_queue or futures:
                # 1. 提交所有就绪的任务
                while ready_queue:
                    node = ready_queue.popleft()
                    if node in self.tasks:
                        print(f"Submitting task: {node}")
                        future = executor.submit(self.tasks[node], results)
                        futures[future] = node
                    else:
                        # 如果没有绑定函数,视为立即完成(仅作为依赖桩)
                        results[node] = None
                        self._update_dependents(node, in_degree, ready_queue)

                if not futures:
                    break

                # 2. 等待任意一个任务完成
                done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
                
                for future in done:
                    node = futures.pop(future)
                    try:
                        result = future.result()
                        results[node] = result
                        print(f"Task {node} completed.")
                    except Exception as e:
                        print(f"Task {node} failed: {e}")
                        raise e
                    
                    self._update_dependents(node, in_degree, ready_queue)

        return results

    def _update_dependents(self, completed_node: str, in_degree: dict, queue: deque):
        """当一个节点完成时,减少其下游节点的入度"""
        for neighbor in self.graph[completed_node]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

# --- 使用示例 ---
if __name__ == "__main__":
    dag = DAG()

    # 定义任务
    def task_a(results): return "Result A"
    def task_b(results): return f"Result B (dep on {results['A']})"
    def task_c(results): return f"Result C (dep on {results['A']})"
    def task_d(results): return f"Result D (dep on {results['B']}, {results['C']})"

    # 构建图
    dag.add_node("A", task_a)
    dag.add_node("B", task_b)
    dag.add_node("C", task_c)
    dag.add_node("D", task_d)

    # 建立依赖: A -> B, A -> C, B -> D, C -> D
    dag.add_edge("A", "B")
    dag.add_edge("A", "C")
    dag.add_edge("B", "D")
    dag.add_edge("C", "D")

    # 1. 查看执行顺序 (拓扑排序)
    print("Execution Order:", dag.topological_sort())

    # 2. 执行
    final_results = dag.execute(max_workers=2)
    print("Final Results:", final_results)

方案二:使用 networkx 库(适合复杂图分析)

如果你需要复杂的图算法(如关键路径分析、可视化、子图提取),不要重复造轮子,使用 networkx

pip install networkx matplotlib
import networkx as nx
import matplotlib.pyplot as plt

# 创建有向图
G = nx.DiGraph()

# 添加节点和边
G.add_edge("A", "B")
G.add_edge("A", "C")
G.add_edge("B", "D")
G.add_edge("C", "D")
G.add_edge("D", "E")

# 1. 检查是否是 DAG
if nx.is_directed_acyclic_graph(G):
    print("这是一个有效的 DAG")
else:
    print("图中存在环!")

# 2. 获取拓扑排序
order = list(nx.topological_sort(G))
print("执行顺序:", order)

# 3. 获取所有祖先/后代
print("D 的祖先:", nx.ancestors(G, "D")) # ['A', 'B', 'C']
print("A 的后代:", nx.descendants(G, "A")) # ['B', 'C', 'D', 'E']

# 4. 可视化
nx.draw(G, with_labels=True, node_color='lightblue', arrowsize=20)
plt.show()

方案三:使用成熟的工作流引擎(生产环境推荐)

如果你是想实现类似 Nextflow 的生产级任务调度(支持重试、日志、持久化状态、分布式),请直接使用现有的 Python 生态工具,而不是自己写代码。

  1. Prefect (现代、Pythonic、易于上手)

    • 特点:代码即工作流,动态 DAG,本地运行无缝切换到云端。
    • 适用:数据管道、机器学习流程。
    from prefect import flow, task
    
    @task
    def extract():
        return [1, 2, 3]
    
    @task
    def transform(data):
        return [x * 2 for x in data]
    
    @flow
    def my_dag_flow():
        data = extract()
        # Prefect 自动根据变量依赖构建 DAG
        transformed = transform(data) 
        return transformed
    
    if __name__ == "__main__":
        my_dag_flow()
    
  2. Airflow (老牌、强大、基于配置)

    • 特点:基于时间调度,强大的 UI,适合复杂的 ETL。
    • 缺点:配置较重,动态性稍差。
  3. Snakemake (生物信息学专用)

    • 特点:类似 Makefile 的语法但用 Python 写,自动处理文件依赖。

总结

  • 学习/面试/简单脚本:使用 方案一(手写类),理解入度、队列和拓扑排序。
  • 图分析/可视化:使用 方案二 (networkx)。
  • 实际工程项目:使用 方案三 (PrefectAirflow),不要自己维护调度器,因为处理失败重试、并发锁和状态持久化非常复杂。