blob: b9f8a113d28b64e0dd5537cc58c14ab256ebadca [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from hugegraph_llm.models.llms.openai import OpenAIClient
class TestOpenAIClient(unittest.TestCase):
def setUp(self):
"""Set up test fixtures and common mock objects."""
# Create mock completion response
self.mock_completion_response = MagicMock()
self.mock_completion_response.choices = [MagicMock(message=MagicMock(content="Paris"))]
self.mock_completion_response.usage = MagicMock()
self.mock_completion_response.usage.model_dump_json.return_value = (
'{"prompt_tokens": 10, "completion_tokens": 5}'
)
# Create mock streaming chunks
self.mock_streaming_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Pa"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="ris"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # Empty content
]
@patch("hugegraph_llm.models.llms.openai.OpenAI")
def test_generate(self, mock_openai_class):
"""Test generate method with mocked OpenAI client."""
# Setup mock client
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = self.mock_completion_response
mock_openai_class.return_value = mock_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
response = openai_client.generate(prompt="What is the capital of France?")
# Verify the response
self.assertIsInstance(response, str)
self.assertEqual(response, "Paris")
# Verify the API was called with correct parameters
mock_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.01,
max_tokens=8092,
messages=[{"role": "user", "content": "What is the capital of France?"}],
)
@patch("hugegraph_llm.models.llms.openai.OpenAI")
def test_generate_with_messages(self, mock_openai_class):
"""Test generate method with messages parameter."""
# Setup mock client
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = self.mock_completion_response
mock_openai_class.return_value = mock_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
response = openai_client.generate(messages=messages)
# Verify the response
self.assertIsInstance(response, str)
self.assertEqual(response, "Paris")
# Verify the API was called with correct parameters
mock_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.01,
max_tokens=8092,
messages=messages,
)
@patch("hugegraph_llm.models.llms.openai.AsyncOpenAI")
def test_agenerate(self, mock_async_openai_class):
"""Test agenerate method with mocked async OpenAI client."""
# Setup mock async client
mock_async_client = MagicMock()
mock_async_client.chat.completions.create = AsyncMock(return_value=self.mock_completion_response)
mock_async_openai_class.return_value = mock_async_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
async def run_async_test():
response = await openai_client.agenerate(prompt="What is the capital of France?")
self.assertIsInstance(response, str)
self.assertEqual(response, "Paris")
# Verify the API was called with correct parameters
mock_async_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.01,
max_tokens=8092,
messages=[{"role": "user", "content": "What is the capital of France?"}],
)
asyncio.run(run_async_test())
@patch("hugegraph_llm.models.llms.openai.OpenAI")
def test_stream_generate(self, mock_openai_class):
"""Test generate_streaming method with mocked OpenAI client."""
# Setup mock client with streaming response
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = iter(self.mock_streaming_chunks)
mock_openai_class.return_value = mock_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
collected_tokens = []
def on_token_callback(chunk):
collected_tokens.append(chunk)
# Collect all tokens from the generator
tokens = list(
openai_client.generate_streaming(
prompt="What is the capital of France?", on_token_callback=on_token_callback
)
)
# Verify the response
self.assertEqual(tokens, ["Pa", "ris"])
self.assertEqual(collected_tokens, ["Pa", "ris"])
# Verify the API was called with correct parameters
mock_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.01,
max_tokens=8092,
messages=[{"role": "user", "content": "What is the capital of France?"}],
stream=True,
)
@patch("hugegraph_llm.models.llms.openai.AsyncOpenAI")
def test_agenerate_streaming(self, mock_async_openai_class):
"""Test agenerate_streaming method with mocked async OpenAI client."""
# Setup mock async client with streaming response
mock_async_client = MagicMock()
# Create async generator for streaming chunks
async def async_streaming_chunks():
for chunk in self.mock_streaming_chunks:
yield chunk
mock_async_client.chat.completions.create = AsyncMock(return_value=async_streaming_chunks())
mock_async_openai_class.return_value = mock_async_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
async def run_async_streaming_test():
collected_tokens = []
def on_token_callback(chunk):
collected_tokens.append(chunk)
# Collect all tokens from the async generator
tokens = []
async for token in openai_client.agenerate_streaming(
prompt="What is the capital of France?", on_token_callback=on_token_callback
):
tokens.append(token)
# Verify the response
self.assertEqual(tokens, ["Pa", "ris"])
self.assertEqual(collected_tokens, ["Pa", "ris"])
# Verify the API was called with correct parameters
mock_async_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.01,
max_tokens=8092,
messages=[{"role": "user", "content": "What is the capital of France?"}],
stream=True,
)
asyncio.run(run_async_streaming_test())
@patch("hugegraph_llm.models.llms.openai.OpenAI")
def test_generate_authentication_error(self, mock_openai_class):
"""Test generate method with authentication error."""
# Setup mock client to raise OpenAI 的认证错误
from openai import AuthenticationError
mock_client = MagicMock()
# Create a properly formatted AuthenticationError
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.headers = {}
auth_error = AuthenticationError(
message="Invalid API key", response=mock_response, body={"error": {"message": "Invalid API key"}}
)
mock_client.chat.completions.create.side_effect = auth_error
mock_openai_class.return_value = mock_client
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
# 调用后应返回认证失败的错误消息
result = openai_client.generate(prompt="What is the capital of France?")
self.assertEqual(result, "Error: The provided OpenAI API key is invalid")
@patch("hugegraph_llm.models.llms.openai.tiktoken.encoding_for_model")
def test_num_tokens_from_string(self, mock_encoding_for_model):
"""Test num_tokens_from_string method with mocked tiktoken."""
# Setup mock encoding
mock_encoding = MagicMock()
mock_encoding.encode.return_value = [1, 2, 3, 4, 5] # 5 tokens
mock_encoding_for_model.return_value = mock_encoding
# Test the method
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
token_count = openai_client.num_tokens_from_string("Hello, world!")
# Verify the response
self.assertIsInstance(token_count, int)
self.assertEqual(token_count, 5)
# Verify the encoding was called correctly
mock_encoding_for_model.assert_called_once_with("gpt-3.5-turbo")
mock_encoding.encode.assert_called_once_with("Hello, world!")
def test_max_allowed_token_length(self):
"""Test max_allowed_token_length method."""
openai_client = OpenAIClient(model_name="gpt-3.5-turbo")
max_tokens = openai_client.max_allowed_token_length()
self.assertIsInstance(max_tokens, int)
self.assertEqual(max_tokens, 8192)
def test_get_llm_type(self):
"""Test get_llm_type method."""
openai_client = OpenAIClient()
llm_type = openai_client.get_llm_type()
self.assertEqual(llm_type, "openai")
if __name__ == "__main__":
unittest.main()