Source code for rllm.transforms.graph_transforms.knn_graph
from functools import lru_cache
from typing import Optional
from torch import Tensor
from rllm.transforms.graph_transforms import EdgeTransform
from rllm.transforms.graph_transforms.functional import knn_graph
[docs]
class KNNGraph(EdgeTransform): # TODO: add force_undirected option.
r"""Builds a k-NN adjacency matrix from node features.
Args:
num_neighbors (int, optional): The number of neighbors. (default: 6)
mode (str[`connectivity`, `distance`], optional):
Type of returned matrix: `connectivity` will return the
connectivity matrix with ones and zeros, while `distance`
will return the distances between neighbors
according to the given metric.
(default: `connectivity`)
metric (str[`minkowski`, `cosine`, `l1`, `l2`, ...], optional):
Metric to use for distance computation.
Default is `minkowski`, which results in the
standard Euclidean distance when p = 2.
(default: `minkowski`)
p (float): Power parameter for the Minkowski metric (default: `2`).
metric_params (dict, optional): Additional keyword arguments for the
metric function. (default: :obj:`None`)
include_self (bool, optional): If set to :obj:`True`, the graph will
contain self-loops. (default: :obj:`False`)
n_jobs (int): Number of workers to use for computation. (default: 1)
"""
def __init__(
self,
num_neighbors: Optional[int] = 6,
mode: Optional[str] = "connectivity",
metric: Optional[str] = "minkowski",
p: Optional[int] = 2,
metric_params: Optional[dict] = None,
include_self: Optional[bool] = False,
n_jobs: int = 1,
):
self.num_neighbors = num_neighbors
self.mode = mode
self.metric = metric
self.p = p
self.metric_params = metric_params
self.include_self = include_self
self.n_jobs = n_jobs
@lru_cache()
def forward(self, x: Tensor) -> Tensor:
knn_adj = knn_graph(
x,
self.num_neighbors,
self.mode,
self.metric,
self.p,
self.metric_params,
self.include_self,
self.n_jobs,
)
return knn_adj