-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't deploy locally on a Mac with M2 Apple Silicon chip #64
Labels
Comments
somebody can give an install guide on the mac with Apple Silicon chip |
they are talking about the model |
vLLM only works with Linux. You'll need to setup an endpoint with Huggingface. This code worked for me. import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, ValidationError
from typing import List, Dict, Union, Optional
import base64
from PIL import Image
import io
import logging
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import time
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Check if MPS is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Load model and processor
model_path = "/Users/admin/Desktop/AI/GUI Agents/UI-TARS/ui-tars-2b-sft"
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map=device
)
# Pydantic models for request validation
class ImageUrl(BaseModel):
url: str
class ContentItem(BaseModel):
type: str
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
class Message(BaseModel):
role: str
content: Union[str, List[ContentItem]] # Can be either string or list of content items
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.7
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
seed: Optional[int] = None
class Choice(BaseModel):
index: int
message: Message
finish_reason: str
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[Choice]
def decode_base64_image(base64_string: str) -> Image.Image:
try:
# Remove the data URL prefix if present
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
logging.debug(f"Decoded image size: {image.size}")
return image
except Exception as e:
logging.error(f"Error decoding image: {str(e)}")
raise
@app.middleware("http")
async def log_requests(request: Request, call_next):
body = await request.body()
logging.info(f"Raw request body: {body.decode()}")
response = await call_next(request)
return response
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
try:
logging.info("Received chat completion request")
logging.info(f"Full request data: {request.dict()}")
# Process the last message
last_message = request.messages[-1]
messages = []
# Convert the message to Qwen2-VL format
content_list = []
if isinstance(last_message.content, str):
content_list.append({
"type": "text",
"text": last_message.content
})
logging.debug(f"Added text: {last_message.content[:100]}...")
else:
for content_item in last_message.content:
if content_item.type == "text":
content_list.append({
"type": "text",
"text": content_item.text
})
logging.debug(f"Added text: {content_item.text[:100]}...")
elif content_item.type == "image_url":
image_url = content_item.image_url.url
if image_url.startswith("data:image"):
logging.debug("Decoding base64 image")
image = decode_base64_image(image_url)
content_list.append({
"type": "image",
"image": image
})
messages.append({
"role": "user",
"content": content_list
})
# Prepare text input using chat template
logging.info("Preparing model inputs")
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process image inputs
image_inputs = []
for message in messages:
for content in message['content']:
if content['type'] == 'image':
image_inputs.append(content['image'])
# Prepare model inputs
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt"
)
# Log input tensor shapes and memory usage
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
logging.debug(f"Input tensor {key}: shape={value.shape}, dtype={value.dtype}")
logging.debug(f"Memory used by {key}: {value.element_size() * value.nelement() / 1024 / 1024:.2f} MB")
# Move inputs to device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
logging.info("Generating response")
# Generate response
generated_ids = model.generate(
**inputs,
max_new_tokens=request.max_tokens or 512,
do_sample=True,
temperature=request.temperature or 0.7,
pad_token_id=processor.tokenizer.pad_token_id
)
# Trim and decode the response
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
]
response_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0] # Get first response since we only have one
# Create the response object
response = ChatCompletionResponse(
id=str(time.time()),
object="chat.completion",
created=int(time.time()),
model=request.model,
choices=[
Choice(
index=0,
message=Message(
role="assistant",
content=[ContentItem(type="text", text=response_text)]
),
finish_reason="stop"
)
]
)
logging.info("Returning response")
return response
except ValidationError as e:
logging.error(f"Validation error: {e.json()}")
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logging.error(f"Error processing request: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001) |
Open
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can't deploy locally on a Mac with M2 Apple Silicon chip.
RuntimeError: Failed to infer device type
The text was updated successfully, but these errors were encountered: