From 086206cb8f2eedfc76e4415fa59e39ef205bf2ba Mon Sep 17 00:00:00 2001 From: Tanner Kvarfordt Date: Wed, 10 Jul 2024 22:47:39 -0700 Subject: [PATCH] Updates to the Text Generation Endpoint (#42) * Updated README * Updated text generation task * Major rev --- README.md | 1 + audio_classification_test.go | 2 +- conversational_test.go | 2 +- examples/audio_classification/main.go | 2 +- examples/conversational/main.go | 2 +- examples/fill_mask/main.go | 2 +- examples/image_classification/main.go | 2 +- examples/image_segmentation/main.go | 2 +- examples/image_to_text/main.go | 2 +- examples/object_detection/main.go | 2 +- examples/question_answering/main.go | 2 +- examples/sentence_similarity/main.go | 2 +- examples/speech_recognition/main.go | 2 +- examples/summarization/main.go | 2 +- examples/table_question_answering/main.go | 2 +- examples/text_classification/main.go | 2 +- examples/text_generation/main.go | 25 +-- examples/text_to_image/main.go | 2 +- examples/token_classification/main.go | 2 +- examples/translation/main.go | 2 +- examples/zeroshot/main.go | 2 +- fill_mask_test.go | 2 +- go.mod | 2 +- image_classification_test.go | 2 +- image_segmentation_test.go | 2 +- image_to_text_test.go | 2 +- object_detection_test.go | 2 +- question_answering_test.go | 2 +- sentence_similarity_test.go | 2 +- setup_test.go | 2 +- speech_recognition_test.go | 2 +- summarization_test.go | 2 +- table_question_answering_test.go | 2 +- text_classification_test.go | 2 +- text_generation.go | 180 ++++++++++++++-------- text_generation_test.go | 57 +++---- text_to_image_test.go | 2 +- token_classification_test.go | 2 +- translation_test.go | 2 +- zeroshot_classification_test.go | 2 +- 40 files changed, 177 insertions(+), 158 deletions(-) diff --git a/README.md b/README.md index 969830a..34970a8 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ See the [examples](./examples) directory. ## Resources - [Hugging Face](https://huggingface.co/) + - [Inference API JSON Schema](https://huggingface.github.io/text-generation-inference/openapi.json) - [Model Hub](https://huggingface.co/models) - [Datasets](https://huggingface.co/datasets) - [Hugging Face Inference API](https://api-inference.huggingface.co/docs/python/html/index.html) (HF API) diff --git a/audio_classification_test.go b/audio_classification_test.go index 606782a..b70d48e 100644 --- a/audio_classification_test.go +++ b/audio_classification_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestAudioClassificationRequest(t *testing.T) { diff --git a/conversational_test.go b/conversational_test.go index 485adef..a103f18 100644 --- a/conversational_test.go +++ b/conversational_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/examples/audio_classification/main.go b/examples/audio_classification/main.go index 99f64e3..711e053 100644 --- a/examples/audio_classification/main.go +++ b/examples/audio_classification/main.go @@ -5,7 +5,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/conversational/main.go b/examples/conversational/main.go index 187fc9a..60915b1 100644 --- a/examples/conversational/main.go +++ b/examples/conversational/main.go @@ -9,7 +9,7 @@ import ( "syscall" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/fill_mask/main.go b/examples/fill_mask/main.go index 09d96a3..3586c7a 100644 --- a/examples/fill_mask/main.go +++ b/examples/fill_mask/main.go @@ -7,7 +7,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/image_classification/main.go b/examples/image_classification/main.go index 4118958..aad6902 100644 --- a/examples/image_classification/main.go +++ b/examples/image_classification/main.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/image_segmentation/main.go b/examples/image_segmentation/main.go index 729b1eb..2d2663c 100644 --- a/examples/image_segmentation/main.go +++ b/examples/image_segmentation/main.go @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "golang.org/x/image/font" "golang.org/x/image/font/basicfont" "golang.org/x/image/math/fixed" diff --git a/examples/image_to_text/main.go b/examples/image_to_text/main.go index 07c43e1..fea7fe8 100644 --- a/examples/image_to_text/main.go +++ b/examples/image_to_text/main.go @@ -6,7 +6,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/object_detection/main.go b/examples/object_detection/main.go index 5913807..17714d3 100644 --- a/examples/object_detection/main.go +++ b/examples/object_detection/main.go @@ -13,7 +13,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/question_answering/main.go b/examples/question_answering/main.go index c2dca70..1df24c2 100644 --- a/examples/question_answering/main.go +++ b/examples/question_answering/main.go @@ -5,7 +5,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/sentence_similarity/main.go b/examples/sentence_similarity/main.go index 65e0aa8..d154f34 100644 --- a/examples/sentence_similarity/main.go +++ b/examples/sentence_similarity/main.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/speech_recognition/main.go b/examples/speech_recognition/main.go index 3ca0203..4c3988c 100644 --- a/examples/speech_recognition/main.go +++ b/examples/speech_recognition/main.go @@ -5,7 +5,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/summarization/main.go b/examples/summarization/main.go index 89e8399..4a4164f 100644 --- a/examples/summarization/main.go +++ b/examples/summarization/main.go @@ -5,7 +5,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/table_question_answering/main.go b/examples/table_question_answering/main.go index d77a7b1..768e0fe 100644 --- a/examples/table_question_answering/main.go +++ b/examples/table_question_answering/main.go @@ -6,7 +6,7 @@ import ( "text/tabwriter" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const TableRows = 2 diff --git a/examples/text_classification/main.go b/examples/text_classification/main.go index 2076564..d78ab61 100644 --- a/examples/text_classification/main.go +++ b/examples/text_classification/main.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/text_generation/main.go b/examples/text_generation/main.go index 38fd705..79823d1 100644 --- a/examples/text_generation/main.go +++ b/examples/text_generation/main.go @@ -3,10 +3,9 @@ package main import ( "fmt" "os" - "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" @@ -19,13 +18,9 @@ func init() { } func main() { - inputs := []string{ - "The answer to life, the universe, and everything is", - "Somebody once told me that the world is gonna roll me", - } - const numReturnSequences = 3 + input := "The answer to life, the universe, and everything is" - fmt.Printf("Inputs: [\"%s\"]\n", strings.Join(inputs, `", "`)) + fmt.Printf("Input: \"%s\"\n", input) type ChanRv struct { resps []*hfapigo.TextGenerationResponse @@ -36,9 +31,8 @@ func main() { fmt.Print("Sending request") go func() { resps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{ - Inputs: inputs, - Parameters: *hfapigo.NewTextGenerationParameters().SetNumReturnSequences(numReturnSequences), - Options: *hfapigo.NewOptions().SetWaitForModel(true), + Input: input, + Options: *hfapigo.NewOptions().SetWaitForModel(true), }) ch <- ChanRv{resps, err} }() @@ -51,14 +45,7 @@ func main() { fmt.Println(chrv.err) return } - for i := range inputs { - fmt.Printf("\nInput %d results:\n", i) - for _, gt := range chrv.resps[i].GeneratedTexts { - gt = strings.Replace(gt, "\n", " ", -1) - gt = strings.Replace(gt, "\r", " ", -1) - fmt.Println(gt) - } - } + fmt.Printf("Response: %s\n", chrv.resps[0].GeneratedText) return default: fmt.Print(".") diff --git a/examples/text_to_image/main.go b/examples/text_to_image/main.go index 8a09a05..03075ba 100644 --- a/examples/text_to_image/main.go +++ b/examples/text_to_image/main.go @@ -10,7 +10,7 @@ import ( "os" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/token_classification/main.go b/examples/token_classification/main.go index 1f909f3..e369f2e 100644 --- a/examples/token_classification/main.go +++ b/examples/token_classification/main.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/translation/main.go b/examples/translation/main.go index 5721df9..595dfff 100644 --- a/examples/translation/main.go +++ b/examples/translation/main.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/examples/zeroshot/main.go b/examples/zeroshot/main.go index 00597df..0252a01 100644 --- a/examples/zeroshot/main.go +++ b/examples/zeroshot/main.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/fill_mask_test.go b/fill_mask_test.go index 6441173..fe015e1 100644 --- a/fill_mask_test.go +++ b/fill_mask_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/go.mod b/go.mod index 5e0fc30..850aad4 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/Kardbord/hfapigo/v2 +module github.com/Kardbord/hfapigo/v3 go 1.17 diff --git a/image_classification_test.go b/image_classification_test.go index 25f3253..dd9a91b 100644 --- a/image_classification_test.go +++ b/image_classification_test.go @@ -3,7 +3,7 @@ package hfapigo_test import ( "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestImageClassificationRequest(t *testing.T) { diff --git a/image_segmentation_test.go b/image_segmentation_test.go index fc80f83..c105882 100644 --- a/image_segmentation_test.go +++ b/image_segmentation_test.go @@ -3,7 +3,7 @@ package hfapigo_test import ( "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestImageSegmentationRequest(t *testing.T) { diff --git a/image_to_text_test.go b/image_to_text_test.go index ad96e55..8c76e98 100644 --- a/image_to_text_test.go +++ b/image_to_text_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestImageToText(t *testing.T) { diff --git a/object_detection_test.go b/object_detection_test.go index 405b7c3..f484858 100644 --- a/object_detection_test.go +++ b/object_detection_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestObjectDetectionRequest(t *testing.T) { diff --git a/question_answering_test.go b/question_answering_test.go index 8b11657..b9e45f5 100644 --- a/question_answering_test.go +++ b/question_answering_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/sentence_similarity_test.go b/sentence_similarity_test.go index 1846c6d..67e0b13 100644 --- a/sentence_similarity_test.go +++ b/sentence_similarity_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/setup_test.go b/setup_test.go index 21d9055..28b7d96 100644 --- a/setup_test.go +++ b/setup_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) const HuggingFaceTokenEnv = "HUGGING_FACE_TOKEN" diff --git a/speech_recognition_test.go b/speech_recognition_test.go index 375bfed..712726f 100644 --- a/speech_recognition_test.go +++ b/speech_recognition_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestSpeechRecognitionRequest(t *testing.T) { diff --git a/summarization_test.go b/summarization_test.go index c9e7e34..c58928d 100644 --- a/summarization_test.go +++ b/summarization_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/table_question_answering_test.go b/table_question_answering_test.go index 9a2fdf5..9d4d399 100644 --- a/table_question_answering_test.go +++ b/table_question_answering_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/text_classification_test.go b/text_classification_test.go index 23cbc2b..af19e6c 100644 --- a/text_classification_test.go +++ b/text_classification_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/text_generation.go b/text_generation.go index 0555b06..331de8a 100644 --- a/text_generation.go +++ b/text_generation.go @@ -6,103 +6,159 @@ import ( "fmt" ) -const RecommendedTextGenerationModel = "gpt2-large" +const ( + RecommendedTextGenerationModel = "microsoft/phi-2" + TextGenerationGrammarTypeJSON = "json" + TextGenerationGrammarTypeRegex = "regex" +) type TextGenerationRequest struct { // (Required) a string to be generated from - Inputs []string `json:"inputs,omitempty"` + Input string `json:"inputs,omitempty"` Parameters TextGenerationParameters `json:"parameters,omitempty"` Options Options `json:"options,omitempty"` } type TextGenerationParameters struct { - // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. - TopK *int `json:"top_k,omitempty"` - - // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add - // tokens in the sample for more probable to least probable until the sum of the probabilities is greater - // than top_p. - TopP *float64 `json:"top_p,omitempty"` - - // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, - // 0 means top_k=1, 100.0 is getting closer to uniform probability. - Temperature *float64 `json:"temperature,omitempty"` - - // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized - // to not be picked in successive generation passes. - RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` - - // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input - // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, - // so look for balance between response times and length of text generated. - MaxNewTokens *int `json:"max_new_tokens,omitempty"` - - // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. - // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens - // for best results. - MaxTime *float64 `json:"max_time,omitempty"` - - // (Default: True). Bool. If set to False, the return results will not contain the original query making it - // easier for prompting. - ReturnFullText *bool `json:"return_full_text,omitempty"` - - // (Default: 1). Integer. The number of proposition you want to be returned. - NumReturnSequences *int `json:"num_return_sequences,omitempty"` + BestOf *int `json:"best_of,omitempty"` + DecoderInputDetails *bool `json:"decoder_input_details,omitempty"` + Details *bool `json:"details,omitempty"` + DoSample *bool `json:"do_sample,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + Grammar *string `json:"grammar,omitempty"` + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + ReturnFullText *bool `json:"return_full_text,omitempty"` + Seed *int64 `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopNTokens *int `json:"top_n_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Truncate *int `json:"truncate,omitempty"` + TypicalP *float64 `json:"typical_p,omitempty"` + Watermark *bool `json:"watermark,omitempty"` } func NewTextGenerationParameters() *TextGenerationParameters { return &TextGenerationParameters{} } -func (params *TextGenerationParameters) SetTopK(topK int) *TextGenerationParameters { - params.TopK = &topK +func (params *TextGenerationParameters) SetBestOf(bestOf int) *TextGenerationParameters { + params.BestOf = &bestOf return params } -func (params *TextGenerationParameters) SetTopP(topP float64) *TextGenerationParameters { - params.TopP = &topP +func (params *TextGenerationParameters) SetDecoderInputDetails(decoderInputDetails bool) *TextGenerationParameters { + params.DecoderInputDetails = &decoderInputDetails return params } -func (params *TextGenerationParameters) SetTempurature(temp float64) *TextGenerationParameters { - params.Temperature = &temp +func (params *TextGenerationParameters) SetDetails(details bool) *TextGenerationParameters { + params.Details = &details return params } -func (params *TextGenerationParameters) SetRepetitionPenaly(penalty float64) *TextGenerationParameters { - params.RepetitionPenalty = &penalty +func (params *TextGenerationParameters) SetDoSample(doSample bool) *TextGenerationParameters { + params.DoSample = &doSample + return params +} +func (params *TextGenerationParameters) SetFrequencyPenalty(frequencyPenalty float64) *TextGenerationParameters { + params.FrequencyPenalty = &frequencyPenalty + return params +} +func (params *TextGenerationParameters) SetGrammar(grammar string) *TextGenerationParameters { + params.Grammar = &grammar return params } func (params *TextGenerationParameters) SetMaxNewTokens(maxNewTokens int) *TextGenerationParameters { params.MaxNewTokens = &maxNewTokens return params } -func (params *TextGenerationParameters) SetMaxTime(maxTime float64) *TextGenerationParameters { - params.MaxTime = &maxTime +func (params *TextGenerationParameters) SetRepetitionPenalty(repetitionPenalty float64) *TextGenerationParameters { + params.RepetitionPenalty = &repetitionPenalty return params } func (params *TextGenerationParameters) SetReturnFullText(returnFullText bool) *TextGenerationParameters { params.ReturnFullText = &returnFullText return params } -func (params *TextGenerationParameters) SetNumReturnSequences(numReturnSequences int) *TextGenerationParameters { - params.NumReturnSequences = &numReturnSequences +func (params *TextGenerationParameters) SetSeed(seed int64) *TextGenerationParameters { + params.Seed = &seed + return params +} +func (params *TextGenerationParameters) SetStop(stop []string) *TextGenerationParameters { + params.Stop = stop + return params +} +func (params *TextGenerationParameters) SetTemperature(temperature float64) *TextGenerationParameters { + params.Temperature = &temperature + return params +} +func (params *TextGenerationParameters) SetTopK(topK int) *TextGenerationParameters { + params.TopK = &topK + return params +} +func (params *TextGenerationParameters) SetTopNTokens(topNTokens int) *TextGenerationParameters { + params.TopNTokens = &topNTokens + return params +} +func (params *TextGenerationParameters) SetTopP(topP float64) *TextGenerationParameters { + params.TopP = &topP + return params +} +func (params *TextGenerationParameters) SetTruncate(truncate int) *TextGenerationParameters { + params.Truncate = &truncate + return params +} +func (params *TextGenerationParameters) SetTypicalP(typicalP float64) *TextGenerationParameters { + params.TypicalP = &typicalP + return params +} +func (params *TextGenerationParameters) SetWatermark(watermark bool) *TextGenerationParameters { + params.Watermark = &watermark + return params +} +func (params *TextGenerationParameters) SetRepetitionPenaly(penalty float64) *TextGenerationParameters { + params.RepetitionPenalty = &penalty return params } type TextGenerationResponse struct { - // A list of generated texts. The length of this list is the value of - // NumReturnSequences in the request. - GeneratedTexts []string + GeneratedText string `json:"generated_text,omitempty"` + Details TextGenerationResponseDetails `json:"details,omitempty"` } -type textGenerationResponseSequence struct { - GeneratedText string `json:"generated_text,omitempty"` +type TextGenerationResponseDetails struct { + BestOfSequences []*TextGenerationBestOfSequence `json:"best_of_sequences,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + Prefill []*TextGenerationPrefillToken `json:"prefill,omitempty"` + Seed int64 `json:"seed,omitempty"` + Tokens []*TextGenerationToken `json:"tokens,omitempty"` + TopTokens []*TextGenerationToken `json:"top_tokens,omitempty"` } -func (tgs textGenerationResponseSequence) String() string { - return tgs.GeneratedText +type TextGenerationBestOfSequence struct { + FinishReason string `json:"finish_reason,omitempty"` + GeneratedText string `json:"generated_text,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + Prefill []*TextGenerationPrefillToken `json:"prefill,omitempty"` + Seed int64 `json:"seed,omitempty"` + Tokens []*TextGenerationToken `json:"tokens,omitempty"` + TopTokens [][]*TextGenerationToken `json:"top_tokens,omitempty"` +} + +type TextGenerationPrefillToken struct { + ID int `json:"id,omitempty"` + LogProb float64 `json:"logprob,omitempty"` + Text string `json:"text,omitempty"` +} + +type TextGenerationToken struct { + TextGenerationPrefillToken + Special bool `json:"special,omitempty"` } func SendTextGenerationRequest(model string, request *TextGenerationRequest) ([]*TextGenerationResponse, error) { if request == nil { - return nil, errors.New("nil SummarizationRequest") + return nil, errors.New("nil TextGenerationRequest") } jsonBuf, err := json.Marshal(request) @@ -115,21 +171,13 @@ func SendTextGenerationRequest(model string, request *TextGenerationRequest) ([] return nil, err } - tgrespsRaw := make([][]*textGenerationResponseSequence, len(request.Inputs)) - err = json.Unmarshal(respBody, &tgrespsRaw) + tgresps := make([]*TextGenerationResponse, 1) + err = json.Unmarshal(respBody, &tgresps) if err != nil { return nil, err } - if len(tgrespsRaw) != len(request.Inputs) { - return nil, fmt.Errorf("expected %d responses, got %d; response=%s", len(request.Inputs), len(tgrespsRaw), string(respBody)) - } - - tgresps := make([]*TextGenerationResponse, len(request.Inputs)) - for i := range tgrespsRaw { - tgresps[i] = &TextGenerationResponse{} - for _, t := range tgrespsRaw[i] { - tgresps[i].GeneratedTexts = append(tgresps[i].GeneratedTexts, t.GeneratedText) - } + if len(tgresps) < 1 { + return nil, fmt.Errorf("expected at least 1 response, got none; response=%s", string(respBody)) } return tgresps, nil diff --git a/text_generation_test.go b/text_generation_test.go index aedfece..8030968 100644 --- a/text_generation_test.go +++ b/text_generation_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) @@ -12,7 +12,7 @@ func TestMarshalUnMarshalTextGenerationRequest(t *testing.T) { // No options { tgExpected := hfapigo.TextGenerationRequest{ - Inputs: []string{"The answer to the universe is"}, + Input: "The answer to the universe is", } jsonBuf, err := json.Marshal(tgExpected) @@ -34,9 +34,8 @@ func TestMarshalUnMarshalTextGenerationRequest(t *testing.T) { // Options { tgExpected := hfapigo.TextGenerationRequest{ - Inputs: []string{"The answer to the universe is"}, + Input: "The answer to the universe is", Parameters: *hfapigo.NewTextGenerationParameters(). - SetMaxTime(12.2). SetMaxNewTokens(240). SetReturnFullText(false), Options: *hfapigo.NewOptions().SetWaitForModel(true), @@ -62,57 +61,41 @@ func TestMarshalUnMarshalTextGenerationRequest(t *testing.T) { func TestTextGenerationRequest(t *testing.T) { // Basic request { - inputs := []string{"The answer to the universe is"} - const returnSeqs = 1 + input := "The answer to the universe is" tgresps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{ - Inputs: inputs, + Input: input, Options: *hfapigo.NewOptions().SetWaitForModel(true), }) if err != nil { t.Fatal(err) } - if len(tgresps) != len(inputs) { - t.Fatalf("expected %d response", len(inputs)) + if len(tgresps) != 1 { + t.Fatalf("expected 1 response, got %d", len(tgresps)) } - for i := range inputs { - if len(tgresps[i].GeneratedTexts) != returnSeqs { - t.Fatalf("expected non-empty list of generated texts") - } - for j := 0; j < returnSeqs; j++ { - if tgresps[i].GeneratedTexts[j] == "" { - t.Fatal("expected non-empty generated text") - } - } + if tgresps[0].GeneratedText == "" { + t.Fatal("expected non-empty generated text") } } // More complicated request { - inputs := []string{ - "The answer to the universe is", - "There once was a ship that put to sea", - } - const returnSeqs = 3 + input := "There once was a ship that put to sea" tgresps, err := hfapigo.SendTextGenerationRequest(hfapigo.RecommendedTextGenerationModel, &hfapigo.TextGenerationRequest{ - Inputs: inputs, - Parameters: *hfapigo.NewTextGenerationParameters().SetRepetitionPenaly(50.235).SetReturnFullText(false).SetNumReturnSequences(returnSeqs), + Input: input, + Parameters: *hfapigo.NewTextGenerationParameters().SetRepetitionPenaly(50.235).SetReturnFullText(false).SetDetails(true), Options: *hfapigo.NewOptions().SetWaitForModel(true), }) if err != nil { t.Fatal(err) } - if len(tgresps) != len(inputs) { - t.Fatalf("expected %d responses", len(inputs)) - } - for i := range inputs { - if len(tgresps[i].GeneratedTexts) != returnSeqs { - t.Fatalf("expected non-empty list of generated texts") - } - for j := 0; j < returnSeqs; j++ { - if tgresps[i].GeneratedTexts[j] == "" { - t.Fatal("expected non-empty generated text") - } - } + if len(tgresps) != 1 { + t.Fatalf("expected 1 response, got %d", len(tgresps)) + } + if tgresps[0].GeneratedText == "" { + t.Fatal("expected non-empty generated text") + } + if tgresps[0].Details.FinishReason == "" { + t.Fatal("expected non-empty finish reason") } } diff --git a/text_to_image_test.go b/text_to_image_test.go index 9f737a7..0167fd7 100644 --- a/text_to_image_test.go +++ b/text_to_image_test.go @@ -3,7 +3,7 @@ package hfapigo_test import ( "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" ) func TestTextToImage(t *testing.T) { diff --git a/token_classification_test.go b/token_classification_test.go index 6515d67..a2e363d 100644 --- a/token_classification_test.go +++ b/token_classification_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/translation_test.go b/translation_test.go index 1b30f95..637f989 100644 --- a/translation_test.go +++ b/translation_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" ) diff --git a/zeroshot_classification_test.go b/zeroshot_classification_test.go index a6e8d36..4399737 100644 --- a/zeroshot_classification_test.go +++ b/zeroshot_classification_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/Kardbord/hfapigo/v2" + "github.com/Kardbord/hfapigo/v3" "github.com/google/go-cmp/cmp" )