diff --git a/Cargo.lock b/Cargo.lock index b51e07de2..3e041a69e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,6 +123,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -224,6 +239,21 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "argminmax" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + [[package]] name = "arrayref" version = "0.3.7" @@ -375,7 +405,7 @@ dependencies = [ "arrow-schema 51.0.0", "arrow-select 51.0.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -396,7 +426,7 @@ dependencies = [ "arrow-schema 52.2.0", "arrow-select 52.2.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -510,7 +540,7 @@ dependencies = [ "arrow-schema 51.0.0", "chrono", "half", - "indexmap 2.1.0", + "indexmap 2.4.0", "lexical-core", "num", "serde", @@ -530,7 +560,7 @@ dependencies = [ "arrow-schema 52.2.0", "chrono", "half", - "indexmap 2.1.0", + "indexmap 2.4.0", "lexical-core", "num", "serde", @@ -709,6 +739,15 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + [[package]] name = "async-executor" version = "1.5.1" @@ -767,6 +806,32 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-openai" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e5ff98f9e7c605df4c88783a0439d1dc667ce86bd79e99d4164f8b0c05ccc" +dependencies = [ + "async-convert", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder 0.20.0", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest 0.12.5", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util 0.7.11", + "tracing", +] + [[package]] name = "async-priority-channel" version = "0.2.0" @@ -832,6 +897,28 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "async-task" version = "4.4.0" @@ -875,6 +962,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "atomic-waker" version = "1.1.1" @@ -1384,6 +1477,20 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.10", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -1435,9 +1542,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64-simd" @@ -1623,6 +1730,27 @@ dependencies = [ "syn_derive", ] +[[package]] +name = "brotli" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19483b140a7ac7174d34b5a581b406c64f84da5409d3e09cf4fff604f9270e67" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bs58" version = "0.5.0" @@ -1633,6 +1761,17 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata 0.4.7", + "serde", +] + [[package]] name = "buf_redux" version = "0.8.4" @@ -1700,6 +1839,20 @@ name = "bytemuck" version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] [[package]] name = "byteorder" @@ -1915,6 +2068,17 @@ dependencies = [ "parse-zoneinfo", ] +[[package]] +name = "chrono-tz" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +dependencies = [ + "chrono", + "chrono-tz-build 0.2.1", + "phf 0.11.2", +] + [[package]] name = "chrono-tz" version = "0.9.0" @@ -1922,8 +2086,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" dependencies = [ "chrono", - "chrono-tz-build", + "chrono-tz-build 0.3.0", + "phf 0.11.2", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", "phf 0.11.2", + "phf_codegen 0.11.2", ] [[package]] @@ -2153,6 +2328,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ + "crossterm", "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", @@ -2165,7 +2341,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0436149c9f6a1935b13306206c739b1ba84fa81f551b5eb87fc2ca7a13700af" dependencies = [ "clap 4.5.4", - "derive_builder", + "derive_builder 0.12.0", "entities", "memchr", "once_cell", @@ -2451,6 +2627,28 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.0", + "crossterm_winapi", + "libc", + "parking_lot 0.12.1", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -2726,7 +2924,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "num_cpus", @@ -2819,7 +3017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" dependencies = [ "arrow 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -2886,7 +3084,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "paste", @@ -2906,7 +3104,7 @@ dependencies = [ "arrow-ord 52.2.0", "arrow-schema 52.2.0", "arrow-string 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -2915,7 +3113,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "paste", @@ -2961,7 +3159,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.1.0", + "indexmap 2.4.0", "itertools 0.12.1", "log 0.4.21", "once_cell", @@ -3082,7 +3280,16 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" dependencies = [ - "derive_builder_macro", + "derive_builder_macro 0.12.0", +] + +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro 0.20.0", ] [[package]] @@ -3097,16 +3304,38 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "derive_builder_macro" version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" dependencies = [ - "derive_builder_core", + "derive_builder_core 0.12.0", "syn 1.0.109", ] +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core 0.20.0", + "syn 2.0.66", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -3278,6 +3507,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "ecdsa" version = "0.14.8" @@ -3337,9 +3572,9 @@ checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" [[package]] name = "either" -version = "1.9.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "elliptic-curve" @@ -3422,6 +3657,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "env_logger" version = "0.9.3" @@ -3778,6 +4025,12 @@ dependencies = [ "yansi", ] +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + [[package]] name = "event-listener" version = "2.5.3" @@ -3795,6 +4048,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -3821,6 +4085,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -3831,6 +4101,22 @@ dependencies = [ "regex", ] +[[package]] +name = "fancy-regex" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" +dependencies = [ + "bit-set", + "regex", +] + +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + [[package]] name = "fastdivide" version = "0.4.1" @@ -3993,6 +4279,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -4372,7 +4664,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.1.0", + "indexmap 2.4.0", "slab", "tokio", "tokio-util 0.7.11", @@ -4441,6 +4733,8 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.11", "allocator-api2", + "rayon", + "serde", ] [[package]] @@ -4521,9 +4815,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -4854,7 +5148,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows 0.48.0", ] [[package]] @@ -5008,9 +5302,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.1.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -5055,7 +5349,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", "windows-sys 0.48.0", ] @@ -5088,7 +5382,7 @@ version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "rustix 0.38.32", "windows-sys 0.48.0", ] @@ -5141,6 +5435,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "itoap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" + [[package]] name = "jetscii" version = "0.5.3" @@ -6028,11 +6328,21 @@ version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9106e1d747ffd48e6be5bb2d97fa706ed25b144fbee4d5c02eae110cd8d6badd" +[[package]] +name = "lz4" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958b4caa893816eea05507c20cfe47574a43d9a697138a7872990bba8a0ece68" +dependencies = [ + "libc", + "lz4-sys", +] + [[package]] name = "lz4-sys" -version = "1.9.4" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" dependencies = [ "cc", "libc", @@ -6133,6 +6443,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg 1.1.0", + "rawpointer", +] + [[package]] name = "maybe-owned" version = "0.3.4" @@ -6181,6 +6501,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + [[package]] name = "memmap2" version = "0.9.4" @@ -6297,13 +6626,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -6410,8 +6740,30 @@ dependencies = [ ] [[package]] -name = "murmurhash32" -version = "0.3.1" +name = "multiversion" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "target-features", +] + +[[package]] +name = "murmurhash32" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" @@ -6452,6 +6804,36 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray-stats" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ebbe97acce52d06aebed4cd4a87c0941f4b2519b59b82b4feb5bd0ce003dfd" +dependencies = [ + "indexmap 2.4.0", + "itertools 0.13.0", + "ndarray", + "noisy_float", + "num-integer", + "num-traits", + "rand 0.8.5", +] + [[package]] name = "new_debug_unreachable" version = "1.0.4" @@ -6491,6 +6873,15 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits", +] + [[package]] name = "nom" version = "7.1.3" @@ -6524,6 +6915,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -6642,7 +7051,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", ] @@ -6683,7 +7092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" dependencies = [ "async-trait", - "base64 0.22.0", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -6839,7 +7248,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.1.0", + "indexmap 2.4.0", "js-sys", "once_cell", "pin-project-lite", @@ -7088,6 +7497,16 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parquet-format-safe" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" +dependencies = [ + "async-trait", + "futures", +] + [[package]] name = "parse-zoneinfo" version = "0.3.0" @@ -7294,7 +7713,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset 0.4.2", - "indexmap 2.1.0", + "indexmap 2.4.0", ] [[package]] @@ -7554,6 +7973,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + [[package]] name = "platforms" version = "3.2.0" @@ -7567,7 +7995,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9d34169e64b3c7a80c8621a48adaf44e0cf62c78a9b25dd9dd35f1881a17cf9" dependencies = [ "base64 0.21.7", - "indexmap 2.1.0", + "indexmap 2.4.0", "line-wrap", "quick-xml 0.31.0", "serde", @@ -7608,6 +8036,416 @@ dependencies = [ "miniz_oxide 0.7.1", ] +[[package]] +name = "polars" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3351ea4570e54cd556e6755b78fe7a2c85368d820c0307cca73c96e796a7ba" +dependencies = [ + "getrandom 0.2.10", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-sql", + "polars-time", + "polars-utils", + "version_check 0.9.4", +] + +[[package]] +name = "polars-arrow" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba65fc4bcabbd64fca01fd30e759f8b2043f0963c57619e331d4b534576c0b47" +dependencies = [ + "ahash 0.8.11", + "atoi", + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "futures", + "getrandom 0.2.10", + "hashbrown 0.14.5", + "itoa 1.0.9", + "itoap", + "lz4", + "multiversion", + "num-traits", + "polars-arrow-format", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check 0.9.4", + "zstd 0.13.2", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f099516af30ac9ae4b4480f4ad02aa017d624f2f37b7a16ad4e9ba52f7e5269" +dependencies = [ + "bytemuck", + "either", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "strength_reduce", + "version_check 0.9.4", +] + +[[package]] +name = "polars-core" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2439484be228b8c302328e2f953e64cfd93930636e5c7ceed90339ece7fef6c" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "comfy-table", + "either", + "hashbrown 0.14.5", + "indexmap 2.4.0", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand 0.8.5", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check 0.9.4", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9b06dfbe79cabe50a7f0a90396864b5ee2c0e0f8d6a9353b2343c29c56e937" +dependencies = [ + "polars-arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-expr" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c630385a56a867c410a20f30772d088f90ec3d004864562b84250b35268f97" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", +] + +[[package]] +name = "polars-io" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" +dependencies = [ + "ahash 0.8.11", + "async-trait", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "futures", + "home", + "itoa 1.0.9", + "memchr", + "memmap2 0.7.1", + "num-traits", + "once_cell", + "percent-encoding 2.3.1", + "polars-arrow", + "polars-core", + "polars-error", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", + "tokio", + "tokio-util 0.7.11", +] + +[[package]] +name = "polars-lazy" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03877e74e42b5340ae52ded705f6d5d14563d90554c9177b01b91ed2412a56ed" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "glob", + "memchr", + "once_cell", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check 0.9.4", +] + +[[package]] +name = "polars-mem-engine" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea9e17771af750c94bf959885e4b3f5b14149576c62ef3ec1c9ef5827b2a30f" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6066552eb577d43b307027fb38096910b643ffb2c89a21628c7e41caf57848d0" +dependencies = [ + "ahash 0.8.11", + "argminmax", + "base64 0.22.1", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "hex", + "indexmap 2.4.0", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "unicode-reverse", + "version_check 0.9.4", +] + +[[package]] +name = "polars-parquet" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" +dependencies = [ + "ahash 0.8.11", + "async-stream", + "base64 0.22.1", + "brotli", + "ethnum", + "flate2", + "futures", + "lz4", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "simdutf8", + "snap", + "streaming-decompression", + "zstd 0.13.2", +] + +[[package]] +name = "polars-pipe" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021bce7768c330687d735340395a77453aa18dd70d57c184cbb302311e87c1b9" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown 0.14.5", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "uuid 1.8.0", + "version_check 0.9.4", +] + +[[package]] +name = "polars-plan" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "220d0d7c02d1c4375802b2813dbedcd1a184df39c43b74689e729ede8d5c2921" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "once_cell", + "percent-encoding 2.3.1", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-parquet", + "polars-time", + "polars-utils", + "rayon", + "recursive", + "regex", + "smartstring", + "strum_macros 0.26.4", + "version_check 0.9.4", +] + +[[package]] +name = "polars-row" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1d70d87a2882a64a43b431aea1329cb9a2c4100547c95c417cc426bb82408b3" +dependencies = [ + "bytemuck", + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fc1c9b778862f09f4a347f768dfdd3d0ba9957499d306d83c7103e0fa8dc5b" +dependencies = [ + "hex", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "rand 0.8.5", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "179f98313a15c0bfdbc8cc0f1d3076d08d567485b9952d46439f94fbc3085df5" +dependencies = [ + "atoi", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53e6dd89fcccb1ec1a62f752c9a9f2d482a85e9255153f46efecc617b4996d50" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "hashbrown 0.14.5", + "indexmap 2.4.0", + "num-traits", + "once_cell", + "polars-error", + "raw-cpuid 11.0.1", + "rayon", + "smartstring", + "stacker", + "sysinfo", + "version_check 0.9.4", +] + [[package]] name = "polling" version = "2.8.0" @@ -7652,6 +8490,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +[[package]] +name = "portable-atomic-util" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -7923,6 +8770,15 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -8420,6 +9276,12 @@ dependencies = [ "bitflags 2.4.0", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -8449,6 +9311,26 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.66", +] + [[package]] name = "redox_syscall" version = "0.3.5" @@ -8596,7 +9478,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -8612,6 +9494,7 @@ dependencies = [ "js-sys", "log 0.4.21", "mime 0.3.17", + "mime_guess 2.0.4", "once_cell", "percent-encoding 2.3.1", "pin-project-lite", @@ -8637,6 +9520,22 @@ dependencies = [ "winreg 0.52.0", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime 0.3.17", + "nom", + "pin-project-lite", + "reqwest 0.12.5", + "thiserror", +] + [[package]] name = "rfc6979" version = "0.3.1" @@ -9014,7 +9913,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -9229,6 +10128,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -9370,11 +10279,11 @@ version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_derive", "serde_json", @@ -9495,6 +10404,31 @@ dependencies = [ "dirs", ] +[[package]] +name = "shinkai-graphrag" +version = "0.1.0" +dependencies = [ + "anyhow", + "arrow 52.2.0", + "arrow-array 52.2.0", + "arrow-schema 52.2.0", + "async-openai", + "async-trait", + "futures", + "lancedb", + "ndarray", + "ndarray-stats", + "polars", + "polars-lazy", + "rand 0.8.5", + "reqwest 0.11.27", + "serde", + "serde_json", + "tiktoken-rs", + "tokio", + "uuid 1.8.0", +] + [[package]] name = "shinkai_crypto_identities" version = "0.1.1" @@ -9537,7 +10471,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9717,7 +10651,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9973,6 +10907,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg 1.1.0", + "static_assertions", + "version_check 0.9.4", +] + [[package]] name = "snafu" version = "0.7.5" @@ -9995,6 +10940,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.4.9" @@ -10100,6 +11051,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "winapi", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -10118,6 +11082,27 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "string_cache" version = "0.8.7" @@ -10280,7 +11265,7 @@ checksum = "874dcfa363995604333cf947ae9f751ca3af4522c60886774c4963943b4746b1" dependencies = [ "bincode", "bitflags 1.3.2", - "fancy-regex", + "fancy-regex 0.11.0", "flate2", "fnv", "once_cell", @@ -10295,6 +11280,20 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "sysinfo" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows 0.52.0", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -10343,7 +11342,7 @@ checksum = "f8d0582f186c0a6d55655d24543f15e43607299425c5ad8352c242b914b31856" dependencies = [ "aho-corasick", "arc-swap", - "base64 0.22.0", + "base64 0.22.1", "bitpacking", "byteorder", "census", @@ -10360,7 +11359,7 @@ dependencies = [ "lru 0.12.3", "lz4_flex", "measure_time", - "memmap2", + "memmap2 0.9.4", "num_cpus", "once_cell", "oneshot", @@ -10493,6 +11492,12 @@ dependencies = [ "xattr", ] +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + [[package]] name = "target-lexicon" version = "0.12.14" @@ -10622,6 +11627,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" +dependencies = [ + "anyhow", + "base64 0.21.7", + "bstr", + "fancy-regex 0.12.0", + "lazy_static", + "parking_lot 0.12.1", + "rustc-hash", +] + [[package]] name = "time" version = "0.1.45" @@ -10722,22 +11742,21 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.6", "tokio-macros", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -10752,9 +11771,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", @@ -10794,9 +11813,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", @@ -10897,7 +11916,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "toml_datetime", "winnow", ] @@ -10908,7 +11927,7 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d34d383cd00a163b4a5b85053df514d45bc330f6de7737edfe0a93311d1eaa03" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_spanned", "toml_datetime", @@ -11250,6 +12269,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" @@ -11397,7 +12425,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.1.0", + "indexmap 2.4.0", "serde", "serde_json", "utoipa-gen", @@ -11758,6 +12786,25 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.3", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.3", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -11993,6 +13040,12 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "xxhash-rust" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" + [[package]] name = "yaml-rust" version = "0.4.5" @@ -12079,7 +13132,7 @@ dependencies = [ "crossbeam-utils", "displaydoc", "flate2", - "indexmap 2.1.0", + "indexmap 2.4.0", "num_enum", "thiserror", ] @@ -12100,7 +13153,7 @@ dependencies = [ "displaydoc", "flate2", "hmac 0.12.1", - "indexmap 2.1.0", + "indexmap 2.4.0", "lzma-rs", "memchr", "pbkdf2 0.12.2", diff --git a/Cargo.toml b/Cargo.toml index b4fdec92c..7e4bb91c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,14 @@ members = [ "shinkai-libs/shinkai-dsl", "shinkai-libs/shinkai-sheet", "shinkai-libs/shinkai-fs-mirror", + "shinkai-libs/shinkai-graphrag", "shinkai-libs/shinkai-message-primitives", "shinkai-libs/shinkai-ocr", "shinkai-libs/shinkai-tcp-relayer", "shinkai-libs/shinkai-vector-resources", "shinkai-bin/*", "shinkai-cli-tools/*" -] +, "shinkai-libs/shinkai-graphrag"] resolver = "2" [workspace.package] diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore new file mode 100644 index 000000000..74deb7343 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -0,0 +1,2 @@ +.vscode +dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml new file mode 100644 index 000000000..3545f3d77 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "shinkai-graphrag" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.86" +arrow = "52.1" +arrow-array = "52.1" +arrow-schema = "52.1" +async-trait = "0.1.74" +futures = "0.3.30" +lancedb = "0.8.0" +polars = { version = "0.41.3", features = ["dtype-struct", "is_in", "lazy", "parquet"] } +polars-lazy = "0.41.3" +rand = "0.8.5" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.117" +tokio = { version = "1.36", features = ["full"] } +uuid = { version = "1.6.1", features = ["v4"] } + +[dev-dependencies] +async-openai = "0.23.4" +ndarray = "0.16.1" +ndarray-stats = "0.6.0" +reqwest = { version = "0.11.26", features = ["json"] } +tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/README.md b/shinkai-libs/shinkai-graphrag/README.md new file mode 100644 index 000000000..a30d249d0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/README.md @@ -0,0 +1,5 @@ +# Shinkai GraphRAG + +Rust implementation of GraphRAG Global and Local search. Documentation can be found [here](https://microsoft.github.io/graphrag/). + +See `tests` to see how to configure and use them. \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs new file mode 100644 index 000000000..b7b7e8ac7 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -0,0 +1,495 @@ +use std::{ + collections::{HashMap, HashSet}, + io::{Cursor, Read}, +}; + +use polars::{ + frame::DataFrame, + io::SerWriter, + prelude::{col, concat, CsvWriter, DataType, IntoLazy, LazyFrame, NamedFrom, SortMultipleOptions, UnionArgs}, + series::Series, +}; +use rand::prelude::SliceRandom; + +use crate::models::{CommunityReport, Entity}; + +#[derive(Debug, Clone)] +pub struct CommunityContextBuilderParams { + pub use_community_summary: bool, + pub column_delimiter: String, + pub shuffle_data: bool, + pub include_community_rank: bool, + pub min_community_rank: u32, + pub community_rank_name: String, + pub include_community_weight: bool, + pub community_weight_name: String, + pub normalize_community_weight: bool, + pub max_tokens: usize, + pub context_name: String, + // conversation_history: Option, + // conversation_history_user_turns_only: bool, + // conversation_history_max_turns: Option, +} + +pub struct GlobalCommunityContext { + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, +} + +impl GlobalCommunityContext { + pub fn new( + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, + ) -> Self { + Self { + community_reports, + entities, + num_tokens_fn, + } + } + + pub fn build_context( + &self, + context_builder_params: CommunityContextBuilderParams, + ) -> anyhow::Result<(Vec, HashMap)> { + let CommunityContextBuilderParams { + use_community_summary, + column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + community_rank_name, + include_community_weight, + community_weight_name, + normalize_community_weight, + max_tokens, + context_name, + } = context_builder_params; + + let (community_context, community_context_data) = CommunityContext::build_community_context( + self.community_reports.clone(), + self.entities.clone(), + self.num_tokens_fn, + use_community_summary, + &column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + &community_rank_name, + include_community_weight, + &community_weight_name, + normalize_community_weight, + max_tokens, + false, + &context_name, + )?; + + let final_context = community_context; + let final_context_data = community_context_data; + + Ok((final_context, final_context_data)) + } +} + +pub struct CommunityContext {} + +impl CommunityContext { + pub fn build_community_context( + community_reports: Vec, + entities: Option>, + num_tokens_fn: fn(&str) -> usize, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: u32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: usize, + single_batch: bool, + context_name: &str, + ) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.as_ref().is_some_and(|e| !e.is_empty()) + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .as_ref() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = Self::_compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports + .iter() + .filter(|&report| _is_included(report)) + .cloned() + .collect(); + + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); + } + + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); + } + + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes.clone()); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch = Batch::new(); + + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); + + for report in selected_reports { + let (new_context_text, new_context) = _report_context_text(&report, &attributes); + let new_tokens = num_tokens_fn(&new_context_text); + + // add the current batch to the context data and start a new batch if we are in multi-batch mode + if batch.batch_tokens + new_tokens > max_tokens { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + column_delimiter, + )?; + + if single_batch { + break; + } + + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); + } + + batch.batch_text.push_str(&new_context_text); + batch.batch_tokens += new_tokens; + batch.batch_records.push(new_context); + } + + if !all_context_text.contains(&batch.batch_text) { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + column_delimiter, + )?; + } + + if all_context_records.is_empty() { + eprintln!("Warning: No community records added when building community context."); + return Ok((Vec::new(), HashMap::new())); + } + + let records_concat = concat( + all_context_records + .into_iter() + .map(|df| df.lazy()) + .collect::>(), + UnionArgs::default(), + )? + .collect()?; + + Ok(( + all_context_text, + HashMap::from([(context_name.to_lowercase(), records_concat)]), + )) + } + + fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, + ) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports; + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.iter().flatten().cloned().collect::>().len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for report in &mut community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + + return community_reports; + } + community_reports + } +} + +struct Batch { + batch_text: String, + batch_tokens: usize, + batch_records: Vec>, +} + +impl Batch { + fn new() -> Self { + Batch { + batch_text: String::new(), + batch_tokens: 0, + batch_records: Vec::new(), + } + } + + fn init_batch( + &mut self, + context_name: &str, + header: &[String], + column_delimiter: &str, + num_tokens_fn: fn(&str) -> usize, + ) { + self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + self.batch_tokens = num_tokens_fn(&self.batch_text); + self.batch_records.clear(); + } + + fn cut_batch( + &mut self, + all_context_text: &mut Vec, + all_context_records: &mut Vec, + entities: Option>, + header: &[String], + community_weight_name: &str, + community_rank_name: &str, + include_community_weight: bool, + include_community_rank: bool, + column_delimiter: &str, + ) -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.as_ref().is_some_and(|e| !e.is_empty()) { + Some(community_weight_name) + } else { + None + }; + let rank_column = if include_community_rank { + Some(community_rank_name) + } else { + None + }; + + let mut record_df = Self::_convert_report_context_to_df( + self.batch_records.clone(), + header.to_owned(), + weight_column, + rank_column, + )?; + if record_df.is_empty() { + return Ok(()); + } + + let column_delimiter = if column_delimiter.is_empty() { + b'|' + } else { + column_delimiter.as_bytes()[0] + }; + + let mut buffer = Cursor::new(Vec::new()); + CsvWriter::new(&mut buffer) + .include_header(true) + .with_separator(column_delimiter) + .finish(&mut record_df)?; + + let mut current_context_text = String::new(); + buffer.set_position(0); + buffer.read_to_string(&mut current_context_text)?; + + if all_context_text.contains(¤t_context_text) { + return Ok(()); + } + + all_context_text.push(current_context_text); + all_context_records.push(record_df); + + Ok(()) + } + + fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); + } + + let mut data_series = Vec::new(); + for (index, header) in header.iter().enumerate() { + let records = context_records + .iter() + .map(|r| r.get(index).unwrap_or(&String::new()).to_owned()) + .collect::>(); + let series = Series::new(header, records); + data_series.push(series); + } + + let record_df = DataFrame::new(data_series)?; + + Self::_rank_report_context(record_df, weight_column, rank_column) + } + + fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + let mut rank_attributes = Vec::new(); + + let mut report_df = report_df; + + if let Some(weight_column) = weight_column { + rank_attributes.push(weight_column); + report_df = report_df + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + } + + if let Some(rank_column) = rank_column { + rank_attributes.push(rank_column); + report_df = report_df + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + } + + if !rank_attributes.is_empty() { + report_df = report_df + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + } + + Ok(report_df) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/conversation_history.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/conversation_history.rs new file mode 100644 index 000000000..1455d3264 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/conversation_history.rs @@ -0,0 +1 @@ +pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs new file mode 100644 index 000000000..305d02af0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/entity_extraction.rs @@ -0,0 +1,50 @@ +use std::collections::HashSet; + +use crate::{ + input::retrieval::entities::{get_entity_by_key, get_entity_by_name}, + llm::base::BaseTextEmbedding, + models::Entity, + vector_stores::{lancedb::LanceDBVectorStore, vector_store::VectorStore}, +}; + +pub async fn map_query_to_entities( + query: &str, + text_embedding_vectorstore: &LanceDBVectorStore, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), + all_entities: &Vec, + embedding_vectorstore_key: &str, + include_entity_names: Option>, + exclude_entity_names: Option>, + k: usize, + oversample_scaler: usize, +) -> anyhow::Result> { + let include_entity_names = include_entity_names.unwrap_or_default(); + let exclude_entity_names: HashSet = exclude_entity_names.unwrap_or_default().into_iter().collect(); + let mut matched_entities = Vec::new(); + + if !query.is_empty() { + let search_results = text_embedding_vectorstore + .similarity_search_by_text(query, text_embedder, k * oversample_scaler) + .await?; + + for result in search_results { + if let Some(matched) = get_entity_by_key(all_entities, embedding_vectorstore_key, &result.document.id) { + matched_entities.push(matched); + } + } + } else { + let mut all_entities = all_entities.clone(); + all_entities.sort_by(|a, b| b.rank.unwrap_or(0).cmp(&a.rank.unwrap_or(0))); + matched_entities = all_entities.iter().take(k).cloned().collect(); + } + + matched_entities.retain(|entity| !exclude_entity_names.contains(&entity.title)); + + let mut included_entities = Vec::new(); + for entity_name in include_entity_names { + included_entities.extend(get_entity_by_name(all_entities, &entity_name)); + } + + included_entities.extend(matched_entities); + Ok(included_entities) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs new file mode 100644 index 000000000..95cbbe4da --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/local_context.rs @@ -0,0 +1,437 @@ +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, +}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::{ + input::retrieval::{ + entities::to_entity_dataframe, + relationships::{ + get_candidate_relationships, get_entities_from_relationships, get_in_network_relationships, + get_out_network_relationships, to_relationship_dataframe, + }, + }, + models::{Entity, Relationship}, +}; + +pub fn build_entity_context( + selected_entities: Vec, + num_tokens_fn: fn(&str) -> usize, + max_tokens: usize, + include_entity_rank: bool, + rank_description: &str, + column_delimiter: &str, + context_name: &str, +) -> anyhow::Result<(String, DataFrame)> { + if selected_entities.is_empty() { + return Ok((String::new(), DataFrame::default())); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec!["id".to_string(), "entity".to_string(), "description".to_string()]; + + if include_entity_rank { + header.push(rank_description.to_string()); + } + + let attribute_cols = if let Some(first_entity) = selected_entities.first().cloned() { + first_entity + .attributes + .unwrap_or_default() + .keys() + .cloned() + .collect::>() + } else { + Vec::new() + }; + + header.extend(attribute_cols.clone()); + current_context_text += &header.join(column_delimiter); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for entity in selected_entities { + let mut new_context = vec![ + entity.short_id.clone().unwrap_or_default(), + entity.title.clone(), + entity.description.clone().unwrap_or_default(), + ]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(entity.short_id.unwrap_or_default()); + records.entry("entity").or_insert_with(Vec::new).push(entity.title); + records + .entry("description") + .or_insert_with(Vec::new) + .push(entity.description.unwrap_or_default()); + + if include_entity_rank { + new_context.push(entity.rank.unwrap_or(0).to_string()); + + records + .entry(rank_description) + .or_insert_with(Vec::new) + .push(entity.rank.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + let field_value = entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let new_context_text = new_context.join(column_delimiter); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += &format!("\n{}", new_context_text); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if include_entity_rank && header == rank_description { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok((current_context_text, record_df)) +} + +pub fn build_relationship_context( + selected_entities: &[Entity], + relationships: &[Relationship], + num_tokens_fn: fn(&str) -> usize, + include_relationship_weight: bool, + max_tokens: usize, + top_k_relationships: usize, + relationship_ranking_attribute: &str, + column_delimiter: &str, + context_name: &str, +) -> anyhow::Result<(String, DataFrame)> { + // Filter relationships based on the criteria + let selected_relationships = _filter_relationships( + selected_entities, + relationships, + top_k_relationships, + relationship_ranking_attribute, + ); + + if selected_entities.is_empty() || selected_relationships.is_empty() { + return Ok((String::new(), DataFrame::default())); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec![ + "id".to_string(), + "source".to_string(), + "target".to_string(), + "description".to_string(), + ]; + + if include_relationship_weight { + header.push("weight".to_string()); + } + + let attribute_cols = if let Some(first_rel) = selected_relationships.first().cloned() { + first_rel + .attributes + .unwrap_or_default() + .keys() + .cloned() + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + current_context_text.push_str(&header.join(column_delimiter)); + current_context_text.push('\n'); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for rel in selected_relationships { + let mut new_context = vec![ + rel.short_id.clone().unwrap_or_default(), + rel.source.clone(), + rel.target.clone(), + rel.description.clone().unwrap_or_default(), + ]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(rel.short_id.unwrap_or_default()); + records.entry("source").or_insert_with(Vec::new).push(rel.source); + records.entry("target").or_insert_with(Vec::new).push(rel.target); + records + .entry("description") + .or_insert_with(Vec::new) + .push(rel.description.unwrap_or_default()); + + if include_relationship_weight { + new_context.push(rel.weight.map_or(String::new(), |w| w.to_string())); + + records + .entry("weight") + .or_insert_with(Vec::new) + .push(rel.weight.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + let field_value = rel + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + rel.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let mut new_context_text = new_context.join(column_delimiter); + new_context_text.push('\n'); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += new_context_text.as_str(); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "weight" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok((current_context_text, record_df)) +} + +fn _filter_relationships( + selected_entities: &[Entity], + relationships: &[Relationship], + top_k_relationships: usize, + relationship_ranking_attribute: &str, +) -> Vec { + // First priority: in-network relationships (i.e. relationships between selected entities) + let in_network_relationships = + get_in_network_relationships(selected_entities, relationships, relationship_ranking_attribute); + + // Second priority - out-of-network relationships + // (i.e. relationships between selected entities and other entities that are not within the selected entities) + let mut out_network_relationships = + get_out_network_relationships(selected_entities, relationships, relationship_ranking_attribute); + + if out_network_relationships.len() <= 1 { + return [in_network_relationships, out_network_relationships].concat(); + } + + // within out-of-network relationships, prioritize mutual relationships + // (i.e. relationships with out-network entities that are shared with multiple selected entities) + let selected_entity_names: HashSet = selected_entities.iter().map(|e| e.title.clone()).collect(); + + let out_network_source_names: Vec = out_network_relationships + .iter() + .filter(|r| !selected_entity_names.contains(&r.source)) + .map(|r| r.source.clone()) + .collect(); + + let out_network_target_names: Vec = out_network_relationships + .iter() + .filter(|r| !selected_entity_names.contains(&r.target)) + .map(|r| r.target.clone()) + .collect(); + + let out_network_entity_names: HashSet = out_network_source_names + .into_iter() + .chain(out_network_target_names) + .collect(); + + let mut out_network_entity_links: HashMap = HashMap::new(); + + for entity_name in out_network_entity_names { + let targets: HashSet = out_network_relationships + .iter() + .filter(|r| r.source == entity_name) + .map(|r| r.target.clone()) + .collect(); + + let sources: HashSet = out_network_relationships + .iter() + .filter(|r| r.target == entity_name) + .map(|r| r.source.clone()) + .collect(); + + out_network_entity_links.insert(entity_name, targets.union(&sources).count()); + } + + // sort out-network relationships by number of links and rank_attributes + for relationship in &mut out_network_relationships { + if relationship.attributes.is_none() { + relationship.attributes = Some(HashMap::new()); + } + + let links = if out_network_entity_links.contains_key(&relationship.source) { + *out_network_entity_links.get(&relationship.source).unwrap() + } else { + *out_network_entity_links.get(&relationship.target).unwrap() + }; + relationship + .attributes + .as_mut() + .unwrap() + .insert("links".to_string(), links.to_string()); + } + + // Sort by attributes[links] first, then by ranking_attribute + if relationship_ranking_attribute == "weight" { + out_network_relationships.sort_by(|a, b| { + let a_links = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_links = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + + b_links + .partial_cmp(&a_links) + .unwrap_or(Ordering::Equal) + .then(b.weight.partial_cmp(&a.weight).unwrap_or(Ordering::Equal)) + }); + } else { + out_network_relationships.sort_by(|a, b| { + let a_links = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_links = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get("links")) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(relationship_ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0.0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(relationship_ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0.0); + + b_links + .partial_cmp(&a_links) + .unwrap_or(Ordering::Equal) + .then(b_rank.partial_cmp(&a_rank).unwrap_or(Ordering::Equal)) + }); + } + + let relationship_budget = top_k_relationships * selected_entities.len(); + out_network_relationships.truncate(relationship_budget); + + let mut selected_relationships = in_network_relationships; + selected_relationships.extend(out_network_relationships); + + selected_relationships +} + +pub fn get_candidate_context( + selected_entities: &[Entity], + entities: &[Entity], + relationships: &[Relationship], + include_entity_rank: bool, + entity_rank_description: &str, + include_relationship_weight: bool, +) -> anyhow::Result> { + let mut candidate_context = HashMap::new(); + + let candidate_relationships = get_candidate_relationships(selected_entities, relationships); + candidate_context.insert( + "relationships".to_string(), + to_relationship_dataframe(&candidate_relationships, include_relationship_weight)?, + ); + + let candidate_entities = get_entities_from_relationships(&candidate_relationships, entities); + candidate_context.insert( + "entities".to_string(), + to_entity_dataframe(&candidate_entities, include_entity_rank, entity_rank_description)?, + ); + + Ok(candidate_context) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs new file mode 100644 index 000000000..a2ed01737 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -0,0 +1,5 @@ +pub mod community_context; +pub mod conversation_history; +pub mod entity_extraction; +pub mod local_context; +pub mod source_context; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs new file mode 100644 index 000000000..af4b0205b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/source_context.rs @@ -0,0 +1,146 @@ +use std::collections::{HashMap, HashSet}; + +use polars::frame::DataFrame; +use polars::prelude::NamedFrom; +use polars::series::Series; +use rand::prelude::SliceRandom; +use rand::{rngs::StdRng, SeedableRng}; + +use crate::models::{Entity, Relationship, TextUnit}; + +pub fn build_text_unit_context( + text_units: Vec, + num_tokens_fn: fn(&str) -> usize, + column_delimiter: &str, + shuffle_data: bool, + max_tokens: usize, + context_name: &str, + random_state: u64, +) -> anyhow::Result<(String, HashMap)> { + if text_units.is_empty() { + return Ok((String::new(), HashMap::new())); + } + + let mut text_units = text_units; + + let mut unique_ids = HashSet::new(); + text_units.retain(|unit| unique_ids.insert(unit.id.clone())); + + if shuffle_data { + let mut rng = StdRng::seed_from_u64(random_state); + text_units.shuffle(&mut rng); + } + + let mut current_context_text = format!("-----{}-----\n", context_name); + let mut header = vec!["id".to_string(), "text".to_string()]; + + let attribute_cols = if let Some(text_unit) = text_units.first().cloned() { + text_unit + .attributes + .unwrap_or_default() + .keys().cloned() + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + current_context_text += &header.join(column_delimiter); + + let mut current_tokens = num_tokens_fn(¤t_context_text); + let mut records = HashMap::new(); + + for unit in text_units { + let mut new_context = vec![unit.short_id.clone().unwrap_or_default(), unit.text.clone()]; + + records + .entry("id") + .or_insert_with(Vec::new) + .push(unit.short_id.unwrap_or_default()); + records.entry("text").or_insert_with(Vec::new).push(unit.text); + + for field in &attribute_cols { + let field_value = unit + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + new_context.push(field_value); + + records.entry(field).or_insert_with(Vec::new).push( + unit.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + let new_context_text = new_context.join(column_delimiter); + let new_tokens = num_tokens_fn(&new_context_text); + + if current_tokens + new_tokens > max_tokens { + break; + } + + current_context_text += &format!("\n{}", new_context_text); + current_tokens += new_tokens; + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + let series = Series::new(header, data_values); + data_series.push(series); + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(( + current_context_text, + HashMap::from([(context_name.to_lowercase(), record_df)]), + )) +} + +pub fn count_relationships( + text_unit: &TextUnit, + entity: &Entity, + relationships: &HashMap, +) -> usize { + let matching_relationships: Vec<&Relationship> = if text_unit.relationship_ids.is_none() { + let entity_relationships: Vec<&Relationship> = relationships + .values() + .filter(|rel| rel.source == entity.title || rel.target == entity.title) + .collect(); + + let entity_relationships: Vec<&Relationship> = entity_relationships + .into_iter() + .filter(|rel| rel.text_unit_ids.is_some()) + .collect(); + + entity_relationships + .into_iter() + .filter(|rel| rel.text_unit_ids.as_ref().unwrap().contains(&text_unit.id)) + .collect() + } else { + let text_unit_relationships: Vec<&Relationship> = text_unit + .relationship_ids + .as_ref() + .unwrap() + .iter() + .filter_map(|rel_id| relationships.get(rel_id)) + .collect(); + + text_unit_relationships + .into_iter() + .filter(|rel| rel.source == entity.title || rel.target == entity.title) + .collect() + }; + + matching_relationships.len() +} diff --git a/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs new file mode 100644 index 000000000..9b2f97b71 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/indexer_adapters.rs @@ -0,0 +1,143 @@ +use std::vec; + +use polars::prelude::*; +use polars_lazy::dsl::col; + +use crate::{ + input::loaders::dfs::{read_community_reports, read_entities, read_relationships, read_text_units}, + models::{CommunityReport, Entity, Relationship, TextUnit}, +}; + +pub fn read_indexer_entities( + final_nodes: &DataFrame, + final_entities: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_embedding_df = final_entities.clone(); + + let entity_df = entity_df + .lazy() + .rename(["title", "degree"], ["name", "rank"]) + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .with_column(col("rank").cast(DataType::Int32)) + .group_by([col("name"), col("rank")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .join( + entity_embedding_df.lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let entities = read_entities( + entity_df, + "id", + Some("human_readable_id"), + "name", + Some("type"), + Some("description"), + None, + Some("description_embedding"), + None, + Some("community"), + Some("text_unit_ids"), + None, + Some("rank"), + )?; + + Ok(entities) +} + +pub fn read_indexer_reports( + final_community_reports: &DataFrame, + final_nodes: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let filtered_community_df = entity_df + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .with_column(col("community").cast(DataType::Int32)) + .group_by([col("title")]) + .agg([col("community").max()]) + .with_column(col("community").cast(DataType::String)) + .filter(len().over([col("community")]).gt(lit(1))) + .collect()?; + + let report_df = final_community_reports.clone(); + let report_df = filter_under_community_level(&report_df, community_level)?; + + let report_df = report_df + .lazy() + .join( + filtered_community_df.lazy(), + [col("community")], + [col("community")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let reports = read_community_reports( + report_df, + "community", + Some("community"), + "title", + "community", + "summary", + "full_content", + Some("rank"), + None, + None, + )?; + Ok(reports) +} + +pub fn read_indexer_relationships(final_relationships: &DataFrame) -> anyhow::Result> { + let relationships = read_relationships( + final_relationships.clone(), + "id", + Some("human_readable_id"), + "source", + "target", + Some("description"), + None, + Some("weight"), + Some("text_unit_ids"), + None, + Some(vec!["rank"]), + )?; + + Ok(relationships) +} + +pub fn read_indexer_text_units(final_text_units: &DataFrame) -> anyhow::Result> { + let text_units = read_text_units( + final_text_units.clone(), + "id", + None, + "text", + Some("entity_ids"), + Some("relationship_ids"), + Some("n_tokens"), + Some("document_ids"), + Some("text_embedding"), + None, + )?; + + Ok(text_units) +} + +fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { + let mask = df.column("level")?.i64()?.lt_eq(community_level); + let result = df.filter(&mask)?; + + Ok(result) +} diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs new file mode 100644 index 000000000..62066f1ba --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/dfs.rs @@ -0,0 +1,631 @@ +use std::collections::{HashMap, HashSet}; + +use polars::{ + frame::DataFrame, + prelude::{AnyValue, ChunkedArray, IntoVec, StringChunked}, +}; + +use crate::{ + models::{CommunityReport, Entity, Relationship, TextUnit}, + vector_stores::{ + lancedb::LanceDBVectorStore, + vector_store::{VectorStore, VectorStoreDocument}, + }, +}; + +pub async fn store_entity_semantic_embeddings( + entities: Vec, + vectorstore: &mut LanceDBVectorStore, +) -> anyhow::Result<()> { + let documents: Vec = entities + .into_iter() + .map(|entity| { + let mut attributes = HashMap::new(); + attributes.insert("title".to_string(), entity.title.clone()); + if let Some(entity_attributes) = entity.attributes { + attributes.extend(entity_attributes); + } + + VectorStoreDocument { + id: entity.id, + text: entity.description, + vector: entity + .description_embedding + .map(|v| v.into_iter().map(|f| f as f32).collect()), + attributes, + } + }) + .collect(); + + vectorstore.load_documents(documents, true).await?; + Ok(()) +} + +pub fn read_entities( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + type_col: Option<&str>, + description_col: Option<&str>, + name_embedding_col: Option<&str>, + description_embedding_col: Option<&str>, + graph_embedding_col: Option<&str>, + community_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + rank_col: Option<&str>, + // attributes_cols: Option>, +) -> anyhow::Result> { + let df_column_names = df.get_column_names(); + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + type_col, + description_col, + name_embedding_col, + description_embedding_col, + graph_embedding_col, + community_col, + text_unit_ids_col, + document_ids_col, + rank_col, + ] + .iter() + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) + .collect::>(); + + let column_names = column_names.into_iter().collect::>().into_vec(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut entities = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = Entity { + id: get_field(row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or_default(), + short_id: Some( + short_id_col + .and_then(|short_id| get_field(row, short_id, &column_names)) + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or_default(), + entity_type: type_col + .and_then(|type_col| get_field(row, type_col, &column_names)) + .map(|entity_type| entity_type.to_string()), + description: description_col + .and_then(|description_col| get_field(row, description_col, &column_names)) + .map(|description| description.to_string()), + name_embedding: name_embedding_col.map(|name_embedding_col| { + get_field(row, name_embedding_col, &column_names) + .map(|name_embedding| match name_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(name_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_default() + }), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_default() + }), + graph_embedding: graph_embedding_col.map(|graph_embedding_col| { + get_field(row, graph_embedding_col, &column_names) + .map(|graph_embedding| match graph_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(graph_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_default() + }), + community_ids: community_col.map(|community_col| { + get_field(row, community_col, &column_names) + .map(|community_ids| match community_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + rank: rank_col.and_then(|rank_col| { + get_field(row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }), + attributes: None, + }; + entities.push(report); + } + + let mut unique_entities: Vec = Vec::new(); + let mut entity_ids: HashSet = HashSet::new(); + + for entity in entities { + if !entity_ids.contains(&entity.id) { + unique_entities.push(entity.clone()); + entity_ids.insert(entity.id); + } + } + + Ok(unique_entities) +} + +pub fn read_community_reports( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + community_col: &str, + summary_col: &str, + content_col: &str, + rank_col: Option<&str>, + _summary_embedding_col: Option<&str>, + _content_embedding_col: Option<&str>, + // attributes_cols: Option<&[&str]>, +) -> anyhow::Result> { + let df_column_names = df.get_column_names(); + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + Some(community_col), + Some(summary_col), + Some(content_col), + rank_col, + ] + .iter() + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) + .collect::>(); + + let column_names: Vec = column_names.into_iter().collect::>().into_vec(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut reports = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = CommunityReport { + id: get_field(row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or_default(), + short_id: Some( + short_id_col + .and_then(|short_id| get_field(row, short_id, &column_names)) + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or_default(), + community_id: get_field(row, community_col, &column_names) + .map(|community| community.to_string()) + .unwrap_or_default(), + summary: get_field(row, summary_col, &column_names) + .map(|summary| summary.to_string()) + .unwrap_or_default(), + full_content: get_field(row, content_col, &column_names) + .map(|content| content.to_string()) + .unwrap_or_default(), + rank: rank_col.and_then(|rank_col| { + get_field(row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }), + summary_embedding: None, + full_content_embedding: None, + attributes: None, + }; + reports.push(report); + } + + let mut unique_reports: Vec = Vec::new(); + let mut report_ids: HashSet = HashSet::new(); + + for report in reports { + if !report_ids.contains(&report.id) { + unique_reports.push(report.clone()); + report_ids.insert(report.id); + } + } + + Ok(unique_reports) +} + +pub fn read_relationships( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + source_col: &str, + target_col: &str, + description_col: Option<&str>, + description_embedding_col: Option<&str>, + weight_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + attributes_cols: Option>, +) -> anyhow::Result> { + let df_column_names = df.get_column_names(); + let mut column_names = [ + Some(id_col), + short_id_col, + Some(source_col), + Some(target_col), + description_col, + description_embedding_col, + weight_col, + text_unit_ids_col, + document_ids_col, + ] + .iter() + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) + .collect::>(); + + if let Some(cols) = attributes_cols.as_ref() { + cols.iter().for_each(|col| { + column_names.insert(col.to_string()); + }); + } + + let column_names = column_names.into_iter().collect::>(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut relationships = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = Relationship { + id: get_field(row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or_default(), + short_id: Some( + short_id_col + .and_then(|short_id| get_field(row, short_id, &column_names)) + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + source: get_field(row, source_col, &column_names) + .map(|source| source.to_string()) + .unwrap_or_default(), + target: get_field(row, target_col, &column_names) + .map(|target| target.to_string()) + .unwrap_or_default(), + description: description_col + .and_then(|description| get_field(row, description, &column_names)) + .map(|description| description.to_string()), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_default() + }), + weight: weight_col.and_then(|weight_col| { + get_field(row, weight_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + attributes: attributes_cols.as_ref().map(|cols| { + cols.iter() + .map(|col| { + get_field(row, col, &column_names) + .map(|v| (col.to_string(), v.to_string())) + .unwrap_or((String::new(), String::new())) + }) + .collect::>() + }), + }; + relationships.push(report); + } + + let mut unique_relationships: Vec = Vec::new(); + let mut relationship_ids: HashSet = HashSet::new(); + + for relationship in relationships { + if !relationship_ids.contains(&relationship.id) { + unique_relationships.push(relationship.clone()); + relationship_ids.insert(relationship.id); + } + } + + Ok(unique_relationships) +} + +pub fn read_text_units( + df: DataFrame, + id_col: &str, + short_id_col: Option<&str>, + text_col: &str, + entities_col: Option<&str>, + relationships_col: Option<&str>, + tokens_col: Option<&str>, + document_ids_col: Option<&str>, + embedding_col: Option<&str>, + attributes_cols: Option>, +) -> anyhow::Result> { + let df_column_names = df.get_column_names(); + let mut column_names = [ + Some(id_col), + short_id_col, + Some(text_col), + entities_col, + relationships_col, + tokens_col, + document_ids_col, + embedding_col, + ] + .iter() + .filter_map(|&v| { + v.map(|v| v.to_string()) + .filter(|v| df_column_names.contains(&v.as_str())) + }) + .collect::>(); + + if let Some(cols) = attributes_cols.as_ref() { + cols.iter().for_each(|col| { + column_names.insert(col.to_string()); + }); + } + + let column_names = column_names.into_iter().collect::>(); + + let mut df = df; + df.as_single_chunk_par(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value); + } + } + rows.push(row_values); + } + + let mut text_units = Vec::new(); + for (idx, row) in rows.iter().enumerate() { + let report = TextUnit { + id: get_field(row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or_default(), + short_id: Some( + short_id_col + .and_then(|short_id| get_field(row, short_id, &column_names)) + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + text: get_field(row, text_col, &column_names) + .map(|text| text.to_string()) + .unwrap_or_default(), + entity_ids: entities_col.map(|entities_col| { + get_field(row, entities_col, &column_names) + .map(|entity_ids| match entity_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + relationship_ids: relationships_col.map(|relationships_col| { + get_field(row, relationships_col, &column_names) + .map(|relationship_ids| match relationship_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + text_embedding: embedding_col.map(|embedding_col| { + get_field(row, embedding_col, &column_names) + .map(|embedding| match embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_default() + }), + n_tokens: tokens_col.and_then(|tokens_col| { + get_field(row, tokens_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_default() + }), + attributes: attributes_cols.as_ref().map(|cols| { + cols.iter() + .map(|col| { + get_field(row, col, &column_names) + .map(|v| (col.to_string(), v.to_string())) + .unwrap_or((String::new(), String::new())) + }) + .collect::>() + }), + }; + text_units.push(report); + } + + let mut unique_text_units: Vec = Vec::new(); + let mut text_unit_ids: HashSet = HashSet::new(); + + for unit in text_units { + if !text_unit_ids.contains(&unit.id) { + unique_text_units.push(unit.clone()); + text_unit_ids.insert(unit.id); + } + } + + Ok(unique_text_units) +} + +fn get_field<'a>(row: &'a [AnyValue<'a>], column_name: &'a str, column_names: &'a [String]) -> Option> { + match column_names.iter().position(|x| x == column_name) { + Some(index) => row.get(index).cloned(), + None => None, + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs new file mode 100644 index 000000000..289f06621 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/loaders/mod.rs @@ -0,0 +1 @@ +pub mod dfs; \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/input/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/mod.rs new file mode 100644 index 000000000..6ce304fa8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/mod.rs @@ -0,0 +1,2 @@ +pub mod loaders; +pub mod retrieval; diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs new file mode 100644 index 000000000..002daca00 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/community_reports.rs @@ -0,0 +1,106 @@ +use std::collections::{HashMap, HashSet}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::models::{CommunityReport, Entity}; + +pub fn get_candidate_communities( + selected_entities: Vec, + community_reports: Vec, + include_community_rank: bool, + use_community_summary: bool, +) -> anyhow::Result { + let mut selected_community_ids: HashSet = HashSet::new(); + for entity in &selected_entities { + if let Some(community_ids) = &entity.community_ids { + selected_community_ids.extend(community_ids.iter().cloned()); + } + } + + let mut selected_reports: Vec = Vec::new(); + for community in &community_reports { + if selected_community_ids.contains(&community.id) { + selected_reports.push(community.clone()); + } + } + + to_community_report_dataframe(selected_reports, include_community_rank, use_community_summary) +} + +pub fn to_community_report_dataframe( + reports: Vec, + include_community_rank: bool, + use_community_summary: bool, +) -> anyhow::Result { + if reports.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec!["id".to_string(), "title".to_string()]; + let attribute_cols: Vec = reports[0] + .attributes + .as_ref() + .map(|attrs| attrs.keys().filter(|&col| !header.contains(col)).cloned().collect()) + .unwrap_or_default(); + + header.extend(attribute_cols.iter().cloned()); + header.push(if use_community_summary { "summary" } else { "content" }.to_string()); + if include_community_rank { + header.push("rank".to_string()); + } + + let mut records = HashMap::new(); + for report in reports { + records + .entry("id") + .or_insert_with(Vec::new) + .push(report.short_id.unwrap_or_default()); + records.entry("title").or_insert_with(Vec::new).push(report.title); + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + + if use_community_summary { + records.entry("summary").or_insert_with(Vec::new).push(report.summary); + } else { + records + .entry("content") + .or_insert_with(Vec::new) + .push(report.full_content); + } + + if include_community_rank { + records + .entry("rank") + .or_insert_with(Vec::new) + .push(report.rank.map(|r| r.to_string()).unwrap_or_default()); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "rank" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = DataFrame::new(data_series)?; + + Ok(record_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs new file mode 100644 index 000000000..f81315204 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/entities.rs @@ -0,0 +1,146 @@ +use std::collections::HashMap; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; +use uuid::Uuid; + +use crate::models::Entity; + +pub fn get_entity_by_key(entities: &Vec, key: &str, value: &str) -> Option { + for entity in entities { + match key { + "id" => { + if entity.id == value + || is_valid_uuid(value) && entity.id == Uuid::parse_str(value).unwrap().to_string().replace('-', "") + { + return Some(entity.clone()); + } + } + "short_id" => { + if entity.short_id.as_ref().unwrap_or(&"".to_string()) == value + || is_valid_uuid(value) + && entity.short_id.as_ref().unwrap_or(&"".to_string()) + == Uuid::parse_str(value).unwrap().to_string().replace('-', "").as_str() + { + return Some(entity.clone()); + } + } + "title" => { + if entity.title == value { + return Some(entity.clone()); + } + } + "entity_type" => { + if entity.entity_type.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + "description" => { + if entity.description.as_ref().unwrap_or(&"".to_string()) == value { + return Some(entity.clone()); + } + } + _ => {} + } + } + None +} + +pub fn get_entity_by_name(entities: &[Entity], entity_name: &str) -> Vec { + entities + .iter() + .filter(|entity| entity.title == entity_name) + .cloned() + .collect() +} + +pub fn to_entity_dataframe( + entities: &Vec, + include_entity_rank: bool, + rank_description: &str, +) -> anyhow::Result { + if entities.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec!["id".to_string(), "entity".to_string(), "description".to_string()]; + + if include_entity_rank { + header.push(rank_description.to_string()); + } + + let attribute_cols = if let Some(first_entity) = entities.first().cloned() { + first_entity + .attributes + .unwrap_or_default() + .keys() + .cloned() + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for entity in entities { + records + .entry("id") + .or_insert_with(Vec::new) + .push(entity.short_id.clone().unwrap_or_default()); + records + .entry("entity") + .or_insert_with(Vec::new) + .push(entity.title.clone()); + records + .entry("description") + .or_insert_with(Vec::new) + .push(entity.description.clone().unwrap_or_default()); + + if include_entity_rank { + records + .entry("rank") + .or_insert_with(Vec::new) + .push(entity.rank.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + entity + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "rank" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} + +pub fn is_valid_uuid(value: &str) -> bool { + Uuid::parse_str(value).is_ok() +} diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/mod.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/mod.rs new file mode 100644 index 000000000..b56ec376b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/mod.rs @@ -0,0 +1,4 @@ +pub mod community_reports; +pub mod entities; +pub mod relationships; +pub mod text_units; diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs new file mode 100644 index 000000000..8c1ea7c94 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/relationships.rs @@ -0,0 +1,263 @@ +use std::{cmp::Ordering, collections::HashMap}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::models::{Entity, Relationship}; + +pub fn get_in_network_relationships( + selected_entities: &[Entity], + relationships: &[Relationship], + ranking_attribute: &str, +) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); + + let selected_relationships: Vec = relationships + .iter() + .filter(|relationship| { + selected_entity_names.contains(&relationship.source) && selected_entity_names.contains(&relationship.target) + }) + .cloned() + .collect(); + + if selected_relationships.len() <= 1 { + return selected_relationships; + } + + // Sort by ranking attribute + sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) +} + +pub fn get_out_network_relationships( + selected_entities: &[Entity], + relationships: &[Relationship], + ranking_attribute: &str, +) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|e| e.title.clone()).collect(); + + let source_relationships: Vec = relationships + .iter() + .filter(|r| selected_entity_names.contains(&r.source) && !selected_entity_names.contains(&r.target)) + .cloned() + .collect(); + + let target_relationships: Vec = relationships + .iter() + .filter(|r| selected_entity_names.contains(&r.target) && !selected_entity_names.contains(&r.source)) + .cloned() + .collect(); + + let selected_relationships = [source_relationships, target_relationships].concat(); + + sort_relationships_by_ranking_attribute(selected_relationships, selected_entities.to_vec(), ranking_attribute) +} + +pub fn get_candidate_relationships(selected_entities: &[Entity], relationships: &[Relationship]) -> Vec { + let selected_entity_names: Vec = selected_entities.iter().map(|entity| entity.title.clone()).collect(); + + relationships + .iter() + .filter(|relationship| { + selected_entity_names.contains(&relationship.source) || selected_entity_names.contains(&relationship.target) + }) + .cloned() + .collect() +} + +pub fn get_entities_from_relationships(relationships: &[Relationship], entities: &[Entity]) -> Vec { + let selected_entity_names: Vec = relationships + .iter() + .flat_map(|relationship| vec![relationship.source.clone(), relationship.target.clone()]) + .collect(); + + entities + .iter() + .filter(|entity| selected_entity_names.contains(&entity.title)) + .cloned() + .collect() +} + +pub fn sort_relationships_by_ranking_attribute( + relationships: Vec, + entities: Vec, + ranking_attribute: &str, +) -> Vec { + if relationships.is_empty() { + return relationships; + } + + let mut relationships = relationships; + + let attribute_names: Vec = if let Some(attributes) = &relationships[0].attributes { + attributes.keys().cloned().collect() + } else { + Vec::new() + }; + + if attribute_names.contains(&ranking_attribute.to_string()) { + relationships.sort_by(|a, b| { + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + b_rank.cmp(&a_rank) + }); + } else if ranking_attribute == "weight" { + relationships.sort_by(|a, b| { + let a_weight = a.weight.unwrap_or(0.0); + let b_weight = b.weight.unwrap_or(0.0); + b_weight.partial_cmp(&a_weight).unwrap_or(Ordering::Equal) + }); + } else { + relationships = calculate_relationship_combined_rank(relationships, entities, ranking_attribute); + relationships.sort_by(|a, b| { + let a_rank = a + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + let b_rank = b + .attributes + .as_ref() + .and_then(|attrs| attrs.get(ranking_attribute)) + .and_then(|rank| rank.parse::().ok()) + .unwrap_or(0); + b_rank.cmp(&a_rank) + }); + } + + relationships +} + +pub fn calculate_relationship_combined_rank( + relationships: Vec, + entities: Vec, + ranking_attribute: &str, +) -> Vec { + let mut relationships = relationships; + let entity_mappings: HashMap<_, _> = entities.iter().map(|e| (e.title.clone(), e)).collect(); + + for relationship in relationships.iter_mut() { + if relationship.attributes.is_none() { + relationship.attributes = Some(HashMap::new()); + } + + let source_rank = entity_mappings + .get(&relationship.source) + .and_then(|e| e.rank) + .unwrap_or(0); + let target_rank = entity_mappings + .get(&relationship.target) + .and_then(|e| e.rank) + .unwrap_or(0); + + if let Some(attributes) = &mut relationship.attributes { + attributes.insert(ranking_attribute.to_string(), (source_rank + target_rank).to_string()); + } + } + + relationships +} + +pub fn to_relationship_dataframe( + relationships: &Vec, + include_relationship_weight: bool, +) -> anyhow::Result { + if relationships.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec![ + "id".to_string(), + "source".to_string(), + "target".to_string(), + "description".to_string(), + ]; + + if include_relationship_weight { + header.push("weight".to_string()); + } + + let attribute_cols = if let Some(relationship) = relationships.first().cloned() { + relationship + .attributes + .unwrap_or_default() + .keys() + .cloned() + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for rel in relationships { + records + .entry("id") + .or_insert_with(Vec::new) + .push(rel.short_id.clone().unwrap_or_default()); + records + .entry("source") + .or_insert_with(Vec::new) + .push(rel.source.clone()); + records + .entry("target") + .or_insert_with(Vec::new) + .push(rel.target.clone()); + records + .entry("description") + .or_insert_with(Vec::new) + .push(rel.description.clone().unwrap_or_default()); + + if include_relationship_weight { + records + .entry("weight") + .or_insert_with(Vec::new) + .push(rel.weight.map(|r| r.to_string()).unwrap_or_default()); + } + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + rel.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + if header == "weight" { + let data_values = data_values + .iter() + .map(|v| v.parse::().unwrap_or(0.0)) + .collect::>(); + let series = Series::new(header, data_values); + data_series.push(series); + } else { + let series = Series::new(header, data_values); + data_series.push(series); + }; + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs b/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs new file mode 100644 index 000000000..3849361ec --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/input/retrieval/text_units.rs @@ -0,0 +1,81 @@ +use std::collections::{HashMap, HashSet}; + +use polars::{frame::DataFrame, prelude::NamedFrom, series::Series}; + +use crate::models::{Entity, TextUnit}; + +pub fn get_candidate_text_units(selected_entities: &Vec, text_units: &[TextUnit]) -> anyhow::Result { + let mut selected_text_ids: HashSet = HashSet::new(); + + for entity in selected_entities { + if let Some(ids) = &entity.text_unit_ids { + for id in ids { + selected_text_ids.insert(id.to_string()); + } + } + } + + let selected_text_units: Vec = text_units + .iter() + .filter(|unit| selected_text_ids.contains(&unit.id)) + .cloned() + .collect(); + + to_text_unit_dataframe(selected_text_units) +} + +pub fn to_text_unit_dataframe(text_units: Vec) -> anyhow::Result { + if text_units.is_empty() { + return Ok(DataFrame::default()); + } + + let mut header = vec!["id".to_string(), "text".to_string()]; + + let attribute_cols = if let Some(text_unit) = text_units.first().cloned() { + text_unit + .attributes + .unwrap_or_default() + .keys() + .cloned() + .collect::>() + } else { + Vec::new() + }; + + let attribute_cols: Vec = attribute_cols.into_iter().filter(|col| !header.contains(col)).collect(); + header.extend(attribute_cols.clone()); + + let mut records = HashMap::new(); + + for unit in text_units { + records + .entry("id") + .or_insert_with(Vec::new) + .push(unit.short_id.clone().unwrap_or_default()); + records.entry("text").or_insert_with(Vec::new).push(unit.text.clone()); + + for field in &attribute_cols { + records.entry(field).or_insert_with(Vec::new).push( + unit.attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(), + ); + } + } + + let mut data_series = Vec::new(); + for (header, data_values) in records { + let series = Series::new(header, data_values); + data_series.push(series); + } + + let record_df = if !data_series.is_empty() { + DataFrame::new(data_series)? + } else { + DataFrame::default() + }; + + Ok(record_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs new file mode 100644 index 000000000..35060f002 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -0,0 +1,7 @@ +pub mod context_builder; +pub mod indexer_adapters; +pub mod input; +pub mod llm; +pub mod models; +pub mod search; +pub mod vector_stores; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/base.rs b/shinkai-libs/shinkai-graphrag/src/llm/base.rs new file mode 100644 index 000000000..3432a19c5 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/base.rs @@ -0,0 +1,44 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Default)] +pub struct BaseLLMCallback { + pub response: Vec, +} + +impl BaseLLMCallback { + pub fn on_llm_new_token(&mut self, token: &str) { + self.response.push(token.to_string()); + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + String(String), + Strings(Vec), + Dictionary(Vec>), +} + +#[derive(Debug, Clone)] +pub struct LLMParams { + pub max_tokens: u32, + pub temperature: f32, +} + +#[async_trait] +pub trait BaseLLM { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result; +} + +#[async_trait] +pub trait BaseTextEmbedding { + async fn aembed(&self, text: &str) -> anyhow::Result>; +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs new file mode 100644 index 000000000..6cf245d4d --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -0,0 +1 @@ +pub mod base; diff --git a/shinkai-libs/shinkai-graphrag/src/models.rs b/shinkai-libs/shinkai-graphrag/src/models.rs new file mode 100644 index 000000000..1d6a63e01 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/models.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct CommunityReport { + pub id: String, + pub short_id: Option, + pub title: String, + pub community_id: String, + pub summary: String, + pub full_content: String, + pub rank: Option, + pub summary_embedding: Option>, + pub full_content_embedding: Option>, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct Entity { + pub id: String, + pub short_id: Option, + pub title: String, + pub entity_type: Option, + pub description: Option, + pub description_embedding: Option>, + pub name_embedding: Option>, + pub graph_embedding: Option>, + pub community_ids: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub rank: Option, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct Relationship { + pub id: String, + pub short_id: Option, + pub source: String, + pub target: String, + pub weight: Option, + pub description: Option, + pub description_embedding: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub attributes: Option>, +} + +#[derive(Debug, Clone)] +pub struct TextUnit { + pub id: String, + pub short_id: Option, + pub text: String, + pub text_embedding: Option>, + pub entity_ids: Option>, + pub relationship_ids: Option>, + pub n_tokens: Option, + pub document_ids: Option>, + pub attributes: Option>, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/base.rs b/shinkai-libs/shinkai-graphrag/src/search/base.rs new file mode 100644 index 000000000..a7ba7eaef --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/base.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; + +use polars::frame::DataFrame; + +#[derive(Debug, Clone)] +pub enum ResponseType { + String(String), + KeyPoints(Vec), +} + +#[derive(Debug, Clone)] +pub enum ContextData { + String(String), + DataFrames(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub enum ContextText { + String(String), + Strings(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Clone)] +pub struct KeyPoint { + pub answer: String, + pub score: i32, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs new file mode 100644 index 000000000..5aed5257c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs @@ -0,0 +1,2 @@ +pub mod prompts; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs new file mode 100644 index 000000000..7c9fef5cb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs @@ -0,0 +1,164 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for global search. + +pub const MAP_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +"#; + +pub const REDUCE_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; + +pub const NO_DATA_ANSWER: &str = "I am sorry but I am unable to answer this question given the provided data."; + +pub const GENERAL_KNOWLEDGE_INSTRUCTION: &str = r#" +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +"#; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs new file mode 100644 index 000000000..34e68e3af --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/search.rs @@ -0,0 +1,381 @@ +use futures::future::join_all; +use serde_json::Value; +use std::collections::HashMap; +use std::time::Instant; + +use crate::context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}; +use crate::context_builder::conversation_history::ConversationHistory; +use crate::llm::base::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType}; +use crate::search::global_search::prompts::NO_DATA_ANSWER; + +use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; + +#[derive(Debug, Clone)] +pub struct SearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, +} + +pub struct GlobalSearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, + pub map_responses: Vec, + pub reduce_context_data: ContextData, + pub reduce_context_text: ContextText, +} + +#[derive(Debug, Clone, Default)] +pub struct GlobalSearchLLMCallback { + response: Vec, + map_response_contexts: Vec, + map_response_outputs: Vec, +} + +impl GlobalSearchLLMCallback { + pub fn on_map_response_start(&mut self, map_response_contexts: Vec) { + self.map_response_contexts = map_response_contexts; + } + + pub fn on_map_response_end(&mut self, map_response_outputs: Vec) { + self.map_response_outputs = map_response_outputs; + } +} + +pub struct GlobalSearch { + llm: Box, + context_builder: GlobalCommunityContext, + num_tokens_fn: fn(&str) -> usize, + context_builder_params: CommunityContextBuilderParams, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, +} + +pub struct GlobalSearchParams { + pub llm: Box, + pub context_builder: GlobalCommunityContext, + pub num_tokens_fn: fn(&str) -> usize, + pub map_system_prompt: Option, + pub reduce_system_prompt: Option, + pub response_type: String, + pub allow_general_knowledge: bool, + pub general_knowledge_inclusion_prompt: Option, + pub callbacks: Option>, + pub max_data_tokens: usize, + pub map_llm_params: LLMParams, + pub reduce_llm_params: LLMParams, + pub context_builder_params: CommunityContextBuilderParams, +} + +impl GlobalSearch { + pub fn new(global_search_params: GlobalSearchParams) -> Self { + let GlobalSearchParams { + llm, + context_builder, + num_tokens_fn, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + context_builder_params, + } = global_search_params; + + let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); + let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); + let general_knowledge_inclusion_prompt = + general_knowledge_inclusion_prompt.unwrap_or(GENERAL_KNOWLEDGE_INSTRUCTION.to_string()); + + GlobalSearch { + llm, + context_builder, + num_tokens_fn, + context_builder_params, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + } + } + + pub async fn asearch( + &self, + query: String, + _conversation_history: Option, + ) -> anyhow::Result { + // Step 1: Generate answers for each batch of community short summaries + let start_time = Instant::now(); + let (context_chunks, context_records) = self + .context_builder + .build_context(self.context_builder_params.clone())?; + + let mut callbacks = match &self.callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_start(context_chunks.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) + } + None => None, + }; + + let map_responses: Vec<_> = join_all( + context_chunks + .iter() + .map(|data| self._map_response_single_batch(data, self.map_llm_params.clone())), + ) + .await; + + let map_responses: Result, _> = map_responses.into_iter().collect(); + let map_responses = map_responses?; + + callbacks = match &callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_end(map_responses.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) + } + None => None, + }; + + let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); + let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); + + // Step 2: Combine the intermediate answers from step 2 to generate the final answer + let reduce_response = self + ._reduce_response(map_responses.clone(), &query, callbacks, self.reduce_llm_params.clone()) + .await?; + + Ok(GlobalSearchResult { + response: reduce_response.response, + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::Strings(context_chunks), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: map_llm_calls + reduce_response.llm_calls, + prompt_tokens: map_prompt_tokens + reduce_response.prompt_tokens, + map_responses, + reduce_context_data: reduce_response.context_data, + reduce_context_text: reduce_response.context_text, + }) + } + + async fn _map_response_single_batch( + &self, + context_data: &str, + llm_params: LLMParams, + ) -> anyhow::Result { + let start_time = Instant::now(); + let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), "Respond using JSON".to_string()), + ]), + ]; + + let search_response = self + .llm + .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .await?; + + let processed_response = self.parse_search_response(&search_response); + + Ok(SearchResult { + response: ResponseType::KeyPoints(processed_response), + context_data: ContextData::String(context_data.to_string()), + context_text: ContextText::String(context_data.to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } + + fn parse_search_response(&self, search_response: &str) -> Vec { + let search_response = &search_response.replace("{{", "{").replace("}}", "}"); + let parsed_elements: Value = serde_json::from_str(search_response).unwrap_or_default(); + + if let Some(points) = parsed_elements.get("points") { + if let Some(points) = points.as_array() { + return points + .iter() + .filter(|element| element.get("description").is_some() && element.get("score").is_some()) + .map(|element| KeyPoint { + answer: element + .get("description") + .unwrap_or(&Value::String("".to_string())) + .to_string(), + score: element + .get("score") + .unwrap_or(&Value::Number(serde_json::Number::from(0))) + .as_i64() + .unwrap_or(0) as i32, + }) + .collect::>(); + } + } + + vec![KeyPoint { + answer: "".to_string(), + score: 0, + }] + } + + async fn _reduce_response( + &self, + map_responses: Vec, + query: &str, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let start_time = Instant::now(); + let mut key_points: Vec> = Vec::new(); + + for (index, response) in map_responses.iter().enumerate() { + if let ResponseType::KeyPoints(response_list) = &response.response { + for key_point in response_list { + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), key_point.answer.clone()); + point.insert("score".to_string(), key_point.score.to_string()); + key_points.push(point); + } + } + } + + let filtered_key_points: Vec> = key_points + .into_iter() + .filter(|point| point.get("score").unwrap().parse::().unwrap() > 0) + .collect(); + + if filtered_key_points.is_empty() && !self.allow_general_knowledge { + eprintln!("Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."); + + return Ok(SearchResult { + response: ResponseType::String(NO_DATA_ANSWER.to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + }); + } + + let mut sorted_key_points = filtered_key_points; + sorted_key_points.sort_by(|a, b| { + b.get("score") + .unwrap() + .parse::() + .unwrap() + .cmp(&a.get("score").unwrap().parse::().unwrap()) + }); + + let mut data: Vec = Vec::new(); + let mut total_tokens = 0; + for point in sorted_key_points { + let mut formatted_response_data: Vec = Vec::new(); + formatted_response_data.push(format!("----Analyst {}----", point.get("analyst").unwrap())); + formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); + formatted_response_data.push(point.get("answer").unwrap().to_string()); + let formatted_response_text = formatted_response_data.join("\n"); + + if total_tokens + (self.num_tokens_fn)(&formatted_response_text) > self.max_data_tokens { + break; + } + + data.push(formatted_response_text.clone()); + total_tokens += (self.num_tokens_fn)(&formatted_response_text); + } + let text_data = data.join("\n\n"); + + let search_prompt = format!( + "{}\n{}", + self.reduce_system_prompt + .replace("{report_data}", &text_data) + .replace("{response_type}", &self.response_type), + if self.allow_general_knowledge { + self.general_knowledge_inclusion_prompt.clone() + } else { + String::new() + } + ); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; + + let llm_callbacks = match callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + llm_callbacks.push(BaseLLMCallback { + response: callback.response.clone(), + }); + } + Some(llm_callbacks) + } + None => None, + }; + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + true, + llm_callbacks, + llm_params, + ) + .await?; + + Ok(SearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::String(text_data.clone()), + context_text: ContextText::String(text_data), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs new file mode 100644 index 000000000..4b394a22a --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mixed_context.rs @@ -0,0 +1,603 @@ +use std::collections::HashMap; + +use polars::{ + frame::DataFrame, + prelude::{is_in, NamedFrom}, + series::Series, +}; + +use crate::{ + context_builder::{ + community_context::CommunityContext, + entity_extraction::map_query_to_entities, + local_context::{build_entity_context, build_relationship_context, get_candidate_context}, + source_context::{build_text_unit_context, count_relationships}, + }, + input::retrieval::{community_reports::get_candidate_communities, text_units::get_candidate_text_units}, + llm::base::BaseTextEmbedding, + models::{CommunityReport, Entity, Relationship, TextUnit}, + vector_stores::lancedb::LanceDBVectorStore, +}; + +#[derive(Debug, Clone)] +pub struct MixedContextBuilderParams { + pub query: String, + pub include_entity_names: Option>, + pub exclude_entity_names: Option>, + pub max_tokens: usize, + pub text_unit_prop: f32, + pub community_prop: f32, + pub top_k_mapped_entities: usize, + pub top_k_relationships: usize, + pub include_community_rank: bool, + pub include_entity_rank: bool, + pub rank_description: String, + pub include_relationship_weight: bool, + pub relationship_ranking_attribute: String, + pub return_candidate_context: bool, + pub use_community_summary: bool, + pub min_community_rank: u32, + pub community_context_name: String, + pub column_delimiter: String, + // pub conversation_history: Option, + // pub conversation_history_max_turns: Option, + // pub conversation_history_user_turns_only: bool, +} + +pub fn default_local_context_params() -> MixedContextBuilderParams { + MixedContextBuilderParams { + query: String::new(), + include_entity_names: None, + exclude_entity_names: None, + max_tokens: 8000, + text_unit_prop: 0.5, + community_prop: 0.25, + top_k_mapped_entities: 10, + top_k_relationships: 10, + include_community_rank: false, + include_entity_rank: false, + rank_description: "number of relationships".to_string(), + include_relationship_weight: false, + relationship_ranking_attribute: "rank".to_string(), + return_candidate_context: false, + use_community_summary: false, + min_community_rank: 0, + community_context_name: "Reports".to_string(), + column_delimiter: "|".to_string(), + } +} + +pub struct LocalSearchMixedContext { + entities: HashMap, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, + text_units: HashMap, + community_reports: HashMap, + relationships: HashMap, + num_tokens_fn: fn(&str) -> usize, + embedding_vectorstore_key: String, +} + +impl LocalSearchMixedContext { + pub fn new( + entities: Vec, + entity_text_embeddings: LanceDBVectorStore, + text_embedder: Box, + text_units: Option>, + community_reports: Option>, + relationships: Option>, + num_tokens_fn: fn(&str) -> usize, + embedding_vectorstore_key: String, + ) -> Self { + let mut context = LocalSearchMixedContext { + entities: HashMap::new(), + entity_text_embeddings, + text_embedder, + text_units: HashMap::new(), + community_reports: HashMap::new(), + relationships: HashMap::new(), + num_tokens_fn, + embedding_vectorstore_key, + }; + + for entity in entities { + context.entities.insert(entity.id.clone(), entity); + } + + if let Some(units) = text_units { + for unit in units { + context.text_units.insert(unit.id.replace('"', ""), unit); + } + } + + if let Some(reports) = community_reports { + for report in reports { + context.community_reports.insert(report.id.clone(), report); + } + } + + if let Some(relations) = relationships { + for relation in relations { + context.relationships.insert(relation.id.replace('"', ""), relation); + } + } + + context + } + + pub async fn build_context( + &self, + context_builder_params: MixedContextBuilderParams, + ) -> anyhow::Result<(String, HashMap)> { + let MixedContextBuilderParams { + query, + include_entity_names, + exclude_entity_names, + max_tokens, + text_unit_prop, + community_prop, + top_k_mapped_entities, + top_k_relationships, + include_community_rank, + include_entity_rank, + rank_description, + include_relationship_weight, + relationship_ranking_attribute, + return_candidate_context, + use_community_summary, + min_community_rank, + community_context_name, + column_delimiter, + } = context_builder_params; + + let include_entity_names = include_entity_names.unwrap_or_default(); + let exclude_entity_names = exclude_entity_names.unwrap_or_default(); + + if community_prop + text_unit_prop > 1.0 { + return Err(anyhow::anyhow!( + "The sum of community_prop and text_unit_prop must be less than or equal to 1.0" + )); + } + + let selected_entities = map_query_to_entities( + &query, + &self.entity_text_embeddings, + &*self.text_embedder, + &self.entities.values().cloned().collect::>(), + &self.embedding_vectorstore_key, + Some(include_entity_names), + Some(exclude_entity_names), + top_k_mapped_entities, + 2, + ) + .await?; + + let mut final_context = Vec::new(); + let mut final_context_data = HashMap::new(); + + // build community context + let community_tokens = std::cmp::max((max_tokens as f32 * community_prop) as usize, 0); + let (community_context, community_context_data) = self._build_community_context( + selected_entities.clone(), + community_tokens, + use_community_summary, + &column_delimiter, + include_community_rank, + min_community_rank, + return_candidate_context, + &community_context_name, + )?; + + if !community_context.trim().is_empty() { + final_context.push(community_context); + final_context_data.extend(community_context_data); + } + + // build local (i.e. entity-relationship-covariate) context + let local_prop = 1_f32 - community_prop - text_unit_prop; + let local_tokens = std::cmp::max((max_tokens as f32 * local_prop) as usize, 0); + let (local_context, local_context_data) = self._build_local_context( + selected_entities.clone(), + local_tokens, + include_entity_rank, + &rank_description, + include_relationship_weight, + top_k_relationships, + &relationship_ranking_attribute, + return_candidate_context, + &column_delimiter, + )?; + + if !local_context.trim().is_empty() { + final_context.push(local_context); + final_context_data.extend(local_context_data); + } + + // build text unit context + let text_unit_tokens = std::cmp::max((max_tokens as f32 * text_unit_prop) as usize, 0); + let (text_unit_context, text_unit_context_data) = self._build_text_unit_context( + selected_entities.clone(), + text_unit_tokens, + return_candidate_context, + "|", + "Sources", + )?; + + if !text_unit_context.trim().is_empty() { + final_context.push(text_unit_context); + final_context_data.extend(text_unit_context_data); + } + + Ok((final_context.join("\n\n"), final_context_data)) + } + + fn _build_community_context( + &self, + selected_entities: Vec, + max_tokens: usize, + use_community_summary: bool, + column_delimiter: &str, + include_community_rank: bool, + min_community_rank: u32, + return_candidate_context: bool, + context_name: &str, + ) -> anyhow::Result<(String, HashMap)> { + if selected_entities.is_empty() || self.community_reports.is_empty() { + return Ok(( + "".to_string(), + HashMap::from([(context_name.to_lowercase(), DataFrame::default())]), + )); + } + + let mut community_matches: HashMap = HashMap::new(); + for entity in &selected_entities { + if let Some(community_ids) = &entity.community_ids { + for community_id in community_ids { + *community_matches.entry(community_id.to_string()).or_insert(0) += 1; + } + } + } + + let mut selected_communities: Vec = Vec::new(); + for community_id in community_matches.keys() { + if let Some(community) = self.community_reports.get(community_id) { + selected_communities.push(community.clone()); + } + } + + for community in &mut selected_communities { + if community.attributes.is_none() { + community.attributes = Some(HashMap::new()); + } + if let Some(attributes) = &mut community.attributes { + attributes.insert("matches".to_string(), community_matches[&community.id].to_string()); + } + } + + selected_communities.sort_by(|a, b| { + let a_matches = a + .attributes + .as_ref() + .unwrap() + .get("matches") + .unwrap() + .parse::() + .unwrap(); + let b_matches = b + .attributes + .as_ref() + .unwrap() + .get("matches") + .unwrap() + .parse::() + .unwrap(); + let a_rank = a.rank.unwrap(); + let b_rank = b.rank.unwrap(); + (b_matches, b_rank).partial_cmp(&(a_matches, a_rank)).unwrap() + }); + + for community in &mut selected_communities { + if let Some(attributes) = &mut community.attributes { + attributes.remove("matches"); + } + } + + let (context_text, context_data) = CommunityContext::build_community_context( + selected_communities, + None, + self.num_tokens_fn, + use_community_summary, + column_delimiter, + false, + include_community_rank, + min_community_rank, + "rank", + true, + "occurrence weight", + true, + max_tokens, + true, + context_name, + )?; + + let mut context_text_result = "".to_string(); + if !context_text.is_empty() { + context_text_result = context_text.join("\n\n"); + } + + let mut context_data = context_data; + if return_candidate_context { + let candidate_context_data = get_candidate_communities( + selected_entities, + self.community_reports.values().cloned().collect(), + use_community_summary, + include_community_rank, + )?; + + let context_key = context_name.to_lowercase(); + if !context_data.contains_key(&context_key) { + let mut new_data = candidate_context_data.clone(); + new_data.with_column(Series::new("in_context", vec![false; candidate_context_data.height()]))?; + context_data.insert(context_key.to_string(), new_data); + } else { + let existing_data = context_data.get(&context_key).unwrap(); + if candidate_context_data + .get_column_names() + .contains(&"id".to_string().as_str()) + && existing_data.get_column_names().contains(&"id".to_string().as_str()) + { + let existing_ids = existing_data.column("id")?; + let context_ids = candidate_context_data.column("id")?; + let mut new_data = candidate_context_data.clone(); + let in_context = is_in(context_ids, existing_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + context_data.insert(context_key.to_string(), new_data); + } else { + let mut existing_data = existing_data.clone(); + existing_data.with_column(Series::new("in_context", vec![true; existing_data.height()]))?; + context_data.insert(context_key.to_string(), existing_data); + } + } + } + + Ok((context_text_result, context_data)) + } + + fn _build_local_context( + &self, + selected_entities: Vec, + max_tokens: usize, + include_entity_rank: bool, + rank_description: &str, + include_relationship_weight: bool, + top_k_relationships: usize, + relationship_ranking_attribute: &str, + return_candidate_context: bool, + column_delimiter: &str, + ) -> anyhow::Result<(String, HashMap)> { + let (entity_context, entity_context_data) = build_entity_context( + selected_entities.clone(), + self.num_tokens_fn, + max_tokens, + include_entity_rank, + rank_description, + column_delimiter, + "Entities", + )?; + + let entity_tokens = (self.num_tokens_fn)(&entity_context); + + let mut added_entities = Vec::new(); + let mut final_context = Vec::new(); + let mut final_context_data = HashMap::new(); + + for entity in &selected_entities { + let mut current_context = Vec::new(); + let mut current_context_data = HashMap::new(); + added_entities.push(entity.clone()); + + let (relationship_context, relationship_context_data) = build_relationship_context( + &added_entities, + &self.relationships.values().cloned().collect::>(), + self.num_tokens_fn, + include_relationship_weight, + max_tokens, + top_k_relationships, + relationship_ranking_attribute, + column_delimiter, + "Relationships", + )?; + + current_context.push(relationship_context.clone()); + current_context_data.insert("relationships".to_string(), relationship_context_data); + + let total_tokens = entity_tokens + (self.num_tokens_fn)(&relationship_context); + + if total_tokens > max_tokens { + eprintln!("Reached token limit - reverting to previous context state"); + break; + } + + final_context = current_context; + final_context_data = current_context_data; + } + + let mut final_context_text = entity_context.to_string(); + final_context_text.push_str("\n\n"); + final_context_text.push_str(&final_context.join("\n\n")); + final_context_data.insert("entities".to_string(), entity_context_data.clone()); + + if return_candidate_context { + let entities = self.entities.values().cloned().collect::>(); + let relationships = self.relationships.values().cloned().collect::>(); + + let candidate_context_data = get_candidate_context( + &selected_entities, + &entities, + &relationships, + include_entity_rank, + rank_description, + include_relationship_weight, + )?; + + for (key, candidate_df) in candidate_context_data { + if !final_context_data.contains_key(&key) { + final_context_data.insert(key.clone(), candidate_df); + } else { + let in_context_df = final_context_data.get_mut(&key).unwrap(); + + if in_context_df.get_column_names().contains(&"id".to_string().as_str()) + && candidate_df.get_column_names().contains(&"id".to_string().as_str()) + { + let context_ids = in_context_df.column("id")?; + let candidate_ids = candidate_df.column("id")?; + let mut new_data = candidate_df.clone(); + let in_context = is_in(candidate_ids, context_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + final_context_data.insert(key.clone(), new_data); + } else { + in_context_df.with_column(Series::new("in_context", vec![true; in_context_df.height()]))?; + } + } + } + } else { + for (_key, context_df) in final_context_data.iter_mut() { + context_df.with_column(Series::new("in_context", vec![true; context_df.height()]))?; + } + } + + Ok((final_context_text, final_context_data)) + } + + fn _build_text_unit_context( + &self, + selected_entities: Vec, + max_tokens: usize, + return_candidate_context: bool, + column_delimiter: &str, + context_name: &str, + ) -> anyhow::Result<(String, HashMap)> { + if selected_entities.is_empty() || self.text_units.is_empty() { + return Ok((String::new(), HashMap::new())); + } + + let mut selected_text_units: Vec = Vec::new(); + + for (index, entity) in selected_entities.iter().enumerate() { + if let Some(text_unit_ids) = &entity.text_unit_ids { + for text_id in text_unit_ids { + if !selected_text_units.iter().any(|unit| &unit.id == text_id) + && self.text_units.contains_key(text_id) + { + let mut selected_unit = self.text_units[text_id].clone(); + let num_relationships = count_relationships(&selected_unit, entity, &self.relationships); + + if selected_unit.attributes.is_none() { + selected_unit.attributes = Some(HashMap::new()); + } + + if let Some(attributes) = &mut selected_unit.attributes { + attributes.insert("entity_order".to_string(), index.to_string()); + attributes.insert("num_relationships".to_string(), num_relationships.to_string()); + } + + selected_text_units.push(selected_unit); + } + } + } + } + + selected_text_units.sort_by(|a, b| { + let a_order = a + .attributes + .as_ref() + .unwrap() + .get("entity_order") + .unwrap() + .parse::() + .unwrap(); + let b_order = b + .attributes + .as_ref() + .unwrap() + .get("entity_order") + .unwrap() + .parse::() + .unwrap(); + + let a_relationships = a + .attributes + .as_ref() + .unwrap() + .get("num_relationships") + .unwrap() + .parse::() + .unwrap(); + let b_relationships = b + .attributes + .as_ref() + .unwrap() + .get("num_relationships") + .unwrap() + .parse::() + .unwrap(); + + a_order + .cmp(&b_order) + .then_with(|| b_relationships.cmp(&a_relationships)) + }); + + for unit in &mut selected_text_units { + unit.attributes.as_mut().unwrap().remove("entity_order"); + unit.attributes.as_mut().unwrap().remove("num_relationships"); + } + + let (context_text, context_data) = build_text_unit_context( + selected_text_units, + self.num_tokens_fn, + column_delimiter, + false, + max_tokens, + context_name, + 86, + )?; + + let mut context_data = context_data; + if return_candidate_context { + let candidate_context_data = get_candidate_text_units( + &selected_entities, + &self.text_units.values().cloned().collect::>(), + )?; + + let context_key = context_name.to_lowercase(); + if !context_data.contains_key(&context_key) { + let mut new_data = candidate_context_data.clone(); + new_data.with_column(Series::new("in_context", vec![false; candidate_context_data.height()]))?; + context_data.insert(context_key.to_string(), new_data); + } else { + let existing_data = context_data.get(&context_key).unwrap(); + if candidate_context_data + .get_column_names() + .contains(&"id".to_string().as_str()) + && existing_data.get_column_names().contains(&"id".to_string().as_str()) + { + let existing_ids = existing_data.column("id")?; + let context_ids = candidate_context_data.column("id")?; + let mut new_data = candidate_context_data.clone(); + let in_context = is_in(context_ids, existing_ids)?; + let in_context = Series::new("in_context", in_context); + new_data.with_column(in_context)?; + context_data.insert(context_key.to_string(), new_data); + } else { + let mut existing_data = existing_data.clone(); + existing_data.with_column(Series::new("in_context", vec![true; existing_data.height()]))?; + context_data.insert(context_key.to_string(), existing_data); + } + } + } + + Ok((context_text, context_data)) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs new file mode 100644 index 000000000..eb73d0830 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/mod.rs @@ -0,0 +1,3 @@ +pub mod mixed_context; +pub mod prompts; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs new file mode 100644 index 000000000..7da6a63eb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/prompts.rs @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for local search. + +pub const LOCAL_SEARCH_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Data tables--- + +{context_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; diff --git a/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs new file mode 100644 index 000000000..1f87d06b4 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/local_search/search.rs @@ -0,0 +1,98 @@ +use std::{collections::HashMap, time::Instant}; + +use crate::{ + llm::base::{BaseLLM, LLMParams, MessageType}, + search::base::{ContextData, ContextText, ResponseType}, +}; + +use super::{ + mixed_context::{LocalSearchMixedContext, MixedContextBuilderParams}, + prompts::LOCAL_SEARCH_SYSTEM_PROMPT, +}; + +pub struct LocalSearchResult { + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, +} + +pub struct LocalSearch { + llm: Box, + context_builder: LocalSearchMixedContext, + num_tokens_fn: fn(&str) -> usize, + system_prompt: String, + response_type: String, + llm_params: LLMParams, + context_builder_params: MixedContextBuilderParams, +} + +impl LocalSearch { + pub fn new( + llm: Box, + context_builder: LocalSearchMixedContext, + num_tokens_fn: fn(&str) -> usize, + llm_params: LLMParams, + context_builder_params: MixedContextBuilderParams, + response_type: String, + system_prompt: Option, + ) -> Self { + let system_prompt = system_prompt.unwrap_or(LOCAL_SEARCH_SYSTEM_PROMPT.to_string()); + + LocalSearch { + llm, + context_builder, + num_tokens_fn, + system_prompt, + response_type, + llm_params, + context_builder_params, + } + } + + pub async fn asearch(&self, query: String) -> anyhow::Result { + let start_time = Instant::now(); + + let mut context_builder_params = self.context_builder_params.clone(); + context_builder_params.query.clone_from(&query); + + let (context_text, context_records) = self.context_builder.build_context(context_builder_params).await?; + + let search_prompt = self + .system_prompt + .replace("{context_data}", &context_text) + .replace("{response_type}", &self.response_type); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + self.llm_params.clone(), + ) + .await?; + + Ok(LocalSearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::String(context_text), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: (self.num_tokens_fn)(&search_prompt), + }) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/mod.rs new file mode 100644 index 000000000..7266f8dab --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/mod.rs @@ -0,0 +1,3 @@ +pub mod base; +pub mod global_search; +pub mod local_search; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs new file mode 100644 index 000000000..78f9ab8ae --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/lancedb.rs @@ -0,0 +1,257 @@ +use std::{collections::HashMap, sync::Arc}; + +use arrow::datatypes::Float32Type; +use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray}; +use futures::TryStreamExt; +use lancedb::{ + arrow::arrow_schema::{DataType, Field, Schema}, + connect, + query::{ExecutableQuery, QueryBase}, + Connection, Table, +}; +use serde_json::json; + +use crate::llm::base::BaseTextEmbedding; + +use super::vector_store::{VectorStore, VectorStoreDocument, VectorStoreSearchResult}; + +pub struct LanceDBVectorStore { + collection_name: String, + db_connection: Option, + document_collection: Option, +} + +impl LanceDBVectorStore { + pub fn new(collection_name: String) -> Self { + LanceDBVectorStore { + collection_name, + db_connection: None, + document_collection: None, + } + } + + pub async fn connect(&mut self, db_uri: &str) -> anyhow::Result<()> { + let connection = connect(db_uri).execute().await?; + self.db_connection = Some(connection); + Ok(()) + } + + async fn similarity_search_by_vector( + &self, + query_embedding: Vec, + k: usize, + ) -> anyhow::Result> { + if let Some(document_collection) = &self.document_collection { + let records = document_collection + .query() + .limit(k) + .nearest_to(query_embedding)? + .execute() + .await? + .try_collect::>() + .await?; + + let mut results = Vec::new(); + for record in records { + let id_col = record + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let text_col = record + .column_by_name("text") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let vector_col = record + .column_by_name("vector") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let attributes_col = record + .column_by_name("attributes") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let distance_col = record + .column_by_name("_distance") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + if id_col.is_empty() + || text_col.is_empty() + || vector_col.is_empty() + || attributes_col.is_empty() + || distance_col.is_empty() + { + continue; + } + + for i in 0..record.num_rows() { + let id = id_col.value(i).to_string(); + let text = text_col.value(i).to_string(); + let vector: Vec = vector_col + .value(i) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|value| value.unwrap()) + .collect(); + let attributes: HashMap = serde_json::from_str(attributes_col.value(i))?; + + let distance = distance_col.value(i); + + let score = 1.0 - distance.abs(); + + let doc = VectorStoreDocument { + id, + text: Some(text), + vector: Some(vector), + attributes, + }; + + results.push(VectorStoreSearchResult { document: doc, score }); + } + } + + return Ok(results); + } + + Ok(Vec::new()) + } +} + +impl VectorStore for LanceDBVectorStore { + async fn similarity_search_by_text( + &self, + text: &str, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), + k: usize, + ) -> anyhow::Result> { + let query_embedding = text_embedder.aembed(text).await?; + + if query_embedding.is_empty() { + return Ok(vec![]); + } + + let results = self.similarity_search_by_vector(query_embedding, k).await?; + Ok(results) + } + + async fn load_documents(&mut self, documents: Vec, overwrite: bool) -> anyhow::Result<()> { + let db_connection = self + .db_connection + .as_ref() + .ok_or_else(|| anyhow::anyhow!("LanceDB connection is not established"))?; + + let data: Vec<_> = documents + .into_iter() + .filter(|document| document.vector.is_some()) + .collect(); + + let vector_dimension = if !data.is_empty() { + data[0].vector.as_ref().map(|v| v.len()).unwrap_or_default() + } else { + 0 + }; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("text", DataType::Utf8, true), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + vector_dimension.try_into().unwrap(), + ), + true, + ), + Field::new("attributes", DataType::Utf8, false), + ])); + + let batches = if !data.is_empty() { + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from( + data.iter().map(|document| document.id.clone()).collect::>(), + )), + Arc::new(StringArray::from( + data.iter() + .map(|document| document.text.clone().unwrap_or_default()) + .collect::>(), + )), + Arc::new(FixedSizeListArray::from_iter_primitive::( + data.iter() + .map(|document| { + Some( + document + .vector + .as_ref() + .map(|v| v.iter().map(|f| Some(*f)).collect::>()) + .unwrap_or_default(), + ) + }) + .collect::>(), + vector_dimension.try_into().unwrap(), + )), + Arc::new(StringArray::from( + data.iter() + .map(|document| json!(document.attributes).to_string()) + .collect::>(), + )), + ], + )?; + + Some(RecordBatchIterator::new(vec![Ok(batch)], schema.clone())) + } else { + None + }; + + if overwrite { + let _ = db_connection.drop_table(&self.collection_name).await; + + if let Some(batches) = batches { + let table = db_connection + .create_table(&self.collection_name, Box::new(batches)) + .execute() + .await?; + + self.document_collection = Some(table); + } else { + let table = db_connection + .create_empty_table(&self.collection_name, schema.clone()) + .execute() + .await?; + + self.document_collection = Some(table); + } + } else { + let table = match db_connection.open_table(&self.collection_name).execute().await { + Ok(table) => table, + Err(_) => { + db_connection + .create_empty_table(&self.collection_name, schema.clone()) + .execute() + .await? + } + }; + + if let Some(batches) = batches { + table.add(batches).execute().await?; + } + + self.document_collection = Some(table); + } + + Ok(()) + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs new file mode 100644 index 000000000..e8083f72b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/mod.rs @@ -0,0 +1,2 @@ +pub mod lancedb; +pub mod vector_store; diff --git a/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs new file mode 100644 index 000000000..d1745fe75 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/vector_stores/vector_store.rs @@ -0,0 +1,30 @@ +use std::collections::HashMap; + +use crate::llm::base::BaseTextEmbedding; + +pub struct VectorStoreSearchResult { + pub document: VectorStoreDocument, + pub score: f32, +} + +pub struct VectorStoreDocument { + pub id: String, + pub text: Option, + pub vector: Option>, + pub attributes: HashMap, +} + +pub trait VectorStore { + fn similarity_search_by_text( + &self, + text: &str, + text_embedder: &(dyn BaseTextEmbedding + Send + Sync), + k: usize, + ) -> impl std::future::Future>> + Send; + + fn load_documents( + &mut self, + documents: Vec, + overwrite: bool, + ) -> impl std::future::Future> + Send; +} diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs new file mode 100644 index 000000000..1dd7e696a --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -0,0 +1,203 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + context_builder::community_context::{CommunityContextBuilderParams, GlobalCommunityContext}, + indexer_adapters::{read_indexer_entities, read_indexer_reports}, + llm::base::LLMParams, + search::global_search::search::{GlobalSearch, GlobalSearchParams}, +}; +use utils::{ + ollama::OllamaChat, + openai::{num_tokens, ChatOpenAI}, +}; + +mod utils; + +// #[tokio::test] +async fn ollama_global_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let model = "llama3.1"; + + let llm = OllamaChat::new(base_url, model); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + // Using tiktoken for token count estimation + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = CommunityContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + // LLM params are ignored for Ollama + let map_llm_params = LLMParams { + max_tokens: 12_000, + temperature: 0.0, + }; + + let reduce_llm_params = LLMParams { + max_tokens: 12_000, + temperature: 0.0, + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + callbacks: None, + max_data_tokens: 12_000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} + +// #[tokio::test] +async fn openai_global_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key), &llm_model, 5); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = CommunityContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + callbacks: None, + max_data_tokens: 12_000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs new file mode 100644 index 000000000..1deb90fd7 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/local_search_tests.rs @@ -0,0 +1,279 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + indexer_adapters::{ + read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, + }, + input::loaders::dfs::store_entity_semantic_embeddings, + llm::base::LLMParams, + search::local_search::{ + mixed_context::{default_local_context_params, LocalSearchMixedContext}, + search::LocalSearch, + }, + vector_stores::lancedb::LanceDBVectorStore, +}; +use utils::{ + ollama::{OllamaChat, OllamaEmbedding}, + openai::{num_tokens, ChatOpenAI, OpenAIEmbedding}, +}; + +mod utils; + +// #[tokio::test] +async fn ollama_local_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let llm_model = "llama3.1"; + let llm = OllamaChat::new(base_url, llm_model); + + let embedding_model = "snowflake-arctic-embed:m"; + let text_embedder = OllamaEmbedding::new(base_url, embedding_model); + + // Load datasets + + let input_dir = "./dataset"; + let lancedb_uri = format!("{}/lancedb", input_dir); + + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + let relationship_table = "create_final_relationships"; + let text_unit_table = "create_final_text_units"; + let community_level = 2; + + // Read entities + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + let mut description_embedding_store = LanceDBVectorStore::new("entity_description_embeddings".to_string()); + description_embedding_store.connect(&lancedb_uri).await?; + + store_entity_semantic_embeddings(entities.clone(), &mut description_embedding_store).await?; + + println!("Entities ({}): {:?}", entity_df.height(), entity_df.head(Some(5))); + + // Read relationships + let mut relationship_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, relationship_table)).unwrap(); + let relationship_df = ParquetReader::new(&mut relationship_file).finish().unwrap(); + + let relationships = read_indexer_relationships(&relationship_df)?; + + println!( + "Relationships ({}): {:?}", + relationship_df.height(), + relationship_df.head(Some(5)) + ); + + // Read community reports + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + + println!("Reports ({}): {:?}", report_df.height(), report_df.head(Some(5))); + + // Read text units + let mut text_unit_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, text_unit_table)).unwrap(); + let text_unit_df = ParquetReader::new(&mut text_unit_file).finish().unwrap(); + + let text_units = read_indexer_text_units(&text_unit_df)?; + + println!( + "Text units ({}): {:?}", + text_unit_df.height(), + text_unit_df.head(Some(5)) + ); + + // Create local search context builder + let context_builder = LocalSearchMixedContext::new( + entities, + description_embedding_store, + Box::new(text_embedder), + Some(text_units), + Some(reports), + Some(relationships), + num_tokens, + "id".to_string(), + ); + + // Create local search engine + let mut local_context_params = default_local_context_params(); + local_context_params.text_unit_prop = 0.5; + local_context_params.community_prop = 0.1; + local_context_params.top_k_mapped_entities = 10; + local_context_params.top_k_relationships = 10; + local_context_params.include_entity_rank = true; + local_context_params.include_relationship_weight = true; + local_context_params.include_community_rank = false; + local_context_params.return_candidate_context = false; + local_context_params.max_tokens = 12_000; + + let llm_params = LLMParams { + max_tokens: 12_000, + temperature: 0.0, + }; + + let search_engine = LocalSearch::new( + Box::new(llm), + context_builder, + num_tokens, + llm_params, + local_context_params, + String::from("multiple paragraphs"), + None, + ); + + let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; + println!("Response: {:?}\n", result.response); + + match result.context_data { + shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { + for (entity, df) in dict.iter() { + println!("Data: {} ({})", entity, df.height()); + println!("{:?}", df.head(Some(10))); + } + } + data => { + println!("Context data: {:?}", data); + } + } + + Ok(()) +} + +// #[tokio::test] +async fn openai_local_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + let embedding_model = std::env::var("GRAPHRAG_EMBEDDING_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key.clone()), &llm_model, 5); + let text_embedder = OpenAIEmbedding::new(Some(api_key), &embedding_model, 8191, 5); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let lancedb_uri = format!("{}/lancedb", input_dir); + + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + let relationship_table = "create_final_relationships"; + let text_unit_table = "create_final_text_units"; + let community_level = 2; + + // Read entities + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + let mut description_embedding_store = LanceDBVectorStore::new("entity_description_embeddings".to_string()); + description_embedding_store.connect(&lancedb_uri).await?; + + store_entity_semantic_embeddings(entities.clone(), &mut description_embedding_store).await?; + + println!("Entities ({}): {:?}", entity_df.height(), entity_df.head(Some(5))); + + // Read relationships + let mut relationship_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, relationship_table)).unwrap(); + let relationship_df = ParquetReader::new(&mut relationship_file).finish().unwrap(); + + let relationships = read_indexer_relationships(&relationship_df)?; + + println!( + "Relationships ({}): {:?}", + relationship_df.height(), + relationship_df.head(Some(5)) + ); + + // Read community reports + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + + println!("Reports ({}): {:?}", report_df.height(), report_df.head(Some(5))); + + // Read text units + let mut text_unit_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, text_unit_table)).unwrap(); + let text_unit_df = ParquetReader::new(&mut text_unit_file).finish().unwrap(); + + let text_units = read_indexer_text_units(&text_unit_df)?; + + println!( + "Text units ({}): {:?}", + text_unit_df.height(), + text_unit_df.head(Some(5)) + ); + + // Create local search context builder + let context_builder = LocalSearchMixedContext::new( + entities, + description_embedding_store, + Box::new(text_embedder), + Some(text_units), + Some(reports), + Some(relationships), + num_tokens, + "id".to_string(), + ); + + // Create local search engine + let mut local_context_params = default_local_context_params(); + local_context_params.text_unit_prop = 0.5; + local_context_params.community_prop = 0.1; + local_context_params.top_k_mapped_entities = 10; + local_context_params.top_k_relationships = 10; + local_context_params.include_entity_rank = true; + local_context_params.include_relationship_weight = true; + local_context_params.include_community_rank = false; + local_context_params.return_candidate_context = false; + local_context_params.max_tokens = 12_000; + + let llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + }; + + let search_engine = LocalSearch::new( + Box::new(llm), + context_builder, + num_tokens, + llm_params, + local_context_params, + String::from("multiple paragraphs"), + None, + ); + + let result = search_engine.asearch("Tell me about Agent Mercer".to_string()).await?; + println!("Response: {:?}\n", result.response); + + let result = search_engine + .asearch("Tell me about Dr. Jordan Hayes".to_string()) + .await?; + println!("Response: {:?}\n", result.response); + + match result.context_data { + shinkai_graphrag::search::base::ContextData::Dictionary(dict) => { + for (entity, df) in dict.iter() { + println!("Data: {} ({})", entity, df.height()); + println!("{:?}", df.head(Some(10))); + } + } + data => { + println!("Context data: {:?}", data); + } + } + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs new file mode 100644 index 000000000..3ef32f620 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod ollama; +pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs new file mode 100644 index 000000000..4b72fe4b2 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -0,0 +1,105 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, BaseTextEmbedding, LLMParams, MessageType}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaChatResponse { + pub model: String, + pub created_at: String, + pub message: OllamaChatMessage, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct OllamaChatMessage { + pub role: String, + pub content: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaEmbeddingResponse { + pub model: String, + pub embeddings: Vec>, +} + +pub struct OllamaChat { + base_url: String, + model: String, +} + +impl OllamaChat { + pub fn new(base_url: &str, model: &str) -> Self { + OllamaChat { + base_url: base_url.to_string(), + model: model.to_string(), + } + } +} + +#[async_trait] +impl BaseLLM for OllamaChat { + async fn agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let client = Client::new(); + let chat_url = format!("{}{}", &self.base_url, "/api/chat"); + + let messages_json = match messages { + MessageType::String(message) => json![message], + MessageType::Strings(messages) => json!(messages), + MessageType::Dictionary(messages) => json!(messages), + }; + + let payload = json!({ + "model": self.model, + "messages": messages_json, + "options": { + "num_ctx": llm_params.max_tokens, + "temperature": llm_params.temperature, + }, + "stream": false, + }); + + let response = client.post(chat_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.message.content) + } +} + +pub struct OllamaEmbedding { + base_url: String, + model: String, +} + +impl OllamaEmbedding { + pub fn new(base_url: &str, model: &str) -> Self { + OllamaEmbedding { + base_url: base_url.to_string(), + model: model.to_string(), + } + } +} + +#[async_trait] +impl BaseTextEmbedding for OllamaEmbedding { + async fn aembed(&self, text: &str) -> anyhow::Result> { + let client = Client::new(); + let embedding_url = format!("{}{}", &self.base_url, "/api/embed"); + + let payload = json!({ + "model": self.model, + "input": text, + }); + + let response = client.post(embedding_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.embeddings.first().cloned().unwrap_or_default()) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs new file mode 100644 index 000000000..8a1eaec53 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -0,0 +1,254 @@ +use async_openai::{ + config::OpenAIConfig, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs, + CreateEmbeddingRequestArgs, + }, + Client, +}; +use async_trait::async_trait; +use ndarray::{Array1, Array2, Axis}; +use ndarray_stats::SummaryStatisticsExt; +use shinkai_graphrag::llm::base::{BaseLLM, BaseLLMCallback, BaseTextEmbedding, LLMParams, MessageType}; +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; + +pub struct ChatOpenAI { + pub api_key: Option, + pub model: String, + pub max_retries: usize, +} + +impl ChatOpenAI { + pub fn new(api_key: Option, model: &str, max_retries: usize) -> Self { + ChatOpenAI { + api_key, + model: model.to_string(), + max_retries, + } + } + + pub async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let mut retry_count = 0; + + loop { + match self + ._agenerate(messages.clone(), streaming, callbacks.clone(), llm_params.clone()) + .await + { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let messages = match messages { + MessageType::String(message) => vec![message], + MessageType::Strings(messages) => messages, + MessageType::Dictionary(messages) => { + let messages = messages + .iter() + .map(|message_map| { + message_map + .iter() + .map(|(key, value)| format!("{}: {}", key, value)) + .collect::>() + .join("\n") + }) + .collect(); + messages + } + }; + + let request_messages = messages + .into_iter() + .map(|m| ChatCompletionRequestSystemMessageArgs::default().content(m).build()) + .collect::>(); + + let request_messages: Result, _> = request_messages.into_iter().collect(); + let request_messages = request_messages?; + let request_messages = request_messages + .into_iter() + .map(|m| Into::::into(m.clone())) + .collect::>(); + + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(llm_params.max_tokens) + .temperature(llm_params.temperature) + .model(self.model.clone()) + .messages(request_messages) + .build()?; + + let response = client.chat().create(request).await?; + + if let Some(choice) = response.choices.first() { + return Ok(choice.message.content.clone().unwrap_or_default()); + } + + Ok(String::new()) + } +} + +#[async_trait] +impl BaseLLM for ChatOpenAI { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + self.agenerate(messages, streaming, callbacks, llm_params).await + } +} + +pub struct OpenAIEmbedding { + pub api_key: Option, + pub model: String, + pub max_tokens: usize, // 8191 + pub max_retries: usize, +} + +impl OpenAIEmbedding { + pub fn new(api_key: Option, model: &str, max_tokens: usize, max_retries: usize) -> Self { + OpenAIEmbedding { + api_key, + model: model.to_string(), + max_tokens, + max_retries, + } + } + + async fn _aembed_with_retry(&self, text: &str) -> anyhow::Result> { + let mut retry_count = 0; + + loop { + match self._aembed(text).await { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _aembed(&self, text: &str) -> anyhow::Result> { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let request = CreateEmbeddingRequestArgs::default() + .model(&self.model) + .input([text.to_string()]) + .build()?; + + let response = client.embeddings().create(request).await?; + let embedding = response + .data.first() + .map(|data| data.embedding.clone()) + .unwrap_or_default(); + + Ok(embedding) + } +} + +#[async_trait] +impl BaseTextEmbedding for OpenAIEmbedding { + async fn aembed(&self, text: &str) -> anyhow::Result> { + let token_chunks = chunk_text(text, self.max_tokens); + let mut chunk_embeddings = Vec::new(); + let mut chunk_lens = Vec::new(); + + for chunk in token_chunks { + let embedding = self._aembed_with_retry(&chunk).await?; + chunk_embeddings.push(embedding); + chunk_lens.push(chunk.len()); + } + + if chunk_embeddings.len() == 1 { + return Ok(chunk_embeddings.swap_remove(0)); + } + + let rows = chunk_embeddings.len(); + let cols = chunk_embeddings[0].len(); + let flat_embeddings: Vec = chunk_embeddings.into_iter().flatten().collect(); + let array_embeddings = Array2::from_shape_vec((rows, cols), flat_embeddings).unwrap(); + + let array_lens = Array1::from_iter(chunk_lens.into_iter().map(|x| x as f32)); + + // Calculate the weighted average + let weighted_avg = array_embeddings.weighted_mean_axis(Axis(0), &array_lens).unwrap(); + + // Normalize the embeddings + let norm = weighted_avg.mapv(|x| x.powi(2)).sum().sqrt(); + let normalized_embeddings = weighted_avg / norm; + + Ok(normalized_embeddings.to_vec()) + } +} + +pub fn num_tokens(text: &str) -> usize { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_with_special_tokens(text).len() +} + +fn batched(iterable: impl Iterator, n: usize) -> impl Iterator> { + if n < 1 { + panic!("n must be at least one"); + } + + let mut it = iterable.peekable(); + std::iter::from_fn(move || { + let mut batch = Vec::with_capacity(n); + for _ in 0..n { + if let Some(item) = it.next() { + batch.push(item); + } else { + break; + } + } + if batch.is_empty() { + None + } else { + Some(batch) + } + }) +} + +fn chunk_text(text: &str, max_tokens: usize) -> impl Iterator + '_ { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + let tokens = bpe.encode_with_special_tokens(text); + + let chunk_iterator = batched(tokens.into_iter(), max_tokens); + chunk_iterator.map(move |chunk| bpe.decode(chunk).unwrap()) +}