From 604ab0e18c88ae7cdbf17089d0518be0bf184ea5 Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 31 Jul 2024 16:53:03 +0200 Subject: [PATCH 01/23] GraphRAG Global Search --- Cargo.lock | 874 +++++++++++++++++- Cargo.toml | 3 +- shinkai-libs/shinkai-graphrag/Cargo.toml | 13 + .../src/context_builder/context_builder.rs | 18 + .../src/context_builder/mod.rs | 1 + shinkai-libs/shinkai-graphrag/src/lib.rs | 3 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 35 + shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 + .../src/search/global_search.rs | 255 +++++ .../shinkai-graphrag/src/search/mod.rs | 1 + 10 files changed, 1163 insertions(+), 41 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/Cargo.toml create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/lib.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/llm.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 9aa8faba4..90c968bde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -218,6 +218,21 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "argminmax" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + [[package]] name = "arrayref" version = "0.3.7" @@ -369,7 +384,7 @@ dependencies = [ "arrow-schema 51.0.0", "arrow-select 51.0.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -390,7 +405,7 @@ dependencies = [ "arrow-schema 52.2.0", "arrow-select 52.2.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -880,6 +895,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "atomic-waker" version = "1.1.1" @@ -1440,9 +1461,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64-simd" @@ -1699,6 +1720,20 @@ name = "bytemuck" version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] [[package]] name = "byteorder" @@ -1914,6 +1949,17 @@ dependencies = [ "parse-zoneinfo", ] +[[package]] +name = "chrono-tz" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +dependencies = [ + "chrono", + "chrono-tz-build 0.2.1", + "phf 0.11.2", +] + [[package]] name = "chrono-tz" version = "0.9.0" @@ -1921,8 +1967,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" dependencies = [ "chrono", - "chrono-tz-build", + "chrono-tz-build 0.3.0", + "phf 0.11.2", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", "phf 0.11.2", + "phf_codegen 0.11.2", ] [[package]] @@ -2152,6 +2209,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ + "crossterm", "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", @@ -2433,6 +2491,28 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.0", + "crossterm_winapi", + "libc", + "parking_lot 0.12.1", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -2801,7 +2881,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" dependencies = [ "arrow 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -2888,7 +2968,7 @@ dependencies = [ "arrow-ord 52.2.0", "arrow-schema 52.2.0", "arrow-string 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -3256,6 +3336,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "ecdsa" version = "0.14.8" @@ -3315,9 +3401,9 @@ checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" [[package]] name = "either" -version = "1.9.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "elliptic-curve" @@ -3400,6 +3486,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "env_logger" version = "0.9.3" @@ -3756,6 +3854,12 @@ dependencies = [ "yansi", ] +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + [[package]] name = "event-listener" version = "2.5.3" @@ -3820,6 +3924,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -3830,6 +3940,12 @@ dependencies = [ "regex", ] +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + [[package]] name = "fastdivide" version = "0.4.1" @@ -3992,6 +4108,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -4440,6 +4562,8 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.11", "allocator-api2", + "rayon", + "serde", ] [[package]] @@ -4520,9 +4644,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -4853,7 +4977,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows 0.48.0", ] [[package]] @@ -5054,7 +5178,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", "windows-sys 0.48.0", ] @@ -5087,7 +5211,7 @@ version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "rustix 0.38.32", "windows-sys 0.48.0", ] @@ -5140,6 +5264,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "itoap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" + [[package]] name = "jetscii" version = "0.5.3" @@ -6021,11 +6151,21 @@ version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9106e1d747ffd48e6be5bb2d97fa706ed25b144fbee4d5c02eae110cd8d6badd" +[[package]] +name = "lz4" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958b4caa893816eea05507c20cfe47574a43d9a697138a7872990bba8a0ece68" +dependencies = [ + "libc", + "lz4-sys", +] + [[package]] name = "lz4-sys" -version = "1.9.4" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" dependencies = [ "cc", "libc", @@ -6055,6 +6195,12 @@ dependencies = [ "libc", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "markup5ever" version = "0.10.1" @@ -6164,6 +6310,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + [[package]] name = "memmap2" version = "0.9.4" @@ -6271,13 +6426,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -6383,6 +6539,28 @@ dependencies = [ "twoway", ] +[[package]] +name = "multiversion" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 1.0.109", + "target-features", +] + [[package]] name = "murmurhash32" version = "0.3.1" @@ -6498,6 +6676,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -6616,7 +6812,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", ] @@ -6657,7 +6853,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" dependencies = [ "async-trait", - "base64 0.22.0", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -7062,6 +7258,12 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parquet-format-safe" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" + [[package]] name = "parse-zoneinfo" version = "0.3.0" @@ -7148,6 +7350,28 @@ dependencies = [ "hmac 0.12.1", ] +[[package]] +name = "pcre2" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be55c43ac18044541d58d897e8f4c55157218428953ebd39d86df3ba0286b2b" +dependencies = [ + "libc", + "log 0.4.21", + "pcre2-sys", +] + +[[package]] +name = "pcre2-sys" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "550f5d18fb1b90c20b87e161852c10cde77858c3900c5059b5ad2a1449f11d8a" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "pddl-ish-parser" version = "0.0.4" @@ -7528,6 +7752,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + [[package]] name = "platforms" version = "3.2.0" @@ -7582,6 +7815,402 @@ dependencies = [ "miniz_oxide 0.7.1", ] +[[package]] +name = "polars" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3351ea4570e54cd556e6755b78fe7a2c85368d820c0307cca73c96e796a7ba" +dependencies = [ + "getrandom 0.2.10", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-sql", + "polars-time", + "polars-utils", + "version_check 0.9.4", +] + +[[package]] +name = "polars-arrow" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba65fc4bcabbd64fca01fd30e759f8b2043f0963c57619e331d4b534576c0b47" +dependencies = [ + "ahash 0.8.11", + "atoi", + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "getrandom 0.2.10", + "hashbrown 0.14.5", + "itoa 1.0.9", + "itoap", + "lz4", + "multiversion", + "num-traits", + "polars-arrow-format", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check 0.9.4", + "zstd 0.13.2", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f099516af30ac9ae4b4480f4ad02aa017d624f2f37b7a16ad4e9ba52f7e5269" +dependencies = [ + "bytemuck", + "either", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "strength_reduce", + "version_check 0.9.4", +] + +[[package]] +name = "polars-core" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2439484be228b8c302328e2f953e64cfd93930636e5c7ceed90339ece7fef6c" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "comfy-table", + "either", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand 0.8.5", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check 0.9.4", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9b06dfbe79cabe50a7f0a90396864b5ee2c0e0f8d6a9353b2343c29c56e937" +dependencies = [ + "polars-arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-expr" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c630385a56a867c410a20f30772d088f90ec3d004864562b84250b35268f97" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", +] + +[[package]] +name = "polars-io" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" +dependencies = [ + "ahash 0.8.11", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "home", + "itoa 1.0.9", + "memchr", + "memmap2 0.7.1", + "num-traits", + "once_cell", + "percent-encoding 2.3.0", + "polars-arrow", + "polars-core", + "polars-error", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", +] + +[[package]] +name = "polars-lazy" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03877e74e42b5340ae52ded705f6d5d14563d90554c9177b01b91ed2412a56ed" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "glob", + "memchr", + "once_cell", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check 0.9.4", +] + +[[package]] +name = "polars-mem-engine" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea9e17771af750c94bf959885e4b3f5b14149576c62ef3ec1c9ef5827b2a30f" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6066552eb577d43b307027fb38096910b643ffb2c89a21628c7e41caf57848d0" +dependencies = [ + "ahash 0.8.11", + "argminmax", + "base64 0.22.1", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "hex", + "indexmap 2.1.0", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "unicode-reverse", + "version_check 0.9.4", +] + +[[package]] +name = "polars-parquet" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" +dependencies = [ + "ahash 0.8.11", + "base64 0.22.1", + "ethnum", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "simdutf8", + "streaming-decompression", +] + +[[package]] +name = "polars-pipe" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021bce7768c330687d735340395a77453aa18dd70d57c184cbb302311e87c1b9" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown 0.14.5", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "uuid 1.8.0", + "version_check 0.9.4", +] + +[[package]] +name = "polars-plan" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "220d0d7c02d1c4375802b2813dbedcd1a184df39c43b74689e729ede8d5c2921" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "once_cell", + "percent-encoding 2.3.0", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-time", + "polars-utils", + "rayon", + "recursive", + "regex", + "smartstring", + "strum_macros 0.26.4", + "version_check 0.9.4", +] + +[[package]] +name = "polars-row" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1d70d87a2882a64a43b431aea1329cb9a2c4100547c95c417cc426bb82408b3" +dependencies = [ + "bytemuck", + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fc1c9b778862f09f4a347f768dfdd3d0ba9957499d306d83c7103e0fa8dc5b" +dependencies = [ + "hex", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "rand 0.8.5", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "179f98313a15c0bfdbc8cc0f1d3076d08d567485b9952d46439f94fbc3085df5" +dependencies = [ + "atoi", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53e6dd89fcccb1ec1a62f752c9a9f2d482a85e9255153f46efecc617b4996d50" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-error", + "raw-cpuid 11.0.1", + "rayon", + "smartstring", + "stacker", + "sysinfo", + "version_check 0.9.4", +] + [[package]] name = "polling" version = "2.8.0" @@ -7906,6 +8535,15 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -8441,6 +9079,26 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "redox_syscall" version = "0.3.5" @@ -8594,7 +9252,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -8949,6 +9607,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rust_decimal_macros" +version = "1.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a05bf7103af0797dbce0667c471946b29b9eaea34652eff67324f360fec027de" +dependencies = [ + "quote 1.0.36", + "rust_decimal", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -9069,7 +9737,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -9421,11 +10089,11 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.8.1" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -9439,9 +10107,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.8.1" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" dependencies = [ "darling 0.20.10", "proc-macro2 1.0.84", @@ -9550,6 +10218,19 @@ dependencies = [ "dirs", ] +[[package]] +name = "shinkai-graphrag" +version = "0.1.0" +dependencies = [ + "async-trait", + "futures", + "polars", + "serde", + "serde_json", + "tiktoken", + "tokio", +] + [[package]] name = "shinkai_crypto_identities" version = "0.1.1" @@ -9592,7 +10273,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9772,7 +10453,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -10025,6 +10706,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg 1.1.0", + "static_assertions", + "version_check 0.9.4", +] + [[package]] name = "snafu" version = "0.7.5" @@ -10152,6 +10844,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "winapi", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -10170,6 +10875,27 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "string_cache" version = "0.8.7" @@ -10358,6 +11084,20 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "sysinfo" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows 0.52.0", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -10406,7 +11146,7 @@ checksum = "f8d0582f186c0a6d55655d24543f15e43607299425c5ad8352c242b914b31856" dependencies = [ "aho-corasick", "arc-swap", - "base64 0.22.0", + "base64 0.22.1", "bitpacking", "byteorder", "census", @@ -10423,7 +11163,7 @@ dependencies = [ "lru 0.12.3", "lz4_flex", "measure_time", - "memmap2", + "memmap2 0.9.4", "num_cpus", "once_cell", "oneshot", @@ -10556,6 +11296,12 @@ dependencies = [ "xattr", ] +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + [[package]] name = "target-lexicon" version = "0.12.14" @@ -10685,6 +11431,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d8d54b2ba3a0f0b14743f9f0e0f6884ee03ad8789f1ce6a4b5650c0d74fad8" +dependencies = [ + "anyhow", + "base64 0.21.7", + "lazy_static", + "maplit", + "pcre2", + "rust_decimal", + "rust_decimal_macros", +] + [[package]] name = "time" version = "0.1.45" @@ -10785,22 +11546,21 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.6", "tokio-macros", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -10815,9 +11575,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2 1.0.84", "quote 1.0.36", @@ -10857,9 +11617,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", @@ -11313,6 +12073,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" @@ -11827,6 +12596,25 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.3", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.3", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -12062,6 +12850,12 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "xxhash-rust" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index b4fdec92c..7e4bb91c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,14 @@ members = [ "shinkai-libs/shinkai-dsl", "shinkai-libs/shinkai-sheet", "shinkai-libs/shinkai-fs-mirror", + "shinkai-libs/shinkai-graphrag", "shinkai-libs/shinkai-message-primitives", "shinkai-libs/shinkai-ocr", "shinkai-libs/shinkai-tcp-relayer", "shinkai-libs/shinkai-vector-resources", "shinkai-bin/*", "shinkai-cli-tools/*" -] +, "shinkai-libs/shinkai-graphrag"] resolver = "2" [workspace.package] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml new file mode 100644 index 000000000..591d7d48c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "shinkai-graphrag" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1.74" +futures = "0.3.30" +polars = "0.41.3" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.117" +tiktoken = "1.0.1" +tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs new file mode 100644 index 000000000..705818a77 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -0,0 +1,18 @@ +use async_trait::async_trait; +// use polars::prelude::*; +use std::collections::HashMap; + +// TODO: Serialize and Deserialize polars::frame::DataFrame +type DataFrame = Vec; + +#[async_trait] +pub trait GlobalContextBuilder { + /// Build the context for the global search mode. + async fn build_context( + &self, + conversation_history: Option, + context_builder_params: Option>, + ) -> (Vec, HashMap); +} + +pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs new file mode 100644 index 000000000..709d766d9 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -0,0 +1 @@ +pub mod context_builder; diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs new file mode 100644 index 000000000..08bc3d655 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -0,0 +1,3 @@ +pub mod context_builder; +pub mod llm; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs new file mode 100644 index 000000000..de33da38b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; + +pub struct BaseLLMCallback { + response: Vec, +} + +impl BaseLLMCallback { + pub fn new() -> Self { + BaseLLMCallback { response: Vec::new() } + } + + pub fn on_llm_new_token(&mut self, token: &str) { + self.response.push(token.to_string()); + } +} + +#[async_trait] +pub trait BaseLLM { + async fn generate(&self, messages: Vec, streaming: bool, callbacks: Option>) + -> String; + + async fn agenerate( + &self, + messages: Vec, + streaming: bool, + callbacks: Option>, + ) -> String; +} + +#[async_trait] +pub trait BaseTextEmbedding { + async fn embed(&self, text: &str) -> Vec; + + async fn aembed(&self, text: &str) -> Vec; +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs new file mode 100644 index 000000000..214bbef7c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -0,0 +1 @@ +pub mod llm; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs new file mode 100644 index 000000000..0a18ec55a --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -0,0 +1,255 @@ +use futures::future::join_all; +//use polars::frame::DataFrame; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Instant; +use tiktoken::encoding::Encoding; +use tokio::sync::Semaphore; + +use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; +use crate::llm::llm::BaseLLM; + +// TODO: Serialize and Deserialize polars::frame::DataFrame +type DataFrame = Vec; + +#[derive(Debug, Serialize, Deserialize)] +struct SearchResult { + response: ResponseType, + context_data: ContextData, + context_text: ContextText, + completion_time: f64, + llm_calls: u32, + prompt_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ResponseType { + String(String), + Dictionary(HashMap), + Dictionaries(Vec>), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ContextData { + String(String), + DataFrames(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ContextText { + String(String), + Strings(Vec), + Dictionary(HashMap), +} + +#[derive(Serialize, Deserialize)] +pub struct GlobalSearchResult { + response: ResponseType, + context_data: ContextData, + context_text: ContextText, + completion_time: f64, + llm_calls: i32, + prompt_tokens: i32, + map_responses: Vec, + reduce_context_data: ContextData, + reduce_context_text: ContextText, +} + +struct GlobalSearchLLMCallback { + map_response_contexts: Vec, + map_response_outputs: Vec, +} + +impl GlobalSearchLLMCallback { + pub fn new() -> Self { + GlobalSearchLLMCallback { + map_response_contexts: Vec::new(), + map_response_outputs: Vec::new(), + } + } + + pub fn on_map_response_start(&mut self, map_response_contexts: Vec) { + self.map_response_contexts = map_response_contexts; + } + + pub fn on_map_response_end(&mut self, map_response_outputs: Vec) { + self.map_response_outputs = map_response_outputs; + } +} + +pub struct GlobalSearch { + llm: Box, + context_builder: Box, + token_encoder: Option, + llm_params: Option>, + context_builder_params: Option>, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: HashMap, + reduce_llm_params: HashMap, + semaphore: Semaphore, +} + +impl GlobalSearch { + pub fn new( + llm: Box, + context_builder: Box, + token_encoder: Option, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + json_mode: bool, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: HashMap, + reduce_llm_params: HashMap, + context_builder_params: Option>, + concurrent_coroutines: usize, + ) -> Self { + let mut map_llm_params = map_llm_params; + + if json_mode { + map_llm_params.insert( + "response_format".to_string(), + serde_json::json!({"type": "json_object"}), + ); + } else { + map_llm_params.remove("response_format"); + } + + let semaphore = Semaphore::new(concurrent_coroutines); + + GlobalSearch { + llm, + context_builder, + token_encoder, + llm_params: None, + context_builder_params, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + semaphore, + } + } + + pub async fn asearch( + &self, + query: String, + conversation_history: Option, + ) -> GlobalSearchResult { + // Step 1: Generate answers for each batch of community short summaries + let start_time = Instant::now(); + let (context_chunks, context_records) = self + .context_builder + .build_context(conversation_history, self.context_builder_params) + .await; + + if let Some(callbacks) = &self.callbacks { + for callback in callbacks { + callback.on_map_response_start(context_chunks); + } + } + + let map_responses: Vec<_> = join_all( + context_chunks + .iter() + .map(|data| self._map_response_single_batch(data, &query, &self.map_llm_params)), + ) + .await; + + if let Some(callbacks) = &self.callbacks { + for callback in callbacks { + callback.on_map_response_end(&map_responses); + } + } + + let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); + let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); + + // Step 2: Combine the intermediate answers from step 2 to generate the final answer + let reduce_response = self + ._reduce_response(&map_responses, &query, self.reduce_llm_params) + .await; + + GlobalSearchResult { + response: reduce_response.response, + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::Strings(context_chunks), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: map_llm_calls + reduce_response.llm_calls, + prompt_tokens: map_prompt_tokens + reduce_response.prompt_tokens, + map_responses, + reduce_context_data: reduce_response.context_data, + reduce_context_text: reduce_response.context_text, + } + } + + async fn _reduce_response( + &self, + map_responses: Vec, + query: &str, + reduce_llm_params: HashMap, + ) -> SearchResult { + let start_time = Instant::now(); + let mut key_points = Vec::new(); + + for (index, response) in map_responses.iter().enumerate() { + if let ResponseType::Dictionaries(response_list) = response.response { + for element in response_list { + if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { + key_points.push((index, answer.clone(), score.clone())); + } + } + } + } + + let filtered_key_points: Vec<_> = key_points + .into_iter() + .filter(|(_, _, score)| score.as_f64().unwrap_or(0.0) > 0.0) + .collect(); + + if filtered_key_points.is_empty() && !self.allow_general_knowledge { + return SearchResult { + response: ResponseType::String("NO_DATA_ANSWER".to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + }; + } + + let mut sorted_key_points = filtered_key_points; + sorted_key_points.sort_by(|a, b| { + b.2.as_f64() + .unwrap_or(0.0) + .partial_cmp(&a.2.as_f64().unwrap_or(0.0)) + .unwrap() + }); + + // TODO: Implement rest of the function + + SearchResult { + response: ResponseType::String("Combined response".to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + } + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/mod.rs new file mode 100644 index 000000000..a12441830 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/mod.rs @@ -0,0 +1 @@ +pub mod global_search; From 123bbb74fb2a8611db3fe18da528432bb2fdf292 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 1 Aug 2024 15:12:24 +0200 Subject: [PATCH 02/23] global search, open ai chat --- Cargo.lock | 176 +++++++++--- shinkai-libs/shinkai-graphrag/Cargo.toml | 4 +- shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 23 +- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 2 + .../shinkai-graphrag/src/llm/openai.rs | 107 +++++++ .../shinkai-graphrag/src/llm/utils.rs | 7 + .../src/search/global_search.rs | 260 ++++++++++++++---- 7 files changed, 468 insertions(+), 111 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/openai.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 90c968bde..37c30ea8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -718,6 +718,15 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + [[package]] name = "async-executor" version = "1.5.1" @@ -1410,6 +1419,20 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.10", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -1659,6 +1682,17 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata 0.4.7", + "serde", +] + [[package]] name = "buf_redux" version = "0.8.4" @@ -2222,7 +2256,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0436149c9f6a1935b13306206c739b1ba84fa81f551b5eb87fc2ca7a13700af" dependencies = [ "clap 4.5.4", - "derive_builder", + "derive_builder 0.12.0", "entities", "memchr", "once_cell", @@ -3138,7 +3172,16 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" dependencies = [ - "derive_builder_macro", + "derive_builder_macro 0.12.0", +] + +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro 0.20.0", ] [[package]] @@ -3153,16 +3196,38 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling 0.20.10", + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "derive_builder_macro" version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" dependencies = [ - "derive_builder_core", + "derive_builder_core 0.12.0", "syn 1.0.109", ] +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core 0.20.0", + "syn 2.0.66", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -3898,6 +3963,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -3940,6 +4016,16 @@ dependencies = [ "regex", ] +[[package]] +name = "fancy-regex" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fast-float" version = "0.2.0" @@ -7350,28 +7436,6 @@ dependencies = [ "hmac 0.12.1", ] -[[package]] -name = "pcre2" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be55c43ac18044541d58d897e8f4c55157218428953ebd39d86df3ba0286b2b" -dependencies = [ - "libc", - "log 0.4.21", - "pcre2-sys", -] - -[[package]] -name = "pcre2-sys" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "550f5d18fb1b90c20b87e161852c10cde77858c3900c5059b5ad2a1449f11d8a" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "pddl-ish-parser" version = "0.0.4" @@ -9160,6 +9224,12 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" + [[package]] name = "regex-lite" version = "0.1.5" @@ -9293,6 +9363,22 @@ dependencies = [ "winreg 0.52.0", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime 0.3.17", + "nom", + "pin-project-lite", + "reqwest 0.12.5", + "thiserror", +] + [[package]] name = "rfc6979" version = "0.3.1" @@ -9607,16 +9693,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "rust_decimal_macros" -version = "1.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a05bf7103af0797dbce0667c471946b29b9eaea34652eff67324f360fec027de" -dependencies = [ - "quote 1.0.36", - "rust_decimal", -] - [[package]] name = "rustc-demangle" version = "0.1.23" @@ -9952,6 +10028,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -10222,12 +10308,14 @@ dependencies = [ name = "shinkai-graphrag" version = "0.1.0" dependencies = [ + "anyhow", + "async-openai", "async-trait", "futures", "polars", "serde", "serde_json", - "tiktoken", + "tiktoken-rs", "tokio", ] @@ -11069,7 +11157,7 @@ checksum = "874dcfa363995604333cf947ae9f751ca3af4522c60886774c4963943b4746b1" dependencies = [ "bincode", "bitflags 1.3.2", - "fancy-regex", + "fancy-regex 0.11.0", "flate2", "fnv", "once_cell", @@ -11432,18 +11520,18 @@ dependencies = [ ] [[package]] -name = "tiktoken" -version = "1.0.1" +name = "tiktoken-rs" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d8d54b2ba3a0f0b14743f9f0e0f6884ee03ad8789f1ce6a4b5650c0d74fad8" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" dependencies = [ "anyhow", "base64 0.21.7", + "bstr", + "fancy-regex 0.12.0", "lazy_static", - "maplit", - "pcre2", - "rust_decimal", - "rust_decimal_macros", + "parking_lot 0.12.1", + "rustc-hash", ] [[package]] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 591d7d48c..9057f8ea0 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -4,10 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] +anyhow = "1.0.86" +async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" polars = "0.41.3" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" -tiktoken = "1.0.1" +tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index de33da38b..4c8f0f68e 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -1,7 +1,11 @@ +use std::collections::HashMap; + use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone)] pub struct BaseLLMCallback { - response: Vec, + pub response: Vec, } impl BaseLLMCallback { @@ -14,22 +18,25 @@ impl BaseLLMCallback { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + String(String), + Strings(Vec), + Dictionary(Vec>), +} + #[async_trait] pub trait BaseLLM { - async fn generate(&self, messages: Vec, streaming: bool, callbacks: Option>) - -> String; - async fn agenerate( &self, - messages: Vec, + messages: MessageType, streaming: bool, callbacks: Option>, - ) -> String; + llm_params: HashMap, + ) -> anyhow::Result; } #[async_trait] pub trait BaseTextEmbedding { - async fn embed(&self, text: &str) -> Vec; - async fn aembed(&self, text: &str) -> Vec; } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 214bbef7c..00cb1d9e1 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1 +1,3 @@ pub mod llm; +pub mod openai; +pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs new file mode 100644 index 000000000..a0b3986b6 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs @@ -0,0 +1,107 @@ +use std::collections::HashMap; + +use async_openai::{ + config::OpenAIConfig, + types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs}, + Client, +}; + +use super::llm::{BaseLLMCallback, MessageType}; + +pub struct ChatOpenAI { + pub api_key: Option, + pub model: String, + pub max_retries: usize, +} + +impl ChatOpenAI { + pub fn new(api_key: Option, model: String, max_retries: usize) -> Self { + ChatOpenAI { + api_key, + model, + max_retries, + } + } + + pub async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: HashMap, + ) -> anyhow::Result { + let mut retry_count = 0; + + loop { + match self + ._agenerate(messages.clone(), streaming, callbacks.clone(), llm_params.clone()) + .await + { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: HashMap, + ) -> anyhow::Result { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let messages = match messages { + MessageType::String(message) => vec![message], + MessageType::Strings(messages) => messages, + MessageType::Dictionary(messages) => { + let messages = messages + .iter() + .map(|message_map| { + message_map + .iter() + .map(|(key, value)| format!("{}: {}", key, value)) + .collect::>() + .join("\n") + }) + .collect(); + messages + } + }; + + let request_messages = messages + .into_iter() + .map(|m| ChatCompletionRequestSystemMessageArgs::default().content(m).build()) + .collect::>(); + + let request_messages: Result, _> = request_messages.into_iter().collect(); + let request_messages = request_messages?; + let request_messages = request_messages + .into_iter() + .map(|m| Into::::into(m.clone())) + .collect::>(); + + let request = CreateChatCompletionRequestArgs::default() + .model(self.model.clone()) + .messages(request_messages) + .build()?; + + let response = client.chat().create(request).await?; + + if let Some(choice) = response.choices.get(0) { + return Ok(choice.message.content.clone().unwrap_or_default()); + } + + return Ok(String::new()); + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs new file mode 100644 index 000000000..a6b4dfc54 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs @@ -0,0 +1,7 @@ +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; + +pub fn num_tokens(text: &str, token_encoder: Option) -> usize { + let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_ordinary(text).len() +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 0a18ec55a..0dde10c0c 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -1,62 +1,72 @@ use futures::future::join_all; //use polars::frame::DataFrame; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use std::time::Instant; -use tiktoken::encoding::Encoding; -use tokio::sync::Semaphore; +use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; -use crate::llm::llm::BaseLLM; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; +use crate::llm::utils::num_tokens; // TODO: Serialize and Deserialize polars::frame::DataFrame type DataFrame = Vec; -#[derive(Debug, Serialize, Deserialize)] -struct SearchResult { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { response: ResponseType, context_data: ContextData, context_text: ContextText, completion_time: f64, - llm_calls: u32, - prompt_tokens: u32, + llm_calls: usize, + prompt_tokens: usize, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ResponseType { String(String), Dictionary(HashMap), Dictionaries(Vec>), + KeyPoints(Vec), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ContextData { String(String), DataFrames(Vec), Dictionary(HashMap), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ContextText { String(String), Strings(Vec), Dictionary(HashMap), } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeyPoint { + answer: String, + score: i32, +} + #[derive(Serialize, Deserialize)] pub struct GlobalSearchResult { response: ResponseType, context_data: ContextData, context_text: ContextText, completion_time: f64, - llm_calls: i32, - prompt_tokens: i32, + llm_calls: usize, + prompt_tokens: usize, map_responses: Vec, reduce_context_data: ContextData, reduce_context_text: ContextText, } -struct GlobalSearchLLMCallback { +#[derive(Debug, Clone)] +pub struct GlobalSearchLLMCallback { + response: Vec, map_response_contexts: Vec, map_response_outputs: Vec, } @@ -64,6 +74,7 @@ struct GlobalSearchLLMCallback { impl GlobalSearchLLMCallback { pub fn new() -> Self { GlobalSearchLLMCallback { + response: Vec::new(), map_response_contexts: Vec::new(), map_response_outputs: Vec::new(), } @@ -81,10 +92,8 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, context_builder: Box, - token_encoder: Option, - llm_params: Option>, + token_encoder: Option, context_builder_params: Option>, - map_system_prompt: String, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -93,15 +102,13 @@ pub struct GlobalSearch { max_data_tokens: usize, map_llm_params: HashMap, reduce_llm_params: HashMap, - semaphore: Semaphore, } impl GlobalSearch { pub fn new( llm: Box, context_builder: Box, - token_encoder: Option, - map_system_prompt: String, + token_encoder: Option, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -112,7 +119,6 @@ impl GlobalSearch { map_llm_params: HashMap, reduce_llm_params: HashMap, context_builder_params: Option>, - concurrent_coroutines: usize, ) -> Self { let mut map_llm_params = map_llm_params; @@ -125,15 +131,11 @@ impl GlobalSearch { map_llm_params.remove("response_format"); } - let semaphore = Semaphore::new(concurrent_coroutines); - GlobalSearch { llm, context_builder, token_encoder, - llm_params: None, context_builder_params, - map_system_prompt, reduce_system_prompt, response_type, allow_general_knowledge, @@ -142,7 +144,6 @@ impl GlobalSearch { max_data_tokens, map_llm_params, reduce_llm_params, - semaphore, } } @@ -150,42 +151,59 @@ impl GlobalSearch { &self, query: String, conversation_history: Option, - ) -> GlobalSearchResult { + ) -> anyhow::Result { // Step 1: Generate answers for each batch of community short summaries let start_time = Instant::now(); let (context_chunks, context_records) = self .context_builder - .build_context(conversation_history, self.context_builder_params) + .build_context(conversation_history, self.context_builder_params.clone()) .await; - if let Some(callbacks) = &self.callbacks { - for callback in callbacks { - callback.on_map_response_start(context_chunks); + let mut callbacks = match &self.callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_start(context_chunks.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) } - } + None => None, + }; let map_responses: Vec<_> = join_all( context_chunks .iter() - .map(|data| self._map_response_single_batch(data, &query, &self.map_llm_params)), + .map(|data| self._map_response_single_batch(data, &query, self.map_llm_params.clone())), ) .await; - if let Some(callbacks) = &self.callbacks { - for callback in callbacks { - callback.on_map_response_end(&map_responses); + let map_responses: Result, _> = map_responses.into_iter().collect(); + let map_responses = map_responses?; + + callbacks = match &callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_end(map_responses.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) } - } + None => None, + }; let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); // Step 2: Combine the intermediate answers from step 2 to generate the final answer let reduce_response = self - ._reduce_response(&map_responses, &query, self.reduce_llm_params) - .await; + ._reduce_response(map_responses.clone(), &query, callbacks, self.reduce_llm_params.clone()) + .await?; - GlobalSearchResult { + Ok(GlobalSearchResult { response: reduce_response.response, context_data: ContextData::Dictionary(context_records), context_text: ContextText::Strings(context_chunks), @@ -195,61 +213,187 @@ impl GlobalSearch { map_responses, reduce_context_data: reduce_response.context_data, reduce_context_text: reduce_response.context_text, + }) + } + + async fn _map_response_single_batch( + &self, + context_data: &str, + query: &str, + llm_params: HashMap, + ) -> anyhow::Result { + let start_time = Instant::now(); + let search_prompt = String::new(); + let mut search_messages = Vec::new(); + search_messages.push(HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ])); + search_messages.push(HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ])); + + let search_response = self + .llm + .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .await?; + + let processed_response = self.parse_search_response(&search_response); + + Ok(SearchResult { + response: ResponseType::KeyPoints(processed_response), + context_data: ContextData::String(context_data.to_string()), + context_text: ContextText::String(context_data.to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + }) + } + + fn parse_search_response(&self, search_response: &str) -> Vec { + let parsed_elements: Value = serde_json::from_str(search_response).unwrap_or_default(); + + if let Some(points) = parsed_elements.get("points") { + if let Some(points) = points.as_array() { + return points + .iter() + .map(|element| KeyPoint { + answer: element + .get("description") + .unwrap_or(&Value::String("".to_string())) + .to_string(), + score: element + .get("score") + .unwrap_or(&Value::Number(serde_json::Number::from(0))) + .as_i64() + .unwrap_or(0) as i32, + }) + .collect::>(); + } } + + Vec::new() } async fn _reduce_response( &self, map_responses: Vec, query: &str, + callbacks: Option>, reduce_llm_params: HashMap, - ) -> SearchResult { + ) -> anyhow::Result { let start_time = Instant::now(); - let mut key_points = Vec::new(); + let mut key_points: Vec> = Vec::new(); for (index, response) in map_responses.iter().enumerate() { - if let ResponseType::Dictionaries(response_list) = response.response { + if let ResponseType::Dictionaries(response_list) = &response.response { for element in response_list { if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { - key_points.push((index, answer.clone(), score.clone())); + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), answer.to_string()); + point.insert("score".to_string(), score.to_string()); + key_points.push(point); } } } } - let filtered_key_points: Vec<_> = key_points + let filtered_key_points: Vec> = key_points .into_iter() - .filter(|(_, _, score)| score.as_f64().unwrap_or(0.0) > 0.0) + .filter(|point| point.get("score").unwrap().parse::().unwrap() > 0) .collect(); if filtered_key_points.is_empty() && !self.allow_general_knowledge { - return SearchResult { + return Ok(SearchResult { response: ResponseType::String("NO_DATA_ANSWER".to_string()), context_data: ContextData::String("".to_string()), context_text: ContextText::String("".to_string()), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 0, prompt_tokens: 0, - }; + }); } let mut sorted_key_points = filtered_key_points; sorted_key_points.sort_by(|a, b| { - b.2.as_f64() - .unwrap_or(0.0) - .partial_cmp(&a.2.as_f64().unwrap_or(0.0)) + b.get("score") + .unwrap() + .parse::() .unwrap() + .cmp(&a.get("score").unwrap().parse::().unwrap()) }); - // TODO: Implement rest of the function + let mut data: Vec = Vec::new(); + let mut total_tokens = 0; + for point in sorted_key_points { + let mut formatted_response_data: Vec = Vec::new(); + formatted_response_data.push(format!("----Analyst {}----", point.get("analyst").unwrap())); + formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); + formatted_response_data.push(point.get("answer").unwrap().to_string()); + let formatted_response_text = formatted_response_data.join("\n"); + if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { + break; + } + data.push(formatted_response_text.clone()); + total_tokens += num_tokens(&formatted_response_text, self.token_encoder); + } + let text_data = data.join("\n\n"); - SearchResult { - response: ResponseType::String("Combined response".to_string()), - context_data: ContextData::String("".to_string()), - context_text: ContextText::String("".to_string()), + let search_prompt = format!( + "{}\n{}", + self.reduce_system_prompt + .replace("{report_data}", &text_data) + .replace("{response_type}", &self.response_type), + if self.allow_general_knowledge { + self.general_knowledge_inclusion_prompt.clone() + } else { + String::new() + } + ); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; + + let llm_callbacks = match callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + llm_callbacks.push(BaseLLMCallback { + response: callback.response.clone(), + }); + } + Some(llm_callbacks) + } + None => None, + }; + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + true, + llm_callbacks, + reduce_llm_params, + ) + .await?; + + Ok(SearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::String(text_data.clone()), + context_text: ContextText::String(text_data), completion_time: start_time.elapsed().as_secs_f64(), - llm_calls: 0, - prompt_tokens: 0, - } + llm_calls: 1, + prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + }) } } From f22dd6b4c50a6e500c6dc62e270a7b666ef534fa Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 2 Aug 2024 14:19:04 +0200 Subject: [PATCH 03/23] read indexer entities, indexer reports --- Cargo.lock | 77 +++++++ shinkai-libs/shinkai-graphrag/Cargo.toml | 3 +- .../src/context_builder/indexer_entities.rs | 207 ++++++++++++++++++ .../src/context_builder/indexer_reports.rs | 135 ++++++++++++ .../src/context_builder/mod.rs | 2 + .../shinkai-graphrag/src/llm/openai.rs | 6 +- 6 files changed, 426 insertions(+), 4 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs diff --git a/Cargo.lock b/Cargo.lock index 37c30ea8e..9f8f3f732 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -117,6 +117,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -861,6 +876,28 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "async-task" version = "4.4.0" @@ -1672,6 +1709,27 @@ dependencies = [ "syn_derive", ] +[[package]] +name = "brotli" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19483b140a7ac7174d34b5a581b406c64f84da5409d3e09cf4fff604f9270e67" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bs58" version = "0.5.0" @@ -7349,6 +7407,10 @@ name = "parquet-format-safe" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" +dependencies = [ + "async-trait", + "futures", +] [[package]] name = "parse-zoneinfo" @@ -7916,6 +7978,7 @@ dependencies = [ "ethnum", "fast-float", "foreign_vec", + "futures", "getrandom 0.2.10", "hashbrown 0.14.5", "itoa 1.0.9", @@ -8031,10 +8094,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" dependencies = [ "ahash 0.8.11", + "async-trait", "atoi_simd", "bytes", "chrono", "fast-float", + "futures", "home", "itoa 1.0.9", "memchr", @@ -8045,6 +8110,7 @@ dependencies = [ "polars-arrow", "polars-core", "polars-error", + "polars-parquet", "polars-time", "polars-utils", "rayon", @@ -8052,6 +8118,8 @@ dependencies = [ "ryu", "simdutf8", "smartstring", + "tokio", + "tokio-util 0.7.11", ] [[package]] @@ -8135,8 +8203,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" dependencies = [ "ahash 0.8.11", + "async-stream", "base64 0.22.1", + "brotli", "ethnum", + "flate2", + "futures", + "lz4", "num-traits", "parquet-format-safe", "polars-arrow", @@ -8144,7 +8217,9 @@ dependencies = [ "polars-error", "polars-utils", "simdutf8", + "snap", "streaming-decompression", + "zstd 0.13.2", ] [[package]] @@ -8190,6 +8265,7 @@ dependencies = [ "polars-core", "polars-io", "polars-ops", + "polars-parquet", "polars-time", "polars-utils", "rayon", @@ -10313,6 +10389,7 @@ dependencies = [ "async-trait", "futures", "polars", + "polars-lazy", "serde", "serde_json", "tiktoken-rs", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 9057f8ea0..5cd0f1549 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -8,7 +8,8 @@ anyhow = "1.0.86" async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" -polars = "0.41.3" +polars = { version = "0.41.3", features = ["lazy", "parquet"] } +polars-lazy = "0.41.3" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs new file mode 100644 index 000000000..0265a05f7 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -0,0 +1,207 @@ +use std::collections::HashMap; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +use super::indexer_reports::filter_under_community_level; + +pub fn read_indexer_entities( + final_nodes: &DataFrame, + final_entities: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let mut entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_df = entity_df.rename("title", "name")?.rename("degree", "rank")?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::Int32)) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("rank").cast(DataType::Int32)) + .collect()?; + + let entity_embedding_df = final_entities.clone(); + + let entity_df = entity_df + .clone() + .lazy() + .group_by([col("name"), col("rank")]) + .agg([col("community").max()]) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::String)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .join( + entity_embedding_df.clone().lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .filter(len().over([col("name")]).gt(lit(1))) + .collect()?; + + let entities = read_entities( + &entity_df, + "id", + Some("human_readable_id"), + "name", + Some("type"), + Some("description"), + None, + Some("description_embedding"), + None, + Some("community"), + Some("text_unit_ids"), + None, + Some("rank"), + )?; + + Ok(entities) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Entity { + pub id: String, + pub short_id: Option, + pub title: String, + pub entity_type: Option, + pub description: Option, + pub description_embedding: Option>, + pub name_embedding: Option>, + pub graph_embedding: Option>, + pub community_ids: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub rank: Option, + pub attributes: Option>, +} + +pub fn read_entities( + df: &DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + type_col: Option<&str>, + description_col: Option<&str>, + name_embedding_col: Option<&str>, + description_embedding_col: Option<&str>, + graph_embedding_col: Option<&str>, + community_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + rank_col: Option<&str>, + // attributes_cols: Option>, +) -> anyhow::Result> { + let column_names = [ + id_col, + short_id_col.unwrap_or("short_id"), + title_col, + type_col.unwrap_or("type"), + description_col.unwrap_or("description"), + name_embedding_col.unwrap_or("name_embedding"), + description_embedding_col.unwrap_or("description_embedding"), + graph_embedding_col.unwrap_or("graph_embedding"), + community_col.unwrap_or("community_ids"), + text_unit_ids_col.unwrap_or("text_unit_ids"), + document_ids_col.unwrap_or("document_ids"), + rank_col.unwrap_or("degree"), + ]; + + let mut df = df.clone(); + df.as_single_chunk_par(); + let mut iters = df.columns(column_names)?.iter().map(|s| s.iter()).collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value.to_string()); + } + } + rows.push(row_values); + } + + let mut entities = Vec::new(); + for row in rows { + let report = Entity { + id: row.get(0).unwrap_or(&String::new()).to_string(), + short_id: Some(row.get(1).unwrap_or(&String::new()).to_string()), + title: row.get(2).unwrap_or(&String::new()).to_string(), + entity_type: Some(row.get(3).unwrap_or(&String::new()).to_string()), + description: Some(row.get(4).unwrap_or(&String::new()).to_string()), + name_embedding: Some( + row.get(5) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + description_embedding: Some( + row.get(6) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + graph_embedding: Some( + row.get(7) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + community_ids: Some( + row.get(8) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + text_unit_ids: Some( + row.get(9) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + document_ids: Some( + row.get(10) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + rank: Some(row.get(11).and_then(|v| v.parse::().ok()).unwrap_or(0)), + attributes: None, + }; + entities.push(report); + } + + Ok(entities) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs new file mode 100644 index 000000000..07811ac8b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +pub fn read_indexer_reports( + final_community_reports: &DataFrame, + final_nodes: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::Int32)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::String)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .group_by([col("title")]) + .agg([col("community").max()]) + .collect()?; + + let filtered_community_df = entity_df + .clone() + .lazy() + .filter(len().over([col("community")]).gt(lit(1))) + .collect()?; + + let report_df = final_community_reports.clone(); + let report_df = filter_under_community_level(&report_df, community_level)?; + + let report_df = report_df + .clone() + .lazy() + .join( + filtered_community_df.clone().lazy(), + [col("community")], + [col("community")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let reports = read_community_reports(&report_df, "community", Some("community"), None, None)?; + Ok(reports) +} + +pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { + let mask = df.column("level")?.i32()?.lt_eq(community_level); + let result = df.filter(&mask)?; + + Ok(result) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct CommunityReport { + pub id: String, + pub short_id: Option, + pub title: String, + pub community_id: String, + pub summary: String, + pub full_content: String, + pub rank: Option, + pub summary_embedding: Option>, + pub full_content_embedding: Option>, + pub attributes: Option>, +} + +pub fn read_community_reports( + df: &DataFrame, + _id_col: &str, + _short_id_col: Option<&str>, + // title_col: &str, + // community_col: &str, + // summary_col: &str, + // content_col: &str, + // rank_col: Option<&str>, + _summary_embedding_col: Option<&str>, + _content_embedding_col: Option<&str>, + // attributes_cols: Option<&[&str]>, +) -> anyhow::Result> { + let mut df = df.clone(); + df.as_single_chunk_par(); + let mut iters = df + .columns(["community", "title", "summary", "full_content", "rank"])? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value.to_string()); + } + } + rows.push(row_values); + } + + let mut reports = Vec::new(); + for row in rows { + let report = CommunityReport { + id: row.get(0).unwrap_or(&String::new()).to_string(), + short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), + title: row.get(1).unwrap_or(&String::new()).to_string(), + community_id: row.get(0).unwrap_or(&String::new()).to_string(), + summary: row.get(3).unwrap_or(&String::new()).to_string(), + full_content: row.get(4).unwrap_or(&String::new()).to_string(), + rank: Some(row.get(5).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + summary_embedding: None, + full_content_embedding: None, + attributes: None, + }; + reports.push(report); + } + + Ok(reports) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 709d766d9..08173d23d 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1 +1,3 @@ pub mod context_builder; +pub mod indexer_entities; +pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs index a0b3986b6..644f1101f 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs @@ -52,9 +52,9 @@ impl ChatOpenAI { async fn _agenerate( &self, messages: MessageType, - streaming: bool, - callbacks: Option>, - llm_params: HashMap, + _streaming: bool, + _callbacks: Option>, + _llm_params: HashMap, ) -> anyhow::Result { let client = match &self.api_key { Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), From df03fcea2de4d41917cefada1470395d292e4083 Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 7 Aug 2024 16:40:27 +0200 Subject: [PATCH 04/23] build community context batch 1 --- Cargo.lock | 54 ++- shinkai-libs/shinkai-graphrag/Cargo.toml | 1 + .../src/context_builder/community_context.rs | 326 ++++++++++++++++++ .../src/context_builder/indexer_entities.rs | 2 +- .../src/context_builder/indexer_reports.rs | 2 +- .../src/context_builder/mod.rs | 1 + .../src/search/global_search.rs | 7 +- 7 files changed, 372 insertions(+), 21 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs diff --git a/Cargo.lock b/Cargo.lock index 9f8f3f732..27b6ce90a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -811,6 +811,32 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-openai" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e5ff98f9e7c605df4c88783a0439d1dc667ce86bd79e99d4164f8b0c05ccc" +dependencies = [ + "async-convert", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder 0.20.0", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest 0.12.5", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util 0.7.11", + "tracing", +] + [[package]] name = "async-priority-channel" version = "0.2.0" @@ -6339,12 +6365,6 @@ dependencies = [ "libc", ] -[[package]] -name = "maplit" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" - [[package]] name = "markup5ever" version = "0.10.1" @@ -9300,12 +9320,6 @@ dependencies = [ "regex-syntax 0.8.2", ] -[[package]] -name = "regex-automata" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" - [[package]] name = "regex-lite" version = "0.1.5" @@ -9414,6 +9428,7 @@ dependencies = [ "js-sys", "log 0.4.21", "mime 0.3.17", + "mime_guess 2.0.4", "once_cell", "percent-encoding 2.3.1", "pin-project-lite", @@ -10251,9 +10266,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" dependencies = [ "base64 0.22.1", "chrono", @@ -10269,9 +10284,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" dependencies = [ "darling 0.20.10", "proc-macro2 1.0.84", @@ -10390,6 +10405,7 @@ dependencies = [ "futures", "polars", "polars-lazy", + "rand 0.8.5", "serde", "serde_json", "tiktoken-rs", @@ -10904,6 +10920,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.4.9" diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 5cd0f1549..5a5ceafb7 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -10,6 +10,7 @@ async-trait = "0.1.74" futures = "0.3.30" polars = { version = "0.41.3", features = ["lazy", "parquet"] } polars-lazy = "0.41.3" +rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs new file mode 100644 index 000000000..3b8350a3d --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -0,0 +1,326 @@ +use std::{ + collections::HashMap, + io::{Cursor, Read}, +}; + +use polars::{ + frame::DataFrame, + io::SerWriter, + prelude::{col, CsvWriter, DataType, IntoLazy, NamedFrom, SortMultipleOptions}, + series::Series, +}; +use rand::prelude::SliceRandom; +use tiktoken_rs::tokenizer::Tokenizer; + +use crate::llm::utils::num_tokens; + +use super::{context_builder::ConversationHistory, indexer_entities::Entity, indexer_reports::CommunityReport}; + +pub struct GlobalCommunityContext { + community_reports: Vec, + entities: Option>, + token_encoder: Option, + random_state: i32, +} + +impl GlobalCommunityContext { + pub fn new( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + random_state: Option, + ) -> Self { + Self { + community_reports, + entities, + token_encoder, + random_state: random_state.unwrap_or(86), + } + } + + pub async fn build_context( + &self, + conversation_history: Option, + context_builder_params: Option>, + ) -> (Vec, HashMap) { + (vec![], HashMap::new()) + } +} + +pub fn build_community_context( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: i32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: i32, + single_batch: bool, + context_name: &str, + random_state: i32, +) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.is_some() + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .clone() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = _compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports + .iter() + .filter(|&report| _is_included(report)) + .cloned() + .collect(); + + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); + } + + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); + } + + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch_text = String::new(); + let mut batch_tokens = 0; + let mut batch_records: Vec> = Vec::new(); + + let mut _init_batch = || { + batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + batch_tokens = num_tokens(&batch_text, token_encoder); + batch_records = Vec::new(); + }; + + let _cut_batch = |batch_records: Vec>, header: Vec| -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.is_some() { + Some(community_weight_name) + } else { + None + }; + let rank_column = if include_community_rank { + Some(community_rank_name) + } else { + None + }; + + let mut record_df = _convert_report_context_to_df(batch_records, header, weight_column, rank_column)?; + if record_df.is_empty() { + return Ok(()); + } + + let mut buffer = Cursor::new(Vec::new()); + CsvWriter::new(buffer.clone()).finish(&mut record_df).unwrap(); + + let mut current_context_text = String::new(); + buffer.read_to_string(&mut current_context_text)?; + + all_context_text.push(current_context_text); + all_context_records.push(record_df); + + Ok(()) + }; + + _init_batch(); + + Ok((vec![], HashMap::new())) +} + +fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, +) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports.clone(); + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for mut report in community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + } + community_reports +} + +fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, +) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); + } + + let mut data_series = Vec::new(); + for (header, records) in header.iter().zip(context_records.iter()) { + let series = Series::new(header, records); + data_series.push(series); + } + + let record_df = DataFrame::new(data_series)?; + + return _rank_report_context(record_df, weight_column, rank_column); +} + +fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, +) -> anyhow::Result { + let weight_column = weight_column.unwrap_or("occurrence weight"); + let rank_column = rank_column.unwrap_or("rank"); + + let mut rank_attributes = Vec::new(); + rank_attributes.push(weight_column); + let report_df = report_df + .clone() + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + + rank_attributes.push(rank_column); + let report_df = report_df + .clone() + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + + let report_df = report_df + .clone() + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + + Ok(report_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 0265a05f7..337831da0 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -83,7 +83,7 @@ pub fn read_indexer_entities( Ok(entities) } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Entity { pub id: String, pub short_id: Option, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index 07811ac8b..fcfca58a6 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -67,7 +67,7 @@ pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> any Ok(result) } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct CommunityReport { pub id: String, pub short_id: Option, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 08173d23d..0abed5320 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1,3 +1,4 @@ +pub mod community_context; pub mod context_builder; pub mod indexer_entities; pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 0dde10c0c..55636214d 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -6,7 +6,8 @@ use std::collections::HashMap; use std::time::Instant; use tiktoken_rs::tokenizer::Tokenizer; -use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; +use crate::context_builder::community_context::GlobalCommunityContext; +use crate::context_builder::context_builder::ConversationHistory; use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; use crate::llm::utils::num_tokens; @@ -91,7 +92,7 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, - context_builder: Box, + context_builder: GlobalCommunityContext, token_encoder: Option, context_builder_params: Option>, reduce_system_prompt: String, @@ -107,7 +108,7 @@ pub struct GlobalSearch { impl GlobalSearch { pub fn new( llm: Box, - context_builder: Box, + context_builder: GlobalCommunityContext, token_encoder: Option, reduce_system_prompt: String, response_type: String, From e23dd8fc246464ce9a6321a1121d6b831b997ada Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 8 Aug 2024 17:27:05 +0200 Subject: [PATCH 05/23] build community context, global search test --- shinkai-libs/shinkai-graphrag/.gitignore | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 7 +- .../src/context_builder/community_context.rs | 584 +++++++++++------- .../src/context_builder/context_builder.rs | 31 +- .../src/context_builder/indexer_entities.rs | 29 +- .../src/context_builder/indexer_reports.rs | 8 +- shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 9 +- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 - .../src/search/global_search.rs | 88 ++- .../tests/it/global_search_tests.rs | 109 ++++ .../shinkai-graphrag/tests/it/utils/mod.rs | 1 + .../{src/llm => tests/it/utils}/openai.rs | 45 +- shinkai-libs/shinkai-graphrag/tests/it_mod.rs | 4 + 13 files changed, 602 insertions(+), 315 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/.gitignore create mode 100644 shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs create mode 100644 shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs rename shinkai-libs/shinkai-graphrag/{src/llm => tests/it/utils}/openai.rs (70%) create mode 100644 shinkai-libs/shinkai-graphrag/tests/it_mod.rs diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore new file mode 100644 index 000000000..122af2cf4 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -0,0 +1 @@ +dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 5a5ceafb7..378f20e4f 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -5,13 +5,16 @@ edition = "2021" [dependencies] anyhow = "1.0.86" -async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" -polars = { version = "0.41.3", features = ["lazy", "parquet"] } +polars = { version = "0.41.3", features = ["dtype-struct", "lazy", "parquet"] } polars-lazy = "0.41.3" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" +tokio = { version = "1.36", features = ["full"] } + +[dev-dependencies] +async-openai = "0.23.4" tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index 3b8350a3d..e46219021 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -6,7 +6,7 @@ use std::{ use polars::{ frame::DataFrame, io::SerWriter, - prelude::{col, CsvWriter, DataType, IntoLazy, NamedFrom, SortMultipleOptions}, + prelude::{col, concat, CsvWriter, DataType, IntoLazy, LazyFrame, NamedFrom, SortMultipleOptions, UnionArgs}, series::Series, }; use rand::prelude::SliceRandom; @@ -14,13 +14,12 @@ use tiktoken_rs::tokenizer::Tokenizer; use crate::llm::utils::num_tokens; -use super::{context_builder::ConversationHistory, indexer_entities::Entity, indexer_reports::CommunityReport}; +use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; pub struct GlobalCommunityContext { community_reports: Vec, entities: Option>, token_encoder: Option, - random_state: i32, } impl GlobalCommunityContext { @@ -28,157 +27,346 @@ impl GlobalCommunityContext { community_reports: Vec, entities: Option>, token_encoder: Option, - random_state: Option, ) -> Self { Self { community_reports, entities, token_encoder, - random_state: random_state.unwrap_or(86), } } pub async fn build_context( &self, - conversation_history: Option, - context_builder_params: Option>, - ) -> (Vec, HashMap) { - (vec![], HashMap::new()) + context_builder_params: ContextBuilderParams, + ) -> anyhow::Result<(Vec, HashMap)> { + let ContextBuilderParams { + use_community_summary, + column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + community_rank_name, + include_community_weight, + community_weight_name, + normalize_community_weight, + max_tokens, + context_name, + } = context_builder_params; + + let (community_context, community_context_data) = CommunityContext::build_community_context( + self.community_reports.clone(), + self.entities.clone(), + self.token_encoder.clone(), + use_community_summary, + &column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + &community_rank_name, + include_community_weight, + &community_weight_name, + normalize_community_weight, + max_tokens, + false, + &context_name, + )?; + + let final_context = community_context; + let final_context_data = community_context_data; + + Ok((final_context, final_context_data)) } } -pub fn build_community_context( - community_reports: Vec, - entities: Option>, - token_encoder: Option, - use_community_summary: bool, - column_delimiter: &str, - shuffle_data: bool, - include_community_rank: bool, - min_community_rank: i32, - community_rank_name: &str, - include_community_weight: bool, - community_weight_name: &str, - normalize_community_weight: bool, - max_tokens: i32, - single_batch: bool, - context_name: &str, - random_state: i32, -) -> anyhow::Result<(Vec, HashMap)> { - let _is_included = |report: &CommunityReport| -> bool { - report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() - }; - - let _get_header = |attributes: Vec| -> Vec { - let mut header = vec!["id".to_string(), "title".to_string()]; - let mut filtered_attributes: Vec = attributes +pub struct CommunityContext {} + +impl CommunityContext { + pub fn build_community_context( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: u32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: usize, + single_batch: bool, + context_name: &str, + ) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.is_some() + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .clone() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = Self::_compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports .iter() - .filter(|&col| !header.contains(&col.to_string())) + .filter(|&report| _is_included(report)) .cloned() .collect(); - if !include_community_weight { - filtered_attributes.retain(|col| col != community_weight_name); + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); } - header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); - header.push(if use_community_summary { - "summary".to_string() - } else { - "content".to_string() - }); - - if include_community_rank { - header.push(community_rank_name.to_string()); + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); } - header - }; + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes.clone()); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch = Batch::new(); + + batch.init_batch(context_name, &header, column_delimiter, token_encoder); + + for report in selected_reports { + let (new_context_text, new_context) = _report_context_text(&report, &attributes); + let new_tokens = num_tokens(&new_context_text, token_encoder); + + // add the current batch to the context data and start a new batch if we are in multi-batch mode + if batch.batch_tokens + new_tokens > max_tokens { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + )?; + + if single_batch { + break; + } - let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { - let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + batch.init_batch(context_name, &header, column_delimiter, token_encoder); + } - for field in attributes { - let value = report - .attributes - .as_ref() - .and_then(|attrs| attrs.get(field)) - .cloned() - .unwrap_or_default(); - context.push(value); + batch.batch_text.push_str(&new_context_text); + batch.batch_tokens += new_tokens; + batch.batch_records.push(new_context); } - context.push(if use_community_summary { - report.summary.clone() - } else { - report.full_content.clone() - }); + if !all_context_text.contains(&batch.batch_text) { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + )?; + } - if include_community_rank { - context.push(report.rank.unwrap_or_default().to_string()); + if all_context_records.is_empty() { + eprintln!("Warning: No community records added when building community context."); + return Ok((Vec::new(), HashMap::new())); } - let result = context.join(column_delimiter) + "\n"; - (result, context) - }; + let records_concat = concat( + all_context_records + .into_iter() + .map(|df| df.lazy()) + .collect::>(), + UnionArgs::default(), + )? + .collect()?; - let compute_community_weights = entities.is_some() - && !community_reports.is_empty() - && include_community_weight - && (community_reports[0].attributes.is_none() - || !community_reports[0] - .attributes - .clone() - .unwrap() - .contains_key(community_weight_name)); + Ok(( + all_context_text, + HashMap::from([(context_name.to_lowercase(), records_concat)]), + )) + } - let mut community_reports = community_reports; - if compute_community_weights { - community_reports = _compute_community_weights( - community_reports, - entities.clone(), - community_weight_name, - normalize_community_weight, - ); + fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, + ) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports.clone(); + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for mut report in community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + } + community_reports } +} - let mut selected_reports: Vec = community_reports - .iter() - .filter(|&report| _is_included(report)) - .cloned() - .collect(); +struct Batch { + batch_text: String, + batch_tokens: usize, + batch_records: Vec>, +} - if selected_reports.is_empty() { - return Ok((Vec::new(), HashMap::new())); +impl Batch { + fn new() -> Self { + Batch { + batch_text: String::new(), + batch_tokens: 0, + batch_records: Vec::new(), + } } - if shuffle_data { - let mut rng = rand::thread_rng(); - selected_reports.shuffle(&mut rng); + fn init_batch( + &mut self, + context_name: &str, + header: &Vec, + column_delimiter: &str, + token_encoder: Option, + ) { + self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + self.batch_tokens = num_tokens(&self.batch_text, token_encoder); + self.batch_records.clear(); } - let attributes = if let Some(attributes) = &community_reports[0].attributes { - attributes.keys().cloned().collect::>() - } else { - Vec::new() - }; - - let header = _get_header(attributes); - let mut all_context_text: Vec = Vec::new(); - let mut all_context_records: Vec = Vec::new(); - - let mut batch_text = String::new(); - let mut batch_tokens = 0; - let mut batch_records: Vec> = Vec::new(); - - let mut _init_batch = || { - batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); - batch_tokens = num_tokens(&batch_text, token_encoder); - batch_records = Vec::new(); - }; - - let _cut_batch = |batch_records: Vec>, header: Vec| -> anyhow::Result<()> { - let weight_column = if include_community_weight && entities.is_some() { + fn cut_batch( + &mut self, + all_context_text: &mut Vec, + all_context_records: &mut Vec, + entities: Option>, + header: &Vec, + community_weight_name: &str, + community_rank_name: &str, + include_community_weight: bool, + include_community_rank: bool, + ) -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.is_some_and(|e| !e.is_empty()) { Some(community_weight_name) } else { None @@ -189,7 +377,12 @@ pub fn build_community_context( None }; - let mut record_df = _convert_report_context_to_df(batch_records, header, weight_column, rank_column)?; + let mut record_df = Self::_convert_report_context_to_df( + self.batch_records.clone(), + header.clone(), + weight_column, + rank_column, + )?; if record_df.is_empty() { return Ok(()); } @@ -204,123 +397,64 @@ pub fn build_community_context( all_context_records.push(record_df); Ok(()) - }; - - _init_batch(); - - Ok((vec![], HashMap::new())) -} + } -fn _compute_community_weights( - community_reports: Vec, - entities: Option>, - weight_attribute: &str, - normalize: bool, -) -> Vec { - // Calculate a community's weight as the count of text units associated with entities within the community. - if let Some(entities) = entities { - let mut community_reports = community_reports.clone(); - let mut community_text_units = std::collections::HashMap::new(); - for entity in entities { - if let Some(community_ids) = entity.community_ids.clone() { - for community_id in community_ids { - community_text_units - .entry(community_id) - .or_insert_with(Vec::new) - .extend(entity.text_unit_ids.clone()); - } - } + fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); } - for report in &mut community_reports { - if report.attributes.is_none() { - report.attributes = Some(std::collections::HashMap::new()); - } - if let Some(attributes) = &mut report.attributes { - attributes.insert( - weight_attribute.to_string(), - community_text_units - .get(&report.community_id) - .map(|text_units| text_units.len()) - .unwrap_or(0) - .to_string(), - ); - } - } - if normalize { - // Normalize by max weight - let all_weights: Vec = community_reports - .iter() - .filter_map(|report| { - report - .attributes - .as_ref() - .and_then(|attributes| attributes.get(weight_attribute)) - .map(|weight| weight.parse::().unwrap_or(0.0)) - }) - .collect(); - if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { - for mut report in community_reports { - if let Some(attributes) = &mut report.attributes { - if let Some(weight) = attributes.get_mut(weight_attribute) { - *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); - } - } - } - } + + let mut data_series = Vec::new(); + for (header, records) in header.iter().zip(context_records.iter()) { + let series = Series::new(header, records); + data_series.push(series); } - } - community_reports -} -fn _convert_report_context_to_df( - context_records: Vec>, - header: Vec, - weight_column: Option<&str>, - rank_column: Option<&str>, -) -> anyhow::Result { - if context_records.is_empty() { - return Ok(DataFrame::empty()); - } + let record_df = DataFrame::new(data_series)?; - let mut data_series = Vec::new(); - for (header, records) in header.iter().zip(context_records.iter()) { - let series = Series::new(header, records); - data_series.push(series); + return Self::_rank_report_context(record_df, weight_column, rank_column); } - let record_df = DataFrame::new(data_series)?; + fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + let mut rank_attributes = Vec::new(); - return _rank_report_context(record_df, weight_column, rank_column); -} + let mut report_df = report_df; -fn _rank_report_context( - report_df: DataFrame, - weight_column: Option<&str>, - rank_column: Option<&str>, -) -> anyhow::Result { - let weight_column = weight_column.unwrap_or("occurrence weight"); - let rank_column = rank_column.unwrap_or("rank"); - - let mut rank_attributes = Vec::new(); - rank_attributes.push(weight_column); - let report_df = report_df - .clone() - .lazy() - .with_column(col(weight_column).cast(DataType::Float64)) - .collect()?; + if let Some(weight_column) = weight_column { + rank_attributes.push(weight_column); + report_df = report_df + .clone() + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + } - rank_attributes.push(rank_column); - let report_df = report_df - .clone() - .lazy() - .with_column(col(rank_column).cast(DataType::Float64)) - .collect()?; + if let Some(rank_column) = rank_column { + rank_attributes.push(rank_column); + report_df = report_df + .clone() + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + } - let report_df = report_df - .clone() - .lazy() - .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) - .collect()?; + if !rank_attributes.is_empty() { + report_df = report_df + .clone() + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + } - Ok(report_df) + Ok(report_df) + } } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs index 705818a77..87db20231 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -1,18 +1,19 @@ -use async_trait::async_trait; -// use polars::prelude::*; -use std::collections::HashMap; - -// TODO: Serialize and Deserialize polars::frame::DataFrame -type DataFrame = Vec; - -#[async_trait] -pub trait GlobalContextBuilder { - /// Build the context for the global search mode. - async fn build_context( - &self, - conversation_history: Option, - context_builder_params: Option>, - ) -> (Vec, HashMap); +#[derive(Debug, Clone)] +pub struct ContextBuilderParams { + //conversation_history: Option, + pub use_community_summary: bool, + pub column_delimiter: String, + pub shuffle_data: bool, + pub include_community_rank: bool, + pub min_community_rank: u32, + pub community_rank_name: String, + pub include_community_weight: bool, + pub community_weight_name: String, + pub normalize_community_weight: bool, + pub max_tokens: usize, + pub context_name: String, + // conversation_history_user_turns_only: bool, + // conversation_history_max_turns: Option, } pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 337831da0..1548d9671 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -117,19 +117,22 @@ pub fn read_entities( // attributes_cols: Option>, ) -> anyhow::Result> { let column_names = [ - id_col, - short_id_col.unwrap_or("short_id"), - title_col, - type_col.unwrap_or("type"), - description_col.unwrap_or("description"), - name_embedding_col.unwrap_or("name_embedding"), - description_embedding_col.unwrap_or("description_embedding"), - graph_embedding_col.unwrap_or("graph_embedding"), - community_col.unwrap_or("community_ids"), - text_unit_ids_col.unwrap_or("text_unit_ids"), - document_ids_col.unwrap_or("document_ids"), - rank_col.unwrap_or("degree"), - ]; + Some(id_col), + short_id_col, + Some(title_col), + type_col, + description_col, + name_embedding_col, + description_embedding_col, + graph_embedding_col, + community_col, + text_unit_ids_col, + document_ids_col, + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); let mut df = df.clone(); df.as_single_chunk_par(); diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index fcfca58a6..9f8b9c507 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -61,7 +61,7 @@ pub fn read_indexer_reports( } pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { - let mask = df.column("level")?.i32()?.lt_eq(community_level); + let mask = df.column("level")?.i64()?.lt_eq(community_level); let result = df.filter(&mask)?; Ok(result) @@ -121,9 +121,9 @@ pub fn read_community_reports( short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), title: row.get(1).unwrap_or(&String::new()).to_string(), community_id: row.get(0).unwrap_or(&String::new()).to_string(), - summary: row.get(3).unwrap_or(&String::new()).to_string(), - full_content: row.get(4).unwrap_or(&String::new()).to_string(), - rank: Some(row.get(5).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + summary: row.get(2).unwrap_or(&String::new()).to_string(), + full_content: row.get(3).unwrap_or(&String::new()).to_string(), + rank: Some(row.get(4).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), summary_embedding: None, full_content_embedding: None, attributes: None, diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 4c8f0f68e..0a8482144 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -25,6 +25,13 @@ pub enum MessageType { Dictionary(Vec>), } +#[derive(Debug, Clone)] +pub struct LLMParams { + pub max_tokens: u32, + pub temperature: f32, + pub response_format: HashMap, +} + #[async_trait] pub trait BaseLLM { async fn agenerate( @@ -32,7 +39,7 @@ pub trait BaseLLM { messages: MessageType, streaming: bool, callbacks: Option>, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result; } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 00cb1d9e1..247bfe098 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1,3 +1,2 @@ pub mod llm; -pub mod openai; pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 55636214d..3a12ee6cd 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -1,30 +1,26 @@ use futures::future::join_all; -//use polars::frame::DataFrame; -use serde::{Deserialize, Serialize}; +use polars::frame::DataFrame; use serde_json::Value; use std::collections::HashMap; use std::time::Instant; use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::community_context::GlobalCommunityContext; -use crate::context_builder::context_builder::ConversationHistory; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; +use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; use crate::llm::utils::num_tokens; -// TODO: Serialize and Deserialize polars::frame::DataFrame -type DataFrame = Vec; - -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct SearchResult { - response: ResponseType, - context_data: ContextData, - context_text: ContextText, - completion_time: f64, - llm_calls: usize, - prompt_tokens: usize, + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ResponseType { String(String), Dictionary(HashMap), @@ -32,37 +28,36 @@ pub enum ResponseType { KeyPoints(Vec), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ContextData { String(String), DataFrames(Vec), Dictionary(HashMap), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ContextText { String(String), Strings(Vec), Dictionary(HashMap), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct KeyPoint { - answer: String, - score: i32, + pub answer: String, + pub score: i32, } -#[derive(Serialize, Deserialize)] pub struct GlobalSearchResult { - response: ResponseType, - context_data: ContextData, - context_text: ContextText, - completion_time: f64, - llm_calls: usize, - prompt_tokens: usize, - map_responses: Vec, - reduce_context_data: ContextData, - reduce_context_text: ContextText, + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, + pub map_responses: Vec, + pub reduce_context_data: ContextData, + pub reduce_context_text: ContextText, } #[derive(Debug, Clone)] @@ -94,15 +89,15 @@ pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, token_encoder: Option, - context_builder_params: Option>, + context_builder_params: ContextBuilderParams, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, general_knowledge_inclusion_prompt: String, callbacks: Option>, max_data_tokens: usize, - map_llm_params: HashMap, - reduce_llm_params: HashMap, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, } impl GlobalSearch { @@ -117,19 +112,18 @@ impl GlobalSearch { json_mode: bool, callbacks: Option>, max_data_tokens: usize, - map_llm_params: HashMap, - reduce_llm_params: HashMap, - context_builder_params: Option>, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, + context_builder_params: ContextBuilderParams, ) -> Self { let mut map_llm_params = map_llm_params; if json_mode { - map_llm_params.insert( - "response_format".to_string(), - serde_json::json!({"type": "json_object"}), - ); + map_llm_params + .response_format + .insert("type".to_string(), "json_object".to_string()); } else { - map_llm_params.remove("response_format"); + map_llm_params.response_format.remove("response_format"); } GlobalSearch { @@ -151,14 +145,14 @@ impl GlobalSearch { pub async fn asearch( &self, query: String, - conversation_history: Option, + _conversation_history: Option, ) -> anyhow::Result { // Step 1: Generate answers for each batch of community short summaries let start_time = Instant::now(); let (context_chunks, context_records) = self .context_builder - .build_context(conversation_history, self.context_builder_params.clone()) - .await; + .build_context(self.context_builder_params.clone()) + .await?; let mut callbacks = match &self.callbacks { Some(callbacks) => { @@ -221,7 +215,7 @@ impl GlobalSearch { &self, context_data: &str, query: &str, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); let search_prompt = String::new(); @@ -282,7 +276,7 @@ impl GlobalSearch { map_responses: Vec, query: &str, callbacks: Option>, - reduce_llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); let mut key_points: Vec> = Vec::new(); @@ -384,7 +378,7 @@ impl GlobalSearch { MessageType::Dictionary(search_messages), true, llm_callbacks, - reduce_llm_params, + llm_params, ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs new file mode 100644 index 000000000..065d36246 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs @@ -0,0 +1,109 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + context_builder::{ + community_context::GlobalCommunityContext, context_builder::ContextBuilderParams, + indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, + }, + llm::llm::LLMParams, + search::global_search::GlobalSearch, +}; +use tiktoken_rs::tokenizer::Tokenizer; + +use crate::it::utils::openai::ChatOpenAI; + +#[tokio::test] +async fn global_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key), llm_model, 20); + let token_encoder = Tokenizer::Cl100kBase; + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + let context_builder = GlobalCommunityContext::new(reports, Some(entities), Some(token_encoder)); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new( + Box::new(llm), + context_builder, + Some(token_encoder), + String::from(""), + String::from("multiple paragraphs"), + false, + String::from(""), + true, + None, + 12_000, + map_llm_params, + reduce_llm_params, + context_builder_params, + ); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs new file mode 100644 index 000000000..d8c308735 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs @@ -0,0 +1 @@ +pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs similarity index 70% rename from shinkai-libs/shinkai-graphrag/src/llm/openai.rs rename to shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs index 644f1101f..1325ef6aa 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; - use async_openai::{ config::OpenAIConfig, - types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs}, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionResponseFormat, + ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, + }, Client, }; - -use super::llm::{BaseLLMCallback, MessageType}; +use async_trait::async_trait; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; pub struct ChatOpenAI { pub api_key: Option, @@ -28,7 +29,7 @@ impl ChatOpenAI { messages: MessageType, streaming: bool, callbacks: Option>, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let mut retry_count = 0; @@ -54,7 +55,7 @@ impl ChatOpenAI { messages: MessageType, _streaming: bool, _callbacks: Option>, - _llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let client = match &self.api_key { Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), @@ -91,7 +92,24 @@ impl ChatOpenAI { .map(|m| Into::::into(m.clone())) .collect::>(); + let response_format = if llm_params + .response_format + .get_key_value("type") + .is_some_and(|(_k, v)| v == "json_object") + { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::JsonObject, + } + } else { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::Text, + } + }; + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(llm_params.max_tokens) + .temperature(llm_params.temperature) + //.response_format(response_format) .model(self.model.clone()) .messages(request_messages) .build()?; @@ -105,3 +123,16 @@ impl ChatOpenAI { return Ok(String::new()); } } + +#[async_trait] +impl BaseLLM for ChatOpenAI { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + self.agenerate(messages, streaming, callbacks, llm_params).await + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs new file mode 100644 index 000000000..4c5c9ed27 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs @@ -0,0 +1,4 @@ +mod it { + mod global_search_tests; + mod utils; +} From 43174eb347481f3ca69237dacab0e157ac35c429 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 9 Aug 2024 14:05:19 +0200 Subject: [PATCH 06/23] add prompts, global search adjustments --- Cargo.lock | 4 +- shinkai-libs/shinkai-graphrag/.gitignore | 1 + .../src/context_builder/community_context.rs | 27 ++- .../{ => global_search}/global_search.rs | 91 +++++++--- .../src/search/global_search/mod.rs | 2 + .../src/search/global_search/prompts.rs | 164 ++++++++++++++++++ .../tests/{it => }/global_search_tests.rs | 30 ++-- shinkai-libs/shinkai-graphrag/tests/it_mod.rs | 4 - .../tests/{it => }/utils/mod.rs | 0 .../tests/{it => }/utils/openai.rs | 2 +- 10 files changed, 272 insertions(+), 53 deletions(-) rename shinkai-libs/shinkai-graphrag/src/search/{ => global_search}/global_search.rs (81%) create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs rename shinkai-libs/shinkai-graphrag/tests/{it => }/global_search_tests.rs (85%) delete mode 100644 shinkai-libs/shinkai-graphrag/tests/it_mod.rs rename shinkai-libs/shinkai-graphrag/tests/{it => }/utils/mod.rs (100%) rename shinkai-libs/shinkai-graphrag/tests/{it => }/utils/openai.rs (98%) diff --git a/Cargo.lock b/Cargo.lock index 27b6ce90a..8d3c41100 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8126,7 +8126,7 @@ dependencies = [ "memmap2 0.7.1", "num-traits", "once_cell", - "percent-encoding 2.3.0", + "percent-encoding 2.3.1", "polars-arrow", "polars-core", "polars-error", @@ -8280,7 +8280,7 @@ dependencies = [ "either", "hashbrown 0.14.5", "once_cell", - "percent-encoding 2.3.0", + "percent-encoding 2.3.1", "polars-arrow", "polars-core", "polars-io", diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore index 122af2cf4..74deb7343 100644 --- a/shinkai-libs/shinkai-graphrag/.gitignore +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -1 +1,2 @@ +.vscode dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index e46219021..1f6be4ffe 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -155,13 +155,13 @@ impl CommunityContext { (result, context) }; - let compute_community_weights = entities.is_some() + let compute_community_weights = entities.as_ref().is_some_and(|e| !e.is_empty()) && !community_reports.is_empty() && include_community_weight && (community_reports[0].attributes.is_none() || !community_reports[0] .attributes - .clone() + .as_ref() .unwrap() .contains_key(community_weight_name)); @@ -219,6 +219,7 @@ impl CommunityContext { community_rank_name, include_community_weight, include_community_rank, + column_delimiter, )?; if single_batch { @@ -243,6 +244,7 @@ impl CommunityContext { community_rank_name, include_community_weight, include_community_rank, + column_delimiter, )?; } @@ -365,8 +367,9 @@ impl Batch { community_rank_name: &str, include_community_weight: bool, include_community_rank: bool, + column_delimiter: &str, ) -> anyhow::Result<()> { - let weight_column = if include_community_weight && entities.is_some_and(|e| !e.is_empty()) { + let weight_column = if include_community_weight && entities.as_ref().is_some_and(|e| !e.is_empty()) { Some(community_weight_name) } else { None @@ -387,10 +390,20 @@ impl Batch { return Ok(()); } + let column_delimiter = if column_delimiter.is_empty() { + b'|' + } else { + column_delimiter.as_bytes()[0] + }; + let mut buffer = Cursor::new(Vec::new()); - CsvWriter::new(buffer.clone()).finish(&mut record_df).unwrap(); + CsvWriter::new(&mut buffer) + .include_header(true) + .with_separator(column_delimiter) + .finish(&mut record_df)?; let mut current_context_text = String::new(); + buffer.set_position(0); buffer.read_to_string(&mut current_context_text)?; all_context_text.push(current_context_text); @@ -410,7 +423,11 @@ impl Batch { } let mut data_series = Vec::new(); - for (header, records) in header.iter().zip(context_records.iter()) { + for (index, header) in header.iter().enumerate() { + let records = context_records + .iter() + .map(|r| r.get(index).unwrap_or(&String::new()).to_owned()) + .collect::>(); let series = Series::new(header, records); data_series.push(series); } diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs similarity index 81% rename from shinkai-libs/shinkai-graphrag/src/search/global_search.rs rename to shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 3a12ee6cd..b5c14795e 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -9,6 +9,9 @@ use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; use crate::llm::utils::num_tokens; +use crate::search::global_search::prompts::NO_DATA_ANSWER; + +use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; #[derive(Debug, Clone)] pub struct SearchResult { @@ -90,6 +93,7 @@ pub struct GlobalSearch { context_builder: GlobalCommunityContext, token_encoder: Option, context_builder_params: ContextBuilderParams, + map_system_prompt: String, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -100,22 +104,42 @@ pub struct GlobalSearch { reduce_llm_params: LLMParams, } +pub struct GlobalSearchParams { + pub llm: Box, + pub context_builder: GlobalCommunityContext, + pub token_encoder: Option, + pub map_system_prompt: Option, + pub reduce_system_prompt: Option, + pub response_type: String, + pub allow_general_knowledge: bool, + pub general_knowledge_inclusion_prompt: Option, + pub json_mode: bool, + pub callbacks: Option>, + pub max_data_tokens: usize, + pub map_llm_params: LLMParams, + pub reduce_llm_params: LLMParams, + pub context_builder_params: ContextBuilderParams, +} + impl GlobalSearch { - pub fn new( - llm: Box, - context_builder: GlobalCommunityContext, - token_encoder: Option, - reduce_system_prompt: String, - response_type: String, - allow_general_knowledge: bool, - general_knowledge_inclusion_prompt: String, - json_mode: bool, - callbacks: Option>, - max_data_tokens: usize, - map_llm_params: LLMParams, - reduce_llm_params: LLMParams, - context_builder_params: ContextBuilderParams, - ) -> Self { + pub fn new(global_search_params: GlobalSearchParams) -> Self { + let GlobalSearchParams { + llm, + context_builder, + token_encoder, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + json_mode, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + context_builder_params, + } = global_search_params; + let mut map_llm_params = map_llm_params; if json_mode { @@ -126,11 +150,17 @@ impl GlobalSearch { map_llm_params.response_format.remove("response_format"); } + let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); + let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); + let general_knowledge_inclusion_prompt = + general_knowledge_inclusion_prompt.unwrap_or(GENERAL_KNOWLEDGE_INSTRUCTION.to_string()); + GlobalSearch { llm, context_builder, token_encoder, context_builder_params, + map_system_prompt, reduce_system_prompt, response_type, allow_general_knowledge, @@ -218,7 +248,8 @@ impl GlobalSearch { llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); - let search_prompt = String::new(); + let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); + let mut search_messages = Vec::new(); search_messages.push(HashMap::from([ ("role".to_string(), "system".to_string()), @@ -253,6 +284,7 @@ impl GlobalSearch { if let Some(points) = points.as_array() { return points .iter() + .filter(|element| element.get("description").is_some() && element.get("score").is_some()) .map(|element| KeyPoint { answer: element .get("description") @@ -268,7 +300,10 @@ impl GlobalSearch { } } - Vec::new() + vec![KeyPoint { + answer: "".to_string(), + score: 0, + }] } async fn _reduce_response( @@ -282,15 +317,13 @@ impl GlobalSearch { let mut key_points: Vec> = Vec::new(); for (index, response) in map_responses.iter().enumerate() { - if let ResponseType::Dictionaries(response_list) = &response.response { - for element in response_list { - if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { - let mut point = HashMap::new(); - point.insert("analyst".to_string(), (index + 1).to_string()); - point.insert("answer".to_string(), answer.to_string()); - point.insert("score".to_string(), score.to_string()); - key_points.push(point); - } + if let ResponseType::KeyPoints(response_list) = &response.response { + for key_point in response_list { + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), key_point.answer.clone()); + point.insert("score".to_string(), key_point.score.to_string()); + key_points.push(point); } } } @@ -301,8 +334,10 @@ impl GlobalSearch { .collect(); if filtered_key_points.is_empty() && !self.allow_general_knowledge { + eprintln!("Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."); + return Ok(SearchResult { - response: ResponseType::String("NO_DATA_ANSWER".to_string()), + response: ResponseType::String(NO_DATA_ANSWER.to_string()), context_data: ContextData::String("".to_string()), context_text: ContextText::String("".to_string()), completion_time: start_time.elapsed().as_secs_f64(), @@ -328,9 +363,11 @@ impl GlobalSearch { formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); formatted_response_data.push(point.get("answer").unwrap().to_string()); let formatted_response_text = formatted_response_data.join("\n"); + if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { break; } + data.push(formatted_response_text.clone()); total_tokens += num_tokens(&formatted_response_text, self.token_encoder); } diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs new file mode 100644 index 000000000..79f16f1e0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs @@ -0,0 +1,2 @@ +pub mod global_search; +pub mod prompts; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs new file mode 100644 index 000000000..7c9fef5cb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs @@ -0,0 +1,164 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for global search. + +pub const MAP_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +"#; + +pub const REDUCE_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; + +pub const NO_DATA_ANSWER: &str = "I am sorry but I am unable to answer this question given the provided data."; + +pub const GENERAL_KNOWLEDGE_INSTRUCTION: &str = r#" +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +"#; diff --git a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs similarity index 85% rename from shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs rename to shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 065d36246..8aea91032 100644 --- a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -5,18 +5,19 @@ use shinkai_graphrag::{ indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, }, llm::llm::LLMParams, - search::global_search::GlobalSearch, + search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; use tiktoken_rs::tokenizer::Tokenizer; +use utils::openai::ChatOpenAI; -use crate::it::utils::openai::ChatOpenAI; +mod utils; #[tokio::test] async fn global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); - let llm = ChatOpenAI::new(Some(api_key), llm_model, 20); + let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); let token_encoder = Tokenizer::Cl100kBase; // Load community reports @@ -76,21 +77,22 @@ async fn global_search_test() -> Result<(), Box> { // Perform global search - let search_engine = GlobalSearch::new( - Box::new(llm), + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), context_builder, - Some(token_encoder), - String::from(""), - String::from("multiple paragraphs"), - false, - String::from(""), - true, - None, - 12_000, + token_encoder: Some(token_encoder), + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 12_000, map_llm_params, reduce_llm_params, context_builder_params, - ); + }); let result = search_engine .asearch( diff --git a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs deleted file mode 100644 index 4c5c9ed27..000000000 --- a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod it { - mod global_search_tests; - mod utils; -} diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs rename to shinkai-libs/shinkai-graphrag/tests/utils/mod.rs diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs similarity index 98% rename from shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs rename to shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 1325ef6aa..7eab1460e 100644 --- a/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -92,7 +92,7 @@ impl ChatOpenAI { .map(|m| Into::::into(m.clone())) .collect::>(); - let response_format = if llm_params + let _response_format = if llm_params .response_format .get_key_value("type") .is_some_and(|(_k, v)| v == "json_object") From 88ecbcc41e261b10d3b2a81e4f621eb3d7c7183b Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 13 Aug 2024 15:08:22 +0200 Subject: [PATCH 07/23] read indexer entities and reports improvements, compute community weights --- shinkai-libs/shinkai-graphrag/Cargo.toml | 3 +- .../src/context_builder/community_context.rs | 10 +- .../src/context_builder/indexer_entities.rs | 224 +++++++++++------- .../src/context_builder/indexer_reports.rs | 123 ++++++---- .../shinkai-graphrag/src/llm/utils.rs | 2 +- 5 files changed, 226 insertions(+), 136 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 378f20e4f..9a9eee401 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -16,5 +16,4 @@ tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } [dev-dependencies] -async-openai = "0.23.4" -tokio = { version = "1.36", features = ["full"] } \ No newline at end of file +async-openai = "0.23.4" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index 1f6be4ffe..c5b096ca1 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, io::{Cursor, Read}, }; @@ -276,7 +276,7 @@ impl CommunityContext { ) -> Vec { // Calculate a community's weight as the count of text units associated with entities within the community. if let Some(entities) = entities { - let mut community_reports = community_reports.clone(); + let mut community_reports = community_reports; let mut community_text_units = std::collections::HashMap::new(); for entity in entities { if let Some(community_ids) = entity.community_ids.clone() { @@ -297,7 +297,7 @@ impl CommunityContext { weight_attribute.to_string(), community_text_units .get(&report.community_id) - .map(|text_units| text_units.len()) + .map(|text_units| text_units.iter().flatten().cloned().collect::>().len()) .unwrap_or(0) .to_string(), ); @@ -316,7 +316,7 @@ impl CommunityContext { }) .collect(); if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { - for mut report in community_reports { + for report in &mut community_reports { if let Some(attributes) = &mut report.attributes { if let Some(weight) = attributes.get_mut(weight_attribute) { *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); @@ -325,6 +325,8 @@ impl CommunityContext { } } } + + return community_reports; } community_reports } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 1548d9671..f550a2aef 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use polars::prelude::*; use polars_lazy::dsl::col; @@ -12,44 +12,19 @@ pub fn read_indexer_entities( community_level: u32, ) -> anyhow::Result> { let entity_df = final_nodes.clone(); - let mut entity_df = filter_under_community_level(&entity_df, community_level)?; + let entity_df = filter_under_community_level(&entity_df, community_level)?; - let entity_df = entity_df.rename("title", "name")?.rename("degree", "rank")?; + let entity_embedding_df = final_entities.clone(); let entity_df = entity_df - .clone() .lazy() + .rename(["title", "degree"], ["name", "rank"]) .with_column(col("community").fill_null(lit(-1))) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::Int32)) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("rank").cast(DataType::Int32)) - .collect()?; - - let entity_embedding_df = final_entities.clone(); - - let entity_df = entity_df - .clone() - .lazy() .group_by([col("name"), col("rank")]) .agg([col("community").max()]) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::String)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .join( entity_embedding_df.clone().lazy(), [col("name")], @@ -58,12 +33,6 @@ pub fn read_indexer_entities( ) .collect()?; - let entity_df = entity_df - .clone() - .lazy() - .filter(len().over([col("name")]).gt(lit(1))) - .collect()?; - let entities = read_entities( &entity_df, "id", @@ -134,9 +103,15 @@ pub fn read_entities( .filter_map(|&v| v.map(|v| v.to_string())) .collect::>(); + let column_names = column_names.into_iter().collect::>().into_vec(); + let mut df = df.clone(); df.as_single_chunk_par(); - let mut iters = df.columns(column_names)?.iter().map(|s| s.iter()).collect::>(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); let mut rows = Vec::new(); for _row in 0..df.height() { @@ -144,67 +119,144 @@ pub fn read_entities( for iter in &mut iters { let value = iter.next(); if let Some(value) = value { - row_values.push(value.to_string()); + row_values.push(value); } } rows.push(row_values); } let mut entities = Vec::new(); - for row in rows { + for (idx, row) in rows.iter().enumerate() { let report = Entity { - id: row.get(0).unwrap_or(&String::new()).to_string(), - short_id: Some(row.get(1).unwrap_or(&String::new()).to_string()), - title: row.get(2).unwrap_or(&String::new()).to_string(), - entity_type: Some(row.get(3).unwrap_or(&String::new()).to_string()), - description: Some(row.get(4).unwrap_or(&String::new()).to_string()), - name_embedding: Some( - row.get(5) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - description_embedding: Some( - row.get(6) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - graph_embedding: Some( - row.get(7) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - community_ids: Some( - row.get(8) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), - ), - text_unit_ids: Some( - row.get(9) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), ), - document_ids: Some( - row.get(10) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), - ), - rank: Some(row.get(11).and_then(|v| v.parse::().ok()).unwrap_or(0)), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + entity_type: type_col + .map(|type_col| get_field(&row, type_col, &column_names)) + .flatten() + .map(|entity_type| entity_type.to_string()), + description: description_col + .map(|description_col| get_field(&row, description_col, &column_names)) + .flatten() + .map(|description| description.to_string()), + name_embedding: name_embedding_col.map(|name_embedding_col| { + get_field(&row, name_embedding_col, &column_names) + .map(|name_embedding| match name_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(name_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(&row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + graph_embedding: graph_embedding_col.map(|graph_embedding_col| { + get_field(&row, graph_embedding_col, &column_names) + .map(|graph_embedding| match graph_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(graph_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + community_ids: community_col.map(|community_col| { + get_field(&row, community_col, &column_names) + .map(|community_ids| match community_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(&row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(&row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }) + .flatten(), attributes: None, }; entities.push(report); } - Ok(entities) + let mut unique_entities: Vec = Vec::new(); + let mut entity_ids: HashSet = HashSet::new(); + + for entity in entities { + if !entity_ids.contains(&entity.id) { + unique_entities.push(entity.clone()); + entity_ids.insert(entity.id); + } + } + + Ok(unique_entities) +} + +pub fn get_field<'a>( + row: &'a Vec>, + column_name: &'a str, + column_names: &'a Vec, +) -> Option> { + match column_names.iter().position(|x| x == column_name) { + Some(index) => row.get(index).cloned(), + None => None, + } } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index 9f8b9c507..cae8dc607 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -1,9 +1,11 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use polars::prelude::*; use polars_lazy::dsl::col; use serde::{Deserialize, Serialize}; +use super::indexer_entities::get_field; + pub fn read_indexer_reports( final_community_reports: &DataFrame, final_nodes: &DataFrame, @@ -12,33 +14,13 @@ pub fn read_indexer_reports( let entity_df = final_nodes.clone(); let entity_df = filter_under_community_level(&entity_df, community_level)?; - let entity_df = entity_df - .clone() + let filtered_community_df = entity_df .lazy() .with_column(col("community").fill_null(lit(-1))) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::Int32)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() - .with_column(col("community").cast(DataType::String)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .group_by([col("title")]) .agg([col("community").max()]) - .collect()?; - - let filtered_community_df = entity_df - .clone() - .lazy() + .with_column(col("community").cast(DataType::String)) .filter(len().over([col("community")]).gt(lit(1))) .collect()?; @@ -46,7 +28,6 @@ pub fn read_indexer_reports( let report_df = filter_under_community_level(&report_df, community_level)?; let report_df = report_df - .clone() .lazy() .join( filtered_community_df.clone().lazy(), @@ -56,7 +37,18 @@ pub fn read_indexer_reports( ) .collect()?; - let reports = read_community_reports(&report_df, "community", Some("community"), None, None)?; + let reports = read_community_reports( + &report_df, + "community", + Some("community"), + "title", + "community", + "summary", + "full_content", + Some("rank"), + None, + None, + )?; Ok(reports) } @@ -83,21 +75,36 @@ pub struct CommunityReport { pub fn read_community_reports( df: &DataFrame, - _id_col: &str, - _short_id_col: Option<&str>, - // title_col: &str, - // community_col: &str, - // summary_col: &str, - // content_col: &str, - // rank_col: Option<&str>, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + community_col: &str, + summary_col: &str, + content_col: &str, + rank_col: Option<&str>, _summary_embedding_col: Option<&str>, _content_embedding_col: Option<&str>, // attributes_cols: Option<&[&str]>, ) -> anyhow::Result> { + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + Some(community_col), + Some(summary_col), + Some(content_col), + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + let column_names: Vec = column_names.into_iter().collect::>().into_vec(); + let mut df = df.clone(); df.as_single_chunk_par(); let mut iters = df - .columns(["community", "title", "summary", "full_content", "rank"])? + .columns(column_names.clone())? .iter() .map(|s| s.iter()) .collect::>(); @@ -108,22 +115,42 @@ pub fn read_community_reports( for iter in &mut iters { let value = iter.next(); if let Some(value) = value { - row_values.push(value.to_string()); + row_values.push(value); } } rows.push(row_values); } let mut reports = Vec::new(); - for row in rows { + for (idx, row) in rows.iter().enumerate() { let report = CommunityReport { - id: row.get(0).unwrap_or(&String::new()).to_string(), - short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), - title: row.get(1).unwrap_or(&String::new()).to_string(), - community_id: row.get(0).unwrap_or(&String::new()).to_string(), - summary: row.get(2).unwrap_or(&String::new()).to_string(), - full_content: row.get(3).unwrap_or(&String::new()).to_string(), - rank: Some(row.get(4).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + community_id: get_field(&row, community_col, &column_names) + .map(|community| community.to_string()) + .unwrap_or(String::new()), + summary: get_field(&row, summary_col, &column_names) + .map(|summary| summary.to_string()) + .unwrap_or(String::new()), + full_content: get_field(&row, content_col, &column_names) + .map(|content| content.to_string()) + .unwrap_or(String::new()), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }) + .flatten(), summary_embedding: None, full_content_embedding: None, attributes: None, @@ -131,5 +158,15 @@ pub fn read_community_reports( reports.push(report); } - Ok(reports) + let mut unique_reports: Vec = Vec::new(); + let mut report_ids: HashSet = HashSet::new(); + + for report in reports { + if !report_ids.contains(&report.id) { + unique_reports.push(report.clone()); + report_ids.insert(report.id); + } + } + + Ok(unique_reports) } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs index a6b4dfc54..1599ce78f 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs @@ -3,5 +3,5 @@ use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub fn num_tokens(text: &str, token_encoder: Option) -> usize { let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); - bpe.encode_ordinary(text).len() + bpe.encode_with_special_tokens(text).len() } From b3a0b9295ec230167a067161ce24841cb4c684e2 Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 13 Aug 2024 18:36:37 +0200 Subject: [PATCH 08/23] improvements, disable global search test --- .../src/context_builder/community_context.rs | 3 --- .../src/context_builder/indexer_entities.rs | 8 ++++---- .../src/context_builder/indexer_reports.rs | 8 ++++---- .../shinkai-graphrag/tests/global_search_tests.rs | 2 +- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index c5b096ca1..d823b0070 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -451,7 +451,6 @@ impl Batch { if let Some(weight_column) = weight_column { rank_attributes.push(weight_column); report_df = report_df - .clone() .lazy() .with_column(col(weight_column).cast(DataType::Float64)) .collect()?; @@ -460,7 +459,6 @@ impl Batch { if let Some(rank_column) = rank_column { rank_attributes.push(rank_column); report_df = report_df - .clone() .lazy() .with_column(col(rank_column).cast(DataType::Float64)) .collect()?; @@ -468,7 +466,6 @@ impl Batch { if !rank_attributes.is_empty() { report_df = report_df - .clone() .lazy() .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) .collect()?; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index f550a2aef..26d8566fd 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -26,7 +26,7 @@ pub fn read_indexer_entities( .agg([col("community").max()]) .with_column(col("community").cast(DataType::String)) .join( - entity_embedding_df.clone().lazy(), + entity_embedding_df.lazy(), [col("name")], [col("name")], JoinArgs::new(JoinType::Inner), @@ -34,7 +34,7 @@ pub fn read_indexer_entities( .collect()?; let entities = read_entities( - &entity_df, + entity_df, "id", Some("human_readable_id"), "name", @@ -70,7 +70,7 @@ pub struct Entity { } pub fn read_entities( - df: &DataFrame, + df: DataFrame, id_col: &str, short_id_col: Option<&str>, title_col: &str, @@ -105,7 +105,7 @@ pub fn read_entities( let column_names = column_names.into_iter().collect::>().into_vec(); - let mut df = df.clone(); + let mut df = df; df.as_single_chunk_par(); let mut iters = df .columns(column_names.clone())? diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index cae8dc607..1b59b9c59 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -30,7 +30,7 @@ pub fn read_indexer_reports( let report_df = report_df .lazy() .join( - filtered_community_df.clone().lazy(), + filtered_community_df.lazy(), [col("community")], [col("community")], JoinArgs::new(JoinType::Inner), @@ -38,7 +38,7 @@ pub fn read_indexer_reports( .collect()?; let reports = read_community_reports( - &report_df, + report_df, "community", Some("community"), "title", @@ -74,7 +74,7 @@ pub struct CommunityReport { } pub fn read_community_reports( - df: &DataFrame, + df: DataFrame, id_col: &str, short_id_col: Option<&str>, title_col: &str, @@ -101,7 +101,7 @@ pub fn read_community_reports( let column_names: Vec = column_names.into_iter().collect::>().into_vec(); - let mut df = df.clone(); + let mut df = df; df.as_single_chunk_par(); let mut iters = df .columns(column_names.clone())? diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 8aea91032..c08c2548a 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -12,7 +12,7 @@ use utils::openai::ChatOpenAI; mod utils; -#[tokio::test] +// #[tokio::test] async fn global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); From 327c9967c5d702a1fab456e571cbdd72335835b4 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 15 Aug 2024 07:01:01 +0200 Subject: [PATCH 09/23] decouple openai tokenizer --- shinkai-libs/shinkai-graphrag/Cargo.toml | 4 ++-- .../src/context_builder/community_context.rs | 23 ++++++++----------- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 - .../shinkai-graphrag/src/llm/utils.rs | 7 ------ .../src/search/global_search/global_search.rs | 20 +++++++--------- .../tests/global_search_tests.rs | 8 +++---- .../shinkai-graphrag/tests/utils/openai.rs | 7 ++++++ 7 files changed, 30 insertions(+), 40 deletions(-) delete mode 100644 shinkai-libs/shinkai-graphrag/src/llm/utils.rs diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 9a9eee401..18650385c 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -12,8 +12,8 @@ polars-lazy = "0.41.3" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" -tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } [dev-dependencies] -async-openai = "0.23.4" \ No newline at end of file +async-openai = "0.23.4" +tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index d823b0070..a6a1a7485 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -10,28 +10,25 @@ use polars::{ series::Series, }; use rand::prelude::SliceRandom; -use tiktoken_rs::tokenizer::Tokenizer; - -use crate::llm::utils::num_tokens; use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; pub struct GlobalCommunityContext { community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, } impl GlobalCommunityContext { pub fn new( community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, ) -> Self { Self { community_reports, entities, - token_encoder, + num_tokens_fn, } } @@ -56,7 +53,7 @@ impl GlobalCommunityContext { let (community_context, community_context_data) = CommunityContext::build_community_context( self.community_reports.clone(), self.entities.clone(), - self.token_encoder.clone(), + self.num_tokens_fn, use_community_summary, &column_delimiter, shuffle_data, @@ -84,7 +81,7 @@ impl CommunityContext { pub fn build_community_context( community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, use_community_summary: bool, column_delimiter: &str, shuffle_data: bool, @@ -202,11 +199,11 @@ impl CommunityContext { let mut batch = Batch::new(); - batch.init_batch(context_name, &header, column_delimiter, token_encoder); + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); for report in selected_reports { let (new_context_text, new_context) = _report_context_text(&report, &attributes); - let new_tokens = num_tokens(&new_context_text, token_encoder); + let new_tokens = num_tokens_fn(&new_context_text); // add the current batch to the context data and start a new batch if we are in multi-batch mode if batch.batch_tokens + new_tokens > max_tokens { @@ -226,7 +223,7 @@ impl CommunityContext { break; } - batch.init_batch(context_name, &header, column_delimiter, token_encoder); + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); } batch.batch_text.push_str(&new_context_text); @@ -352,10 +349,10 @@ impl Batch { context_name: &str, header: &Vec, column_delimiter: &str, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, ) { self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); - self.batch_tokens = num_tokens(&self.batch_text, token_encoder); + self.batch_tokens = num_tokens_fn(&self.batch_text); self.batch_records.clear(); } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 247bfe098..214bbef7c 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1,2 +1 @@ pub mod llm; -pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs deleted file mode 100644 index 1599ce78f..000000000 --- a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs +++ /dev/null @@ -1,7 +0,0 @@ -use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; - -pub fn num_tokens(text: &str, token_encoder: Option) -> usize { - let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); - let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); - bpe.encode_with_special_tokens(text).len() -} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index b5c14795e..0368d4a59 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -3,12 +3,10 @@ use polars::frame::DataFrame; use serde_json::Value; use std::collections::HashMap; use std::time::Instant; -use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; -use crate::llm::utils::num_tokens; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -26,8 +24,6 @@ pub struct SearchResult { #[derive(Debug, Clone)] pub enum ResponseType { String(String), - Dictionary(HashMap), - Dictionaries(Vec>), KeyPoints(Vec), } @@ -91,7 +87,7 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, context_builder_params: ContextBuilderParams, map_system_prompt: String, reduce_system_prompt: String, @@ -107,7 +103,7 @@ pub struct GlobalSearch { pub struct GlobalSearchParams { pub llm: Box, pub context_builder: GlobalCommunityContext, - pub token_encoder: Option, + pub num_tokens_fn: fn(&str) -> usize, pub map_system_prompt: Option, pub reduce_system_prompt: Option, pub response_type: String, @@ -126,7 +122,7 @@ impl GlobalSearch { let GlobalSearchParams { llm, context_builder, - token_encoder, + num_tokens_fn, map_system_prompt, reduce_system_prompt, response_type, @@ -158,7 +154,7 @@ impl GlobalSearch { GlobalSearch { llm, context_builder, - token_encoder, + num_tokens_fn, context_builder_params, map_system_prompt, reduce_system_prompt, @@ -273,7 +269,7 @@ impl GlobalSearch { context_text: ContextText::String(context_data.to_string()), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 1, - prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + prompt_tokens: (self.num_tokens_fn)(&search_prompt), }) } @@ -364,12 +360,12 @@ impl GlobalSearch { formatted_response_data.push(point.get("answer").unwrap().to_string()); let formatted_response_text = formatted_response_data.join("\n"); - if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { + if total_tokens + (self.num_tokens_fn)(&formatted_response_text) > self.max_data_tokens { break; } data.push(formatted_response_text.clone()); - total_tokens += num_tokens(&formatted_response_text, self.token_encoder); + total_tokens += (self.num_tokens_fn)(&formatted_response_text); } let text_data = data.join("\n\n"); @@ -425,7 +421,7 @@ impl GlobalSearch { context_text: ContextText::String(text_data), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 1, - prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + prompt_tokens: (self.num_tokens_fn)(&search_prompt), }) } } diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index c08c2548a..42bcd5834 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -7,8 +7,7 @@ use shinkai_graphrag::{ llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; -use tiktoken_rs::tokenizer::Tokenizer; -use utils::openai::ChatOpenAI; +use utils::openai::{num_tokens, ChatOpenAI}; mod utils; @@ -18,7 +17,6 @@ async fn global_search_test() -> Result<(), Box> { let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); - let token_encoder = Tokenizer::Cl100kBase; // Load community reports // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip @@ -47,7 +45,7 @@ async fn global_search_test() -> Result<(), Box> { // Build global context based on community reports - let context_builder = GlobalCommunityContext::new(reports, Some(entities), Some(token_encoder)); + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); let context_builder_params = ContextBuilderParams { use_community_summary: false, // False means using full community reports. True means using community short summaries. @@ -80,7 +78,7 @@ async fn global_search_test() -> Result<(), Box> { let search_engine = GlobalSearch::new(GlobalSearchParams { llm: Box::new(llm), context_builder, - token_encoder: Some(token_encoder), + num_tokens_fn: num_tokens, map_system_prompt: None, reduce_system_prompt: None, response_type: String::from("multiple paragraphs"), diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 7eab1460e..95e7e7b80 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -8,6 +8,7 @@ use async_openai::{ }; use async_trait::async_trait; use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { pub api_key: Option, @@ -136,3 +137,9 @@ impl BaseLLM for ChatOpenAI { self.agenerate(messages, streaming, callbacks, llm_params).await } } + +pub fn num_tokens(text: &str) -> usize { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_with_special_tokens(text).len() +} From 9b66f2d5ac02523817b9d40680f50a41a4dfd45d Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 16 Aug 2024 12:13:53 +0200 Subject: [PATCH 10/23] test global search with llama 3.1 --- Cargo.lock | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 1 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 6 + .../src/search/global_search/global_search.rs | 11 +- .../tests/global_search_tests.rs | 106 +++++++++++++++++- .../shinkai-graphrag/tests/utils/mod.rs | 1 + .../shinkai-graphrag/tests/utils/ollama.rs | 100 +++++++++++++++++ .../shinkai-graphrag/tests/utils/openai.rs | 3 +- 8 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index 8d3c41100..1bdfcdd91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10406,6 +10406,7 @@ dependencies = [ "polars", "polars-lazy", "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "tiktoken-rs", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 18650385c..7975fa77c 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -16,4 +16,5 @@ tokio = { version = "1.36", features = ["full"] } [dev-dependencies] async-openai = "0.23.4" +reqwest = { version = "0.11.26", features = ["json"] } tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 0a8482144..5fa5cf633 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -40,6 +40,7 @@ pub trait BaseLLM { streaming: bool, callbacks: Option>, llm_params: LLMParams, + search_phase: Option, ) -> anyhow::Result; } @@ -47,3 +48,8 @@ pub trait BaseLLM { pub trait BaseTextEmbedding { async fn aembed(&self, text: &str) -> Vec; } + +pub enum GlobalSearchPhase { + Map, + Reduce, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 0368d4a59..8bb60b955 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -6,7 +6,7 @@ use std::time::Instant; use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -258,7 +258,13 @@ impl GlobalSearch { let search_response = self .llm - .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + llm_params, + Some(GlobalSearchPhase::Map), + ) .await?; let processed_response = self.parse_search_response(&search_response); @@ -412,6 +418,7 @@ impl GlobalSearch { true, llm_callbacks, llm_params, + Some(GlobalSearchPhase::Reduce), ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 42bcd5834..5c888a8df 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -7,12 +7,114 @@ use shinkai_graphrag::{ llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; -use utils::openai::{num_tokens, ChatOpenAI}; +use utils::{ + ollama::Ollama, + openai::{num_tokens, ChatOpenAI}, +}; mod utils; // #[tokio::test] -async fn global_search_test() -> Result<(), Box> { +async fn ollama_global_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let model_type = "llama3.1"; + + let llm = Ollama::new(base_url.to_string(), model_type.to_string()); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + // Using tiktoken for token count estimation + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + // LLM params are ignored for Ollama + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 5000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} + +// #[tokio::test] +async fn openai_global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs index d8c308735..3ef32f620 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs @@ -1 +1,2 @@ +pub mod ollama; pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs new file mode 100644 index 000000000..41d3619b8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -0,0 +1,100 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct OllamaMessage { + pub role: String, + pub content: String, +} + +pub struct Ollama { + base_url: String, + model_type: String, +} + +impl Ollama { + pub fn new(base_url: String, model_type: String) -> Self { + Ollama { base_url, model_type } + } +} + +#[async_trait] +impl BaseLLM for Ollama { + async fn agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + _llm_params: LLMParams, + search_phase: Option, + ) -> anyhow::Result { + let client = Client::new(); + let chat_url = format!("{}{}", &self.base_url, "/api/chat"); + + let messages_json = match messages { + MessageType::String(message) => json![message], + MessageType::Strings(messages) => json!(messages), + MessageType::Dictionary(messages) => { + let messages = match search_phase { + Some(GlobalSearchPhase::Map) => { + // Filter out system messages and convert them to user messages + messages + .into_iter() + .filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system")) + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + Some(GlobalSearchPhase::Reduce) => { + // Convert roles to user + messages + .into_iter() + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + _ => messages, + }; + + json!(messages) + } + }; + + let payload = json!({ + "model": self.model_type, + "messages": messages_json, + "stream": false, + }); + + let response = client.post(chat_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.message.content) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 95e7e7b80..255d5b4e5 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -7,7 +7,7 @@ use async_openai::{ Client, }; use async_trait::async_trait; -use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { @@ -133,6 +133,7 @@ impl BaseLLM for ChatOpenAI { streaming: bool, callbacks: Option>, llm_params: LLMParams, + _search_phase: Option, ) -> anyhow::Result { self.agenerate(messages, streaming, callbacks, llm_params).await } From ecb3b7484b4de61c94763cfafae1bff08bfbc896 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 15 Aug 2024 14:05:37 +0200 Subject: [PATCH 11/23] GraphRAG Local Search --- .../src/context_builder/community_context.rs | 8 +- .../src/context_builder/context_builder.rs | 4 +- .../src/context_builder/indexer_entities.rs | 22 +-- .../src/context_builder/indexer_reports.rs | 19 +-- shinkai-libs/shinkai-graphrag/src/lib.rs | 2 + shinkai-libs/shinkai-graphrag/src/models.rs | 59 +++++++ .../shinkai-graphrag/src/search/base.rs | 29 ++++ .../src/search/global_search/global_search.rs | 34 +--- .../src/search/local_search/local_search.rs | 96 +++++++++++ .../src/search/local_search/mixed_context.rs | 155 ++++++++++++++++++ .../src/search/local_search/mod.rs | 3 + .../src/search/local_search/prompts.rs | 69 ++++++++ .../shinkai-graphrag/src/search/mod.rs | 2 + .../src/vector_stores/lancedb.rs | 1 + .../shinkai-graphrag/src/vector_stores/mod.rs | 1 + .../tests/global_search_tests.rs | 4 +- 16 files changed, 436 insertions(+), 72 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/models.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/base.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index a6a1a7485..fd30fcfcb 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -11,7 +11,9 @@ use polars::{ }; use rand::prelude::SliceRandom; -use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; +use crate::models::{CommunityReport, Entity}; + +use super::context_builder::GlobalSearchContextBuilderParams; pub struct GlobalCommunityContext { community_reports: Vec, @@ -34,9 +36,9 @@ impl GlobalCommunityContext { pub async fn build_context( &self, - context_builder_params: ContextBuilderParams, + context_builder_params: GlobalSearchContextBuilderParams, ) -> anyhow::Result<(Vec, HashMap)> { - let ContextBuilderParams { + let GlobalSearchContextBuilderParams { use_community_summary, column_delimiter, shuffle_data, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs index 87db20231..18fffb419 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -1,6 +1,5 @@ #[derive(Debug, Clone)] -pub struct ContextBuilderParams { - //conversation_history: Option, +pub struct GlobalSearchContextBuilderParams { pub use_community_summary: bool, pub column_delimiter: String, pub shuffle_data: bool, @@ -12,6 +11,7 @@ pub struct ContextBuilderParams { pub normalize_community_weight: bool, pub max_tokens: usize, pub context_name: String, + //conversation_history: Option, // conversation_history_user_turns_only: bool, // conversation_history_max_turns: Option, } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 26d8566fd..d6e38d3da 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -1,8 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use polars::prelude::*; use polars_lazy::dsl::col; -use serde::{Deserialize, Serialize}; + +use crate::models::Entity; use super::indexer_reports::filter_under_community_level; @@ -52,23 +53,6 @@ pub fn read_indexer_entities( Ok(entities) } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Entity { - pub id: String, - pub short_id: Option, - pub title: String, - pub entity_type: Option, - pub description: Option, - pub description_embedding: Option>, - pub name_embedding: Option>, - pub graph_embedding: Option>, - pub community_ids: Option>, - pub text_unit_ids: Option>, - pub document_ids: Option>, - pub rank: Option, - pub attributes: Option>, -} - pub fn read_entities( df: DataFrame, id_col: &str, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index 1b59b9c59..a7f2fd4ba 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -1,8 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use polars::prelude::*; use polars_lazy::dsl::col; -use serde::{Deserialize, Serialize}; + +use crate::models::CommunityReport; use super::indexer_entities::get_field; @@ -59,20 +60,6 @@ pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> any Ok(result) } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CommunityReport { - pub id: String, - pub short_id: Option, - pub title: String, - pub community_id: String, - pub summary: String, - pub full_content: String, - pub rank: Option, - pub summary_embedding: Option>, - pub full_content_embedding: Option>, - pub attributes: Option>, -} - pub fn read_community_reports( df: DataFrame, id_col: &str, diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs index 08bc3d655..f0c118869 100644 --- a/shinkai-libs/shinkai-graphrag/src/lib.rs +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -1,3 +1,5 @@ pub mod context_builder; pub mod llm; +pub mod models; pub mod search; +pub mod vector_stores; diff --git a/shinkai-libs/shinkai-graphrag/src/models.rs b/shinkai-libs/shinkai-graphrag/src/models.rs new file mode 100644 index 000000000..1d6a63e01 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/models.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct CommunityReport { + pub id: String, + pub short_id: Option, + pub title: String, + pub community_id: String, + pub summary: String, + pub full_content: String, + pub rank: Option, + pub summary_embedding: Option>, + pub full_content_embedding: Option>, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct Entity { + pub id: String, + pub short_id: Option, + pub title: String, + pub entity_type: Option, + pub description: Option, + pub description_embedding: Option>, + pub name_embedding: Option>, + pub graph_embedding: Option>, + pub community_ids: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub rank: Option, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct Relationship { + pub id: String, + pub short_id: Option, + pub source: String, + pub target: String, + pub weight: Option, + pub description: Option, + pub description_embedding: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct TextUnit { + pub id: String, + pub short_id: Option, + pub text: String, + pub text_embedding: Option>, + pub entity_ids: Option>, + pub relationship_ids: Option>, + pub n_tokens: Option, + pub document_ids: Option>, + pub attributes: Option>, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/base.rs b/shinkai-libs/shinkai-graphrag/src/search/base.rs new file mode 100644 index 000000000..a7ba7eaef --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/base.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; + +use polars::frame::DataFrame; + +#[derive(Debug, Clone)] +pub enum ResponseType { + String(String), + KeyPoints(Vec), +} + +#[derive(Debug, Clone)] +pub enum ContextData { + String(String), + DataFrames(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub enum ContextText { + String(String), + Strings(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub struct KeyPoint { + pub answer: String, + pub score: i32, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 8bb60b955..c537198a3 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -1,12 +1,12 @@ use futures::future::join_all; -use polars::frame::DataFrame; use serde_json::Value; use std::collections::HashMap; use std::time::Instant; use crate::context_builder::community_context::GlobalCommunityContext; -use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; +use crate::context_builder::context_builder::{ConversationHistory, GlobalSearchContextBuilderParams}; use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -21,32 +21,6 @@ pub struct SearchResult { pub prompt_tokens: usize, } -#[derive(Debug, Clone)] -pub enum ResponseType { - String(String), - KeyPoints(Vec), -} - -#[derive(Debug, Clone)] -pub enum ContextData { - String(String), - DataFrames(Vec), - Dictionary(HashMap), -} - -#[derive(Debug, Clone)] -pub enum ContextText { - String(String), - Strings(Vec), - Dictionary(HashMap), -} - -#[derive(Debug, Clone)] -pub struct KeyPoint { - pub answer: String, - pub score: i32, -} - pub struct GlobalSearchResult { pub response: ResponseType, pub context_data: ContextData, @@ -88,7 +62,7 @@ pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, num_tokens_fn: fn(&str) -> usize, - context_builder_params: ContextBuilderParams, + context_builder_params: GlobalSearchContextBuilderParams, map_system_prompt: String, reduce_system_prompt: String, response_type: String, @@ -114,7 +88,7 @@ pub struct GlobalSearchParams { pub max_data_tokens: usize, pub map_llm_params: LLMParams, pub reduce_llm_params: LLMParams, - pub context_builder_params: ContextBuilderParams, + pub context_builder_params: GlobalSearchContextBuilderParams, } impl GlobalSearch { diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs new file mode 100644 index 000000000..955551f80 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs @@ -0,0 +1,96 @@ +use std::{collections::HashMap, time::Instant}; + +use crate::{ + llm::llm::{BaseLLM, LLMParams, MessageType}, + search::base::{ContextData, ContextText, ResponseType}, +}; + +use super::{ + mixed_context::{LocalSearchContextBuilderParams, LocalSearchMixedContext}, + prompts::LOCAL_SEARCH_SYSTEM_PROMPT, +}; + +pub struct LocalSearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, +} + +pub struct LocalSearch { + llm: Box, + context_builder: LocalSearchMixedContext, + num_tokens_fn: fn(&str) -> usize, + system_prompt: String, + response_type: String, + llm_params: LLMParams, + context_builder_params: LocalSearchContextBuilderParams, +} + +impl LocalSearch { + pub fn new( + llm: Box, + context_builder: LocalSearchMixedContext, + num_tokens_fn: fn(&str) -> usize, + llm_params: LLMParams, + context_builder_params: LocalSearchContextBuilderParams, + response_type: String, + system_prompt: Option, + ) -> Self { + let system_prompt = system_prompt.unwrap_or(LOCAL_SEARCH_SYSTEM_PROMPT.to_string()); + + LocalSearch { + llm, + context_builder, + num_tokens_fn, + system_prompt, + response_type, + llm_params, + context_builder_params, + } + } + + pub async fn asearch(&self, query: String) -> anyhow::Result { + let start_time = Instant::now(); + let (context_text, context_records) = self + .context_builder + .build_context(self.context_builder_params.clone()) + .await?; + + let search_prompt = self + .system_prompt + .replace("{context_data}", &context_text) + .replace("{response_type}", &self.response_type); + + let mut search_messages = Vec::new(); + search_messages.push(HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ])); + search_messages.push(HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ])); + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + self.llm_params.clone(), + ) + .await?; + + Ok(LocalSearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::String(context_text), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs new file mode 100644 index 000000000..5cadc1653 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -0,0 +1,155 @@ +use std::collections::HashMap; + +use polars::frame::DataFrame; + +use crate::{ + llm::llm::BaseTextEmbedding, + models::{CommunityReport, Entity, Relationship, TextUnit}, + vector_stores::lancedb::LanceDBVectorStore, +}; + +#[derive(Debug, Clone)] +pub struct LocalSearchContextBuilderParams { + pub query: String, + pub include_entity_names: Option>, + pub exclude_entity_names: Option>, + pub max_tokens: i32, + pub text_unit_prop: f32, + pub community_prop: f32, + pub top_k_mapped_entities: i32, + pub top_k_relationships: i32, + pub include_community_rank: bool, + pub include_entity_rank: bool, + pub rank_description: String, + pub include_relationship_weight: bool, + pub relationship_ranking_attribute: String, + pub return_candidate_context: bool, + pub use_community_summary: bool, + pub min_community_rank: i32, + pub community_context_name: String, + pub column_delimiter: String, + // pub conversation_history: Option, + // pub conversation_history_max_turns: Option, + // pub conversation_history_user_turns_only: bool, +} + +pub fn default_local_context_params() -> LocalSearchContextBuilderParams { + LocalSearchContextBuilderParams { + query: String::new(), + include_entity_names: None, + exclude_entity_names: None, + max_tokens: 8000, + text_unit_prop: 0.5, + community_prop: 0.25, + top_k_mapped_entities: 10, + top_k_relationships: 10, + include_community_rank: false, + include_entity_rank: false, + rank_description: "number of relationships".to_string(), + include_relationship_weight: false, + relationship_ranking_attribute: "rank".to_string(), + return_candidate_context: false, + use_community_summary: false, + min_community_rank: 0, + community_context_name: "Reports".to_string(), + column_delimiter: "|".to_string(), + } +} + +pub struct LocalSearchMixedContext { + entities: HashMap, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, + text_units: HashMap, + community_reports: HashMap, + relationships: HashMap, + num_tokens_fn: fn(&str) -> usize, + embedding_vectorstore_key: String, +} + +impl LocalSearchMixedContext { + pub fn new( + entities: Vec, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, + text_units: Option>, + community_reports: Option>, + relationships: Option>, + num_tokens_fn: fn(&str) -> usize, + embedding_vectorstore_key: String, + ) -> Self { + let mut context = LocalSearchMixedContext { + entities: HashMap::new(), + entity_text_embeddings, + text_embedder, + text_units: HashMap::new(), + community_reports: HashMap::new(), + relationships: HashMap::new(), + num_tokens_fn, + embedding_vectorstore_key, + }; + + for entity in entities { + context.entities.insert(entity.id.clone(), entity); + } + + if let Some(units) = text_units { + for unit in units { + context.text_units.insert(unit.id.clone(), unit); + } + } + + if let Some(reports) = community_reports { + for report in reports { + context.community_reports.insert(report.id.clone(), report); + } + } + + if let Some(relations) = relationships { + for relation in relations { + context.relationships.insert(relation.id.clone(), relation); + } + } + + context + } + + pub async fn build_context( + &self, + context_builder_params: LocalSearchContextBuilderParams, + ) -> anyhow::Result<(String, HashMap)> { + let LocalSearchContextBuilderParams { + query, + include_entity_names, + exclude_entity_names, + max_tokens, + text_unit_prop, + community_prop, + top_k_mapped_entities, + top_k_relationships, + include_community_rank, + include_entity_rank, + rank_description, + include_relationship_weight, + relationship_ranking_attribute, + return_candidate_context, + use_community_summary, + min_community_rank, + community_context_name, + column_delimiter, + } = context_builder_params; + + let include_entity_names = include_entity_names.unwrap_or_default(); + let exclude_entity_names = exclude_entity_names.unwrap_or_default(); + + if community_prop + text_unit_prop > 1.0 { + return Err(anyhow::anyhow!( + "The sum of community_prop and text_unit_prop must be less than or equal to 1.0" + )); + } + + let context_text = String::new(); + let context_records = HashMap::new(); + Ok((context_text, context_records)) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs new file mode 100644 index 000000000..2908c58f2 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs @@ -0,0 +1,3 @@ +pub mod local_search; +pub mod mixed_context; +pub mod prompts; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs new file mode 100644 index 000000000..7da6a63eb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for local search. + +pub const LOCAL_SEARCH_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Data tables--- + +{context_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; diff --git a/shinkai-libs/shinkai-graphrag/src/search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/mod.rs index a12441830..7266f8dab 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/mod.rs @@ -1 +1,3 @@ +pub mod base; pub mod global_search; +pub mod local_search; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs new file mode 100644 index 000000000..e40415053 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -0,0 +1 @@ +pub struct LanceDBVectorStore {} diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs new file mode 100644 index 000000000..832ccb1b0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs @@ -0,0 +1 @@ +pub mod lancedb; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 5c888a8df..4c67feb4e 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -1,7 +1,7 @@ use polars::{io::SerReader, prelude::ParquetReader}; use shinkai_graphrag::{ context_builder::{ - community_context::GlobalCommunityContext, context_builder::ContextBuilderParams, + community_context::GlobalCommunityContext, context_builder::GlobalSearchContextBuilderParams, indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, }, llm::llm::LLMParams, @@ -149,7 +149,7 @@ async fn openai_global_search_test() -> Result<(), Box> { let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); - let context_builder_params = ContextBuilderParams { + let context_builder_params = GlobalSearchContextBuilderParams { use_community_summary: false, // False means using full community reports. True means using community short summaries. shuffle_data: true, include_community_rank: true, From 6a59c03f8bb9f6f728849588a53915da1f400b7b Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 21 Aug 2024 17:02:37 +0200 Subject: [PATCH 12/23] mixed context, entities, community reports --- Cargo.lock | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 3 +- .../src/context_builder/community_context.rs | 24 ++- .../src/context_builder/context_builder.rs | 18 -- .../src/context_builder/mod.rs | 2 - .../indexer_entities.rs | 0 .../indexer_reports.rs | 0 .../src/indexer_adapters/mod.rs | 2 + shinkai-libs/shinkai-graphrag/src/lib.rs | 2 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 1 + .../src/retrieval/community_reports.rs | 40 ++++ .../src/retrieval/entity_extraction.rs | 100 ++++++++++ .../shinkai-graphrag/src/retrieval/mod.rs | 2 + .../src/search/global_search/global_search.rs | 11 +- .../src/search/local_search/local_search.rs | 7 +- .../src/search/local_search/mixed_context.rs | 183 ++++++++++++++++-- .../shinkai-graphrag/src/vector_stores/mod.rs | 1 + .../src/vector_stores/vector_store.rs | 21 ++ .../tests/global_search_tests.rs | 10 +- 19 files changed, 375 insertions(+), 53 deletions(-) rename shinkai-libs/shinkai-graphrag/src/{context_builder => indexer_adapters}/indexer_entities.rs (100%) rename shinkai-libs/shinkai-graphrag/src/{context_builder => indexer_adapters}/indexer_reports.rs (100%) create mode 100644 shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs diff --git a/Cargo.lock b/Cargo.lock index 1bdfcdd91..f3f804d53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10411,6 +10411,7 @@ dependencies = [ "serde_json", "tiktoken-rs", "tokio", + "uuid 1.8.0", ] [[package]] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 7975fa77c..0212ad59d 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -7,12 +7,13 @@ edition = "2021" anyhow = "1.0.86" async-trait = "0.1.74" futures = "0.3.30" -polars = { version = "0.41.3", features = ["dtype-struct", "lazy", "parquet"] } +polars = { version = "0.41.3", features = ["dtype-struct", "is_in", "lazy", "parquet"] } polars-lazy = "0.41.3" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tokio = { version = "1.36", features = ["full"] } +uuid = { version = "1.6.1", features = ["v4"] } [dev-dependencies] async-openai = "0.23.4" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index fd30fcfcb..053d6aecd 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -13,7 +13,23 @@ use rand::prelude::SliceRandom; use crate::models::{CommunityReport, Entity}; -use super::context_builder::GlobalSearchContextBuilderParams; +#[derive(Debug, Clone)] +pub struct CommunityContextBuilderParams { + pub use_community_summary: bool, + pub column_delimiter: String, + pub shuffle_data: bool, + pub include_community_rank: bool, + pub min_community_rank: u32, + pub community_rank_name: String, + pub include_community_weight: bool, + pub community_weight_name: String, + pub normalize_community_weight: bool, + pub max_tokens: usize, + pub context_name: String, + // conversation_history: Option, + // conversation_history_user_turns_only: bool, + // conversation_history_max_turns: Option, +} pub struct GlobalCommunityContext { community_reports: Vec, @@ -34,11 +50,11 @@ impl GlobalCommunityContext { } } - pub async fn build_context( + pub fn build_context( &self, - context_builder_params: GlobalSearchContextBuilderParams, + context_builder_params: CommunityContextBuilderParams, ) -> anyhow::Result<(Vec, HashMap)> { - let GlobalSearchContextBuilderParams { + let CommunityContextBuilderParams { use_community_summary, column_delimiter, shuffle_data, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs index 18fffb419..1455d3264 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -1,19 +1 @@ -#[derive(Debug, Clone)] -pub struct GlobalSearchContextBuilderParams { - pub use_community_summary: bool, - pub column_delimiter: String, - pub shuffle_data: bool, - pub include_community_rank: bool, - pub min_community_rank: u32, - pub community_rank_name: String, - pub include_community_weight: bool, - pub community_weight_name: String, - pub normalize_community_weight: bool, - pub max_tokens: usize, - pub context_name: String, - //conversation_history: Option, - // conversation_history_user_turns_only: bool, - // conversation_history_max_turns: Option, -} - pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 0abed5320..43d390fb6 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1,4 +1,2 @@ pub mod community_context; pub mod context_builder; -pub mod indexer_entities; -pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_entities.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs rename to shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_entities.rs diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs rename to shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs new file mode 100644 index 000000000..c49aae604 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs @@ -0,0 +1,2 @@ +pub mod indexer_entities; +pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs index f0c118869..30ceb1057 100644 --- a/shinkai-libs/shinkai-graphrag/src/lib.rs +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -1,5 +1,7 @@ pub mod context_builder; +pub mod indexer_adapters; pub mod llm; pub mod models; +pub mod retrieval; pub mod search; pub mod vector_stores; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 5fa5cf633..8d27e1563 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -47,6 +47,7 @@ pub trait BaseLLM { #[async_trait] pub trait BaseTextEmbedding { async fn aembed(&self, text: &str) -> Vec; + fn embed(&self, text: &str) -> Vec; } pub enum GlobalSearchPhase { diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs new file mode 100644 index 000000000..37efd2e10 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs @@ -0,0 +1,40 @@ +use std::collections::HashSet; + +use polars::frame::DataFrame; + +use crate::models::{CommunityReport, Entity}; + +pub fn get_candidate_communities( + selected_entities: Vec, + community_reports: Vec, + include_community_rank: bool, + use_community_summary: bool, +) -> anyhow::Result { + let mut selected_community_ids: HashSet = HashSet::new(); + for entity in &selected_entities { + if let Some(community_ids) = &entity.community_ids { + selected_community_ids.extend(community_ids.iter().cloned()); + } + } + + let mut selected_reports: Vec = Vec::new(); + for community in &community_reports { + if selected_community_ids.contains(&community.id) { + selected_reports.push(community.clone()); + } + } + + to_community_report_dataframe(selected_reports, include_community_rank, use_community_summary) +} + +pub fn to_community_report_dataframe( + reports: Vec, + include_community_rank: bool, + use_community_summary: bool, +) -> anyhow::Result { + if reports.is_empty() { + return Ok(DataFrame::default()); + } + + Ok(DataFrame::default()) +} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs new file mode 100644 index 000000000..9bb784569 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs @@ -0,0 +1,100 @@ +use std::collections::HashSet; +use uuid::Uuid; + +use crate::{llm::llm::BaseTextEmbedding, models::Entity, vector_stores::vector_store::VectorStore}; + +pub fn map_query_to_entities( + query: &str, + text_embedding_vectorstore: &Box, + text_embedder: &Box, + all_entities: &Vec, + embedding_vectorstore_key: &str, + include_entity_names: Option>, + exclude_entity_names: Option>, + k: usize, + oversample_scaler: usize, +) -> Vec { + let include_entity_names = include_entity_names.unwrap_or_else(Vec::new); + let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_else(Vec::new).into_iter().collect(); + let mut matched_entities = Vec::new(); + + if !query.is_empty() { + let search_results = text_embedding_vectorstore.similarity_search_by_text( + query, + &|t| text_embedder.embed(t), + k * oversample_scaler, + ); + + for result in search_results { + if let Some(matched) = get_entity_by_key(all_entities, &embedding_vectorstore_key, &result.document.id) { + matched_entities.push(matched); + } + } + } else { + let mut all_entities = all_entities.clone(); + all_entities.sort_by(|a, b| b.rank.unwrap_or(0).cmp(&a.rank.unwrap_or(0))); + matched_entities = all_entities.iter().take(k).cloned().collect(); + } + + matched_entities.retain(|entity| !exclude_entity_names.contains(&entity.title)); + + let mut included_entities = Vec::new(); + for entity_name in include_entity_names { + included_entities.extend(get_entity_by_name(all_entities, &entity_name)); + } + + included_entities.extend(matched_entities); + included_entities +} + +pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Option { + for entity in entities { + match key { + "id" => { + if entity.id == value + || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace("-", "") + { + return Some(entity.clone()); + } + } + "short_id" => { + if entity.short_id.as_ref().unwrap_or(&"".to_string()) == value + || is_valid_uuid(value) + && entity.short_id.as_ref().unwrap_or(&"".to_string()) + == Uuid::parse_str(value).unwrap().to_string().replace("-", "").as_str() + { + return Some(entity.clone()); + } + } + "title" => { + if entity.title == value { + return Some(entity.clone()); + } + } + "entity_type" => { + if entity.entity_type.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + "description" => { + if entity.description.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + _ => {} + } + } + None +} + +pub fn get_entity_by_name(entities: &Vec, entity_name: &str) -> Vec { + entities + .iter() + .filter(|entity| entity.title == entity_name) + .cloned() + .collect() +} + +pub fn is_valid_uuid(value: &str) -> bool { + Uuid::parse_str(value).is_ok() +} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs new file mode 100644 index 000000000..12a881a40 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs @@ -0,0 +1,2 @@ +pub mod community_reports; +pub mod entity_extraction; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index c537198a3..f4c15fb1c 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -3,8 +3,8 @@ use serde_json::Value; use std::collections::HashMap; use std::time::Instant; -use crate::context_builder::community_context::GlobalCommunityContext; -use crate::context_builder::context_builder::{ConversationHistory, GlobalSearchContextBuilderParams}; +use crate::context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}; +use crate::context_builder::context_builder::ConversationHistory; use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; @@ -62,7 +62,7 @@ pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, num_tokens_fn: fn(&str) -> usize, - context_builder_params: GlobalSearchContextBuilderParams, + context_builder_params: CommunityContextBuilderParams, map_system_prompt: String, reduce_system_prompt: String, response_type: String, @@ -88,7 +88,7 @@ pub struct GlobalSearchParams { pub max_data_tokens: usize, pub map_llm_params: LLMParams, pub reduce_llm_params: LLMParams, - pub context_builder_params: GlobalSearchContextBuilderParams, + pub context_builder_params: CommunityContextBuilderParams, } impl GlobalSearch { @@ -151,8 +151,7 @@ impl GlobalSearch { let start_time = Instant::now(); let (context_chunks, context_records) = self .context_builder - .build_context(self.context_builder_params.clone()) - .await?; + .build_context(self.context_builder_params.clone())?; let mut callbacks = match &self.callbacks { Some(callbacks) => { diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs index 955551f80..fbbe65638 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs @@ -6,7 +6,7 @@ use crate::{ }; use super::{ - mixed_context::{LocalSearchContextBuilderParams, LocalSearchMixedContext}, + mixed_context::{LocalSearchMixedContext, MixedContextBuilderParams}, prompts::LOCAL_SEARCH_SYSTEM_PROMPT, }; @@ -26,7 +26,7 @@ pub struct LocalSearch { system_prompt: String, response_type: String, llm_params: LLMParams, - context_builder_params: LocalSearchContextBuilderParams, + context_builder_params: MixedContextBuilderParams, } impl LocalSearch { @@ -35,7 +35,7 @@ impl LocalSearch { context_builder: LocalSearchMixedContext, num_tokens_fn: fn(&str) -> usize, llm_params: LLMParams, - context_builder_params: LocalSearchContextBuilderParams, + context_builder_params: MixedContextBuilderParams, response_type: String, system_prompt: Option, ) -> Self { @@ -81,6 +81,7 @@ impl LocalSearch { false, None, self.llm_params.clone(), + None, ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index 5cadc1653..c5d71309e 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -1,23 +1,29 @@ use std::collections::HashMap; -use polars::frame::DataFrame; +use polars::{ + frame::DataFrame, + prelude::{is_in, NamedFrom}, + series::Series, +}; use crate::{ + context_builder::community_context::CommunityContext, llm::llm::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, - vector_stores::lancedb::LanceDBVectorStore, + retrieval::{community_reports::get_candidate_communities, entity_extraction::map_query_to_entities}, + vector_stores::vector_store::VectorStore, }; #[derive(Debug, Clone)] -pub struct LocalSearchContextBuilderParams { +pub struct MixedContextBuilderParams { pub query: String, pub include_entity_names: Option>, pub exclude_entity_names: Option>, - pub max_tokens: i32, + pub max_tokens: usize, pub text_unit_prop: f32, pub community_prop: f32, - pub top_k_mapped_entities: i32, - pub top_k_relationships: i32, + pub top_k_mapped_entities: usize, + pub top_k_relationships: usize, pub include_community_rank: bool, pub include_entity_rank: bool, pub rank_description: String, @@ -25,7 +31,7 @@ pub struct LocalSearchContextBuilderParams { pub relationship_ranking_attribute: String, pub return_candidate_context: bool, pub use_community_summary: bool, - pub min_community_rank: i32, + pub min_community_rank: u32, pub community_context_name: String, pub column_delimiter: String, // pub conversation_history: Option, @@ -33,8 +39,8 @@ pub struct LocalSearchContextBuilderParams { // pub conversation_history_user_turns_only: bool, } -pub fn default_local_context_params() -> LocalSearchContextBuilderParams { - LocalSearchContextBuilderParams { +pub fn default_local_context_params() -> MixedContextBuilderParams { + MixedContextBuilderParams { query: String::new(), include_entity_names: None, exclude_entity_names: None, @@ -58,7 +64,7 @@ pub fn default_local_context_params() -> LocalSearchContextBuilderParams { pub struct LocalSearchMixedContext { entities: HashMap, - entity_text_embeddings: LanceDBVectorStore, + entity_text_embeddings: Box, text_embedder: Box, text_units: HashMap, community_reports: HashMap, @@ -70,7 +76,7 @@ pub struct LocalSearchMixedContext { impl LocalSearchMixedContext { pub fn new( entities: Vec, - entity_text_embeddings: LanceDBVectorStore, + entity_text_embeddings: Box, text_embedder: Box, text_units: Option>, community_reports: Option>, @@ -116,9 +122,9 @@ impl LocalSearchMixedContext { pub async fn build_context( &self, - context_builder_params: LocalSearchContextBuilderParams, + context_builder_params: MixedContextBuilderParams, ) -> anyhow::Result<(String, HashMap)> { - let LocalSearchContextBuilderParams { + let MixedContextBuilderParams { query, include_entity_names, exclude_entity_names, @@ -148,8 +154,159 @@ impl LocalSearchMixedContext { )); } + let selected_entities = map_query_to_entities( + &query, + &self.entity_text_embeddings, + &self.text_embedder, + &self.entities.values().cloned().collect::>(), + &self.embedding_vectorstore_key, + Some(include_entity_names), + Some(exclude_entity_names), + top_k_mapped_entities, + 2, + ); + + let community_tokens = std::cmp::max((max_tokens as f32 * community_prop) as usize, 0); + let context_text = String::new(); let context_records = HashMap::new(); Ok((context_text, context_records)) } + + fn _build_community_context( + &self, + selected_entities: Vec, + max_tokens: usize, + use_community_summary: bool, + column_delimiter: &str, + include_community_rank: bool, + min_community_rank: u32, + return_candidate_context: bool, + context_name: &str, + ) -> anyhow::Result<(String, HashMap)> { + if selected_entities.is_empty() || self.community_reports.is_empty() { + return Ok(( + "".to_string(), + HashMap::from([(context_name.to_lowercase(), DataFrame::default())]), + )); + } + + let mut community_matches: HashMap = HashMap::new(); + for entity in &selected_entities { + if let Some(community_ids) = &entity.community_ids { + for community_id in community_ids { + *community_matches.entry(community_id.to_string()).or_insert(0) += 1; + } + } + } + + let mut selected_communities: Vec = Vec::new(); + for community_id in community_matches.keys() { + if let Some(community) = self.community_reports.get(community_id) { + selected_communities.push(community.clone()); + } + } + + for community in &mut selected_communities { + if community.attributes.is_none() { + community.attributes = Some(HashMap::new()); + } + if let Some(attributes) = &mut community.attributes { + attributes.insert("matches".to_string(), community_matches[&community.id].to_string()); + } + } + + selected_communities.sort_by(|a, b| { + let a_matches = a + .attributes + .as_ref() + .unwrap() + .get("matches") + .unwrap() + .parse::() + .unwrap(); + let b_matches = b + .attributes + .as_ref() + .unwrap() + .get("matches") + .unwrap() + .parse::() + .unwrap(); + let a_rank = a.rank.unwrap(); + let b_rank = b.rank.unwrap(); + (b_matches, b_rank).partial_cmp(&(a_matches, a_rank)).unwrap() + }); + + for community in &mut selected_communities { + if let Some(attributes) = &mut community.attributes { + attributes.remove("matches"); + } + } + + let (context_text, context_data) = CommunityContext::build_community_context( + selected_communities, + None, + self.num_tokens_fn, + use_community_summary, + column_delimiter, + false, + include_community_rank, + min_community_rank, + "rank", + true, + "occurrence weight", + true, + max_tokens, + true, + context_name, + )?; + + let mut context_text_result = "".to_string(); + if !context_text.is_empty() { + context_text_result = context_text.join("\n\n"); + } + + let mut context_data = context_data; + if return_candidate_context { + let candidate_context_data = get_candidate_communities( + selected_entities, + self.community_reports.values().cloned().collect(), + use_community_summary, + include_community_rank, + )?; + + let context_key = context_name.to_lowercase(); + if !context_data.contains_key(&context_key) { + let mut new_data = candidate_context_data.clone(); + new_data + .with_column(Series::new("in_context", vec![false; candidate_context_data.height()])) + .unwrap(); + context_data.insert(context_key.to_string(), new_data); + } else { + let existing_data = context_data.get(&context_key).unwrap(); + if candidate_context_data + .get_column_names() + .contains(&"id".to_string().as_str()) + && existing_data.get_column_names().contains(&"id".to_string().as_str()) + { + let existing_ids = existing_data.column("id")?; + let context_ids = candidate_context_data.column("id")?; + let mut new_data = candidate_context_data.clone(); + let in_context = is_in(context_ids, existing_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + context_data.insert(context_key.to_string(), new_data); + } else { + let mut existing_data = existing_data.clone(); + existing_data + .with_column(Series::new("in_context", vec![true; existing_data.height()])) + .unwrap(); + context_data.insert(context_key.to_string(), existing_data); + } + } + } + + Ok((context_text_result, context_data)) + } } diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs index 832ccb1b0..e8083f72b 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs @@ -1 +1,2 @@ pub mod lancedb; +pub mod vector_store; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs new file mode 100644 index 000000000..b72d5763b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -0,0 +1,21 @@ +use std::collections::HashMap; + +pub struct VectorStoreSearchResult { + pub document: VectorStoreDocument, + pub score: f64, +} + +pub struct VectorStoreDocument { + pub id: String, + pub text: Option, + pub attributes: HashMap, +} + +pub trait VectorStore { + fn similarity_search_by_text( + &self, + text: &str, + text_embedder: &dyn Fn(&str) -> Vec, + k: usize, + ) -> Vec; +} diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 4c67feb4e..a808c7ff2 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -1,9 +1,7 @@ use polars::{io::SerReader, prelude::ParquetReader}; use shinkai_graphrag::{ - context_builder::{ - community_context::GlobalCommunityContext, context_builder::GlobalSearchContextBuilderParams, - indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, - }, + context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}, + indexer_adapters::{indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports}, llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; @@ -51,7 +49,7 @@ async fn ollama_global_search_test() -> Result<(), Box> { // Using tiktoken for token count estimation let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); - let context_builder_params = ContextBuilderParams { + let context_builder_params = CommunityContextBuilderParams { use_community_summary: false, // False means using full community reports. True means using community short summaries. shuffle_data: true, include_community_rank: true, @@ -149,7 +147,7 @@ async fn openai_global_search_test() -> Result<(), Box> { let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); - let context_builder_params = GlobalSearchContextBuilderParams { + let context_builder_params = CommunityContextBuilderParams { use_community_summary: false, // False means using full community reports. True means using community short summaries. shuffle_data: true, include_community_rank: true, From b38fe3b5b7bf7492b62d960e0f4df7d8497bb4d2 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 22 Aug 2024 15:12:00 +0200 Subject: [PATCH 13/23] local context, relationships --- .../src/context_builder/entity_extraction.rs | 52 +++ .../src/context_builder/local_context.rs | 403 ++++++++++++++++++ .../src/context_builder/mod.rs | 2 + .../src/retrieval/community_reports.rs | 72 +++- .../src/retrieval/entities.rs | 55 +++ .../src/retrieval/entity_extraction.rs | 100 ----- .../shinkai-graphrag/src/retrieval/mod.rs | 3 +- .../src/retrieval/relationships.rs | 141 ++++++ .../src/search/local_search/mixed_context.rs | 112 ++++- 9 files changed, 834 insertions(+), 106 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs delete mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs new file mode 100644 index 000000000..ecf216f22 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -0,0 +1,52 @@ +use std::collections::HashSet; + +use crate::{ + llm::llm::BaseTextEmbedding, + models::Entity, + retrieval::entities::{get_entity_by_key, get_entity_by_name}, + vector_stores::vector_store::VectorStore, +}; + +pub fn map_query_to_entities( + query: &str, + text_embedding_vectorstore: &Box, + text_embedder: &Box, + all_entities: &Vec, + embedding_vectorstore_key: &str, + include_entity_names: Option>, + exclude_entity_names: Option>, + k: usize, + oversample_scaler: usize, +) -> Vec { + let include_entity_names = include_entity_names.unwrap_or_else(Vec::new); + let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_else(Vec::new).into_iter().collect(); + let mut matched_entities = Vec::new(); + + if !query.is_empty() { + let search_results = text_embedding_vectorstore.similarity_search_by_text( + query, + &|t| text_embedder.embed(t), + k * oversample_scaler, + ); + + for result in search_results { + if let Some(matched) = get_entity_by_key(all_entities, &embedding_vectorstore_key, &result.document.id) { + matched_entities.push(matched); + } + } + } else { + let mut all_entities = all_entities.clone(); + all_entities.sort_by(|a, b| b.rank.unwrap_or(0).cmp(&a.rank.unwrap_or(0))); + matched_entities = all_entities.iter().take(k).cloned().collect(); + } + + matched_entities.retain(|entity| !exclude_entity_names.contains(&entity.title)); + + let mut included_entities = Vec::new(); + for entity_name in include_entity_names { + included_entities.extend(get_entity_by_name(all_entities, &entity_name)); + } + + included_entities.extend(matched_entities); + included_entities +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs new file mode 100644 index 000000000..77d4e5da9 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -0,0 +1,403 @@ +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, +}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::{ + models::{Entity, Relationship}, + retrieval::relationships::{get_in_network_relationships, get_out_network_relationships}, +}; + +pub fn build_entity_context( + selected_entities: Vec, + num_tokens_fn: fn(&str) -> usize, + max_tokens: usize, + include_entity_rank: bool, + rank_description: &str, + column_delimiter: &str, + context_name: &str, +) -> anyhow::Result<(String, DataFrame)> { + if selected_entities.is_empty() { + return Ok((String::new(), DataFrame::default())); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec!["id".to_string(), "entity".to_string(), "description".to_string()]; + + if include_entity_rank { + header.push(rank_description.to_string()); + } + + let attribute_cols = if let Some(first_entity) = selected_entities.first().cloned() { + first_entity + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + header.extend(attribute_cols.clone()); + current_context_text += &header.join(column_delimiter); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for entity in selected_entities { + let mut new_context = vec![ + entity.short_id.clone().unwrap_or_default(), + entity.title.clone(), + entity.description.clone().unwrap_or_default(), + ]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(entity.short_id.unwrap_or_default()); + records.entry("entity").or_insert_with(Vec::new).push(entity.title); + records + .entry("description") + .or_insert_with(Vec::new) + .push(entity.description.unwrap_or_default()); + + if include_entity_rank { + new_context.push(entity.rank.unwrap_or(0).to_string()); + + records + .entry("rank") + .or_insert_with(Vec::new) + .push(entity.rank.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + let field_value = entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let new_context_text = new_context.join(column_delimiter); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += &format!("\n{}", new_context_text); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "rank" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok((current_context_text, record_df)) +} + +pub fn build_relationship_context( + selected_entities: &Vec, + relationships: &Vec, + num_tokens_fn: fn(&str) -> usize, + include_relationship_weight: bool, + max_tokens: usize, + top_k_relationships: usize, + relationship_ranking_attribute: &str, + column_delimiter: &str, + context_name: &str, +) -> anyhow::Result<(String, DataFrame)> { + // Filter relationships based on the criteria + let selected_relationships = _filter_relationships( + &selected_entities, + &relationships, + top_k_relationships, + relationship_ranking_attribute, + ); + + if selected_entities.is_empty() || selected_relationships.is_empty() { + return Ok((String::new(), DataFrame::default())); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec![ + "id".to_string(), + "source".to_string(), + "target".to_string(), + "description".to_string(), + ]; + + if include_relationship_weight { + header.push("weight".to_string()); + } + + let attribute_cols = if let Some(first_rel) = selected_relationships.first().cloned() { + first_rel + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + current_context_text.push_str(&header.join(column_delimiter)); + current_context_text.push('\n'); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for rel in selected_relationships { + let mut new_context = vec![ + rel.short_id.clone().unwrap_or_default(), + rel.source.clone(), + rel.target.clone(), + rel.description.clone().unwrap_or_default(), + ]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(rel.short_id.unwrap_or_default()); + records.entry("source").or_insert_with(Vec::new).push(rel.source); + records.entry("target").or_insert_with(Vec::new).push(rel.target); + records + .entry("description") + .or_insert_with(Vec::new) + .push(rel.description.unwrap_or_default()); + + if include_relationship_weight { + new_context.push(rel.weight.map_or(String::new(), |w| w.to_string())); + + records + .entry("weight") + .or_insert_with(Vec::new) + .push(rel.weight.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + let field_value = rel + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + rel.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let mut new_context_text = new_context.join(column_delimiter); + new_context_text.push('\n'); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += new_context_text.as_str(); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "weight" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok((current_context_text, record_df)) +} + +fn _filter_relationships( + selected_entities: &Vec, + relationships: &Vec, + top_k_relationships: usize, + relationship_ranking_attribute: &str, +) -> Vec { + // First priority: in-network relationships (i.e. relationships between selected entities) + let in_network_relationships = + get_in_network_relationships(selected_entities, relationships, relationship_ranking_attribute); + + // Second priority - out-of-network relationships + // (i.e. relationships between selected entities and other entities that are not within the selected entities) + let mut out_network_relationships = + get_out_network_relationships(selected_entities, relationships, relationship_ranking_attribute); + + if out_network_relationships.len() <= 1 { + return [in_network_relationships, out_network_relationships].concat(); + } + + // within out-of-network relationships, prioritize mutual relationships + // (i.e. relationships with out-network entities that are shared with multiple selected entities) + let selected_entity_names: HashSet = selected_entities.iter().map(|e| e.title.clone()).collect(); + + let out_network_source_names: Vec = out_network_relationships + .iter() + .filter(|r| !selected_entity_names.contains(&r.source)) + .map(|r| r.source.clone()) + .collect(); + + let out_network_target_names: Vec = out_network_relationships + .iter() + .filter(|r| !selected_entity_names.contains(&r.target)) + .map(|r| r.target.clone()) + .collect(); + + let out_network_entity_names: HashSet = out_network_source_names + .into_iter() + .chain(out_network_target_names.into_iter()) + .collect(); + + let mut out_network_entity_links: HashMap = HashMap::new(); + + for entity_name in out_network_entity_names { + let targets: HashSet = out_network_relationships + .iter() + .filter(|r| r.source == entity_name) + .map(|r| r.target.clone()) + .collect(); + + let sources: HashSet = out_network_relationships + .iter() + .filter(|r| r.target == entity_name) + .map(|r| r.source.clone()) + .collect(); + + out_network_entity_links.insert(entity_name, targets.union(&sources).count()); + } + + // sort out-network relationships by number of links and rank_attributes + for relationship in &mut out_network_relationships { + if relationship.attributes.is_none() { + relationship.attributes = Some(HashMap::new()); + } + + let links = if out_network_entity_links.contains_key(&relationship.source) { + *out_network_entity_links.get(&relationship.source).unwrap() + } else { + *out_network_entity_links.get(&relationship.target).unwrap() + }; + relationship + .attributes + .as_mut() + .unwrap() + .insert("links".to_string(), links.to_string()); + } + + // Sort by attributes[links] first, then by ranking_attribute + if relationship_ranking_attribute == "weight" { + out_network_relationships.sort_by(|a, b| { + let a_links = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_links = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + + b_links + .partial_cmp(&a_links) + .unwrap_or(Ordering::Equal) + .then(b.weight.partial_cmp(&a.weight).unwrap_or(Ordering::Equal)) + }); + } else { + out_network_relationships.sort_by(|a, b| { + let a_links = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_links = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(relationship_ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0.0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(relationship_ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0.0); + + b_links + .partial_cmp(&a_links) + .unwrap_or(Ordering::Equal) + .then(b_rank.partial_cmp(&a_rank).unwrap_or(Ordering::Equal)) + }); + } + + let relationship_budget = top_k_relationships * selected_entities.len(); + out_network_relationships.truncate(relationship_budget); + + Vec::new() +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 43d390fb6..c0e14f261 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1,2 +1,4 @@ pub mod community_context; pub mod context_builder; +pub mod entity_extraction; +pub mod local_context; diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs index 37efd2e10..504f93df2 100644 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; -use polars::frame::DataFrame; +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use crate::models::{CommunityReport, Entity}; @@ -36,5 +36,71 @@ pub fn to_community_report_dataframe( return Ok(DataFrame::default()); } - Ok(DataFrame::default()) + let mut header = vec!["id".to_string(), "title".to_string()]; + let attribute_cols: Vec = reports[0] + .attributes + .as_ref() + .map(|attrs| attrs.keys().filter(|&col| !header.contains(&col)).cloned().collect()) + .unwrap_or_default(); + + header.extend(attribute_cols.iter().cloned()); + header.push(if use_community_summary { "summary" } else { "content" }.to_string()); + if include_community_rank { + header.push("rank".to_string()); + } + + let mut records = HashMap::new(); + for report in reports { + records + .entry("id") + .or_insert_with(Vec::new) + .push(report.short_id.unwrap_or_default()); + records.entry("title").or_insert_with(Vec::new).push(report.title); + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + if use_community_summary { + records.entry("summary").or_insert_with(Vec::new).push(report.summary); + } else { + records + .entry("content") + .or_insert_with(Vec::new) + .push(report.full_content); + } + + if include_community_rank { + records + .entry("rank") + .or_insert_with(Vec::new) + .push(report.rank.map(|r| r.to_string()).unwrap_or_default()); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "rank" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = DataFrame::new(data_series)?; + + Ok(record_df) } diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs new file mode 100644 index 000000000..fa3fb102c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs @@ -0,0 +1,55 @@ +use uuid::Uuid; + +use crate::models::Entity; + +pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Option { + for entity in entities { + match key { + "id" => { + if entity.id == value + || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace("-", "") + { + return Some(entity.clone()); + } + } + "short_id" => { + if entity.short_id.as_ref().unwrap_or(&"".to_string()) == value + || is_valid_uuid(value) + && entity.short_id.as_ref().unwrap_or(&"".to_string()) + == Uuid::parse_str(value).unwrap().to_string().replace("-", "").as_str() + { + return Some(entity.clone()); + } + } + "title" => { + if entity.title == value { + return Some(entity.clone()); + } + } + "entity_type" => { + if entity.entity_type.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + "description" => { + if entity.description.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + _ => {} + } + } + None +} + +pub fn get_entity_by_name(entities: &Vec, entity_name: &str) -> Vec { + entities + .iter() + .filter(|entity| entity.title == entity_name) + .cloned() + .collect() +} + +pub fn is_valid_uuid(value: &str) -> bool { + Uuid::parse_str(value).is_ok() +} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs deleted file mode 100644 index 9bb784569..000000000 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/entity_extraction.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::collections::HashSet; -use uuid::Uuid; - -use crate::{llm::llm::BaseTextEmbedding, models::Entity, vector_stores::vector_store::VectorStore}; - -pub fn map_query_to_entities( - query: &str, - text_embedding_vectorstore: &Box, - text_embedder: &Box, - all_entities: &Vec, - embedding_vectorstore_key: &str, - include_entity_names: Option>, - exclude_entity_names: Option>, - k: usize, - oversample_scaler: usize, -) -> Vec { - let include_entity_names = include_entity_names.unwrap_or_else(Vec::new); - let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_else(Vec::new).into_iter().collect(); - let mut matched_entities = Vec::new(); - - if !query.is_empty() { - let search_results = text_embedding_vectorstore.similarity_search_by_text( - query, - &|t| text_embedder.embed(t), - k * oversample_scaler, - ); - - for result in search_results { - if let Some(matched) = get_entity_by_key(all_entities, &embedding_vectorstore_key, &result.document.id) { - matched_entities.push(matched); - } - } - } else { - let mut all_entities = all_entities.clone(); - all_entities.sort_by(|a, b| b.rank.unwrap_or(0).cmp(&a.rank.unwrap_or(0))); - matched_entities = all_entities.iter().take(k).cloned().collect(); - } - - matched_entities.retain(|entity| !exclude_entity_names.contains(&entity.title)); - - let mut included_entities = Vec::new(); - for entity_name in include_entity_names { - included_entities.extend(get_entity_by_name(all_entities, &entity_name)); - } - - included_entities.extend(matched_entities); - included_entities -} - -pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Option { - for entity in entities { - match key { - "id" => { - if entity.id == value - || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace("-", "") - { - return Some(entity.clone()); - } - } - "short_id" => { - if entity.short_id.as_ref().unwrap_or(&"".to_string()) == value - || is_valid_uuid(value) - && entity.short_id.as_ref().unwrap_or(&"".to_string()) - == Uuid::parse_str(value).unwrap().to_string().replace("-", "").as_str() - { - return Some(entity.clone()); - } - } - "title" => { - if entity.title == value { - return Some(entity.clone()); - } - } - "entity_type" => { - if entity.entity_type.as_ref().unwrap_or(&"".to_string()) == value { - return Some(entity.clone()); - } - } - "description" => { - if entity.description.as_ref().unwrap_or(&"".to_string()) == value { - return Some(entity.clone()); - } - } - _ => {} - } - } - None -} - -pub fn get_entity_by_name(entities: &Vec, entity_name: &str) -> Vec { - entities - .iter() - .filter(|entity| entity.title == entity_name) - .cloned() - .collect() -} - -pub fn is_valid_uuid(value: &str) -> bool { - Uuid::parse_str(value).is_ok() -} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs index 12a881a40..57119bfe4 100644 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs @@ -1,2 +1,3 @@ pub mod community_reports; -pub mod entity_extraction; +pub mod entities; +pub mod relationships; diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs new file mode 100644 index 000000000..8d2683184 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs @@ -0,0 +1,141 @@ +use std::{cmp::Ordering, collections::HashMap}; + +use crate::models::{Entity, Relationship}; + +pub fn get_in_network_relationships( + selected_entities: &Vec, + relationships: &Vec, + ranking_attribute: &str, +) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); + + let selected_relationships: Vec = relationships + .clone() + .into_iter() + .filter(|relationship| { + selected_entity_names.contains(&relationship.source) && selected_entity_names.contains(&relationship.target) + }) + .collect(); + + if selected_relationships.len() <= 1 { + return selected_relationships; + } + + // Sort by ranking attribute + sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) +} + +pub fn get_out_network_relationships( + selected_entities: &Vec, + relationships: &Vec, + ranking_attribute: &str, +) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|e| e.title.clone()).collect(); + + let source_relationships: Vec = relationships + .iter() + .filter(|r| selected_entity_names.contains(&r.source) && !selected_entity_names.contains(&r.target)) + .cloned() + .collect(); + + let target_relationships: Vec = relationships + .iter() + .filter(|r| selected_entity_names.contains(&r.target) && !selected_entity_names.contains(&r.source)) + .cloned() + .collect(); + + let selected_relationships = [source_relationships, target_relationships].concat(); + + sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) +} + +pub fn sort_relationships_by_ranking_attribute( + relationships: Vec, + entities: Vec, + ranking_attribute: &str, +) -> Vec { + if relationships.is_empty() { + return relationships; + } + + let mut relationships = relationships; + + let attribute_names: Vec = if let Some(attributes) = &relationships[0].attributes { + attributes.keys().cloned().collect() + } else { + Vec::new() + }; + + if attribute_names.contains(&ranking_attribute.to_string()) { + relationships.sort_by(|a, b| { + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + b_rank.cmp(&a_rank) + }); + } else if ranking_attribute == "weight" { + relationships.sort_by(|a, b| { + let a_weight = a.weight.unwrap_or(0.0); + let b_weight = b.weight.unwrap_or(0.0); + b_weight.partial_cmp(&a_weight).unwrap_or(Ordering::Equal) + }); + } else { + relationships = calculate_relationship_combined_rank(relationships, entities, ranking_attribute); + relationships.sort_by(|a, b| { + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + b_rank.cmp(&a_rank) + }); + } + + relationships +} + +pub fn calculate_relationship_combined_rank( + relationships: Vec, + entities: Vec, + ranking_attribute: &str, +) -> Vec { + let mut relationships = relationships; + let entity_mappings: HashMap<_, _> = entities.iter().map(|e| (e.title.clone(), e)).collect(); + + for relationship in relationships.iter_mut() { + if relationship.attributes.is_none() { + relationship.attributes = Some(HashMap::new()); + } + + let source_rank = entity_mappings + .get(&relationship.source) + .and_then(|e| e.rank) + .unwrap_or(0); + let target_rank = entity_mappings + .get(&relationship.target) + .and_then(|e| e.rank) + .unwrap_or(0); + + if let Some(attributes) = &mut relationship.attributes { + attributes.insert(ranking_attribute.to_string(), (source_rank + target_rank).to_string()); + } + } + + relationships +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index c5d71309e..ae8764829 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -7,10 +7,14 @@ use polars::{ }; use crate::{ - context_builder::community_context::CommunityContext, + context_builder::{ + community_context::CommunityContext, + entity_extraction::map_query_to_entities, + local_context::{build_entity_context, build_relationship_context}, + }, llm::llm::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, - retrieval::{community_reports::get_candidate_communities, entity_extraction::map_query_to_entities}, + retrieval::community_reports::get_candidate_communities, vector_stores::vector_store::VectorStore, }; @@ -166,7 +170,44 @@ impl LocalSearchMixedContext { 2, ); + let mut final_context = Vec::new(); + let mut final_context_data = HashMap::new(); + let community_tokens = std::cmp::max((max_tokens as f32 * community_prop) as usize, 0); + let (community_context, community_context_data) = self._build_community_context( + selected_entities.clone(), + community_tokens, + use_community_summary, + &column_delimiter, + include_community_rank, + min_community_rank, + return_candidate_context, + &community_context_name, + )?; + + if !community_context.trim().is_empty() { + final_context.push(community_context); + final_context_data.extend(community_context_data); + } + + let local_prop = 1 as f32 - community_prop - text_unit_prop; + let local_tokens = std::cmp::max((max_tokens as f32 * local_prop) as usize, 0); + let (local_context, local_context_data) = self._build_local_context( + selected_entities.clone(), + local_tokens, + include_entity_rank, + &rank_description, + include_relationship_weight, + top_k_relationships, + &relationship_ranking_attribute, + return_candidate_context, + &column_delimiter, + )?; + + if !local_context.trim().is_empty() { + final_context.push(local_context); + final_context_data.extend(local_context_data); + } let context_text = String::new(); let context_records = HashMap::new(); @@ -309,4 +350,71 @@ impl LocalSearchMixedContext { Ok((context_text_result, context_data)) } + + fn _build_local_context( + &self, + selected_entities: Vec, + max_tokens: usize, + include_entity_rank: bool, + rank_description: &str, + include_relationship_weight: bool, + top_k_relationships: usize, + relationship_ranking_attribute: &str, + return_candidate_context: bool, + column_delimiter: &str, + ) -> anyhow::Result<(String, HashMap)> { + let (entity_context, entity_context_data) = build_entity_context( + selected_entities.clone(), + self.num_tokens_fn, + max_tokens, + include_entity_rank, + rank_description, + column_delimiter, + "Entities", + )?; + + let entity_tokens = (self.num_tokens_fn)(&entity_context); + + let mut added_entities = Vec::new(); + let mut final_context = Vec::new(); + let mut final_context_data = HashMap::new(); + + for entity in selected_entities { + let mut current_context = Vec::new(); + let mut current_context_data = HashMap::new(); + added_entities.push(entity); + + let (relationship_context, relationship_context_data) = build_relationship_context( + &added_entities, + &self.relationships.values().cloned().collect(), + self.num_tokens_fn, + include_relationship_weight, + max_tokens, + top_k_relationships, + relationship_ranking_attribute, + column_delimiter, + "Relationships", + )?; + + current_context.push(relationship_context.clone()); + current_context_data.insert("relationships", relationship_context_data); + + let total_tokens = entity_tokens + (self.num_tokens_fn)(&relationship_context); + + if total_tokens > max_tokens { + eprintln!("Reached token limit - reverting to previous context state"); + break; + } + + final_context = current_context; + final_context_data = current_context_data; + } + + let mut final_context_text = entity_context.to_string(); + final_context_text.push_str("\n\n"); + final_context_text.push_str(&final_context.join("\n\n")); + final_context_data.insert("entities", entity_context_data.clone()); + + Ok(("".to_string(), HashMap::new())) + } } From af79dc8e4d754bd0e8d86e8486632426ea95a462 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 23 Aug 2024 13:54:04 +0200 Subject: [PATCH 14/23] build local context, text unit context, relationships --- .../src/context_builder/local_context.rs | 33 ++- .../src/context_builder/mod.rs | 1 + .../src/context_builder/source_context.rs | 144 ++++++++++++ .../src/retrieval/entities.rs | 91 ++++++++ .../shinkai-graphrag/src/retrieval/mod.rs | 1 + .../src/retrieval/relationships.rs | 125 +++++++++++ .../src/retrieval/text_units.rs | 84 +++++++ .../src/search/local_search/mixed_context.rs | 212 ++++++++++++++++-- 8 files changed, 674 insertions(+), 17 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs index 77d4e5da9..4d31cf61a 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -7,7 +7,13 @@ use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use crate::{ models::{Entity, Relationship}, - retrieval::relationships::{get_in_network_relationships, get_out_network_relationships}, + retrieval::{ + entities::to_entity_dataframe, + relationships::{ + get_candidate_relationships, get_entities_from_relationships, get_in_network_relationships, + get_out_network_relationships, to_relationship_dataframe, + }, + }, }; pub fn build_entity_context( @@ -401,3 +407,28 @@ fn _filter_relationships( Vec::new() } + +pub fn get_candidate_context( + selected_entities: &Vec, + entities: &Vec, + relationships: &Vec, + include_entity_rank: bool, + entity_rank_description: &str, + include_relationship_weight: bool, +) -> anyhow::Result> { + let mut candidate_context = HashMap::new(); + + let candidate_relationships = get_candidate_relationships(selected_entities, relationships); + candidate_context.insert( + "relationships".to_string(), + to_relationship_dataframe(&candidate_relationships, include_relationship_weight)?, + ); + + let candidate_entities = get_entities_from_relationships(&candidate_relationships, entities); + candidate_context.insert( + "entities".to_string(), + to_entity_dataframe(&candidate_entities, include_entity_rank, entity_rank_description)?, + ); + + Ok(candidate_context) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index c0e14f261..f781afd72 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -2,3 +2,4 @@ pub mod community_context; pub mod context_builder; pub mod entity_extraction; pub mod local_context; +pub mod source_context; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs new file mode 100644 index 000000000..0746cf94b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs @@ -0,0 +1,144 @@ +use std::collections::HashMap; + +use polars::frame::DataFrame; +use polars::prelude::NamedFrom; +use polars::series::Series; +use rand::prelude::SliceRandom; +use rand::{rngs::StdRng, SeedableRng}; + +use crate::models::{Entity, Relationship, TextUnit}; + +pub fn build_text_unit_context( + text_units: Vec, + num_tokens_fn: fn(&str) -> usize, + column_delimiter: &str, + shuffle_data: bool, + max_tokens: usize, + context_name: &str, + random_state: u64, +) -> anyhow::Result<(String, HashMap)> { + if text_units.is_empty() { + return Ok((String::new(), HashMap::new())); + } + + let mut text_units = text_units; + + if shuffle_data { + let mut rng = StdRng::seed_from_u64(random_state); + text_units.shuffle(&mut rng); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec!["id".to_string(), "text".to_string()]; + + let attribute_cols = if let Some(text_unit) = text_units.first().cloned() { + text_unit + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + current_context_text += &header.join(column_delimiter); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for unit in text_units { + let mut new_context = vec![unit.short_id.clone().unwrap_or_default(), unit.text.clone()]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(unit.short_id.unwrap_or_default()); + records.entry("text").or_insert_with(Vec::new).push(unit.text); + + for field in &attribute_cols { + let field_value = unit + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + unit.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let new_context_text = new_context.join(column_delimiter); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += &format!("\n{}", new_context_text); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + let series = Series::new(header, data_values); + data_series.push(series); + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(( + current_context_text, + HashMap::from([(context_name.to_lowercase(), record_df)]), + )) +} + +pub fn count_relationships( + text_unit: &TextUnit, + entity: &Entity, + relationships: &HashMap, +) -> usize { + let matching_relationships: Vec<&Relationship> = if text_unit.relationship_ids.is_none() { + let entity_relationships: Vec<&Relationship> = relationships + .values() + .filter(|rel| rel.source == entity.title || rel.target == entity.title) + .collect(); + + let entity_relationships: Vec<&Relationship> = entity_relationships + .into_iter() + .filter(|rel| rel.text_unit_ids.is_some()) + .collect(); + + entity_relationships + .into_iter() + .filter(|rel| rel.text_unit_ids.as_ref().unwrap().contains(&text_unit.id)) + .collect() + } else { + let text_unit_relationships: Vec<&Relationship> = text_unit + .relationship_ids + .as_ref() + .unwrap() + .iter() + .filter_map(|rel_id| relationships.get(rel_id)) + .collect(); + + text_unit_relationships + .into_iter() + .filter(|rel| rel.source == entity.title || rel.target == entity.title) + .collect() + }; + + matching_relationships.len() +} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs index fa3fb102c..bc62e48f1 100644 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use uuid::Uuid; use crate::models::Entity; @@ -50,6 +53,94 @@ pub fn get_entity_by_name(entities: &Vec, entity_name: &str) -> Vec, + include_entity_rank: bool, + rank_description: &str, +) -> anyhow::Result { + if entities.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec!["id".to_string(), "entity".to_string(), "description".to_string()]; + + if include_entity_rank { + header.push(rank_description.to_string()); + } + + let attribute_cols = if let Some(first_entity) = entities.first().cloned() { + first_entity + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for entity in entities { + records + .entry("id") + .or_insert_with(Vec::new) + .push(entity.short_id.clone().unwrap_or_default()); + records + .entry("entity") + .or_insert_with(Vec::new) + .push(entity.title.clone()); + records + .entry("description") + .or_insert_with(Vec::new) + .push(entity.description.clone().unwrap_or_default()); + + if include_entity_rank { + records + .entry("rank") + .or_insert_with(Vec::new) + .push(entity.rank.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "rank" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} + pub fn is_valid_uuid(value: &str) -> bool { Uuid::parse_str(value).is_ok() } diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs index 57119bfe4..b56ec376b 100644 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs @@ -1,3 +1,4 @@ pub mod community_reports; pub mod entities; pub mod relationships; +pub mod text_units; diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs index 8d2683184..83353dd86 100644 --- a/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs @@ -1,5 +1,7 @@ use std::{cmp::Ordering, collections::HashMap}; +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + use crate::models::{Entity, Relationship}; pub fn get_in_network_relationships( @@ -49,6 +51,34 @@ pub fn get_out_network_relationships( sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) } +pub fn get_candidate_relationships( + selected_entities: &Vec, + relationships: &Vec, +) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); + + relationships + .iter() + .cloned() + .filter(|relationship| { + selected_entity_names.contains(&relationship.source) || selected_entity_names.contains(&relationship.target) + }) + .collect() +} + +pub fn get_entities_from_relationships(relationships: &Vec, entities: &Vec) -> Vec { + let selected_entity_names: Vec = relationships + .iter() + .flat_map(|relationship| vec![relationship.source.clone(), relationship.target.clone()]) + .collect(); + + entities + .iter() + .cloned() + .filter(|entity| selected_entity_names.contains(&entity.title)) + .collect() +} + pub fn sort_relationships_by_ranking_attribute( relationships: Vec, entities: Vec, @@ -139,3 +169,98 @@ pub fn calculate_relationship_combined_rank( relationships } + +pub fn to_relationship_dataframe( + relationships: &Vec, + include_relationship_weight: bool, +) -> anyhow::Result { + if relationships.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec![ + "id".to_string(), + "source".to_string(), + "target".to_string(), + "description".to_string(), + ]; + + if include_relationship_weight { + header.push("weight".to_string()); + } + + let attribute_cols = if let Some(relationship) = relationships.first().cloned() { + relationship + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for rel in relationships { + records + .entry("id") + .or_insert_with(Vec::new) + .push(rel.short_id.clone().unwrap_or_default()); + records + .entry("source") + .or_insert_with(Vec::new) + .push(rel.source.clone()); + records + .entry("target") + .or_insert_with(Vec::new) + .push(rel.target.clone()); + records + .entry("description") + .or_insert_with(Vec::new) + .push(rel.description.clone().unwrap_or_default()); + + if include_relationship_weight { + records + .entry("weight") + .or_insert_with(Vec::new) + .push(rel.weight.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + rel.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "weight" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs b/shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs new file mode 100644 index 000000000..ad839b88e --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs @@ -0,0 +1,84 @@ +use std::collections::{HashMap, HashSet}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::models::{Entity, TextUnit}; + +pub fn get_candidate_text_units( + selected_entities: &Vec, + text_units: &Vec, +) -> anyhow::Result { + let mut selected_text_ids: HashSet = HashSet::new(); + + for entity in selected_entities { + if let Some(ids) = &entity.text_unit_ids { + for id in ids { + selected_text_ids.insert(id.to_string()); + } + } + } + + let selected_text_units: Vec = text_units + .iter() + .cloned() + .filter(|unit| selected_text_ids.contains(&unit.id)) + .collect(); + + to_text_unit_dataframe(selected_text_units) +} + +pub fn to_text_unit_dataframe(text_units: Vec) -> anyhow::Result { + if text_units.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec!["id".to_string(), "text".to_string()]; + + let attribute_cols = if let Some(text_unit) = text_units.first().cloned() { + text_unit + .attributes + .unwrap_or_default() + .keys() + .map(|s| s.clone()) + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for unit in text_units { + records + .entry("id") + .or_insert_with(Vec::new) + .push(unit.short_id.clone().unwrap_or_default()); + records.entry("text").or_insert_with(Vec::new).push(unit.text.clone()); + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + unit.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + let series = Series::new(header, data_values); + data_series.push(series); + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index ae8764829..2fd3864a2 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -10,11 +10,12 @@ use crate::{ context_builder::{ community_context::CommunityContext, entity_extraction::map_query_to_entities, - local_context::{build_entity_context, build_relationship_context}, + local_context::{build_entity_context, build_relationship_context, get_candidate_context}, + source_context::{build_text_unit_context, count_relationships}, }, llm::llm::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, - retrieval::community_reports::get_candidate_communities, + retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, vector_stores::vector_store::VectorStore, }; @@ -173,6 +174,7 @@ impl LocalSearchMixedContext { let mut final_context = Vec::new(); let mut final_context_data = HashMap::new(); + // build community context let community_tokens = std::cmp::max((max_tokens as f32 * community_prop) as usize, 0); let (community_context, community_context_data) = self._build_community_context( selected_entities.clone(), @@ -190,6 +192,7 @@ impl LocalSearchMixedContext { final_context_data.extend(community_context_data); } + // build local (i.e. entity-relationship-covariate) context let local_prop = 1 as f32 - community_prop - text_unit_prop; let local_tokens = std::cmp::max((max_tokens as f32 * local_prop) as usize, 0); let (local_context, local_context_data) = self._build_local_context( @@ -209,9 +212,22 @@ impl LocalSearchMixedContext { final_context_data.extend(local_context_data); } - let context_text = String::new(); - let context_records = HashMap::new(); - Ok((context_text, context_records)) + // build text unit context + let text_unit_tokens = std::cmp::max((max_tokens as f32 * text_unit_prop) as usize, 0); + let (text_unit_context, text_unit_context_data) = self._build_text_unit_context( + selected_entities.clone(), + text_unit_tokens, + return_candidate_context, + "|", + "Sources", + )?; + + if !text_unit_context.trim().is_empty() { + final_context.push(text_unit_context); + final_context_data.extend(text_unit_context_data); + } + + Ok((final_context.join("\n\n"), final_context_data)) } fn _build_community_context( @@ -320,9 +336,7 @@ impl LocalSearchMixedContext { let context_key = context_name.to_lowercase(); if !context_data.contains_key(&context_key) { let mut new_data = candidate_context_data.clone(); - new_data - .with_column(Series::new("in_context", vec![false; candidate_context_data.height()])) - .unwrap(); + new_data.with_column(Series::new("in_context", vec![false; candidate_context_data.height()]))?; context_data.insert(context_key.to_string(), new_data); } else { let existing_data = context_data.get(&context_key).unwrap(); @@ -340,9 +354,7 @@ impl LocalSearchMixedContext { context_data.insert(context_key.to_string(), new_data); } else { let mut existing_data = existing_data.clone(); - existing_data - .with_column(Series::new("in_context", vec![true; existing_data.height()])) - .unwrap(); + existing_data.with_column(Series::new("in_context", vec![true; existing_data.height()]))?; context_data.insert(context_key.to_string(), existing_data); } } @@ -379,10 +391,10 @@ impl LocalSearchMixedContext { let mut final_context = Vec::new(); let mut final_context_data = HashMap::new(); - for entity in selected_entities { + for entity in &selected_entities { let mut current_context = Vec::new(); let mut current_context_data = HashMap::new(); - added_entities.push(entity); + added_entities.push(entity.clone()); let (relationship_context, relationship_context_data) = build_relationship_context( &added_entities, @@ -397,7 +409,7 @@ impl LocalSearchMixedContext { )?; current_context.push(relationship_context.clone()); - current_context_data.insert("relationships", relationship_context_data); + current_context_data.insert("relationships".to_string(), relationship_context_data); let total_tokens = entity_tokens + (self.num_tokens_fn)(&relationship_context); @@ -413,8 +425,176 @@ impl LocalSearchMixedContext { let mut final_context_text = entity_context.to_string(); final_context_text.push_str("\n\n"); final_context_text.push_str(&final_context.join("\n\n")); - final_context_data.insert("entities", entity_context_data.clone()); + final_context_data.insert("entities".to_string(), entity_context_data.clone()); + + if return_candidate_context { + let entities = self.entities.values().cloned().collect(); + let relationships = self.relationships.values().cloned().collect(); + + let candidate_context_data = get_candidate_context( + &selected_entities, + &entities, + &relationships, + include_entity_rank, + rank_description, + include_relationship_weight, + )?; + + for (key, candidate_df) in candidate_context_data { + if !final_context_data.contains_key(&key) { + final_context_data.insert(key.clone(), candidate_df); + } else { + let in_context_df = final_context_data.get_mut(&key).unwrap(); + + if in_context_df.get_column_names().contains(&"id".to_string().as_str()) + && candidate_df.get_column_names().contains(&"id".to_string().as_str()) + { + let context_ids = in_context_df.column("id")?; + let candidate_ids = candidate_df.column("id")?; + let mut new_data = candidate_df.clone(); + let in_context = is_in(candidate_ids, context_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + final_context_data.insert(key.clone(), new_data); + } else { + in_context_df.with_column(Series::new("in_context", vec![true; in_context_df.height()]))?; + } + } + } + } else { + for (_key, context_df) in final_context_data.iter_mut() { + context_df.with_column(Series::new("in_context", vec![true; context_df.height()]))?; + } + } + + Ok((final_context_text, final_context_data)) + } + + fn _build_text_unit_context( + &self, + selected_entities: Vec, + max_tokens: usize, + return_candidate_context: bool, + column_delimiter: &str, + context_name: &str, + ) -> anyhow::Result<(String, HashMap)> { + if selected_entities.is_empty() || self.text_units.is_empty() { + return Ok((String::new(), HashMap::new())); + } + + let mut selected_text_units: Vec = Vec::new(); + + for (index, entity) in selected_entities.iter().enumerate() { + if let Some(text_unit_ids) = &entity.text_unit_ids { + for text_id in text_unit_ids { + if !selected_text_units.iter().any(|unit| &unit.id == text_id) + && self.text_units.contains_key(text_id) + { + let mut selected_unit = self.text_units[text_id].clone(); + let num_relationships = count_relationships(&selected_unit, entity, &self.relationships); + selected_unit + .attributes + .as_mut() + .unwrap_or(&mut HashMap::new()) + .insert("entity_order".to_string(), index.to_string()); + selected_unit + .attributes + .as_mut() + .unwrap_or(&mut HashMap::new()) + .insert("num_relationships".to_string(), num_relationships.to_string()); + selected_text_units.push(selected_unit); + } + } + } + } + + selected_text_units.sort_by(|a, b| { + let a_order = a + .attributes + .as_ref() + .unwrap() + .get("entity_order") + .unwrap() + .parse::() + .unwrap(); + let b_order = b + .attributes + .as_ref() + .unwrap() + .get("entity_order") + .unwrap() + .parse::() + .unwrap(); + + let a_relationships = a + .attributes + .as_ref() + .unwrap() + .get("num_relationships") + .unwrap() + .parse::() + .unwrap(); + let b_relationships = b + .attributes + .as_ref() + .unwrap() + .get("num_relationships") + .unwrap() + .parse::() + .unwrap(); + + a_order + .cmp(&b_order) + .then_with(|| b_relationships.cmp(&a_relationships)) + }); + + for unit in &mut selected_text_units { + unit.attributes.as_mut().unwrap().remove("entity_order"); + unit.attributes.as_mut().unwrap().remove("num_relationships"); + } + + let (context_text, context_data) = build_text_unit_context( + selected_text_units, + self.num_tokens_fn, + column_delimiter, + false, + max_tokens, + context_name, + 86, + )?; + + let mut context_data = context_data; + if return_candidate_context { + let candidate_context_data = + get_candidate_text_units(&selected_entities, &self.text_units.values().cloned().collect())?; + + let context_key = context_name.to_lowercase(); + if !context_data.contains_key(&context_key) { + let mut new_data = candidate_context_data.clone(); + new_data.with_column(Series::new("in_context", vec![false; candidate_context_data.height()]))?; + context_data.insert(context_key.to_string(), new_data); + } else { + let existing_data = context_data.get(&context_key).unwrap(); + if candidate_context_data + .get_column_names() + .contains(&"id".to_string().as_str()) + && existing_data.get_column_names().contains(&"id".to_string().as_str()) + { + let existing_ids = existing_data.column("id")?; + let context_ids = candidate_context_data.column("id")?; + let mut new_data = candidate_context_data.clone(); + let in_context = is_in(context_ids, existing_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + context_data.insert(context_key.to_string(), new_data); + } else { + let mut existing_data = existing_data.clone(); + existing_data.with_column(Series::new("in_context", vec![true; existing_data.height()]))?; + context_data.insert(context_key.to_string(), existing_data); + } + } + } - Ok(("".to_string(), HashMap::new())) + Ok((context_text, context_data)) } } From 863e0667ad4b46722458647ff51c602a22384a3e Mon Sep 17 00:00:00 2001 From: benolt Date: Mon, 26 Aug 2024 16:20:34 +0200 Subject: [PATCH 15/23] vector store, lancedb connect --- Cargo.lock | 4 + shinkai-libs/shinkai-graphrag/Cargo.toml | 6 +- .../src/context_builder/entity_extraction.rs | 2 +- .../src/context_builder/local_context.rs | 4 +- .../shinkai-graphrag/src/indexer_adapters.rs | 106 ++++++++++ .../src/indexer_adapters/indexer_reports.rs | 159 --------------- .../src/indexer_adapters/mod.rs | 2 - .../loaders/dfs.rs} | 181 +++++++++++++----- .../shinkai-graphrag/src/input/loaders/mod.rs | 1 + .../shinkai-graphrag/src/input/mod.rs | 2 + .../retrieval/community_reports.rs | 0 .../src/{ => input}/retrieval/entities.rs | 0 .../src/{ => input}/retrieval/mod.rs | 0 .../{ => input}/retrieval/relationships.rs | 0 .../src/{ => input}/retrieval/text_units.rs | 0 shinkai-libs/shinkai-graphrag/src/lib.rs | 2 +- .../src/search/local_search/mixed_context.rs | 2 +- .../src/vector_stores/lancedb.rs | 120 +++++++++++- .../src/vector_stores/vector_store.rs | 3 + .../tests/global_search_tests.rs | 2 +- 20 files changed, 377 insertions(+), 219 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs delete mode 100644 shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs delete mode 100644 shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs rename shinkai-libs/shinkai-graphrag/src/{indexer_adapters/indexer_entities.rs => input/loaders/dfs.rs} (64%) create mode 100644 shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/input/mod.rs rename shinkai-libs/shinkai-graphrag/src/{ => input}/retrieval/community_reports.rs (100%) rename shinkai-libs/shinkai-graphrag/src/{ => input}/retrieval/entities.rs (100%) rename shinkai-libs/shinkai-graphrag/src/{ => input}/retrieval/mod.rs (100%) rename shinkai-libs/shinkai-graphrag/src/{ => input}/retrieval/relationships.rs (100%) rename shinkai-libs/shinkai-graphrag/src/{ => input}/retrieval/text_units.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index f3f804d53..e5beb0300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10400,9 +10400,13 @@ name = "shinkai-graphrag" version = "0.1.0" dependencies = [ "anyhow", + "arrow 52.2.0", + "arrow-array 52.2.0", + "arrow-schema 52.2.0", "async-openai", "async-trait", "futures", + "lancedb", "polars", "polars-lazy", "rand 0.8.5", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 0212ad59d..ca6579cc7 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -5,8 +5,12 @@ edition = "2021" [dependencies] anyhow = "1.0.86" +arrow = "52.1" +arrow-array = "52.1" +arrow-schema = "52.1" async-trait = "0.1.74" futures = "0.3.30" +lancedb = "0.8.0" polars = { version = "0.41.3", features = ["dtype-struct", "is_in", "lazy", "parquet"] } polars-lazy = "0.41.3" rand = "0.8.5" @@ -18,4 +22,4 @@ uuid = { version = "1.6.1", features = ["v4"] } [dev-dependencies] async-openai = "0.23.4" reqwest = { version = "0.11.26", features = ["json"] } -tiktoken-rs = "0.5.9" \ No newline at end of file +tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs index ecf216f22..b6d9547fb 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -1,9 +1,9 @@ use std::collections::HashSet; use crate::{ + input::retrieval::entities::{get_entity_by_key, get_entity_by_name}, llm::llm::BaseTextEmbedding, models::Entity, - retrieval::entities::{get_entity_by_key, get_entity_by_name}, vector_stores::vector_store::VectorStore, }; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs index 4d31cf61a..eb73a6a53 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -6,14 +6,14 @@ use std::{ use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use crate::{ - models::{Entity, Relationship}, - retrieval::{ + input::retrieval::{ entities::to_entity_dataframe, relationships::{ get_candidate_relationships, get_entities_from_relationships, get_in_network_relationships, get_out_network_relationships, to_relationship_dataframe, }, }, + models::{Entity, Relationship}, }; pub fn build_entity_context( diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs new file mode 100644 index 000000000..af5722e10 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs @@ -0,0 +1,106 @@ +use polars::prelude::*; +use polars_lazy::dsl::col; + +use crate::{ + input::loaders::dfs::{read_community_reports, read_entities}, + models::{CommunityReport, Entity}, +}; + +pub fn read_indexer_entities( + final_nodes: &DataFrame, + final_entities: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_embedding_df = final_entities.clone(); + + let entity_df = entity_df + .lazy() + .rename(["title", "degree"], ["name", "rank"]) + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .with_column(col("rank").cast(DataType::Int32)) + .group_by([col("name"), col("rank")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .join( + entity_embedding_df.lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let entities = read_entities( + entity_df, + "id", + Some("human_readable_id"), + "name", + Some("type"), + Some("description"), + None, + Some("description_embedding"), + None, + Some("community"), + Some("text_unit_ids"), + None, + Some("rank"), + )?; + + Ok(entities) +} + +pub fn read_indexer_reports( + final_community_reports: &DataFrame, + final_nodes: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let filtered_community_df = entity_df + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .group_by([col("title")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .filter(len().over([col("community")]).gt(lit(1))) + .collect()?; + + let report_df = final_community_reports.clone(); + let report_df = filter_under_community_level(&report_df, community_level)?; + + let report_df = report_df + .lazy() + .join( + filtered_community_df.lazy(), + [col("community")], + [col("community")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let reports = read_community_reports( + report_df, + "community", + Some("community"), + "title", + "community", + "summary", + "full_content", + Some("rank"), + None, + None, + )?; + Ok(reports) +} + +fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { + let mask = df.column("level")?.i64()?.lt_eq(community_level); + let result = df.filter(&mask)?; + + Ok(result) +} diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs deleted file mode 100644 index a7f2fd4ba..000000000 --- a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_reports.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::collections::HashSet; - -use polars::prelude::*; -use polars_lazy::dsl::col; - -use crate::models::CommunityReport; - -use super::indexer_entities::get_field; - -pub fn read_indexer_reports( - final_community_reports: &DataFrame, - final_nodes: &DataFrame, - community_level: u32, -) -> anyhow::Result> { - let entity_df = final_nodes.clone(); - let entity_df = filter_under_community_level(&entity_df, community_level)?; - - let filtered_community_df = entity_df - .lazy() - .with_column(col("community").fill_null(lit(-1))) - .with_column(col("community").cast(DataType::Int32)) - .group_by([col("title")]) - .agg([col("community").max()]) - .with_column(col("community").cast(DataType::String)) - .filter(len().over([col("community")]).gt(lit(1))) - .collect()?; - - let report_df = final_community_reports.clone(); - let report_df = filter_under_community_level(&report_df, community_level)?; - - let report_df = report_df - .lazy() - .join( - filtered_community_df.lazy(), - [col("community")], - [col("community")], - JoinArgs::new(JoinType::Inner), - ) - .collect()?; - - let reports = read_community_reports( - report_df, - "community", - Some("community"), - "title", - "community", - "summary", - "full_content", - Some("rank"), - None, - None, - )?; - Ok(reports) -} - -pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { - let mask = df.column("level")?.i64()?.lt_eq(community_level); - let result = df.filter(&mask)?; - - Ok(result) -} - -pub fn read_community_reports( - df: DataFrame, - id_col: &str, - short_id_col: Option<&str>, - title_col: &str, - community_col: &str, - summary_col: &str, - content_col: &str, - rank_col: Option<&str>, - _summary_embedding_col: Option<&str>, - _content_embedding_col: Option<&str>, - // attributes_cols: Option<&[&str]>, -) -> anyhow::Result> { - let column_names = [ - Some(id_col), - short_id_col, - Some(title_col), - Some(community_col), - Some(summary_col), - Some(content_col), - rank_col, - ] - .iter() - .filter_map(|&v| v.map(|v| v.to_string())) - .collect::>(); - - let column_names: Vec = column_names.into_iter().collect::>().into_vec(); - - let mut df = df; - df.as_single_chunk_par(); - let mut iters = df - .columns(column_names.clone())? - .iter() - .map(|s| s.iter()) - .collect::>(); - - let mut rows = Vec::new(); - for _row in 0..df.height() { - let mut row_values = Vec::new(); - for iter in &mut iters { - let value = iter.next(); - if let Some(value) = value { - row_values.push(value); - } - } - rows.push(row_values); - } - - let mut reports = Vec::new(); - for (idx, row) in rows.iter().enumerate() { - let report = CommunityReport { - id: get_field(&row, id_col, &column_names) - .map(|id| id.to_string()) - .unwrap_or(String::new()), - short_id: Some( - short_id_col - .map(|short_id| get_field(&row, short_id, &column_names)) - .flatten() - .map(|short_id| short_id.to_string()) - .unwrap_or(idx.to_string()), - ), - title: get_field(&row, title_col, &column_names) - .map(|title| title.to_string()) - .unwrap_or(String::new()), - community_id: get_field(&row, community_col, &column_names) - .map(|community| community.to_string()) - .unwrap_or(String::new()), - summary: get_field(&row, summary_col, &column_names) - .map(|summary| summary.to_string()) - .unwrap_or(String::new()), - full_content: get_field(&row, content_col, &column_names) - .map(|content| content.to_string()) - .unwrap_or(String::new()), - rank: rank_col - .map(|rank_col| { - get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) - }) - .flatten(), - summary_embedding: None, - full_content_embedding: None, - attributes: None, - }; - reports.push(report); - } - - let mut unique_reports: Vec = Vec::new(); - let mut report_ids: HashSet = HashSet::new(); - - for report in reports { - if !report_ids.contains(&report.id) { - unique_reports.push(report.clone()); - report_ids.insert(report.id); - } - } - - Ok(unique_reports) -} diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs deleted file mode 100644 index c49aae604..000000000 --- a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod indexer_entities; -pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs similarity index 64% rename from shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_entities.rs rename to shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs index d6e38d3da..8cd531f0b 100644 --- a/shinkai-libs/shinkai-graphrag/src/indexer_adapters/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs @@ -1,56 +1,39 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; -use polars::prelude::*; -use polars_lazy::dsl::col; +use polars::{ + frame::DataFrame, + prelude::{AnyValue, ChunkedArray, IntoVec, StringChunked}, +}; -use crate::models::Entity; +use crate::{ + models::{CommunityReport, Entity}, + vector_stores::vector_store::{VectorStore, VectorStoreDocument}, +}; -use super::indexer_reports::filter_under_community_level; +pub fn store_entity_semantic_embeddings( + entities: Vec, + mut vectorstore: Box, +) -> Box { + let documents: Vec = entities + .into_iter() + .map(|entity| { + let mut attributes = HashMap::new(); + attributes.insert("title".to_string(), entity.title.clone()); + if let Some(entity_attributes) = entity.attributes { + attributes.extend(entity_attributes); + } -pub fn read_indexer_entities( - final_nodes: &DataFrame, - final_entities: &DataFrame, - community_level: u32, -) -> anyhow::Result> { - let entity_df = final_nodes.clone(); - let entity_df = filter_under_community_level(&entity_df, community_level)?; - - let entity_embedding_df = final_entities.clone(); - - let entity_df = entity_df - .lazy() - .rename(["title", "degree"], ["name", "rank"]) - .with_column(col("community").fill_null(lit(-1))) - .with_column(col("community").cast(DataType::Int32)) - .with_column(col("rank").cast(DataType::Int32)) - .group_by([col("name"), col("rank")]) - .agg([col("community").max()]) - .with_column(col("community").cast(DataType::String)) - .join( - entity_embedding_df.lazy(), - [col("name")], - [col("name")], - JoinArgs::new(JoinType::Inner), - ) - .collect()?; - - let entities = read_entities( - entity_df, - "id", - Some("human_readable_id"), - "name", - Some("type"), - Some("description"), - None, - Some("description_embedding"), - None, - Some("community"), - Some("text_unit_ids"), - None, - Some("rank"), - )?; - - Ok(entities) + VectorStoreDocument { + id: entity.id, + text: entity.description, + vector: entity.description_embedding, + attributes, + } + }) + .collect(); + + vectorstore.load_documents(documents, true); + vectorstore } pub fn read_entities( @@ -234,7 +217,105 @@ pub fn read_entities( Ok(unique_entities) } -pub fn get_field<'a>( +pub fn read_community_reports( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + community_col: &str, + summary_col: &str, + content_col: &str, + rank_col: Option<&str>, + _summary_embedding_col: Option<&str>, + _content_embedding_col: Option<&str>, + // attributes_cols: Option<&[&str]>, +) -> anyhow::Result> { + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + Some(community_col), + Some(summary_col), + Some(content_col), + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + let column_names: Vec = column_names.into_iter().collect::>().into_vec(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut reports = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = CommunityReport { + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + community_id: get_field(&row, community_col, &column_names) + .map(|community| community.to_string()) + .unwrap_or(String::new()), + summary: get_field(&row, summary_col, &column_names) + .map(|summary| summary.to_string()) + .unwrap_or(String::new()), + full_content: get_field(&row, content_col, &column_names) + .map(|content| content.to_string()) + .unwrap_or(String::new()), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }) + .flatten(), + summary_embedding: None, + full_content_embedding: None, + attributes: None, + }; + reports.push(report); + } + + let mut unique_reports: Vec = Vec::new(); + let mut report_ids: HashSet = HashSet::new(); + + for report in reports { + if !report_ids.contains(&report.id) { + unique_reports.push(report.clone()); + report_ids.insert(report.id); + } + } + + Ok(unique_reports) +} + +fn get_field<'a>( row: &'a Vec>, column_name: &'a str, column_names: &'a Vec, diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs new file mode 100644 index 000000000..289f06621 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs @@ -0,0 +1 @@ +pub mod dfs; \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/input/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/mod.rs new file mode 100644 index 000000000..6ce304fa8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/mod.rs @@ -0,0 +1,2 @@ +pub mod loaders; +pub mod retrieval; diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/retrieval/community_reports.rs rename to shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/retrieval/entities.rs rename to shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/mod.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/retrieval/mod.rs rename to shinkai-libs/shinkai-graphrag/src/input/retrieval/mod.rs diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/retrieval/relationships.rs rename to shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs diff --git a/shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/retrieval/text_units.rs rename to shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs index 30ceb1057..35060f002 100644 --- a/shinkai-libs/shinkai-graphrag/src/lib.rs +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -1,7 +1,7 @@ pub mod context_builder; pub mod indexer_adapters; +pub mod input; pub mod llm; pub mod models; -pub mod retrieval; pub mod search; pub mod vector_stores; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index 2fd3864a2..f161eb524 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -13,9 +13,9 @@ use crate::{ local_context::{build_entity_context, build_relationship_context, get_candidate_context}, source_context::{build_text_unit_context, count_relationships}, }, + input::retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, llm::llm::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, - retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, vector_stores::vector_store::VectorStore, }; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs index e40415053..5e7afd338 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -1 +1,119 @@ -pub struct LanceDBVectorStore {} +use std::sync::Arc; + +use arrow::datatypes::Float64Type; +use arrow_array::{FixedSizeListArray, Float64Array, RecordBatch, RecordBatchIterator, StringArray}; +use lancedb::{ + arrow::arrow_schema::{DataType, Field, Schema}, + connect, Connection, +}; +use serde_json::json; + +use super::vector_store::{VectorStore, VectorStoreDocument, VectorStoreSearchResult}; + +pub struct LanceDBVectorStore { + collection_name: String, + db_connection: Option, +} + +impl LanceDBVectorStore { + pub fn new(collection_name: String) -> Self { + LanceDBVectorStore { + collection_name, + db_connection: None, + } + } + + pub async fn connect(&mut self, db_uri: &str) -> anyhow::Result<()> { + let connection = connect(db_uri).execute().await?; + self.db_connection = Some(connection); + Ok(()) + } + + fn similarity_search_by_vector(&self, query_embedding: Vec, k: usize) -> Vec { + Vec::new() + } +} + +impl VectorStore for LanceDBVectorStore { + fn similarity_search_by_text( + &self, + text: &str, + text_embedder: &dyn Fn(&str) -> Vec, + k: usize, + ) -> Vec { + let query_embedding = text_embedder(text); + + if query_embedding.is_empty() { + return vec![]; + } + + self.similarity_search_by_vector(query_embedding, k) + } + + fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()> { + let data: Vec<_> = documents + .into_iter() + .filter(|document| document.vector.is_some()) + .collect(); + + let data = if data.is_empty() { None } else { Some(data) }; + + let vector_len = data + .as_ref() + .and_then(|data| data.first()) + .and_then(|document| document.vector.as_ref()) + .map(|vector| vector.len()) + .unwrap_or_default(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("text", DataType::Utf8, true), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float64, false)), + vector_len.try_into().unwrap_or_default(), + ), + true, + ), + Field::new("attributes", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from( + data.as_ref() + .map(|data| data.iter().map(|document| document.id.clone()).collect::>()) + .unwrap_or_default(), + )), + Arc::new(StringArray::from( + data.as_ref() + .map(|data| { + data.iter() + .map(|document| document.text.clone().unwrap_or_default()) + .collect::>() + }) + .unwrap_or_default(), + )), + Arc::new(FixedSizeListArray::from_iter_primitive::( + data.as_ref() + .map(|data| data.iter().map(|document| document.vector.clone()).collect::>()), + vector_len.try_into().unwrap_or_default(), + )), + Arc::new(StringArray::from( + data.as_ref() + .map(|data| { + data.iter() + .map(|document| json!(document.attributes).to_string()) + .collect::>() + }) + .unwrap_or_default(), + )), + ], + ); + + let batch_iterator = RecordBatchIterator::new(vec![batch], schema.clone()); + Ok(()) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs index b72d5763b..d8a9bc246 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -8,6 +8,7 @@ pub struct VectorStoreSearchResult { pub struct VectorStoreDocument { pub id: String, pub text: Option, + pub vector: Option>, pub attributes: HashMap, } @@ -18,4 +19,6 @@ pub trait VectorStore { text_embedder: &dyn Fn(&str) -> Vec, k: usize, ) -> Vec; + + fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()>; } diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index a808c7ff2..06a1f1903 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -1,7 +1,7 @@ use polars::{io::SerReader, prelude::ParquetReader}; use shinkai_graphrag::{ context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}, - indexer_adapters::{indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports}, + indexer_adapters::{read_indexer_entities, read_indexer_reports}, llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; From 114b9a6ae094004370d736636017a088bd49f358 Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 27 Aug 2024 13:51:54 +0200 Subject: [PATCH 16/23] similarity search, read relationships, text units --- .../src/context_builder/entity_extraction.rs | 22 +- .../shinkai-graphrag/src/indexer_adapters.rs | 41 ++- .../shinkai-graphrag/src/input/loaders/dfs.rs | 319 +++++++++++++++++- .../src/search/local_search/mixed_context.rs | 13 +- .../src/vector_stores/lancedb.rs | 236 +++++++++---- .../src/vector_stores/vector_store.rs | 12 +- 6 files changed, 553 insertions(+), 90 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs index b6d9547fb..933936d0b 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -1,33 +1,31 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use crate::{ input::retrieval::entities::{get_entity_by_key, get_entity_by_name}, llm::llm::BaseTextEmbedding, models::Entity, - vector_stores::vector_store::VectorStore, + vector_stores::{lancedb::LanceDBVectorStore, vector_store::VectorStore}, }; -pub fn map_query_to_entities( +pub async fn map_query_to_entities( query: &str, - text_embedding_vectorstore: &Box, - text_embedder: &Box, + text_embedding_vectorstore: &LanceDBVectorStore, + text_embedder: &Box, all_entities: &Vec, embedding_vectorstore_key: &str, include_entity_names: Option>, exclude_entity_names: Option>, k: usize, oversample_scaler: usize, -) -> Vec { +) -> anyhow::Result> { let include_entity_names = include_entity_names.unwrap_or_else(Vec::new); let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_else(Vec::new).into_iter().collect(); let mut matched_entities = Vec::new(); if !query.is_empty() { - let search_results = text_embedding_vectorstore.similarity_search_by_text( - query, - &|t| text_embedder.embed(t), - k * oversample_scaler, - ); + let search_results = text_embedding_vectorstore + .similarity_search_by_text(query, text_embedder, k * oversample_scaler) + .await?; for result in search_results { if let Some(matched) = get_entity_by_key(all_entities, &embedding_vectorstore_key, &result.document.id) { @@ -48,5 +46,5 @@ pub fn map_query_to_entities( } included_entities.extend(matched_entities); - included_entities + Ok(included_entities) } diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs index af5722e10..9b2f97b71 100644 --- a/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs +++ b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs @@ -1,9 +1,11 @@ +use std::vec; + use polars::prelude::*; use polars_lazy::dsl::col; use crate::{ - input::loaders::dfs::{read_community_reports, read_entities}, - models::{CommunityReport, Entity}, + input::loaders::dfs::{read_community_reports, read_entities, read_relationships, read_text_units}, + models::{CommunityReport, Entity, Relationship, TextUnit}, }; pub fn read_indexer_entities( @@ -98,6 +100,41 @@ pub fn read_indexer_reports( Ok(reports) } +pub fn read_indexer_relationships(final_relationships: &DataFrame) -> anyhow::Result> { + let relationships = read_relationships( + final_relationships.clone(), + "id", + Some("human_readable_id"), + "source", + "target", + Some("description"), + None, + Some("weight"), + Some("text_unit_ids"), + None, + Some(vec!["rank"]), + )?; + + Ok(relationships) +} + +pub fn read_indexer_text_units(final_text_units: &DataFrame) -> anyhow::Result> { + let text_units = read_text_units( + final_text_units.clone(), + "id", + None, + "text", + Some("entity_ids"), + Some("relationship_ids"), + Some("n_tokens"), + Some("document_ids"), + Some("text_embedding"), + None, + )?; + + Ok(text_units) +} + fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { let mask = df.column("level")?.i64()?.lt_eq(community_level); let result = df.filter(&mask)?; diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs index 8cd531f0b..378b09a39 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs @@ -6,14 +6,17 @@ use polars::{ }; use crate::{ - models::{CommunityReport, Entity}, - vector_stores::vector_store::{VectorStore, VectorStoreDocument}, + models::{CommunityReport, Entity, Relationship, TextUnit}, + vector_stores::{ + lancedb::LanceDBVectorStore, + vector_store::{VectorStore, VectorStoreDocument}, + }, }; -pub fn store_entity_semantic_embeddings( +pub async fn store_entity_semantic_embeddings( entities: Vec, - mut vectorstore: Box, -) -> Box { + mut vectorstore: LanceDBVectorStore, +) -> anyhow::Result { let documents: Vec = entities .into_iter() .map(|entity| { @@ -32,8 +35,8 @@ pub fn store_entity_semantic_embeddings( }) .collect(); - vectorstore.load_documents(documents, true); - vectorstore + vectorstore.load_documents(documents, true).await?; + Ok(vectorstore) } pub fn read_entities( @@ -315,6 +318,308 @@ pub fn read_community_reports( Ok(unique_reports) } +pub fn read_relationships( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + source_col: &str, + target_col: &str, + description_col: Option<&str>, + description_embedding_col: Option<&str>, + weight_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + attributes_cols: Option>, +) -> anyhow::Result> { + let mut column_names = [ + Some(id_col), + short_id_col, + Some(source_col), + Some(target_col), + description_col, + description_embedding_col, + weight_col, + text_unit_ids_col, + document_ids_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + attributes_cols.as_ref().map(|cols| { + cols.iter().for_each(|col| { + column_names.insert(col.to_string()); + }); + }); + + let column_names = column_names.into_iter().collect::>(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut relationships = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = Relationship { + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + source: get_field(&row, source_col, &column_names) + .map(|source| source.to_string()) + .unwrap_or(String::new()), + target: get_field(&row, target_col, &column_names) + .map(|target| target.to_string()) + .unwrap_or(String::new()), + description: description_col + .map(|description| get_field(&row, description, &column_names)) + .flatten() + .map(|description| description.to_string()), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(&row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + weight: weight_col + .map(|weight_col| { + get_field(&row, weight_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }) + .flatten(), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(&row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(&row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + attributes: attributes_cols.as_ref().map(|cols| { + cols.iter() + .map(|col| { + get_field(&row, col, &column_names) + .map(|v| (col.to_string(), v.to_string())) + .unwrap_or((String::new(), String::new())) + }) + .collect::>() + }), + }; + relationships.push(report); + } + + let mut unique_relationships: Vec = Vec::new(); + let mut relationship_ids: HashSet = HashSet::new(); + + for relationship in relationships { + if !relationship_ids.contains(&relationship.id) { + unique_relationships.push(relationship.clone()); + relationship_ids.insert(relationship.id); + } + } + + Ok(unique_relationships) +} + +pub fn read_text_units( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + text_col: &str, + entities_col: Option<&str>, + relationships_col: Option<&str>, + tokens_col: Option<&str>, + document_ids_col: Option<&str>, + embedding_col: Option<&str>, + attributes_cols: Option>, +) -> anyhow::Result> { + let mut column_names = [ + Some(id_col), + short_id_col, + Some(text_col), + entities_col, + relationships_col, + tokens_col, + document_ids_col, + embedding_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + attributes_cols.as_ref().map(|cols| { + cols.iter().for_each(|col| { + column_names.insert(col.to_string()); + }); + }); + + let column_names = column_names.into_iter().collect::>(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut text_units = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = TextUnit { + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + text: get_field(&row, text_col, &column_names) + .map(|text| text.to_string()) + .unwrap_or(String::new()), + entity_ids: entities_col.map(|entities_col| { + get_field(&row, entities_col, &column_names) + .map(|entity_ids| match entity_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + relationship_ids: relationships_col.map(|relationships_col| { + get_field(&row, relationships_col, &column_names) + .map(|relationship_ids| match relationship_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + text_embedding: embedding_col.map(|embedding_col| { + get_field(&row, embedding_col, &column_names) + .map(|embedding| match embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + n_tokens: tokens_col + .map(|tokens_col| { + get_field(&row, tokens_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }) + .flatten(), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(&row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + attributes: attributes_cols.as_ref().map(|cols| { + cols.iter() + .map(|col| { + get_field(&row, col, &column_names) + .map(|v| (col.to_string(), v.to_string())) + .unwrap_or((String::new(), String::new())) + }) + .collect::>() + }), + }; + text_units.push(report); + } + + let mut unique_text_units: Vec = Vec::new(); + let mut text_unit_ids: HashSet = HashSet::new(); + + for unit in text_units { + if !text_unit_ids.contains(&unit.id) { + unique_text_units.push(unit.clone()); + text_unit_ids.insert(unit.id); + } + } + + Ok(unique_text_units) +} + fn get_field<'a>( row: &'a Vec>, column_name: &'a str, diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index f161eb524..09bd00b92 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -16,7 +16,7 @@ use crate::{ input::retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, llm::llm::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, - vector_stores::vector_store::VectorStore, + vector_stores::lancedb::LanceDBVectorStore, }; #[derive(Debug, Clone)] @@ -69,8 +69,8 @@ pub fn default_local_context_params() -> MixedContextBuilderParams { pub struct LocalSearchMixedContext { entities: HashMap, - entity_text_embeddings: Box, - text_embedder: Box, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, text_units: HashMap, community_reports: HashMap, relationships: HashMap, @@ -81,8 +81,8 @@ pub struct LocalSearchMixedContext { impl LocalSearchMixedContext { pub fn new( entities: Vec, - entity_text_embeddings: Box, - text_embedder: Box, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, text_units: Option>, community_reports: Option>, relationships: Option>, @@ -169,7 +169,8 @@ impl LocalSearchMixedContext { Some(exclude_entity_names), top_k_mapped_entities, 2, - ); + ) + .await?; let mut final_context = Vec::new(); let mut final_context_data = HashMap::new(); diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs index 5e7afd338..8ee793426 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -1,18 +1,24 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use arrow::datatypes::Float64Type; -use arrow_array::{FixedSizeListArray, Float64Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{Array, Float64Array, ListArray, RecordBatch, RecordBatchIterator, StringArray}; +use futures::TryStreamExt; use lancedb::{ arrow::arrow_schema::{DataType, Field, Schema}, - connect, Connection, + connect, + query::{ExecutableQuery, QueryBase}, + Connection, Table, }; use serde_json::json; +use crate::llm::llm::BaseTextEmbedding; + use super::vector_store::{VectorStore, VectorStoreDocument, VectorStoreSearchResult}; pub struct LanceDBVectorStore { collection_name: String, db_connection: Option, + document_collection: Option, } impl LanceDBVectorStore { @@ -20,6 +26,7 @@ impl LanceDBVectorStore { LanceDBVectorStore { collection_name, db_connection: None, + document_collection: None, } } @@ -29,91 +36,200 @@ impl LanceDBVectorStore { Ok(()) } - fn similarity_search_by_vector(&self, query_embedding: Vec, k: usize) -> Vec { - Vec::new() + async fn similarity_search_by_vector( + &self, + query_embedding: Vec, + k: usize, + ) -> anyhow::Result> { + if let Some(document_collection) = &self.document_collection { + let records = document_collection + .query() + .limit(k) + .nearest_to(query_embedding)? + .execute() + .await? + .try_collect::>() + .await?; + + let mut results = Vec::new(); + for record in records { + let id_col = record + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let text_col = record + .column_by_name("text") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let vector_col = record + .column_by_name("vector") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let attributes_col = record + .column_by_name("attributes") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let distance_col = record + .column_by_name("_distance") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + if id_col.len() == 0 + || text_col.len() == 0 + || vector_col.len() == 0 + || attributes_col.len() == 0 + || distance_col.len() == 0 + { + continue; + } + + let id = id_col.value(0).to_string(); + let text = text_col.value(0).to_string(); + let vector: Vec = vector_col + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|value| value.unwrap()) + .collect(); + let attributes: HashMap = serde_json::from_str(attributes_col.value(0))?; + + let distance = distance_col.value(0); + + let score = 1.0 - distance.abs(); + + let doc = VectorStoreDocument { + id, + text: Some(text), + vector: Some(vector), + attributes, + }; + + results.push(VectorStoreSearchResult { document: doc, score }); + } + + return Ok(results); + } + + Ok(Vec::new()) } } impl VectorStore for LanceDBVectorStore { - fn similarity_search_by_text( + async fn similarity_search_by_text( &self, text: &str, - text_embedder: &dyn Fn(&str) -> Vec, + text_embedder: &Box, k: usize, - ) -> Vec { - let query_embedding = text_embedder(text); + ) -> anyhow::Result> { + let query_embedding = text_embedder.embed(text); if query_embedding.is_empty() { - return vec![]; + return Ok(vec![]); } - self.similarity_search_by_vector(query_embedding, k) + let results = self.similarity_search_by_vector(query_embedding, k).await?; + Ok(results) } - fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()> { + async fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()> { + let db_connection = self + .db_connection + .as_ref() + .ok_or_else(|| anyhow::anyhow!("LanceDB connection is not established"))?; + let data: Vec<_> = documents .into_iter() .filter(|document| document.vector.is_some()) .collect(); - let data = if data.is_empty() { None } else { Some(data) }; - - let vector_len = data - .as_ref() - .and_then(|data| data.first()) - .and_then(|document| document.vector.as_ref()) - .map(|vector| vector.len()) - .unwrap_or_default(); - let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("text", DataType::Utf8, true), Field::new( "vector", - DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float64, false)), - vector_len.try_into().unwrap_or_default(), - ), + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), true, ), Field::new("attributes", DataType::Utf8, false), ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from( - data.as_ref() - .map(|data| data.iter().map(|document| document.id.clone()).collect::>()) - .unwrap_or_default(), - )), - Arc::new(StringArray::from( - data.as_ref() - .map(|data| { - data.iter() - .map(|document| document.text.clone().unwrap_or_default()) - .collect::>() - }) - .unwrap_or_default(), - )), - Arc::new(FixedSizeListArray::from_iter_primitive::( - data.as_ref() - .map(|data| data.iter().map(|document| document.vector.clone()).collect::>()), - vector_len.try_into().unwrap_or_default(), - )), - Arc::new(StringArray::from( - data.as_ref() - .map(|data| { - data.iter() - .map(|document| json!(document.attributes).to_string()) - .collect::>() - }) - .unwrap_or_default(), - )), - ], - ); - - let batch_iterator = RecordBatchIterator::new(vec![batch], schema.clone()); + let batches = if !data.is_empty() { + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from( + data.iter().map(|document| document.id.clone()).collect::>(), + )), + Arc::new(StringArray::from( + data.iter() + .map(|document| document.text.clone().unwrap_or_default()) + .collect::>(), + )), + Arc::new(ListArray::from_iter_primitive::( + data.iter() + .map(|document| { + Some( + document + .vector + .as_ref() + .map(|v| v.iter().map(|f| Some(f.clone())).collect::>()) + .unwrap_or_default(), + ) + }) + .collect::>(), + )), + Arc::new(StringArray::from( + data.iter() + .map(|document| json!(document.attributes).to_string()) + .collect::>(), + )), + ], + )?; + + Some(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) + } else { + None + }; + + if overwrite { + if let Some(batches) = batches { + let table = db_connection + .create_table(&self.collection_name, Box::new(batches)) + .execute() + .await?; + + self.document_collection = Some(table); + } else { + let table = db_connection + .create_empty_table(&self.collection_name, schema.clone()) + .execute() + .await?; + + self.document_collection = Some(table); + } + } else { + let table = db_connection.open_table(&self.collection_name).execute().await?; + + if let Some(batches) = batches { + table.add(batches).execute().await?; + } + + self.document_collection = Some(table); + } + Ok(()) } } diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs index d8a9bc246..aa863f09f 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use crate::llm::llm::BaseTextEmbedding; + pub struct VectorStoreSearchResult { pub document: VectorStoreDocument, pub score: f64, @@ -16,9 +18,13 @@ pub trait VectorStore { fn similarity_search_by_text( &self, text: &str, - text_embedder: &dyn Fn(&str) -> Vec, + text_embedder: &Box, k: usize, - ) -> Vec; + ) -> impl std::future::Future>> + Send; - fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()>; + fn load_documents( + &mut self, + documents: Vec, + overwrite: bool, + ) -> impl std::future::Future> + Send; } From 711ccd37d1e6ff625b5ee59d381dda5d796ea85f Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 28 Aug 2024 15:17:09 +0200 Subject: [PATCH 17/23] test local search with openai, adjustments --- Cargo.lock | 112 ++++++++++--- shinkai-libs/shinkai-graphrag/Cargo.toml | 2 + .../src/context_builder/entity_extraction.rs | 4 +- .../shinkai-graphrag/src/input/loaders/dfs.rs | 34 +++- .../src/input/retrieval/relationships.rs | 23 ++- .../src/input/retrieval/text_units.rs | 9 +- .../src/llm/{llm.rs => base.rs} | 10 +- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 2 +- .../src/search/global_search/mod.rs | 2 +- .../{global_search.rs => search.rs} | 41 ++--- .../src/search/local_search/mixed_context.rs | 10 +- .../src/search/local_search/mod.rs | 2 +- .../{local_search.rs => search.rs} | 30 ++-- .../src/vector_stores/lancedb.rs | 46 ++++-- .../src/vector_stores/vector_store.rs | 6 +- .../tests/global_search_tests.rs | 10 +- .../tests/local_search_tests.rs | 150 ++++++++++++++++++ .../shinkai-graphrag/tests/utils/ollama.rs | 2 +- .../shinkai-graphrag/tests/utils/openai.rs | 142 +++++++++++++++-- 19 files changed, 481 insertions(+), 156 deletions(-) rename shinkai-libs/shinkai-graphrag/src/llm/{llm.rs => base.rs} (79%) rename shinkai-libs/shinkai-graphrag/src/search/global_search/{global_search.rs => search.rs} (93%) rename shinkai-libs/shinkai-graphrag/src/search/local_search/{local_search.rs => search.rs} (77%) create mode 100644 shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs diff --git a/Cargo.lock b/Cargo.lock index e5beb0300..5700f0254 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,7 +534,7 @@ dependencies = [ "arrow-schema 51.0.0", "chrono", "half", - "indexmap 2.1.0", + "indexmap 2.4.0", "lexical-core", "num", "serde", @@ -554,7 +554,7 @@ dependencies = [ "arrow-schema 52.2.0", "chrono", "half", - "indexmap 2.1.0", + "indexmap 2.4.0", "lexical-core", "num", "serde", @@ -2906,7 +2906,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "num_cpus", @@ -3066,7 +3066,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "paste", @@ -3095,7 +3095,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "paste", @@ -3141,7 +3141,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "once_cell", @@ -4663,7 +4663,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.1.0", + "indexmap 2.4.0", "slab", "tokio", "tokio-util 0.7.11", @@ -5301,9 +5301,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.1.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -6426,6 +6426,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg 1.1.0", + "rawpointer", +] + [[package]] name = "maybe-owned" version = "0.3.4" @@ -6768,6 +6778,36 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray-stats" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ebbe97acce52d06aebed4cd4a87c0941f4b2519b59b82b4feb5bd0ce003dfd" +dependencies = [ + "indexmap 2.4.0", + "itertools 0.13.0", + "ndarray", + "noisy_float", + "num-integer", + "num-traits", + "rand 0.8.5", +] + [[package]] name = "new_debug_unreachable" version = "1.0.4" @@ -6807,6 +6847,15 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits", +] + [[package]] name = "nom" version = "7.1.3" @@ -7173,7 +7222,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.1.0", + "indexmap 2.4.0", "js-sys", "once_cell", "pin-project-lite", @@ -7638,7 +7687,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset 0.4.2", - "indexmap 2.1.0", + "indexmap 2.4.0", ] [[package]] @@ -7920,7 +7969,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9d34169e64b3c7a80c8621a48adaf44e0cf62c78a9b25dd9dd35f1881a17cf9" dependencies = [ "base64 0.21.7", - "indexmap 2.1.0", + "indexmap 2.4.0", "line-wrap", "quick-xml 0.31.0", "serde", @@ -8057,7 +8106,7 @@ dependencies = [ "comfy-table", "either", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "num-traits", "once_cell", "polars-arrow", @@ -8201,7 +8250,7 @@ dependencies = [ "either", "hashbrown 0.14.5", "hex", - "indexmap 2.1.0", + "indexmap 2.4.0", "memchr", "num-traits", "polars-arrow", @@ -8359,7 +8408,7 @@ dependencies = [ "ahash 0.8.11", "bytemuck", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "num-traits", "once_cell", "polars-error", @@ -8415,6 +8464,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +[[package]] +name = "portable-atomic-util" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -9210,6 +9268,12 @@ dependencies = [ "bitflags 2.4.0", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -9623,7 +9687,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cbd33e0b668aea0ab238b9164523aca929096f9f40834700d71d91dd4888882" dependencies = [ "either", - "indexmap 2.1.0", + "indexmap 2.4.0", "rquickjs-core", "rquickjs-macro", ] @@ -9638,7 +9702,7 @@ dependencies = [ "chrono", "dlopen", "either", - "indexmap 2.1.0", + "indexmap 2.4.0", "phf 0.11.2", "relative-path", "rquickjs-sys", @@ -9653,7 +9717,7 @@ dependencies = [ "convert_case 0.6.0", "fnv", "ident_case", - "indexmap 2.1.0", + "indexmap 2.4.0", "phf_generator 0.11.2", "phf_shared 0.11.2", "proc-macro-crate 1.3.1", @@ -10274,7 +10338,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_derive", "serde_json", @@ -10407,6 +10471,8 @@ dependencies = [ "async-trait", "futures", "lancedb", + "ndarray", + "ndarray-stats", "polars", "polars-lazy", "rand 0.8.5", @@ -11913,7 +11979,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "toml_datetime", "winnow", ] @@ -11924,7 +11990,7 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d34d383cd00a163b4a5b85053df514d45bc330f6de7737edfe0a93311d1eaa03" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_spanned", "toml_datetime", @@ -12428,7 +12494,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_json", "utoipa-gen", @@ -13135,7 +13201,7 @@ dependencies = [ "crossbeam-utils", "displaydoc", "flate2", - "indexmap 2.1.0", + "indexmap 2.4.0", "num_enum", "thiserror", ] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index ca6579cc7..3545f3d77 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -21,5 +21,7 @@ uuid = { version = "1.6.1", features = ["v4"] } [dev-dependencies] async-openai = "0.23.4" +ndarray = "0.16.1" +ndarray-stats = "0.6.0" reqwest = { version = "0.11.26", features = ["json"] } tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs index 933936d0b..8e01adb0a 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, sync::Arc}; +use std::collections::HashSet; use crate::{ input::retrieval::entities::{get_entity_by_key, get_entity_by_name}, - llm::llm::BaseTextEmbedding, + llm::base::BaseTextEmbedding, models::Entity, vector_stores::{lancedb::LanceDBVectorStore, vector_store::VectorStore}, }; diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs index 378b09a39..b644967f6 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs @@ -15,8 +15,8 @@ use crate::{ pub async fn store_entity_semantic_embeddings( entities: Vec, - mut vectorstore: LanceDBVectorStore, -) -> anyhow::Result { + vectorstore: &mut LanceDBVectorStore, +) -> anyhow::Result<()> { let documents: Vec = entities .into_iter() .map(|entity| { @@ -29,14 +29,16 @@ pub async fn store_entity_semantic_embeddings( VectorStoreDocument { id: entity.id, text: entity.description, - vector: entity.description_embedding, + vector: entity + .description_embedding + .map(|v| v.into_iter().map(|f| f as f32).collect()), attributes, } }) .collect(); vectorstore.load_documents(documents, true).await?; - Ok(vectorstore) + Ok(()) } pub fn read_entities( @@ -55,6 +57,7 @@ pub fn read_entities( rank_col: Option<&str>, // attributes_cols: Option>, ) -> anyhow::Result> { + let df_column_names = df.get_column_names(); let column_names = [ Some(id_col), short_id_col, @@ -70,7 +73,10 @@ pub fn read_entities( rank_col, ] .iter() - .filter_map(|&v| v.map(|v| v.to_string())) + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) .collect::>(); let column_names = column_names.into_iter().collect::>().into_vec(); @@ -233,6 +239,7 @@ pub fn read_community_reports( _content_embedding_col: Option<&str>, // attributes_cols: Option<&[&str]>, ) -> anyhow::Result> { + let df_column_names = df.get_column_names(); let column_names = [ Some(id_col), short_id_col, @@ -243,7 +250,10 @@ pub fn read_community_reports( rank_col, ] .iter() - .filter_map(|&v| v.map(|v| v.to_string())) + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) .collect::>(); let column_names: Vec = column_names.into_iter().collect::>().into_vec(); @@ -331,6 +341,7 @@ pub fn read_relationships( document_ids_col: Option<&str>, attributes_cols: Option>, ) -> anyhow::Result> { + let df_column_names = df.get_column_names(); let mut column_names = [ Some(id_col), short_id_col, @@ -343,7 +354,10 @@ pub fn read_relationships( document_ids_col, ] .iter() - .filter_map(|&v| v.map(|v| v.to_string())) + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) .collect::>(); attributes_cols.as_ref().map(|cols| { @@ -479,6 +493,7 @@ pub fn read_text_units( embedding_col: Option<&str>, attributes_cols: Option>, ) -> anyhow::Result> { + let df_column_names = df.get_column_names(); let mut column_names = [ Some(id_col), short_id_col, @@ -490,7 +505,10 @@ pub fn read_text_units( embedding_col, ] .iter() - .filter_map(|&v| v.map(|v| v.to_string())) + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) .collect::>(); attributes_cols.as_ref().map(|cols| { diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs index 83353dd86..0651f9561 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs @@ -5,14 +5,14 @@ use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use crate::models::{Entity, Relationship}; pub fn get_in_network_relationships( - selected_entities: &Vec, - relationships: &Vec, + selected_entities: &[Entity], + relationships: &[Relationship], ranking_attribute: &str, ) -> Vec { let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); let selected_relationships: Vec = relationships - .clone() + .to_owned() .into_iter() .filter(|relationship| { selected_entity_names.contains(&relationship.source) && selected_entity_names.contains(&relationship.target) @@ -28,8 +28,8 @@ pub fn get_in_network_relationships( } pub fn get_out_network_relationships( - selected_entities: &Vec, - relationships: &Vec, + selected_entities: &[Entity], + relationships: &[Relationship], ranking_attribute: &str, ) -> Vec { let selected_entity_names: Vec = selected_entities.iter().map(|e| e.title.clone()).collect(); @@ -51,22 +51,19 @@ pub fn get_out_network_relationships( sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) } -pub fn get_candidate_relationships( - selected_entities: &Vec, - relationships: &Vec, -) -> Vec { +pub fn get_candidate_relationships(selected_entities: &[Entity], relationships: &[Relationship]) -> Vec { let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); relationships .iter() - .cloned() .filter(|relationship| { selected_entity_names.contains(&relationship.source) || selected_entity_names.contains(&relationship.target) }) + .cloned() .collect() } -pub fn get_entities_from_relationships(relationships: &Vec, entities: &Vec) -> Vec { +pub fn get_entities_from_relationships(relationships: &[Relationship], entities: &[Entity]) -> Vec { let selected_entity_names: Vec = relationships .iter() .flat_map(|relationship| vec![relationship.source.clone(), relationship.target.clone()]) @@ -74,8 +71,8 @@ pub fn get_entities_from_relationships(relationships: &Vec, entiti entities .iter() - .cloned() .filter(|entity| selected_entity_names.contains(&entity.title)) + .cloned() .collect() } @@ -194,7 +191,7 @@ pub fn to_relationship_dataframe( .attributes .unwrap_or_default() .keys() - .map(|s| s.clone()) + .cloned() .collect::>() } else { Vec::new() diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs index ad839b88e..3849361ec 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs @@ -4,10 +4,7 @@ use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; use crate::models::{Entity, TextUnit}; -pub fn get_candidate_text_units( - selected_entities: &Vec, - text_units: &Vec, -) -> anyhow::Result { +pub fn get_candidate_text_units(selected_entities: &Vec, text_units: &[TextUnit]) -> anyhow::Result { let mut selected_text_ids: HashSet = HashSet::new(); for entity in selected_entities { @@ -20,8 +17,8 @@ pub fn get_candidate_text_units( let selected_text_units: Vec = text_units .iter() - .cloned() .filter(|unit| selected_text_ids.contains(&unit.id)) + .cloned() .collect(); to_text_unit_dataframe(selected_text_units) @@ -39,7 +36,7 @@ pub fn to_text_unit_dataframe(text_units: Vec) -> anyhow::Result>() } else { Vec::new() diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/base.rs similarity index 79% rename from shinkai-libs/shinkai-graphrag/src/llm/llm.rs rename to shinkai-libs/shinkai-graphrag/src/llm/base.rs index 8d27e1563..4dc414de4 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/base.rs @@ -3,16 +3,12 @@ use std::collections::HashMap; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct BaseLLMCallback { pub response: Vec, } impl BaseLLMCallback { - pub fn new() -> Self { - BaseLLMCallback { response: Vec::new() } - } - pub fn on_llm_new_token(&mut self, token: &str) { self.response.push(token.to_string()); } @@ -29,7 +25,6 @@ pub enum MessageType { pub struct LLMParams { pub max_tokens: u32, pub temperature: f32, - pub response_format: HashMap, } #[async_trait] @@ -46,8 +41,7 @@ pub trait BaseLLM { #[async_trait] pub trait BaseTextEmbedding { - async fn aembed(&self, text: &str) -> Vec; - fn embed(&self, text: &str) -> Vec; + async fn aembed(&self, text: &str) -> anyhow::Result>; } pub enum GlobalSearchPhase { diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 214bbef7c..6cf245d4d 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1 +1 @@ -pub mod llm; +pub mod base; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs index 79f16f1e0..5aed5257c 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs @@ -1,2 +1,2 @@ -pub mod global_search; pub mod prompts; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs similarity index 93% rename from shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs rename to shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs index f4c15fb1c..716f55f77 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs @@ -5,7 +5,7 @@ use std::time::Instant; use crate::context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}; use crate::context_builder::context_builder::ConversationHistory; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use crate::llm::base::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; @@ -33,7 +33,7 @@ pub struct GlobalSearchResult { pub reduce_context_text: ContextText, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct GlobalSearchLLMCallback { response: Vec, map_response_contexts: Vec, @@ -41,14 +41,6 @@ pub struct GlobalSearchLLMCallback { } impl GlobalSearchLLMCallback { - pub fn new() -> Self { - GlobalSearchLLMCallback { - response: Vec::new(), - map_response_contexts: Vec::new(), - map_response_outputs: Vec::new(), - } - } - pub fn on_map_response_start(&mut self, map_response_contexts: Vec) { self.map_response_contexts = map_response_contexts; } @@ -83,7 +75,6 @@ pub struct GlobalSearchParams { pub response_type: String, pub allow_general_knowledge: bool, pub general_knowledge_inclusion_prompt: Option, - pub json_mode: bool, pub callbacks: Option>, pub max_data_tokens: usize, pub map_llm_params: LLMParams, @@ -102,7 +93,6 @@ impl GlobalSearch { response_type, allow_general_knowledge, general_knowledge_inclusion_prompt, - json_mode, callbacks, max_data_tokens, map_llm_params, @@ -112,14 +102,6 @@ impl GlobalSearch { let mut map_llm_params = map_llm_params; - if json_mode { - map_llm_params - .response_format - .insert("type".to_string(), "json_object".to_string()); - } else { - map_llm_params.response_format.remove("response_format"); - } - let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); let general_knowledge_inclusion_prompt = @@ -219,15 +201,16 @@ impl GlobalSearch { let start_time = Instant::now(); let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); - let mut search_messages = Vec::new(); - search_messages.push(HashMap::from([ - ("role".to_string(), "system".to_string()), - ("content".to_string(), search_prompt.clone()), - ])); - search_messages.push(HashMap::from([ - ("role".to_string(), "user".to_string()), - ("content".to_string(), query.to_string()), - ])); + let mut search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; let search_response = self .llm diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index 09bd00b92..b9487dc01 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -14,7 +14,7 @@ use crate::{ source_context::{build_text_unit_context, count_relationships}, }, input::retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, - llm::llm::BaseTextEmbedding, + llm::base::BaseTextEmbedding, models::{CommunityReport, Entity, Relationship, TextUnit}, vector_stores::lancedb::LanceDBVectorStore, }; @@ -194,7 +194,7 @@ impl LocalSearchMixedContext { } // build local (i.e. entity-relationship-covariate) context - let local_prop = 1 as f32 - community_prop - text_unit_prop; + let local_prop = 1_f32 - community_prop - text_unit_prop; let local_tokens = std::cmp::max((max_tokens as f32 * local_prop) as usize, 0); let (local_context, local_context_data) = self._build_local_context( selected_entities.clone(), @@ -566,8 +566,10 @@ impl LocalSearchMixedContext { let mut context_data = context_data; if return_candidate_context { - let candidate_context_data = - get_candidate_text_units(&selected_entities, &self.text_units.values().cloned().collect())?; + let candidate_context_data = get_candidate_text_units( + &selected_entities, + &self.text_units.values().cloned().collect::>(), + )?; let context_key = context_name.to_lowercase(); if !context_data.contains_key(&context_key) { diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs index 2908c58f2..eb73d0830 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs @@ -1,3 +1,3 @@ -pub mod local_search; pub mod mixed_context; pub mod prompts; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs similarity index 77% rename from shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs rename to shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs index fbbe65638..6c555edf8 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/local_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, time::Instant}; use crate::{ - llm::llm::{BaseLLM, LLMParams, MessageType}, + llm::base::{BaseLLM, LLMParams, MessageType}, search::base::{ContextData, ContextText, ResponseType}, }; @@ -54,25 +54,27 @@ impl LocalSearch { pub async fn asearch(&self, query: String) -> anyhow::Result { let start_time = Instant::now(); - let (context_text, context_records) = self - .context_builder - .build_context(self.context_builder_params.clone()) - .await?; + + let mut context_builder_params = self.context_builder_params.clone(); + context_builder_params.query = query.clone(); + + let (context_text, context_records) = self.context_builder.build_context(context_builder_params).await?; let search_prompt = self .system_prompt .replace("{context_data}", &context_text) .replace("{response_type}", &self.response_type); - let mut search_messages = Vec::new(); - search_messages.push(HashMap::from([ - ("role".to_string(), "system".to_string()), - ("content".to_string(), search_prompt.clone()), - ])); - search_messages.push(HashMap::from([ - ("role".to_string(), "user".to_string()), - ("content".to_string(), query.to_string()), - ])); + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; let search_response = self .llm diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs index 8ee793426..fb520dd53 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; -use arrow::datatypes::Float64Type; -use arrow_array::{Array, Float64Array, ListArray, RecordBatch, RecordBatchIterator, StringArray}; +use arrow::datatypes::Float32Type; +use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray}; use futures::TryStreamExt; use lancedb::{ arrow::arrow_schema::{DataType, Field, Schema}, @@ -11,7 +11,7 @@ use lancedb::{ }; use serde_json::json; -use crate::llm::llm::BaseTextEmbedding; +use crate::llm::base::BaseTextEmbedding; use super::vector_store::{VectorStore, VectorStoreDocument, VectorStoreSearchResult}; @@ -38,7 +38,7 @@ impl LanceDBVectorStore { async fn similarity_search_by_vector( &self, - query_embedding: Vec, + query_embedding: Vec, k: usize, ) -> anyhow::Result> { if let Some(document_collection) = &self.document_collection { @@ -69,7 +69,7 @@ impl LanceDBVectorStore { .column_by_name("vector") .unwrap() .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); let attributes_col = record .column_by_name("attributes") @@ -82,24 +82,24 @@ impl LanceDBVectorStore { .column_by_name("_distance") .unwrap() .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); - if id_col.len() == 0 - || text_col.len() == 0 - || vector_col.len() == 0 - || attributes_col.len() == 0 - || distance_col.len() == 0 + if id_col.is_empty() + || text_col.is_empty() + || vector_col.is_empty() + || attributes_col.is_empty() + || distance_col.is_empty() { continue; } let id = id_col.value(0).to_string(); let text = text_col.value(0).to_string(); - let vector: Vec = vector_col + let vector: Vec = vector_col .value(0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .iter() .map(|value| value.unwrap()) @@ -134,7 +134,7 @@ impl VectorStore for LanceDBVectorStore { text_embedder: &Box, k: usize, ) -> anyhow::Result> { - let query_embedding = text_embedder.embed(text); + let query_embedding = text_embedder.aembed(text).await?; if query_embedding.is_empty() { return Ok(vec![]); @@ -155,12 +155,21 @@ impl VectorStore for LanceDBVectorStore { .filter(|document| document.vector.is_some()) .collect(); + let vector_dimension = if !data.is_empty() { + data[0].vector.as_ref().map(|v| v.len()).unwrap_or_default() + } else { + 0 + }; + let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("text", DataType::Utf8, true), Field::new( "vector", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + vector_dimension.try_into().unwrap(), + ), true, ), Field::new("attributes", DataType::Utf8, false), @@ -178,18 +187,19 @@ impl VectorStore for LanceDBVectorStore { .map(|document| document.text.clone().unwrap_or_default()) .collect::>(), )), - Arc::new(ListArray::from_iter_primitive::( + Arc::new(FixedSizeListArray::from_iter_primitive::( data.iter() .map(|document| { Some( document .vector .as_ref() - .map(|v| v.iter().map(|f| Some(f.clone())).collect::>()) + .map(|v| v.iter().map(|f| Some(*f)).collect::>()) .unwrap_or_default(), ) }) .collect::>(), + vector_dimension.try_into().unwrap(), )), Arc::new(StringArray::from( data.iter() @@ -205,6 +215,8 @@ impl VectorStore for LanceDBVectorStore { }; if overwrite { + let _ = db_connection.drop_table(&self.collection_name).await; + if let Some(batches) = batches { let table = db_connection .create_table(&self.collection_name, Box::new(batches)) diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs index aa863f09f..e6a3503ab 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -1,16 +1,16 @@ use std::collections::HashMap; -use crate::llm::llm::BaseTextEmbedding; +use crate::llm::base::BaseTextEmbedding; pub struct VectorStoreSearchResult { pub document: VectorStoreDocument, - pub score: f64, + pub score: f32, } pub struct VectorStoreDocument { pub id: String, pub text: Option, - pub vector: Option>, + pub vector: Option>, pub attributes: HashMap, } diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 06a1f1903..f95eaa32b 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -2,8 +2,8 @@ use polars::{io::SerReader, prelude::ParquetReader}; use shinkai_graphrag::{ context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}, indexer_adapters::{read_indexer_entities, read_indexer_reports}, - llm::llm::LLMParams, - search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, + llm::base::LLMParams, + search::global_search::search::{GlobalSearch, GlobalSearchParams}, }; use utils::{ ollama::Ollama, @@ -67,13 +67,11 @@ async fn ollama_global_search_test() -> Result<(), Box> { let map_llm_params = LLMParams { max_tokens: 1000, temperature: 0.0, - response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), }; let reduce_llm_params = LLMParams { max_tokens: 2000, temperature: 0.0, - response_format: std::collections::HashMap::new(), }; // Perform global search @@ -87,7 +85,6 @@ async fn ollama_global_search_test() -> Result<(), Box> { response_type: String::from("multiple paragraphs"), allow_general_knowledge: false, general_knowledge_inclusion_prompt: None, - json_mode: true, callbacks: None, max_data_tokens: 5000, map_llm_params, @@ -164,13 +161,11 @@ async fn openai_global_search_test() -> Result<(), Box> { let map_llm_params = LLMParams { max_tokens: 1000, temperature: 0.0, - response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), }; let reduce_llm_params = LLMParams { max_tokens: 2000, temperature: 0.0, - response_format: std::collections::HashMap::new(), }; // Perform global search @@ -184,7 +179,6 @@ async fn openai_global_search_test() -> Result<(), Box> { response_type: String::from("multiple paragraphs"), allow_general_knowledge: false, general_knowledge_inclusion_prompt: None, - json_mode: true, callbacks: None, max_data_tokens: 12_000, map_llm_params, diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs new file mode 100644 index 000000000..767a43bf5 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -0,0 +1,150 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + indexer_adapters::{ + read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, + }, + input::loaders::dfs::store_entity_semantic_embeddings, + llm::base::LLMParams, + search::local_search::{ + mixed_context::{LocalSearchMixedContext, MixedContextBuilderParams}, + search::LocalSearch, + }, + vector_stores::lancedb::LanceDBVectorStore, +}; +use utils::openai::{num_tokens, ChatOpenAI, OpenAIEmbedding}; + +mod utils; + +#[tokio::test] +async fn openai_local_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + let embedding_model = std::env::var("GRAPHRAG_EMBEDDING_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key.clone()), llm_model, 5); + let text_embedder = OpenAIEmbedding::new(Some(api_key), embedding_model, 8191, 5); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let lancedb_uri = format!("{}/lancedb", input_dir); + + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + let relationship_table = "create_final_relationships"; + let text_unit_table = "create_final_text_units"; + let community_level = 2; + + // Read entities + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + let mut description_embedding_store = LanceDBVectorStore::new("entity_description_embeddings".to_string()); + description_embedding_store.connect(&lancedb_uri).await?; + + store_entity_semantic_embeddings(entities.clone(), &mut description_embedding_store).await?; + + println!("Entities ({}): {:?}", entity_df.height(), entity_df.head(Some(5))); + + // Read relationships + let mut relationship_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, relationship_table)).unwrap(); + let relationship_df = ParquetReader::new(&mut relationship_file).finish().unwrap(); + + let relationships = read_indexer_relationships(&relationship_df)?; + + println!( + "Relationships ({}): {:?}", + relationship_df.height(), + relationship_df.head(Some(5)) + ); + + // Read community reports + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + + println!("Reports ({}): {:?}", report_df.height(), report_df.head(Some(5))); + + // Read text units + let mut text_unit_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, text_unit_table)).unwrap(); + let text_unit_df = ParquetReader::new(&mut text_unit_file).finish().unwrap(); + + let text_units = read_indexer_text_units(&text_unit_df)?; + + println!( + "Text units ({}): {:?}", + text_unit_df.height(), + text_unit_df.head(Some(5)) + ); + + // Create local search context builder + let context_builder = LocalSearchMixedContext::new( + entities, + description_embedding_store, + Box::new(text_embedder), + Some(text_units), + Some(reports), + Some(relationships), + num_tokens, + "id".to_string(), + ); + + // Create local search engine + let local_context_params = MixedContextBuilderParams { + text_unit_prop: 0.5, + community_prop: 0.1, + top_k_mapped_entities: 10, + top_k_relationships: 10, + include_entity_rank: true, + include_relationship_weight: true, + include_community_rank: false, + return_candidate_context: false, + max_tokens: 12_000, + + query: "".to_string(), + include_entity_names: None, + exclude_entity_names: None, + rank_description: "number of relationships".to_string(), + relationship_ranking_attribute: "rank".to_string(), + use_community_summary: false, + min_community_rank: 0, + community_context_name: "Reports".to_string(), + column_delimiter: "|".to_string(), + }; + + let llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + }; + + let search_engine = LocalSearch::new( + Box::new(llm), + context_builder, + num_tokens, + llm_params, + local_context_params, + String::from("multiple paragraphs"), + None, + ); + + let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; + println!("Response: {:?}", result.response); + + let result = search_engine + .asearch("Tell me about Dr. Jordan Hayes".to_string()) + .await?; + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs index 41d3619b8..029336608 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; -use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; #[derive(Serialize, Deserialize, Debug)] pub struct OllamaResponse { diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 255d5b4e5..4c67a3f17 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -2,12 +2,16 @@ use async_openai::{ config::OpenAIConfig, types::{ ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionResponseFormat, - ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, + ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, }, Client, }; use async_trait::async_trait; -use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use ndarray::{Array1, Array2, Axis}; +use ndarray_stats::SummaryStatisticsExt; +use shinkai_graphrag::llm::base::{ + BaseLLM, BaseLLMCallback, BaseTextEmbedding, GlobalSearchPhase, LLMParams, MessageType, +}; use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { @@ -93,24 +97,9 @@ impl ChatOpenAI { .map(|m| Into::::into(m.clone())) .collect::>(); - let _response_format = if llm_params - .response_format - .get_key_value("type") - .is_some_and(|(_k, v)| v == "json_object") - { - ChatCompletionResponseFormat { - r#type: ChatCompletionResponseFormatType::JsonObject, - } - } else { - ChatCompletionResponseFormat { - r#type: ChatCompletionResponseFormatType::Text, - } - }; - let request = CreateChatCompletionRequestArgs::default() .max_tokens(llm_params.max_tokens) .temperature(llm_params.temperature) - //.response_format(response_format) .model(self.model.clone()) .messages(request_messages) .build()?; @@ -139,8 +128,127 @@ impl BaseLLM for ChatOpenAI { } } +pub struct OpenAIEmbedding { + pub api_key: Option, + pub model: String, + pub max_tokens: usize, // 8191 + pub max_retries: usize, +} + +impl OpenAIEmbedding { + pub fn new(api_key: Option, model: String, max_tokens: usize, max_retries: usize) -> Self { + OpenAIEmbedding { + api_key, + model, + max_tokens, + max_retries, + } + } + + async fn _aembed_with_retry(&self, text: &str) -> anyhow::Result> { + let mut retry_count = 0; + + loop { + match self._aembed(text).await { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _aembed(&self, text: &str) -> anyhow::Result> { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let request = CreateEmbeddingRequestArgs::default() + .model(&self.model) + .input([text.to_string()]) + .build()?; + + let response = client.embeddings().create(request).await?; + let embedding = response + .data + .get(0) + .map(|data| data.embedding.clone()) + .unwrap_or_default(); + + Ok(embedding) + } +} + +#[async_trait] +impl BaseTextEmbedding for OpenAIEmbedding { + async fn aembed(&self, text: &str) -> anyhow::Result> { + let token_chunks = chunk_text(text, self.max_tokens); + let mut chunk_embeddings = Vec::new(); + let mut chunk_lens = Vec::new(); + + for chunk in token_chunks { + let embedding = self._aembed_with_retry(&chunk).await?; + chunk_embeddings.push(embedding); + chunk_lens.push(chunk.len()); + } + + let rows = chunk_embeddings.len(); + let cols = chunk_embeddings[0].len(); + let flat_embeddings: Vec = chunk_embeddings.into_iter().flatten().collect(); + let array_embeddings = Array2::from_shape_vec((rows, cols), flat_embeddings).unwrap(); + + let array_lens = Array1::from_iter(chunk_lens.into_iter().map(|x| x as f32)); + + // Calculate the weighted average + let weighted_avg = array_embeddings.weighted_mean_axis(Axis(0), &array_lens).unwrap(); + + // Normalize the embeddings + let norm = weighted_avg.mapv(|x| x.powi(2)).sum().sqrt(); + let normalized_embeddings = weighted_avg / norm; + + Ok(normalized_embeddings.to_vec()) + } +} + pub fn num_tokens(text: &str) -> usize { let token_encoder = Tokenizer::Cl100kBase; let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); bpe.encode_with_special_tokens(text).len() } + +fn batched(iterable: impl Iterator, n: usize) -> impl Iterator> { + if n < 1 { + panic!("n must be at least one"); + } + + let mut it = iterable.peekable(); + std::iter::from_fn(move || { + let mut batch = Vec::with_capacity(n); + for _ in 0..n { + if let Some(item) = it.next() { + batch.push(item); + } else { + break; + } + } + if batch.is_empty() { + None + } else { + Some(batch) + } + }) +} + +fn chunk_text<'a>(text: &'a str, max_tokens: usize) -> impl Iterator + 'a { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + let tokens = bpe.encode_with_special_tokens(text); + + let chunk_iterator = batched(tokens.into_iter(), max_tokens); + chunk_iterator.map(move |chunk| bpe.decode(chunk).unwrap()) +} From 888c6bab5ccde9959259a88cf686fa1651a0b440 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 29 Aug 2024 16:00:52 +0200 Subject: [PATCH 18/23] test context --- .../src/context_builder/local_context.rs | 6 +-- .../src/search/global_search/search.rs | 4 +- .../tests/local_search_tests.rs | 45 +++++++++---------- .../shinkai-graphrag/tests/utils/openai.rs | 8 +++- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs index eb73a6a53..a883e577f 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -74,7 +74,7 @@ pub fn build_entity_context( new_context.push(entity.rank.unwrap_or(0).to_string()); records - .entry("rank") + .entry(rank_description) .or_insert_with(Vec::new) .push(entity.rank.map(|r| r.to_string()).unwrap_or_default()); } @@ -111,10 +111,10 @@ pub fn build_entity_context( let mut data_series = Vec::new(); for (header, data_values) in records { - if header == "rank" { + if include_entity_rank && header == rank_description { let data_values = data_values .iter() - .map(|v| v.parse::().unwrap_or(0.0)) + .map(|v| v.parse::().unwrap_or(0)) .collect::>(); let series = Series::new(header, data_values); data_series.push(series); diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs index 716f55f77..917b2b784 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs @@ -100,8 +100,6 @@ impl GlobalSearch { context_builder_params, } = global_search_params; - let mut map_llm_params = map_llm_params; - let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); let general_knowledge_inclusion_prompt = @@ -201,7 +199,7 @@ impl GlobalSearch { let start_time = Instant::now(); let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); - let mut search_messages = vec![ + let search_messages = vec![ HashMap::from([ ("role".to_string(), "system".to_string()), ("content".to_string(), search_prompt.clone()), diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs index 767a43bf5..33a767b57 100644 --- a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -6,7 +6,7 @@ use shinkai_graphrag::{ input::loaders::dfs::store_entity_semantic_embeddings, llm::base::LLMParams, search::local_search::{ - mixed_context::{LocalSearchMixedContext, MixedContextBuilderParams}, + mixed_context::{default_local_context_params, LocalSearchMixedContext}, search::LocalSearch, }, vector_stores::lancedb::LanceDBVectorStore, @@ -99,27 +99,16 @@ async fn openai_local_search_test() -> Result<(), Box> { ); // Create local search engine - let local_context_params = MixedContextBuilderParams { - text_unit_prop: 0.5, - community_prop: 0.1, - top_k_mapped_entities: 10, - top_k_relationships: 10, - include_entity_rank: true, - include_relationship_weight: true, - include_community_rank: false, - return_candidate_context: false, - max_tokens: 12_000, - - query: "".to_string(), - include_entity_names: None, - exclude_entity_names: None, - rank_description: "number of relationships".to_string(), - relationship_ranking_attribute: "rank".to_string(), - use_community_summary: false, - min_community_rank: 0, - community_context_name: "Reports".to_string(), - column_delimiter: "|".to_string(), - }; + let mut local_context_params = default_local_context_params(); + local_context_params.text_unit_prop = 0.5; + local_context_params.community_prop = 0.1; + local_context_params.top_k_mapped_entities = 10; + local_context_params.top_k_relationships = 10; + local_context_params.include_entity_rank = true; + local_context_params.include_relationship_weight = true; + local_context_params.include_community_rank = false; + local_context_params.return_candidate_context = false; + local_context_params.max_tokens = 12_000; let llm_params = LLMParams { max_tokens: 2000, @@ -144,7 +133,17 @@ async fn openai_local_search_test() -> Result<(), Box> { .await?; println!("Response: {:?}", result.response); - println!("Context: {:?}", result.context_data); + match result.context_data { + shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { + for (entity, df) in dict.iter() { + println!("Data: {} ({})", entity, df.height()); + println!("{:?}", df.head(Some(10))); + } + } + data => { + println!("Context data: {:?}", data); + } + } Ok(()) } diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 4c67a3f17..ab85366f5 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -1,8 +1,8 @@ use async_openai::{ config::OpenAIConfig, types::{ - ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionResponseFormat, - ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, CreateEmbeddingRequestArgs, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs, + CreateEmbeddingRequestArgs, }, Client, }; @@ -197,6 +197,10 @@ impl BaseTextEmbedding for OpenAIEmbedding { chunk_lens.push(chunk.len()); } + if chunk_embeddings.len() == 1 { + return Ok(chunk_embeddings.swap_remove(0)); + } + let rows = chunk_embeddings.len(); let cols = chunk_embeddings[0].len(); let flat_embeddings: Vec = chunk_embeddings.into_iter().flatten().collect(); From 55860bb322ac5ec5fe9d97cc33177d14c16a8043 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 30 Aug 2024 15:00:26 +0200 Subject: [PATCH 19/23] lancedb search by vector u;dates --- .../src/vector_stores/lancedb.rs | 62 +++++++++++-------- .../tests/local_search_tests.rs | 4 +- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs index fb520dd53..83c52cbff 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -94,30 +94,32 @@ impl LanceDBVectorStore { continue; } - let id = id_col.value(0).to_string(); - let text = text_col.value(0).to_string(); - let vector: Vec = vector_col - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|value| value.unwrap()) - .collect(); - let attributes: HashMap = serde_json::from_str(attributes_col.value(0))?; - - let distance = distance_col.value(0); - - let score = 1.0 - distance.abs(); - - let doc = VectorStoreDocument { - id, - text: Some(text), - vector: Some(vector), - attributes, - }; - - results.push(VectorStoreSearchResult { document: doc, score }); + for i in 0..record.num_rows() { + let id = id_col.value(i).to_string(); + let text = text_col.value(i).to_string(); + let vector: Vec = vector_col + .value(i) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|value| value.unwrap()) + .collect(); + let attributes: HashMap = serde_json::from_str(attributes_col.value(i))?; + + let distance = distance_col.value(i); + + let score = 1.0 - distance.abs(); + + let doc = VectorStoreDocument { + id, + text: Some(text), + vector: Some(vector), + attributes, + }; + + results.push(VectorStoreSearchResult { document: doc, score }); + } } return Ok(results); @@ -233,7 +235,17 @@ impl VectorStore for LanceDBVectorStore { self.document_collection = Some(table); } } else { - let table = db_connection.open_table(&self.collection_name).execute().await?; + let table = match db_connection.open_table(&self.collection_name).execute().await { + Ok(table) => table, + Err(_) => { + let table = db_connection + .create_empty_table(&self.collection_name, schema.clone()) + .execute() + .await?; + + table + } + }; if let Some(batches) = batches { table.add(batches).execute().await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs index 33a767b57..e20cca6a3 100644 --- a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -126,12 +126,12 @@ async fn openai_local_search_test() -> Result<(), Box> { ); let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; - println!("Response: {:?}", result.response); + println!("Response: {:?}\n", result.response); let result = search_engine .asearch("Tell me about Dr. Jordan Hayes".to_string()) .await?; - println!("Response: {:?}", result.response); + println!("Response: {:?}\n", result.response); match result.context_data { shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { From fabaecbfa587b202fcc6f0a2ef2a520c73ffa318 Mon Sep 17 00:00:00 2001 From: benolt Date: Mon, 2 Sep 2024 15:00:33 +0200 Subject: [PATCH 20/23] test with ollama, adjust relationships --- .../src/context_builder/local_context.rs | 5 +- .../tests/global_search_tests.rs | 8 +- .../tests/local_search_tests.rs | 144 +++++++++++++++++- .../shinkai-graphrag/tests/utils/ollama.rs | 67 ++++++-- .../shinkai-graphrag/tests/utils/openai.rs | 8 +- 5 files changed, 208 insertions(+), 24 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs index a883e577f..ca4f36154 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -405,7 +405,10 @@ fn _filter_relationships( let relationship_budget = top_k_relationships * selected_entities.len(); out_network_relationships.truncate(relationship_budget); - Vec::new() + let mut selected_relationships = in_network_relationships; + selected_relationships.extend(out_network_relationships); + + selected_relationships } pub fn get_candidate_context( diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index f95eaa32b..f72583cfe 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -6,7 +6,7 @@ use shinkai_graphrag::{ search::global_search::search::{GlobalSearch, GlobalSearchParams}, }; use utils::{ - ollama::Ollama, + ollama::OllamaChat, openai::{num_tokens, ChatOpenAI}, }; @@ -15,9 +15,9 @@ mod utils; // #[tokio::test] async fn ollama_global_search_test() -> Result<(), Box> { let base_url = "http://localhost:11434"; - let model_type = "llama3.1"; + let model = "llama3.1"; - let llm = Ollama::new(base_url.to_string(), model_type.to_string()); + let llm = OllamaChat::new(base_url, model); // Load community reports // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip @@ -113,7 +113,7 @@ async fn openai_global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); - let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); + let llm = ChatOpenAI::new(Some(api_key), &llm_model, 5); // Load community reports // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs index e20cca6a3..eb471fb11 100644 --- a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -11,18 +11,156 @@ use shinkai_graphrag::{ }, vector_stores::lancedb::LanceDBVectorStore, }; -use utils::openai::{num_tokens, ChatOpenAI, OpenAIEmbedding}; +use utils::{ + ollama::OllamaChat, + openai::{num_tokens, ChatOpenAI, OpenAIEmbedding}, +}; mod utils; +#[tokio::test] +async fn ollama_local_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let llm_model = "llama3.1"; + let llm = OllamaChat::new(base_url, llm_model); + + // Using OpenAI embeddings since the dataset was created with OpenAI embeddings + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let embedding_model = std::env::var("GRAPHRAG_EMBEDDING_MODEL").unwrap(); + let text_embedder = OpenAIEmbedding::new(Some(api_key), &embedding_model, 8191, 5); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let lancedb_uri = format!("{}/lancedb", input_dir); + + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + let relationship_table = "create_final_relationships"; + let text_unit_table = "create_final_text_units"; + let community_level = 2; + + // Read entities + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + let mut description_embedding_store = LanceDBVectorStore::new("entity_description_embeddings".to_string()); + description_embedding_store.connect(&lancedb_uri).await?; + + store_entity_semantic_embeddings(entities.clone(), &mut description_embedding_store).await?; + + println!("Entities ({}): {:?}", entity_df.height(), entity_df.head(Some(5))); + + // Read relationships + let mut relationship_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, relationship_table)).unwrap(); + let relationship_df = ParquetReader::new(&mut relationship_file).finish().unwrap(); + + let relationships = read_indexer_relationships(&relationship_df)?; + + println!( + "Relationships ({}): {:?}", + relationship_df.height(), + relationship_df.head(Some(5)) + ); + + // Read community reports + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + + println!("Reports ({}): {:?}", report_df.height(), report_df.head(Some(5))); + + // Read text units + let mut text_unit_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, text_unit_table)).unwrap(); + let text_unit_df = ParquetReader::new(&mut text_unit_file).finish().unwrap(); + + let text_units = read_indexer_text_units(&text_unit_df)?; + + println!( + "Text units ({}): {:?}", + text_unit_df.height(), + text_unit_df.head(Some(5)) + ); + + // Create local search context builder + let context_builder = LocalSearchMixedContext::new( + entities, + description_embedding_store, + Box::new(text_embedder), + Some(text_units), + Some(reports), + Some(relationships), + num_tokens, + "id".to_string(), + ); + + // Create local search engine + let mut local_context_params = default_local_context_params(); + local_context_params.text_unit_prop = 0.5; + local_context_params.community_prop = 0.1; + local_context_params.top_k_mapped_entities = 10; + local_context_params.top_k_relationships = 10; + local_context_params.include_entity_rank = true; + local_context_params.include_relationship_weight = true; + local_context_params.include_community_rank = false; + local_context_params.return_candidate_context = false; + local_context_params.max_tokens = 12_000; + + let llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + }; + + let search_engine = LocalSearch::new( + Box::new(llm), + context_builder, + num_tokens, + llm_params, + local_context_params, + String::from("multiple paragraphs"), + None, + ); + + let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; + println!("Response: {:?}\n", result.response); + + let result = search_engine + .asearch("Tell me about Dr. Jordan Hayes".to_string()) + .await?; + println!("Response: {:?}\n", result.response); + + match result.context_data { + shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { + for (entity, df) in dict.iter() { + println!("Data: {} ({})", entity, df.height()); + println!("{:?}", df.head(Some(10))); + } + } + data => { + println!("Context data: {:?}", data); + } + } + + Ok(()) +} + #[tokio::test] async fn openai_local_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); let embedding_model = std::env::var("GRAPHRAG_EMBEDDING_MODEL").unwrap(); - let llm = ChatOpenAI::new(Some(api_key.clone()), llm_model, 5); - let text_embedder = OpenAIEmbedding::new(Some(api_key), embedding_model, 8191, 5); + let llm = ChatOpenAI::new(Some(api_key.clone()), &llm_model, 5); + let text_embedder = OpenAIEmbedding::new(Some(api_key), &embedding_model, 8191, 5); // Load community reports // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs index 029336608..b34a12471 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -2,34 +2,45 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; -use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use shinkai_graphrag::llm::base::{ + BaseLLM, BaseLLMCallback, BaseTextEmbedding, GlobalSearchPhase, LLMParams, MessageType, +}; #[derive(Serialize, Deserialize, Debug)] -pub struct OllamaResponse { +pub struct OllamaChatResponse { pub model: String, pub created_at: String, - pub message: OllamaMessage, + pub message: OllamaChatMessage, } #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct OllamaMessage { +pub struct OllamaChatMessage { pub role: String, pub content: String, } -pub struct Ollama { +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaEmbeddingResponse { + pub model: String, + pub embeddings: Vec>, +} + +pub struct OllamaChat { base_url: String, - model_type: String, + model: String, } -impl Ollama { - pub fn new(base_url: String, model_type: String) -> Self { - Ollama { base_url, model_type } +impl OllamaChat { + pub fn new(base_url: &str, model: &str) -> Self { + OllamaChat { + base_url: base_url.to_string(), + model: model.to_string(), + } } } #[async_trait] -impl BaseLLM for Ollama { +impl BaseLLM for OllamaChat { async fn agenerate( &self, messages: MessageType, @@ -87,14 +98,46 @@ impl BaseLLM for Ollama { }; let payload = json!({ - "model": self.model_type, + "model": self.model, "messages": messages_json, "stream": false, }); let response = client.post(chat_url).json(&payload).send().await?; - let response = response.json::().await?; + let response = response.json::().await?; Ok(response.message.content) } } + +pub struct OllamaEmbedding { + base_url: String, + model: String, +} + +impl OllamaEmbedding { + pub fn new(base_url: &str, model: &str) -> Self { + OllamaEmbedding { + base_url: base_url.to_string(), + model: model.to_string(), + } + } +} + +#[async_trait] +impl BaseTextEmbedding for OllamaEmbedding { + async fn aembed(&self, text: &str) -> anyhow::Result> { + let client = Client::new(); + let embedding_url = format!("{}{}", &self.base_url, "/api/embedding"); + + let payload = json!({ + "model": self.model, + "input": text, + }); + + let response = client.post(embedding_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.embeddings.first().cloned().unwrap_or_default()) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index ab85366f5..392184f48 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -21,10 +21,10 @@ pub struct ChatOpenAI { } impl ChatOpenAI { - pub fn new(api_key: Option, model: String, max_retries: usize) -> Self { + pub fn new(api_key: Option, model: &str, max_retries: usize) -> Self { ChatOpenAI { api_key, - model, + model: model.to_string(), max_retries, } } @@ -136,10 +136,10 @@ pub struct OpenAIEmbedding { } impl OpenAIEmbedding { - pub fn new(api_key: Option, model: String, max_tokens: usize, max_retries: usize) -> Self { + pub fn new(api_key: Option, model: &str, max_tokens: usize, max_retries: usize) -> Self { OpenAIEmbedding { api_key, - model, + model: model.to_string(), max_tokens, max_retries, } From b3c5ddb793cd1cb7418f28fdbaa7c7b09850dbea Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 3 Sep 2024 11:08:03 +0200 Subject: [PATCH 21/23] ollama search adjustments, fix build text unit context --- .../src/context_builder/community_context.rs | 12 +- ...ext_builder.rs => conversation_history.rs} | 0 .../src/context_builder/entity_extraction.rs | 8 +- .../src/context_builder/local_context.rs | 24 +-- .../src/context_builder/mod.rs | 2 +- .../src/context_builder/source_context.rs | 8 +- .../shinkai-graphrag/src/input/loaders/dfs.rs | 171 ++++++++---------- .../src/input/retrieval/community_reports.rs | 2 +- .../src/input/retrieval/entities.rs | 8 +- .../src/input/retrieval/relationships.rs | 4 +- shinkai-libs/shinkai-graphrag/src/llm/base.rs | 6 - .../src/search/global_search/search.rs | 19 +- .../src/search/local_search/mixed_context.rs | 32 ++-- .../src/search/local_search/search.rs | 3 +- .../src/vector_stores/lancedb.rs | 8 +- .../src/vector_stores/vector_store.rs | 2 +- .../tests/global_search_tests.rs | 8 +- .../tests/local_search_tests.rs | 6 +- .../shinkai-graphrag/tests/utils/ollama.rs | 52 +----- .../shinkai-graphrag/tests/utils/openai.rs | 14 +- 20 files changed, 159 insertions(+), 230 deletions(-) rename shinkai-libs/shinkai-graphrag/src/context_builder/{context_builder.rs => conversation_history.rs} (100%) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index 053d6aecd..b7b7e8ac7 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -365,7 +365,7 @@ impl Batch { fn init_batch( &mut self, context_name: &str, - header: &Vec, + header: &[String], column_delimiter: &str, num_tokens_fn: fn(&str) -> usize, ) { @@ -379,7 +379,7 @@ impl Batch { all_context_text: &mut Vec, all_context_records: &mut Vec, entities: Option>, - header: &Vec, + header: &[String], community_weight_name: &str, community_rank_name: &str, include_community_weight: bool, @@ -399,7 +399,7 @@ impl Batch { let mut record_df = Self::_convert_report_context_to_df( self.batch_records.clone(), - header.clone(), + header.to_owned(), weight_column, rank_column, )?; @@ -423,6 +423,10 @@ impl Batch { buffer.set_position(0); buffer.read_to_string(&mut current_context_text)?; + if all_context_text.contains(¤t_context_text) { + return Ok(()); + } + all_context_text.push(current_context_text); all_context_records.push(record_df); @@ -451,7 +455,7 @@ impl Batch { let record_df = DataFrame::new(data_series)?; - return Self::_rank_report_context(record_df, weight_column, rank_column); + Self::_rank_report_context(record_df, weight_column, rank_column) } fn _rank_report_context( diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/conversation_history.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs rename to shinkai-libs/shinkai-graphrag/src/context_builder/conversation_history.rs diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs index 8e01adb0a..305d02af0 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -10,7 +10,7 @@ use crate::{ pub async fn map_query_to_entities( query: &str, text_embedding_vectorstore: &LanceDBVectorStore, - text_embedder: &Box, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), all_entities: &Vec, embedding_vectorstore_key: &str, include_entity_names: Option>, @@ -18,8 +18,8 @@ pub async fn map_query_to_entities( k: usize, oversample_scaler: usize, ) -> anyhow::Result> { - let include_entity_names = include_entity_names.unwrap_or_else(Vec::new); - let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_else(Vec::new).into_iter().collect(); + let include_entity_names = include_entity_names.unwrap_or_default(); + let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_default().into_iter().collect(); let mut matched_entities = Vec::new(); if !query.is_empty() { @@ -28,7 +28,7 @@ pub async fn map_query_to_entities( .await?; for result in search_results { - if let Some(matched) = get_entity_by_key(all_entities, &embedding_vectorstore_key, &result.document.id) { + if let Some(matched) = get_entity_by_key(all_entities, embedding_vectorstore_key, &result.document.id) { matched_entities.push(matched); } } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs index ca4f36154..95cbbe4da 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -41,7 +41,7 @@ pub fn build_entity_context( .attributes .unwrap_or_default() .keys() - .map(|s| s.clone()) + .cloned() .collect::>() } else { Vec::new() @@ -134,8 +134,8 @@ pub fn build_entity_context( } pub fn build_relationship_context( - selected_entities: &Vec, - relationships: &Vec, + selected_entities: &[Entity], + relationships: &[Relationship], num_tokens_fn: fn(&str) -> usize, include_relationship_weight: bool, max_tokens: usize, @@ -146,8 +146,8 @@ pub fn build_relationship_context( ) -> anyhow::Result<(String, DataFrame)> { // Filter relationships based on the criteria let selected_relationships = _filter_relationships( - &selected_entities, - &relationships, + selected_entities, + relationships, top_k_relationships, relationship_ranking_attribute, ); @@ -173,7 +173,7 @@ pub fn build_relationship_context( .attributes .unwrap_or_default() .keys() - .map(|s| s.clone()) + .cloned() .collect::>() } else { Vec::new() @@ -271,8 +271,8 @@ pub fn build_relationship_context( } fn _filter_relationships( - selected_entities: &Vec, - relationships: &Vec, + selected_entities: &[Entity], + relationships: &[Relationship], top_k_relationships: usize, relationship_ranking_attribute: &str, ) -> Vec { @@ -307,7 +307,7 @@ fn _filter_relationships( let out_network_entity_names: HashSet = out_network_source_names .into_iter() - .chain(out_network_target_names.into_iter()) + .chain(out_network_target_names) .collect(); let mut out_network_entity_links: HashMap = HashMap::new(); @@ -412,9 +412,9 @@ fn _filter_relationships( } pub fn get_candidate_context( - selected_entities: &Vec, - entities: &Vec, - relationships: &Vec, + selected_entities: &[Entity], + entities: &[Entity], + relationships: &[Relationship], include_entity_rank: bool, entity_rank_description: &str, include_relationship_weight: bool, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index f781afd72..a2ed01737 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1,5 +1,5 @@ pub mod community_context; -pub mod context_builder; +pub mod conversation_history; pub mod entity_extraction; pub mod local_context; pub mod source_context; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs index 0746cf94b..af4b0205b 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use polars::frame::DataFrame; use polars::prelude::NamedFrom; @@ -23,6 +23,9 @@ pub fn build_text_unit_context( let mut text_units = text_units; + let mut unique_ids = HashSet::new(); + text_units.retain(|unit| unique_ids.insert(unit.id.clone())); + if shuffle_data { let mut rng = StdRng::seed_from_u64(random_state); text_units.shuffle(&mut rng); @@ -35,8 +38,7 @@ pub fn build_text_unit_context( text_unit .attributes .unwrap_or_default() - .keys() - .map(|s| s.clone()) + .keys().cloned() .collect::>() } else { Vec::new() diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs index b644967f6..62066f1ba 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs @@ -104,29 +104,26 @@ pub fn read_entities( let mut entities = Vec::new(); for (idx, row) in rows.iter().enumerate() { let report = Entity { - id: get_field(&row, id_col, &column_names) + id: get_field(row, id_col, &column_names) .map(|id| id.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), short_id: Some( short_id_col - .map(|short_id| get_field(&row, short_id, &column_names)) - .flatten() + .and_then(|short_id| get_field(row, short_id, &column_names)) .map(|short_id| short_id.to_string()) .unwrap_or(idx.to_string()), ), - title: get_field(&row, title_col, &column_names) + title: get_field(row, title_col, &column_names) .map(|title| title.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), entity_type: type_col - .map(|type_col| get_field(&row, type_col, &column_names)) - .flatten() + .and_then(|type_col| get_field(row, type_col, &column_names)) .map(|entity_type| entity_type.to_string()), description: description_col - .map(|description_col| get_field(&row, description_col, &column_names)) - .flatten() + .and_then(|description_col| get_field(row, description_col, &column_names)) .map(|description| description.to_string()), name_embedding: name_embedding_col.map(|name_embedding_col| { - get_field(&row, name_embedding_col, &column_names) + get_field(row, name_embedding_col, &column_names) .map(|name_embedding| match name_embedding { AnyValue::List(series) => series .f64() @@ -136,10 +133,10 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string().parse::().unwrap_or(0.0)], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), description_embedding: description_embedding_col.map(|description_embedding_col| { - get_field(&row, description_embedding_col, &column_names) + get_field(row, description_embedding_col, &column_names) .map(|description_embedding| match description_embedding { AnyValue::List(series) => series .f64() @@ -149,10 +146,10 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string().parse::().unwrap_or(0.0)], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), graph_embedding: graph_embedding_col.map(|graph_embedding_col| { - get_field(&row, graph_embedding_col, &column_names) + get_field(row, graph_embedding_col, &column_names) .map(|graph_embedding| match graph_embedding { AnyValue::List(series) => series .f64() @@ -162,10 +159,10 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string().parse::().unwrap_or(0.0)], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), community_ids: community_col.map(|community_col| { - get_field(&row, community_col, &column_names) + get_field(row, community_col, &column_names) .map(|community_ids| match community_ids { AnyValue::List(series) => series .str() @@ -175,10 +172,10 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { - get_field(&row, text_unit_ids_col, &column_names) + get_field(row, text_unit_ids_col, &column_names) .map(|text_unit_ids| match text_unit_ids { AnyValue::List(series) => series .str() @@ -188,10 +185,10 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), document_ids: document_ids_col.map(|document_ids_col| { - get_field(&row, document_ids_col, &column_names) + get_field(row, document_ids_col, &column_names) .map(|document_ids| match document_ids { AnyValue::List(series) => series .str() @@ -201,13 +198,11 @@ pub fn read_entities( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() + }), + rank: rank_col.and_then(|rank_col| { + get_field(row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) }), - rank: rank_col - .map(|rank_col| { - get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) - }) - .flatten(), attributes: None, }; entities.push(report); @@ -281,33 +276,30 @@ pub fn read_community_reports( let mut reports = Vec::new(); for (idx, row) in rows.iter().enumerate() { let report = CommunityReport { - id: get_field(&row, id_col, &column_names) + id: get_field(row, id_col, &column_names) .map(|id| id.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), short_id: Some( short_id_col - .map(|short_id| get_field(&row, short_id, &column_names)) - .flatten() + .and_then(|short_id| get_field(row, short_id, &column_names)) .map(|short_id| short_id.to_string()) .unwrap_or(idx.to_string()), ), - title: get_field(&row, title_col, &column_names) + title: get_field(row, title_col, &column_names) .map(|title| title.to_string()) - .unwrap_or(String::new()), - community_id: get_field(&row, community_col, &column_names) + .unwrap_or_default(), + community_id: get_field(row, community_col, &column_names) .map(|community| community.to_string()) - .unwrap_or(String::new()), - summary: get_field(&row, summary_col, &column_names) + .unwrap_or_default(), + summary: get_field(row, summary_col, &column_names) .map(|summary| summary.to_string()) - .unwrap_or(String::new()), - full_content: get_field(&row, content_col, &column_names) + .unwrap_or_default(), + full_content: get_field(row, content_col, &column_names) .map(|content| content.to_string()) - .unwrap_or(String::new()), - rank: rank_col - .map(|rank_col| { - get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) - }) - .flatten(), + .unwrap_or_default(), + rank: rank_col.and_then(|rank_col| { + get_field(row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }), summary_embedding: None, full_content_embedding: None, attributes: None, @@ -360,11 +352,11 @@ pub fn read_relationships( }) .collect::>(); - attributes_cols.as_ref().map(|cols| { + if let Some(cols) = attributes_cols.as_ref() { cols.iter().for_each(|col| { column_names.insert(col.to_string()); }); - }); + } let column_names = column_names.into_iter().collect::>(); @@ -391,28 +383,26 @@ pub fn read_relationships( let mut relationships = Vec::new(); for (idx, row) in rows.iter().enumerate() { let report = Relationship { - id: get_field(&row, id_col, &column_names) + id: get_field(row, id_col, &column_names) .map(|id| id.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), short_id: Some( short_id_col - .map(|short_id| get_field(&row, short_id, &column_names)) - .flatten() + .and_then(|short_id| get_field(row, short_id, &column_names)) .map(|short_id| short_id.to_string()) .unwrap_or(idx.to_string()), ), - source: get_field(&row, source_col, &column_names) + source: get_field(row, source_col, &column_names) .map(|source| source.to_string()) - .unwrap_or(String::new()), - target: get_field(&row, target_col, &column_names) + .unwrap_or_default(), + target: get_field(row, target_col, &column_names) .map(|target| target.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), description: description_col - .map(|description| get_field(&row, description, &column_names)) - .flatten() + .and_then(|description| get_field(row, description, &column_names)) .map(|description| description.to_string()), description_embedding: description_embedding_col.map(|description_embedding_col| { - get_field(&row, description_embedding_col, &column_names) + get_field(row, description_embedding_col, &column_names) .map(|description_embedding| match description_embedding { AnyValue::List(series) => series .f64() @@ -422,15 +412,13 @@ pub fn read_relationships( .collect::>(), value => vec![value.to_string().parse::().unwrap_or(0.0)], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() + }), + weight: weight_col.and_then(|weight_col| { + get_field(row, weight_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) }), - weight: weight_col - .map(|weight_col| { - get_field(&row, weight_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) - }) - .flatten(), text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { - get_field(&row, text_unit_ids_col, &column_names) + get_field(row, text_unit_ids_col, &column_names) .map(|text_unit_ids| match text_unit_ids { AnyValue::List(series) => series .str() @@ -440,10 +428,10 @@ pub fn read_relationships( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), document_ids: document_ids_col.map(|document_ids_col| { - get_field(&row, document_ids_col, &column_names) + get_field(row, document_ids_col, &column_names) .map(|document_ids| match document_ids { AnyValue::List(series) => series .str() @@ -453,12 +441,12 @@ pub fn read_relationships( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), attributes: attributes_cols.as_ref().map(|cols| { cols.iter() .map(|col| { - get_field(&row, col, &column_names) + get_field(row, col, &column_names) .map(|v| (col.to_string(), v.to_string())) .unwrap_or((String::new(), String::new())) }) @@ -511,11 +499,11 @@ pub fn read_text_units( }) .collect::>(); - attributes_cols.as_ref().map(|cols| { + if let Some(cols) = attributes_cols.as_ref() { cols.iter().for_each(|col| { column_names.insert(col.to_string()); }); - }); + } let column_names = column_names.into_iter().collect::>(); @@ -542,21 +530,20 @@ pub fn read_text_units( let mut text_units = Vec::new(); for (idx, row) in rows.iter().enumerate() { let report = TextUnit { - id: get_field(&row, id_col, &column_names) + id: get_field(row, id_col, &column_names) .map(|id| id.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), short_id: Some( short_id_col - .map(|short_id| get_field(&row, short_id, &column_names)) - .flatten() + .and_then(|short_id| get_field(row, short_id, &column_names)) .map(|short_id| short_id.to_string()) .unwrap_or(idx.to_string()), ), - text: get_field(&row, text_col, &column_names) + text: get_field(row, text_col, &column_names) .map(|text| text.to_string()) - .unwrap_or(String::new()), + .unwrap_or_default(), entity_ids: entities_col.map(|entities_col| { - get_field(&row, entities_col, &column_names) + get_field(row, entities_col, &column_names) .map(|entity_ids| match entity_ids { AnyValue::List(series) => series .str() @@ -566,10 +553,10 @@ pub fn read_text_units( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), relationship_ids: relationships_col.map(|relationships_col| { - get_field(&row, relationships_col, &column_names) + get_field(row, relationships_col, &column_names) .map(|relationship_ids| match relationship_ids { AnyValue::List(series) => series .str() @@ -579,10 +566,10 @@ pub fn read_text_units( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), text_embedding: embedding_col.map(|embedding_col| { - get_field(&row, embedding_col, &column_names) + get_field(row, embedding_col, &column_names) .map(|embedding| match embedding { AnyValue::List(series) => series .f64() @@ -592,15 +579,13 @@ pub fn read_text_units( .collect::>(), value => vec![value.to_string().parse::().unwrap_or(0.0)], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() + }), + n_tokens: tokens_col.and_then(|tokens_col| { + get_field(row, tokens_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) }), - n_tokens: tokens_col - .map(|tokens_col| { - get_field(&row, tokens_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) - }) - .flatten(), document_ids: document_ids_col.map(|document_ids_col| { - get_field(&row, document_ids_col, &column_names) + get_field(row, document_ids_col, &column_names) .map(|document_ids| match document_ids { AnyValue::List(series) => series .str() @@ -610,12 +595,12 @@ pub fn read_text_units( .collect::>(), value => vec![value.to_string()], }) - .unwrap_or_else(|| Vec::new()) + .unwrap_or_default() }), attributes: attributes_cols.as_ref().map(|cols| { cols.iter() .map(|col| { - get_field(&row, col, &column_names) + get_field(row, col, &column_names) .map(|v| (col.to_string(), v.to_string())) .unwrap_or((String::new(), String::new())) }) @@ -638,11 +623,7 @@ pub fn read_text_units( Ok(unique_text_units) } -fn get_field<'a>( - row: &'a Vec>, - column_name: &'a str, - column_names: &'a Vec, -) -> Option> { +fn get_field<'a>(row: &'a [AnyValue<'a>], column_name: &'a str, column_names: &'a [String]) -> Option> { match column_names.iter().position(|x| x == column_name) { Some(index) => row.get(index).cloned(), None => None, diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs index 504f93df2..002daca00 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs @@ -40,7 +40,7 @@ pub fn to_community_report_dataframe( let attribute_cols: Vec = reports[0] .attributes .as_ref() - .map(|attrs| attrs.keys().filter(|&col| !header.contains(&col)).cloned().collect()) + .map(|attrs| attrs.keys().filter(|&col| !header.contains(col)).cloned().collect()) .unwrap_or_default(); header.extend(attribute_cols.iter().cloned()); diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs index bc62e48f1..f81315204 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs @@ -10,7 +10,7 @@ pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Opti match key { "id" => { if entity.id == value - || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace("-", "") + || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace('-', "") { return Some(entity.clone()); } @@ -19,7 +19,7 @@ pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Opti if entity.short_id.as_ref().unwrap_or(&"".to_string()) == value || is_valid_uuid(value) && entity.short_id.as_ref().unwrap_or(&"".to_string()) - == Uuid::parse_str(value).unwrap().to_string().replace("-", "").as_str() + == Uuid::parse_str(value).unwrap().to_string().replace('-', "").as_str() { return Some(entity.clone()); } @@ -45,7 +45,7 @@ pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Opti None } -pub fn get_entity_by_name(entities: &Vec, entity_name: &str) -> Vec { +pub fn get_entity_by_name(entities: &[Entity], entity_name: &str) -> Vec { entities .iter() .filter(|entity| entity.title == entity_name) @@ -73,7 +73,7 @@ pub fn to_entity_dataframe( .attributes .unwrap_or_default() .keys() - .map(|s| s.clone()) + .cloned() .collect::>() } else { Vec::new() diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs index 0651f9561..8c1ea7c94 100644 --- a/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs @@ -12,11 +12,11 @@ pub fn get_in_network_relationships( let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); let selected_relationships: Vec = relationships - .to_owned() - .into_iter() + .iter() .filter(|relationship| { selected_entity_names.contains(&relationship.source) && selected_entity_names.contains(&relationship.target) }) + .cloned() .collect(); if selected_relationships.len() <= 1 { diff --git a/shinkai-libs/shinkai-graphrag/src/llm/base.rs b/shinkai-libs/shinkai-graphrag/src/llm/base.rs index 4dc414de4..3432a19c5 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/base.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/base.rs @@ -35,7 +35,6 @@ pub trait BaseLLM { streaming: bool, callbacks: Option>, llm_params: LLMParams, - search_phase: Option, ) -> anyhow::Result; } @@ -43,8 +42,3 @@ pub trait BaseLLM { pub trait BaseTextEmbedding { async fn aembed(&self, text: &str) -> anyhow::Result>; } - -pub enum GlobalSearchPhase { - Map, - Reduce, -} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs index 917b2b784..34e68e3af 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs @@ -4,8 +4,8 @@ use std::collections::HashMap; use std::time::Instant; use crate::context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}; -use crate::context_builder::context_builder::ConversationHistory; -use crate::llm::base::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; +use crate::context_builder::conversation_history::ConversationHistory; +use crate::llm::base::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; @@ -149,7 +149,7 @@ impl GlobalSearch { let map_responses: Vec<_> = join_all( context_chunks .iter() - .map(|data| self._map_response_single_batch(data, &query, self.map_llm_params.clone())), + .map(|data| self._map_response_single_batch(data, self.map_llm_params.clone())), ) .await; @@ -193,7 +193,6 @@ impl GlobalSearch { async fn _map_response_single_batch( &self, context_data: &str, - query: &str, llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); @@ -206,19 +205,13 @@ impl GlobalSearch { ]), HashMap::from([ ("role".to_string(), "user".to_string()), - ("content".to_string(), query.to_string()), + ("content".to_string(), "Respond using JSON".to_string()), ]), ]; let search_response = self .llm - .agenerate( - MessageType::Dictionary(search_messages), - false, - None, - llm_params, - Some(GlobalSearchPhase::Map), - ) + .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) .await?; let processed_response = self.parse_search_response(&search_response); @@ -234,6 +227,7 @@ impl GlobalSearch { } fn parse_search_response(&self, search_response: &str) -> Vec { + let search_response = &search_response.replace("{{", "{").replace("}}", "}"); let parsed_elements: Value = serde_json::from_str(search_response).unwrap_or_default(); if let Some(points) = parsed_elements.get("points") { @@ -372,7 +366,6 @@ impl GlobalSearch { true, llm_callbacks, llm_params, - Some(GlobalSearchPhase::Reduce), ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs index b9487dc01..4b394a22a 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -106,7 +106,7 @@ impl LocalSearchMixedContext { if let Some(units) = text_units { for unit in units { - context.text_units.insert(unit.id.clone(), unit); + context.text_units.insert(unit.id.replace('"', ""), unit); } } @@ -118,7 +118,7 @@ impl LocalSearchMixedContext { if let Some(relations) = relationships { for relation in relations { - context.relationships.insert(relation.id.clone(), relation); + context.relationships.insert(relation.id.replace('"', ""), relation); } } @@ -162,7 +162,7 @@ impl LocalSearchMixedContext { let selected_entities = map_query_to_entities( &query, &self.entity_text_embeddings, - &self.text_embedder, + &*self.text_embedder, &self.entities.values().cloned().collect::>(), &self.embedding_vectorstore_key, Some(include_entity_names), @@ -399,7 +399,7 @@ impl LocalSearchMixedContext { let (relationship_context, relationship_context_data) = build_relationship_context( &added_entities, - &self.relationships.values().cloned().collect(), + &self.relationships.values().cloned().collect::>(), self.num_tokens_fn, include_relationship_weight, max_tokens, @@ -429,8 +429,8 @@ impl LocalSearchMixedContext { final_context_data.insert("entities".to_string(), entity_context_data.clone()); if return_candidate_context { - let entities = self.entities.values().cloned().collect(); - let relationships = self.relationships.values().cloned().collect(); + let entities = self.entities.values().cloned().collect::>(); + let relationships = self.relationships.values().cloned().collect::>(); let candidate_context_data = get_candidate_context( &selected_entities, @@ -493,16 +493,16 @@ impl LocalSearchMixedContext { { let mut selected_unit = self.text_units[text_id].clone(); let num_relationships = count_relationships(&selected_unit, entity, &self.relationships); - selected_unit - .attributes - .as_mut() - .unwrap_or(&mut HashMap::new()) - .insert("entity_order".to_string(), index.to_string()); - selected_unit - .attributes - .as_mut() - .unwrap_or(&mut HashMap::new()) - .insert("num_relationships".to_string(), num_relationships.to_string()); + + if selected_unit.attributes.is_none() { + selected_unit.attributes = Some(HashMap::new()); + } + + if let Some(attributes) = &mut selected_unit.attributes { + attributes.insert("entity_order".to_string(), index.to_string()); + attributes.insert("num_relationships".to_string(), num_relationships.to_string()); + } + selected_text_units.push(selected_unit); } } diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs index 6c555edf8..1f87d06b4 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs @@ -56,7 +56,7 @@ impl LocalSearch { let start_time = Instant::now(); let mut context_builder_params = self.context_builder_params.clone(); - context_builder_params.query = query.clone(); + context_builder_params.query.clone_from(&query); let (context_text, context_records) = self.context_builder.build_context(context_builder_params).await?; @@ -83,7 +83,6 @@ impl LocalSearch { false, None, self.llm_params.clone(), - None, ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs index 83c52cbff..78f9ab8ae 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -133,7 +133,7 @@ impl VectorStore for LanceDBVectorStore { async fn similarity_search_by_text( &self, text: &str, - text_embedder: &Box, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), k: usize, ) -> anyhow::Result> { let query_embedding = text_embedder.aembed(text).await?; @@ -238,12 +238,10 @@ impl VectorStore for LanceDBVectorStore { let table = match db_connection.open_table(&self.collection_name).execute().await { Ok(table) => table, Err(_) => { - let table = db_connection + db_connection .create_empty_table(&self.collection_name, schema.clone()) .execute() - .await?; - - table + .await? } }; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs index e6a3503ab..d1745fe75 100644 --- a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -18,7 +18,7 @@ pub trait VectorStore { fn similarity_search_by_text( &self, text: &str, - text_embedder: &Box, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), k: usize, ) -> impl std::future::Future>> + Send; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index f72583cfe..1dd7e696a 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -58,19 +58,19 @@ async fn ollama_global_search_test() -> Result<(), Box> { include_community_weight: true, community_weight_name: String::from("occurrence weight"), normalize_community_weight: true, - max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) context_name: String::from("Reports"), column_delimiter: String::from("|"), }; // LLM params are ignored for Ollama let map_llm_params = LLMParams { - max_tokens: 1000, + max_tokens: 12_000, temperature: 0.0, }; let reduce_llm_params = LLMParams { - max_tokens: 2000, + max_tokens: 12_000, temperature: 0.0, }; @@ -86,7 +86,7 @@ async fn ollama_global_search_test() -> Result<(), Box> { allow_general_knowledge: false, general_knowledge_inclusion_prompt: None, callbacks: None, - max_data_tokens: 5000, + max_data_tokens: 12_000, map_llm_params, reduce_llm_params, context_builder_params, diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs index eb471fb11..79a00fbf8 100644 --- a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -18,7 +18,7 @@ use utils::{ mod utils; -#[tokio::test] +// #[tokio::test] async fn ollama_local_search_test() -> Result<(), Box> { let base_url = "http://localhost:11434"; let llm_model = "llama3.1"; @@ -116,7 +116,7 @@ async fn ollama_local_search_test() -> Result<(), Box> { local_context_params.max_tokens = 12_000; let llm_params = LLMParams { - max_tokens: 2000, + max_tokens: 12_000, temperature: 0.0, }; @@ -153,7 +153,7 @@ async fn ollama_local_search_test() -> Result<(), Box> { Ok(()) } -#[tokio::test] +// #[tokio::test] async fn openai_local_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs index b34a12471..5efe5494a 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -2,9 +2,7 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; -use shinkai_graphrag::llm::base::{ - BaseLLM, BaseLLMCallback, BaseTextEmbedding, GlobalSearchPhase, LLMParams, MessageType, -}; +use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, BaseTextEmbedding, LLMParams, MessageType}; #[derive(Serialize, Deserialize, Debug)] pub struct OllamaChatResponse { @@ -46,8 +44,7 @@ impl BaseLLM for OllamaChat { messages: MessageType, _streaming: bool, _callbacks: Option>, - _llm_params: LLMParams, - search_phase: Option, + llm_params: LLMParams, ) -> anyhow::Result { let client = Client::new(); let chat_url = format!("{}{}", &self.base_url, "/api/chat"); @@ -55,51 +52,16 @@ impl BaseLLM for OllamaChat { let messages_json = match messages { MessageType::String(message) => json![message], MessageType::Strings(messages) => json!(messages), - MessageType::Dictionary(messages) => { - let messages = match search_phase { - Some(GlobalSearchPhase::Map) => { - // Filter out system messages and convert them to user messages - messages - .into_iter() - .filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system")) - .map(|map| { - map.into_iter() - .map(|(key, value)| { - if key == "role" { - return (key, "user".to_string()); - } - (key, value) - }) - .collect() - }) - .collect() - } - Some(GlobalSearchPhase::Reduce) => { - // Convert roles to user - messages - .into_iter() - .map(|map| { - map.into_iter() - .map(|(key, value)| { - if key == "role" { - return (key, "user".to_string()); - } - (key, value) - }) - .collect() - }) - .collect() - } - _ => messages, - }; - - json!(messages) - } + MessageType::Dictionary(messages) => json!(messages), }; let payload = json!({ "model": self.model, "messages": messages_json, + "options": { + "num_ctx": llm_params.max_tokens, + "temperature": llm_params.temperature, + }, "stream": false, }); diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 392184f48..8a1eaec53 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -9,9 +9,7 @@ use async_openai::{ use async_trait::async_trait; use ndarray::{Array1, Array2, Axis}; use ndarray_stats::SummaryStatisticsExt; -use shinkai_graphrag::llm::base::{ - BaseLLM, BaseLLMCallback, BaseTextEmbedding, GlobalSearchPhase, LLMParams, MessageType, -}; +use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, BaseTextEmbedding, LLMParams, MessageType}; use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { @@ -106,11 +104,11 @@ impl ChatOpenAI { let response = client.chat().create(request).await?; - if let Some(choice) = response.choices.get(0) { + if let Some(choice) = response.choices.first() { return Ok(choice.message.content.clone().unwrap_or_default()); } - return Ok(String::new()); + Ok(String::new()) } } @@ -122,7 +120,6 @@ impl BaseLLM for ChatOpenAI { streaming: bool, callbacks: Option>, llm_params: LLMParams, - _search_phase: Option, ) -> anyhow::Result { self.agenerate(messages, streaming, callbacks, llm_params).await } @@ -175,8 +172,7 @@ impl OpenAIEmbedding { let response = client.embeddings().create(request).await?; let embedding = response - .data - .get(0) + .data.first() .map(|data| data.embedding.clone()) .unwrap_or_default(); @@ -248,7 +244,7 @@ fn batched(iterable: impl Iterator, n: usize) -> impl Iterator(text: &'a str, max_tokens: usize) -> impl Iterator + 'a { +fn chunk_text(text: &str, max_tokens: usize) -> impl Iterator + '_ { let token_encoder = Tokenizer::Cl100kBase; let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); let tokens = bpe.encode_with_special_tokens(text); From 190f0714a1471985e1039ab346046cef16f350f4 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 5 Sep 2024 13:52:28 +0200 Subject: [PATCH 22/23] fix ollama embed url --- shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs index 5efe5494a..4b72fe4b2 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -90,7 +90,7 @@ impl OllamaEmbedding { impl BaseTextEmbedding for OllamaEmbedding { async fn aembed(&self, text: &str) -> anyhow::Result> { let client = Client::new(); - let embedding_url = format!("{}{}", &self.base_url, "/api/embedding"); + let embedding_url = format!("{}{}", &self.base_url, "/api/embed"); let payload = json!({ "model": self.model, From fb535434574f1a571c5497ae763010ee208673c0 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 6 Sep 2024 14:15:57 +0200 Subject: [PATCH 23/23] update ollama test --- shinkai-libs/shinkai-graphrag/README.md | 5 +++++ .../shinkai-graphrag/tests/local_search_tests.rs | 16 ++++------------ 2 files changed, 9 insertions(+), 12 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/README.md diff --git a/shinkai-libs/shinkai-graphrag/README.md b/shinkai-libs/shinkai-graphrag/README.md new file mode 100644 index 000000000..a30d249d0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/README.md @@ -0,0 +1,5 @@ +# Shinkai GraphRAG + +Rust implementation of GraphRAG Global and Local search. Documentation can be found [here](https://microsoft.github.io/graphrag/). + +See `tests` to see how to configure and use them. \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs index 79a00fbf8..1deb90fd7 100644 --- a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -12,7 +12,7 @@ use shinkai_graphrag::{ vector_stores::lancedb::LanceDBVectorStore, }; use utils::{ - ollama::OllamaChat, + ollama::{OllamaChat, OllamaEmbedding}, openai::{num_tokens, ChatOpenAI, OpenAIEmbedding}, }; @@ -24,13 +24,10 @@ async fn ollama_local_search_test() -> Result<(), Box> { let llm_model = "llama3.1"; let llm = OllamaChat::new(base_url, llm_model); - // Using OpenAI embeddings since the dataset was created with OpenAI embeddings - let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); - let embedding_model = std::env::var("GRAPHRAG_EMBEDDING_MODEL").unwrap(); - let text_embedder = OpenAIEmbedding::new(Some(api_key), &embedding_model, 8191, 5); + let embedding_model = "snowflake-arctic-embed:m"; + let text_embedder = OllamaEmbedding::new(base_url, embedding_model); - // Load community reports - // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + // Load datasets let input_dir = "./dataset"; let lancedb_uri = format!("{}/lancedb", input_dir); @@ -133,11 +130,6 @@ async fn ollama_local_search_test() -> Result<(), Box> { let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; println!("Response: {:?}\n", result.response); - let result = search_engine - .asearch("Tell me about Dr. Jordan Hayes".to_string()) - .await?; - println!("Response: {:?}\n", result.response); - match result.context_data { shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { for (entity, df) in dict.iter() {