diff --git a/.github/workflows/llama-cpp-rs-check.yml b/.github/workflows/llama-cpp-rs-check.yml index 9e6fe4b8..095721bb 100644 --- a/.github/workflows/llama-cpp-rs-check.yml +++ b/.github/workflows/llama-cpp-rs-check.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: submodules: recursive - name: Install Compile Deps @@ -43,13 +43,13 @@ jobs: target: [ linux/arm64, linux/amd64 ] steps: - name: checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 - name: Setup QEMU - uses: docker/setup-qemu-action@49b3bc8e6bdd4a60e6116a5414239cba5943d3cf + uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 with: platforms: arm64,amd64 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db + uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 - name: Build uses: docker/build-push-action@v6 with: @@ -61,7 +61,7 @@ jobs: runs-on: macos-latest steps: - name: checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: submodules: recursive - name: Setup Rust @@ -73,7 +73,7 @@ jobs: runs-on: windows-latest steps: - name: checkout - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: submodules: recursive - name: Setup Rust @@ -81,4 +81,23 @@ jobs: - name: Build run: cargo build --features sampler - name: Test - run: cargo test --features sampler \ No newline at end of file + run: cargo test --features sampler + windows-vulkan: + name: Check that it builds on windows with vulkan + runs-on: windows-latest + steps: + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 + with: + submodules: recursive + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + - name: Install Vulkan SDK + uses: jakoch/install-vulkan-sdk-action@v1.0.5 + with: + vulkan_version: 1.3.296.0 + install_runtime: true + cache: true + stripdown: true + - name: Build + run: cargo build --features "sampler vulkan" --verbose diff --git a/.github/workflows/publish-upon-release.yml b/.github/workflows/publish-upon-release.yml index 1e3cc18b..b470e3f1 100644 --- a/.github/workflows/publish-upon-release.yml +++ b/.github/workflows/publish-upon-release.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: submodules: recursive - name: Publish crates for llama-cpp-sys-2 diff --git a/.github/workflows/update-llama-cpp.yml b/.github/workflows/update-llama-cpp.yml index 48e83f7e..230b3ee6 100644 --- a/.github/workflows/update-llama-cpp.yml +++ b/.github/workflows/update-llama-cpp.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Set date run: echo "DATE=$(date -I)" >> $GITHUB_ENV - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 name: Checkout latest with: submodules: recursive diff --git a/.github/workflows/update-toml-version.yaml b/.github/workflows/update-toml-version.yaml index f7446d3e..5055e8ba 100644 --- a/.github/workflows/update-toml-version.yaml +++ b/.github/workflows/update-toml-version.yaml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 with: submodules: recursive diff --git a/.gitmodules b/.gitmodules index 625b54c7..0dfa7e0d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "llama-cpp-sys-2/llama.cpp"] path = llama-cpp-sys-2/llama.cpp - url = https://github.com/ggerganov/llama.cpp + url = https://github.com/ggml-org/llama.cpp diff --git a/Cargo.lock b/Cargo.lock index 899deed9..683832a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "adler" @@ -68,9 +68,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "base64" @@ -80,9 +80,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.4" +version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ "bitflags", "cexpr", @@ -109,9 +109,9 @@ checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "cc" -version = "1.1.28" +version = "1.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "956a5e21988b87f372569b66183b78babf23ebc2e744b733e4350a752c4dafac" dependencies = [ "jobserver", "libc", @@ -146,9 +146,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" dependencies = [ "clap_builder", "clap_derive", @@ -156,9 +156,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" dependencies = [ "anstream", "anstyle", @@ -168,9 +168,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ "heck", "proc-macro2", @@ -180,15 +180,15 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "cmake" -version = "0.1.51" +version = "0.1.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" dependencies = [ "cc", ] @@ -277,7 +277,7 @@ checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b" [[package]] name = "embeddings" -version = "0.1.83" +version = "0.1.109" dependencies = [ "anyhow", "clap", @@ -293,27 +293,27 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] [[package]] name = "enumflags2" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d232db7f5956f3f14313dc2f87985c58bd2c695ce124c8cdd984e08e15ac133d" +checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147" dependencies = [ "enumflags2_derive", ] [[package]] name = "enumflags2_derive" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de0d48a183585823424a4ce1aa132d174a6a81bd540895822eb4c8373a8e49e8" +checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79" dependencies = [ "proc-macro2", "quote", @@ -336,6 +336,15 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "find_cuda_helper" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad" +dependencies = [ + "glob", +] + [[package]] name = "flate2" version = "1.0.30" @@ -383,9 +392,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "heck" @@ -653,23 +662,26 @@ checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" [[package]] name = "llama-cpp-2" -version = "0.1.83" +version = "0.1.109" dependencies = [ "encoding_rs", "enumflags2", "llama-cpp-sys-2", "thiserror", "tracing", + "tracing-core", ] [[package]] name = "llama-cpp-sys-2" -version = "0.1.83" +version = "0.1.109" dependencies = [ "bindgen", "cc", "cmake", + "find_cuda_helper", "glob", + "walkdir", ] [[package]] @@ -726,6 +738,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -740,9 +762,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ "bitflags", "cfg-if", @@ -772,9 +794,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" dependencies = [ "cc", "libc", @@ -788,6 +810,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -916,17 +944,27 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "reranker" +version = "0.1.86" +dependencies = [ + "anyhow", + "clap", + "encoding_rs", + "hf-hub", + "llama-cpp-2", +] + [[package]] name = "ring" -version = "0.17.8" +version = "0.17.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" dependencies = [ "cc", "cfg-if", "getrandom", "libc", - "spin", "untrusted", "windows-sys 0.52.0", ] @@ -987,6 +1025,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1050,6 +1097,15 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1058,13 +1114,14 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "simple" -version = "0.1.83" +version = "0.1.109" dependencies = [ "anyhow", "clap", "encoding_rs", "hf-hub", "llama-cpp-2", + "tracing-subscriber", ] [[package]] @@ -1073,12 +1130,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1099,9 +1150,9 @@ checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" -version = "2.0.66" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1133,24 +1184,34 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tinystr" version = "0.7.6" @@ -1163,9 +1224,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1174,9 +1235,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -1185,11 +1246,50 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", + "tracing-serde", ] [[package]] @@ -1259,12 +1359,28 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1292,6 +1408,37 @@ dependencies = [ "rustix", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index b7abe72e..047fac3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,26 +1,29 @@ [workspace] resolver = "2" members = [ - "llama-cpp-sys-2", - "llama-cpp-2", - "examples/embeddings", - "examples/simple", + "llama-cpp-sys-2", + "llama-cpp-2", + "examples/embeddings", + "examples/simple", + "examples/reranker", ] [workspace.dependencies] # core library deps thiserror = "1" tracing = "0.1" +tracing-core = "0.1" # examples and benchmarks hf-hub = { version = "0.3.2" } criterion = "0.5.1" pprof = "0.13.0" -bindgen = "0.69.4" -cc = "1.1.28" -anyhow = "1.0.86" -clap = "4.5.19" -encoding_rs = "0.8.34" +bindgen = "0.69.5" +cc = "1.2.26" +anyhow = "1.0.98" +clap = "4.5.39" +encoding_rs = "0.8.35" +tracing-subscriber = { version = "0.3", features = ["json"] } [workspace.lints.rust] missing_docs = { level = "warn" } diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 00000000..1b5ec8b7 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,176 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS diff --git a/LISENCE-MIT b/LISENCE-MIT new file mode 100644 index 00000000..7eadd881 --- /dev/null +++ b/LISENCE-MIT @@ -0,0 +1,25 @@ +Copyright (c) Dial AI + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index a4bd84e1..ea57dbc9 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ [readme]: https://github.com/utilityai/llama-cpp-rs/tree/main/llama-cpp-2 -This is the home for [llama-cpp-2][crates.io]. It also contains the [llama-cpp-sys] bindings which are updated regularly +This is the home for [llama-cpp-2][crates.io]. It also contains the [llama-cpp-sys] bindings which are updated semi-regularly and in sync with [llama-cpp-2][crates.io]. This project was created with the explict goal of staying as up to date as possible with llama.cpp, as a result it is diff --git a/examples/embeddings/Cargo.toml b/examples/embeddings/Cargo.toml index 3815fed7..eb993289 100644 --- a/examples/embeddings/Cargo.toml +++ b/examples/embeddings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "embeddings" -version = "0.1.83" +version = "0.1.109" edition = "2021" [dependencies] diff --git a/examples/reranker/Cargo.toml b/examples/reranker/Cargo.toml new file mode 100644 index 00000000..fa32c2d3 --- /dev/null +++ b/examples/reranker/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "reranker" +version = "0.1.86" +edition = "2021" + +[dependencies] +llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" } +hf-hub = { workspace = true } +clap = { workspace = true, features = ["derive"] } +anyhow = { workspace = true } +encoding_rs = { workspace = true } + +[features] +cuda = ["llama-cpp-2/cuda"] +metal = ["llama-cpp-2/metal"] +native = ["llama-cpp-2/native"] +vulkan = ["llama-cpp-2/vulkan"] + +[lints] +workspace = true \ No newline at end of file diff --git a/examples/reranker/README.md b/examples/reranker/README.md new file mode 100644 index 00000000..935c37ca --- /dev/null +++ b/examples/reranker/README.md @@ -0,0 +1,75 @@ +# Rust Reranker Implementation + +A Rust implementation of cross-encoder based reranking using llama-cpp-2. Cross-encoder reranking is a more accurate way to determine similarity between queries and documents compared to traditional embedding-based approaches. + +## Overview + +This implementation adds a new pooling type `LLAMA_POOLING_TYPE_RANK` which enables cross-encoder based reranking. Unlike traditional embedding approaches that encode query and document separately, this method: + +- Processes query and document pairs together in a single pass +- Directly evaluates semantic relationships between the pairs +- Outputs raw similarity scores indicating relevance + +## Installation + +```bash +# Follow instructions to clone repo. +# Navigate to examples reranker +cd examples/reranker + +# Build the project +cargo build --release +``` + +## Usage + +### Command Line Interface + +```bash +cargo run --release -- \  ✔ │ 5s │ 12:48:35 + --model-path "models/bge-reranker-v2-m3.gguf" \ + --query "what is panda?" \ + --documents "hi" \ + --documents "it's a bear" \ + --documents "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." \ + --pooling rank +``` +Should output(with bge-reranker-v2-m3-Q5_0): +rerank score 0: -6.551 +rerank score 1: -3.802 +rerank score 2: 4.522 + +### CLI Arguments + +- `--model-path`: Path to the GGUF model file +- `--query`: The search query +- `--documents`: One or more documents to rank against the query +- `--pooling`: Pooling type (options: none, mean, rank) + +### Pooling Types + +- `rank`: Performs cross-encoder reranking + + +Note: The raw scores are not normalized through a sigmoid function. If you need scores between 0-1, you'll need to implement sigmoid normalization in your application code. + +# Additional notes + +- Query and documents are concatenated using the format queryanswer + +## Supported Models + +Some tested models: + +- [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) +- [jinaai/jina-reranker-v1-tiny-en](https://huggingface.co/jinaai/jina-reranker-v1-tiny-en) + +Not tested others, but anything supported by llama.cpp should work. + +## Implementation Details + +This is a close Rust implementation of the reranker implementation discussed in [llama.cpp PR #9510](https://github.com/ggerganov/llama.cpp/pull/9510). + +## Potential issues + +The bos, eos, sep tokens are being hardcoded. We need to ideally get it from the model and build out the prompts based on each specific model. \ No newline at end of file diff --git a/examples/reranker/src/main.rs b/examples/reranker/src/main.rs new file mode 100644 index 00000000..22eae98e --- /dev/null +++ b/examples/reranker/src/main.rs @@ -0,0 +1,344 @@ +//! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. +#![allow( + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss +)] + +use std::io::Write; +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::{bail, Context, Result}; +use clap::Parser; +use hf_hub::api::sync::ApiBuilder; + +use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType}; +use llama_cpp_2::context::LlamaContext; +use llama_cpp_2::ggml_time_us; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::model::LlamaModel; +use llama_cpp_2::model::{AddBos, Special}; + +#[derive(clap::Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Path to the model file + #[clap(long)] + model_path: PathBuf, + + /// The query to embed + #[clap(long)] + query: String, + + /// The documents to embed and compare against + #[clap(long, num_args = 1..)] + documents: Vec, + + /// Pooling type (none, mean, or rank) + #[clap(long, default_value = "none")] + pooling: String, + + /// Whether to normalise the produced embeddings + #[clap(long, default_value_t = true)] + normalise: bool, + + #[clap(long, default_value_t = false)] + disable_gpu: bool, +} + +fn main() -> Result<()> { + let Args { + model_path, + query, + documents, + pooling, + normalise, + disable_gpu, + } = Args::parse(); + + // init LLM + let backend = LlamaBackend::init()?; + + // offload all layers to the gpu + let model_params = { + #[cfg(any(feature = "cuda", feature = "vulkan"))] + if !disable_gpu { + LlamaModelParams::default().with_n_gpu_layers(1000) + } else { + LlamaModelParams::default() + } + #[cfg(not(any(feature = "cuda", feature = "vulkan")))] + LlamaModelParams::default() + }; + + let model = LlamaModel::load_from_file(&backend, model_path, &model_params) + .with_context(|| "unable to load model")?; + // println!("pooling: {}", pooling); + let pooling_type = match pooling.as_str() { + "mean" => LlamaPoolingType::Mean, + "none" => LlamaPoolingType::None, + "rank" => LlamaPoolingType::Rank, + _ => LlamaPoolingType::Unspecified, + }; + + let ctx_params = LlamaContextParams::default() + .with_n_threads_batch(std::thread::available_parallelism()?.get().try_into()?) + .with_embeddings(true) + .with_pooling_type(pooling_type); + println!("ctx_params: {:?}", ctx_params); + let mut ctx = model + .new_context(&backend, ctx_params) + .with_context(|| "unable to create the llama_context")?; + + let n_embd = model.n_embd(); + + let prompt_lines = { + let mut lines = Vec::new(); + for doc in documents { + // Todo! update to get eos and sep from model instead of hardcoding + lines.push(format!("{query}{eos}{sep}{doc}", sep = "", eos = "")); + } + lines + }; + + println!("prompt_lines: {:?}", prompt_lines); + // tokenize the prompt + let tokens_lines_list = prompt_lines + .iter() + .map(|line| model.str_to_token(line, AddBos::Always)) + .collect::, _>>() + .with_context(|| format!("failed to tokenize {:?}", prompt_lines))?; + + let n_ctx = ctx.n_ctx() as usize; + let n_ctx_train = model.n_ctx_train(); + + eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); + + if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { + bail!("One of the provided prompts exceeds the size of the context window"); + } + + // print the prompt token-by-token + eprintln!(); + + for (i, token_line) in tokens_lines_list.iter().enumerate() { + eprintln!("Prompt {i} --> {}", prompt_lines[i]); + eprintln!("Number of tokens: {}", token_line.len()); + for token in token_line { + // Attempt to convert token to string and print it; if it fails, print the token instead + match model.token_to_str(*token, Special::Tokenize) { + Ok(token_str) => eprintln!("{token} --> {token_str}"), + Err(e) => { + eprintln!("Failed to convert token to string, error: {e}"); + eprintln!("Token value: {token}"); + } + } + } + eprintln!(); + } + + std::io::stderr().flush()?; + + // create a llama_batch with the size of the context + // we use this object to submit token data for decoding + let mut batch = LlamaBatch::new(2048, 1); + + // Todo! update to get n_embd to init vector size for better memory management + // let mut n_embd_count = if pooling == "none" { + // tokens_lines_list.iter().map(|tokens| tokens.len()).sum() + // } else { + // tokens_lines_list.len() + // }; + let mut embeddings_stored = 0; + let mut max_seq_id_batch = 0; + let mut output = Vec::with_capacity(tokens_lines_list.len()); + + let t_main_start = ggml_time_us(); + + for tokens in &tokens_lines_list { + // Flush the batch if the next prompt would exceed our batch size + if (batch.n_tokens() as usize + tokens.len()) > 2048 { + batch_decode( + &mut ctx, + &mut batch, + max_seq_id_batch, + n_embd, + &mut output, + normalise, + pooling.clone(), + )?; + embeddings_stored += if pooling == "none" { + batch.n_tokens() + } else { + max_seq_id_batch + }; + max_seq_id_batch = 0; + batch.clear(); + } + + batch.add_sequence(tokens, max_seq_id_batch, false)?; + max_seq_id_batch += 1; + } + // Handle final batch + batch_decode( + &mut ctx, + &mut batch, + max_seq_id_batch, + n_embd, + &mut output, + normalise, + pooling.clone(), + )?; + + let t_main_end = ggml_time_us(); + + for (j, embeddings) in output.iter().enumerate() { + if pooling == "none" { + eprintln!("embedding {j}: "); + for i in 0..n_embd as usize { + if !normalise { + eprint!("{:6.5} ", embeddings[i]); + } else { + eprint!("{:9.6} ", embeddings[i]); + } + } + eprintln!(); + } else if pooling == "rank" { + eprintln!("rerank score {j}: {:8.3}", embeddings[0]); + } else { + eprintln!("embedding {j}: "); + for i in 0..n_embd as usize { + if !normalise { + eprint!("{:6.5} ", embeddings[i]); + } else { + eprint!("{:9.6} ", embeddings[i]); + } + } + eprintln!(); + } + } + + let duration = Duration::from_micros((t_main_end - t_main_start) as u64); + let total_tokens: usize = tokens_lines_list.iter().map(Vec::len).sum(); + eprintln!( + "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", + total_tokens, + duration.as_secs_f32(), + total_tokens as f32 / duration.as_secs_f32() + ); + + println!("{}", ctx.timings()); + + Ok(()) +} + +fn batch_decode( + ctx: &mut LlamaContext, + batch: &mut LlamaBatch, + s_batch: i32, + n_embd: i32, + output: &mut Vec>, + normalise: bool, + pooling: String, +) -> Result<()> { + eprintln!( + "{}: n_tokens = {}, n_seq = {}", + stringify!(batch_decode), + batch.n_tokens(), + s_batch + ); + + // Clear previous kv_cache values + ctx.clear_kv_cache(); + + ctx.decode(batch).with_context(|| "llama_decode() failed")?; + + for i in 0..s_batch { + let embeddings = ctx + .embeddings_seq_ith(i) + .with_context(|| "Failed to get sequence embeddings")?; + let normalized = if normalise { + if pooling == "rank" { + normalize_embeddings(&embeddings, -1) + } else { + normalize_embeddings(&embeddings, 2) + } + } else { + embeddings.to_vec() + }; + output.push(normalized); + } + + batch.clear(); + + Ok(()) +} + +/// Normalizes embeddings based on different normalization strategies +fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec { + let n = input.len(); + let mut output = vec![0.0; n]; + + let sum = match embd_norm { + -1 => 1.0, // no normalization + 0 => { + // max absolute + let max_abs = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max) / 32760.0; + max_abs as f64 + } + 2 => { + // euclidean norm + input + .iter() + .map(|x| (*x as f64).powi(2)) + .sum::() + .sqrt() + } + p => { + // p-norm + let sum = input.iter().map(|x| (x.abs() as f64).powi(p)).sum::(); + sum.powf(1.0 / p as f64) + } + }; + + let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 }; + + for i in 0..n { + output[i] = (input[i] as f64 * norm) as f32; + } + + output +} + +// /// Calculates cosine similarity between two embedding vectors +// fn embedding_similarity_cos(embd1: &[f32], embd2: &[f32]) -> f32 { +// assert_eq!(embd1.len(), embd2.len(), "Embedding vectors must be the same length"); + +// let (sum, sum1, sum2) = embd1.iter().zip(embd2.iter()).fold( +// (0.0f64, 0.0f64, 0.0f64), +// |(sum, sum1, sum2), (e1, e2)| { +// let e1 = *e1 as f64; +// let e2 = *e2 as f64; +// ( +// sum + e1 * e2, +// sum1 + e1 * e1, +// sum2 + e2 * e2 +// ) +// } +// ); + +// // Handle zero vectors +// if sum1 == 0.0 || sum2 == 0.0 { +// return if sum1 == 0.0 && sum2 == 0.0 { +// 1.0 // two zero vectors are similar +// } else { +// 0.0 +// }; +// } + +// (sum / (sum1.sqrt() * sum2.sqrt())) as f32 +// } diff --git a/examples/simple/Cargo.toml b/examples/simple/Cargo.toml index e14648ad..28d0ee6d 100644 --- a/examples/simple/Cargo.toml +++ b/examples/simple/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple" -version = "0.1.83" +version = "0.1.109" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,13 +8,14 @@ edition = "2021" [dependencies] llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.69" } hf-hub = { workspace = true } -clap = { workspace = true , features = ["derive"] } +clap = { workspace = true, features = ["derive"] } anyhow = { workspace = true } encoding_rs = { workspace = true } +tracing-subscriber = { workspace = true } [features] cuda = ["llama-cpp-2/cuda"] -metal = ["llama-cpp-2/metal"] +metal = ["llama-cpp-2/metal"] native = ["llama-cpp-2/native"] vulkan = ["llama-cpp-2/vulkan"] diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index 267d6864..9d4eef47 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -10,14 +10,15 @@ use anyhow::{anyhow, bail, Context, Result}; use clap::Parser; use hf_hub::api::sync::ApiBuilder; use llama_cpp_2::context::params::LlamaContextParams; -use llama_cpp_2::ggml_time_us; use llama_cpp_2::llama_backend::LlamaBackend; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::token::data_array::LlamaTokenDataArray; +use llama_cpp_2::sampling::LlamaSampler; +use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions}; + use std::ffi::CString; use std::io::Write; use std::num::NonZeroU32; @@ -66,6 +67,8 @@ struct Args { help = "size of the prompt context (default: loaded from themodel)" )] ctx_size: Option, + #[arg(short = 'v', long, help = "enable verbose llama.cpp logs")] + verbose: bool, } /// Parse a single key-value pair @@ -131,8 +134,14 @@ fn main() -> Result<()> { threads, threads_batch, ctx_size, + verbose, } = Args::parse(); + if verbose { + tracing_subscriber::fmt().init(); + } + send_logs_to_tracing(LogOptions::default().with_logs_enabled(verbose)); + // init LLM let backend = LlamaBackend::init()?; @@ -174,9 +183,9 @@ fn main() -> Result<()> { .with_context(|| "unable to load model")?; // initialize the context - let mut ctx_params = LlamaContextParams::default() - .with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))) - .with_seed(seed.unwrap_or(1234)); + let mut ctx_params = + LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))); + if let Some(threads) = threads { ctx_params = ctx_params.with_n_threads(threads); } @@ -244,23 +253,25 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut sampler = LlamaSampler::chain_simple([ + LlamaSampler::dist(seed.unwrap_or(1234)), + LlamaSampler::greedy(), + ]); + while n_cur <= n_len { // sample the next token { - let candidates = ctx.candidates(); - - let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); + let token = sampler.sample(&ctx, batch.n_tokens() - 1); - // sample the most likely token - let new_token_id = ctx.sample_token_greedy(candidates_p); + sampler.accept(token); // is it an end of stream? - if model.is_eog_token(new_token_id) { + if model.is_eog_token(token) { eprintln!(); break; } - let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize)?; + let output_bytes = model.token_to_bytes(token, Special::Tokenize)?; // use `Decoder.decode_to_string()` to avoid the intermediate buffer let mut output_string = String::with_capacity(32); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); @@ -268,7 +279,7 @@ either reduce n_len or increase n_ctx" std::io::stdout().flush()?; batch.clear(); - batch.add(new_token_id, n_cur, &[0], true)?; + batch.add(token, n_cur, &[0], true)?; } n_cur += 1; diff --git a/examples/usage.rs b/examples/usage.rs index 1b7d1f5d..323ad6c2 100644 --- a/examples/usage.rs +++ b/examples/usage.rs @@ -14,7 +14,7 @@ use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::token::data_array::LlamaTokenDataArray; +use llama_cpp_2::sampling::LlamaSampler; use std::io::Write; #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] @@ -53,26 +53,22 @@ fn main() { // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut sampler = LlamaSampler::greedy(); while n_cur <= n_len { // sample the next token { - let candidates = ctx.candidates_ith(batch.n_tokens() - 1); + let token = sampler.sample(&ctx, batch.n_tokens() - 1); - let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); - - // sample the most likely token - let new_token_id = ctx.sample_token_greedy(candidates_p); + sampler.accept(token); // is it an end of stream? - if new_token_id == model.token_eos() { + if token == model.token_eos() { eprintln!(); break; } - let output_bytes = model - .token_to_bytes(new_token_id, Special::Tokenize) - .unwrap(); + let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap(); // use `Decoder.decode_to_string()` to avoid the intermediate buffer let mut output_string = String::with_capacity(32); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); @@ -80,7 +76,7 @@ fn main() { std::io::stdout().flush().unwrap(); batch.clear(); - batch.add(new_token_id, n_cur, &[0], true).unwrap(); + batch.add(token, n_cur, &[0], true).unwrap(); } n_cur += 1; diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index 7f18b2df..fb2f9f57 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-2" description = "llama.cpp bindings for Rust" -version = "0.1.83" +version = "0.1.109" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" @@ -9,23 +9,27 @@ repository = "https://github.com/utilityai/llama-cpp-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -enumflags2 = "0.7.10" +enumflags2 = "0.7.11" llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.69" } thiserror = { workspace = true } tracing = { workspace = true } +tracing-core = { workspace = true } [dev-dependencies] encoding_rs = { workspace = true } [features] -default = ["openmp"] +default = ["openmp", "android-shared-stdcxx"] cuda = ["llama-cpp-sys-2/cuda"] +cuda-no-vmm = ["cuda", "llama-cpp-sys-2/cuda-no-vmm"] metal = ["llama-cpp-sys-2/metal"] dynamic-link = ["llama-cpp-sys-2/dynamic-link"] vulkan = ["llama-cpp-sys-2/vulkan"] native = ["llama-cpp-sys-2/native"] openmp = ["llama-cpp-sys-2/openmp"] sampler = [] +# Only has an impact on Android. +android-shared-stdcxx = ["llama-cpp-sys-2/shared-stdcxx"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 80ee8f75..10f2d7eb 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -9,6 +9,7 @@ use crate::llama_batch::LlamaBatch; use crate::model::{LlamaLoraAdapter, LlamaModel}; use crate::timing::LlamaTimings; use crate::token::data::LlamaTokenData; +use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; use crate::{ DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError, @@ -17,7 +18,6 @@ use crate::{ pub mod kv_cache; pub mod params; -pub mod sample; pub mod session; /// Safe wrapper around `llama_context`. @@ -52,13 +52,13 @@ impl<'model> LlamaContext<'model> { } } - /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch. + /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`]. #[must_use] pub fn n_batch(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) } } - /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch. + /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`]. #[must_use] pub fn n_ubatch(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) } @@ -203,6 +203,21 @@ impl<'model> LlamaContext<'model> { }) } + /// Get the token data array for the last token in the context. + /// + /// This is a convience method that implements: + /// ```ignore + /// LlamaTokenDataArray::from_iter(ctx.candidates(), false) + /// ``` + /// + /// # Panics + /// + /// - underlying logits data is null + #[must_use] + pub fn token_data_array(&self) -> LlamaTokenDataArray { + LlamaTokenDataArray::from_iter(self.candidates(), false) + } + /// Token logits obtained from the last call to `decode()`. /// The logits for which `batch.logits[i] != 0` are stored contiguously /// in the order they have appeared in the batch. @@ -218,6 +233,7 @@ impl<'model> LlamaContext<'model> { /// /// - `n_vocab` does not fit into a usize /// - token data returned is null + #[must_use] pub fn get_logits(&self) -> &[f32] { let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) }; assert!(!data.is_null(), "logits data for last token is null"); @@ -238,6 +254,21 @@ impl<'model> LlamaContext<'model> { }) } + /// Get the token data array for the ith token in the context. + /// + /// This is a convience method that implements: + /// ```ignore + /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false) + /// ``` + /// + /// # Panics + /// + /// - logit `i` is not initialized. + #[must_use] + pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray { + LlamaTokenDataArray::from_iter(self.candidates_ith(i), false) + } + /// Get the logits for the ith token in the context. /// /// # Panics @@ -267,12 +298,12 @@ impl<'model> LlamaContext<'model> { /// Reset the timings for the context. pub fn reset_timings(&mut self) { - unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) } } /// Returns the timings for the context. pub fn timings(&mut self) -> LlamaTimings { - let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) }; + let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) }; LlamaTimings { timings } } @@ -287,7 +318,7 @@ impl<'model> LlamaContext<'model> { scale: f32, ) -> Result<(), LlamaLoraAdapterSetError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_set( + llama_cpp_sys_2::llama_set_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), scale, @@ -311,7 +342,7 @@ impl<'model> LlamaContext<'model> { adapter: &mut LlamaLoraAdapter, ) -> Result<(), LlamaLoraAdapterRemoveError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_remove( + llama_cpp_sys_2::llama_rm_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), ) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index d5a8ed65..14f5b5a6 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -6,6 +6,7 @@ use std::num::{NonZeroU8, TryFromIntError}; /// Errors that can occur when attempting to prepare values for the kv cache #[derive(Debug, Eq, PartialEq, thiserror::Error)] +#[allow(clippy::module_name_repetitions)] pub enum KvCacheConversionError { /// Sequence id conversion to i32 failed #[error("Provided sequence id is too large for a i32")] @@ -27,21 +28,22 @@ impl LlamaContext<'_> { /// * `dest` - The sequence id to copy the cache to. /// * `size` - The size of the cache to copy. pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, 0, size) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) } } /// Copy the cache from one sequence to another. /// /// # Returns - /// A `Result` indicating whether the operation was successful. If the either position exceeds - /// the maximum i32 value, no copy is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters - /// /// * `src` - The sequence id to copy the cache from. /// * `dest` - The sequence id to copy the cache to. /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`. + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn copy_kv_cache_seq( &mut self, src: i32, @@ -51,12 +53,12 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1); + llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1); } Ok(()) } @@ -69,10 +71,12 @@ impl LlamaContext<'_> { /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned. /// /// # Parameters - /// /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`. + /// + /// # Errors + /// If the sequence id or either position exceeds [`i32::MAX`]. pub fn clear_kv_cache_seq( &mut self, src: Option, @@ -81,25 +85,25 @@ impl LlamaContext<'_> { ) -> Result { let src = src .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?; + .map_err(KvCacheConversionError::SeqIdTooLarge)?; let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; - Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) }) + .map_err(KvCacheConversionError::P1TooLarge)?; + Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) }) } /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) #[must_use] pub fn get_kv_cache_used_cells(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) } } /// Clear the KV cache pub fn clear_kv_cache(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_clear(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) } } /// Removes all tokens that do not belong to the specified sequence @@ -108,7 +112,7 @@ impl LlamaContext<'_> { /// /// * `seq_id` - The sequence id to keep pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_keep(self.context.as_ptr(), seq_id) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) } } #[allow(clippy::doc_markdown)] @@ -118,8 +122,7 @@ impl LlamaContext<'_> { /// - explicitly with [`Self::kv_cache_update`] /// /// # Returns - /// A `Result` indicating whether the operation was successful. If either position - /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters /// @@ -127,6 +130,9 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `delta` - The relative position to add to the tokens + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn kv_cache_seq_add( &mut self, seq_id: i32, @@ -136,12 +142,12 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { - llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); + llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); } Ok(()) } @@ -152,8 +158,7 @@ impl LlamaContext<'_> { /// - explicitly with [`Self::kv_cache_update`] /// /// # Returns - /// A `Result` indicating whether the operation was successful. If either position - /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters /// @@ -161,6 +166,9 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `d` - The factor to divide the positions by + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn kv_cache_seq_div( &mut self, seq_id: i32, @@ -170,12 +178,12 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; let d = c_int::from(d.get()); - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } Ok(()) } @@ -186,7 +194,7 @@ impl LlamaContext<'_> { /// * `seq_id` - The sequence id to get the max position for #[must_use] pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 { - unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) } + unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) } } /// Defragment the KV cache @@ -194,130 +202,11 @@ impl LlamaContext<'_> { /// - lazily on next [`LlamaContext::decode`] /// - explicitly with [`Self::kv_cache_update`] pub fn kv_cache_defrag(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) } } /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) pub fn kv_cache_update(&mut self) { - unsafe { llama_cpp_sys_2::llama_kv_cache_update(self.context.as_ptr()) } - } - - /// Returns the number of tokens in the KV cache (slow, use only for debug) - /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times - #[must_use] - pub fn get_kv_cache_token_count(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) } - } - - /// Create an empty KV cache view. (use only for debugging purposes) - /// - /// # Parameters - /// - /// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error - /// if there are more sequences in a cell than this value, however they will - /// not be visible in the view `cells_sequences`. - #[must_use] - pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView { - let view = - unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) }; - KVCacheView { view, ctx: self } - } -} - -/// Information associated with an individual cell in the KV cache view. -#[derive(Debug)] -pub struct KVCacheViewCell { - /// The position for this cell. Takes KV cache shifts into account. - /// May be negative if the cell is not populated. - pub pos: llama_cpp_sys_2::llama_pos, -} - -/// An updateable view of the KV cache. (use only for debugging purposes) -#[derive(Debug)] -pub struct KVCacheView<'a> { - ctx: &'a LlamaContext<'a>, - view: llama_cpp_sys_2::llama_kv_cache_view, -} - -impl<'a> KVCacheView<'a> { - /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) - pub fn update(&mut self) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view); - } - } - - /// Number of KV cache cells. This will be the same as the context size. - #[must_use] - pub fn n_cells(&self) -> i32 { - self.view.n_cells - } - - /// Number of tokens in the cache. For example, if there are two populated - /// cells, the first with 1 sequence id in it and the second with 2 sequence - /// ids then you'll have 3 tokens. - #[must_use] - pub fn token_count(&self) -> i32 { - self.view.token_count - } - - /// Number of populated cache cells. - #[must_use] - pub fn used_cells(&self) -> i32 { - self.view.used_cells - } - - /// Maximum contiguous empty slots in the cache. - #[must_use] - pub fn max_contiguous(&self) -> i32 { - self.view.max_contiguous - } - - /// Index to the start of the `max_contiguous` slot range. Can be negative - /// when cache is full. - #[must_use] - pub fn max_contiguous_idx(&self) -> i32 { - self.view.max_contiguous_idx - } - - /// Information for individual cells. - /// - /// # Panics - /// - /// - if `n_cells` does not fit into usize. - pub fn cells(&self) -> impl Iterator { - unsafe { - std::slice::from_raw_parts( - self.view.cells, - usize::try_from(self.view.n_cells).expect("failed to fit n_cells into usize"), - ) - } - .iter() - .map(|&cell| KVCacheViewCell { pos: cell.pos }) - } - - /// The sequences for each cell. There will be `n_max_seq` items per cell. - /// - /// # Panics - /// - /// - if `n_cells * n_max_seq` does not fit into usize. - /// - if `n_max_seq` does not fit into usize. - pub fn cells_sequences(&self) -> impl Iterator { - unsafe { - std::slice::from_raw_parts( - self.view.cells_sequences, - usize::try_from(self.view.n_cells * self.view.n_seq_max) - .expect("failed to fit n_cells * n_max_seq into usize"), - ) - } - .chunks(usize::try_from(self.view.n_seq_max).expect("failed to fit n_max_seq into usize")) - } -} - -impl<'a> Drop for KVCacheView<'a> { - fn drop(&mut self) { - unsafe { - llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view); - } + unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) } } } diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 14eca8b0..892dc8dc 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -47,7 +47,7 @@ impl From for i32 { pub enum LlamaPoolingType { /// The pooling type is unspecified Unspecified = -1, - /// No pooling + /// No pooling None = 0, /// Mean pooling Mean = 1, @@ -55,6 +55,8 @@ pub enum LlamaPoolingType { Cls = 2, /// Last pooling Last = 3, + /// Rank pooling + Rank = 4, } /// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if @@ -66,6 +68,7 @@ impl From for LlamaPoolingType { 1 => Self::Mean, 2 => Self::Cls, 3 => Self::Last, + 4 => Self::Rank, _ => Self::Unspecified, } } @@ -79,6 +82,7 @@ impl From for i32 { LlamaPoolingType::Mean => 1, LlamaPoolingType::Cls => 2, LlamaPoolingType::Last => 3, + LlamaPoolingType::Rank => 4, LlamaPoolingType::Unspecified => -1, } } @@ -95,10 +99,8 @@ impl From for i32 { /// use llama_cpp_2::context::params::LlamaContextParams; /// ///let ctx_params = LlamaContextParams::default() -/// .with_n_ctx(NonZeroU32::new(2048)) -/// .with_seed(1234); +/// .with_n_ctx(NonZeroU32::new(2048)); /// -/// assert_eq!(ctx_params.seed(), 1234); /// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048)); /// ``` #[derive(Debug, Clone)] @@ -116,37 +118,6 @@ unsafe impl Send for LlamaContextParams {} unsafe impl Sync for LlamaContextParams {} impl LlamaContextParams { - /// Set the seed of the context - /// - /// # Examples - /// - /// ```rust - /// use llama_cpp_2::context::params::LlamaContextParams; - /// let params = LlamaContextParams::default(); - /// let params = params.with_seed(1234); - /// assert_eq!(params.seed(), 1234); - /// ``` - #[must_use] - pub fn with_seed(mut self, seed: u32) -> Self { - self.context_params.seed = seed; - self - } - - /// Get the seed of the context - /// - /// # Examples - /// - /// ```rust - /// use llama_cpp_2::context::params::LlamaContextParams; - /// let params = LlamaContextParams::default() - /// .with_seed(1234); - /// assert_eq!(params.seed(), 1234); - /// ``` - #[must_use] - pub fn seed(&self) -> u32 { - self.context_params.seed - } - /// Set the side of the context /// /// # Examples diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs deleted file mode 100644 index cc0f85ee..00000000 --- a/llama-cpp-2/src/context/sample.rs +++ /dev/null @@ -1,141 +0,0 @@ -//! Sampling functions for the context. - -use crate::context::LlamaContext; -use crate::grammar::LlamaGrammar; -use crate::token::data_array::LlamaTokenDataArray; -use crate::token::LlamaToken; - -#[cfg(feature = "sampler")] -pub mod sampler; - -impl LlamaContext<'_> { - /// Accept a token into the grammar. - pub fn grammar_accept_token(&mut self, grammar: &mut LlamaGrammar, token: LlamaToken) { - unsafe { - llama_cpp_sys_2::llama_grammar_accept_token( - grammar.grammar.as_ptr(), - self.context.as_ptr(), - token.0, - ); - } - } - - /// Perform grammar sampling. - pub fn sample_grammar( - &mut self, - llama_token_data_array: &mut LlamaTokenDataArray, - llama_grammar: &LlamaGrammar, - ) { - unsafe { - llama_token_data_array.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_grammar( - self.context.as_ptr(), - c_llama_token_data_array, - llama_grammar.grammar.as_ptr(), - ); - }); - } - } - - /// See [`LlamaTokenDataArray::sample_temp`] - pub fn sample_temp(&mut self, token_data: &mut LlamaTokenDataArray, temperature: f32) { - token_data.sample_temp(Some(self), temperature); - } - - /// Sample a token greedily. Note that this *does not* take into account anything that has modified the probabilities - it only looks at logits. - /// - /// Most of the time [`LlamaTokenDataArray::sample_softmax`] or [`LlamaTokenDataArray::sample_token`] should be used instead. - /// - /// # Panics - /// - /// - if `token_data` is empty - #[must_use] - pub fn sample_token_greedy(&mut self, mut token_data: LlamaTokenDataArray) -> LlamaToken { - assert!(!token_data.data.is_empty(), "no tokens"); - let mut data_arr = llama_cpp_sys_2::llama_token_data_array { - data: token_data - .data - .as_mut_ptr() - .cast::(), - size: token_data.data.len(), - sorted: token_data.sorted, - }; - let token = unsafe { - llama_cpp_sys_2::llama_sample_token_greedy( - self.context.as_ptr(), - std::ptr::addr_of_mut!(data_arr), - ) - }; - LlamaToken(token) - } - - /// See [`LlamaTokenDataArray::sample_tail_free`] - pub fn sample_tail_free( - &mut self, - token_data: &mut LlamaTokenDataArray, - z: f32, - min_keep: usize, - ) { - token_data.sample_tail_free(Some(self), z, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_typical`] - pub fn sample_typical( - &mut self, - token_data: &mut LlamaTokenDataArray, - p: f32, - min_keep: usize, - ) { - token_data.sample_typical(Some(self), p, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_top_p`] - pub fn sample_top_p(&mut self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { - token_data.sample_top_p(Some(self), p, min_keep); - } - - /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) - pub fn sample_min_p( - &mut self, - llama_token_data: &mut LlamaTokenDataArray, - p: f32, - min_keep: usize, - ) { - let ctx = self.context.as_ptr(); - unsafe { - llama_token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// See [`LlamaTokenDataArray::sample_top_k`] - pub fn sample_top_k(&mut self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { - token_data.sample_top_k(Some(self), k, min_keep); - } - - /// See [`LlamaTokenDataArray::sample_softmax`] - pub fn sample_token_softmax(&mut self, token_data: &mut LlamaTokenDataArray) { - token_data.sample_softmax(Some(self)); - } - - /// See [`LlamaTokenDataArray::sample_repetition_penalty`] - pub fn sample_repetition_penalty( - &mut self, - token_data: &mut LlamaTokenDataArray, - last_tokens: &[LlamaToken], - penalty_last_n: usize, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - token_data.sample_repetition_penalty( - Some(self), - last_tokens, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ); - } -} diff --git a/llama-cpp-2/src/context/sample/sampler.rs b/llama-cpp-2/src/context/sample/sampler.rs deleted file mode 100644 index 948a1aa5..00000000 --- a/llama-cpp-2/src/context/sample/sampler.rs +++ /dev/null @@ -1,112 +0,0 @@ -//! Create a sampler struct to encapsulate the sampling process. This allows passing all the possible -//! sampling parameters around as a single struct, and also allow late binding of expensive context -//! like [`crate::context::LlamaContext`] or token history to the sampler. -//! -//! # Example -//! -//! **Llama.cpp default sampler** -//! -//! ```rust -//! use llama_cpp_2::context::sample::sampler::{Sampler, SampleStep}; -//! use llama_cpp_2::token::data::LlamaTokenData; -//! use llama_cpp_2::token::data_array::LlamaTokenDataArray; -//! use llama_cpp_2::token::LlamaToken; -//! -//! // Sample a token greedily and add to the history. -//! let mut finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec| { -//! canidates.sample_softmax(None); -//! let token = canidates.data[0]; -//! history.push(token.id()); -//! vec![token] -//! }; -//! -//! let mut history = vec![]; -//! let mut sampler = Sampler::new(finalizer); -//! -//! sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0)); -//! sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1)); -//! sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1)); -//! sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1)); -//! sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1)); -//! sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1)); -//! sampler.push_step(&|c, _| c.sample_temp(None, 0.5)); -//! -//! // random candidates -//! let candidates = LlamaTokenDataArray::from_iter((0..4).map(|i| LlamaTokenData::new(LlamaToken::new(i), i as f32 / 6.0, 0.0)), false); -//! -//! for _ in 0..10 { -//! let tokens = sampler.sample(&mut history, candidates.clone()); -//! assert_eq!(tokens.len(), 1); -//! } -//! -//! assert_eq!(history.len(), 10); -//! ``` - -use crate::token::data::LlamaTokenData; -use crate::token::data_array::LlamaTokenDataArray; -use std::fmt::{Debug, Formatter}; - -/// A single step to sample tokens from the remaining candidates. -pub type SampleStep = dyn Fn(&mut LlamaTokenDataArray, &mut C); - -/// The final step to select tokens from the remaining candidates. -pub type SampleFinalizer = dyn Fn(LlamaTokenDataArray, &mut C) -> Vec; - -/// A series of sampling steps that will produce a vector of token data. -/// -/// `C` is dynamic context that will be passed to the sampling functions. Some sampling steps may -/// require state to be maintained across multiple samples, and this context can be used to store -/// that state. For example, [`LlamaTokenDataArray::sample_token_mirostat_v2`] requires a `mu` to be -/// shared across multiple samples. -pub struct Sampler<'a, C> { - /// The steps to take when sampling. - pub steps: Vec<&'a SampleStep>, - /// The final step to select one or more tokens from the remaining candidates. - pub finalizer: &'a SampleFinalizer, -} - -impl Debug for Sampler<'_, T> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Sampler") - .field( - "steps", - &format!( - "{} steps of Box ()>", - &self.steps.len() - ), - ) - .field( - "finalizer", - &"Box Vec>", - ) - .finish() - } -} - -impl<'a, T> Sampler<'a, T> { - /// Create a new sampler with a given finalizer. - pub fn new(finalizer: &'a SampleFinalizer) -> Self { - Self { - steps: vec![], - finalizer, - } - } - - /// Adds a step to the sampler. - pub fn push_step(&mut self, step: &'a SampleStep) { - self.steps.push(step); - } - - /// Sample a token from the given candidates. - #[must_use] - pub fn sample( - &mut self, - context: &mut T, - mut candidates: LlamaTokenDataArray, - ) -> Vec { - for step in &self.steps { - step(&mut candidates, context); - } - (self.finalizer)(candidates, context) - } -} diff --git a/llama-cpp-2/src/grammar.rs b/llama-cpp-2/src/grammar.rs deleted file mode 100644 index 667a870b..00000000 --- a/llama-cpp-2/src/grammar.rs +++ /dev/null @@ -1,491 +0,0 @@ -//! The grammar module contains the grammar parser and the grammar struct. -//! -//! This allows creating a llama-cpp grammar. This is essentially a translation of the parser in -//! `common` to rust - -use std::collections::BTreeMap; -use std::fmt::{Debug, Formatter}; - -use llama_cpp_sys_2::{llama_grammar, llama_grammar_element, llama_gretype}; -use std::ptr::NonNull; -use std::str::FromStr; -use tracing::error; - -/// Details of extraneous characters after a rule error. -#[derive(thiserror::Error, Debug)] -#[error("Extraneous chars after rule {name:?}: {chars:?}")] -pub struct ExtraneousCharsAfterRule { - /// The name of the rule being parsed - pub name: String, - /// the extraneous characters - pub chars: String, - /// the rest of the input, this is still to be parsed. - pub rest: String, -} - -/// There was an error parsing the grammar. -#[derive(thiserror::Error, Debug)] -#[allow(clippy::module_name_repetitions)] -pub enum GrammarParseError { - /// There was an unexpected end of input. - #[error("Unexpected end of input")] - UnexpectedEndOfInput { - /// the stage of parsing that was being performed when we ran out of input. - parse_stage: &'static str, - }, - /// There was unexpected characters after a rule name but before "::=". There can only be whitespace. - #[error("Unexpected Chars after name {name:?} and before \"::=\": {chars}")] - UnexpectedCharsAfterName { - /// the name of the rule being parsed - name: String, - /// the unexpected characters - chars: String, - }, - /// There was no "::=" after a rule name. - #[error("Expected ::= after name {name:?}")] - ExpectedEqualsAfterName { - /// the name of the rule being parsed - name: String, - }, - /// There was no closing bracket in a nested rule. - #[error("Expected closing bracket in nested rule {name:?}")] - MissingClosingBracketInNestedRule { - /// the name of the rule being parsed - name: String, - }, - /// There was no rule before a postfix operator. - #[error("Missing rule before postfix operator in {name:?}")] - ExpectedRuleBeforePostfixOperator { - /// the name of the rule being parsed - name: String, - }, - /// There was an incorrect hex size. - #[error("Expected hex number with size {expected_size}, but number was {actual:?}")] - IncorrectHexSize { - /// the expected size of the hex number - expected_size: usize, - /// the actual hex number - actual: String, - }, - /// An unknown escape character was found. - #[error("Unknown escape {escape:?}")] - UnknownEscape { - /// the unknown character - escape: char, - }, - /// Failed to parse hex from a string. - #[error("Failed to parse hex from {string}: {error}")] - ParseHexError { - /// the error that occurred when parsing the hex - #[source] - error: std::num::ParseIntError, - /// the string that was being parsed - string: String, - }, - /// there was not space after the name - // todo: is this actually an error? - #[error("Missing space after name in {rest:?}")] - MissingSpaceAfterName { - /// the rest of the input, this is still to be parsed. - rest: String, - }, - /// There was unexpected characters after the rule. - #[error("{0}")] - ExtraneousCharsAfterRule(ExtraneousCharsAfterRule), -} - -/// A grammar for llama-cpp. -#[allow(clippy::module_name_repetitions)] -pub struct LlamaGrammar { - parse: ParseState, - pub(crate) grammar: NonNull, -} - -impl Clone for LlamaGrammar { - fn clone(&self) -> Self { - let grammar = unsafe { llama_cpp_sys_2::llama_grammar_copy(self.grammar.as_ptr()) }; - Self { - parse: self.parse.clone(), - grammar: NonNull::new(grammar).expect("copied grammar should never be null"), - } - } -} - -unsafe impl Send for LlamaGrammar {} - -unsafe impl Sync for LlamaGrammar {} - -#[allow(clippy::module_name_repetitions)] -impl Debug for LlamaGrammar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LlamaGrammar") - .field("grammar", &self.grammar) - .field("parse", &self.parse) - .finish() - } -} - -#[derive(Debug, Clone, PartialEq)] -struct ParseState { - symbol_ids: BTreeMap, - rules: Vec>, -} - -impl ParseState { - fn new() -> Self { - Self { - symbol_ids: BTreeMap::new(), - rules: Vec::new(), - } - } - - fn get_symbol_id(&mut self, name: &str) -> u32 { - let next_id = - u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); - let result = self.symbol_ids.entry(name.to_string()).or_insert(next_id); - *result - } - - fn generate_symbol_id(&mut self, name: &str) -> u32 { - let next_id = - u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); - let generated_name = format!("{name}_{next_id}"); - let None = self.symbol_ids.insert(generated_name, next_id) else { - panic!("Failed to create unique name for {name}"); - }; - next_id - } - - fn parse_rule<'a>(&mut self, rest: &'a str) -> Result, GrammarParseError> { - let rest = Self::consume_whitespace_and_comments(rest, true); - if rest.is_empty() { - return Ok(None); - } - let (name, rest) = Self::parse_name(rest)?; - let rest = rest.trim_start(); - let rule_id = self.get_symbol_id(name); - - let (after_name, rest) = - rest.split_once("::=") - .ok_or_else(|| GrammarParseError::ExpectedEqualsAfterName { - name: name.to_string(), - })?; - - if !after_name.is_empty() { - return Err(GrammarParseError::UnexpectedCharsAfterName { - name: name.to_string(), - chars: after_name.to_string(), - }); - } - - let rest = self.parse_alternatives(name, rule_id, rest, false)?; - - let Some((after_rule, rest)) = rest.split_once('\n') else { - return Ok(None); - }; - - if !after_rule.chars().all(char::is_whitespace) { - return Err(GrammarParseError::ExtraneousCharsAfterRule( - ExtraneousCharsAfterRule { - name: name.to_string(), - chars: after_rule.to_string(), - rest: rest.to_string(), - }, - )); - } - - Ok(Some(rest)) - } - - fn consume_whitespace_and_comments(mut rest: &str, allow_newlines: bool) -> &str { - loop { - rest = rest.trim_start_matches( - |c: char| if allow_newlines { true } else { c != '\n' } && c.is_whitespace(), - ); - if rest.starts_with('#') { - rest = rest.split_once('\n').map_or("", |(_comment, rest)| rest); - } else { - break; - } - } - rest - } - - fn parse_alternatives<'a>( - &mut self, - name: &str, - id: u32, - rest: &'a str, - nested: bool, - ) -> Result<&'a str, GrammarParseError> { - let mut rule = Vec::new(); - let rest = self.parse_sequence(rest.trim_start(), name, &mut rule, nested)?; - let mut rest = Self::consume_whitespace_and_comments(rest, nested); - while rest.starts_with('|') { - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_ALT, - value: 0, - }); - rest = Self::consume_whitespace_and_comments(&rest[1..], true); - rest = self.parse_sequence(rest, name, &mut rule, nested)?; - } - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, - value: 0, - }); - self.add_rule(id, rule); - Ok(rest) - } - - fn add_rule(&mut self, id: u32, rule: Vec) { - let id = id as usize; - if self.rules.len() <= id { - self.rules.resize(id + 1, Vec::new()); - } - self.rules[id] = rule; - } - - #[allow(clippy::too_many_lines)] - fn parse_sequence<'a>( - &mut self, - mut rest: &'a str, - name: &str, - rule: &mut Vec, - nested: bool, - ) -> Result<&'a str, GrammarParseError> { - let mut last_sym_start = rule.len(); - while !rest.is_empty() { - let first_char = - rest.chars() - .next() - .ok_or(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "sequence", - })?; - if first_char == '"' { - rest = &rest[1..]; - last_sym_start = rule.len(); - while !rest.starts_with('"') { - let (c, r) = Self::parse_char(rest)?; - rest = r; - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, - value: c as _, - }); - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char == '[' { - rest = &rest[1..]; - let start_type = if rest.starts_with('^') { - rest = &rest[1..]; - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_NOT - } else { - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR - }; - last_sym_start = rule.len(); - while !rest.starts_with(']') { - let (c, r) = Self::parse_char(rest)?; - rest = r; - let gre_type = if last_sym_start < rule.len() { - llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_ALT - } else { - start_type - }; - rule.push(llama_grammar_element { - type_: gre_type, - value: c as _, - }); - if rest.starts_with('-') && rest.get(1..).is_some_and(|r| !r.starts_with(']')) { - let (c, r) = Self::parse_char(&rest[1..])?; - rest = r; - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER, - value: c as _, - }); - } - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char.is_alphabetic() { - let (name, r) = Self::parse_name(rest)?; - rest = Self::consume_whitespace_and_comments(r, nested); - let ref_rule_id = self.get_symbol_id(name); - last_sym_start = rule.len(); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: ref_rule_id, - }); - } else if first_char == '(' { - rest = rest[1..].trim_start(); - let sub_rule_id = self.generate_symbol_id(name); - rest = self.parse_alternatives(name, sub_rule_id, rest, true)?; - last_sym_start = rule.len(); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - if !rest.starts_with(')') { - return Err(GrammarParseError::MissingClosingBracketInNestedRule { - name: name.to_string(), - }); - } - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else if first_char == '*' || first_char == '+' || first_char == '?' { - if last_sym_start == rule.len() { - return Err(GrammarParseError::ExpectedRuleBeforePostfixOperator { - name: name.to_string(), - }); - } - let sub_rule_id = self.generate_symbol_id(name); - let mut sub_rule: Vec = - rule.iter().skip(last_sym_start).copied().collect(); - if rest.starts_with(['*', '+']) { - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - } - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_ALT, - value: 0, - }); - if rest.starts_with('+') { - sub_rule.extend(rule.iter().skip(last_sym_start).copied()); - } - sub_rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, - value: 0, - }); - self.add_rule(sub_rule_id, sub_rule); - - rule.truncate(last_sym_start); - rule.push(llama_grammar_element { - type_: llama_cpp_sys_2::LLAMA_GRETYPE_RULE_REF, - value: sub_rule_id, - }); - - rest = Self::consume_whitespace_and_comments(&rest[1..], nested); - } else { - break; - } - } - - Ok(rest) - } - - fn parse_hex(rest: &str, size: usize) -> Result<(llama_gretype, &str), GrammarParseError> { - if rest.len() < size { - return Err(GrammarParseError::IncorrectHexSize { - expected_size: size, - actual: rest.to_string(), - }); - } - - let (hex, rest) = rest.split_at(size); - let value = - u32::from_str_radix(hex, 16).map_err(|error| GrammarParseError::ParseHexError { - string: hex.to_string(), - error, - })?; - - Ok((value as llama_gretype, rest)) - } - - fn parse_char(rest: &str) -> Result<(llama_gretype, &str), GrammarParseError> { - if let Some(rest) = rest.strip_prefix('\\') { - let Some(escaped) = rest.chars().next() else { - return Err(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "escape char", - }); - }; - let rest = &rest[escaped.len_utf8()..]; - match escaped { - 'x' => Self::parse_hex(rest, 2), - 'u' => Self::parse_hex(rest, 4), - 'U' => Self::parse_hex(rest, 8), - 't' => Ok((u32::from('\t') as llama_gretype, rest)), - 'r' => Ok((u32::from('\r') as llama_gretype, rest)), - 'n' => Ok((u32::from('\n') as llama_gretype, rest)), - '\\' => Ok((u32::from('\\') as llama_gretype, rest)), - '"' => Ok((u32::from('"') as llama_gretype, rest)), - '[' => Ok((u32::from('[') as llama_gretype, rest)), - ']' => Ok((u32::from(']') as llama_gretype, rest)), - c => Err(GrammarParseError::UnknownEscape { escape: c }), - } - } else if let Some(c) = rest.chars().next() { - Ok((u32::from(c) as llama_gretype, &rest[c.len_utf8()..])) - } else { - Err(GrammarParseError::UnexpectedEndOfInput { - parse_stage: "char", - }) - } - } - - fn parse_name(rest: &str) -> Result<(&str, &str), GrammarParseError> { - let name_end = rest - .find(|c: char| !c.is_alphanumeric() && c != '-' && c != '_') - .ok_or(GrammarParseError::MissingSpaceAfterName { - rest: rest.to_string(), - })?; - let name = &rest[..name_end]; - let rest = &rest[name_end..]; - Ok((name, rest)) - } -} - -/// An error that can occur creating a grammar from a string. -#[derive(thiserror::Error, Debug)] -pub enum LlamaGrammarFromStrError { - /// There was an error parsing the grammar. - #[error("Failed to parse grammar {0}")] - ParseError(#[from] GrammarParseError), - /// Llama-cpp returned null - this can occur for many reasons, but should ideally be caught on - /// the rust side beforehand. - #[error("llama-cpp returned null")] - LlamaCppNullError, -} - -impl FromStr for ParseState { - type Err = GrammarParseError; - - fn from_str(s: &str) -> Result { - let mut parse_state = ParseState::new(); - let mut remaining = Some(s); - while let Some(str) = remaining { - remaining = parse_state.parse_rule(str)?; - } - Ok(parse_state) - } -} - -impl FromStr for LlamaGrammar { - type Err = LlamaGrammarFromStrError; - - fn from_str(s: &str) -> Result { - let mut parse_state = ParseState::from_str(s)?; - - let n_rules = parse_state.rules.len(); - let root_id = parse_state.get_symbol_id("root"); - let mut vec = parse_state - .rules - .iter_mut() - .map(|v| v.as_ptr()) - .collect::>(); - let rules = vec.as_mut_ptr(); - - let grammar = - unsafe { llama_cpp_sys_2::llama_grammar_init(rules, n_rules, root_id as usize) }; - - Ok(Self { - parse: parse_state, - grammar: NonNull::new(grammar).ok_or(LlamaGrammarFromStrError::LlamaCppNullError)?, - }) - } -} - -impl Drop for LlamaGrammar { - fn drop(&mut self) { - unsafe { llama_cpp_sys_2::llama_grammar_free(self.grammar.as_ptr()) } - } -} - -#[cfg(test)] -mod tests; diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 2717c845..f2ac5313 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -23,10 +23,11 @@ use std::path::PathBuf; use std::string::FromUtf8Error; pub mod context; -pub mod grammar; pub mod llama_backend; pub mod llama_batch; +mod log; pub mod model; +pub mod sampling; pub mod timing; pub mod token; pub mod token_type; @@ -62,22 +63,41 @@ pub enum LLamaCppError { /// see [`EmbeddingsError`] #[error(transparent)] EmbeddingError(#[from] EmbeddingsError), + // See [`LlamaSamplerError`] } /// There was an error while getting the chat template from a model. #[derive(Debug, Eq, PartialEq, thiserror::Error)] pub enum ChatTemplateError { - /// the buffer was too small. - #[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")] - BuffSizeError(usize), - /// gguf has no chat template - #[error("the model has no meta val - returned code {0}")] - MissingTemplate(i32), + /// gguf has no chat template (by that name) + #[error("chat template not found - returned null pointer")] + MissingTemplate, + + /// chat template contained a null byte + #[error("null byte in string {0}")] + NullError(#[from] NulError), + /// The chat template was not valid utf8. #[error(transparent)] Utf8Error(#[from] std::str::Utf8Error), } +/// Failed fetching metadata value +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum MetaValError { + /// The provided string contains an unexpected null-byte + #[error("null byte in string {0}")] + NullError(#[from] NulError), + + /// The returned data contains invalid UTF8 data + #[error("FromUtf8Error {0}")] + FromUtf8Error(#[from] FromUtf8Error), + + /// Got negative return value. This happens if the key or index queried does not exist. + #[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")] + NegativeReturn(i32), +} + /// Failed to Load context #[derive(Debug, Eq, PartialEq, thiserror::Error)] pub enum LlamaContextLoadError { @@ -197,6 +217,8 @@ pub enum LlamaLoraAdapterRemoveError { /// get the time (in microseconds) according to llama.cpp /// ``` /// # use llama_cpp_2::llama_time_us; +/// # use llama_cpp_2::llama_backend::LlamaBackend; +/// let backend = LlamaBackend::init().unwrap(); /// let time = llama_time_us(); /// assert!(time > 0); /// ``` @@ -279,9 +301,6 @@ pub enum NewLlamaChatMessageError { /// Failed to apply model chat template. #[derive(Debug, thiserror::Error)] pub enum ApplyChatTemplateError { - /// the buffer was too small. - #[error("The buffer was too small. Please contact a maintainer and we will update it.")] - BuffSizeError, /// the string contained a null byte and thus could not be converted to a c string. #[error("{0}")] NulError(#[from] NulError), @@ -294,6 +313,8 @@ pub enum ApplyChatTemplateError { /// /// ``` /// # use std::time::Duration; +/// # use llama_cpp_2::llama_backend::LlamaBackend; +/// let backend = LlamaBackend::init().unwrap(); /// use llama_cpp_2::ggml_time_us; /// /// let start = ggml_time_us(); @@ -325,3 +346,76 @@ pub fn ggml_time_us() -> i64 { pub fn llama_supports_mlock() -> bool { unsafe { llama_cpp_sys_2::llama_supports_mlock() } } + +/// Options to configure how llama.cpp logs are intercepted. +#[derive(Default, Debug, Clone)] +pub struct LogOptions { + disabled: bool, +} + +impl LogOptions { + /// If enabled, logs are sent to tracing. If disabled, all logs are suppressed. Default is for + /// logs to be sent to tracing. + pub fn with_logs_enabled(mut self, enabled: bool) -> Self { + self.disabled = !enabled; + self + } +} + +extern "C" fn logs_to_trace( + level: llama_cpp_sys_2::ggml_log_level, + text: *const ::std::os::raw::c_char, + data: *mut ::std::os::raw::c_void, +) { + // In the "fast-path" (i.e. the vast majority of logs) we want to avoid needing to take the log state + // lock at all. Similarly, we try to avoid any heap allocations within this function. This is accomplished + // by being a dummy pass-through to tracing in the normal case of DEBUG/INFO/WARN/ERROR logs that are + // newline terminated and limiting the slow-path of locks and/or heap allocations for other cases. + use std::borrow::Borrow; + + let log_state = unsafe { &*(data as *const log::State) }; + + let text = unsafe { std::ffi::CStr::from_ptr(text) }; + let text = text.to_string_lossy(); + let text: &str = text.borrow(); + + if log_state.options.disabled { + return; + } + + // As best I can tell llama.cpp / ggml require all log format strings at call sites to have the '\n'. + // If it's missing, it means that you expect more logs via CONT (or there's a typo in the codebase). To + // distinguish typo from intentional support for CONT, we have to buffer until the next message comes in + // to know how to flush it. + + if level == llama_cpp_sys_2::GGML_LOG_LEVEL_CONT { + log_state.cont_buffered_log(text); + } else if text.ends_with('\n') { + log_state.emit_non_cont_line(level, text); + } else { + log_state.buffer_non_cont(level, text); + } +} + +/// Redirect llama.cpp logs into tracing. +pub fn send_logs_to_tracing(options: LogOptions) { + // TODO: Reinitialize the state to support calling send_logs_to_tracing multiple times. + + // We set up separate log states for llama.cpp and ggml to make sure that CONT logs between the two + // can't possibly interfere with each other. In other words, if llama.cpp emits a log without a trailing + // newline and calls a GGML function, the logs won't be weirdly intermixed and instead we'll llama.cpp logs + // will CONT previous llama.cpp logs and GGML logs will CONT previous ggml logs. + let llama_heap_state = Box::as_ref( + log::LLAMA_STATE + .get_or_init(|| Box::new(log::State::new(log::Module::LlamaCpp, options.clone()))), + ) as *const _; + let ggml_heap_state = Box::as_ref( + log::GGML_STATE.get_or_init(|| Box::new(log::State::new(log::Module::GGML, options))), + ) as *const _; + + unsafe { + // GGML has to be set after llama since setting llama sets ggml as well. + llama_cpp_sys_2::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _); + llama_cpp_sys_2::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _); + } +} diff --git a/llama-cpp-2/src/llama_backend.rs b/llama-cpp-2/src/llama_backend.rs index 938356f7..1cc3fa3d 100644 --- a/llama-cpp-2/src/llama_backend.rs +++ b/llama-cpp-2/src/llama_backend.rs @@ -70,6 +70,21 @@ impl LlamaBackend { Ok(LlamaBackend {}) } + /// Was the code built for a GPU backend & is a supported one available. + pub fn supports_gpu_offload(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() } + } + + /// Does this platform support loading the model via mmap. + pub fn supports_mmap(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_mmap() } + } + + /// Does this platform support locking the model in RAM. + pub fn supports_mlock(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_mlock() } + } + /// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`. pub fn void_logs(&mut self) { unsafe extern "C" fn void_log( diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index e52bfa9e..b96588c7 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -10,6 +10,7 @@ pub struct LlamaBatch { allocated: usize, /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. pub(crate) initialized_logits: Vec, + #[allow(clippy::doc_markdown)] /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` pub(crate) llama_batch: llama_batch, } @@ -20,6 +21,9 @@ pub enum BatchAddError { /// There was not enough space in the batch to add the token. #[error("Insufficient Space of {0}")] InsufficientSpace(usize), + /// Empty buffer is provided for [`LlamaBatch::get_one`] + #[error("Empty buffer")] + EmptyBuffer, } impl LlamaBatch { @@ -149,6 +153,40 @@ impl LlamaBatch { } } + /// ``llama_batch_get_one`` + /// Return batch for single sequence of tokens + /// + /// NOTE: this is a helper function to facilitate transition to the new batch API + /// + /// # Errors + /// If the provided token buffer is empty. + /// + /// # Panics + /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`]. + pub fn get_one(tokens: &[LlamaToken]) -> Result { + if tokens.is_empty() { + return Err(BatchAddError::EmptyBuffer); + } + let batch = unsafe { + let ptr = tokens.as_ptr() as *mut i32; + llama_cpp_sys_2::llama_batch_get_one( + ptr, + tokens + .len() + .try_into() + .expect("number of tokens exceeds i32::MAX"), + ) + }; + let batch = Self { + allocated: 0, + initialized_logits: vec![(tokens.len() - 1) + .try_into() + .expect("number of tokens exceeds i32::MAX + 1")], + llama_batch: batch, + }; + Ok(batch) + } + /// Returns the number of tokens in the batch. #[must_use] pub fn n_tokens(&self) -> i32 { @@ -170,7 +208,9 @@ impl Drop for LlamaBatch { /// # } fn drop(&mut self) { unsafe { - llama_batch_free(self.llama_batch); + if self.allocated > 0 { + llama_batch_free(self.llama_batch); + } } } } diff --git a/llama-cpp-2/src/log.rs b/llama-cpp-2/src/log.rs new file mode 100644 index 00000000..e77f94bb --- /dev/null +++ b/llama-cpp-2/src/log.rs @@ -0,0 +1,259 @@ +use super::LogOptions; +use std::sync::OnceLock; +use tracing_core::{callsite, field, identify_callsite, Interest, Kind, Metadata}; + +static FIELD_NAMES: &[&str] = &["message", "module"]; + +struct OverridableFields { + message: tracing::field::Field, + target: tracing::field::Field, +} + +macro_rules! log_cs { + ($level:expr, $cs:ident, $meta:ident, $fields:ident, $ty:ident) => { + struct $ty; + static $cs: $ty = $ty; + static $meta: Metadata<'static> = Metadata::new( + "log event", + "llama-cpp-2", + $level, + ::core::option::Option::None, + ::core::option::Option::None, + ::core::option::Option::None, + field::FieldSet::new(FIELD_NAMES, identify_callsite!(&$cs)), + Kind::EVENT, + ); + static $fields: std::sync::LazyLock = std::sync::LazyLock::new(|| { + let fields = $meta.fields(); + OverridableFields { + message: fields.field("message").unwrap(), + target: fields.field("module").unwrap(), + } + }); + + impl callsite::Callsite for $ty { + fn set_interest(&self, _: Interest) {} + fn metadata(&self) -> &'static Metadata<'static> { + &$meta + } + } + }; +} +log_cs!( + tracing_core::Level::DEBUG, + DEBUG_CS, + DEBUG_META, + DEBUG_FIELDS, + DebugCallsite +); +log_cs!( + tracing_core::Level::INFO, + INFO_CS, + INFO_META, + INFO_FIELDS, + InfoCallsite +); +log_cs!( + tracing_core::Level::WARN, + WARN_CS, + WARN_META, + WARN_FIELDS, + WarnCallsite +); +log_cs!( + tracing_core::Level::ERROR, + ERROR_CS, + ERROR_META, + ERROR_FIELDS, + ErrorCallsite +); + +#[derive(Clone, Copy)] +pub(super) enum Module { + GGML, + LlamaCpp, +} + +impl Module { + const fn name(&self) -> &'static str { + match self { + Module::GGML => "ggml", + Module::LlamaCpp => "llama.cpp", + } + } +} + +fn meta_for_level( + level: llama_cpp_sys_2::ggml_log_level, +) -> (&'static Metadata<'static>, &'static OverridableFields) { + match level { + llama_cpp_sys_2::GGML_LOG_LEVEL_DEBUG => (&DEBUG_META, &DEBUG_FIELDS), + llama_cpp_sys_2::GGML_LOG_LEVEL_INFO => (&INFO_META, &INFO_FIELDS), + llama_cpp_sys_2::GGML_LOG_LEVEL_WARN => (&WARN_META, &WARN_FIELDS), + llama_cpp_sys_2::GGML_LOG_LEVEL_ERROR => (&ERROR_META, &ERROR_FIELDS), + _ => { + unreachable!("Illegal log level to be called here") + } + } +} + +pub(super) struct State { + pub(super) options: LogOptions, + module: Module, + buffered: std::sync::Mutex>, + previous_level: std::sync::atomic::AtomicI32, + is_buffering: std::sync::atomic::AtomicBool, +} + +impl State { + pub(super) fn new(module: Module, options: LogOptions) -> Self { + Self { + options, + module, + buffered: Default::default(), + previous_level: Default::default(), + is_buffering: Default::default(), + } + } + + fn generate_log(target: Module, level: llama_cpp_sys_2::ggml_log_level, text: &str) { + // Annoying but tracing requires that the provided target name is a string literal and + // even &'static str isn't enough so we have to duplicate the generation AND we can't even + // extract the interrior module within llama.cpp/ggml to be able to propagate it forward. + // This happens because the target is part of a static variable injected by the macro that's + // initialized with said target. + + let (module, text) = text + .char_indices() + .take_while(|(_, c)| c.is_ascii_lowercase() || *c == '_') + .last() + .and_then(|(pos, _)| { + let next_two = text.get(pos + 1..pos + 3); + if next_two == Some(": ") { + let (sub_module, text) = text.split_at(pos + 1); + let text = text.split_at(2).1; + Some((Some(format!("{}::{sub_module}", target.name())), text)) + } else { + None + } + }) + .unwrap_or((None, text)); + + let (meta, fields) = meta_for_level(level); + + tracing::dispatcher::get_default(|dispatcher| { + if dispatcher.enabled(meta) { + dispatcher.event(&tracing::Event::new( + meta, + &meta.fields().value_set(&[ + (&fields.message, Some(&text as &dyn tracing::field::Value)), + ( + &fields.target, + module.as_ref().map(|s| s as &dyn tracing::field::Value), + ), + ]), + )); + } + }); + } + + /// Append more text to the previously buffered log. The text may or may not end with a newline. + pub(super) fn cont_buffered_log(&self, text: &str) { + let mut lock = self.buffered.lock().unwrap(); + + if let Some((previous_log_level, mut buffer)) = lock.take() { + buffer.push_str(text); + if buffer.ends_with('\n') { + self.is_buffering + .store(false, std::sync::atomic::Ordering::Release); + Self::generate_log(self.module, previous_log_level, buffer.as_str()); + } else { + *lock = Some((previous_log_level, buffer)); + } + } else { + let level = self + .previous_level + .load(std::sync::atomic::Ordering::Acquire) + as llama_cpp_sys_2::ggml_log_level; + tracing::warn!( + inferred_level = level, + text = text, + origin = "crate", + "llma.cpp sent out a CONT log without any previously buffered message" + ); + *lock = Some((level, text.to_string())); + } + } + + /// Start buffering a message. Not the CONT log level and text is missing a newline. + pub(super) fn buffer_non_cont(&self, level: llama_cpp_sys_2::ggml_log_level, text: &str) { + debug_assert!(!text.ends_with('\n')); + debug_assert_ne!(level, llama_cpp_sys_2::GGML_LOG_LEVEL_CONT); + + if let Some((previous_log_level, buffer)) = self + .buffered + .lock() + .unwrap() + .replace((level, text.to_string())) + { + tracing::warn!( + level = previous_log_level, + text = &buffer, + origin = "crate", + "Message buffered unnnecessarily due to missing newline and not followed by a CONT" + ); + Self::generate_log(self.module, previous_log_level, buffer.as_str()) + } + + self.is_buffering + .store(true, std::sync::atomic::Ordering::Release); + self.previous_level + .store(level as i32, std::sync::atomic::Ordering::Release); + } + + // Emit a normal unbuffered log message (not the CONT log level and the text ends with a newline). + pub(super) fn emit_non_cont_line(&self, level: llama_cpp_sys_2::ggml_log_level, text: &str) { + debug_assert!(text.ends_with('\n')); + debug_assert_ne!(level, llama_cpp_sys_2::GGML_LOG_LEVEL_CONT); + + if self + .is_buffering + .swap(false, std::sync::atomic::Ordering::Acquire) + { + if let Some((buf_level, buf_text)) = self.buffered.lock().unwrap().take() { + // This warning indicates a bug within llama.cpp + tracing::warn!(level = buf_level, text = buf_text, origin = "crate", "llama.cpp message buffered spuriously due to missing \\n and being followed by a non-CONT message!"); + Self::generate_log(self.module, buf_level, buf_text.as_str()); + } + } + + self.previous_level + .store(level as i32, std::sync::atomic::Ordering::Release); + + let (text, newline) = text.split_at(text.len() - 1); + debug_assert_eq!(newline, "\n"); + + match level { + llama_cpp_sys_2::GGML_LOG_LEVEL_NONE => { + // TODO: Support logging this to stdout directly via options? + tracing::info!(no_log_level = true, text); + } + llama_cpp_sys_2::GGML_LOG_LEVEL_DEBUG + | llama_cpp_sys_2::GGML_LOG_LEVEL_INFO + | llama_cpp_sys_2::GGML_LOG_LEVEL_WARN + | llama_cpp_sys_2::GGML_LOG_LEVEL_ERROR => Self::generate_log(self.module, level, text), + llama_cpp_sys_2::GGML_LOG_LEVEL_CONT => unreachable!(), + _ => { + tracing::warn!( + level = level, + text = text, + origin = "crate", + "Unknown llama.cpp log level" + ) + } + } + } +} + +pub(super) static LLAMA_STATE: OnceLock> = OnceLock::new(); +pub(super) static GGML_STATE: OnceLock> = OnceLock::new(); diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 54c82bd5..b8cd26bb 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -1,10 +1,10 @@ //! A safe wrapper around `llama_model`. -use std::ffi::CStr; -use std::ffi::CString; +use std::ffi::{c_char, CStr, CString}; use std::num::NonZeroU16; use std::os::raw::c_int; use std::path::Path; use std::ptr::NonNull; +use std::str::Utf8Error; use crate::context::params::LlamaContextParams; use crate::context::LlamaContext; @@ -13,8 +13,9 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs}; use crate::{ - ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError, - LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError, + ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, + LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, + StringToTokenError, TokenToStringError, }; pub mod params; @@ -32,7 +33,43 @@ pub struct LlamaModel { #[repr(transparent)] #[allow(clippy::module_name_repetitions)] pub struct LlamaLoraAdapter { - pub(crate) lora_adapter: NonNull, + pub(crate) lora_adapter: NonNull, +} + +/// A performance-friendly wrapper around [LlamaModel::chat_template] which is then +/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM +/// prompt. Internally the template is stored as a CString to avoid round-trip conversions +/// within the FFI. +#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)] +pub struct LlamaChatTemplate(CString); + +impl LlamaChatTemplate { + /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61) + /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret. + pub fn new(template: &str) -> Result { + Ok(Self(CString::new(template)?)) + } + + /// Accesses the template as a c string reference. + pub fn as_c_str(&self) -> &CStr { + &self.0 + } + + /// Attempts to convert the CString into a Rust str reference. + pub fn to_str(&self) -> Result<&str, Utf8Error> { + self.0.to_str() + } + + /// Convenience method to create an owned String. + pub fn to_string(&self) -> Result { + self.to_str().map(str::to_string) + } +} + +impl std::fmt::Debug for LlamaChatTemplate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } } /// A Safe wrapper around `llama_chat_message` @@ -44,6 +81,9 @@ pub struct LlamaChatMessage { impl LlamaChatMessage { /// Create a new `LlamaChatMessage` + /// + /// # Errors + /// If either of ``role`` or ``content`` contain null bytes. pub fn new(role: String, content: String) -> Result { Ok(Self { role: CString::new(role)?, @@ -52,6 +92,15 @@ impl LlamaChatMessage { } } +/// The Rope type that's used within the model. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RopeType { + Norm, + NeoX, + MRope, + Vision, +} + /// How to determine if we should prepend a bos token to tokens #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AddBos { @@ -75,6 +124,10 @@ unsafe impl Send for LlamaModel {} unsafe impl Sync for LlamaModel {} impl LlamaModel { + pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab { + unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) } + } + /// get the number of tokens the model was trained on /// /// # Panics @@ -100,31 +153,31 @@ impl LlamaModel { /// Get the beginning of stream token. #[must_use] pub fn token_bos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the end of stream token. #[must_use] pub fn token_eos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the newline token. #[must_use] pub fn token_nl(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) }; LlamaToken(token) } /// Check if a token represents the end of generation (end of turn, end of sequence, etc.) #[must_use] pub fn is_eog_token(&self, token: LlamaToken) -> bool { - unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) } + unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) } } - /// Get the decoder start token token. + /// Get the decoder start token. #[must_use] pub fn decode_start_token(&self) -> LlamaToken { let token = @@ -142,20 +195,33 @@ impl LlamaModel { token: LlamaToken, special: Special, ) -> Result { - self.token_to_str_with_size(token, 32, special) + let bytes = self.token_to_bytes(token, special)?; + Ok(String::from_utf8(bytes)?) } /// Convert single token to bytes. /// /// # Errors - /// /// See [`TokenToStringError`] for more information. + /// + /// # Panics + /// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by + /// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never + /// happen. pub fn token_to_bytes( &self, token: LlamaToken, special: Special, ) -> Result, TokenToStringError> { - self.token_to_bytes_with_size(token, 32, special, None) + match self.token_to_bytes_with_size(token, 8, special, None) { + Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size( + token, + (-i).try_into().expect("Error buffer size is positive"), + special, + None, + ), + x => x, + } } /// Convert a vector of tokens to a single string. @@ -168,15 +234,15 @@ impl LlamaModel { tokens: &[LlamaToken], special: Special, ) -> Result { - let mut builder = String::with_capacity(tokens.len() * 4); - for str in tokens + let mut builder: Vec = Vec::with_capacity(tokens.len() * 4); + for piece in tokens .iter() .copied() - .map(|t| self.token_to_str(t, special)) + .map(|t| self.token_to_bytes(t, special)) { - builder += &str?; + builder.extend_from_slice(&piece?); } - Ok(builder) + Ok(String::from_utf8(builder)?) } /// Convert a string to a Vector of tokens. @@ -212,7 +278,7 @@ impl LlamaModel { }; let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos)); - let mut buffer = Vec::with_capacity(tokens_estimation); + let mut buffer: Vec = Vec::with_capacity(tokens_estimation); let c_string = CString::new(str)?; let buffer_capacity = @@ -220,10 +286,10 @@ impl LlamaModel { let size = unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, - buffer.as_mut_ptr(), + buffer.as_mut_ptr().cast::(), buffer_capacity, add_bos, true, @@ -236,10 +302,10 @@ impl LlamaModel { buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger ")); unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, - buffer.as_mut_ptr(), + buffer.as_mut_ptr().cast::(), -size, add_bos, true, @@ -253,7 +319,7 @@ impl LlamaModel { // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size` unsafe { buffer.set_len(size) } - Ok(buffer.into_iter().map(LlamaToken).collect()) + Ok(buffer) } /// Get the type of a token. @@ -263,14 +329,14 @@ impl LlamaModel { /// If the token type is not known to this library. #[must_use] pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs { - let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) }; + let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) }; LlamaTokenAttrs::try_from(token_type).expect("token type is valid") } /// Convert a token to a string with a specified buffer size. /// - /// Generally you should use [`LlamaModel::token_to_str`] instead as 8 bytes is enough for most words and - /// the extra bytes do not really matter. + /// Generally you should use [`LlamaModel::token_to_str`] as it is able to decode tokens with + /// any length. /// /// # Errors /// @@ -294,8 +360,8 @@ impl LlamaModel { /// Convert a token to bytes with a specified buffer size. /// - /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and - /// the extra bytes do not really matter. + /// Generally you should use [`LlamaModel::token_to_bytes`] as it is able to handle tokens of + /// any length. /// /// # Errors /// @@ -314,18 +380,16 @@ impl LlamaModel { lstrip: Option, ) -> Result, TokenToStringError> { if token == self.token_nl() { - return Ok(String::from("\n").into_bytes()); + return Ok(b"\n".to_vec()); } // unsure what to do with this in the face of the 'special' arg + attr changes let attrs = self.token_attr(token); - if attrs.contains(LlamaTokenAttr::Control) - && (token == self.token_bos() || token == self.token_eos()) - { - return Ok(Vec::new()); - } else if attrs.is_empty() + if attrs.is_empty() || attrs .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused) + || attrs.contains(LlamaTokenAttr::Control) + && (token == self.token_bos() || token == self.token_eos()) { return Ok(Vec::new()); } @@ -342,7 +406,7 @@ impl LlamaModel { let lstrip = lstrip.map_or(0, |it| i32::from(it.get())); let size = unsafe { llama_cpp_sys_2::llama_token_to_piece( - self.model.as_ptr(), + self.vocab_ptr(), token.0, buf, len, @@ -369,7 +433,7 @@ impl LlamaModel { /// without issue. #[must_use] pub fn n_vocab(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) } } /// The type of vocab the model was trained on. @@ -379,7 +443,8 @@ impl LlamaModel { /// If llama-cpp emits a vocab type that is not known to this library. #[must_use] pub fn vocab_type(&self) -> VocabType { - let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) }; + // llama_cpp_sys_2::llama_model_get_vocab + let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) }; VocabType::try_from(vocab_type).expect("invalid vocab type") } @@ -390,41 +455,143 @@ impl LlamaModel { unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) } } - /// Get chat template from model. + /// Returns the total size of all the tensors in the model in bytes. + pub fn size(&self) -> u64 { + unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) } + } + + /// Returns the number of parameters in the model. + pub fn n_params(&self) -> u64 { + unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) } + } + + /// Returns whether the model is a recurrent network (Mamba, RWKV, etc) + pub fn is_recurrent(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) } + } + + /// Returns the number of layers within the model. + pub fn n_layer(&self) -> u32 { + // It's never possible for this to panic because while the API interface is defined as an int32_t, + // the field it's accessing is a uint32_t. + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap() + } + + /// Returns the number of attention heads within the model. + pub fn n_head(&self) -> u32 { + // It's never possible for this to panic because while the API interface is defined as an int32_t, + // the field it's accessing is a uint32_t. + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap() + } + + /// Returns the number of KV attention heads. + pub fn n_head_kv(&self) -> u32 { + // It's never possible for this to panic because while the API interface is defined as an int32_t, + // the field it's accessing is a uint32_t. + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) }).unwrap() + } + + /// Get metadata value as a string by key name + pub fn meta_val_str(&self, key: &str) -> Result { + let key_cstring = CString::new(key)?; + let key_ptr = key_cstring.as_ptr(); + + extract_meta_string( + |buf_ptr, buf_len| unsafe { + llama_cpp_sys_2::llama_model_meta_val_str( + self.model.as_ptr(), + key_ptr, + buf_ptr, + buf_len, + ) + }, + 256, + ) + } + + /// Get the number of metadata key/value pairs + pub fn meta_count(&self) -> i32 { + unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) } + } + + /// Get metadata key name by index + pub fn meta_key_by_index(&self, index: i32) -> Result { + extract_meta_string( + |buf_ptr, buf_len| unsafe { + llama_cpp_sys_2::llama_model_meta_key_by_index( + self.model.as_ptr(), + index, + buf_ptr, + buf_len, + ) + }, + 256, + ) + } + + /// Get metadata value as a string by index + pub fn meta_val_str_by_index(&self, index: i32) -> Result { + extract_meta_string( + |buf_ptr, buf_len| unsafe { + llama_cpp_sys_2::llama_model_meta_val_str_by_index( + self.model.as_ptr(), + index, + buf_ptr, + buf_len, + ) + }, + 256, + ) + } + + /// Returns the rope type of the model. + pub fn rope_type(&self) -> Option { + match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } { + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None, + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision), + rope_type => { + tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp"); + None + } + } + } + + /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned. + /// + /// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template + /// substitution applied to convert a list of messages into a prompt the LLM can use to complete + /// the chat. + /// + /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja), + /// to parse jinja templates not supported by the llama.cpp template engine. /// /// # Errors /// - /// * If the model has no chat template + /// * If the model has no chat template by that name /// * If the chat template is not a valid [`CString`]. - #[allow(clippy::missing_panics_doc)] // we statically know this will not panic as - pub fn get_chat_template(&self, buf_size: usize) -> Result { - // longest known template is about 1200 bytes from llama.cpp - let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null"); - let chat_ptr = chat_temp.into_raw(); - let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes"); - - let ret = unsafe { - llama_cpp_sys_2::llama_model_meta_val_str( - self.model.as_ptr(), - chat_name.as_ptr(), - chat_ptr, - buf_size, - ) + pub fn chat_template( + &self, + name: Option<&str>, + ) -> Result { + let name_cstr = name.map(CString::new); + let name_ptr = match name_cstr { + Some(Ok(name)) => name.as_ptr(), + _ => std::ptr::null(), }; + let result = + unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) }; - if ret < 0 { - return Err(ChatTemplateError::MissingTemplate(ret)); - } - - let template_c = unsafe { CString::from_raw(chat_ptr) }; - let template = template_c.to_str()?; - - let ret: usize = ret.try_into().unwrap(); - if template.len() < ret { - return Err(ChatTemplateError::BuffSizeError(ret + 1)); + // Convert result to Rust String if not null + if result.is_null() { + Err(ChatTemplateError::MissingTemplate) + } else { + let chat_template_cstr = unsafe { CStr::from_ptr(result) }; + let chat_template = CString::new(chat_template_cstr.to_bytes())?; + Ok(LlamaChatTemplate(chat_template)) } - - Ok(template.to_owned()) } /// Loads a model from a file. @@ -474,7 +641,7 @@ impl LlamaModel { let cstr = CString::new(path)?; let adapter = - unsafe { llama_cpp_sys_2::llama_lora_adapter_init(self.model.as_ptr(), cstr.as_ptr()) }; + unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) }; let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?; @@ -508,22 +675,32 @@ impl LlamaModel { /// Apply the models chat template to some messages. /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// - /// `tmpl` of None means to use the default template provided by llama.cpp for the model + /// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given + /// a null pointer for the template, this requires an explicit template to be specified. If you want to + /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template + /// string. + /// + /// Use [Self::chat_template] to retrieve the template baked into the model (this is the preferred + /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM). + /// + /// You probably want to set `add_ass` to true so that the generated template string ends with a the + /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate + /// one into the output and the output may also have unexpected output aside from that. /// /// # Errors /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. #[tracing::instrument(skip_all)] pub fn apply_chat_template( &self, - tmpl: Option, - chat: Vec, + tmpl: &LlamaChatTemplate, + chat: &[LlamaChatMessage], add_ass: bool, ) -> Result { // Buffer is twice the length of messages per their recommendation let message_length = chat.iter().fold(0, |acc, c| { acc + c.role.to_bytes().len() + c.content.to_bytes().len() }); - let mut buff = vec![0; message_length * 4]; + let mut buff: Vec = vec![0; message_length * 2]; // Build our llama_cpp_sys_2 chat messages let chat: Vec = chat @@ -534,36 +711,73 @@ impl LlamaModel { }) .collect(); - // Set the tmpl pointer - let tmpl = tmpl.map(CString::new); - let tmpl_ptr = match &tmpl { - Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(), - None => std::ptr::null(), - }; + let tmpl_ptr = tmpl.0.as_ptr(); - let formatted_chat = unsafe { - let res = llama_cpp_sys_2::llama_chat_apply_template( - self.model.as_ptr(), + let res = unsafe { + llama_cpp_sys_2::llama_chat_apply_template( tmpl_ptr, chat.as_ptr(), chat.len(), add_ass, - buff.as_mut_ptr(), - buff.len() as i32, - ); - // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it - // The error message informs the user to contact a maintainer - if res > buff.len() as i32 { - return Err(ApplyChatTemplateError::BuffSizeError); - } - Ok::( - CStr::from_ptr(buff.as_mut_ptr()) - .to_string_lossy() - .to_string(), + buff.as_mut_ptr().cast::(), + buff.len().try_into().expect("Buffer size exceeds i32::MAX"), ) - }?; - Ok(formatted_chat) + }; + + if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") { + buff.resize(res.try_into().expect("res is negative"), 0); + + let res = unsafe { + llama_cpp_sys_2::llama_chat_apply_template( + tmpl_ptr, + chat.as_ptr(), + chat.len(), + add_ass, + buff.as_mut_ptr().cast::(), + buff.len().try_into().expect("Buffer size exceeds i32::MAX"), + ) + }; + assert_eq!(Ok(res), buff.len().try_into()); + } + buff.truncate(res.try_into().expect("res is negative")); + Ok(String::from_utf8(buff)?) + } +} + +/// Generic helper function for extracting string values from the C API +/// This are specifically useful for the the metadata functions, where we pass in a buffer +/// to be populated by a string, not yet knowing if the buffer is large enough. +/// If the buffer was not large enough, we get the correct length back, which can be used to +/// construct a buffer of appropriate size. +fn extract_meta_string(c_function: F, capacity: usize) -> Result +where + F: Fn(*mut c_char, usize) -> i32, +{ + let mut buffer = vec![0u8; capacity]; + + // call the foreign function + let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len()); + if result < 0 { + return Err(MetaValError::NegativeReturn(result)); + } + + // check if the response fit in our buffer + let returned_len = result as usize; + if returned_len >= capacity { + // buffer wasn't large enough, try again with the correct capacity. + return extract_meta_string(c_function, returned_len + 1); } + + // verify null termination + debug_assert_eq!( + buffer.get(returned_len), + Some(&0), + "should end with null byte" + ); + + // resize, convert, and return + buffer.truncate(returned_len); + Ok(String::from_utf8(buffer)?) } impl Drop for LlamaModel { diff --git a/llama-cpp-2/src/model/params/kv_overrides.rs b/llama-cpp-2/src/model/params/kv_overrides.rs index 8bbcbdd4..b17516a1 100644 --- a/llama-cpp-2/src/model/params/kv_overrides.rs +++ b/llama-cpp-2/src/model/params/kv_overrides.rs @@ -104,7 +104,7 @@ pub struct KvOverrideValueIterator<'a> { current: usize, } -impl<'a> Iterator for KvOverrideValueIterator<'a> { +impl Iterator for KvOverrideValueIterator<'_> { type Item = (CString, ParamOverrideValue); fn next(&mut self) -> Option { diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs new file mode 100644 index 00000000..96feb402 --- /dev/null +++ b/llama-cpp-2/src/sampling.rs @@ -0,0 +1,520 @@ +//! Safe wrapper around `llama_sampler`. + +use std::borrow::Borrow; +use std::ffi::{c_char, CString}; +use std::fmt::{Debug, Formatter}; + +use crate::context::LlamaContext; +use crate::model::LlamaModel; +use crate::token::data_array::LlamaTokenDataArray; +use crate::token::logit_bias::LlamaLogitBias; +use crate::token::LlamaToken; + +/// A safe wrapper around `llama_sampler`. +pub struct LlamaSampler { + pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler, +} + +impl Debug for LlamaSampler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlamaSamplerChain").finish() + } +} + +impl LlamaSampler { + /// Sample and accept a token from the idx-th output of the last evaluation + #[must_use] + pub fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken { + let token = unsafe { + llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) + }; + + LlamaToken(token) + } + + /// Applies this sampler to a [`LlamaTokenDataArray`]. + pub fn apply(&self, data_array: &mut LlamaTokenDataArray) { + data_array.apply_sampler(self); + } + + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers + /// (e.g. grammar, repetition, etc.) + pub fn accept(&mut self, token: LlamaToken) { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + } + + /// Accepts several tokens from the sampler or context, possibly updating the internal state of + /// certain samplers (e.g. grammar, repetition, etc.) + pub fn accept_many(&mut self, tokens: impl IntoIterator>) { + for token in tokens { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) } + } + } + + /// Accepts several tokens from the sampler or context, possibly updating the internal state of + /// certain samplers (e.g. grammar, repetition, etc.) + #[must_use] + pub fn with_tokens( + mut self, + tokens: impl IntoIterator>, + ) -> Self { + self.accept_many(tokens); + self + } + + /// Resets the internal state of the sampler. + /// + /// This can be useful when you want to start fresh with a sampler without creating a new instance. + pub fn reset(&mut self) { + unsafe { + llama_cpp_sys_2::llama_sampler_reset(self.sampler); + } + } + + /// Gets the random seed used by this sampler. + /// + /// Returns: + /// - For random samplers (dist, mirostat, mirostat_v2): returns their current seed + /// - For sampler chains: returns the first non-default seed found in reverse order + /// - For all other samplers: returns 0xFFFFFFFF + #[must_use] + pub fn get_seed(&self) -> u32 { + unsafe { llama_cpp_sys_2::llama_sampler_get_seed(self.sampler) } + } + + /// Combines a list of samplers into a single sampler that applies each component sampler one + /// after another. + /// + /// If you are using a chain to select a token, the chain should always end with one of + /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and + /// [`LlamaSampler::mirostat_v2`]. + #[must_use] + pub fn chain(samplers: impl IntoIterator, no_perf: bool) -> Self { + unsafe { + let chain = llama_cpp_sys_2::llama_sampler_chain_init( + llama_cpp_sys_2::llama_sampler_chain_params { no_perf }, + ); + + for sampler in samplers { + llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler); + + // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now + // owned by the chain + std::mem::forget(sampler); + } + + Self { sampler: chain } + } + } + + /// Same as [`Self::chain`] with `no_perf = false`. + /// + /// # Example + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// use llama_cpp_2::llama_backend::LlamaBackend; + /// let backend = LlamaBackend::init().unwrap(); + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([ + /// LlamaSampler::temp(0.5), + /// LlamaSampler::greedy(), + /// ])); + /// + /// assert_eq!(data_array.data[0].logit(), 0.); + /// assert_eq!(data_array.data[1].logit(), 2.); + /// assert_eq!(data_array.data[2].logit(), 4.); + /// + /// assert_eq!(data_array.data.len(), 3); + /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2))); + /// ``` + #[must_use] + pub fn chain_simple(samplers: impl IntoIterator) -> Self { + Self::chain(samplers, false) + } + + #[allow(clippy::doc_markdown)] + /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original + /// value, the rest are set to -inf + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5)); + /// + /// assert_eq!(data_array.data[0].logit(), 0.); + /// assert_eq!(data_array.data[1].logit(), 2.); + /// assert_eq!(data_array.data[2].logit(), 4.); + /// ``` + #[must_use] + pub fn temp(t: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) }; + Self { sampler } + } + + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper + /// . + #[must_use] + pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) }; + Self { sampler } + } + + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// LlamaTokenData::new(LlamaToken(3), 3., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::top_k(2)); + /// + /// assert_eq!(data_array.data.len(), 2); + /// assert_eq!(data_array.data[0].id(), LlamaToken(3)); + /// assert_eq!(data_array.data[1].id(), LlamaToken(2)); + /// ``` + #[must_use] + pub fn top_k(k: i32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) }; + Self { sampler } + } + + /// Top-nσ sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" + /// + /// + /// This method filters logits by selecting only those within *n* standard deviations of the mean. + /// + /// # Parameters + /// - `n`: Number of standard deviations from the mean to include in sampling + /// + /// # Example + /// ```rust + /// use llama_cpp_2::sampling::LlamaSampler; + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0), + /// LlamaTokenData::new(LlamaToken(1), 1.0, 0.0), + /// LlamaTokenData::new(LlamaToken(2), 2.0, 0.0), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::top_n_sigma(2.0)); + /// ``` + #[must_use] + pub fn top_n_sigma(n: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_n_sigma(n) }; + Self { sampler } + } + + /// Locally Typical Sampling implementation described in the paper . + #[must_use] + pub fn typical(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) }; + Self { sampler } + } + + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// + #[must_use] + pub fn top_p(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) }; + Self { sampler } + } + + /// Minimum P sampling as described in + #[must_use] + pub fn min_p(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) }; + Self { sampler } + } + + /// XTC sampler as described in + #[must_use] + pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) }; + Self { sampler } + } + + /// Grammar sampler + /// + /// # Panics + /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes. + #[must_use] + pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self { + let grammar_str = CString::new(grammar_str).unwrap(); + let grammar_root = CString::new(grammar_root).unwrap(); + + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_grammar( + model.vocab_ptr(), + grammar_str.as_ptr(), + grammar_root.as_ptr(), + ) + }; + Self { sampler } + } + + /// Lazy grammar sampler, introduced in + /// + /// This sampler enforces grammar rules only when specific trigger words or tokens are encountered. + /// + /// # Panics + /// - If `grammar_str` or `grammar_root` contain null bytes + /// - If any trigger word contains null bytes + #[must_use] + pub fn grammar_lazy( + model: &LlamaModel, + grammar_str: &str, + grammar_root: &str, + trigger_words: impl IntoIterator>, + trigger_tokens: &[LlamaToken], + ) -> Self { + let grammar_str = CString::new(grammar_str).unwrap(); + let grammar_root = CString::new(grammar_root).unwrap(); + + let trigger_word_cstrings: Vec = trigger_words + .into_iter() + .map(|word| CString::new(word.as_ref()).unwrap()) + .collect(); + + let mut trigger_word_ptrs: Vec<*const c_char> = trigger_word_cstrings + .iter() + .map(|cs| cs.as_ptr()) + .collect(); + + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_grammar_lazy( + model.vocab_ptr(), + grammar_str.as_ptr(), + grammar_root.as_ptr(), + trigger_word_ptrs.as_mut_ptr(), + trigger_word_ptrs.len(), + trigger_tokens.as_ptr().cast(), + trigger_tokens.len(), + ) + }; + + Self { sampler } + } + + /// DRY sampler, designed by p-e-w, as described in: + /// , porting Koboldcpp + /// implementation authored by pi6am: + /// + /// # Panics + /// If any string in ``seq_breakers`` contains null bytes. + #[allow(missing_docs)] + #[must_use] + pub fn dry( + model: &LlamaModel, + multiplier: f32, + base: f32, + allowed_length: i32, + penalty_last_n: i32, + seq_breakers: impl IntoIterator>, + ) -> Self { + let seq_breakers: Vec = seq_breakers + .into_iter() + .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes")) + .collect(); + let mut seq_breaker_pointers: Vec<*const c_char> = + seq_breakers.iter().map(|s| s.as_ptr()).collect(); + + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_dry( + model.vocab_ptr(), + model + .n_ctx_train() + .try_into() + .expect("n_ctx_train exceeds i32::MAX"), + multiplier, + base, + allowed_length, + penalty_last_n, + seq_breaker_pointers.as_mut_ptr(), + seq_breaker_pointers.len(), + ) + }; + Self { sampler } + } + + /// Penalizes tokens for being present in the context. + /// + /// Parameters: + /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) + /// - ``penalty_repeat``: 1.0 = disabled + /// - ``penalty_freq``: 0.0 = disabled + /// - ``penalty_present``: 0.0 = disabled + #[allow(clippy::too_many_arguments)] + #[must_use] + pub fn penalties( + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) -> Self { + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_penalties( + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + ) + }; + Self { sampler } + } + + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Parameters: + /// - ``n_vocab``: [`LlamaModel::n_vocab`] + /// - ``seed``: Seed to initialize random generation with. + /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the + /// generated text. A higher value corresponds to more surprising or less predictable text, + /// while a lower value corresponds to less surprising or more predictable text. + /// - ``eta``: The learning rate used to update `mu` based on the error between the target and + /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be + /// updated more quickly, while a smaller learning rate will result in slower updates. + /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary + /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. + /// In the paper, they use `m = 100`, but you can experiment with different values to see how + /// it affects the performance of the algorithm. + #[must_use] + pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { + let sampler = + unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) }; + Self { sampler } + } + + /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Parameters: + /// - ``seed``: Seed to initialize random generation with. + /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the + /// generated text. A higher value corresponds to more surprising or less predictable text, + /// while a lower value corresponds to less surprising or more predictable text. + /// - ``eta``: The learning rate used to update `mu` based on the error between the target and + /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be + /// updated more quickly, while a smaller learning rate will result in slower updates. + #[must_use] + pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) }; + Self { sampler } + } + + /// Selects a token at random based on each token's probabilities + #[must_use] + pub fn dist(seed: u32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) }; + Self { sampler } + } + + /// Selects the most likely token + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::greedy()); + /// + /// assert_eq!(data_array.data.len(), 2); + /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1))); + /// ``` + #[must_use] + pub fn greedy() -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() }; + Self { sampler } + } + + /// Creates a sampler that applies bias values to specific tokens during sampling. + /// + /// # Parameters + /// - ``n_vocab``: [`LlamaModel::n_vocab`] + /// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs + /// + /// # Example + /// ```rust + /// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let biases = vec![ + /// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1 + /// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2 + /// ]; + /// + /// // Assuming vocab_size of 32000 + /// let sampler = LlamaSampler::logit_bias(32000, &biases); + /// ``` + #[must_use] + pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self { + + let data = biases.as_ptr().cast::(); + + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_logit_bias( + n_vocab, + biases.len() as i32, + data, + ) + }; + + Self { sampler } + } + +} + +impl Drop for LlamaSampler { + fn drop(&mut self) { + unsafe { + llama_cpp_sys_2::llama_sampler_free(self.sampler); + } + } +} diff --git a/llama-cpp-2/src/timing.rs b/llama-cpp-2/src/timing.rs index 51cf682a..b45d9318 100644 --- a/llama-cpp-2/src/timing.rs +++ b/llama-cpp-2/src/timing.rs @@ -4,43 +4,35 @@ use std::fmt::{Debug, Display, Formatter}; /// A wrapper around `llama_timings`. #[derive(Clone, Copy, Debug)] pub struct LlamaTimings { - pub(crate) timings: llama_cpp_sys_2::llama_timings, + pub(crate) timings: llama_cpp_sys_2::llama_perf_context_data, } impl LlamaTimings { /// Create a new `LlamaTimings`. /// ``` /// # use llama_cpp_2::timing::LlamaTimings; - /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7, 8, 9); - /// let timings_str = "load time = 3.00 ms - /// sample time = 4.00 ms / 7 runs (0.57 ms per token, 1750.00 tokens per second) - /// prompt eval time = 5.00 ms / 8 tokens (0.62 ms per token, 1600.00 tokens per second) - /// eval time = 6.00 ms / 9 runs (0.67 ms per token, 1500.00 tokens per second) - /// total time = 1.00 ms"; + /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6); + /// let timings_str = "load time = 2.00 ms + /// prompt eval time = 3.00 ms / 5 tokens (0.60 ms per token, 1666.67 tokens per second) + /// eval time = 4.00 ms / 6 runs (0.67 ms per token, 1500.00 tokens per second)\n"; /// assert_eq!(timings_str, format!("{}", timings)); /// ``` #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( t_start_ms: f64, - t_end_ms: f64, t_load_ms: f64, - t_sample_ms: f64, t_p_eval_ms: f64, t_eval_ms: f64, - n_sample: i32, n_p_eval: i32, n_eval: i32, ) -> Self { Self { - timings: llama_cpp_sys_2::llama_timings { + timings: llama_cpp_sys_2::llama_perf_context_data { t_start_ms, - t_end_ms, t_load_ms, - t_sample_ms, t_p_eval_ms, t_eval_ms, - n_sample, n_p_eval, n_eval, }, @@ -53,24 +45,12 @@ impl LlamaTimings { self.timings.t_start_ms } - /// Get the end time in milliseconds. - #[must_use] - pub fn t_end_ms(&self) -> f64 { - self.timings.t_end_ms - } - /// Get the load time in milliseconds. #[must_use] pub fn t_load_ms(&self) -> f64 { self.timings.t_load_ms } - /// Get the sample time in milliseconds. - #[must_use] - pub fn t_sample_ms(&self) -> f64 { - self.timings.t_sample_ms - } - /// Get the prompt evaluation time in milliseconds. #[must_use] pub fn t_p_eval_ms(&self) -> f64 { @@ -83,12 +63,6 @@ impl LlamaTimings { self.timings.t_eval_ms } - /// Get the number of samples. - #[must_use] - pub fn n_sample(&self) -> i32 { - self.timings.n_sample - } - /// Get the number of prompt evaluations. #[must_use] pub fn n_p_eval(&self) -> i32 { @@ -106,21 +80,11 @@ impl LlamaTimings { self.timings.t_start_ms = t_start_ms; } - /// Set the end time in milliseconds. - pub fn set_t_end_ms(&mut self, t_end_ms: f64) { - self.timings.t_end_ms = t_end_ms; - } - /// Set the load time in milliseconds. pub fn set_t_load_ms(&mut self, t_load_ms: f64) { self.timings.t_load_ms = t_load_ms; } - /// Set the sample time in milliseconds. - pub fn set_t_sample_ms(&mut self, t_sample_ms: f64) { - self.timings.t_sample_ms = t_sample_ms; - } - /// Set the prompt evaluation time in milliseconds. pub fn set_t_p_eval_ms(&mut self, t_p_eval_ms: f64) { self.timings.t_p_eval_ms = t_p_eval_ms; @@ -131,11 +95,6 @@ impl LlamaTimings { self.timings.t_eval_ms = t_eval_ms; } - /// Set the number of samples. - pub fn set_n_sample(&mut self, n_sample: i32) { - self.timings.n_sample = n_sample; - } - /// Set the number of prompt evaluations. pub fn set_n_p_eval(&mut self, n_p_eval: i32) { self.timings.n_p_eval = n_p_eval; @@ -150,14 +109,6 @@ impl LlamaTimings { impl Display for LlamaTimings { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { writeln!(f, "load time = {:.2} ms", self.t_load_ms())?; - writeln!( - f, - "sample time = {:.2} ms / {} runs ({:.2} ms per token, {:.2} tokens per second)", - self.t_sample_ms(), - self.n_sample(), - self.t_sample_ms() / f64::from(self.n_sample()), - 1e3 / self.t_sample_ms() * f64::from(self.n_sample()) - )?; writeln!( f, "prompt eval time = {:.2} ms / {} tokens ({:.2} ms per token, {:.2} tokens per second)", @@ -174,10 +125,6 @@ impl Display for LlamaTimings { self.t_eval_ms() / f64::from(self.n_eval()), 1e3 / self.t_eval_ms() * f64::from(self.n_eval()) )?; - write!( - f, - "total time = {:.2} ms", - self.t_end_ms() - self.t_start_ms() - ) + Ok(()) } } diff --git a/llama-cpp-2/src/token.rs b/llama-cpp-2/src/token.rs index 3019420d..abb4fbbf 100644 --- a/llama-cpp-2/src/token.rs +++ b/llama-cpp-2/src/token.rs @@ -5,6 +5,7 @@ use std::fmt::Display; pub mod data; pub mod data_array; +pub mod logit_bias; /// A safe wrapper for `llama_token`. #[repr(transparent)] diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index e81ab336..448864b9 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,23 +1,24 @@ -//! an rusty equivalent of `llama_token_data`. -use crate::context::LlamaContext; -use crate::token::data::LlamaTokenData; -use crate::token::LlamaToken; -use llama_cpp_sys_2::llama_token; -use std::cmp::min; +//! an rusty equivalent of `llama_token_data_array`. use std::ptr; +use crate::{sampling::LlamaSampler, token::data::LlamaTokenData}; + +use super::LlamaToken; + /// a safe wrapper around `llama_token_data_array`. #[derive(Debug, Clone, PartialEq)] #[allow(clippy::module_name_repetitions)] pub struct LlamaTokenDataArray { /// the underlying data pub data: Vec, + /// the index of the selected token in ``data`` + pub selected: Option, /// is the data sorted? pub sorted: bool, } impl LlamaTokenDataArray { - /// Create a new `LlamaTokenDataArray` from a vector and weather or not the data is sorted. + /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted. /// /// ``` /// # use llama_cpp_2::token::data::LlamaTokenData; @@ -32,10 +33,14 @@ impl LlamaTokenDataArray { /// ``` #[must_use] pub fn new(data: Vec, sorted: bool) -> Self { - Self { data, sorted } + Self { + data, + selected: None, + sorted, + } } - /// Create a new `LlamaTokenDataArray` from an iterator and weather or not the data is sorted. + /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted. /// ``` /// # use llama_cpp_2::token::data::LlamaTokenData; /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; @@ -52,6 +57,12 @@ impl LlamaTokenDataArray { { Self::new(data.into_iter().collect(), sorted) } + + /// Returns the current selected token, if one exists. + #[must_use] + pub fn selected_token(&self) -> Option { + self.data.get(self.selected?).map(LlamaTokenData::id) + } } impl LlamaTokenDataArray { @@ -59,352 +70,89 @@ impl LlamaTokenDataArray { /// /// # Panics /// - /// Panics if some of the safety conditions are not met. (we cannot check all of them at runtime so breaking them is UB) + /// Panics if some of the safety conditions are not met. (we cannot check all of them at + /// runtime so breaking them is UB) /// /// SAFETY: - /// [modify] cannot change the data pointer. + /// The returned array formed by the data pointer and the length must entirely consist of + /// initialized token data and the length must be less than the capacity of this array's data + /// buffer. /// if the data is not sorted, sorted must be false. - /// the size of the data can only decrease (i.e you cannot add new elements). pub(crate) unsafe fn modify_as_c_llama_token_data_array( &mut self, modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, ) -> T { let size = self.data.len(); - let data = self.data.as_mut_ptr().cast(); + let data = self + .data + .as_mut_ptr() + .cast::(); + let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { data, size, + selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1), sorted: self.sorted, }; + let result = modify(&mut c_llama_token_data_array); + assert!( - ptr::eq(data, c_llama_token_data_array.data), - "data pointer changed" + c_llama_token_data_array.size <= self.data.capacity(), + "Size of the returned array exceeds the data buffer's capacity!" ); - assert!(c_llama_token_data_array.size <= size, "size increased"); + if !ptr::eq(c_llama_token_data_array.data, data) { + ptr::copy( + c_llama_token_data_array.data, + data, + c_llama_token_data_array.size, + ); + } self.data.set_len(c_llama_token_data_array.size); + self.sorted = c_llama_token_data_array.sorted; + self.selected = c_llama_token_data_array + .selected + .try_into() + .ok() + .filter(|&s| s < self.data.len()); + result } - /// Repetition penalty described in [CTRL academic paper](https://arxiv.org/abs/1909.05858), with negative logit fix. - /// Frequency and presence penalties described in [OpenAI API](https://platform.openai.com/docs/api-reference/parameter-details). - /// - /// # Parameters - /// - /// * `ctx` - the context to use. May be `None` if you do not care to record the sample timings. - /// * `last_tokens` - the last tokens in the context. - /// - /// * `penalty_last_n` - the number of tokens back to consider for the repetition penalty. (0 for no penalty) - /// * `penalty_repeat` - the repetition penalty. (1.0 for no penalty) - /// * `penalty_freq` - the frequency penalty. (0.0 for no penalty) - /// * `penalty_present` - the presence penalty. (0.0 for no penalty) - /// - /// # Example - /// - /// ```rust - /// # use std::collections::BTreeMap; - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// let history = vec![ - /// LlamaToken::new(2), - /// LlamaToken::new(1), - /// LlamaToken::new(0), - /// ]; - /// - /// let candidates = vec![ - /// LlamaToken::new(0), - /// LlamaToken::new(1), - /// LlamaToken::new(2), - /// LlamaToken::new(3), - /// ]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates.iter().map(|&token| LlamaTokenData::new(token, 0.0, 0.0)), false); - /// - /// candidates.sample_repetition_penalty(None, &history, 2, 1.1, 0.1, 0.1); - /// - /// let token_logits = candidates.data.into_iter().map(|token_data| (token_data.id(), token_data.logit())).collect::>(); - /// assert_eq!(token_logits[&LlamaToken(0)], 0.0, "expected no penalty as it is out of `penalty_last_n`"); - /// assert!(token_logits[&LlamaToken(1)] < 0.0, "expected penalty as it is in `penalty_last_n`"); - /// assert!(token_logits[&LlamaToken(2)] < 0.0, "expected penalty as it is in `penalty_last_n`"); - /// assert_eq!(token_logits[&LlamaToken(3)], 0.0, "expected no penalty as it is not in `history`"); - /// ``` - pub fn sample_repetition_penalty( - &mut self, - ctx: Option<&mut LlamaContext>, - last_tokens: &[LlamaToken], - penalty_last_n: usize, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - let penalty_last_n = min(penalty_last_n, last_tokens.len().saturating_sub(1)); + /// Modifies the data array by applying a sampler to it + pub fn apply_sampler(&mut self, sampler: &LlamaSampler) { unsafe { self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_repetition_penalties( - ctx, - c_llama_token_data_array, - // safe cast as LlamaToken is repr(transparent) - last_tokens.as_ptr().cast::(), - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - ); + llama_cpp_sys_2::llama_sampler_apply(sampler.sampler, c_llama_token_data_array); }); } } - /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let lowest = LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0); - /// let middle = LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0); - /// let highest = LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0); - /// - /// let candidates = vec![lowest, middle, highest]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_softmax(None); - /// - /// assert!(candidates.sorted); - /// assert_eq!(candidates.data[0].id(), highest.id()); - /// assert_eq!(candidates.data[0].logit(), highest.logit()); - /// assert!(candidates.data[0].p() > candidates.data[1].p()); - /// assert_eq!(candidates.data[1].id(), middle.id()); - /// assert_eq!(candidates.data[1].logit(), middle.logit()); - /// assert!(candidates.data[1].p() > candidates.data[2].p()); - /// assert_eq!(candidates.data[2].id(), lowest.id()); - /// assert_eq!(candidates.data[2].logit(), lowest.logit()); - /// ``` - pub fn sample_softmax(&mut self, ctx: Option<&mut LlamaContext>) { - unsafe { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_softmax(ctx, c_llama_token_data_array); - }); - } - } - - /// Modify the logits of [`Self`] in place using temperature sampling. - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// - /// candidates.sample_temp(None, 0.5); - /// - /// assert_ne!(candidates.data[0].logit(), 0.1); - /// assert_ne!(candidates.data[1].logit(), 0.2); - /// assert_ne!(candidates.data[2].logit(), 0.7); - /// ``` - pub fn sample_temp(&mut self, ctx: Option<&mut LlamaContext>, temperature: f32) { - if temperature == 0.0 { - return; - } - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_temp(ctx, c_llama_token_data_array, temperature); - }); - } + /// Modifies the data array by applying a sampler to it + #[must_use] + pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self { + self.apply_sampler(sampler); + self } /// Randomly selects a token from the candidates based on their probabilities. - pub fn sample_token(&mut self, ctx: &mut LlamaContext) -> LlamaToken { - let llama_token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token(ctx.context.as_ptr(), c_llama_token_data_array) - }) - }; - LlamaToken(llama_token) - } - - /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - pub fn sample_top_k(&mut self, ctx: Option<&mut LlamaContext>, k: i32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); - }); - } - } - - /// Tail Free Sampling described in [Tail-Free-Sampling](https://www.trentonbricken.com/Tail-Free-Sampling/). - pub fn sample_tail_free(&mut self, ctx: Option<&mut LlamaContext>, z: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); - }); - } - } - - /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - /// - /// # Example - /// - /// ```rust - /// - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_typical(None, 0.5, 1); - /// - /// ``` - pub fn sample_typical(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - /// - /// # Example - /// - /// ```rust - /// - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_top_p(None, 0.5, 1); /// - /// assert_eq!(candidates.data.len(), 2); - /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); - /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); - /// ``` - pub fn sample_top_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) - /// - /// # Example - /// - /// ``` - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(4), 0.0001, 0.0), - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_min_p(None, 0.05, 1); - /// ``` - pub fn sample_min_p(&mut self, ctx: Option<&mut LlamaContext>, p: f32, min_keep: usize) { - let ctx = ctx.map_or(ptr::null_mut(), |ctx| ctx.context.as_ptr()); - unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); - }); - } - } - - /// Mirostat 2.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. - /// - /// # Parameters - /// - /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - pub fn sample_token_mirostat_v2( - &mut self, - ctx: &mut LlamaContext, - tau: f32, - eta: f32, - mu: &mut f32, - ) -> LlamaToken { - let mu_ptr = ptr::from_mut(mu); - let token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token_mirostat_v2( - ctx.context.as_ptr(), - c_llama_token_data_array, - tau, - eta, - mu_ptr, - ) - }) - }; - *mu = unsafe { *mu_ptr }; - LlamaToken(token) + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. + pub fn sample_token(&mut self, seed: u32) -> LlamaToken { + self.apply_sampler(&LlamaSampler::dist(seed)); + self.selected_token() + .expect("Dist sampler failed to select a token!") } - /// Mirostat 1.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words. - /// - /// # Parameters + /// Selects the token with the highest probability. /// - /// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `m` The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - pub fn sample_token_mirostat_v1( - &mut self, - ctx: &mut LlamaContext, - tau: f32, - eta: f32, - m: i32, - mu: &mut f32, - ) -> LlamaToken { - let mu_ptr = ptr::from_mut(mu); - let token = unsafe { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sample_token_mirostat( - ctx.context.as_ptr(), - c_llama_token_data_array, - tau, - eta, - m, - mu_ptr, - ) - }) - }; - *mu = unsafe { *mu_ptr }; - LlamaToken(token) + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. + pub fn sample_token_greedy(&mut self) -> LlamaToken { + self.apply_sampler(&LlamaSampler::greedy()); + self.selected_token() + .expect("Greedy sampler failed to select a token!") } } diff --git a/llama-cpp-2/src/token/logit_bias.rs b/llama-cpp-2/src/token/logit_bias.rs new file mode 100644 index 00000000..631c9395 --- /dev/null +++ b/llama-cpp-2/src/token/logit_bias.rs @@ -0,0 +1,93 @@ +//! Safe wrapper around `llama_logit_bias`. +use crate::token::LlamaToken; + +/// A transparent wrapper around `llama_logit_bias`. +/// +/// Represents a bias to be applied to a specific token during text generation. +/// The bias modifies the likelihood of the token being selected. +/// +/// Do not rely on `repr(transparent)` for this type. It should be considered an implementation +/// detail and may change across minor versions. +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(transparent)] +#[allow(clippy::module_name_repetitions)] +pub struct LlamaLogitBias { + logit_bias: llama_cpp_sys_2::llama_logit_bias, +} + +impl LlamaLogitBias { + /// Creates a new logit bias for a specific token with the given bias value. + /// + /// # Examples + /// ``` + /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// let token = LlamaToken::new(1); + /// let bias = LlamaLogitBias::new(token, 1.5); + /// ``` + #[must_use] + pub fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self { + Self { + logit_bias: llama_cpp_sys_2::llama_logit_bias { + token, + bias, + }, + } + } + + /// Gets the token this bias applies to. + /// + /// # Examples + /// ``` + /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// let token = LlamaToken::new(1); + /// let bias = LlamaLogitBias::new(token, 1.5); + /// assert_eq!(bias.token(), token); + /// ``` + #[must_use] + pub fn token(&self) -> LlamaToken { + LlamaToken(self.logit_bias.token) + } + + /// Gets the bias value. + /// + /// # Examples + /// ``` + /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// let token = LlamaToken::new(1); + /// let bias = LlamaLogitBias::new(token, 1.5); + /// assert_eq!(bias.bias(), 1.5); + /// ``` + #[must_use] + pub fn bias(&self) -> f32 { + self.logit_bias.bias + } + + /// Sets the token this bias applies to. + /// + /// # Examples + /// ``` + /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// let token = LlamaToken::new(1); + /// let mut bias = LlamaLogitBias::new(token, 1.5); + /// let new_token = LlamaToken::new(2); + /// bias.set_token(new_token); + /// assert_eq!(bias.token(), new_token); + /// ``` + pub fn set_token(&mut self, token: LlamaToken) { + self.logit_bias.token = token.0; + } + + /// Sets the bias value. + /// + /// # Examples + /// ``` + /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; + /// let token = LlamaToken::new(1); + /// let mut bias = LlamaLogitBias::new(token, 1.5); + /// bias.set_bias(2.0); + /// assert_eq!(bias.bias(), 2.0); + /// ``` + pub fn set_bias(&mut self, bias: f32) { + self.logit_bias.bias = bias; + } +} \ No newline at end of file diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 5e26631a..068204da 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-sys-2" description = "Low Level Bindings to llama.cpp" -version = "0.1.83" +version = "0.1.109" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" @@ -12,9 +12,9 @@ include = [ "build.rs", "/src", - "/llama.cpp/common/*.h", - "/llama.cpp/common/*.hpp", - "/llama.cpp/common/*.cpp", + "/llama.cpp/common/**/*.h", + "/llama.cpp/common/**/*.hpp", + "/llama.cpp/common/**/*.cpp", "/llama.cpp/ggml/include/*.h", "/llama.cpp/ggml/src/*.h", "/llama.cpp/ggml/src/*.c", @@ -31,10 +31,12 @@ include = [ "/llama.cpp/ggml/src/ggml-metal.metal", "/llama.cpp/include/llama.h", + "/llama.cpp/include/llama-cpp.h", + "/llama.cpp/ggml/src/ggml-cpu/**/*", "/llama.cpp/ggml/src/ggml-cuda/**/*", - - "/llama.cpp/ggml/src/vulkan-shaders/**/*", + "/llama.cpp/ggml/src/ggml-metal/**/*", + "/llama.cpp/ggml/src/ggml-vulkan/**/*", "/llama.cpp/ggml/src/llamafile/sgemm.h", "/llama.cpp/ggml/src/llamafile/sgemm.cpp", @@ -45,7 +47,6 @@ include = [ "/llama.cpp/common/CMakeLists.txt", "/llama.cpp/ggml/CMakeLists.txt", "/llama.cpp/ggml/src/CMakeLists.txt", - "/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt", "/llama.cpp/src/CMakeLists.txt", "/llama.cpp/cmake", @@ -61,12 +62,18 @@ include = [ bindgen = { workspace = true } cc = { workspace = true, features = ["parallel"] } cmake = "0.1" -glob = "0.3.1" +find_cuda_helper = "0.2.0" +glob = "0.3.2" +walkdir = "2" [features] cuda = [] +# Disables the need to dynamically link against libcuda.so / cuda.dll +cuda-no-vmm = ["cuda"] metal = [] dynamic-link = [] vulkan = [] native = [] openmp = [] +# Only has an impact on Android. +shared-stdcxx = [] diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index 33b0ee19..8ecfb3a9 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -3,6 +3,24 @@ use glob::glob; use std::env; use std::path::{Path, PathBuf}; use std::process::Command; +use walkdir::DirEntry; + +enum WindowsVariant { + Msvc, + Other, +} + +enum AppleVariant { + MacOS, + Other, +} + +enum TargetOs { + Windows(WindowsVariant), + Apple(AppleVariant), + Linux, + Android, +} macro_rules! debug_log { ($($arg:tt)*) => { @@ -12,41 +30,38 @@ macro_rules! debug_log { }; } -fn get_cargo_target_dir() -> Result> { - let out_dir = std::path::PathBuf::from(std::env::var("OUT_DIR")?); - let profile = std::env::var("PROFILE")?; - let mut target_dir = None; - let mut sub_path = out_dir.as_path(); - while let Some(parent) = sub_path.parent() { - if parent.ends_with(&profile) { - target_dir = Some(parent); - break; +fn parse_target_os() -> Result<(TargetOs, String), String> { + let target = env::var("TARGET").unwrap(); + + if target.contains("windows") { + if target.ends_with("-windows-msvc") { + Ok((TargetOs::Windows(WindowsVariant::Msvc), target)) + } else { + Ok((TargetOs::Windows(WindowsVariant::Other), target)) + } + } else if target.contains("apple") { + if target.ends_with("-apple-darwin") { + Ok((TargetOs::Apple(AppleVariant::MacOS), target)) + } else { + Ok((TargetOs::Apple(AppleVariant::Other), target)) } - sub_path = parent; + } else if target.contains("android") { + Ok((TargetOs::Android, target)) + } else if target.contains("linux") { + Ok((TargetOs::Linux, target)) + } else { + Err(target) } - let target_dir = target_dir.ok_or("not found")?; - Ok(target_dir.to_path_buf()) } -fn copy_folder(src: &Path, dst: &Path) { - std::fs::create_dir_all(dst).expect("Failed to create dst directory"); - if cfg!(unix) { - std::process::Command::new("cp") - .arg("-rf") - .arg(src) - .arg(dst.parent().unwrap()) - .status() - .expect("Failed to execute cp command"); - } - - if cfg!(windows) { - std::process::Command::new("robocopy.exe") - .arg("/e") - .arg(src) - .arg(dst) - .status() - .expect("Failed to execute robocopy command"); - } +fn get_cargo_target_dir() -> Result> { + let out_dir = env::var("OUT_DIR")?; + let path = PathBuf::from(out_dir); + let target_dir = path + .ancestors() + .nth(3) + .ok_or("OUT_DIR is not deep enough")?; + Ok(target_dir.to_path_buf()) } fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec { @@ -58,14 +73,12 @@ fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec { } else { "*.a" } + } else if build_shared_libs { + "*.so" } else { - if build_shared_libs { - "*.so" - } else { - "*.a" - } + "*.a" }; - let libs_dir = out_dir.join("lib"); + let libs_dir = out_dir.join("lib*"); let pattern = libs_dir.join(lib_pattern); debug_log!("Extract libs {}", pattern.display()); @@ -82,6 +95,12 @@ fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec { let lib_name = if stem_str.starts_with("lib") { stem_str.strip_prefix("lib").unwrap_or(stem_str) } else { + if path.extension() == Some(std::ffi::OsStr::new("a")) { + let target = path.parent().unwrap().join(format!("lib{}.a", stem_str)); + std::fs::rename(&path, &target).unwrap_or_else(|e| { + panic!("Failed to rename {path:?} to {target:?}: {e:?}"); + }) + } stem_str }; lib_names.push(lib_name.to_string()); @@ -101,7 +120,8 @@ fn extract_lib_assets(out_dir: &Path) -> Vec { "*.so" }; - let libs_dir = out_dir.join("lib"); + let shared_libs_dir = if cfg!(windows) { "bin" } else { "lib" }; + let libs_dir = out_dir.join(shared_libs_dir); let pattern = libs_dir.join(shared_lib_pattern); debug_log!("Extract lib assets {}", pattern.display()); let mut files = Vec::new(); @@ -142,16 +162,24 @@ fn macos_link_search_path() -> Option { None } +fn is_hidden(e: &DirEntry) -> bool { + e.file_name() + .to_str() + .map(|s| s.starts_with('.')) + .unwrap_or_default() +} + fn main() { + println!("cargo:rerun-if-changed=build.rs"); - let target = env::var("TARGET").unwrap(); + let (target_os, target_triple) = + parse_target_os().unwrap_or_else(|t| panic!("Failed to parse target os {t}")); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let target_dir = get_cargo_target_dir().unwrap(); - let llama_dst = out_dir.join("llama.cpp"); let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get CARGO_MANIFEST_DIR"); let llama_src = Path::new(&manifest_dir).join("llama.cpp"); - let build_shared_libs = cfg!(feature = "cuda") || cfg!(feature = "dynamic-link"); + let build_shared_libs = cfg!(feature = "dynamic-link"); let build_shared_libs = std::env::var("LLAMA_BUILD_SHARED_LIBS") .map(|v| v == "1") @@ -161,17 +189,40 @@ fn main() { .map(|v| v == "1") .unwrap_or(false); - debug_log!("TARGET: {}", target); + println!("cargo:rerun-if-env-changed=LLAMA_LIB_PROFILE"); + println!("cargo:rerun-if-env-changed=LLAMA_BUILD_SHARED_LIBS"); + println!("cargo:rerun-if-env-changed=LLAMA_STATIC_CRT"); + + debug_log!("TARGET: {}", target_triple); debug_log!("CARGO_MANIFEST_DIR: {}", manifest_dir); debug_log!("TARGET_DIR: {}", target_dir.display()); debug_log!("OUT_DIR: {}", out_dir.display()); debug_log!("BUILD_SHARED: {}", build_shared_libs); - // Prepare sherpa-onnx source - if !llama_dst.exists() { - debug_log!("Copy {} to {}", llama_src.display(), llama_dst.display()); - copy_folder(&llama_src, &llama_dst); + // Make sure that changes to the llama.cpp project trigger a rebuild. + let rebuild_on_children_of = [ + llama_src.join("src"), + llama_src.join("ggml/src"), + llama_src.join("common"), + ]; + for entry in walkdir::WalkDir::new(&llama_src) + .into_iter() + .filter_entry(|e| !is_hidden(e)) + { + let entry = entry.expect("Failed to obtain entry"); + let rebuild = entry + .file_name() + .to_str() + .map(|f| f.starts_with("CMake")) + .unwrap_or_default() + || rebuild_on_children_of + .iter() + .any(|src_folder| entry.path().starts_with(src_folder)); + if rebuild { + println!("cargo:rerun-if-changed={}", entry.path().display()); + } } + // Speed up build env::set_var( "CMAKE_BUILD_PARALLEL_LEVEL", @@ -184,8 +235,9 @@ fn main() { // Bindings let bindings = bindgen::Builder::default() .header("wrapper.h") - .clang_arg(format!("-I{}", llama_dst.join("include").display())) - .clang_arg(format!("-I{}", llama_dst.join("ggml/include").display())) + .clang_arg(format!("-I{}", llama_src.join("include").display())) + .clang_arg(format!("-I{}", llama_src.join("ggml/include").display())) + .clang_arg(format!("--target={}", target_triple)) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .derive_partialeq(true) .allowlist_function("ggml_.*") @@ -196,7 +248,6 @@ fn main() { .generate() .expect("Failed to generate bindings"); - // Write the generated bindings to an output file let bindings_path = out_dir.join("bindings.rs"); bindings @@ -204,13 +255,12 @@ fn main() { .expect("Failed to write bindings"); println!("cargo:rerun-if-changed=wrapper.h"); - println!("cargo:rerun-if-changed=./sherpa-onnx"); debug_log!("Bindings Created"); // Build with Cmake - let mut config = Config::new(&llama_dst); + let mut config = Config::new(&llama_src); // Would require extra source files to pointlessly // be included in what's uploaded to and downloaded from @@ -218,41 +268,130 @@ fn main() { config.define("LLAMA_BUILD_TESTS", "OFF"); config.define("LLAMA_BUILD_EXAMPLES", "OFF"); config.define("LLAMA_BUILD_SERVER", "OFF"); + config.define("LLAMA_BUILD_TOOLS", "OFF"); + config.define("LLAMA_CURL", "OFF"); config.define( "BUILD_SHARED_LIBS", if build_shared_libs { "ON" } else { "OFF" }, ); - if cfg!(target_os = "macos") { + if matches!(target_os, TargetOs::Apple(_)) { config.define("GGML_BLAS", "OFF"); } - if cfg!(windows) { - config.static_crt(static_crt); + if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) + && matches!( + profile.as_str(), + "Release" | "RelWithDebInfo" | "MinSizeRel" + )) + { + // Debug Rust builds under MSVC turn off optimization even though we're ideally building the release profile of llama.cpp. + // Looks like an upstream bug: + // https://github.com/rust-lang/cmake-rs/issues/240 + // For now explicitly reinject the optimization flags that a CMake Release build is expected to have on in this scenario. + // This fixes CPU inference performance when part of a Rust debug build. + for flag in &["/O2", "/DNDEBUG", "/Ob2"] { + config.cflag(flag); + config.cxxflag(flag); + } } - - if cfg!(feature = "vulkan") { - config.define("GGML_VULKAN", "ON"); - if cfg!(windows) { - let vulkan_path = env::var("VULKAN_SDK").expect("Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set"); - let vulkan_lib_path = Path::new(&vulkan_path).join("Lib"); - println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); - println!("cargo:rustc-link-lib=vulkan-1"); + config.static_crt(static_crt); + + if matches!(target_os, TargetOs::Android) { + // build flags for android taken from this doc + // https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md + let android_ndk = env::var("ANDROID_NDK") + .expect("Please install Android NDK and ensure that ANDROID_NDK env variable is set"); + + println!("cargo::rerun-if-env-changed=ANDROID_NDK"); + + config.define( + "CMAKE_TOOLCHAIN_FILE", + format!("{android_ndk}/build/cmake/android.toolchain.cmake"), + ); + if env::var("ANDROID_PLATFORM").is_ok() { + println!("cargo::rerun-if-env-changed=ANDROID_PLATFORM"); + } else { + config.define("ANDROID_PLATFORM", "android-28"); + } + if target_triple.contains("aarch64") || target_triple.contains("armv7") { + config.cflag("-march=armv8.7a"); + config.cxxflag("-march=armv8.7a"); + } else if target_triple.contains("x86_64") { + config.cflag("-march=x86-64"); + config.cxxflag("-march=x86-64"); + } else if target_triple.contains("i686") { + config.cflag("-march=i686"); + config.cxxflag("-march=i686"); + } else { + // Rather than guessing just fail. + panic!("Unsupported Android target {target_triple}"); } + config.define("GGML_LLAMAFILE", "OFF"); + if cfg!(feature = "shared-stdcxx") { + println!("cargo:rustc-link-lib=dylib=stdc++"); + println!("cargo:rustc-link-lib=c++_shared"); + } + } - if cfg!(target_os = "linux") { - println!("cargo:rustc-link-lib=vulkan"); + if matches!(target_os, TargetOs::Linux) + && target_triple.contains("aarch64") + && !env::var(format!("CARGO_FEATURE_{}", "native".to_uppercase())).is_ok() + { + // If the native feature is not enabled, we take off the native ARM64 support. + // It is useful in docker environments where the native feature is not enabled. + config.define("GGML_NATIVE", "OFF"); + config.define("GGML_CPU_ARM_ARCH", "armv8-a"); + } + + if cfg!(feature = "vulkan") { + config.define("GGML_VULKAN", "ON"); + match target_os { + TargetOs::Windows(_) => { + let vulkan_path = env::var("VULKAN_SDK").expect( + "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set", + ); + let vulkan_lib_path = Path::new(&vulkan_path).join("Lib"); + println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); + println!("cargo:rustc-link-lib=vulkan-1"); + + // workaround for this error: "FileTracker : error FTK1011: could not create the new file tracking log file" + // the issue is likely caused by nested cmake projects with ExternalProject_Add + // and windows' FileTracker thingy not properly picking up the inherited dir config + // ...either that, or it has to do with MSBuild FileTracker not respecting the path + // limit configuration set in the windows registry. I'm not sure which, but this + // makes my builds work. + env::set_var("TrackFileAccess", "false"); + + // since we disabled TrackFileAccess, we can now run into problems with parallel + // access to pdb files. /FS solves this. + config.cflag("/FS"); + config.cxxflag("/FS"); + } + TargetOs::Linux => { + println!("cargo:rustc-link-lib=vulkan"); + } + _ => (), } } if cfg!(feature = "cuda") { config.define("GGML_CUDA", "ON"); + + if cfg!(feature = "cuda-no-vmm") { + config.define("GGML_CUDA_NO_VMM", "ON"); + } } - if cfg!(feature = "openmp") { + // Android doesn't have OpenMP support AFAICT and openmp is a default feature. Do this here + // rather than modifying the defaults in Cargo.toml just in case someone enables the OpenMP feature + // and tries to build for Android anyway. + if cfg!(feature = "openmp") && !matches!(target_os, TargetOs::Android) { config.define("GGML_OPENMP", "ON"); + } else { + config.define("GGML_OPENMP", "OFF"); } // General @@ -265,58 +404,85 @@ fn main() { // Search paths println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); + println!( + "cargo:rustc-link-search={}", + out_dir.join("lib64").display() + ); println!("cargo:rustc-link-search={}", build_dir.display()); - // Link libraries - let llama_libs_kind = if build_shared_libs { "dylib" } else { "static" }; - let llama_libs = extract_lib_names(&out_dir, build_shared_libs); + if cfg!(feature = "cuda") && !build_shared_libs { + println!("cargo:rerun-if-env-changed=CUDA_PATH"); - for lib in llama_libs { - debug_log!( - "LINK {}", - format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib) - ); - println!( - "{}", - format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib) - ); - } + for lib_dir in find_cuda_helper::find_cuda_lib_dirs() { + println!("cargo:rustc-link-search=native={}", lib_dir.display()); + } - // OpenMP - if cfg!(feature = "openmp") { - if target.contains("gnu") { - println!("cargo:rustc-link-lib=gomp"); + // Logic from ggml-cuda/CMakeLists.txt + println!("cargo:rustc-link-lib=static=cudart_static"); + if matches!(target_os, TargetOs::Windows(_)) { + println!("cargo:rustc-link-lib=static=cublas"); + println!("cargo:rustc-link-lib=static=cublasLt"); + } else { + println!("cargo:rustc-link-lib=static=cublas_static"); + println!("cargo:rustc-link-lib=static=cublasLt_static"); + } + + // Need to link against libcuda.so unless GGML_CUDA_NO_VMM is defined. + if !cfg!(feature = "cuda-no-vmm") { + println!("cargo:rustc-link-lib=cuda"); } - } - // Windows debug - if cfg!(all(debug_assertions, windows)) { - println!("cargo:rustc-link-lib=dylib=msvcrtd"); + println!("cargo:rustc-link-lib=static=culibos"); } - // // macOS - if cfg!(target_os = "macos") { - println!("cargo:rustc-link-lib=framework=Foundation"); - println!("cargo:rustc-link-lib=framework=Metal"); - println!("cargo:rustc-link-lib=framework=MetalKit"); - println!("cargo:rustc-link-lib=framework=Accelerate"); - println!("cargo:rustc-link-lib=c++"); + // Link libraries + let llama_libs_kind = if build_shared_libs { "dylib" } else { "static" }; + let llama_libs = extract_lib_names(&out_dir, build_shared_libs); + assert_ne!(llama_libs.len(), 0); + + for lib in llama_libs { + let link = format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib); + debug_log!("LINK {link}",); + println!("{link}",); } - // Linux - if cfg!(target_os = "linux") { - println!("cargo:rustc-link-lib=dylib=stdc++"); + // OpenMP + if cfg!(feature = "openmp") && target_triple.contains("gnu") { + println!("cargo:rustc-link-lib=gomp"); } - if target.contains("apple") { - // On (older) OSX we need to link against the clang runtime, - // which is hidden in some non-default path. - // - // More details at https://github.com/alexcrichton/curl-rust/issues/279. - if let Some(path) = macos_link_search_path() { - println!("cargo:rustc-link-lib=clang_rt.osx"); - println!("cargo:rustc-link-search={}", path); + match target_os { + TargetOs::Windows(WindowsVariant::Msvc) => { + println!("cargo:rustc-link-lib=advapi32"); + if cfg!(debug_assertions) { + println!("cargo:rustc-link-lib=dylib=msvcrtd"); + } } + TargetOs::Linux => { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + TargetOs::Apple(variant) => { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + println!("cargo:rustc-link-lib=framework=Accelerate"); + println!("cargo:rustc-link-lib=c++"); + + match variant { + AppleVariant::MacOS => { + // On (older) OSX we need to link against the clang runtime, + // which is hidden in some non-default path. + // + // More details at https://github.com/alexcrichton/curl-rust/issues/279. + if let Some(path) = macos_link_search_path() { + println!("cargo:rustc-link-lib=clang_rt.osx"); + println!("cargo:rustc-link-search={}", path); + } + } + AppleVariant::Other => (), + } + } + _ => (), } // copy DLLs to target @@ -330,7 +496,7 @@ fn main() { debug_log!("HARD LINK {} TO {}", asset.display(), dst.display()); if !dst.exists() { std::fs::hard_link(asset.clone(), dst).unwrap(); - } + } // Copy DLLs to examples as well if target_dir.join("examples").exists() { @@ -349,4 +515,4 @@ fn main() { } } } -} \ No newline at end of file +} diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 8f1d81a0..f8705144 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 8f1d81a0b6f50b9bad72db0b6fcd299ad9ecd48c +Subproject commit f87051445aae8d38137e47a117fcc5e752c62398 diff --git a/test-build.Dockerfile b/test-build.Dockerfile index 8540d2f9..383e0973 100644 --- a/test-build.Dockerfile +++ b/test-build.Dockerfile @@ -1,16 +1,16 @@ ARG CUDA_VERSION=12.3.1 ARG UBUNTU_VERSION=22.04 -FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} as base-cuda +FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS base-cuda # Install requirements for rustup install + bindgen: https://rust-lang.github.io/rust-bindgen/requirements.html -RUN DEBIAN_FRONTEND=noninteractive apt update -y && apt install -y curl llvm-dev libclang-dev clang pkg-config libssl-dev +RUN DEBIAN_FRONTEND=noninteractive apt update -y && apt install -y curl llvm-dev libclang-dev clang pkg-config libssl-dev cmake git RUN curl https://sh.rustup.rs -sSf | bash -s -- -y ENV PATH=/root/.cargo/bin:$PATH COPY . . RUN cargo build --bin simple --features cuda -FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} as base-cuda-runtime +FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS base-cuda-runtime COPY --from=base-cuda /target/debug/simple /usr/local/bin/simple