flexrag.models.anthropic_model 源代码
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import httpx
from flexrag.prompt import ChatPrompt
from flexrag.utils import LOGGER_MANAGER, TIME_METER, configure
from .model_base import GENERATORS, GenerationConfig, GeneratorBase
logger = LOGGER_MANAGER.get_logger("flexrag.models.anthropic")
[文档]
@configure
class AnthropicGeneratorConfig:
"""Configuration for AnthropicGenerator.
:param model_name: The name of the model. Required.
:type model_name: str
:param base_url: The base url of the API. Defaults to None.
:type base_url: Optional[str]
:param api_key: The API key. Defaults to os.environ.get("ANTHROPIC_API_KEY", "EMPTY").
:type api_key: str
:param verbose: Whether to output verbose logs. Defaults to False.
:type verbose: bool
:param proxy: The proxy to use. Defaults to None.
:type proxy: Optional[str]
:param allow_parallel: Whether to allow parallel generation. Defaults to True.
:type allow_parallel: bool
"""
model_name: Optional[str] = None
base_url: Optional[str] = None
api_key: str = os.environ.get("ANTHROPIC_API_KEY", "EMPTY")
verbose: bool = False
proxy: Optional[str] = None
allow_parallel: bool = True
[文档]
@GENERATORS("anthropic", config_class=AnthropicGeneratorConfig)
class AnthropicGenerator(GeneratorBase):
def __init__(self, cfg: AnthropicGeneratorConfig) -> None:
from anthropic import Anthropic
# set proxy
if cfg.proxy is not None:
client = httpx.Client(proxies=cfg.proxy)
else:
client = None
self.client = Anthropic(
api_key=cfg.api_key,
base_url=cfg.base_url,
http_client=client,
)
assert cfg.model_name is not None, "model_name must be provided"
self.model_name = cfg.model_name
self.allow_parallel = cfg.allow_parallel
if not cfg.verbose:
logger = logging.getLogger("httpx")
logger.setLevel(logging.WARNING)
return
@TIME_METER("anthropic_generate")
def _chat(
self,
prompts: list[ChatPrompt],
generation_config: GenerationConfig = GenerationConfig(),
) -> list[list[str]]:
# as anthropic does not support sample_num, we sample multiple times
gen_cfg = self._get_options(generation_config)
if self.allow_parallel:
with ThreadPoolExecutor() as pool:
responses = pool.map(
lambda prompt: [
self.client.messages.create(
model=self.model_name,
messages=prompt.to_list(),
**gen_cfg,
)
.content[0]
.text
for _ in range(generation_config.sample_num)
],
prompts,
)
responses = list(responses)
else:
responses: list[list[str]] = []
for prompt in prompts:
prompt = prompt.to_list()
responses.append([])
for _ in range(generation_config.sample_num):
response = self.client.messages.create(
model=self.model_name,
messages=prompt,
**gen_cfg,
)
responses[-1].append(response.content[0].text)
return responses
@TIME_METER("anthropic_generate")
async def async_chat(
self,
prompts: list[ChatPrompt],
generation_config: GenerationConfig = GenerationConfig(),
) -> list[list[str]]:
gen_cfg = self._get_options(generation_config)
tasks = []
for prompt in prompts:
prompt = prompt.to_list()
# as anthropic does not support sample_num, we sample multiple times
tasks.append([])
for _ in range(generation_config.sample_num):
tasks[-1].append(
asyncio.create_task(
asyncio.to_thread(
self.client.messages.create,
model=self.model_name,
messages=prompt,
**gen_cfg,
)
)
)
responses = [
[(await task).content[0].text for task in task_list] for task_list in tasks
]
return responses
@TIME_METER("anthropic_generate")
def _generate(
self,
prefixes: list[str],
generation_config: GenerationConfig = GenerationConfig(),
) -> list[list[str]]:
raise NotImplementedError("The Anthropic text completion API is deprecated.")
@TIME_METER("anthropic_generate")
async def async_generate(
self,
prefixes: list[str],
generation_config: GenerationConfig = GenerationConfig(),
) -> list[list[str]]:
raise NotImplementedError("The Anthropic text completion API is deprecated.")
def _get_options(self, generation_config: GenerationConfig) -> dict:
return {
"temperature": (
generation_config.temperature if generation_config.do_sample else 0.0
),
"max_tokens": generation_config.max_new_tokens,
"top_p": generation_config.top_p,
"top_k": generation_config.top_k,
"stop_sequences": generation_config.stop_str,
}