Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
## Contributing to OmAgent

**We welcome contributions from everyone!** If you're interested in contributing to OmAgent, please follow these guidelines.

### Getting Started

1. **Fork the Repository:** Fork the OmAgent repository on GitHub. This creates your own copy where you can make changes.
2. **Clone Your Fork:** Clone your forked repository to your local machine:

```bash
git clone https://github.com/your-username/OmAgent.git
```

3. **Create a New Branch:** Create a new branch for your specific contribution:

```bash
git checkout -b your-feature-branch
```

### Making Changes

1. **Make Your Changes:** Make your changes to the codebase or documentation.
2. **Commit Your Changes:** Commit your changes with informative commit messages:

```bash
git add .
git commit -m "Your commit message"
```

3. **Push Your Changes:** Push your changes to your forked repository:

```bash
git push origin your-feature-branch
```

### Creating a Pull Request

1. **Navigate to Your Fork:** Go to your forked repository on GitHub.
2. **Create a Pull Request:** Click the "New pull request" button.
3. **Select Your Branch:** Select your feature branch as the head and the main branch of the original OmAgent repository as the base.
4. **Add a Description:** Provide a clear description of your changes.
5. **Submit Your Pull Request:** Click the "Create pull request" button.

### Additional Notes

* **Code Style:** Adhere to the project's coding style conventions (PEP 8).
* **Testing:** Ensure your changes don't introduce regressions. Add unit tests if necessary.
* **Communication:** Feel free to discuss your contributions on the project's issue tracker.

### Commands

Here are some useful commands:

* **List branches:** `git branch`
* **Switch to a branch:** `git checkout branch-name`
* **Merge branches:** `git merge other-branch`
* **Pull changes from upstream:** `git pull upstream main`

**Thank you for your contributions!**
233 changes: 135 additions & 98 deletions omagent-core/src/omagent_core/core/llm/azure_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@
from ...utils.registry import registry
from .base import BaseLLM

BASIC_SYS_PROMPT = """You are an intelligent agent that can help in many regions.
Flowing are some basic information about your working environment, please try your best to answer the questions based on them if needed.
# Define a basic system prompt template
BASIC_SYS_PROMPT = """You are an intelligent agent that can help in many regions.
Following are some basic information about your working environment, please try your best to answer the questions based on them if needed.
Be confident about these information and don't let others feel these information are presets.
Be concise.
---BASIC IMFORMATION---
---BASIC INFORMATION---
Current Datetime: {}
Region: {}
Operating System: {}"""


@registry.register_llm()
class AzureGPTLLM(BaseLLM):
"""
AzureGPTLLM is a class that interfaces with Azure's OpenAI service to generate responses based on input messages.
"""

model_id: str
vision: bool = False
endpoint: str
Expand All @@ -35,11 +39,13 @@ class AzureGPTLLM(BaseLLM):

class Config:
"""Configuration for this pydantic object."""

protected_namespaces = ()
extra = "allow"

def __init__(self, /, **data: Any) -> None:
"""
Initialize the AzureGPTLLM with the provided data.
"""
super().__init__(**data)
self.client = AzureOpenAI(
api_key=self.api_key,
Expand All @@ -53,58 +59,49 @@ def __init__(self, /, **data: Any) -> None:
)

def _call(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
"""
Synchronously call the Azure OpenAI service with the provided messages.
"""
if not self.api_key:
raise ValueError("api_key is required")

if len(self.stm.image_cache):
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value)
for key, value in self.stm.image_cache.items()
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
if kwargs.get("tools"):
body["tools"] = kwargs["tools"]
# Handle image caching
self._handle_image_cache(records, kwargs)

if self.vision:
res = self.client.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
res = self.client.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=body.get("response_format", None),
tools=body.get("tools", None),
)
# Prepare the request body
body = self._msg2req(records, kwargs)

# Make the API call
res = self._make_api_call(body)
res = res.model_dump()
body.update({"response": res})
self.callback.send_block(body)
return res

async def _acall(self, records: List[Message], **kwargs) -> Dict:
if self.api_key is None or self.api_key == "":
"""
Asynchronously call the Azure OpenAI service with the provided messages.
"""
if not self.api_key:
raise ValueError("api_key is required")

# Handle image caching
self._handle_image_cache(records, kwargs)

# Prepare the request body
body = self._msg2req(records, kwargs)

# Make the API call
res = await self._make_async_api_call(body)
res = res.model_dump()
body.update({"response": res})
self.callback.send_block(body)
return res

def _handle_image_cache(self, records: List[Message], kwargs: Dict) -> None:
"""
Handle image caching for the messages.
"""
if len(self.stm.image_cache):
for record in records:
record.combine_image_message(
Expand All @@ -114,43 +111,18 @@ async def _acall(self, records: List[Message], **kwargs) -> Dict:
}
)
elif len(kwargs.get("images", [])):
image_cache = {}
for index, each in enumerate(kwargs["images"]):
image_cache[f"<image_{index}>"] = each
image_cache = {f"<image_{index}>": each for index, each in enumerate(kwargs["images"])}
for record in records:
record.combine_image_message(
image_cache={
key: encode_image(value) for key, value in image_cache.items()
}
)
body = self._msg2req(records)
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
if kwargs.get("tools"):
body["tools"] = kwargs["tools"]

if self.vision:
res = await self.aclient.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
res = await self.aclient.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=body.get("response_format", None),
tools=body.get("tools", None),
)
res = res.model_dump()
body.update({"response": res})
self.callback.send_block(body)
return res

def _msg2req(self, records: List[Message]) -> dict:
def _msg2req(self, records: List[Message], kwargs: Dict) -> dict:
"""
Convert messages to a request format suitable for the API.
"""
def get_content(msg: List[Content] | Content) -> List[dict] | str:
if isinstance(msg, list):
return [c.model_dump(exclude_none=True) for c in msg]
Expand All @@ -163,46 +135,66 @@ def get_content(msg: List[Content] | Content) -> List[dict] | str:
{"role": message.role, "content": get_content(message.content)}
for message in records
]

# Process messages for vision mode
if self.vision:
processed_messages = []
for message in messages:
if message["role"] == "user":
if isinstance(message["content"], str):
message["content"] = [
{"type": "text", "text": message["content"]}
]
merged_dict = {}
for message in messages:
if message["role"] == "user":
merged_dict["role"] = message["role"]
if "content" in merged_dict:
merged_dict["content"] += message["content"]
else:
merged_dict["content"] = message["content"]
else:
processed_messages.append(message)
processed_messages.append(merged_dict)
messages = processed_messages
messages = self._process_vision_messages(messages)

# Add default system prompt if required
if self.use_default_sys_prompt:
messages = [self._generate_default_sys_prompt()] + messages

body = {
"model": self.model_id,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}

if self.response_format != "text":
body["response_format"] = {"type": self.response_format}

# Add tools and tool choices if provided
if kwargs.get("tool_choice"):
body["tool_choice"] = kwargs["tool_choice"]
if kwargs.get("tools"):
body["tools"] = kwargs["tools"]

return body

def _process_vision_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Process messages for vision mode.
"""
processed_messages = []
merged_dict = {}
for message in messages:
if message["role"] == "user":
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
if "content" in merged_dict:
merged_dict["content"] += message["content"]
else:
merged_dict["content"] = message["content"]
else:
processed_messages.append(message)
processed_messages.append(merged_dict)
return processed_messages

def _generate_default_sys_prompt(self) -> Dict:
"""
Generate the default system prompt with current environment details.
"""
loc = self._get_location()
os = self._get_linux_distribution()
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
promt_str = BASIC_SYS_PROMPT.format(loc, os, current_time)
return {"role": "system", "content": promt_str}
prompt_str = BASIC_SYS_PROMPT.format(current_time, loc, os)
return {"role": "system", "content": prompt_str}

def _get_linux_distribution(self) -> str:
"""
Get the Linux distribution name.
"""
platform = sysconfig.get_platform()
if "linux" in platform:
if os.path.exists("/etc/lsb-release"):
Expand All @@ -218,8 +210,53 @@ def _get_linux_distribution(self) -> str:
return platform

def _get_location(self) -> str:
"""
Get the current location based on IP address.
"""
g = geocoder.ip("me")
if g.ok:
return g.city + "," + g.country
return f"{g.city}, {g.country}"
else:
return "unknown"

def _make_api_call(self, body: Dict) -> Any:
"""
Make a synchronous API call to Azure OpenAI.
"""
if self.vision:
return self.client.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
return self.client.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=body.get("response_format", None),
tools=body.get("tools", None),
)

async def _make_async_api_call(self, body: Dict) -> Any:
"""
Make an asynchronous API call to Azure OpenAI.
"""
if self.vision:
return await self.aclient.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
return await self.aclient.chat.completions.create(
model=self.model_id,
messages=body["messages"],
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=body.get("response_format", None),
tools=body.get("tools", None),
)