rllm.nn.conv.graph_conv.GCNConv¶
- class rllm.nn.conv.graph_conv.GCNConv(in_dim: int, out_dim: int, bias: bool = True, normalize: bool = False)[source]¶
Bases:
MessagePassingThe GCN (Graph Convolutional Network) model implementation with message passing, based on the “Semi-supervised Classification with Graph Convolutional Networks” paper.
This model applies convolution operations to graph-structured data, allowing for the aggregation of feature information from neighboring nodes.
\[\mathbf{X}^{\prime} = \mathbf{\tilde{A}} \mathbf{X} \mathbf{W}\]- Parameters:
in_dim (int) – Size of each input sample.
out_dim (int) – Size of each output sample.
bias (bool) – If set to False, no bias terms are added into the final output.
normalize (bool) – If set to True, the adjacency matrix is normalized using the symmetric normalization method. The normalization is performed as follows: \(\mathbf{\tilde{A}} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\). where \(\mathbf{D}\) is the degree matrix of the graph.
Shapes:
input:
node features \((|\mathcal{V}|, F_{in})\)
edge_index is sparse adjacency matrix \((|\mathcal{V}|, |\mathcal{V}|)\) or edge list \((2, |\mathcal{E}|)\)
output:
node features \((|\mathcal{V}|, F_{out})\)
- forward(x: Tensor, edge_index: Tensor, edge_weight: Tensor | None = None, dim_size: int | None = None) Tensor[source]¶
Apply a single GCN message-passing step.
- Parameters:
x (Tensor) – Input node features with shape
[num_nodes, in_dim].edge_index (Union[Tensor, SparseTensor]) – Graph connectivity in edge-list or sparse adjacency format.
edge_weight (Optional[Tensor]) – Optional edge weights used during neighborhood aggregation.
dim_size (Optional[int]) – Optional number of destination nodes for aggregation output shape inference.
- Returns:
Output node features with shape
[num_nodes, out_dim].- Return type:
Tensor
Example
>>> import torch >>> from rllm.nn.conv.graph_conv import GCNConv >>> conv = GCNConv(16, 8) >>> x = torch.randn(5, 16) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv(x, edge_index) >>> out.shape torch.Size([5, 8])