diff --git a/predict.py b/predict.py index 39fc7a6..197289c 100644 --- a/predict.py +++ b/predict.py @@ -6,7 +6,6 @@ from uuid import uuid4 from dataclasses import dataclass, field from pprint import pprint -import inspect import jinja2 import torch # pylint: disable=import-error @@ -159,28 +158,17 @@ async def setup( f"Using prompt template from `predictor_config.json`: {self.config.prompt_template}" ) self.tokenizer.chat_template = self.config.prompt_template - self.prompt_template = None elif self.tokenizer.chat_template: print( f"Using prompt template from `tokenizer`: {self.tokenizer.chat_template}" ) - self.prompt_template = None else: print( "No prompt template specified in `predictor_config.json` or " f"`tokenizer`, defaulting to: {PROMPT_TEMPLATE}" ) - self.tokenizer.chat_template = None - self.prompt_template = PROMPT_TEMPLATE - - self._testing = True - # generator = self.predict( - # **dict(self._defaults, **{"max_tokens": 3, "prompt": "hi"}) - # ) - # test_output = "".join([tok async for tok in generator]) - # print("Test prediction output:", test_output) - self._testing = False + self.tokenizer.chat_template = PROMPT_TEMPLATE async def predict( # pylint: disable=invalid-overridden-method, arguments-differ, too-many-arguments, too-many-locals self, @@ -230,8 +218,7 @@ async def predict( # pylint: disable=invalid-overridden-method, arguments-diffe ) -> ConcatenateIterator[str]: start = time.time() - if prompt_template or self.prompt_template: - prompt_template = prompt_template or self.prompt_template + if prompt_template: prompt = format_prompt( prompt=prompt, prompt_template=prompt_template, @@ -345,9 +332,3 @@ def load_config(self, weights: str) -> PredictorConfig: config = PredictorConfig() pprint(config) return config - - _defaults = { - key: param.default.default - for key, param in inspect.signature(predict).parameters.items() - if hasattr(param.default, "default") - }