rllm.nn.conv.graph_conv.GATConv

class rllm.nn.conv.graph_conv.GATConv(in_dim: int | Tuple[int, int], out_dim: int, num_heads: int = 8, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0.6, bias: bool = True, skip_connection: bool = False, **kwargs)[source]

Bases: MessagePassing

The GAT (Graph Attention Network) model implementation with message passing, based on the “Graph Attention Networks” paper.

In particular, this implementation utilizes sparse attention mechanisms to handle graph-structured data, similiar to <https://github.com/Diego999/pyGAT>.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j}\]

where the attention coefficients \(\alpha_{i,j}\) are computed as:

\[\alpha_{i,j} =\frac{\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} \mathbf{\Theta} \mathbf{x}_i+ \mathbf{a}^{\top} \mathbf{ \Theta}\mathbf{x}_j\right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top} \mathbf{\Theta} \mathbf{x}_i + \mathbf{a}^{\top}\mathbf{\Theta}\mathbf{x}_k \right)\right)}.\]
Parameters:
  • in_dim (int) – Size of each input sample.

  • out_dim (int) – Size of each output sample.

  • num_heads (int) – Number of multi-head-attentions, the default value is 1.

  • concat (bool) – If set to False, the multi-head attentions are averaged instead of concatenated.

  • negative_slope (float) – LeakyReLU angle of the negative slope, the default value is 0.2.

  • dropout (float) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. The default value is 0.

  • bias (bool) – If set to False, no bias terms are added into the final output.

  • skip_connection (bool) – If set to True, the layer will add a learnable skip-connection. (default: False)

Shapes:

  • input:

    node features \((|\mathcal{V}|, F_{in})\)

    edge_index is sparse adjacency matrix \((|\mathcal{V}|, |\mathcal{V}|)\) or edge list \((2, |\mathcal{E}|)\)

  • output:

    node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor | Tuple[Tensor, Tensor], edge_index: Tensor, return_attention_weights: bool = False) Tensor | Tuple[Tensor, Tuple[Tensor, Tensor]][source]

Compute node embeddings with graph attention message passing.

Parameters:
  • x (Union[Tensor, Tuple[Tensor, Tensor]]) –

    • Tensor input features for homogeneous graphs.

    • Tuple of source and destination node features for bipartite graphs.

  • edge_index (Union[Tensor, SparseTensor]) – Graph connectivity in edge-list or sparse adjacency format.

  • return_attention_weights (bool) – If True, also return edge-level attention coefficients.

Returns:

If return_attention_weights is False, returns output node features. Otherwise returns output features and corresponding attention weights.

Return type:

Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, Tuple[SparseTensor, Tensor]]]

Example

>>> import torch
>>> from rllm.nn.conv.graph_conv import GATConv
>>> conv = GATConv(16, 8, num_heads=2, concat=False)
>>> x = torch.randn(4, 16)
>>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
>>> out = conv(x, edge_index)
>>> out.shape
torch.Size([4, 8])
message_and_aggregate(edge_index, alpha, x_src, dim_size)[source]

The message and aggregation interface to be overridden by subclasses.