from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, List, Dict, Union, Tuple, Sequence, Callable, Optional
from rllm.llm.types import ChatMessage
from rllm.llm.parser.base import BaseOutputParser
from rllm.llm.prompt.utils import (
get_template_vars,
prompt_to_messages,
messages_to_prompt,
)
from rllm.llm.llm_module.base import BaseLLM
default_messages_to_prompt = messages_to_prompt
class BasePromptTemplate(ABC):
def __init__(
self,
metadata: Dict[str, Any],
template_vars: List[str],
kwargs: Dict[str, str],
output_parser: Optional[BaseOutputParser] = None,
template_var_mappings: Optional[Dict[str, Any]] = None,
function_mappings: Optional[Dict[str, Callable]] = None,
):
self.metadata = metadata
self.template_vars = template_vars
self.kwargs = kwargs
self.output_parser = output_parser
self.template_var_mappings = template_var_mappings or {}
self.function_mappings = function_mappings or {}
def _map_template_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""For keys in template_var_mappings, swap in the right keys."""
template_var_mappings = self.template_var_mappings or {}
return {template_var_mappings.get(k, k): v for k, v in kwargs.items()}
def _map_function_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""For keys in function_mappings,
compute values and combine with kwargs.
User can pass in functions instead of fixed values as format variables.
For each function, we call the function with the current kwargs,
get back the value, and then use that value in the template
for the corresponding format variable.
"""
function_mappings = self.function_mappings or {}
# First generate the values for the functions
new_kwargs = {}
for k, v in function_mappings.items():
# TODO: figure out what variables to pass into each function
# Assuming we decide to use all kwargs
new_kwargs[k] = v(**kwargs)
# Then, add the fixed variables only if not in new_kwargs already
# (implying that function mapping will override fixed variables)
for k, v in kwargs.items():
if k not in new_kwargs:
new_kwargs[k] = v
return new_kwargs
def _map_all_vars(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Map both template and function variables.
We (1) first call function mappings to compute functions,
and then (2) call the template_var_mappings.
"""
# Map function
new_kwargs = self._map_function_vars(kwargs)
# Map template vars (to point to
# existing format vars in string template)
return self._map_template_vars(new_kwargs)
@abstractmethod
def partial_format(self, **kwargs) -> "BasePromptTemplate":
raise NotImplementedError
@abstractmethod
def format(self, llm: Optional[BaseLLM] = None, **kwargs) -> str:
raise NotImplementedError
@abstractmethod
def format_messages(
self, llm: Optional[BaseLLM] = None, **kwargs
) -> List[ChatMessage]:
raise NotImplementedError
@abstractmethod
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
raise NotImplementedError
[docs]
class PromptTemplate(BasePromptTemplate):
r"""Template used for completion."""
def __init__(
self,
template: str,
output_parser: Optional[BaseOutputParser] = None,
metadata: Optional[Dict[str, Any]] = None,
template_var_mappings: Optional[Dict[str, Any]] = None,
function_mappings: Optional[Dict[str, Callable]] = None,
**kwargs,
) -> None:
if metadata is None:
metadata = {}
self.template = template
template_vars = get_template_vars(template)
super().__init__(
template_vars=template_vars,
kwargs=kwargs,
metadata=metadata,
output_parser=output_parser,
template_var_mappings=template_var_mappings,
function_mappings=function_mappings,
)
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
return self.template
[docs]
class ChatPromptTemplate(BasePromptTemplate):
r"""Template used for chat."""
def __init__(
self,
message_templates: List[ChatMessage],
output_parser: Optional[BaseOutputParser] = None,
metadata: Optional[Dict[str, Any]] = None,
template_var_mappings: Optional[Dict[str, Any]] = None,
function_mappings: Optional[Dict[str, Callable]] = None,
**kwargs,
):
if metadata is None:
metadata = {}
template_vars = []
for message_template in message_templates:
template_vars.extend(get_template_vars(message_template.content or ""))
super().__init__(
message_templates=message_templates,
kwargs=kwargs,
metadata=metadata,
output_parser=output_parser,
template_vars=template_vars,
template_var_mappings=template_var_mappings,
function_mappings=function_mappings,
)
[docs]
@classmethod
def from_messages(
cls,
message_templates: Union[List[Tuple[str, str]], List[ChatMessage]],
**kwargs,
) -> "ChatPromptTemplate":
"""From messages."""
if isinstance(message_templates[0], tuple):
message_templates = [
ChatMessage.from_str(role=role, content=content)
for role, content in message_templates
]
return cls(message_templates=message_templates, **kwargs)
def partial_format(self, **kwargs) -> "ChatPromptTemplate":
prompt = deepcopy(self)
prompt.kwargs.update(kwargs)
return prompt
def format(
self,
llm: Optional[BaseLLM] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
**kwargs,
) -> str:
del llm # unused
messages = self.format_messages(**kwargs)
if messages_to_prompt is not None:
return messages_to_prompt(messages)
return default_messages_to_prompt(messages)
def format_messages(
self, llm: Optional[BaseLLM] = None, **kwargs
) -> List[ChatMessage]:
del llm # unused
"""Format the prompt into a list of chat messages."""
all_kwargs = {
**self.kwargs,
**kwargs,
}
mapped_all_kwargs = self._map_all_vars(all_kwargs)
messages: List[ChatMessage] = []
for message_template in self.message_templates:
template_vars = get_template_vars(message_template.content or "")
relevant_kwargs = {
k: v for k, v in mapped_all_kwargs.items() if k in template_vars
}
content_template = message_template.content or ""
# if there's mappings specified, make sure those are used
content = content_template.format(**relevant_kwargs)
message: ChatMessage = message_template.copy()
message.content = content
messages.append(message)
if self.output_parser is not None:
messages = self.output_parser.format_messages(messages)
return messages
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
return default_messages_to_prompt(self.message_templates)