From 0dae8ed9df50a0849bbe221c2b1032b2a756e15b Mon Sep 17 00:00:00 2001 From: nauyisu022 <59754221+nauyisu022@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:42:43 +0800 Subject: [PATCH] update --- docs/guides/generation_details.md | 2 - trustllm_pkg/setup.py | 2 - trustllm_pkg/trustllm/config.py | 5 +-- .../trustllm/utils/generation_utils.py | 42 +------------------ 4 files changed, 2 insertions(+), 49 deletions(-) diff --git a/docs/guides/generation_details.md b/docs/guides/generation_details.md index 2f15a7c..074cee6 100644 --- a/docs/guides/generation_details.md +++ b/docs/guides/generation_details.md @@ -70,8 +70,6 @@ config.claude_api = "claude api" config.openai_key = "openai api" -config.palm_api = "palm api" - config.ernie_client_id = "ernie client id" config.ernie_client_secret = "ernie client secret" diff --git a/trustllm_pkg/setup.py b/trustllm_pkg/setup.py index 723d819..7812d64 100644 --- a/trustllm_pkg/setup.py +++ b/trustllm_pkg/setup.py @@ -25,9 +25,7 @@ 'python-dotenv', 'urllib3', 'anthropic', - 'google.generativeai', 'google-api-python-client', - 'google.ai.generativelanguage', 'replicate', 'zhipuai>=2.0.1' ], diff --git a/trustllm_pkg/trustllm/config.py b/trustllm_pkg/trustllm/config.py index 27c1e0f..cc03a7b 100644 --- a/trustllm_pkg/trustllm/config.py +++ b/trustllm_pkg/trustllm/config.py @@ -9,7 +9,6 @@ deepinfra_api = None ernie_api = None claude_api = None -palm_api = None replicate_api = None zhipu_api = None @@ -38,11 +37,10 @@ zhipu_model = ["glm-4", "glm-3-turbo"] claude_model = ["claude-2", "claude-instant-1"] openai_model = ["chatgpt", "gpt-4"] -google_model = ["bison-001", "gemini"] wenxin_model = ["ernie"] replicate_model=["vicuna-7b","vicuna-13b","vicuna-33b","chatglm3-6b","llama3-70b","llama3-8b"] -online_model = deepinfra_model + zhipu_model + claude_model + openai_model + google_model + wenxin_model+replicate_model +online_model = deepinfra_model + zhipu_model + claude_model + openai_model + wenxin_model+replicate_model model_info = { "online_model": online_model, @@ -50,7 +48,6 @@ "deepinfra_model": deepinfra_model, 'claude_model': claude_model, 'openai_model': openai_model, - 'google_model': google_model, 'wenxin_model': wenxin_model, 'replicate_model':replicate_model, "model_mapping": { diff --git a/trustllm_pkg/trustllm/utils/generation_utils.py b/trustllm_pkg/trustllm/utils/generation_utils.py index 7697b43..546300e 100644 --- a/trustllm_pkg/trustllm/utils/generation_utils.py +++ b/trustllm_pkg/trustllm/utils/generation_utils.py @@ -1,7 +1,6 @@ import os, json from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT -import google.generativeai as genai -from google.generativeai.types import safety_types + from fastchat.model import load_model, get_conversation_template from openai import OpenAI,AzureOpenAI from tenacity import retry, wait_random_exponential, stop_after_attempt @@ -17,16 +16,6 @@ model_mapping = model_info['model_mapping'] rev_model_mapping = {value: key for key, value in model_mapping.items()} -# Define safety settings to allow harmful content generation -safety_setting = [ - {"category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, - {"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, - {"category": safety_types.HarmCategory.HARM_CATEGORY_SEXUAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, - {"category": safety_types.HarmCategory.HARM_CATEGORY_TOXICITY, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, - {"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, - {"category": safety_types.HarmCategory.HARM_CATEGORY_DANGEROUS, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE}, -] - # Retrieve model information def get_models(): return model_mapping, online_model_list @@ -98,31 +87,7 @@ def claude_api(string, model, temperature): return completion.completion -@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) -def gemini_api(string, temperature): - genai.configure(api_key=trustllm.config.gemini_api) - model = genai.GenerativeModel('gemini-pro') - response = model.generate_content(string, temperature=temperature, safety_settings=safety_setting) - return response - - -@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) -def palm_api(string, model, temperature): - genai.configure(api_key=trustllm.config.palm_api) - - model_mapping = { - 'bison-001': 'models/text-bison-001', - } - completion = genai.generate_text( - model=model_mapping[model], # models/text-bison-001 - prompt=string, - temperature=temperature, - # The maximum length of the response - max_output_tokens=4000, - safety_settings=safety_setting - ) - return completion.result @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) @@ -148,11 +113,6 @@ def zhipu_api(string, model, temperature): def gen_online(model_name, prompt, temperature, replicate=False, deepinfra=False): if model_name in model_info['wenxin_model']: res = get_ernie_res(prompt, temperature=temperature) - elif model_name in model_info['google_model']: - if model_name == 'bison-001': - res = palm_api(prompt, model=model_name, temperature=temperature) - elif model_name == 'gemini-pro': - res = gemini_api(prompt, temperature=temperature) elif model_name in model_info['openai_model']: res = get_res_openai(prompt, model=model_name, temperature=temperature) elif model_name in model_info['deepinfra_model']: