Skip to content

Commit

Permalink
Emit input and output tokens
Browse files Browse the repository at this point in the history
Emits input and output token counts with cog.emit_metric. Updates tests accordingly.
  • Loading branch information
joehoover committed Aug 1, 2024
1 parent b32418b commit beafd97
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 43 deletions.
33 changes: 30 additions & 3 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,18 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install dependencies
run: pip install -r requirements-dev.txt
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Create and activate virtual environment
run: |
python -m venv venv
source venv/bin/activate
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Run pylint
run: |
pylint --recursive y tests/**/*.py
Expand All @@ -31,7 +41,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Create and activate virtual environment
run: |
python -m venv venv
source venv/bin/activate
- name: Install dependencies
run: pip install -r requirements-dev.txt
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Debug information
run: |
which python
python --version
pip list
python -c "import sys; print(sys.path)"
python -c "import attrs; print(attrs.__file__)"
- name: Run unit tests
run: pytest tests/unit
13 changes: 12 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import jinja2
import torch # pylint: disable=import-error
import cog # pylint: disable=import-error
from cog import BasePredictor, ConcatenateIterator, Input
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-error
Expand Down Expand Up @@ -281,7 +282,12 @@ async def predict( # pylint: disable=invalid-overridden-method, arguments-diffe
)

request_id = uuid4().hex
generator = self.engine.generate(prompt, sampling_params, request_id)

generator = self.engine.generate(
prompt,
sampling_params,
request_id,
)
start = 0

async for result in generator:
Expand All @@ -302,6 +308,11 @@ async def predict( # pylint: disable=invalid-overridden-method, arguments-diffe
self.log(f"Generation took {time.time() - start:.2f}s")
self.log(f"Formatted prompt: {prompt}")

if not self._testing:
# pylint: disable=no-member, undefined-loop-variable
cog.emit_metric("input_token_count", len(result.prompt_token_ids))
cog.emit_metric("output_token_count", len(result.outputs[0].token_ids))

def load_config(self, weights: str) -> PredictorConfig:
"""
Load the predictor configuration from the specified weights directory or
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/end_to_end/local/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_predict():

try:

with open(config_filename, "w") as temp_config:
with open(config_filename, "w", encoding="utf-8") as temp_config:
json.dump(predictor_config, temp_config, indent=4)

weights_url = "https://weights.replicate.delivery/default/internal-testing/EleutherAI/pythia-70m/model.tar" # pylint: disable=line-too-long
Expand Down
84 changes: 47 additions & 37 deletions tests/unit/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,42 +135,52 @@ async def test_setup_with_invalid_predictor_config():

@pytest.mark.asyncio
async def test_predict(mock_dependencies):
class MockResult: # pylint: disable=too-few-public-methods
"""
Use this to mock the result object that the engine returns.
"""
def __init__(self, text):
self.outputs = [MagicMock(text=text)]

# Define an async generator function
async def mock_generate(*args, **kwargs): # pylint: disable=unused-argument
yield MockResult("Generated text")

mock_dependencies["engine"].generate = mock_generate

predictor = Predictor()
predictor.log = MagicMock()
with patch.object(Predictor, 'setup') as mock_setup:
def setup_side_effect(*args): # pylint: disable=unused-argument
predictor.engine = mock_dependencies["engine"]
predictor.prompt_template = None
predictor.config = PredictorConfig()
mock_setup.side_effect = setup_side_effect
await predictor.setup("dummy_weights")

# Mock the tokenizer
predictor.tokenizer = MagicMock()
predictor.tokenizer.chat_template = None
predictor.tokenizer.eos_token_id = 0

# Call the predict method
result = predictor.predict(
prompt="Test prompt", prompt_template=MockInput(default=None)
)

# Consume the async generator
texts = []
async for item in result:
texts.append(item)
with patch('predict.cog.emit_metric') as mock_emit_metric:
class MockOutput:
def __init__(self, text):
self.text = text
self.token_ids = [4, 5, 6] # Generated tokens

class MockResult:
def __init__(self, text):
self.outputs = [MockOutput(text)]
self.prompt_token_ids = [1, 2, 3] # Input tokens


# Define an async generator function
async def mock_generate(*args, **kwargs): # pylint: disable=unused-argument
yield MockResult("Generated text")

mock_dependencies["engine"].generate = mock_generate

predictor = Predictor()
predictor.log = MagicMock()
with patch.object(Predictor, 'setup') as mock_setup:
def setup_side_effect(*args): # pylint: disable=unused-argument
predictor.engine = mock_dependencies["engine"]
predictor.prompt_template = None
predictor.config = PredictorConfig()
predictor._testing = False # pylint: disable=protected-access
mock_setup.side_effect = setup_side_effect
await predictor.setup("dummy_weights")

assert texts == ["Generated text"]
# Mock the tokenizer
predictor.tokenizer = MagicMock()
predictor.tokenizer.chat_template = None
predictor.tokenizer.eos_token_id = 0

# Call the predict method
result = predictor.predict(
prompt="Test prompt", prompt_template=MockInput(default=None)
)

# Consume the async generator
texts = []
async for item in result:
texts.append(item)

assert texts == ["Generated text"]
# Assert that emit_metric was called with the expected arguments
mock_emit_metric.assert_any_call("input_token_count", 3)
mock_emit_metric.assert_any_call("output_token_count", 3)

0 comments on commit beafd97

Please sign in to comment.