Graph Data Handle¶
Data Handling of Graphs¶
Graph data typically includes node connectivity and features. In rLLM, a simple graph data instance is defined by rllm.data.GraphData .
It generally contains the following information:
data.x: Node feature matrix, shape:[num_nodes, feature_dims]data.adj: Adjacency matrix representing graph structure, shape:[num_nodes, num_nodes]data.y: Node labels used for supervised training
An instance for storing graph data can be created as follows:
import torch
from rllm.data import GraphData
x = torch.tensor([[0],[1],[2]]).float()
y = torch.tensor([0, 1, 2]).long()
adj = torch.tensor([[0, 1, 0],
[1, 0, 0],
[0, 0, 1]])
data = GraphData(x=x, y=y, adj=adj)
The GraphData also provides various convenient functions fro inferring information from graph data and performing operations on it, such as:
data.train_mask = torch.tensor([True, True, False])
print(data.train_mask == data['train_mask'])
>>> True
data['test_mask'] = torch.tensor([False, False, True])
print(data.test_mask == data['test_mask'])
>>> True
print(data.num_nodes)
>>> 3
print(data.num_classes)
>>> 3
# transfer data to device.
data.to('cuda')
data.to('cpu')
Graph Transforms¶
The Transform module provides a range of methods for modifying and preprocessing graph data features contained in subclasses of BaseGraph, such as GraphData and HeteroGraphData.
These methods can be applied explicitly after initialization or implicitly by specifying them as the :obj:transform parameter when loading datasets.
Furthermore, the module supports the GraphTransform class, which allows users to chain multiple transformation methods for streamlined usage.
import os.path as osp
import rllm.transforms.graph_transform as GT
from rllm.datasets.planetoid import PlanetoidDataset
transform = GT.GraphTransform([
GT.NormalizeFeatures('l2'), # Normalize node features
GT.GCNNorm() # add self-loops and row-normalize adjacency
])
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
dataset = PlanetoidDataset(path, args.dataset, transform=transform)