Source code for rllm.nn.models.rect
import torch
from torch import Tensor
import torch.nn.functional as F
from rllm.nn.conv.graph_conv import GCNConv
[docs]
class RECT_L(torch.nn.Module):
r"""The RECT model, or more specifically its supervised part RECT-L,
from the `"Network Embedding with Completely-imbalanced Labels"
<https://arxiv.org/abs/2007.03545>`__ paper.
In particular, a GCN model is trained that reconstructs semantic class
knowledge.
Args:
in_dim (int): Size of each input sample.
hidden_dim (int): Intermediate size of each sample.
dropout (float, optional): The dropout probability.
(default: :obj:`0.0`)
Example:
>>> import torch
>>> model = RECT_L(in_dim=16, hidden_dim=8)
>>> x = torch.randn(4, 16)
>>> adj = torch.tensor([[0, 1, 2], [1, 2, 3]])
>>> model(x, adj).shape
torch.Size([4, 16])
"""
def __init__(
self,
in_dim: int,
hidden_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.prelu = torch.nn.PReLU()
self.conv = GCNConv(in_dim, hidden_dim)
self.lin = torch.nn.Linear(hidden_dim, in_dim)
self.reset_parameters()
[docs]
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.conv.reset_parameters()
self.lin.reset_parameters()
torch.nn.init.xavier_uniform_(self.lin.weight.data)
[docs]
def forward(self, x: Tensor, adj: Tensor):
"""Encode node features and reconstruct semantic targets.
Args:
x (Tensor): Input node features.
adj (Tensor): Graph connectivity.
Returns:
Tensor: Reconstructed output features.
"""
x = self.prelu(self.conv(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)
[docs]
@torch.jit.export
def embed(self, x: Tensor, adj: Tensor):
"""Compute hidden embeddings without gradient updates.
Args:
x (Tensor): Input node features.
adj (Tensor): Graph connectivity.
Returns:
Tensor: Hidden node embeddings.
"""
with torch.no_grad():
return self.prelu(self.conv(x, adj))
[docs]
@torch.jit.export
def get_semantic_labels(self, x: Tensor, y: Tensor, mask: Tensor):
r"""Replace labels with corresponding class-center embeddings.
Args:
x (Tensor): Node embeddings.
y (Tensor): Class labels.
mask (Tensor): Mask selecting labeled nodes.
Returns:
Tensor: Class-center embeddings for labeled nodes.
"""
device = x.device
with torch.no_grad():
x = x[mask]
y = y[mask]
classes = y.unique()
class_centers = torch.zeros(classes.max() + 1, x.shape[1]).to(device)
for ci in classes:
class_centers[ci] = x[y == ci].mean(0)
return class_centers[y]
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.in_dim}, " f"{self.hidden_dim})"