当前位置 博文首页 > 韦全敏的博客:图神经网络框架DGL教程-第4章:图数据处理管道
更多图神经网络和深度学习内容请关注:
DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset
类定义的标准的数据处理管道。 DGL推荐用户将图数据处理为 dgl.data.DGLDataset
的子类。该类为导入、处理和保存图数据提供了简单而干净的解决方案。
DGLDataset 是处理、导入和保存 dgl.data
中定义的图数据集的基类。 它实现了用于处理图数据的基本模版。下面的流程图展示了这个模版的工作方式。
即
has_cache()
. If true, goto 5.download()
to download the data.process()
to process the data.save()
to save the processed dataset on disk and goto 6.load()
to load the processed dataset from disk.为了处理位于远程服务器或本地磁盘上的图数据集,下面的例子中定义了一个类,称为 MyDataset
, 它继承自 dgl.data.DGLDataset。
from dgl.data import DGLDataset
class MyDataset(DGLDataset):
""" 用于在DGL中自定义图数据集的模板:
Parameters
----------
url : str
下载原始数据集的url。
raw_dir : str
指定下载数据(未经处理的数据)的存储目录或已下载数据的存储目录。默认: ~/.dgl/
save_dir : str
处理完成的数据集的保存目录。默认:raw_dir指定的值
force_reload : bool
是否重新导入数据集。默认:False
verbose : bool
是否打印进度信息。
"""
def __init__(self,
url=None,
raw_dir=None,
save_dir=None,
force_reload=False,
verbose=False):
super(MyDataset, self).__init__(name='dataset_name',
url=url,
raw_dir=raw_dir,
save_dir=save_dir,
force_reload=force_reload,
verbose=verbose)
def download(self):
# 将原始数据下载到本地磁盘
pass
def process(self):
# 将原始数据处理为图、标签和数据集划分的掩码
pass
def __getitem__(self, idx):
# 通过idx得到与之对应的一个样本
pass
def __len__(self):
# 数据样本的数量
pass
def save(self):
# 将处理后的数据保存至 `self.save_path`
pass
def load(self):
# 从 `self.save_path` 导入处理后的数据
pass
def has_cache(self):
# 检查在 `self.save_path` 中是否存有处理后的数据
pass
Using backend: pytorch
DGLDataset
类有抽象函数 process()
, __getitem__(idx)
和 __len__()
。子类必须实现这些函数。同时DGL也建议实现保存save()
和导入load
函数, 因为对于处理后的大型数据集,这么做可以节省大量的时间, 并且有多个已有的API可以简化此操作(请参阅 4.4 保存和加载数据)。
请注意, DGLDataset
的目的是提供一种标准且方便的方式来导入图数据。 用户可以存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息。 诸如采样、划分或特征归一化等操作建议在 DGLDataset 子类之外完成。
如果用户的数据集已经在本地磁盘中,请确保它被存放在目录 raw_dir
中。 如果用户想在任何地方运行代码而又不想自己下载数据并将其移动到正确的目录中,则可以通过实现函数 download()
来自动完成。
如果数据集是一个zip
文件,可以直接继承 dgl.data.DGLBuiltinDataset
类。后者支持解压缩zip文件。 否则用户需要自己实现 download()
,具体可以参考 QM7bDataset 类:
import os
from dgl.data.utils import download
def download(self):
# 存储文件的路径
file_path = os.path.join(self.raw_dir, self.name + '.mat')
# 下载文件
download(self.url, path=file_path)
上面的代码将一个.mat文件下载到目录 self.raw_dir
。如果文件是.gz、.tar、.tar.gz或.tgz文件,请使用 extract_archive() 函数进行解压缩。以下代码展示了如何在 BitcoinOTCDataset 类中下载一个.gz文件:
from dgl.data.utils import download, check_sha1
def download(self):
# 存储文件的路径,请确保使用与原始文件名相同的后缀
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
# 下载文件
download(self.url, path=gz_file_path)
# 检查 SHA-1
if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
# 将文件解压缩到目录self.raw_dir下的self.name目录中
self._extract_gz(gz_file_path, self.raw_path)
上面的代码会将文件解压缩到 self.raw_dir
下的目录 self.name
中。 如果该类继承自 dgl.data.DGLBuiltinDataset
来处理zip文件, 则它也会将文件解压缩到目录 self.name
中。
一个可选项是用户可以按照上面的示例检查下载后文件的SHA-1字符串,以防作者在远程服务器上更改了文件。
用户可以在 process()
函数中实现数据处理。该函数假定原始数据已经位于 self.raw_dir
目录中。
图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。
本节重点介绍了处理图、特征和划分掩码的标准方法。用户指南将以内置数据集为例,并跳过从文件构建图的实现。 用户可以参考 1.4 从外部源创建图 以查看如何从外部数据源构建图的完整指南。
整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。 因此,需要将原始数据处理为 dgl.DGLGraph
对象的列表和标签张量的列表。 此外,如果原始数据已被拆分为多个文件,则可以添加参数 split
以导入数据的特定部分。
下面以 QM7bDataset
为例:
from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
# 将数据处理为图列表和标签列表
self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename):
data = io.loadmat(filename)
labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])
feats = data['X']
num_graphs = labels.shape[0]
graphs = []
for i in range(num_graphs):
edge_list = feats[i].nonzero()
g = dgl_graph(edge_list)
g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
dtype=F.data_type_dict['float32'])
graphs.append(g)
return graphs, labels
def __getitem__(self, idx):
""" 通过idx获取对应的图和标签
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
"""数据集中图的数量"""
return len(self.graphs)
函数 process()
将原始数据处理为图列表和标签列表。用户必须实现 __getitem__(idx)
和 __len__()
以进行迭代。 DGL建议让 __getitem__(idx)
返回如上面代码所示的元组 (图,标签)
。 用户可以参考 QM7bDataset源代码 以获得 self._load_graph()
和 __getitem__
的详细信息。
用户还可以向类添加属性以指示一些有用的数据集信息。在 QM7bDataset
中, 用户可以添加属性 num_labels
来指示此多任务数据集中的预测任务总数:
@property
def num_labels(self):
"""每个图的标签数,即预测任务数。"""
return 14
在编写完这些代码之后,用户可以按如下所示的方式来使用 QM7bDataset:
import dgl
import torch
from dgl.dataloading import GraphDataLoader
# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels
# 创建 dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# 训练
for epoch in range(100):
for g, labels in dataloader:
# 用户自己的训练代码
pass
训练整图分类模型的完整指南可以在 5.4 整图分类 中找到。
有关整图分类数据集的更多示例,用户可以参考 5.4 整图分类:
Graph isomorphism network dataset
Mini graph classification dataset
QM7b dataset
TU dataset
与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码
来指定数据集的划分。 本节以内置数据集 CitationGraphDataset 为例:
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# 跳过一些处理的代码
# === 跳过数据处理 ===
# 构建图
g = dgl.graph(graph)
# 划分掩码
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.