Source code for rllm.llm.predictor

from typing import Any, List, Literal, Optional
import time

import pandas as pd
from tqdm import tqdm

from rllm.llm.prompt.default_prompt import (
    DEFAULT_SCENARIO_CLASSIFICATION_TMPL,
    DEFAULT_SCENARIO_REGRESSION_TMPL
)
from rllm.llm.prompt.utils import (
    generate_sample_description,
    get_template_vars
)


from rllm.llm.llm_module.general_llm import LLM
from rllm.llm.prompt.base import BasePromptTemplate
from rllm.llm.prompt.base import PromptTemplate


[docs] class Predictor: r"""Predictor for relational data. Data should be organized into a :class:`pandas.dataframe` format, with any prediction labels removed if present. Args: prompt (Optional[:class:`rllm.llm.prompt.base.BasePromptTemplate`]): The prompt to instruct llm make prediction. llm (:class:`rllm.llm.llm_module.general_llm.LLM`): The llm used for prediction, it is recommended to be initialized with LangChain. type (Optional[Literal['classification', 'regression']] ): Task type. .. code-block:: python import pandas as pd from langchain_openai import OpenAI from rllm.llm import LangChainLLM, Predictor # labels in dataframe should be removed. data = pd.read_csv('data.csv') scenario = 'Your_task_description' labels = 'Your_task_labels' llm = LangChainLLM(OpenAI(openai_api_key="YOUR_API_KEY")) predictor = Predictor(llm=llm, type='classification') outputs = predictor(data.head(10), scenario=scenario, labels=labels) """ def __init__( self, prompt: Optional[BasePromptTemplate] = None, llm: LLM = None, type: Optional[Literal['classification', 'regression']] = None, ) -> None: # NOTE: Only support `PromptTemplate` so far self._llm = llm if prompt is None: assert type in ['classification', 'regression'], \ "type must be 'classification' or 'regresssion'!" function_mapping = { 'sample_description': generate_sample_description } if type == 'classification': self.prompt = PromptTemplate( DEFAULT_SCENARIO_CLASSIFICATION_TMPL, function_mappings=function_mapping ) else: self.prompt = PromptTemplate( DEFAULT_SCENARIO_REGRESSION_TMPL, function_mappings=function_mapping ) else: self.prompt = prompt def invoke( self, df: pd.DataFrame, **kwargs, ) -> List[str]: # Check if all variables in the prompt are provided. input_variables = { **kwargs, **self.prompt.function_mappings }.keys() required_variables = get_template_vars(self.prompt.template) for var in required_variables: assert var in input_variables, \ f"Variable '{var}' not found in input variables." # Make prediction, remember `row` is a default argument. outputs = [] for index, row in tqdm(df.iterrows(), total=len(df)): output = "" for i in range(3): try: output = self._llm.predict(self.prompt, row=row, **kwargs) break except Exception as exc: if i == 2: tqdm.write( f"Prediction failed for row {index} after " f"{i + 1} attempts: {exc}" ) output = "" else: time.sleep(1.5 * (i + 1)) # retry backoff outputs.append(output) time.sleep(0.5) return outputs def __call__( self, df: pd.DataFrame, **kwargs, ) -> Any: return self.invoke(df, **kwargs)