Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emit input and output tokens #34

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,18 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install dependencies
run: pip install -r requirements-dev.txt
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Create and activate virtual environment
run: |
python -m venv venv
source venv/bin/activate
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Run pylint
run: |
pylint --recursive y tests/**/*.py
Expand All @@ -31,7 +41,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Create and activate virtual environment
run: |
python -m venv venv
source venv/bin/activate
- name: Install dependencies
run: pip install -r requirements-dev.txt
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Debug information
run: |
which python
python --version
pip list
python -c "import sys; print(sys.path)"
python -c "import attrs; print(attrs.__file__)"
- name: Run unit tests
run: pytest tests/unit
13 changes: 12 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import jinja2
import torch # pylint: disable=import-error
import cog # pylint: disable=import-error
from cog import BasePredictor, ConcatenateIterator, Input
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-error
Expand Down Expand Up @@ -281,7 +282,12 @@ async def predict( # pylint: disable=invalid-overridden-method, arguments-diffe
)

request_id = uuid4().hex
generator = self.engine.generate(prompt, sampling_params, request_id)

generator = self.engine.generate(
prompt,
sampling_params,
request_id,
)
start = 0

async for result in generator:
Expand All @@ -302,6 +308,11 @@ async def predict( # pylint: disable=invalid-overridden-method, arguments-diffe
self.log(f"Generation took {time.time() - start:.2f}s")
self.log(f"Formatted prompt: {prompt}")

if not self._testing:
# pylint: disable=no-member, undefined-loop-variable
cog.emit_metric("input_token_count", len(result.prompt_token_ids))
cog.emit_metric("output_token_count", len(result.outputs[0].token_ids))

def load_config(self, weights: str) -> PredictorConfig:
"""
Load the predictor configuration from the specified weights directory or
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/end_to_end/local/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_predict():

try:

with open(config_filename, "w") as temp_config:
with open(config_filename, "w", encoding="utf-8") as temp_config:
json.dump(predictor_config, temp_config, indent=4)

weights_url = "https://weights.replicate.delivery/default/internal-testing/EleutherAI/pythia-70m/model.tar" # pylint: disable=line-too-long
Expand Down
84 changes: 47 additions & 37 deletions tests/unit/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,42 +135,52 @@ async def test_setup_with_invalid_predictor_config():

@pytest.mark.asyncio
async def test_predict(mock_dependencies):
class MockResult: # pylint: disable=too-few-public-methods
"""
Use this to mock the result object that the engine returns.
"""
def __init__(self, text):
self.outputs = [MagicMock(text=text)]

# Define an async generator function
async def mock_generate(*args, **kwargs): # pylint: disable=unused-argument
yield MockResult("Generated text")

mock_dependencies["engine"].generate = mock_generate

predictor = Predictor()
predictor.log = MagicMock()
with patch.object(Predictor, 'setup') as mock_setup:
def setup_side_effect(*args): # pylint: disable=unused-argument
predictor.engine = mock_dependencies["engine"]
predictor.prompt_template = None
predictor.config = PredictorConfig()
mock_setup.side_effect = setup_side_effect
await predictor.setup("dummy_weights")

# Mock the tokenizer
predictor.tokenizer = MagicMock()
predictor.tokenizer.chat_template = None
predictor.tokenizer.eos_token_id = 0

# Call the predict method
result = predictor.predict(
prompt="Test prompt", prompt_template=MockInput(default=None)
)

# Consume the async generator
texts = []
async for item in result:
texts.append(item)
with patch('predict.cog.emit_metric') as mock_emit_metric:
class MockOutput:
def __init__(self, text):
self.text = text
self.token_ids = [4, 5, 6] # Generated tokens

class MockResult:
def __init__(self, text):
self.outputs = [MockOutput(text)]
self.prompt_token_ids = [1, 2, 3] # Input tokens


# Define an async generator function
async def mock_generate(*args, **kwargs): # pylint: disable=unused-argument
yield MockResult("Generated text")

mock_dependencies["engine"].generate = mock_generate

predictor = Predictor()
predictor.log = MagicMock()
with patch.object(Predictor, 'setup') as mock_setup:
def setup_side_effect(*args): # pylint: disable=unused-argument
predictor.engine = mock_dependencies["engine"]
predictor.prompt_template = None
predictor.config = PredictorConfig()
predictor._testing = False # pylint: disable=protected-access
mock_setup.side_effect = setup_side_effect
await predictor.setup("dummy_weights")

assert texts == ["Generated text"]
# Mock the tokenizer
predictor.tokenizer = MagicMock()
predictor.tokenizer.chat_template = None
predictor.tokenizer.eos_token_id = 0

# Call the predict method
result = predictor.predict(
prompt="Test prompt", prompt_template=MockInput(default=None)
)

# Consume the async generator
texts = []
async for item in result:
texts.append(item)

assert texts == ["Generated text"]
# Assert that emit_metric was called with the expected arguments
mock_emit_metric.assert_any_call("input_token_count", 3)
mock_emit_metric.assert_any_call("output_token_count", 3)
Loading