diff --git a/llm/data_gemma/huggingface_api.py b/llm/data_gemma/huggingface_api.py index 1c46142..13dc3c2 100644 --- a/llm/data_gemma/huggingface_api.py +++ b/llm/data_gemma/huggingface_api.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """HF Pipeline API based LLM Interface. For example usage, see: https://huggingface.co/google/gemma-2-27b @@ -81,7 +80,7 @@ def query(self, prompt: str) -> base.LLMCall: start = time.time() input_ids = self.tokenizer(prompt, return_tensors='pt').to('cuda') - outputs = self.model.generate(**input_ids) + outputs = self.model.generate(**input_ids, max_new_tokens=4096) ans = '' err = ''