|  | 
| 4 | 4 | #include "llama.h" | 
| 5 | 5 | 
 | 
| 6 | 6 | #include <ctime> | 
|  | 7 | +#include <cstdio> | 
| 7 | 8 | #include <algorithm> | 
| 8 | 9 | 
 | 
| 9 | 10 | #if defined(_MSC_VER) | 
| @@ -70,6 +71,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | 
| 70 | 71 |     } | 
| 71 | 72 | } | 
| 72 | 73 | 
 | 
|  | 74 | +// plain, pipe-friendly output: one embedding per line | 
|  | 75 | +static void print_raw_embeddings(const float * emb, | 
|  | 76 | +                                 int n_embd_count, | 
|  | 77 | +                                 int n_embd, | 
|  | 78 | +                                 const llama_model * model, | 
|  | 79 | +                                 enum llama_pooling_type pooling_type, | 
|  | 80 | +                                 int embd_normalize) { | 
|  | 81 | +    const uint32_t n_cls_out = llama_model_n_cls_out(model); | 
|  | 82 | +    const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); | 
|  | 83 | +    const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd; | 
|  | 84 | + | 
|  | 85 | +    for (int j = 0; j < n_embd_count; ++j) { | 
|  | 86 | +        for (int i = 0; i < cols; ++i) { | 
|  | 87 | +            if (embd_normalize == 0) { | 
|  | 88 | +                printf("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); | 
|  | 89 | +            } else { | 
|  | 90 | +                printf("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); | 
|  | 91 | +            } | 
|  | 92 | +        } | 
|  | 93 | +        printf("\n"); | 
|  | 94 | +    } | 
|  | 95 | +} | 
|  | 96 | + | 
| 73 | 97 | int main(int argc, char ** argv) { | 
| 74 | 98 |     common_params params; | 
| 75 | 99 | 
 | 
| @@ -259,6 +283,10 @@ int main(int argc, char ** argv) { | 
| 259 | 283 |     float * out = emb + e * n_embd; | 
| 260 | 284 |     batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); | 
| 261 | 285 | 
 | 
|  | 286 | +    if (params.embd_out == "raw") { | 
|  | 287 | +        print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); | 
|  | 288 | +    } | 
|  | 289 | + | 
| 262 | 290 |     if (params.embd_out.empty()) { | 
| 263 | 291 |         LOG("\n"); | 
| 264 | 292 | 
 | 
|  | 
0 commit comments