diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index e87e75cf5d3..e6c06babf1f 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -15,7 +15,7 @@ import grpc import torch - +import torch.cuda from transformers import AutoTokenizer, AutoModel _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -70,14 +70,10 @@ def LoadModel(self, request, context): try: self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2 self.tokenizer = AutoTokenizer.from_pretrained(model_name) - - if request.CUDA: + if request.CUDA or torch.cuda.is_available(): try: - # TODO: also tensorflow, make configurable - import torch.cuda - if torch.cuda.is_available(): - print("Loading model", model_name, "to CUDA.", file=sys.stderr) - self.model = self.model.to("cuda") + print("Loading model", model_name, "to CUDA.", file=sys.stderr) + self.model = self.model.to("cuda") except Exception as err: print("Not using CUDA:", err, file=sys.stderr) except Exception as err: @@ -113,6 +109,47 @@ def Embedding(self, request, context): print("Embeddings:", sentence_embeddings, file=sys.stderr) return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) + def Predict(self, request, context): + """ + Generates text based on the given prompt and sampling parameters. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The predict result. + """ + if request.TopP == 0: + request.TopP = 0.9 + + max_tokens = 200 + if request.Tokens > 0: + max_tokens = request.Tokens + + inputs = self.tokenizer.tokenizer(request.Prompt, return_tensors="pt").input_ids + outputs = self.model.generate(inputs,max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP) + + generated_text = self.tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + # Remove prompt from response if present + if request.Prompt in generated_text: + generated_text = generated_text.replace(request.Prompt, "") + + return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + + def PredictStream(self, request, context): + """ + Generates text based on the given prompt and sampling parameters, and streams the results. + + Args: + request: The predict stream request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The predict stream result. + """ + yield self.Predict(request, context) + def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) diff --git a/embedded/models/dolphin-2.5-mixtral-8x7b.yaml b/embedded/models/dolphin-2.5-mixtral-8x7b.yaml index dbbeac0ef51..b6df47997cd 100644 --- a/embedded/models/dolphin-2.5-mixtral-8x7b.yaml +++ b/embedded/models/dolphin-2.5-mixtral-8x7b.yaml @@ -5,6 +5,7 @@ parameters: temperature: 0.2 top_k: 40 top_p: 0.95 + seed: -1 template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}} diff --git a/embedded/models/llava.yaml b/embedded/models/llava.yaml index 551eb26b55d..2e571f212e2 100644 --- a/embedded/models/llava.yaml +++ b/embedded/models/llava.yaml @@ -17,6 +17,7 @@ parameters: temperature: 0.2 top_k: 40 top_p: 0.95 + seed: -1 template: chat: | diff --git a/embedded/models/mistral-openorca.yaml b/embedded/models/mistral-openorca.yaml index 3a41c766c8d..fbab4e39ad1 100644 --- a/embedded/models/mistral-openorca.yaml +++ b/embedded/models/mistral-openorca.yaml @@ -5,6 +5,7 @@ parameters: temperature: 0.2 top_k: 40 top_p: 0.95 + seed: -1 template: chat_message: | <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}} diff --git a/embedded/models/mixtral-instruct.yaml b/embedded/models/mixtral-instruct.yaml index c9c55869032..3272557a717 100644 --- a/embedded/models/mixtral-instruct.yaml +++ b/embedded/models/mixtral-instruct.yaml @@ -4,6 +4,7 @@ parameters: model: huggingface://TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/mixtral-8x7b-instruct-v0.1.Q2_K.gguf temperature: 0.2 top_k: 40 + seed: -1 top_p: 0.95 template: chat: &chat | diff --git a/embedded/models/tinyllama-chat.yaml b/embedded/models/tinyllama-chat.yaml index 7c9a7579216..48c44f9fc85 100644 --- a/embedded/models/tinyllama-chat.yaml +++ b/embedded/models/tinyllama-chat.yaml @@ -4,6 +4,7 @@ parameters: model: huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q8_0.gguf temperature: 0.2 top_k: 40 + seed: -1 top_p: 0.95 template: chat_message: | diff --git a/examples/configurations/phi-2.yaml b/examples/configurations/phi-2.yaml index 67cef0cc088..c09aa6ce261 100644 --- a/examples/configurations/phi-2.yaml +++ b/examples/configurations/phi-2.yaml @@ -10,6 +10,7 @@ parameters: temperature: 0.2 top_k: 40 top_p: 0.95 + seed: -1 template: chat: &template | Instruct: {{.Input}}