flexrag.models.openai_model 源代码

import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import httpx
import numpy as np
from openai import AzureOpenAI, OpenAI

from flexrag.prompt import ChatPrompt
from flexrag.utils import LOGGER_MANAGER, TIME_METER, configure

from .model_base import (
    ENCODERS,
    GENERATORS,
    EncoderBase,
    EncoderBaseConfig,
    GenerationConfig,
    GeneratorBase,
)

logger = LOGGER_MANAGER.get_logger("flexrag.models.openai")


[文档] @configure class OpenAIConfig: """The Base Configuration for OpenAI Client. :param is_azure: Whether the model is hosted on Azure. Default is False. :type is_azure: bool :param model_name: The name of the model to use. :type model_name: str :param base_url: The base URL of the OpenAI API. Default is None. :type base_url: Optional[str] :param api_key: The API key for OpenAI. Default is os.environ.get("OPENAI_API_KEY", "EMPTY"). :type api_key: str :param api_version: The API version to use. Default is "2024-07-01-preview". :type api_version: str :param verbose: Whether to show verbose logs. Default is False. :type verbose: bool :param proxy: The proxy to use for the HTTP client. Default is None. :type proxy: Optional[str] """ is_azure: bool = False model_name: Optional[str] = None base_url: Optional[str] = None api_key: str = os.environ.get("OPENAI_API_KEY", "EMPTY") api_version: str = "2024-07-01-preview" verbose: bool = False proxy: Optional[str] = None
[文档] @configure class OpenAIGeneratorConfig(OpenAIConfig): """Configuration for OpenAI Generator. :param allow_parallel: Whether to allow parallel generation. Default is True. :type allow_parallel: bool """ allow_parallel: bool = True
[文档] @GENERATORS("openai", config_class=OpenAIGeneratorConfig) class OpenAIGenerator(GeneratorBase): def __init__(self, cfg: OpenAIGeneratorConfig) -> None: # prepare proxy if cfg.proxy is not None: httpx_client = httpx.Client(proxies=cfg.proxy) else: httpx_client = None # prepare client if cfg.is_azure: self.client = AzureOpenAI( api_key=cfg.api_key, api_version=cfg.api_version, azure_endpoint=cfg.base_url, http_client=httpx_client, ) else: self.client = OpenAI( api_key=cfg.api_key, base_url=cfg.base_url, http_client=httpx_client, ) # set logger self.allow_parallel = cfg.allow_parallel assert cfg.model_name is not None, "`model_name` must be provided" self.model_name = cfg.model_name if not cfg.verbose: logger = logging.getLogger("httpx") logger.setLevel(logging.WARNING) # check client self._check() return @TIME_METER("openai_generate") def _chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: gen_cfg = self._get_options(generation_config) if self.allow_parallel: with ThreadPoolExecutor() as pool: responses = pool.map( lambda prompt: [ r.message.content for r in self.client.chat.completions.create( model=self.model_name, messages=prompt.to_list(), **gen_cfg, ).choices ], prompts, ) responses = list(responses) else: responses = [] for prompt in prompts: prompt = prompt.to_list() response = self.client.chat.completions.create( model=self.model_name, messages=prompt, **gen_cfg, ) responses.append([i.message.content for i in response.choices]) return responses @TIME_METER("openai_generate") async def async_chat( self, prompts: list[ChatPrompt], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: if not isinstance(prompts, list): prompts = [prompts] tasks = [] gen_cfg = self._get_options(generation_config) for prompt in prompts: prompt = prompt.to_list() tasks.append( asyncio.create_task( asyncio.to_thread( self.client.chat.completions.create, model=self.model_name, messages=prompt, **gen_cfg, ) ) ) responses = [[i.message.content for i in (await r).choices] for r in tasks] return responses @TIME_METER("openai_generate") def _generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: if not isinstance(prefixes, list): prefixes = [prefixes] gen_cfg = self._get_options(generation_config) if self.allow_parallel: with ThreadPoolExecutor() as pool: responses = pool.map( lambda prefix: [ r.text for r in self.client.completions.create( model=self.model_name, prompt=prefix, **gen_cfg, ).choices ], prefixes, ) responses = list(responses) else: responses = [] for prefix in prefixes: response = self.client.completions.create( model=self.model_name, prompt=prefix, **gen_cfg, ) responses.append([i.text for i in response.choices]) return responses @TIME_METER("openai_generate") async def async_generate( self, prefixes: list[str], generation_config: GenerationConfig = GenerationConfig(), ) -> list[list[str]]: if not isinstance(prefixes, list): prefixes = [prefixes] tasks = [] gen_cfg = self._get_options(generation_config) for prefix in prefixes: tasks.append( asyncio.create_task( asyncio.to_thread( self.client.completions.create, model=self.model_name, prompt=prefix, **gen_cfg, ) ) ) responses = [[i.text for i in (await r).choices] for r in tasks] return responses def _get_options(self, generation_config: GenerationConfig) -> dict: if "llama-3" in self.model_name.lower(): extra_body = {"stop_token_ids": [128009]} # hotfix for llama-3 else: extra_body = None return { "temperature": ( generation_config.temperature if generation_config.do_sample else 0.0 ), "max_tokens": generation_config.max_new_tokens, "top_p": generation_config.top_p, "n": generation_config.sample_num, "extra_body": extra_body, "stop": list(generation_config.stop_str), } def _check(self): model_lists = [i.id for i in self.client.models.list().data] assert self.model_name in model_lists, f"Model {self.model_name} not found"
[文档] @configure class OpenAIEncoderConfig(OpenAIConfig, EncoderBaseConfig): """Configuration for OpenAI Encoder. :param embedding_size: The size of the embedding vector. If None, it will be determined from the model. Default is None. :type embedding_size: Optional[int] """ embedding_size: Optional[int] = None
[文档] @ENCODERS("openai", config_class=OpenAIEncoderConfig) class OpenAIEncoder(EncoderBase): def __init__(self, cfg: OpenAIEncoderConfig) -> None: super().__init__(cfg) # prepare proxy if cfg.proxy is not None: httpx_client = httpx.Client(proxies=cfg.proxy) else: httpx_client = None # prepare client if cfg.is_azure: self.client = AzureOpenAI( api_key=cfg.api_key, api_version=cfg.api_version, azure_endpoint=cfg.base_url, http_client=httpx_client, ) else: self.client = OpenAI( api_key=cfg.api_key, base_url=cfg.base_url, http_client=httpx_client, ) # set logger assert cfg.model_name is not None, "`model_name` must be provided" self.model_name = cfg.model_name self.dimension = cfg.embedding_size if not cfg.verbose: logger = logging.getLogger("httpx") logger.setLevel(logging.WARNING) # check client self._check() return @TIME_METER("openai_encode") def _encode(self, texts: list[str]) -> np.ndarray: if self.dimension is None: r = self.client.embeddings.create(model=self.model_name, input=texts) else: r = self.client.embeddings.create( model=self.model_name, input=texts, dimensions=self.dimension ) embeddings = [i.embedding for i in r.data] return np.array(embeddings) @TIME_METER("openai_encode") async def async_encode(self, texts: list[str]) -> np.ndarray: r = await asyncio.to_thread( self.client.embeddings.create, model=self.model_name, input=texts, dimensions=self.dimension, ) embeddings = [i.embedding for i in r.data] return np.array(embeddings) @property def embedding_size(self): if self.dimension is None: return len( self.client.embeddings.create(model=self.model_name, input="test") .data[0] .embedding ) return self.dimension def _check(self): model_lists = [i.id for i in self.client.models.list().data] assert self.model_name in model_lists, f"Model {self.model_name} not found"