import dgl
import networkx as nx
import matplotlib.pyplot as plt
import torch

# 假设你已经有了一个DGL图对象bg
# 如果没有，这里提供一个示例图的创建
# u = torch.tensor([0, 1, 2, 3, 2])
# v = torch.tensor([1, 2, 3, 4, 0])
# bg = dgl.graph((u, v))

def visualize_dgl_graph(g, node_labels=None, edge_labels=None, figsize=(32, 16)):
    """
    可视化DGL图
    
    参数:
    g: DGL图对象
    node_labels: 节点标签字典 (可选)
    edge_labels: 边标签字典 (可选)
    figsize: 图像大小
    """
    # 将DGL图转换为NetworkX图
    nx_g = dgl.to_networkx(g)
    
    # 创建图形
    plt.figure(figsize=figsize)
    
    # 设置布局
    pos = nx.spring_layout(nx_g, seed=42)
    
    # 绘制节点
    nx.draw_networkx_nodes(nx_g, pos, 
                          node_color='lightblue', 
                          node_size=500,
                          alpha=0.8)
    
    # 绘制边
    nx.draw_networkx_edges(nx_g, pos, 
                          edge_color='gray',
                          arrows=True,
                          arrowsize=20,
                          alpha=0.6)
    
    # 绘制节点标签
    if node_labels is None:
        node_labels = {i: str(i) for i in range(g.num_nodes())}
    nx.draw_networkx_labels(nx_g, pos, node_labels, font_size=8)
    
    # 绘制边标签（如果提供）
    if edge_labels is not None:
        nx.draw_networkx_edge_labels(nx_g, pos, edge_labels)
    
    plt.title(f"DGL\n num_nodes: {g.num_nodes()}, num_edges: {g.num_edges()}")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# # 使用示例
# visualize_dgl_graph(bg)