当前位置 博文首页 > 韦全敏的博客:图神经网络框架DGL教程-第4章:图数据处理管道

    韦全敏的博客:图神经网络框架DGL教程-第4章:图数据处理管道

    作者:[db:作者] 时间:2021-07-08 15:37

    更多图神经网络和深度学习内容请关注:
    在这里插入图片描述

    第4章:图数据处理管道

    DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道。 DGL推荐用户将图数据处理为 dgl.data.DGLDataset 的子类。该类为导入、处理和保存图数据提供了简单而干净的解决方案。

    4.1 DGLDataset类

    DGLDataset 是处理、导入和保存 dgl.data 中定义的图数据集的基类。 它实现了用于处理图数据的基本模版。下面的流程图展示了这个模版的工作方式。

    在这里插入图片描述

    1. Check whether there is a dataset cache on disk (already processed and stored on the disk) by invoking has_cache(). If true, goto 5.
    2. Call download() to download the data.
    3. Call process() to process the data.
    4. Call save() to save the processed dataset on disk and goto 6.
    5. Call load() to load the processed dataset from disk.
    6. Done.

    为了处理位于远程服务器或本地磁盘上的图数据集,下面的例子中定义了一个类,称为 MyDataset, 它继承自 dgl.data.DGLDataset。

    from dgl.data import DGLDataset
    
    class MyDataset(DGLDataset):
        """ 用于在DGL中自定义图数据集的模板:
    
        Parameters
        ----------
        url : str
            下载原始数据集的url。
        raw_dir : str
            指定下载数据(未经处理的数据)的存储目录或已下载数据的存储目录。默认: ~/.dgl/
        save_dir : str
            处理完成的数据集的保存目录。默认:raw_dir指定的值
        force_reload : bool
            是否重新导入数据集。默认:False
        verbose : bool
            是否打印进度信息。
        """
        def __init__(self,
                     url=None,
                     raw_dir=None,
                     save_dir=None,
                     force_reload=False,
                     verbose=False):
            super(MyDataset, self).__init__(name='dataset_name',
                                            url=url,
                                            raw_dir=raw_dir,
                                            save_dir=save_dir,
                                            force_reload=force_reload,
                                            verbose=verbose)
    
        def download(self):
            # 将原始数据下载到本地磁盘
            pass
    
        def process(self):
            # 将原始数据处理为图、标签和数据集划分的掩码
            pass
    
        def __getitem__(self, idx):
            # 通过idx得到与之对应的一个样本
            pass
    
        def __len__(self):
            # 数据样本的数量
            pass
    
        def save(self):
            # 将处理后的数据保存至 `self.save_path`
            pass
    
        def load(self):
            # 从 `self.save_path` 导入处理后的数据
            pass
    
        def has_cache(self):
            # 检查在 `self.save_path` 中是否存有处理后的数据
            pass
    
    Using backend: pytorch
    

    DGLDataset 类有抽象函数 process()__getitem__(idx)__len__()。子类必须实现这些函数。同时DGL也建议实现保存save()和导入load函数, 因为对于处理后的大型数据集,这么做可以节省大量的时间, 并且有多个已有的API可以简化此操作(请参阅 4.4 保存和加载数据)。

    请注意, DGLDataset 的目的是提供一种标准且方便的方式来导入图数据。 用户可以存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息。 诸如采样、划分或特征归一化等操作建议在 DGLDataset 子类之外完成。

    4.2 下载原始数据(可选)

    如果用户的数据集已经在本地磁盘中,请确保它被存放在目录 raw_dir 中。 如果用户想在任何地方运行代码而又不想自己下载数据并将其移动到正确的目录中,则可以通过实现函数 download() 来自动完成。

    如果数据集是一个zip文件,可以直接继承 dgl.data.DGLBuiltinDataset 类。后者支持解压缩zip文件。 否则用户需要自己实现 download(),具体可以参考 QM7bDataset 类:

    import os
    from dgl.data.utils import download
    
    def download(self):
        # 存储文件的路径
        file_path = os.path.join(self.raw_dir, self.name + '.mat')
        # 下载文件
        download(self.url, path=file_path)
    

    上面的代码将一个.mat文件下载到目录 self.raw_dir。如果文件是.gz、.tar、.tar.gz或.tgz文件,请使用 extract_archive() 函数进行解压缩。以下代码展示了如何在 BitcoinOTCDataset 类中下载一个.gz文件:

    from dgl.data.utils import download, check_sha1
    
    def download(self):
        # 存储文件的路径,请确保使用与原始文件名相同的后缀
        gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
        # 下载文件
        download(self.url, path=gz_file_path)
        # 检查 SHA-1
        if not check_sha1(gz_file_path, self._sha1_str):
            raise UserWarning('File {} is downloaded but the content hash does not match.'
                              'The repo may be outdated or download may be incomplete. '
                              'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
        # 将文件解压缩到目录self.raw_dir下的self.name目录中
        self._extract_gz(gz_file_path, self.raw_path)
    

    上面的代码会将文件解压缩到 self.raw_dir 下的目录 self.name 中。 如果该类继承自 dgl.data.DGLBuiltinDataset 来处理zip文件, 则它也会将文件解压缩到目录 self.name 中。

    一个可选项是用户可以按照上面的示例检查下载后文件的SHA-1字符串,以防作者在远程服务器上更改了文件。

    4.3 处理数据

    用户可以在 process() 函数中实现数据处理。该函数假定原始数据已经位于 self.raw_dir 目录中。

    图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。

    本节重点介绍了处理图、特征和划分掩码的标准方法。用户指南将以内置数据集为例,并跳过从文件构建图的实现。 用户可以参考 1.4 从外部源创建图 以查看如何从外部数据源构建图的完整指南。

    处理整图分类数据集

    整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。 因此,需要将原始数据处理为 dgl.DGLGraph 对象的列表和标签张量的列表。 此外,如果原始数据已被拆分为多个文件,则可以添加参数 split 以导入数据的特定部分。

    下面以 QM7bDataset 为例:

    from dgl.data import DGLDataset
    
    class QM7bDataset(DGLDataset):
        _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
               'datasets/qm7b.mat'
        _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
    
        def __init__(self, raw_dir=None, force_reload=False, verbose=False):
            super(QM7bDataset, self).__init__(name='qm7b',
                                              url=self._url,
                                              raw_dir=raw_dir,
                                              force_reload=force_reload,
                                              verbose=verbose)
    
        def process(self):
            mat_path = self.raw_path + '.mat'
            # 将数据处理为图列表和标签列表
            self.graphs, self.label = self._load_graph(mat_path)
    	
    	def _load_graph(self, filename):
            data = io.loadmat(filename)
            labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])
            feats = data['X']
            num_graphs = labels.shape[0]
            graphs = []
            for i in range(num_graphs):
                edge_list = feats[i].nonzero()
                g = dgl_graph(edge_list)
                g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
                                        dtype=F.data_type_dict['float32'])
                graphs.append(g)
            return graphs, labels
    	
        def __getitem__(self, idx):
            """ 通过idx获取对应的图和标签
    
            Parameters
            ----------
            idx : int
                Item index
    
            Returns
            -------
            (dgl.DGLGraph, Tensor)
            """
            return self.graphs[idx], self.label[idx]
    
        def __len__(self):
            """数据集中图的数量"""
            return len(self.graphs)
    

    函数 process() 将原始数据处理为图列表和标签列表。用户必须实现 __getitem__(idx)__len__() 以进行迭代。 DGL建议让 __getitem__(idx) 返回如上面代码所示的元组 (图,标签)。 用户可以参考 QM7bDataset源代码 以获得 self._load_graph()__getitem__ 的详细信息。

    用户还可以向类添加属性以指示一些有用的数据集信息。在 QM7bDataset 中, 用户可以添加属性 num_labels 来指示此多任务数据集中的预测任务总数:

    @property
    def num_labels(self):
        """每个图的标签数,即预测任务数。"""
        return 14
    

    在编写完这些代码之后,用户可以按如下所示的方式来使用 QM7bDataset:

    import dgl
    import torch
    
    from dgl.dataloading import GraphDataLoader
    
    # 数据导入
    dataset = QM7bDataset()
    num_labels = dataset.num_labels
    
    # 创建 dataloaders
    dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
    
    # 训练
    for epoch in range(100):
        for g, labels in dataloader:
            # 用户自己的训练代码
            pass
    

    训练整图分类模型的完整指南可以在 5.4 整图分类 中找到。

    有关整图分类数据集的更多示例,用户可以参考 5.4 整图分类:

    • Graph isomorphism network dataset

    • Mini graph classification dataset

    • QM7b dataset

    • TU dataset

    处理节点分类数据集

    与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码来指定数据集的划分。 本节以内置数据集 CitationGraphDataset 为例:

    from dgl.data import DGLBuiltinDataset
    from dgl.data.utils import _get_dgl_url
    
    class CitationGraphDataset(DGLBuiltinDataset):
        _urls = {
            'cora_v2' : 'dataset/cora_v2.zip',
            'citeseer' : 'dataset/citeseer.zip',
            'pubmed' : 'dataset/pubmed.zip',
        }
    
        def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
            assert name.lower() in ['cora', 'citeseer', 'pubmed']
            if name.lower() == 'cora':
                name = 'cora_v2'
            url = _get_dgl_url(self._urls[name])
            super(CitationGraphDataset, self).__init__(name,
                                                       url=url,
                                                       raw_dir=raw_dir,
                                                       force_reload=force_reload,
                                                       verbose=verbose)
    
        def process(self):
            # 跳过一些处理的代码
            # === 跳过数据处理 ===
    
            # 构建图
            g = dgl.graph(graph)
    
            # 划分掩码
            g.ndata['train_mask'] = train_mask
            g.ndata['val_mask'] = val_mask
            g.