Skip to content

Commit f6ed21f

Browse files
iamlemecabetlen
andauthored
feat: Allow for possibly non-pooled embeddings (abetlen#1380)
* allow for possibly non-pooled embeddings * add more to embeddings section in README.md --------- Co-authored-by: Andrei <[email protected]>
1 parent fcfea66 commit f6ed21f

File tree

5 files changed

+67
-20
lines changed

5 files changed

+67
-20
lines changed

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ llama = Llama(
575575

576576
### Embeddings
577577

578-
To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding).
578+
To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding) or [`embed`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.embed). Note that you must pass `embedding=True` to the constructor upon model creation for these to work properly.
579579

580580
```python
581581
import llama_cpp
@@ -589,6 +589,12 @@ embeddings = llm.create_embedding("Hello, world!")
589589
embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"])
590590
```
591591

592+
There are two primary notions of embeddings in a Transformer-style model: *token level* and *sequence level*. Sequence level embeddings are produced by "pooling" token level embeddings together, usually by averaging them or using the first token.
593+
594+
Models that are explicitly geared towards embeddings will usually return sequence level embeddings by default, one for each input string. Non-embedding models such as those designed for text generation will typically return only token level embeddings, one for each token in each sequence. Thus the dimensionality of the return type will be one higher for token level embeddings.
595+
596+
It is possible to control pooling behavior in some cases using the `pooling_type` flag on model creation. You can ensure token level embeddings from any model using `LLAMA_POOLING_TYPE_NONE`. The reverse, getting a generation oriented model to yield sequence level embeddings is currently not possible, but you can always do the pooling manually.
597+
592598
### Adjusting the Context Window
593599

594600
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

llama_cpp/_internals.py

+14
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def n_ctx(self) -> int:
273273
assert self.ctx is not None
274274
return llama_cpp.llama_n_ctx(self.ctx)
275275

276+
def pooling_type(self) -> int:
277+
assert self.ctx is not None
278+
return llama_cpp.llama_pooling_type(self.ctx)
279+
276280
def kv_cache_clear(self):
277281
assert self.ctx is not None
278282
llama_cpp.llama_kv_cache_clear(self.ctx)
@@ -641,6 +645,16 @@ def _should_add_bos(model: _LlamaModel) -> bool:
641645
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
642646

643647

648+
# Embedding functions
649+
650+
651+
def _normalize_embedding(embedding):
652+
norm = float(np.linalg.norm(embedding))
653+
if norm == 0.0:
654+
return embedding
655+
return [v / norm for v in embedding]
656+
657+
644658
# Python wrappers over common/sampling structs
645659

646660

llama_cpp/llama.py

+39-18
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_LlamaTokenDataArray, # type: ignore
5151
_LlamaSamplingParams, # type: ignore
5252
_LlamaSamplingContext, # type: ignore
53+
_normalize_embedding, # type: ignore
5354
)
5455
from ._logger import set_verbose
5556
from ._utils import suppress_stdout_stderr
@@ -760,7 +761,7 @@ def create_embedding(
760761
input = input if isinstance(input, list) else [input]
761762

762763
# get numeric embeddings
763-
embeds: List[List[float]]
764+
embeds: Union[List[List[float]], List[List[List[float]]]]
764765
total_tokens: int
765766
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
766767

@@ -787,7 +788,7 @@ def create_embedding(
787788
def embed(
788789
self,
789790
input: Union[str, List[str]],
790-
normalize: bool = True,
791+
normalize: bool = False,
791792
truncate: bool = True,
792793
return_count: bool = False,
793794
):
@@ -803,6 +804,10 @@ def embed(
803804
n_embd = self.n_embd()
804805
n_batch = self.n_batch
805806

807+
# get pooling information
808+
pooling_type = self.pooling_type()
809+
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
810+
806811
if self.context_params.embeddings == False:
807812
raise RuntimeError(
808813
"Llama model must be created with embedding=True to call this method"
@@ -820,29 +825,37 @@ def embed(
820825
self._batch.reset()
821826

822827
# decode and fetch embeddings
823-
data: List[List[float]] = []
828+
data: Union[List[List[float]], List[List[List[float]]]] = []
824829

825-
def decode_batch(n_seq: int):
830+
def decode_batch(seq_sizes: List[int]):
826831
assert self._ctx.ctx is not None
827832
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
828833
self._ctx.decode(self._batch)
829834
self._batch.reset()
830835

831836
# store embeddings
832-
for i in range(n_seq):
833-
ptr = llama_cpp.llama_get_embeddings_seq(
834-
self._ctx.ctx, i
835-
)
836-
if not ptr:
837-
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
838-
embedding: List[float] = ptr[:n_embd]
839-
if normalize:
840-
norm = float(np.linalg.norm(embedding))
841-
embedding = [v / norm for v in embedding]
842-
data.append(embedding)
837+
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
838+
pos: int = 0
839+
for i, size in enumerate(seq_sizes):
840+
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
841+
embedding: List[List[float]] = [
842+
ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size)
843+
]
844+
if normalize:
845+
embedding = [_normalize_embedding(e) for e in embedding]
846+
data.append(embedding)
847+
pos += size
848+
else:
849+
for i in range(len(seq_sizes)):
850+
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
851+
embedding: List[float] = ptr[:n_embd]
852+
if normalize:
853+
embedding = _normalize_embedding(embedding)
854+
data.append(embedding)
843855

844856
# init state
845857
total_tokens = 0
858+
s_batch = []
846859
t_batch = 0
847860
p_batch = 0
848861

@@ -863,17 +876,21 @@ def decode_batch(n_seq: int):
863876

864877
# time to eval batch
865878
if t_batch + n_tokens > n_batch:
866-
decode_batch(p_batch)
879+
decode_batch(s_batch)
880+
s_batch = []
867881
t_batch = 0
868882
p_batch = 0
869883

870884
# add to batch
871-
self._batch.add_sequence(tokens, p_batch, False)
885+
self._batch.add_sequence(tokens, p_batch, logits_all)
886+
887+
# update batch stats
888+
s_batch.append(n_tokens)
872889
t_batch += n_tokens
873890
p_batch += 1
874891

875892
# hanlde last batch
876-
decode_batch(p_batch)
893+
decode_batch(s_batch)
877894

878895
if self.verbose:
879896
llama_cpp.llama_print_timings(self._ctx.ctx)
@@ -1845,6 +1862,10 @@ def token_nl(self) -> int:
18451862
"""Return the newline token."""
18461863
return self._model.token_nl()
18471864

1865+
def pooling_type(self) -> str:
1866+
"""Return the pooling type."""
1867+
return self._ctx.pooling_type()
1868+
18481869
@staticmethod
18491870
def logits_to_logprobs(
18501871
logits: Union[npt.NDArray[np.single], List], axis: int = -1

llama_cpp/llama_cpp.py

+6
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,12 @@ def llama_rope_type(model: llama_model_p, /) -> int:
11891189
...
11901190

11911191

1192+
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
1193+
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
1194+
def llama_pooling_type(model: llama_model_p, /) -> int:
1195+
...
1196+
1197+
11921198
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
11931199
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
11941200
def llama_n_vocab(model: llama_model_p, /) -> int:

llama_cpp/llama_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class EmbeddingUsage(TypedDict):
2424
class Embedding(TypedDict):
2525
index: int
2626
object: str
27-
embedding: List[float]
27+
embedding: Union[List[float], List[List[float]]]
2828

2929

3030
class CreateEmbeddingResponse(TypedDict):

0 commit comments

Comments
 (0)