rllm.nn.encoder.HeteroTemporalEncoder¶
- class rllm.nn.encoder.HeteroTemporalEncoder(node_types: List[str], channels: int)[source]¶
Bases:
ModuleHeteroTemporalEncoder for RDL model from paper “RelBench: A Benchmark for Deep Learning on Relational Databases”.
- Parameters:
node_types (List[str]) – The list of node types.
channels (int) – The number of channels.
- Returns:
The
forwardmethod returns a dictionary from node type to temporal embeddings.
Example
>>> import torch >>> enc = HeteroTemporalEncoder(node_types=["user", "item"], channels=16) >>> seed_time = torch.tensor([1000.0, 1100.0]) >>> time_dict = {"user": torch.tensor([900.0]), "item": torch.tensor([950.0])} >>> batch_dict = {"user": torch.tensor([0]), "item": torch.tensor([1])} >>> out = enc(seed_time, time_dict, batch_dict) >>> out["user"].shape torch.Size([1, 16])
- forward(seed_time: Tensor, time_dict: Dict[str, Tensor], batch_dict: Dict[str, Tensor]) Dict[str, Tensor][source]¶
Compute relative temporal embeddings for each node type.
- Parameters:
seed_time (Tensor) – The reference timestamps for seed nodes of shape
[num_seeds].time_dict (Dict[str, Tensor]) – Timestamps per node type.
batch_dict (Dict[str, Tensor]) – Batch assignment indices per node type, mapping each node to a seed node.
- Returns:
Temporal embeddings per node type of shape
[num_nodes, channels].- Return type:
Dict[str, Tensor]