Commit 8923a7f7 by Leo

upload code

parent 4606f1b4
import dgl
import dgl
import networkx as nx
import matplotlib.pyplot as plt
import torch
# 假设你已经有了一个DGL图对象bg
# 如果没有,这里提供一个示例图的创建
# u = torch.tensor([0, 1, 2, 3, 2])
# v = torch.tensor([1, 2, 3, 4, 0])
# bg = dgl.graph((u, v))
def visualize_dgl_graph(g, node_labels=None, edge_labels=None, figsize=(32, 16)):
"""
可视化DGL图
参数:
g: DGL图对象
node_labels: 节点标签字典 (可选)
edge_labels: 边标签字典 (可选)
figsize: 图像大小
"""
# 将DGL图转换为NetworkX图
nx_g = dgl.to_networkx(g)
# 创建图形
plt.figure(figsize=figsize)
# 设置布局
pos = nx.spring_layout(nx_g, seed=42)
# 绘制节点
nx.draw_networkx_nodes(nx_g, pos,
node_color='lightblue',
node_size=500,
alpha=0.8)
# 绘制边
nx.draw_networkx_edges(nx_g, pos,
edge_color='gray',
arrows=True,
arrowsize=20,
alpha=0.6)
# 绘制节点标签
if node_labels is None:
node_labels = {i: str(i) for i in range(g.num_nodes())}
nx.draw_networkx_labels(nx_g, pos, node_labels, font_size=8)
# 绘制边标签(如果提供)
if edge_labels is not None:
nx.draw_networkx_edge_labels(nx_g, pos, edge_labels)
plt.title(f"DGL\n num_nodes: {g.num_nodes()}, num_edges: {g.num_edges()}")
plt.axis('off')
plt.tight_layout()
plt.show()
# # 使用示例
# visualize_dgl_graph(bg)
\ No newline at end of file
{
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 基本图"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![图](..\\picture\\image.png)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Graph(num_nodes=4, num_edges=4,\n",
" ndata_schemes={}\n",
" edata_schemes={})\n",
"tensor([0, 1, 2, 3])\n",
"(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))\n",
"(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))\n"
]
}
],
"source": [
"import dgl\n",
"import torch as th\n",
"\n",
"# 边 0->1, 0->2, 0->3, 1->3\n",
"\n",
"u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n",
"g = dgl.graph((u, v))\n",
"print(g) # 图中节点的数量是DGL通过给定的图的边列表中最大的点ID推断所得出的\n",
"\n",
"# 获取节点的ID\n",
"print(g.nodes())\n",
"# 获取边的对应端点\n",
"print(g.edges())\n",
"# 获取边的对应端点和边ID\n",
"print(g.edges(form='all'))\n",
"\n",
"# 如果具有最大ID的节点没有边,在创建图的时候,用户需要明确地指明节点的数量。\n",
"g = dgl.graph((u, v), num_nodes=8)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"##转化成无向图\n",
"bg = dgl.to_bidirected(g)\n",
"bg.edges()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes=6, num_edges=4,\n",
" ndata_schemes={}\n",
" edata_schemes={})"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## 节点特征和边特征\n",
"import dgl\n",
"import torch as th\n",
"g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6个节点,4条边\n",
"g\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes=6, num_edges=4,\n",
" ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32)}\n",
" edata_schemes={'x': Scheme(shape=(), dtype=torch.int32)})"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g.ndata['x'] = th.ones(g.num_nodes(), 3) # 长度为3的节点特征\n",
"g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32) # 标量整型特征\n",
"g"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1., 1., 1.])\n",
"tensor([1, 1], dtype=torch.int32)\n"
]
}
],
"source": [
"# 不同名称的特征可以具有不同形状\n",
"g.ndata['y'] = th.randn(g.num_nodes(), 5)\n",
"print(g.ndata['x'][1]) # 获取节点1的特征\n",
"print(g.edata['x'][th.tensor([0, 3])]) # 获取边0和3的特征"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes=4, num_edges=4,\n",
" ndata_schemes={}\n",
" edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## 加权图\n",
"# 边 0->1, 0->2, 0->3, 1->3\n",
"edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])\n",
"weights = th.tensor([0.1, 0.6, 0.9, 0.7]) # 每条边的权重\n",
"g = dgl.graph(edges)\n",
"g.edata['w'] = weights # 将其命名为 'w'\n",
"g"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes=100, num_edges=500,\n",
" ndata_schemes={}\n",
" edata_schemes={})"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## 外部创建接口\n",
"## scipy构建\n",
"import dgl\n",
"import torch as th\n",
"import scipy.sparse as sp\n",
"spmat = sp.rand(100, 100, density=0.05) # 5%非零项\n",
"dgl.from_scipy(spmat) # 来自SciPy"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"## networkx\n",
"import networkx as nx\n",
"nx_g = nx.path_graph(5) # 一条链路0-1-2-3-4\n",
"g=dgl.from_networkx(nx_g) # 来自NetworkX\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"## 保存dgl图\n",
"dgl.save_graphs(\"graph.dgl\", g)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"read_g=dgl.load_graphs('graph.dgl')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([Graph(num_nodes=5, num_edges=8,\n",
" ndata_schemes={}\n",
" edata_schemes={})],\n",
" {})"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"read_g"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lbb",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment