From dcbf83686d53ce42db3ae9fe98d5b3ce9c555345 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 6 Nov 2023 01:44:21 +0000 Subject: [PATCH 1/3] Add support to CTransformer Signed-off-by: GitHub --- extra/grpc/ctransformers/Makefile | 18 + extra/grpc/ctransformers/README.md | 5 + extra/grpc/ctransformers/backend_pb2.py | 61 ++++ extra/grpc/ctransformers/backend_pb2_grpc.py | 363 +++++++++++++++++++ extra/grpc/ctransformers/c_transformers.py | 106 ++++++ extra/grpc/ctransformers/run.sh | 14 + extra/grpc/huggingface/backend_pb2_grpc.py | 1 + extra/requirements.txt | 7 - 8 files changed, 568 insertions(+), 7 deletions(-) create mode 100644 extra/grpc/ctransformers/Makefile create mode 100644 extra/grpc/ctransformers/README.md create mode 100644 extra/grpc/ctransformers/backend_pb2.py create mode 100644 extra/grpc/ctransformers/backend_pb2_grpc.py create mode 100644 extra/grpc/ctransformers/c_transformers.py create mode 100755 extra/grpc/ctransformers/run.sh delete mode 100644 extra/requirements.txt diff --git a/extra/grpc/ctransformers/Makefile b/extra/grpc/ctransformers/Makefile new file mode 100644 index 00000000000..cda763b7897 --- /dev/null +++ b/extra/grpc/ctransformers/Makefile @@ -0,0 +1,18 @@ +.PONY: ctransformers +ctransformers: + @echo "Creating virtual environment..." + @conda create -n ctransformers python=3.11 -y + @echo "Virtual environment created." + + @echo "Activating virtual environment..." + @. activate ctransformers + + @echo "Installing dependencies..." + @pip install grpcio==1.59.0 protobuf==4.24.4 + + # Install ctransformers from JLLLLLL's cuBLAS wheels will append cu117to version of ctransformer, this will cause creating from file failed. + @echo "Installing ctransformers..." + @pip install ctransformers==0.2.27 --prefer-binary --extra-index-url=https://jllllll.github.io/ctransformers-cuBLAS-wheels/AVX2/cu117 + + @echo "Deactivating virtual environment..." + @. deactivate \ No newline at end of file diff --git a/extra/grpc/ctransformers/README.md b/extra/grpc/ctransformers/README.md new file mode 100644 index 00000000000..93027ed83ec --- /dev/null +++ b/extra/grpc/ctransformers/README.md @@ -0,0 +1,5 @@ +# Creating a separate environment for ctransformers project + +``` +make ctransformers +``` diff --git a/extra/grpc/ctransformers/backend_pb2.py b/extra/grpc/ctransformers/backend_pb2.py new file mode 100644 index 00000000000..12e8bf51e15 --- /dev/null +++ b/extra/grpc/ctransformers/backend_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: backend.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\x96\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\x12\x0e\n\x06NDraft\x18) \x01(\x05\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x86\x06\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\x12\x10\n\x08\x43\x46GScale\x18\x1d \x01(\x02\x12\x0f\n\x07IMG2IMG\x18\x1e \x01(\x08\x12\x11\n\tCLIPModel\x18\x1f \x01(\t\x12\x15\n\rCLIPSubfolder\x18 \x01(\t\x12\x10\n\x08\x43LIPSkip\x18! \x01(\x05\x12\x11\n\tTokenizer\x18\" \x01(\t\x12\x10\n\x08LoraBase\x18# \x01(\t\x12\x13\n\x0bLoraAdapter\x18$ \x01(\t\x12\x11\n\tNoMulMatQ\x18% \x01(\x08\x12\x12\n\nDraftModel\x18\' \x01(\t\x12\x11\n\tAudioPath\x18& \x01(\t\x12\x14\n\x0cQuantization\x18( \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\xd7\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\x12\x0b\n\x03src\x18\t \x01(\t\x12\x18\n\x10\x45nableParameters\x18\n \x01(\t\x12\x10\n\x08\x43LIPSkip\x18\x0b \x01(\x05\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t\"6\n\x14TokenizationResponse\x12\x0e\n\x06length\x18\x01 \x01(\x05\x12\x0e\n\x06tokens\x18\x02 \x03(\x05\"\x8e\x01\n\x0fMemoryUsageData\x12\r\n\x05total\x18\x01 \x01(\x04\x12:\n\tbreakdown\x18\x02 \x03(\x0b\x32\'.backend.MemoryUsageData.BreakdownEntry\x1a\x30\n\x0e\x42reakdownEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x04:\x02\x38\x01\"\xad\x01\n\x0eStatusResponse\x12,\n\x05state\x18\x01 \x01(\x0e\x32\x1d.backend.StatusResponse.State\x12(\n\x06memory\x18\x02 \x01(\x0b\x32\x18.backend.MemoryUsageData\"C\n\x05State\x12\x11\n\rUNINITIALIZED\x10\x00\x12\x08\n\x04\x42USY\x10\x01\x12\t\n\x05READY\x10\x02\x12\x12\n\x05\x45RROR\x10\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x32\xf4\x04\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x12J\n\x0eTokenizeString\x12\x17.backend.PredictOptions\x1a\x1d.backend.TokenizationResponse\"\x00\x12;\n\x06Status\x12\x16.backend.HealthMessage\x1a\x17.backend.StatusResponse\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' + _MEMORYUSAGEDATA_BREAKDOWNENTRY._options = None + _MEMORYUSAGEDATA_BREAKDOWNENTRY._serialized_options = b'8\001' + _globals['_HEALTHMESSAGE']._serialized_start=26 + _globals['_HEALTHMESSAGE']._serialized_end=41 + _globals['_PREDICTOPTIONS']._serialized_start=44 + _globals['_PREDICTOPTIONS']._serialized_end=834 + _globals['_REPLY']._serialized_start=836 + _globals['_REPLY']._serialized_end=860 + _globals['_MODELOPTIONS']._serialized_start=863 + _globals['_MODELOPTIONS']._serialized_end=1637 + _globals['_RESULT']._serialized_start=1639 + _globals['_RESULT']._serialized_end=1681 + _globals['_EMBEDDINGRESULT']._serialized_start=1683 + _globals['_EMBEDDINGRESULT']._serialized_end=1720 + _globals['_TRANSCRIPTREQUEST']._serialized_start=1722 + _globals['_TRANSCRIPTREQUEST']._serialized_end=1789 + _globals['_TRANSCRIPTRESULT']._serialized_start=1791 + _globals['_TRANSCRIPTRESULT']._serialized_end=1869 + _globals['_TRANSCRIPTSEGMENT']._serialized_start=1871 + _globals['_TRANSCRIPTSEGMENT']._serialized_end=1960 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=1963 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=2178 + _globals['_TTSREQUEST']._serialized_start=2180 + _globals['_TTSREQUEST']._serialized_end=2234 + _globals['_TOKENIZATIONRESPONSE']._serialized_start=2236 + _globals['_TOKENIZATIONRESPONSE']._serialized_end=2290 + _globals['_MEMORYUSAGEDATA']._serialized_start=2293 + _globals['_MEMORYUSAGEDATA']._serialized_end=2435 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_start=2387 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_end=2435 + _globals['_STATUSRESPONSE']._serialized_start=2438 + _globals['_STATUSRESPONSE']._serialized_end=2611 + _globals['_STATUSRESPONSE_STATE']._serialized_start=2544 + _globals['_STATUSRESPONSE_STATE']._serialized_end=2611 + _globals['_BACKEND']._serialized_start=2614 + _globals['_BACKEND']._serialized_end=3242 +# @@protoc_insertion_point(module_scope) diff --git a/extra/grpc/ctransformers/backend_pb2_grpc.py b/extra/grpc/ctransformers/backend_pb2_grpc.py new file mode 100644 index 00000000000..79a7677fb27 --- /dev/null +++ b/extra/grpc/ctransformers/backend_pb2_grpc.py @@ -0,0 +1,363 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import backend_pb2 as backend__pb2 + + +class BackendStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Health = channel.unary_unary( + '/backend.Backend/Health', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Predict = channel.unary_unary( + '/backend.Backend/Predict', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.LoadModel = channel.unary_unary( + '/backend.Backend/LoadModel', + request_serializer=backend__pb2.ModelOptions.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.PredictStream = channel.unary_stream( + '/backend.Backend/PredictStream', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Embedding = channel.unary_unary( + '/backend.Backend/Embedding', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.EmbeddingResult.FromString, + ) + self.GenerateImage = channel.unary_unary( + '/backend.Backend/GenerateImage', + request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.AudioTranscription = channel.unary_unary( + '/backend.Backend/AudioTranscription', + request_serializer=backend__pb2.TranscriptRequest.SerializeToString, + response_deserializer=backend__pb2.TranscriptResult.FromString, + ) + self.TTS = channel.unary_unary( + '/backend.Backend/TTS', + request_serializer=backend__pb2.TTSRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.TokenizeString = channel.unary_unary( + '/backend.Backend/TokenizeString', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.TokenizationResponse.FromString, + ) + self.Status = channel.unary_unary( + '/backend.Backend/Status', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.StatusResponse.FromString, + ) + + +class BackendServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Health(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def LoadModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PredictStream(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embedding(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AudioTranscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TTS(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TokenizeString(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BackendServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Health': grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'LoadModel': grpc.unary_unary_rpc_method_handler( + servicer.LoadModel, + request_deserializer=backend__pb2.ModelOptions.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'PredictStream': grpc.unary_stream_rpc_method_handler( + servicer.PredictStream, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Embedding': grpc.unary_unary_rpc_method_handler( + servicer.Embedding, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.EmbeddingResult.SerializeToString, + ), + 'GenerateImage': grpc.unary_unary_rpc_method_handler( + servicer.GenerateImage, + request_deserializer=backend__pb2.GenerateImageRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'AudioTranscription': grpc.unary_unary_rpc_method_handler( + servicer.AudioTranscription, + request_deserializer=backend__pb2.TranscriptRequest.FromString, + response_serializer=backend__pb2.TranscriptResult.SerializeToString, + ), + 'TTS': grpc.unary_unary_rpc_method_handler( + servicer.TTS, + request_deserializer=backend__pb2.TTSRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'TokenizeString': grpc.unary_unary_rpc_method_handler( + servicer.TokenizeString, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.TokenizationResponse.SerializeToString, + ), + 'Status': grpc.unary_unary_rpc_method_handler( + servicer.Status, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.StatusResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'backend.Backend', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Backend(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Health(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', + backend__pb2.ModelOptions.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PredictStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Embedding(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.EmbeddingResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', + backend__pb2.GenerateImageRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AudioTranscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', + backend__pb2.TranscriptRequest.SerializeToString, + backend__pb2.TranscriptResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TTS(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', + backend__pb2.TTSRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TokenizeString(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.TokenizationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/extra/grpc/ctransformers/c_transformers.py b/extra/grpc/ctransformers/c_transformers.py new file mode 100644 index 00000000000..df4636e7e1f --- /dev/null +++ b/extra/grpc/ctransformers/c_transformers.py @@ -0,0 +1,106 @@ +""" +This is the extra gRPC server of LocalAI +""" + +from __future__ import annotations +from typing import List +from concurrent import futures +import time +import argparse +import signal +import sys +import os + +import grpc +import backend_pb2 +import backend_pb2_grpc + +from ctransformers import AutoModelForCausalLM + +# Adapted from https://github.com/marella/ctransformers/tree/main#supported-models +# License: MIT +# Adapted by AIsuko +class ModelType: + GPT = "gpt2" + GPT_J_GPT4_ALL_J= "gptj" + GPT_NEOX_STABLE_LM = "gpt_neox" + FALCON= "falcon" + LLaMA_LLaMA2 = "llama" + MPT="mpt" + STAR_CODER_CHAT="gpt_bigcode" + DOLLY_V2="dolly-v2" + REPLIT="replit" + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + BackendServicer is the class that implements the gRPC service + """ + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + try: + model_path = request.Model + if not os.path.exists(model_path): + return backend_pb2.Result(success=False, message=f"Model path {model_path} does not exist") + + model_type = request.ModelType + if model_type not in ModelType.__dict__.values(): + return backend_pb2.Result(success=False, message=f"Model type {model_type} not supported") + llm = AutoModelForCausalLM.from_pretrained(model_file=model_path, model_type=model_type) + self.model=llm + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def Predict(self, request, context): + #TODO + return super().Predict(request, context) + + def PredictStream(self, request, context): + return super().PredictStream(request, context) + + def TokenizeString(self, request, context): + try: + token: List[int]=self.model.tokenize(request.prompt, add_bos_token=False) + l=len(token) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.TokenizationResponse(length=l, token=token) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) \ No newline at end of file diff --git a/extra/grpc/ctransformers/run.sh b/extra/grpc/ctransformers/run.sh new file mode 100755 index 00000000000..11cd5fdb4b4 --- /dev/null +++ b/extra/grpc/ctransformers/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +## +## A bash script wrapper that runs the ctransformers server with conda + +export PATH=$PATH:/opt/conda/bin + +# Activate conda environment +source activate ctransformers + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python $DIR/c_transformers.py $@ diff --git a/extra/grpc/huggingface/backend_pb2_grpc.py b/extra/grpc/huggingface/backend_pb2_grpc.py index 79a7677fb27..fc605ead42b 100644 --- a/extra/grpc/huggingface/backend_pb2_grpc.py +++ b/extra/grpc/huggingface/backend_pb2_grpc.py @@ -32,6 +32,7 @@ def __init__(self, channel): self.PredictStream = channel.unary_stream( '/backend.Backend/PredictStream', request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, ) self.Embedding = channel.unary_unary( diff --git a/extra/requirements.txt b/extra/requirements.txt deleted file mode 100644 index fb3cc0122a9..00000000000 --- a/extra/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -sentence_transformers -grpcio -google -protobuf -six -omegaconf -compel \ No newline at end of file From 1270c8e2eed9d7babe707b2a29c9d22b81dd020c Mon Sep 17 00:00:00 2001 From: Aisuko Date: Tue, 7 Nov 2023 01:10:50 +0000 Subject: [PATCH 2/3] rename folder Signed-off-by: GitHub --- extra/grpc/c_transformers/c_transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/extra/grpc/c_transformers/c_transformers.py b/extra/grpc/c_transformers/c_transformers.py index 848b0681903..eb3e5787e27 100644 --- a/extra/grpc/c_transformers/c_transformers.py +++ b/extra/grpc/c_transformers/c_transformers.py @@ -16,7 +16,6 @@ import backend_pb2_grpc from ctransformers import AutoModelForCausalLM -from ctransformer.llm import Config # Adapted from https://github.com/marella/ctransformers/tree/main#supported-models # License: MIT @@ -60,7 +59,6 @@ def LoadModel(self, request, context): return backend_pb2.Result(message="Model loaded successfully", success=True) def Predict(self, request, context): - #TODO return super().Predict(request, context) def PredictStream(self, request, context): From dc751d630bdb2a6e7932ac011409dfaec52c79e2 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 8 Nov 2023 00:32:55 +0000 Subject: [PATCH 3/3] Add prediction Signed-off-by: GitHub --- extra/grpc/c_transformers/c_transformers.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/extra/grpc/c_transformers/c_transformers.py b/extra/grpc/c_transformers/c_transformers.py index eb3e5787e27..3aa5f896fa2 100644 --- a/extra/grpc/c_transformers/c_transformers.py +++ b/extra/grpc/c_transformers/c_transformers.py @@ -15,7 +15,7 @@ import backend_pb2 import backend_pb2_grpc -from ctransformers import AutoModelForCausalLM +from ctransformers import AutoModelForCausalLM, AutoConfig, Config # Adapted from https://github.com/marella/ctransformers/tree/main#supported-models # License: MIT @@ -50,8 +50,7 @@ def LoadModel(self, request, context): return backend_pb2.Result(success=False, message=f"Model path {model_path} does not exist") model_type = request.ModelType if model_type not in ModelType.__dict__.values(): - return backend_pb2.Result(success=False, message=f"Model type {model_type} not supported") - + return backend_pb2.Result(success=False, message=f"Model type {model_type} not supported") llm = AutoModelForCausalLM.from_pretrained(model_file=model_path, model_type=model_type) self.model=llm except Exception as err: @@ -59,7 +58,11 @@ def LoadModel(self, request, context): return backend_pb2.Result(message="Model loaded successfully", success=True) def Predict(self, request, context): - return super().Predict(request, context) + try: + generated_text=self.model(request.prompt) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(message=bytes(generated_text), encoding="utf-8") def PredictStream(self, request, context): return super().PredictStream(request, context)