Source code for rllm.nn.conv.graph_conv.gat_conv

from typing import Union, Tuple

import torch
from torch import Tensor
from torch.sparse import Tensor as SparseTensor
import torch.nn.init as init

from rllm.utils import set_values, seg_softmax
from rllm.nn.conv.graph_conv import MessagePassing


[docs] class GATConv(MessagePassing): r"""The GAT (Graph Attention Network) model implementation with message passing, based on the `"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>`__ paper. In particular, this implementation utilizes sparse attention mechanisms to handle graph-structured data, similiar to <https://github.com/Diego999/pyGAT>. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j} where the attention coefficients :math:`\alpha_{i,j}` are computed as: .. math:: \alpha_{i,j} =\frac{\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} \mathbf{\Theta} \mathbf{x}_i+ \mathbf{a}^{\top} \mathbf{ \Theta}\mathbf{x}_j\right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top} \mathbf{\Theta} \mathbf{x}_i + \mathbf{a}^{\top}\mathbf{\Theta}\mathbf{x}_k \right)\right)}. Args: in_dim (int): Size of each input sample. out_dim (int): Size of each output sample. num_heads (int): Number of multi-head-attentions, the default value is 1. concat (bool): If set to `False`, the multi-head attentions are averaged instead of concatenated. negative_slope (float): LeakyReLU angle of the negative slope, the default value is 0.2. dropout (float): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. The default value is 0. bias (bool): If set to `False`, no bias terms are added into the final output. skip_connection (bool): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` edge_index is sparse adjacency matrix :math:`(|\mathcal{V}|, |\mathcal{V}|)` or edge list :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ node_dim = 0 head_dim = 1 def __init__( self, in_dim: Union[int, Tuple[int, int]], out_dim: int, num_heads: int = 8, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0.6, bias: bool = True, skip_connection: bool = False, **kwargs, ): super().__init__(aggr="add", **kwargs) self.in_dim = in_dim self.out_dim = out_dim self.num_heads = num_heads self.concat = concat self.skip_connection = skip_connection if isinstance(in_dim, int): self.lin = torch.nn.Linear(in_dim, out_dim * num_heads, bias=False) else: in_dim = tuple(in_dim) self.lin_src = torch.nn.Linear(in_dim[0], out_dim * num_heads, bias=False) self.lin_dst = torch.nn.Linear(in_dim[1], out_dim * num_heads, bias=False) if self.skip_connection: self.lin_skip = torch.nn.Linear( in_features=in_dim[1] if isinstance(in_dim, tuple) else in_dim, out_features=out_dim * num_heads if self.concat else out_dim, bias=False, ) # attention weights self.attn_src = torch.nn.Parameter(torch.Tensor(1, num_heads, out_dim)) self.attn_dst = torch.nn.Parameter(torch.Tensor(1, num_heads, out_dim)) if bias and concat: self.bias = torch.nn.Parameter(torch.empty(num_heads * out_dim)) elif bias and not concat: self.bias = torch.nn.Parameter(torch.empty(out_dim)) else: self.register_parameter("bias", None) self.leaky_relu = torch.nn.LeakyReLU(negative_slope) self.dropout = torch.nn.Dropout(p=dropout) self.reset_parameters() def reset_parameters(self): if self.bias is not None: init.zeros_(self.bias) init.xavier_normal_(self.attn_src) init.xavier_normal_(self.attn_dst)
[docs] def forward( self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], return_attention_weights: bool = False, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, Tuple[SparseTensor, Tensor]], ]: r"""Compute node embeddings with graph attention message passing. Args: x (Union[Tensor, Tuple[Tensor, Tensor]]): - Tensor input features for homogeneous graphs. - Tuple of source and destination node features for bipartite graphs. edge_index (Union[Tensor, SparseTensor]): Graph connectivity in edge-list or sparse adjacency format. return_attention_weights (bool): If True, also return edge-level attention coefficients. Returns: Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, Tuple[SparseTensor, Tensor]]]: If ``return_attention_weights`` is False, returns output node features. Otherwise returns output features and corresponding attention weights. Example: >>> import torch >>> from rllm.nn.conv.graph_conv import GATConv >>> conv = GATConv(16, 8, num_heads=2, concat=False) >>> x = torch.randn(4, 16) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv(x, edge_index) >>> out.shape torch.Size([4, 8]) """ self.return_attention_weights = return_attention_weights # Linear projection if isinstance(x, Tensor): if self.skip_connection: skip_res = self.lin_skip(x) x = self.lin(x).view(-1, self.num_heads, self.out_dim) # (N, H, D) x_src = x_dst = x else: if self.skip_connection: skip_res = self.lin_skip(x[1]) x_src = self.lin_src(x[0]).view( -1, self.num_heads, self.out_dim ) # (N_src, H, D) x_dst = self.lin_dst(x[1]).view( -1, self.num_heads, self.out_dim ) # (N_dst, H, D) num_nodes = x_dst.size(0) # N_dst # node attention alpha_src = (x_src * self.attn_src).sum(dim=-1) # (N_src, H) alpha_dst = (x_dst * self.attn_dst).sum(dim=-1) # (N_dst, H) # out: (N_dst, H, D) out = self.propagate( None, edge_index=edge_index, alpha=(alpha_src, alpha_dst), x_src=x_src, dim_size=num_nodes, ) if self.concat: out = out.view(-1, self.num_heads * self.out_dim) # (N_dst, H * D) else: out = out.mean(dim=self.head_dim, keepdim=False) # (N_dst, D) if self.skip_connection: out = out + skip_res if self.bias is not None: out += self.bias if return_attention_weights: if edge_index.is_sparse: return out, ( set_values(edge_index, self.attn_weights), self.attn_weights, ) else: return out, (edge_index, self.attn_weights) else: return out
[docs] def message_and_aggregate(self, edge_index, alpha, x_src, dim_size): edge_index, _ = self.__unify_edgeindex__(edge_index) x_src = self.retrieve_feats( x_src, edge_index, dim=0, retrieve_dim=self.node_dim ) # attention scores # alpha_src / alpha_dst: (E, H) alpha_src, alpha_dst = self.retrieve_feats( alpha, edge_index, retrieve_dim=self.node_dim ) # alpha: (E, H) alpha = self.leaky_relu(alpha_src + alpha_dst) alpha = seg_softmax(alpha, edge_index[1], num_segs=dim_size) alpha = self.dropout(alpha) self.attn_weights = ( alpha.clone().detach() if self.return_attention_weights else None ) # msgs: (E, H, D) msgs = x_src * alpha.unsqueeze(-1) return self.aggr_module( msgs, edge_index[1], dim=self.node_dim, dim_size=dim_size )
def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.in_dim}, " f"{self.out_dim}, num_heads={self.num_heads})" )