Source code for rllm.nn.encoder.heterotemporal_encoder
from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import ModuleDict
from .col_encoder._positional_encoder import PositionalEncoder
[docs]
class HeteroTemporalEncoder(torch.nn.Module):
r"""HeteroTemporalEncoder for RDL model from paper
`"RelBench: A Benchmark for Deep Learning on Relational Databases"
<https://arxiv.org/abs/2407.20060>`_.
Args:
node_types (List[str]): The list of node types.
channels (int): The number of channels.
Returns:
The ``forward`` method returns a dictionary from node type to
temporal embeddings.
Example:
>>> import torch
>>> enc = HeteroTemporalEncoder(node_types=["user", "item"], channels=16)
>>> seed_time = torch.tensor([1000.0, 1100.0])
>>> time_dict = {"user": torch.tensor([900.0]), "item": torch.tensor([950.0])}
>>> batch_dict = {"user": torch.tensor([0]), "item": torch.tensor([1])}
>>> out = enc(seed_time, time_dict, batch_dict)
>>> out["user"].shape
torch.Size([1, 16])
"""
def __init__(self, node_types: List[str], channels: int):
super().__init__()
self.encoder_dict = ModuleDict(
{node_type: PositionalEncoder(channels) for node_type in node_types}
)
self.lin_dict = ModuleDict(
{node_type: torch.nn.Linear(channels, channels) for node_type in node_types}
)
[docs]
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()
[docs]
def forward(
self,
seed_time: Tensor,
time_dict: Dict[str, Tensor],
batch_dict: Dict[str, Tensor],
) -> Dict[str, Tensor]:
r"""Compute relative temporal embeddings for each node type.
Args:
seed_time (Tensor): The reference timestamps for seed nodes of
shape :obj:`[num_seeds]`.
time_dict (Dict[str, Tensor]): Timestamps per node type.
batch_dict (Dict[str, Tensor]): Batch assignment indices per
node type, mapping each node to a seed node.
Returns:
Dict[str, Tensor]: Temporal embeddings per node type of shape
:obj:`[num_nodes, channels]`.
"""
out_dict: Dict[str, Tensor] = {}
for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.
x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x
return out_dict