diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3cc84bdf392..49afef71a0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: # Triggers the workflow on push or pull request events but only for the "main" branch push: - branches: [ "master" , "release/0.4.x" ] + branches: [ "master" , "release/1.0.x" ] pull_request: - branches: [ "master" , "release/0.4.x" ] + branches: [ "master" , "release/1.0.x" ] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -102,6 +102,7 @@ jobs: cp -r $GITHUB_WORKSPACE/tests/clickhouse-test ./ mkdir queries cp -r $GITHUB_WORKSPACE/tests/queries/4_cnch_stateless queries/ + cp -r $GITHUB_WORKSPACE/tests/queries/8_cnch_S3_only queries/ cp -r $GITHUB_WORKSPACE/tests/queries/shell_config.sh queries/ # We need skip-list to skip some tests. cp $GITHUB_WORKSPACE/tests/queries/skip_list.json queries/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3581abcc7f3..accedc5261b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,7 +11,7 @@ jobs: - name: Setup Git Proxy. if: ${{ runner.name != 'ec2-aws-id4-10.10.129.157' }} run: | - git config --global http.proxy http://${{ secrets.HTTP_PROXY }} + git config --global http.https://github.com.proxy http://${{ secrets.VIP_PROXY }}:3128 - name: Setup Environment Varialbes run: | export PROJECT_NAME=byconity-$(cat /etc/hostname) @@ -48,17 +48,17 @@ jobs: tag: ${{ github.ref }} overwrite: true file_glob: true + - name: Login to Docker Hub. + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build Docker. run: | cd ByConity/docker/debian/release BUILD_TYPE=Release CMAKE_FLAGS="-DENABLE_JAVA_EXTENSIONS=0" TAG=${{ github.ref_name }} make image-github cd - docker images - - name: Login to Docker Hub. - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Push Dockers. run: | docker push byconity/byconity:${{ github.ref_name }} @@ -68,6 +68,22 @@ jobs: docker tag byconity/byconity:${{ github.ref_name }} byconity/byconity:stable docker push byconity/byconity:stable docker image rm -f byconity/byconity:stable + - name: Login to Volcano Container Registry + uses: docker/login-action@v3 + with: + registry: byconity-cn-beijing.cr.volces.com + username: ${{ secrets.VOLC_CR_USER }} + password: ${{ secrets.VOLC_CR_PASS }} + - name: Push Dockers to Volcano CR + run: | + docker tag byconity/byconity:${{ github.ref_name }} byconity-cn-beijing.cr.volces.com/byconity/byconity:${{ github.ref_name}} + docker push byconity-cn-beijing.cr.volces.com/byconity/byconity:${{ github.ref_name }} + - name: Tag as Stable then push to Volcano CR + if: "!github.event.release.prerelease" + run: | + docker tag byconity/byconity:${{ github.ref_name }} byconity-cn-beijing.cr.volces.com/byconity/byconity:stable + docker push byconity-cn-beijing.cr.volces.com/byconity/byconity:stable + docker image rm -f byconity-cn-beijing.cr.volces.com/byconity/byconity:stable - name: Cleanup Data. if: always() run: | diff --git a/.gitmodules b/.gitmodules index cddbf96b3d5..5ee1a75824c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -286,7 +286,7 @@ url = https://github.com/google/benchmark.git [submodule "contrib/hivemetastore"] path = contrib/hivemetastore - url = https://github.com/ClickHouse/hive-metastore.git + url = https://github.com/ByConity/hivemetastore.git [submodule "contrib/udns"] path = contrib/udns url = https://github.com/ortclib/udns.git diff --git a/base/common/ThreadLocal.h b/base/common/ThreadLocal.h index 6e61ba56412..ac07d1dfa16 100644 --- a/base/common/ThreadLocal.h +++ b/base/common/ThreadLocal.h @@ -102,7 +102,10 @@ template class ThreadLocalManagedBase { public: ThreadLocalManagedBase() noexcept { - bthread_key_create(&key, [](void *obj) { delete static_cast(obj); }); + bthread_key_create(&key, [](void * obj) { + static_cast(obj)->~T(); + free(obj); + }); } ~ThreadLocalManagedBase() noexcept { bthread_key_delete(key); } @@ -128,9 +131,3 @@ class ThreadLocalManagedBase { private: bthread_key_t key; }; - -template -class ThreadLocalManaged : public ThreadLocalManagedBase> { -public: - static void *create() { return new T(); } -}; diff --git a/base/common/bthread_exception.cpp b/base/common/bthread_exception.cpp index e13790f60ca..6505213b83d 100644 --- a/base/common/bthread_exception.cpp +++ b/base/common/bthread_exception.cpp @@ -1,5 +1,14 @@ #include +template +class ThreadLocalManagedUntracked : public ThreadLocalManagedBase> { +public: + static void *create() { + void * p = malloc(sizeof(T)); + return new (p) T(); + } +}; + namespace { struct __cxa_eh_globals { void * caughtExceptions; @@ -11,7 +20,7 @@ namespace __cxxabiv1 { namespace { __cxa_eh_globals * __globals () { - static ThreadLocalManaged<__cxa_eh_globals> eh_globals; + static ThreadLocalManagedUntracked<__cxa_eh_globals> eh_globals; return eh_globals.get(); } } diff --git a/base/common/chrono_io.h b/base/common/chrono_io.h index cb70b94a354..4c8b7ccc6cf 100644 --- a/base/common/chrono_io.h +++ b/base/common/chrono_io.h @@ -39,7 +39,7 @@ std::string to_string(const std::chrono::time_point & tp) // Don't use DateLUT because it shows weird characters for // TimePoint::max(). I wish we could use C++20 format, but it's not // there yet. - // return DateLUT::instance().timeToString(std::chrono::system_clock::to_time_t(tp)); + // return DateLUT::serverTimezoneInstance().timeToString(std::chrono::system_clock::to_time_t(tp)); auto in_time_t = std::chrono::system_clock::to_time_t(tp); return to_string(in_time_t); diff --git a/base/daemon/BaseDaemon.cpp b/base/daemon/BaseDaemon.cpp index 9a345a72cbf..fbe49056370 100644 --- a/base/daemon/BaseDaemon.cpp +++ b/base/daemon/BaseDaemon.cpp @@ -1128,7 +1128,7 @@ void BaseDaemon::shouldSetupWatchdog(char * argv0_) void BaseDaemon::setupWatchdog() { /// Initialize in advance to avoid double initialization in forked processes. - DateLUT::instance(); + DateLUT::serverTimezoneInstance(); std::string original_process_name; if (argv0) diff --git a/ci_scripts/config/users.xml b/ci_scripts/config/users.xml index a7859c3d20b..4aa4a732bc3 100644 --- a/ci_scripts/config/users.xml +++ b/ci_scripts/config/users.xml @@ -21,6 +21,7 @@ 6 8589934592 50000 + 0 diff --git a/contrib/hualloc/hu_alloc.cpp b/contrib/hualloc/hu_alloc.cpp index 0a2f598a9c6..44093081bd1 100644 --- a/contrib/hualloc/hu_alloc.cpp +++ b/contrib/hualloc/hu_alloc.cpp @@ -71,7 +71,7 @@ void* ReclaimThread(void *args) { // keep & max can be separate for large & segment spaces const char * sleep_second = std::getenv("HUALLOC_CLAIM_INTERVAL"); - int sleep = 3; + int sleep = 1; try { if (sleep_second && std::strlen(sleep_second) > 0) @@ -79,7 +79,7 @@ void* ReclaimThread(void *args) } catch(...) { - sleep = 3; + sleep = 1; } yint cached = *(yint *) args; diff --git a/contrib/hualloc/hu_alloc.h b/contrib/hualloc/hu_alloc.h index 036b3bcd47b..2a0b5484959 100644 --- a/contrib/hualloc/hu_alloc.h +++ b/contrib/hualloc/hu_alloc.h @@ -1843,7 +1843,7 @@ static void DumpLocalAllocMasksLocked(char *segment) //////////////////////////////////////////////////////////////////////////////////////////////////////// static yint ReclaimKeepSize = 1 * 1024 * (1ull << 20); -static yint ReclaimMaxReclaim = 512 * (1ull << 20); +static yint ReclaimMaxReclaim = 1* 1024 * (1ull << 20); static void hu_init() { diff --git a/contrib/incubator-brpc b/contrib/incubator-brpc index 275bf4ff355..aa0318be51b 160000 --- a/contrib/incubator-brpc +++ b/contrib/incubator-brpc @@ -1 +1 @@ -Subproject commit 275bf4ff35537eab940a84c615da17eee2b4cd9b +Subproject commit aa0318be51bc4aed735b7759452b7cd25e3c34dd diff --git a/contrib/libhdfs3-open b/contrib/libhdfs3-open index 84f9550229e..507a309506d 160000 --- a/contrib/libhdfs3-open +++ b/contrib/libhdfs3-open @@ -1 +1 @@ -Subproject commit 84f9550229e5a6836e6fab9ff5532557e213c0a8 +Subproject commit 507a309506db9d50d500448fc1fafd5959e0e368 diff --git a/docker/CI/multi-servers/server.yml b/docker/CI/multi-servers/server.yml index b3954ec2a0f..3c5d5ad9480 100644 --- a/docker/CI/multi-servers/server.yml +++ b/docker/CI/multi-servers/server.yml @@ -1,11 +1,15 @@ logger: - level: debug + level: trace log: /var/log/byconity/out.log errorlog: /var/log/byconity/err.log testlog: /var/log/byconity/test.log size: 1000M count: 10 console: true +additional_services: + GIS: 1 + VectorSearch: 1 + FullTextSearch: 1 http_port: 21557 rpc_port: 30605 tcp_port: 52145 @@ -15,6 +19,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 @@ -89,7 +100,7 @@ service_discovery: psm: data.cnch.server node: - host: server-0 - hostname: server + hostname: server-0 ports: port: - name: PORT2 diff --git a/docker/CI/multi-servers/worker.yml b/docker/CI/multi-servers/worker.yml index b6f9b141788..314f12df597 100644 --- a/docker/CI/multi-servers/worker.yml +++ b/docker/CI/multi-servers/worker.yml @@ -55,6 +55,11 @@ storage_configuration: default: local_disk disk: local_disk hdfs_addr: "hdfs://hdfs-namenode:9000" +cnch_unique_table_log: + database: cnch_system + table: cnch_unique_table_log + flush_max_row_count: 10000 + flush_interval_milliseconds: 7500 query_log: database: system table: query_log diff --git a/docker/CI/multi-workers/server.yml b/docker/CI/multi-workers/server.yml index 05a296a5a95..109b3dd933e 100644 --- a/docker/CI/multi-workers/server.yml +++ b/docker/CI/multi-workers/server.yml @@ -5,6 +5,10 @@ logger: testlog: /var/log/byconity/test.log size: 1000M count: 10 +additional_services: + GIS: 1 + VectorSearch: 1 + FullTextSearch: 1 http_port: 21557 rpc_port: 30605 tcp_port: 52145 @@ -14,6 +18,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/CI/s3/server.yml b/docker/CI/s3/server.yml index b383bfee3a1..6d53ac43097 100644 --- a/docker/CI/s3/server.yml +++ b/docker/CI/s3/server.yml @@ -1,12 +1,16 @@ # Auto-generated! Please do not modify this file directly. Refer to 'convert-hdfs-configs-to-s3.sh'. logger: - level: debug + level: trace log: /var/log/byconity/out.log errorlog: /var/log/byconity/err.log testlog: /var/log/byconity/test.log size: 1000M count: 10 console: true +additional_services: + GIS: 1 + VectorSearch: 1 + FullTextSearch: 1 http_port: 21557 rpc_port: 30605 tcp_port: 52145 @@ -16,6 +20,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 @@ -98,7 +109,7 @@ service_discovery: psm: data.cnch.server node: - host: server-0 - hostname: server + hostname: server-0 ports: port: - name: PORT2 diff --git a/docker/CI/s3/worker.yml b/docker/CI/s3/worker.yml index bf0b8b4bfed..503691fab5b 100644 --- a/docker/CI/s3/worker.yml +++ b/docker/CI/s3/worker.yml @@ -63,6 +63,11 @@ storage_configuration: disk: s3_disk # To avoid break hard-coded test cases. cnch_default_policy: cnch_default_hdfs +cnch_unique_table_log: + database: cnch_system + table: cnch_unique_table_log + flush_max_row_count: 10000 + flush_interval_milliseconds: 7500 query_log: database: system table: query_log diff --git a/docker/ci-deploy/config/daemon-manager.yml b/docker/ci-deploy/config/daemon-manager.yml index 12cb9ba0fff..58ee09e88bf 100644 --- a/docker/ci-deploy/config/daemon-manager.yml +++ b/docker/ci-deploy/config/daemon-manager.yml @@ -35,6 +35,10 @@ daemon_manager: name: DEDUP_WORKER interval: 3000 disable: 0 + - + name: TXN_GC + interval: 3000 + disable: 0 hdfs_addr: "hdfs://COMPOSE_PROJECT_NAME-hdfs-namenode:9000" storage_configuration: diff --git a/docker/ci-deploy/config/server.yml b/docker/ci-deploy/config/server.yml index 19dc34443b3..52a901aa678 100644 --- a/docker/ci-deploy/config/server.yml +++ b/docker/ci-deploy/config/server.yml @@ -14,6 +14,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/ci-deploy/config/worker.yml b/docker/ci-deploy/config/worker.yml index 86dc15dc655..d00225c80d4 100644 --- a/docker/ci-deploy/config/worker.yml +++ b/docker/ci-deploy/config/worker.yml @@ -13,6 +13,13 @@ exchange_port: 47447 exchange_status_port: 60611 interserver_http_port: 30491 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: worker max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/ci-deploy/config_for_s3_storage/daemon-manager.yml b/docker/ci-deploy/config_for_s3_storage/daemon-manager.yml index ec0aa2cabcd..e89d5d958f5 100644 --- a/docker/ci-deploy/config_for_s3_storage/daemon-manager.yml +++ b/docker/ci-deploy/config_for_s3_storage/daemon-manager.yml @@ -35,6 +35,10 @@ daemon_manager: name: DEDUP_WORKER interval: 3000 disable: 0 + - + name: TXN_GC + interval: 3000 + disable: 0 storage_configuration: disks: diff --git a/docker/ci-deploy/config_for_s3_storage/server.yml b/docker/ci-deploy/config_for_s3_storage/server.yml index 2660f8493ef..762b89ee281 100644 --- a/docker/ci-deploy/config_for_s3_storage/server.yml +++ b/docker/ci-deploy/config_for_s3_storage/server.yml @@ -15,6 +15,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/ci-deploy/docker-compose-s3.yml b/docker/ci-deploy/docker-compose-s3.yml index 6ef279d79a9..d4caac79ae4 100644 --- a/docker/ci-deploy/docker-compose-s3.yml +++ b/docker/ci-deploy/docker-compose-s3.yml @@ -121,6 +121,7 @@ services: - SYS_PTRACE minio: ports: + - 19000:9000 - 19001:9001 environment: - MINIO_ROOT_USER=minio diff --git a/docker/debian/base/Dockerfile b/docker/debian/base/Dockerfile index 8714f79c859..b0f5933870e 100644 --- a/docker/debian/base/Dockerfile +++ b/docker/debian/base/Dockerfile @@ -4,13 +4,11 @@ ARG FDB_VERSION=7.1.27 WORKDIR /downloads RUN apk --no-cache add wget RUN wget -qO cmake.3.17.tar.gz https://cmake.org/files/v3.17/cmake-3.17.0-Linux-x86_64.tar.gz -# RUN wget https://apt.llvm.org/llvm.sh -RUN wget -q https://mirrors.tuna.tsinghua.edu.cn/llvm-apt/llvm.sh RUN wget -qO foundationdb-clients.deb https://github.com/apple/foundationdb/releases/download/${FDB_VERSION}/foundationdb-clients_${FDB_VERSION}-1_amd64.deb RUN wget -qO tini https://github.com/krallin/tini/releases/download/v0.19.0/tini # Debian builder -FROM debian:11.6 as debian-builder +FROM debian:11.11 as debian-builder LABEL description="Debian image for compiling" LABEL org.opencontainers.image.source="https://github.com/ByConity/ByConity" @@ -24,8 +22,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-instal COPY --from=downloader /downloads /downloads WORKDIR /downloads RUN tar --strip-components=1 -xzf cmake.3.17.tar.gz -C /usr/local -RUN chmod +x llvm.sh && ./llvm.sh 11 -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt -RUN dpkg -i foundationdb-clients.deb + +# Use the official LLVM script and mirror +RUN wget -qO - https://apt.llvm.org/llvm.sh | bash -s 11 +RUN dpkg -i foundationdb-clients.deb || apt-get install --fix-missing -y RUN rm -rf /downloads COPY build.sh / @@ -39,7 +39,7 @@ RUN ldconfig ENV CC=clang-11 CXX=clang++-11 # Base runner image -FROM debian:11.6-slim as debian-runner +FROM debian:11.11-slim as debian-runner LABEL description="Base Debian image for runtime" LABEL org.opencontainers.image.source="https://github.com/ByConity/ByConity" diff --git a/docker/debian/base/build.sh b/docker/debian/base/build.sh index 6166aaa19de..3cf1d64b3e5 100644 --- a/docker/debian/base/build.sh +++ b/docker/debian/base/build.sh @@ -15,7 +15,11 @@ sed -i \ -e "s/set (VERSION_SCM [^) ]*/set (VERSION_SCM $VERSION_SCM/g;" \ cmake/version.cmake -cmake -DCMAKE_BUILD_TYPE=${BUILD_TYPE} ${CMAKE_FLAGS} -DENABLE_BREAKPAD=ON -DCMAKE_INSTALL_PREFIX=build_install -S . -B build_docker +# Set a default value if BREAKPAD_STATUS is empty or unset +BREAKPAD_STATUS=${BREAKPAD_STATUS:-"OFF"} +echo "Breakpad status: $BREAKPAD_STATUS" + +cmake -DCMAKE_BUILD_TYPE=${BUILD_TYPE} ${CMAKE_FLAGS} -DENABLE_BREAKPAD=${BREAKPAD_STATUS} -DCMAKE_INSTALL_PREFIX=build_install -S . -B build_docker NUM_JOBS=$(( ($(nproc || grep -c ^processor /proc/cpuinfo) + 1) / 2 )) ninja -C build_docker -j $NUM_JOBS install diff --git a/docker/docker-compose/byconity-multi-cluster/daemon-manager.yml b/docker/docker-compose/byconity-multi-cluster/daemon-manager.yml index 1e5ed11f4ae..39e0df00978 100644 --- a/docker/docker-compose/byconity-multi-cluster/daemon-manager.yml +++ b/docker/docker-compose/byconity-multi-cluster/daemon-manager.yml @@ -104,6 +104,10 @@ daemon_manager: name: DEDUP_WORKER interval: 3000 disable: 0 + - + name: TXN_GC + interval: 3000 + disable: 0 hdfs_addr: "hdfs://hdfs-namenode:9000" -cnch_config: "/config/cnch-config.yml" \ No newline at end of file +cnch_config: "/config/cnch-config.yml" diff --git a/docker/docker-compose/byconity-multi-cluster/server.yml b/docker/docker-compose/byconity-multi-cluster/server.yml index f7982fe324f..d735f0aa349 100644 --- a/docker/docker-compose/byconity-multi-cluster/server.yml +++ b/docker/docker-compose/byconity-multi-cluster/server.yml @@ -14,6 +14,13 @@ exchange_port: 47447 exchange_status_port: 60611 interserver_http_port: 30491 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/docker-compose/byconity-multi-cluster/worker.yml b/docker/docker-compose/byconity-multi-cluster/worker.yml index b27eb29ffed..7949c720190 100644 --- a/docker/docker-compose/byconity-multi-cluster/worker.yml +++ b/docker/docker-compose/byconity-multi-cluster/worker.yml @@ -13,6 +13,13 @@ exchange_port: 47447 exchange_status_port: 60611 interserver_http_port: 30491 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: worker max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/docker-compose/byconity-multiworkers-cluster/daemon-manager.yml b/docker/docker-compose/byconity-multiworkers-cluster/daemon-manager.yml index 1e5ed11f4ae..39e0df00978 100644 --- a/docker/docker-compose/byconity-multiworkers-cluster/daemon-manager.yml +++ b/docker/docker-compose/byconity-multiworkers-cluster/daemon-manager.yml @@ -104,6 +104,10 @@ daemon_manager: name: DEDUP_WORKER interval: 3000 disable: 0 + - + name: TXN_GC + interval: 3000 + disable: 0 hdfs_addr: "hdfs://hdfs-namenode:9000" -cnch_config: "/config/cnch-config.yml" \ No newline at end of file +cnch_config: "/config/cnch-config.yml" diff --git a/docker/docker-compose/byconity-multiworkers-cluster/server.yml b/docker/docker-compose/byconity-multiworkers-cluster/server.yml index 84dbefbc7f0..e92a3e920ef 100644 --- a/docker/docker-compose/byconity-multiworkers-cluster/server.yml +++ b/docker/docker-compose/byconity-multiworkers-cluster/server.yml @@ -14,6 +14,13 @@ exchange_port: 47447 exchange_status_port: 60611 interserver_http_port: 30491 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/docker-compose/byconity-multiworkers-cluster/worker.yml b/docker/docker-compose/byconity-multiworkers-cluster/worker.yml index b27eb29ffed..7949c720190 100644 --- a/docker/docker-compose/byconity-multiworkers-cluster/worker.yml +++ b/docker/docker-compose/byconity-multiworkers-cluster/worker.yml @@ -13,6 +13,13 @@ exchange_port: 47447 exchange_status_port: 60611 interserver_http_port: 30491 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: worker max_connections: 4096 keep_alive_timeout: 3 diff --git a/docker/docker-compose/byconity-simple-cluster/daemon-manager.yml b/docker/docker-compose/byconity-simple-cluster/daemon-manager.yml index 3f329583c4a..9b13b0587ad 100644 --- a/docker/docker-compose/byconity-simple-cluster/daemon-manager.yml +++ b/docker/docker-compose/byconity-simple-cluster/daemon-manager.yml @@ -44,7 +44,7 @@ service_discovery: psm: data.cnch.server node: host: server-0 - hostname: server + hostname: server-0 ports: port: - @@ -200,5 +200,9 @@ daemon_manager: name: DEDUP_WORKER interval: 3000 disable: 0 + - + name: TXN_GC + interval: 3000 + disable: 0 hdfs_addr: "hdfs://hdfs-namenode:9000" -cnch_config: "/config/cnch-config.yml" \ No newline at end of file +cnch_config: "/config/cnch-config.yml" diff --git a/docker/docker-compose/byconity-simple-cluster/server.yml b/docker/docker-compose/byconity-simple-cluster/server.yml index 743b5ba44b3..f249082dc0d 100644 --- a/docker/docker-compose/byconity-simple-cluster/server.yml +++ b/docker/docker-compose/byconity-simple-cluster/server.yml @@ -1,11 +1,15 @@ logger: - level: debug + level: trace log: /var/log/byconity/out.log errorlog: /var/log/byconity/err.log testlog: /var/log/byconity/test.log size: 1000M count: 10 console: true +additional_services: + GIS: 1 + VectorSearch: 1 + FullTextSearch: 1 http_port: 21557 rpc_port: 30605 tcp_port: 52145 @@ -15,6 +19,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: server max_connections: 4096 keep_alive_timeout: 3 @@ -46,7 +57,7 @@ service_discovery: psm: data.cnch.server node: host: server-0 - hostname: server + hostname: server-0 ports: port: - diff --git a/docker/docker-compose/byconity-simple-cluster/worker.yml b/docker/docker-compose/byconity-simple-cluster/worker.yml index 6bd4dd0be43..82b2d20908c 100644 --- a/docker/docker-compose/byconity-simple-cluster/worker.yml +++ b/docker/docker-compose/byconity-simple-cluster/worker.yml @@ -15,6 +15,13 @@ exchange_status_port: 60611 interserver_http_port: 30491 mysql_port: 9004 listen_host: "0.0.0.0" +prometheus: + endpoint: "/metrics" + port: 0 + metrics: true + events: true + asynchronous_metrics: true + part_metrics: false cnch_type: worker vw_name: vw_default max_connections: 4096 @@ -49,7 +56,7 @@ service_discovery: headless_service: cnch-server-headless node: host: server-0 - hostname: server + hostname: server-0 ports: port: - diff --git a/docker/packager/binary/build.sh b/docker/packager/binary/build.sh index 348ed846072..aa276b2ade7 100755 --- a/docker/packager/binary/build.sh +++ b/docker/packager/binary/build.sh @@ -33,7 +33,9 @@ ccache_status ccache --zero-stats ||: # Build everything -cmake --debug-trycompile -DCMAKE_VERBOSE_MAKEFILE=1 -LA "-DCMAKE_BUILD_TYPE=$BUILD_TYPE" -DENABLE_CHECK_HEAVY_BUILDS=0 -DENABLE_BREAKPAD=ON "${CMAKE_FLAGS[@]}" .. +# Set a default value if BREAKPAD_STATUS is empty or unset +BREAKPAD_STATUS=${BREAKPAD_STATUS:-"ON"} +cmake --debug-trycompile -DCMAKE_VERBOSE_MAKEFILE=1 -LA "-DCMAKE_BUILD_TYPE=$BUILD_TYPE" -DENABLE_CHECK_HEAVY_BUILDS=0 -DENABLE_BREAKPAD=$BREAKPAD_STATUS "${CMAKE_FLAGS[@]}" .. # No quotes because I want it to expand to nothing if empty. # shellcheck disable=SC2086 # No quotes because I want it to expand to nothing if empty. diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index 07cfc6c187e..df1c42f2110 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -19,7 +19,9 @@ * All Bytedance's Modifications are Copyright (2023) Bytedance Ltd. and/or its affiliates. */ +#include "Common/CurrentThread.h" #include "ConnectionParameters.h" +#include "Core/Protocol.h" #include "QueryFuzzer.h" #include "Storages/HDFS/HDFSCommon.h" #include "Suggest.h" @@ -457,7 +459,7 @@ class Client : public Poco::Util::Application if (current_time % 3 != 0) return false; - auto days = DateLUT::instance().toDayNum(current_time).toUnderType(); + auto days = DateLUT::sessionInstance().toDayNum(current_time).toUnderType(); for (auto d : chineseNewYearIndicators) { /// Let's celebrate until Lantern Festival @@ -552,11 +554,20 @@ class Client : public Poco::Util::Application int mainImpl() { UseSSL use_ssl; + MainThreadStatus::getInstance(); registerFormats(); registerFunctions(); registerAggregateFunctions(); + { + // All that just to set DB::CurrentThread::get().getGlobalContext() + // which is required for client timezone (pushed from server) to work. + auto thread_group = std::make_shared(); + const_cast(thread_group->global_context) = context; + CurrentThread::attachTo(thread_group); + } + /// Batch mode is enabled if one of the following is true: /// - -e (--query) command line option is present. /// The value of the option is used as the text of query (or of multiple queries). @@ -612,7 +623,7 @@ class Client : public Poco::Util::Application connect(); /// Initialize DateLUT here to avoid counting time spent here as query execution time. - const auto local_tz = DateLUT::instance().getTimeZone(); + const auto local_tz = DateLUT::sessionInstance().getTimeZone(); if (is_interactive) { @@ -1686,12 +1697,28 @@ class Client : public Poco::Util::Application context->applySettingsChanges(settings_ast.as()->changes); }; const auto * insert = parsed_query->as(); - if (insert && insert->settings_ast) + if (const auto * select = parsed_query->as(); select && select->settings()) + apply_query_settings(*select->settings()); + else if (const auto * select_with_union = parsed_query->as()) + { + const ASTs & children = select_with_union->list_of_selects->children; + if (!children.empty()) + { + // On the client it is enough to apply settings only for the + // last SELECT, since the only thing that is important to apply + // on the client is format settings. + const auto * last_select = children.back()->as(); + if (last_select && last_select->settings()) + { + apply_query_settings(*last_select->settings()); + } + } + } + else if (const auto * query_with_output = parsed_query->as(); + query_with_output && query_with_output->settings_ast) + apply_query_settings(*query_with_output->settings_ast); + else if (insert && insert->settings_ast) apply_query_settings(*insert->settings_ast); - /// FIXME: try to prettify this cast using `as<>()` - const auto * with_output = dynamic_cast(parsed_query.get()); - if (with_output && with_output->settings_ast) - apply_query_settings(*with_output->settings_ast); if (!connection->checkConnected()) connect(); @@ -2149,6 +2176,10 @@ class Client : public Poco::Util::Application case Protocol::Server::QueryMetrics: return true; + case Protocol::Server::TimezoneUpdate: + onTimezoneUpdate(packet.server_timezone); + return true; + default: throw Exception( ErrorCodes::UNKNOWN_PACKET_FROM_SERVER, "Unknown packet {} from server {}", packet.type, connection->getDescription()); @@ -2181,9 +2212,13 @@ class Client : public Poco::Util::Application columns_description = ColumnsDescription::parse(packet.multistring_message[1]); return receiveSampleBlock(out, columns_description); + case Protocol::Server::TimezoneUpdate: + onTimezoneUpdate(packet.server_timezone); + break; + default: throw NetException( - "Unexpected packet from server (expected Data, Exception or Log, got " + "Unexpected packet from server (expected Data, Exception or Log or TimezoneUpdate , got " + String(Protocol::Server::toString(packet.type)) + ")", ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER); } @@ -2212,6 +2247,10 @@ class Client : public Poco::Util::Application onLogData(packet.block); break; + case Protocol::Server::TimezoneUpdate: + onTimezoneUpdate(packet.server_timezone); + break; + default: throw NetException( "Unexpected packet from server (expected Exception, EndOfStream or Log, got " @@ -2226,7 +2265,7 @@ class Client : public Poco::Util::Application { auto packet_type = connection->checkPacket(); - while (packet_type && *packet_type == Protocol::Server::Log) + while (packet_type && (*packet_type == Protocol::Server::Log || *packet_type == Protocol::Server::TimezoneUpdate)) { receiveAndProcessPacket(false); packet_type = connection->checkPacket(); @@ -2469,9 +2508,17 @@ class Client : public Poco::Util::Application } } + void onTimezoneUpdate(const String & tz) + { + context->setSetting("session_timezone", tz); + } + static void showClientVersion() { - std::cout << R"( + #define RESET_ "\033[0m" + #define LIGHT_CYAN_ "\033[96m" + + std::cout << LIGHT_CYAN_ << R"( ______ _ _ _ | ___ \ | | | | | | | |_/ /_ _| |_ ___| |_| | ___ _ _ ___ ___ @@ -2480,7 +2527,7 @@ class Client : public Poco::Util::Application \____/ \__, |\__\___\_| |_/\___/ \__,_|___/\___| __/ | |___/ - )" << std::endl; + )" << RESET_ << std::endl; std::cout << VERSION_NAME << " client version " << VERSION_STRING << VERSION_OFFICIAL << "." << std::endl; } diff --git a/programs/copier/ClusterCopierApp.cpp b/programs/copier/ClusterCopierApp.cpp index 69f2fd3383c..3a3b0bf9cb6 100644 --- a/programs/copier/ClusterCopierApp.cpp +++ b/programs/copier/ClusterCopierApp.cpp @@ -39,7 +39,7 @@ void ClusterCopierApp::initialize(Poco::Util::Application & self) time_t timestamp = Poco::Timestamp().epochTime(); auto curr_pid = Poco::Process::id(); - process_id = std::to_string(DateLUT::instance().toNumYYYYMMDDhhmmss(timestamp)) + "_" + std::to_string(curr_pid); + process_id = std::to_string(DateLUT::serverTimezoneInstance().toNumYYYYMMDDhhmmss(timestamp)) + "_" + std::to_string(curr_pid); host_id = escapeForFileName(getFQDNOrHostName()) + '#' + process_id; process_path = fs::weakly_canonical(fs::path(base_dir) / ("clickhouse-copier_" + process_id)); fs::create_directories(process_path); diff --git a/programs/dumper/Dumper.cpp b/programs/dumper/Dumper.cpp index 822eb81435f..9c883a638d8 100644 --- a/programs/dumper/Dumper.cpp +++ b/programs/dumper/Dumper.cpp @@ -329,9 +329,8 @@ void ClickHouseDumper::initHDFS() /// Options load from command line argument use priority -100 in layeredconfiguration, so construct /// hdfs params from config directly rather than from config file - HDFSConnectionParams hdfs_params = HDFSConnectionParams(HDFSConnectionParams::CONN_NNPROXY, - config().getString("hdfs_user", "clickhouse"), config().getString("output_hdfs_nnproxy", "nnproxy")); - hdfs_params.lookupOnNeed(); + HDFSConnectionParams hdfs_params = HDFSConnectionParams::parseFromMisusedNNProxyStr( + config().getString("output_hdfs_nnproxy", "nnproxy"), config().getString("hdfs_user", "clickhouse")); global_context->setHdfsConnectionParams(hdfs_params); /// register default hdfs file system bool has_hdfs_disk = false; diff --git a/programs/keeper/Keeper.cpp b/programs/keeper/Keeper.cpp index 7ad6fd6ae7b..f50afe46097 100644 --- a/programs/keeper/Keeper.cpp +++ b/programs/keeper/Keeper.cpp @@ -379,8 +379,8 @@ int Keeper::main(const std::vector & /*args*/) /// Initialize DateLUT early, to not interfere with running time of first query. LOG_DEBUG(log, "Initializing DateLUT."); - DateLUT::instance(); - LOG_TRACE(log, "Initialized DateLUT with time zone '{}'.", DateLUT::instance().getTimeZone()); + DateLUT::serverTimezoneInstance(); + LOG_TRACE(log, "Initialized DateLUT with time zone '{}'.", DateLUT::serverTimezoneInstance().getTimeZone()); /// Don't want to use DNS cache DNSResolver::instance().setDisableCacheFlag(); diff --git a/programs/meta-inspector/MetastoreInspector.cpp b/programs/meta-inspector/MetastoreInspector.cpp index dbc300525c1..837cc294fc5 100644 --- a/programs/meta-inspector/MetastoreInspector.cpp +++ b/programs/meta-inspector/MetastoreInspector.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -350,7 +351,7 @@ class MetastoreInspector : public Poco::Util::Application { MetaCommand cmd = MetaCommand::parse(command); std::string full_key = name_space.empty() ? cmd.key : Catalog::escapeString(name_space) + '_' + cmd.key; - size_t key_offset = name_space.empty() ? 0 : name_space.size() + 1; + size_t key_offset = name_space.empty() ? 0 : Catalog::escapeString(name_space).size() + 1; switch (cmd.type) { case MetaCommandType::HELP: @@ -362,6 +363,14 @@ class MetastoreInspector : public Poco::Util::Application { std::string value; metastore_ptr->get(full_key, value); + // try parse large KV before really dump metadata. + DB::Protos::DataModelLargeKVMeta large_kv_model; + if (Catalog::tryParseLargeKVMetaModel(value, large_kv_model)) + { + std::cout << "Large KV base value: \n" << large_kv_model.DebugString() << std::endl; + tryGetLargeValue(metastore_ptr, name_space, full_key, value); + std::cout << "Original value : " << std::endl; + } dumpMetadata(cmd.key, value); break; } diff --git a/programs/obfuscator/Obfuscator.cpp b/programs/obfuscator/Obfuscator.cpp index 00020fb1cef..a3327438a73 100644 --- a/programs/obfuscator/Obfuscator.cpp +++ b/programs/obfuscator/Obfuscator.cpp @@ -461,7 +461,7 @@ class DateTimeModel : public IModel const DateLUTImpl & date_lut; public: - explicit DateTimeModel(UInt64 seed_) : seed(seed_), date_lut(DateLUT::instance()) {} + explicit DateTimeModel(UInt64 seed_) : seed(seed_), date_lut(DateLUT::serverTimezoneInstance()) { } void train(const IColumn &) override {} void finalize() override {} diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index e13643ca0e1..fbc145e771e 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -595,10 +595,8 @@ int Server::main(const std::vector & /*args*/) global_context->initCnchConfig(config()); global_context->setBlockPrivilegedOp(config().getBool("restrict_tenanted_users_to_privileged_operations", false)); global_context->initRootConfig(config()); - global_context->initPreloadThrottler(); const auto & root_config = global_context->getRootConfig(); - // Initialize global thread pool. Do it before we fetch configs from zookeeper // nodes (`from_zk`), because ZooKeeper interface uses the pool. We will // ignore `max_thread_pool_size` in configs we fetch from ZK, but oh well. @@ -854,8 +852,8 @@ int Server::main(const std::vector & /*args*/) /// Initialize DateLUT early, to not interfere with running time of first query. LOG_DEBUG(log, "Initializing DateLUT."); - DateLUT::instance(); - LOG_TRACE(log, "Initialized DateLUT with time zone '{}'.", DateLUT::instance().getTimeZone()); + DateLUT::serverTimezoneInstance(); + LOG_TRACE(log, "Initialized DateLUT with time zone '{}'.", DateLUT::serverTimezoneInstance().getTimeZone()); /// Storage with temporary data for processing of heavy queries. { @@ -1023,27 +1021,31 @@ int Server::main(const std::vector & /*args*/) #if USE_HUALLOC if (config->getBool("hualloc_numa_aware", false)) { - size_t max_numa_node = SystemUtils::getMaxNumaNode(); - std::vector numa_nodes_cpu_mask = SystemUtils::getNumaNodesCpuMask(); - bool hualloc_enable_mbind = config->getBool("hualloc_enable_mbind", false); - int mbind_mode = config->getInt("hualloc_mbind_mode", 1); - - /* - *mbind mode - #define MPOL_DEFAULT 0 - #define MPOL_PREFERRED 1 - #define MPOL_BIND 2 - #define MPOL_INTERLEAVE 3 - #define MPOL_LOCAL 4 - #define MPOL_MAX 5 - */ - huallocSetNumaInfo( - max_numa_node, - numa_nodes_cpu_mask, - hualloc_enable_mbind, - mbind_mode, - huallocLogPrint - ); + static std::once_flag numa_aware_init_flag; + std::call_once(numa_aware_init_flag, [&]() + { + size_t max_numa_node = SystemUtils::getMaxNumaNode(); + std::vector numa_nodes_cpu_mask = SystemUtils::getNumaNodesCpuMask(); + bool hualloc_enable_mbind = config->getBool("hualloc_enable_mbind", false); + int mbind_mode = config->getInt("hualloc_mbind_mode", 1); + + /* + *mbind mode + #define MPOL_DEFAULT 0 + #define MPOL_PREFERRED 1 + #define MPOL_BIND 2 + #define MPOL_INTERLEAVE 3 + #define MPOL_LOCAL 4 + #define MPOL_MAX 5 + */ + huallocSetNumaInfo( + max_numa_node, + numa_nodes_cpu_mask, + hualloc_enable_mbind, + mbind_mode, + huallocLogPrint + ); + }); } double default_hualloc_cache_ratio = config->getDouble("hualloc_cache_ratio", 0.25); @@ -1224,8 +1226,8 @@ int Server::main(const std::vector & /*args*/) /// A cache for gin index store GinIndexStoreCacheSettings ginindex_store_cache_settings; ginindex_store_cache_settings.lru_max_size = config().getUInt64("ginindex_store_cache_size", 5368709120); //5GB - ginindex_store_cache_settings.mapping_bucket_size = config().getUInt64("ginindex_store_cache_bucket", 5000); //5000 - ginindex_store_cache_settings.cache_shard_num = config().getUInt64("ginindex_store_cache_shard", 8); //8 + ginindex_store_cache_settings.mapping_bucket_size = config().getUInt64("ginindex_store_cache_bucket", 1000); //1000 + ginindex_store_cache_settings.cache_shard_num = config().getUInt64("ginindex_store_cache_shard", 2); //2 ginindex_store_cache_settings.lru_update_interval = config().getUInt64("ginindex_store_cache_lru_update_interval", 60); //60 seconds global_context->setGinIndexStoreFactory(ginindex_store_cache_settings); diff --git a/src/Access/AccessControlManager.cpp b/src/Access/AccessControlManager.cpp index 00dcbe65577..245ba1f56b8 100644 --- a/src/Access/AccessControlManager.cpp +++ b/src/Access/AccessControlManager.cpp @@ -148,6 +148,15 @@ AccessControlManager::AccessControlManager() { } +bool AccessControlManager::isSensitiveGrantee(const String & grantee) const +{ + auto pos = grantee.find('.'); + + if (pos == String::npos || pos == 0) + return false; + + return isSensitiveTenant(grantee.substr(0, pos)); +} bool AccessControlManager::isSensitiveTenant(const String & tenant) const { @@ -437,8 +446,7 @@ void AccessControlManager::checkSettingNameIsAllowed(const std::string_view & se custom_settings_prefixes->checkSettingNameIsAllowed(setting_name); } - -std::shared_ptr AccessControlManager::getContextAccess( +ContextAccessParams AccessControlManager::getContextAccessParams( const UUID & user_id, const std::vector & current_roles, bool use_default_roles, @@ -446,6 +454,7 @@ std::shared_ptr AccessControlManager::getContextAccess( const String & current_database, const ClientInfo & client_info, const String & tenant, + bool has_tenant_id_in_username, bool load_roles) const { ContextAccessParams params; @@ -460,8 +469,9 @@ std::shared_ptr AccessControlManager::getContextAccess( params.http_method = client_info.http_method; params.address = client_info.current_address.host(); params.quota_key = client_info.quota_key; - params.has_tenant_id_in_username = !tenant.empty(); - params.enable_sensitive_permission = sensitive_permission_tenants->isSensitivePermissionEnabled(tenant); + params.has_tenant_id_in_username = has_tenant_id_in_username; + params.enable_sensitive_permission = + has_tenant_id_in_username ? isSensitiveTenant(tenant) : false; params.load_roles = load_roles; /// Extract the last entry from comma separated list of X-Forwarded-For addresses. @@ -474,8 +484,7 @@ std::shared_ptr AccessControlManager::getContextAccess( boost::trim(last_forwarded_address); params.forwarded_address = last_forwarded_address; } - - return getContextAccess(params); + return params; } diff --git a/src/Access/AccessControlManager.h b/src/Access/AccessControlManager.h index 6e25122615b..69325a37cc9 100644 --- a/src/Access/AccessControlManager.h +++ b/src/Access/AccessControlManager.h @@ -49,8 +49,6 @@ class AccessControlManager : public MultipleAccessStorage AccessControlManager(); ~AccessControlManager() override; - bool isSensitiveTenant(const String & tenant) const; - /// Initializes access storage (user directories). void setUpFromMainConfig(const Poco::Util::AbstractConfiguration & config_, const String & config_path_, const zkutil::GetZooKeeper & get_zookeeper_function_); @@ -131,7 +129,7 @@ class AccessControlManager : public MultipleAccessStorage void setSelectFromMySQLRequiresGrant(bool enable) { select_from_mysql_requires_grant = enable; } bool doesSelectFromMySQLRequireGrant() const { return select_from_mysql_requires_grant; } - std::shared_ptr getContextAccess( + ContextAccessParams getContextAccessParams( const UUID & user_id, const std::vector & current_roles, bool use_default_roles, @@ -139,6 +137,7 @@ class AccessControlManager : public MultipleAccessStorage const String & current_database, const ClientInfo & client_info, const String & tenant, + bool has_tenant_id_in_username, bool load_roles) const; std::shared_ptr getContextAccess(const ContextAccessParams & params) const; @@ -170,8 +169,13 @@ class AccessControlManager : public MultipleAccessStorage const ExternalAuthenticators & getExternalAuthenticators() const; + bool isSensitiveGrantee(const String & grantee) const; + std::function sensitive_resource_getter; +private: + bool isSensitiveTenant(const String & tenant) const; + private: class ContextAccessCache; class CustomSettingsPrefixes; diff --git a/src/Access/AccessFlags.h b/src/Access/AccessFlags.h index 7e32f423564..95fe013c6a6 100644 --- a/src/Access/AccessFlags.h +++ b/src/Access/AccessFlags.h @@ -112,7 +112,7 @@ class AccessFlags static AccessFlags allFlagsGrantableOnColumnLevel(); private: - static constexpr size_t NUM_FLAGS = 128; + static constexpr size_t NUM_FLAGS = 256; using Flags = std::bitset; Flags flags; diff --git a/src/Access/AccessRights.cpp b/src/Access/AccessRights.cpp index cbbf5cda06d..f6be42163f3 100644 --- a/src/Access/AccessRights.cpp +++ b/src/Access/AccessRights.cpp @@ -192,6 +192,7 @@ namespace }; + /* must be synced with the Level definition in ContextAccess.cpp */ enum Level { GLOBAL_LEVEL, @@ -275,30 +276,56 @@ struct AccessRightsBase::Node calculateMinMaxFlags(); } + template void revoke(const AccessFlags & flags_) { removeGrantsRec(flags_); optimizeTree(); } - template + template void revoke(const AccessFlags & flags_, const std::string_view & name, const Args &... subnames) { - auto & child = getChild(name); + if constexpr (if_exists) + { + auto * child = tryGetChild(name); - child.revoke(flags_, subnames...); - eraseChildIfPossible(child); + if (!child) + return; + + child->template revoke(flags_, subnames...); + eraseChildIfPossible(*child); + } + else + { + auto & child = getChild(name); + + child.template revoke(flags_, subnames...); + eraseChildIfPossible(child); + } calculateMinMaxFlags(); } - template + template void revoke(const AccessFlags & flags_, const std::vector & names) { for (const auto & name : names) { - auto & child = getChild(name); - child.revoke(flags_); - eraseChildIfPossible(child); + if constexpr (if_exists) + { + auto * child = tryGetChild(name); + if (!child) + continue; + + child->template revoke(flags_); + eraseChildIfPossible(*child); + } + else + { + auto & child = getChild(name); + child.template revoke(flags_); + eraseChildIfPossible(child); + } } calculateMinMaxFlags(); } @@ -350,58 +377,59 @@ struct AccessRightsBase::Node return true; } - bool isGranted(const std::unordered_set &, const AccessFlags & flags_) const requires Permission + bool isGranted(int sensitive_level, const AccessFlags & flags_) const requires Permission { + /* sensitive resource is not granted */ + if (level < sensitive_level) + return false; + return isGranted(flags_); } template - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags_, const std::string_view & name, const Args &... subnames) const requires Permission + bool isGranted(int sensitive_level, const AccessFlags & flags_, const std::string_view & name, const Args &... subnames) const requires Permission { AccessFlags flags_to_check = flags_ - min_flags_with_children; + if (!max_flags_with_children.contains(flags_to_check)) + return false; - const Node * child = tryGetChild(name); // to reject, this should fail + const Node * child = tryGetChild(name); if (child) - { - return child->isGranted(sensitive_columns, flags_to_check, subnames...); - } - else - { - auto current_node_name = node_name ? *node_name : "NULL"; - return name == current_node_name && flags.contains(flags_to_check); - } + return child->isGranted(sensitive_level, flags_to_check, subnames...); + + /* sensitive resource is not granted */ + if (level < sensitive_level) + return false; + + return flags.contains(flags_to_check); } template - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags_, const std::vector & names) const requires Permission + bool isGranted(int sensitive_level, const AccessFlags & flags_, const std::unordered_set & names) const requires Permission { AccessFlags flags_to_check = flags_ - min_flags_with_children; + if (!max_flags_with_children.contains(flags_to_check)) + return false; for (const auto & name : names) { const Node * child = tryGetChild(name); if (child) { - if (sensitive_columns.contains(name)) - { - if (!child->isGranted(sensitive_columns, flags_to_check, name) || !flags.contains(flags_to_check)) // For sensitive column, must have permissions granted for both table and the column - return false; - } - else if (!child->isGranted(sensitive_columns, flags_to_check, name)) + if (!child->isGranted(sensitive_level, flags_to_check, name)) return false; } else { - if (sensitive_columns.contains(name)) - { - auto current_node_name = node_name ? *node_name : "NULL"; - if (name != current_node_name || !flags.contains(flags_to_check)) - return false; - } + /* sensitive resource is not granted */ + if (level < sensitive_level) + return false; + if (!flags.contains(flags_to_check)) return false; } } + return true; } @@ -511,7 +539,7 @@ struct AccessRightsBase::Node bool canEraseChild([[maybe_unused]] const Node & child) const { if constexpr (IsSensitive) - return false; + return !child.max_flags_with_children; else return ((flags & child.getAllGrantableFlags()) == child.flags) && !child.children; } @@ -607,8 +635,8 @@ struct AccessRightsBase::Node auto flags_go = node_go ? node_go->flags : parent_fl_go; auto revokes = parent_fl - flags; auto revokes_go = parent_fl_go - flags_go - revokes; - auto grants_go = flags_go - parent_fl_go; - auto grants = flags - parent_fl - grants_go; + auto grants_go = IsSensitive ? flags_go : flags_go - parent_fl_go; + auto grants = IsSensitive ? flags - grants_go : flags - parent_fl - grants_go; if (revokes) res.push_back(ProtoElement{revokes, full_name, false, true}); @@ -713,7 +741,7 @@ struct AccessRightsBase::Node for (auto & [lhs_childname, lhs_child] : *children) { if (!rhs.tryGetChild(lhs_childname)) - lhs_child.flags |= rhs.flags & lhs_child.getAllGrantableFlags(); + lhs_child.addGrantsRec(rhs.flags, COLUMN_LEVEL); } } } @@ -731,7 +759,7 @@ struct AccessRightsBase::Node for (auto & [lhs_childname, lhs_child] : *children) { if (!rhs.tryGetChild(lhs_childname)) - lhs_child.flags &= rhs.flags; + lhs_child.removeGrantsRec(~rhs.flags); } } } @@ -930,14 +958,14 @@ void AccessRightsBase::grantWithGrantOption(const AccessRightsEleme template -template +template void AccessRightsBase::revokeImpl(const AccessFlags & flags, const Args &... args) { auto helper = [&](std::unique_ptr & root_node) { if (!root_node) return; - root_node->revoke(flags, args...); + root_node->template revoke(flags, args...); if (!root_node->flags && !root_node->children) root_node = nullptr; }; @@ -948,78 +976,95 @@ void AccessRightsBase::revokeImpl(const AccessFlags & flags, const } template -template +template void AccessRightsBase::revokeImplHelper(const AccessRightsElement & element) { assert(!element.grant_option || grant_option); if (element.any_database) - revokeImpl(element.access_flags); + revokeImpl(element.access_flags); else if (element.any_table) - revokeImpl(element.access_flags, element.database); + revokeImpl(element.access_flags, element.database); else if (element.any_column) - revokeImpl(element.access_flags, element.database, element.table); + revokeImpl(element.access_flags, element.database, element.table); else - revokeImpl(element.access_flags, element.database, element.table, element.columns); + revokeImpl(element.access_flags, element.database, element.table, element.columns); } template -template +template void AccessRightsBase::revokeImpl(const AccessRightsElement & element) { if constexpr (grant_option) { - revokeImplHelper(element); + revokeImplHelper(element); } else { if (element.grant_option) - revokeImplHelper(element); + revokeImplHelper(element); else - revokeImplHelper(element); + revokeImplHelper(element); } } template -template +template void AccessRightsBase::revokeImpl(const AccessRightsElements & elements) { for (const auto & element : elements) - revokeImpl(element); + revokeImpl(element); } template -void AccessRightsBase::revoke(const AccessFlags & flags) { revokeImpl(flags); } +void AccessRightsBase::revoke(const AccessFlags & flags) { revokeImpl(flags); } template -void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } +void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } template -void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } +void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } template -void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } +void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } template -void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } +void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } template -void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } +void AccessRightsBase::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } template -void AccessRightsBase::revoke(const AccessRightsElement & element) { revokeImpl(element); } +void AccessRightsBase::revoke(const AccessRightsElement & element) { revokeImpl(element); } template -void AccessRightsBase::revoke(const AccessRightsElements & elements) { revokeImpl(elements); } +void AccessRightsBase::revoke(const AccessRightsElements & elements) { revokeImpl(elements); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags) { revokeImpl(flags); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags) { revokeImpl(flags); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } template -void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } +void AccessRightsBase::tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } template -void AccessRightsBase::revokeGrantOption(const AccessRightsElement & element) { revokeImpl(element); } +void AccessRightsBase::tryRevoke(const AccessRightsElement & element) { revokeImpl(element); } template -void AccessRightsBase::revokeGrantOption(const AccessRightsElements & elements) { revokeImpl(elements); } +void AccessRightsBase::tryRevoke(const AccessRightsElements & elements) { revokeImpl(elements); } + +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags) { revokeImpl(flags); } +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } +template +void AccessRightsBase::revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } +template +void AccessRightsBase::revokeGrantOption(const AccessRightsElement & element) { revokeImpl(element); } +template +void AccessRightsBase::revokeGrantOption(const AccessRightsElements & elements) { revokeImpl(elements); } template @@ -1155,17 +1200,15 @@ void AccessRightsBase::makeIntersection(const AccessRightsBase & root_node, const std::unique_ptr & other_root_node) { if (!root_node) - { - if (other_root_node) - root_node = std::make_unique(*other_root_node); return; - } - if (other_root_node) + if (!other_root_node) { - root_node->makeIntersection(*other_root_node); - if (!root_node->flags && !root_node->children) - root_node = nullptr; + root_node = nullptr; + return; } + root_node->makeIntersection(*other_root_node); + if (!root_node->flags && !root_node->children) + root_node = nullptr; }; helper(root, other.root); helper(root_with_grant_option, other.root_with_grant_option); @@ -1219,14 +1262,14 @@ void AccessRightsBase::logTree() const } template -bool SensitiveAccessRights::isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const Args &... args) const +bool SensitiveAccessRights::isGrantedImpl(int sensitive_level, const AccessFlags & flags, const Args &... args) const { auto helper = [&](const std::unique_ptr & root_node) -> bool { if (!root_node) return flags.isEmpty(); - return root_node->isGranted(sensitive_columns, flags, args...); + return root_node->isGranted(sensitive_level, flags, args...); }; if constexpr (grant_option) return helper(root_with_grant_option); @@ -1235,52 +1278,51 @@ bool SensitiveAccessRights::isGrantedImpl(const std::unordered_set -bool SensitiveAccessRights::isGrantedImplHelper(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const +bool SensitiveAccessRights::isGrantedImplHelper(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const { assert(!element.grant_option || grant_option); if (element.any_database) - return isGrantedImpl(sensitive_columns, element.access_flags); + return isGrantedImpl(sensitive_level, element.access_flags); else if (element.any_table) - return isGrantedImpl(sensitive_columns, element.access_flags, element.database); + return isGrantedImpl(sensitive_level, element.access_flags, element.database); else if (element.any_column) - return isGrantedImpl(sensitive_columns, element.access_flags, element.database, element.table); + return isGrantedImpl(sensitive_level, element.access_flags, element.database, element.table); else - return isGrantedImpl(sensitive_columns, element.access_flags, element.database, element.table, element.columns); + return isGrantedImpl(sensitive_level, element.access_flags, element.database, element.table, sensitive_columns); } template -bool SensitiveAccessRights::isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const +bool SensitiveAccessRights::isGrantedImpl(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const { if constexpr (grant_option) { - return isGrantedImplHelper(sensitive_columns, element); + return isGrantedImplHelper(sensitive_level, sensitive_columns, element); } else { if (element.grant_option) - return isGrantedImplHelper(sensitive_columns, element); + return isGrantedImplHelper(sensitive_level, sensitive_columns, element); else - return isGrantedImplHelper(sensitive_columns, element); + return isGrantedImplHelper(sensitive_level, sensitive_columns, element); } } template -bool SensitiveAccessRights::isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const +bool SensitiveAccessRights::isGrantedImpl(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const { for (const auto & element : elements) - if (!isGrantedImpl(sensitive_columns, element)) + if (!isGrantedImpl(sensitive_level, sensitive_columns, element)) return false; return true; } - -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags) const { return isGrantedImpl(sensitive_columns, flags); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database) const { return isGrantedImpl(sensitive_columns, flags, database); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return isGrantedImpl(sensitive_columns, flags, database, table); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isGrantedImpl(sensitive_columns, flags, database, table, column); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isGrantedImpl(sensitive_columns, flags, database, table, columns); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isGrantedImpl(sensitive_columns, flags, database, table, columns); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const { return isGrantedImpl(sensitive_columns, element); } -bool SensitiveAccessRights::isGranted(const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const { return isGrantedImpl(sensitive_columns, elements); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set &, const AccessFlags & flags) const { return isGrantedImpl(sensitive_level, flags); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set &, const AccessFlags & flags, const std::string_view & database) const { return isGrantedImpl(sensitive_level, flags, database); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set &, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return isGrantedImpl(sensitive_level, flags, database, table); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set &, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isGrantedImpl(sensitive_level, flags, database, table, column); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector &) const { return isGrantedImpl(sensitive_level, flags, database, table, sensitive_columns); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings &) const { return isGrantedImpl(sensitive_level, flags, database, table, sensitive_columns); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const { return isGrantedImpl(sensitive_level, sensitive_columns, element); } +bool SensitiveAccessRights::isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const { return isGrantedImpl(sensitive_level, sensitive_columns, elements); } template class AccessRightsBase; template class AccessRightsBase; diff --git a/src/Access/AccessRights.h b/src/Access/AccessRights.h index 7deb076048c..ad75fb3f3ff 100644 --- a/src/Access/AccessRights.h +++ b/src/Access/AccessRights.h @@ -68,6 +68,15 @@ class AccessRightsBase void revoke(const AccessRightsElement & element); void revoke(const AccessRightsElements & elements); + void tryRevoke(const AccessFlags & flags); + void tryRevoke(const AccessFlags & flags, const std::string_view & database); + void tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table); + void tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void tryRevoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns); + void tryRevoke(const AccessRightsElement & element); + void tryRevoke(const AccessRightsElements & elements); + void revokeGrantOption(const AccessFlags & flags); void revokeGrantOption(const AccessFlags & flags, const std::string_view & database); void revokeGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table); @@ -117,16 +126,16 @@ class AccessRightsBase template void grantImplHelper(const AccessRightsElement & element); - template + template void revokeImpl(const AccessFlags & flags, const Args &... args); - template + template void revokeImpl(const AccessRightsElement & element); - template + template void revokeImpl(const AccessRightsElements & elements); - template + template void revokeImplHelper(const AccessRightsElement & element); @@ -182,26 +191,27 @@ class SensitiveAccessRights : public AccessRightsBase public: using Base = AccessRightsBase; using Base::Base; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; - bool isGranted(const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const; + + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; + bool isGranted(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const; private: template - bool isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessFlags & flags, const Args &... args) const; + bool isGrantedImpl(int sensitive_level, const AccessFlags & flags, const Args &... args) const; template - bool isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; + bool isGrantedImpl(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; template - bool isGrantedImpl(const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const; + bool isGrantedImpl(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElements & elements) const; template - bool isGrantedImplHelper(const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; + bool isGrantedImplHelper(int sensitive_level, const std::unordered_set & sensitive_columns, const AccessRightsElement & element) const; }; } diff --git a/src/Access/AccessType.h b/src/Access/AccessType.h index 95b9f447c98..e2d169a9d90 100644 --- a/src/Access/AccessType.h +++ b/src/Access/AccessType.h @@ -124,6 +124,7 @@ enum class AccessType implicitly enabled by the grant CREATE_TABLE on any table */ \ M(CREATE_FUNCTION, "", GLOBAL, CREATE) /* allows to execute CREATE FUNCTION */ \ M(CREATE_BINDING, "", GLOBAL, CREATE) /* allows to execute CREATE BINDING */ \ + M(CREATE_PREPARED_STATEMENT, "", GLOBAL, CREATE) /* allows to execute CREATE PREPARED STATEMENT */ \ M(CREATE, "", GROUP, ALL) /* allows to execute {CREATE|ATTACH} */ \ \ M(DROP_DATABASE, "", DATABASE, DROP) /* allows to execute {DROP|DETACH} DATABASE */ \ @@ -133,6 +134,7 @@ enum class AccessType M(DROP_DICTIONARY, "", DICTIONARY, DROP) /* allows to execute {DROP|DETACH} DICTIONARY */ \ M(DROP_FUNCTION, "", GLOBAL, DROP) /* allows to execute DROP FUNCTION */\ M(DROP_BINDING, "", GLOBAL, DROP) /* allows to execute DROP BINDING */\ + M(DROP_PREPARED_STATEMENT, "", GLOBAL, DROP) /* allows to execute DROP PREPARED STATEMENT */ \ M(DROP, "", GROUP, ALL) /* allows to execute {DROP|DETACH} */ \ \ M(TRUNCATE, "TRUNCATE TABLE", TABLE, ALL) \ @@ -163,6 +165,7 @@ enum class AccessType M(CREATE_SETTINGS_PROFILE, "CREATE PROFILE", GLOBAL, ACCESS_MANAGEMENT) \ M(ALTER_SETTINGS_PROFILE, "ALTER PROFILE", GLOBAL, ACCESS_MANAGEMENT) \ M(DROP_SETTINGS_PROFILE, "DROP PROFILE", GLOBAL, ACCESS_MANAGEMENT) \ + M(SET_SENSITIVE, "", COLUMN, ACCESS_MANAGEMENT) \ M(SHOW_USERS, "SHOW CREATE USER", GLOBAL, SHOW_ACCESS) \ M(SHOW_ROLES, "SHOW CREATE ROLE", GLOBAL, SHOW_ACCESS) \ M(SHOW_ROW_POLICIES, "SHOW POLICIES, SHOW CREATE ROW POLICY, SHOW CREATE POLICY", GLOBAL, SHOW_ACCESS) \ diff --git a/src/Access/AeolusAccessUtil.h b/src/Access/AeolusAccessUtil.h new file mode 100644 index 00000000000..856baf98aff --- /dev/null +++ b/src/Access/AeolusAccessUtil.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ + +bool aeolusCheck(const Context & context, const String & full_table_name) +{ + String access_table_names = context.getSettingsRef().access_table_names; + + if (access_table_names.empty()) + access_table_names = context.getSettingsRef().accessible_table_names; + + if (access_table_names.empty()) + return true; + + std::unordered_set allowed_tables; + std::vector tables; + boost::split(tables, access_table_names, boost::is_any_of(" ,")); + allowed_tables.insert(tables.begin(), tables.end()); + + for (auto & table : tables) + { + char * begin = table.data(); + char * end = begin + table.size(); + Tokens tokens(begin, end); + IParser::Pos token_iterator(tokens, context.getSettingsRef().max_parser_depth); + auto pos = token_iterator; + Expected expected; + String database_name, table_name; + if (!parseDatabaseAndTableName(pos, expected, database_name, table_name)) + continue; + + StorageID table_id{database_name, table_name}; + /// tryGetTable below requires resolved table id + StorageID resolved = context.tryResolveStorageID(table_id); + if (!resolved) + continue; + + /// access_table_names need to have resolved name, otherwise tryGetTable below will fail + if (table_id.database_name.empty() && !resolved.database_name.empty()) + allowed_tables.emplace(resolved.getDatabaseName() + "." + resolved.getTableName()); + } + + return allowed_tables.count(full_table_name) > 0; +} + +} diff --git a/src/Access/ContextAccess.cpp b/src/Access/ContextAccess.cpp index 296dcf2cabb..ee0a25131c5 100644 --- a/src/Access/ContextAccess.cpp +++ b/src/Access/ContextAccess.cpp @@ -40,6 +40,15 @@ namespace ErrorCodes namespace { + /* must be synced with the Level definition in AccessRights.cpp */ + enum Level + { + GLOBAL_LEVEL, + DATABASE_LEVEL, + TABLE_LEVEL, + COLUMN_LEVEL, + }; + static const std::unordered_set always_accessible_tables { /// Constant tables "one", @@ -68,12 +77,14 @@ namespace "columns", "mutations", "users", + "dictionaries", /// Specific to the current session "settings", "current_roles", "enabled_roles", "quota_usage", + "processes", /// The following tables hide some rows if the current user doesn't have corresponding SHOW privileges. /// For IDE tools to get schema info @@ -387,7 +398,8 @@ void ContextAccess::calculateAccessRights() const } LOG_TRACE(trace_log, "Settings: readonly={}, allow_ddl={}, allow_introspection_functions={}", params.readonly, params.allow_ddl, params.allow_introspection); LOG_TRACE(trace_log, "List of all grants: {}", access->toString()); - LOG_TRACE(trace_log, "List of all sensitive grants: {}", sensitive_access->toString()); + if (params.enable_sensitive_permission) + LOG_TRACE(trace_log, "List of all sensitive grants: {}", sensitive_access->toString()); LOG_TRACE(trace_log, "List of all grants including implicit: {}", access_with_implicit->toString()); } } @@ -513,7 +525,7 @@ std::shared_ptr ContextAccess::getSensitiveAccessRi return nothing_granted; } -bool ContextAccess::isSensitiveImpl(std::unordered_set & cols, const std::string_view & database, const std::string_view & table = {}, const std::vector & columns = {}) const +int ContextAccess::isSensitiveImpl(std::unordered_set & cols, const std::string_view & database, const std::string_view & table = {}, const std::vector & columns = {}) const { auto sensitive_resource = manager->sensitive_resource_getter(formatTenantDatabaseName(std::string(database))); if (!sensitive_resource) @@ -536,7 +548,7 @@ bool ContextAccess::isSensitiveImpl(std::unordered_set & cols, } if (!cols.empty()) - return true; + return COLUMN_LEVEL; } } @@ -544,33 +556,21 @@ bool ContextAccess::isSensitiveImpl(std::unordered_set & cols, { for (auto & sensitive_table : sensitive_resource->tables()) { - if (sensitive_table.table() == table) - return sensitive_table.is_sensitive(); + if (sensitive_table.table() != table) + continue; + + if (sensitive_table.is_sensitive()) + return TABLE_LEVEL; } } if (!database.empty()) { - return sensitive_resource->is_sensitive(); + if (sensitive_resource->is_sensitive()) + return DATABASE_LEVEL; } - return false; -} - -template -bool ContextAccess::checkSensitivePermissions(std::unordered_set & cols, const Args &... args) const -{ - auto tenant_id = getCurrentTenantId(); - - // Only apply sensitive permission checks on tenanted users only - if (tenant_id.empty()) - return false; - - // Only enable sensitive permission check for selected tenants - if (!params.enable_sensitive_permission) - return false; - - return isSensitive(cols, args...); + return GLOBAL_LEVEL; } template @@ -657,21 +657,42 @@ bool ContextAccess::checkAccessImplHelper(const AccessFlags & flags, const Args auto acs = getAccessRightsWithImplicit(); bool granted; bool check_sensitive_permissions = false; - std::unordered_set sensitive_columns; if constexpr (grant_option) granted = acs->hasGrantOption(flags, args...); - else if (checkSensitivePermissions(sensitive_columns, args...)) - { - check_sensitive_permissions = true; - granted = getSensitiveAccessRights()->isGranted(sensitive_columns, flags, args...); - } else granted = acs->isGranted(flags, args...); if (granted) granted = checkTenantsAccess(args...); + while (granted) + { + auto tenant_id = getCurrentTenantId(); + + // Only apply sensitive permission checks on tenanted users only + if (tenant_id.empty()) + break; + + // Only enable sensitive permission check for selected tenants + if (!params.enable_sensitive_permission) + break; + + if (roles_info->is_admin) + break; + + std::unordered_set sensitive_columns; + int sensitive_level = isSensitive(sensitive_columns, args...); + + if (sensitive_level != GLOBAL_LEVEL) + { + check_sensitive_permissions = true; + //std::vector cols{sensitive_columns.begin(), sensitive_columns.end()}; + granted = getSensitiveAccessRights()->isGranted(sensitive_level, sensitive_columns, flags, args...); + } + break; + } + if (!granted) { if (grant_option && acs->isGranted(flags, args...)) @@ -686,7 +707,7 @@ bool ContextAccess::checkAccessImplHelper(const AccessFlags & flags, const Args return access_denied( "Not enough privileges. To execute this query it's necessary to have grant" - + (check_sensitive_permissions && !grant_option ? std::string("(sensitive) ") : std::string(" ")) + + (check_sensitive_permissions && !grant_option ? std::string(" SENSITIVE ") : std::string(" ")) + AccessRightsElement{flags, args...}.toStringWithoutOptions() + (grant_option ? " WITH GRANT OPTION" : ""), ErrorCodes::ACCESS_DENIED); } @@ -700,15 +721,16 @@ bool ContextAccess::checkAccessImplHelper(const AccessFlags & flags, const Args const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY; const AccessFlags function_ddl = AccessType::CREATE_FUNCTION | AccessType::DROP_FUNCTION; const AccessFlags binding_ddl = AccessType::CREATE_BINDING | AccessType::DROP_BINDING; + const AccessFlags prepared_statement_ddl = AccessType::CREATE_PREPARED_STATEMENT | AccessType::DROP_PREPARED_STATEMENT; const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; - const AccessFlags table_and_dictionary_and_function_ddl_and_binding = table_ddl | dictionary_ddl | function_ddl | binding_ddl; + const AccessFlags table_and_dictionary_and_function_ddl_and_binding = table_ddl | dictionary_ddl | function_ddl | binding_ddl | prepared_statement_ddl; const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; const AccessFlags write_dcl_access = AccessType::ACCESS_MANAGEMENT - AccessType::SHOW_ACCESS; const AccessFlags not_readonly_flags = write_table_access | table_and_dictionary_and_function_ddl_and_binding | write_dcl_access | AccessType::SYSTEM | AccessType::KILL_QUERY; const AccessFlags not_readonly_1_flags = AccessType::CREATE_TEMPORARY_TABLE; - const AccessFlags ddl_flags = table_ddl | dictionary_ddl | function_ddl | binding_ddl; + const AccessFlags ddl_flags = table_ddl | dictionary_ddl | function_ddl | binding_ddl | prepared_statement_ddl; const AccessFlags introspection_flags = AccessType::INTROSPECTION; }; static const PrecalculatedFlags precalc; @@ -798,12 +820,12 @@ bool ContextAccess::checkAccessImpl(const AccessRightsElements & elements) const return true; } -bool ContextAccess::isSensitive(std::unordered_set & /*cols*/) const { return false; } -bool ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database) const { return isSensitiveImpl(cols, database); } -bool ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table) const { return isSensitiveImpl(cols, database, table); } -bool ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isSensitiveImpl(cols, database, table, {column}); } -bool ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isSensitiveImpl(cols, database, table, columns); } -bool ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isSensitiveImpl(cols, database, table, {columns.begin(), columns.end()}); } +int ContextAccess::isSensitive(std::unordered_set & /*cols*/) { return GLOBAL_LEVEL; } +int ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database) const { return isSensitiveImpl(cols, database); } +int ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table) const { return isSensitiveImpl(cols, database, table); } +int ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isSensitiveImpl(cols, database, table, {column}); } +int ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isSensitiveImpl(cols, database, table, columns); } +int ContextAccess::isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isSensitiveImpl(cols, database, table, {columns.begin(), columns.end()}); } bool ContextAccess::isGranted(const AccessFlags & flags) const { return checkAccessImpl(flags); } bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database) const { return checkAccessImpl(flags, database); } @@ -945,4 +967,9 @@ void ContextAccess::checkAdminOption(const std::vector & role_ids) const { void ContextAccess::checkAdminOption(const std::vector & role_ids, const Strings & names_of_roles) const { checkAdminOptionImpl(role_ids, names_of_roles); } void ContextAccess::checkAdminOption(const std::vector & role_ids, const std::unordered_map & names_of_roles) const { checkAdminOptionImpl(role_ids, names_of_roles); } +bool ContextAccessParams::dependsOnSettingName(std::string_view setting_name) +{ + return (setting_name == "readonly") || (setting_name == "allow_ddl") || (setting_name == "allow_introspection_functions"); +} + } diff --git a/src/Access/ContextAccess.h b/src/Access/ContextAccess.h index 513685c5f43..91098e95ef3 100644 --- a/src/Access/ContextAccess.h +++ b/src/Access/ContextAccess.h @@ -60,6 +60,8 @@ struct ContextAccessParams friend bool operator >(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return rhs < lhs; } friend bool operator <=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(rhs < lhs); } friend bool operator >=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(lhs < rhs); } + + static bool dependsOnSettingName(std::string_view setting_name); }; @@ -116,13 +118,13 @@ class ContextAccess : public std::enable_shared_from_this void checkGrantOption(const AccessRightsElement & element) const; void checkGrantOption(const AccessRightsElements & elements) const; - bool isSensitiveImpl(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isSensitive(std::unordered_set & cols) const; - bool isSensitive(std::unordered_set & cols, const std::string_view & database) const; - bool isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table) const; - bool isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - bool isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + int isSensitiveImpl(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + static int isSensitive(std::unordered_set & cols); + int isSensitive(std::unordered_set & cols, const std::string_view & database) const; + int isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table) const; + int isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + int isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + int isSensitive(std::unordered_set & cols, const std::string_view & database, const std::string_view & table, const Strings & columns) const; /// Checks if a specified access is granted, and returns false if not. /// Empty database means the current database. @@ -218,9 +220,6 @@ class ContextAccess : public std::enable_shared_from_this template bool checkAdminOptionImplHelper(const Container & role_ids, const GetNameFunction & get_name_function) const; - template - bool checkSensitivePermissions(std::unordered_set & cols, const Args &... args) const; - const AccessControlManager * manager = nullptr; const Params params; bool is_full_access = false; diff --git a/src/Access/DiskAccessStorage.cpp b/src/Access/DiskAccessStorage.cpp index 11b2065a0a0..e97ec6d37c9 100644 --- a/src/Access/DiskAccessStorage.cpp +++ b/src/Access/DiskAccessStorage.cpp @@ -136,6 +136,7 @@ namespace std::shared_ptr quota; std::shared_ptr profile; AccessEntityPtr res; + bool sensitive_tenant = false; for (const auto & query : queries) { @@ -176,12 +177,16 @@ namespace } else if (auto * grant_query = query->as()) { + /* sensitive permissions were serialized first */ + if (grant_query->is_sensitive) + sensitive_tenant = true; + if (!user && !role) throw Exception("A user or role should be attached before grant in file " + file_path, ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); if (user) - InterpreterGrantQuery::updateUserFromQuery(*user, *grant_query); + InterpreterGrantQuery::updateUserFromQuery(*user, *grant_query, sensitive_tenant); else - InterpreterGrantQuery::updateRoleFromQuery(*role, *grant_query); + InterpreterGrantQuery::updateRoleFromQuery(*role, *grant_query, sensitive_tenant); } else throw Exception("No interpreter found for query " + query->getID(), ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); @@ -215,7 +220,10 @@ namespace ASTs queries; queries.push_back(InterpreterShowCreateAccessEntityQuery::getAttachQuery(entity)); if ((entity.getType() == EntityType::USER) || (entity.getType() == EntityType::ROLE)) - boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity)); + { + boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity, true)); + boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity, false)); + } /// Serialize the list of ATTACH queries to a string. WriteBufferFromOwnString buf; @@ -396,7 +404,7 @@ void DiskAccessStorage::clear() { entries_by_id.clear(); for (auto type : collections::range(EntityType::MAX)) - // collections::range(MAX_CONDITION_TYPE) give us a range of [0, MAX_CONDITION_TYPE) + // collections::range(MAX_CONDITION_TYPE) give us a range of [0, MAX_CONDITION_TYPE) // coverity[overrun-local] entries_by_name_and_type[static_cast(type)].clear(); } diff --git a/src/Access/EnabledRolesInfo.h b/src/Access/EnabledRolesInfo.h index 1795a573bee..872822d7dd4 100644 --- a/src/Access/EnabledRolesInfo.h +++ b/src/Access/EnabledRolesInfo.h @@ -20,6 +20,7 @@ struct EnabledRolesInfo AccessRights access; SensitiveAccessRights sensitive_access; SettingsProfileElements settings_from_enabled_roles; + bool is_admin = false; Strings getCurrentRolesNames() const; Strings getEnabledRolesNames() const; diff --git a/src/Access/KVAccessStorage.cpp b/src/Access/KVAccessStorage.cpp index 191d6db3f4d..5ad45c9e15b 100644 --- a/src/Access/KVAccessStorage.cpp +++ b/src/Access/KVAccessStorage.cpp @@ -109,21 +109,29 @@ namespace /// Reads a file containing ATTACH queries and then parses it to build an access entity. - AccessEntityPtr convertFromSqlToEntity(const String & create_sql) + AccessEntityPtr convertFromSqlToEntity(const String & create_sql, const String & sensitive_sql) { /// Parse the create sql. ASTs queries; - ParserAttachAccessEntity parser; - const char * begin = create_sql.data(); /// begin of current query - const char * pos = begin; /// parser moves pos from begin to the end of current query - const char * end = begin + create_sql.size(); - while (pos < end) - { - queries.emplace_back(parseQueryAndMovePosition(parser, pos, end, "", true, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH)); - while (isWhitespaceASCII(*pos) || *pos == ';') - ++pos; - } + auto parse_sql = [&](const String & sql, bool sensitive_mode) { + ParserAttachAccessEntity parser; + const char * begin = sql.data(); /// begin of current query + const char * pos = begin; /// parser moves pos from begin to the end of current query + const char * end = begin + sql.size(); + while (pos < end) + { + ASTPtr ast = parseQueryAndMovePosition(parser, pos, end, "", true, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH); + /* ignore ATTACH queries in sensitive mode */ + if (!sensitive_mode || ast->as()) + queries.emplace_back(ast); + while (isWhitespaceASCII(*pos) || *pos == ';') + ++pos; + } + }; + + parse_sql(create_sql, false); + parse_sql(sensitive_sql, true); /// Interpret the AST to build an access entity. std::shared_ptr user; @@ -132,6 +140,7 @@ namespace std::shared_ptr quota; std::shared_ptr profile; AccessEntityPtr res; + bool sensitive_tenant = !sensitive_sql.empty(); for (const auto & query : queries) { @@ -172,12 +181,16 @@ namespace } else if (auto * grant_query = query->as()) { + /* sensitive permissions were serialized first */ + if (grant_query->is_sensitive) + sensitive_tenant = true; + if (!user && !role) throw Exception("A user or role should be attached before grant in sql: " + create_sql, ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); if (user) - InterpreterGrantQuery::updateUserFromQuery(*user, *grant_query); + InterpreterGrantQuery::updateUserFromQuery(*user, *grant_query, sensitive_tenant); else - InterpreterGrantQuery::updateRoleFromQuery(*role, *grant_query); + InterpreterGrantQuery::updateRoleFromQuery(*role, *grant_query, sensitive_tenant); } else throw Exception("No interpreter found for query " + query->getID(), ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); @@ -191,13 +204,18 @@ namespace /// Writes ATTACH queries for building a specified access entity to a file. - String convertFromEntityToSql(const IAccessEntity & entity) + String convertFromEntityToSql(const IAccessEntity & entity, bool sensitive_mode) { /// Build list of ATTACH queries. ASTs queries; queries.push_back(InterpreterShowCreateAccessEntityQuery::getAttachQuery(entity)); if ((entity.getType() == EntityType::USER) || (entity.getType() == EntityType::ROLE)) - boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity)); + { + ASTs grants = InterpreterShowGrantsQuery::getAttachGrantQueries(entity, sensitive_mode); + if (grants.empty() && sensitive_mode) + return {}; + boost::range::push_back(queries, grants); + } /// Serialize the list of ATTACH queries to a string. WriteBufferFromOwnString buf; @@ -206,12 +224,16 @@ namespace formatAST(*query, buf, false, true); buf.write(";\n", 2); } + return buf.str(); } class ConcurrentAccessGuard { public: - ConcurrentAccessGuard(const UUID &uuid) + ConcurrentAccessGuard & operator=(const ConcurrentAccessGuard &) = delete; + ConcurrentAccessGuard(const ConcurrentAccessGuard &) = delete; + ConcurrentAccessGuard() = delete; + explicit ConcurrentAccessGuard(const UUID &uuid) { { std::scoped_lock lock(map_mtx); @@ -349,7 +371,7 @@ UUID KVAccessStorage::updateCache(EntityType type, const AccessEntityModel & ent lock.unlock(); auto name_shard = getShard(entity_model.name()); auto & name_map = entries_by_name_and_type[static_cast(type)][name_shard]; - const auto & entity = entity_ ? entity_ : convertFromSqlToEntity(entity_model.create_sql()); + const auto & entity = entity_ ? entity_ : convertFromSqlToEntity(entity_model.create_sql(), entity_model.sensitive_sql()); lock.lock(); auto & entry = uuid_map[uuid]; @@ -373,6 +395,8 @@ UUID KVAccessStorage::updateCache(EntityType type, const AccessEntityModel & ent // Always get entity from KV to ensure that we have the most updated Entity at all times std::optional KVAccessStorage::findImpl(EntityType type, const String & name) const { + Notifications notifications; + SCOPE_EXIT({ notify(notifications); }); auto entity_model = catalog->tryGetAccessEntity(type, name); if (!entity_model) @@ -385,7 +409,7 @@ std::optional KVAccessStorage::findImpl(EntityType type, const String & na return std::nullopt; } - return updateCache(type, *entity_model, nullptr); + return updateCache(type, *entity_model, ¬ifications); } @@ -480,7 +504,10 @@ UUID KVAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replac AccessEntityModel new_entity_model; RPCHelpers::fillUUID(uuid, *(new_entity_model.mutable_uuid())); new_entity_model.set_name(name); - new_entity_model.set_create_sql(convertFromEntityToSql(*new_entity)); + new_entity_model.set_create_sql(convertFromEntityToSql(*new_entity, false)); + auto sensitive_sql = convertFromEntityToSql(*new_entity, true); + if (!sensitive_sql.empty()) + new_entity_model.set_sensitive_sql(sensitive_sql); catalog->putAccessEntity(type, new_entity_model, old_entity_model, replace_if_exists); updateCache(type, new_entity_model, new_entity); @@ -533,7 +560,10 @@ void KVAccessStorage::updateImpl(const UUID & uuid, const UpdateFunc & update_fu AccessEntityModel new_entity_model; RPCHelpers::fillUUID(uuid, *(new_entity_model.mutable_uuid())); new_entity_model.set_name(new_name); - new_entity_model.set_create_sql(convertFromEntityToSql(*new_entity)); + new_entity_model.set_create_sql(convertFromEntityToSql(*new_entity, false)); + auto sensitive_sql = convertFromEntityToSql(*new_entity, true); + if (!sensitive_sql.empty()) + new_entity_model.set_sensitive_sql(sensitive_sql); catalog->putAccessEntity(old_entry.type, new_entity_model, old_entry.entity_model); Notifications notifications; @@ -552,9 +582,9 @@ void KVAccessStorage::updateImpl(const UUID & uuid, const UpdateFunc & update_fu if (new_entity_model.commit_time() < entry->commit_time) throw Exception("Concurrent rbac update, model had been overwritten by another server", ErrorCodes::CONCURRENT_RBAC_UPDATE); - entry->entity = new_entity; + entry->entity = std::move(new_entity); entry->commit_time = new_entity_model.commit_time(); - entry->entity_model = new_entity_model; + entry->entity_model = std::move(new_entity_model); if (name_changed) { diff --git a/src/Access/RoleCache.cpp b/src/Access/RoleCache.cpp index 96c41b0f9f7..db2d6a22d9e 100644 --- a/src/Access/RoleCache.cpp +++ b/src/Access/RoleCache.cpp @@ -43,6 +43,8 @@ namespace roles_info.enabled_roles_with_admin_option.emplace(role_id); roles_info.names_of_roles[role_id] = role->getName(); + if (roles_info.names_of_roles[role_id].ends_with("AccountAdmin") && is_current_role) + roles_info.is_admin = true; roles_info.access.makeUnion(role->access); roles_info.sensitive_access.makeUnion(role->sensitive_access); roles_info.settings_from_enabled_roles.merge(role->settings); diff --git a/src/Access/tests/gtest_access_rights_ops.cpp b/src/Access/tests/gtest_access_rights_ops.cpp new file mode 100644 index 00000000000..293f07c6e58 --- /dev/null +++ b/src/Access/tests/gtest_access_rights_ops.cpp @@ -0,0 +1,102 @@ +#include +#include + +using namespace DB; + + +TEST(AccessRights, Union) +{ + AccessRights lhs, rhs; + lhs.grant(AccessType::CREATE_TABLE, "db1", "tb1"); + rhs.grant(AccessType::SELECT, "db2"); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), "GRANT CREATE TABLE ON db1.tb1, GRANT SELECT ON db2.*"); + + lhs.clear(); + rhs.clear(); + rhs.grant(AccessType::SELECT, "db2"); + lhs.grant(AccessType::CREATE_TABLE, "db1", "tb1"); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), "GRANT CREATE TABLE ON db1.tb1, GRANT SELECT ON db2.*"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT); + rhs.grant(AccessType::SELECT, "db1", "tb1"); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT ON *.*"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col1", "col2"}); + rhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col2", "col3"}); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT(col1, col2, col3) ON db1.tb1"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col1", "col2"}); + rhs.grantWithGrantOption(AccessType::SELECT, "db1", "tb1", Strings{"col2", "col3"}); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT(col1) ON db1.tb1, GRANT SELECT(col2, col3) ON db1.tb1 WITH GRANT OPTION"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::INSERT); + rhs.grant(AccessType::ALL, "db1"); + lhs.makeUnion(rhs); + ASSERT_EQ(lhs.toString(), + "GRANT INSERT ON *.*, " + "GRANT SHOW, SELECT, ALTER, CREATE DATABASE, CREATE TABLE, CREATE VIEW, " + "CREATE DICTIONARY, DROP DATABASE, DROP TABLE, DROP VIEW, DROP DICTIONARY, " + "TRUNCATE, OPTIMIZE, SET SENSITIVE, " + "SYSTEM MERGES, SYSTEM TTL MERGES, SYSTEM FETCHES, " + "SYSTEM MOVES, SYSTEM SENDS, SYSTEM REPLICATION QUEUES, " + "SYSTEM DROP REPLICA, SYSTEM SYNC REPLICA, SYSTEM RESTART REPLICA, " + "SYSTEM RESTORE REPLICA, SYSTEM RECALCULATE METRICS, SYSTEM FLUSH DISTRIBUTED, SYSTEM CONSUME, dictGet ON db1.*"); +} + + +TEST(AccessRights, Intersection) +{ + AccessRights lhs, rhs; + lhs.grant(AccessType::CREATE_TABLE, "db1", "tb1"); + rhs.grant(AccessType::SELECT, "db2"); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT USAGE ON *.*"); + + lhs.clear(); + rhs.clear(); + lhs.grant(AccessType::SELECT, "db2"); + rhs.grant(AccessType::CREATE_TABLE, "db1", "tb1"); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT USAGE ON *.*"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT); + rhs.grant(AccessType::SELECT, "db1", "tb1"); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT ON db1.tb1"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col1", "col2"}); + rhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col2", "col3"}); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT(col2) ON db1.tb1"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::SELECT, "db1", "tb1", Strings{"col1", "col2"}); + rhs.grantWithGrantOption(AccessType::SELECT, "db1", "tb1", Strings{"col2", "col3"}); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT SELECT(col2) ON db1.tb1"); + + lhs = {}; + rhs = {}; + lhs.grant(AccessType::INSERT); + rhs.grant(AccessType::ALL, "db1"); + lhs.makeIntersection(rhs); + ASSERT_EQ(lhs.toString(), "GRANT INSERT ON db1.*"); +} diff --git a/src/Advisor/tests/gtest_materialized_view.cpp b/src/Advisor/tests/gtest_materialized_view.cpp index 80792670da6..cb98795378f 100644 --- a/src/Advisor/tests/gtest_materialized_view.cpp +++ b/src/Advisor/tests/gtest_materialized_view.cpp @@ -124,7 +124,7 @@ TEST_F(MaterializedViewAdviseTest, TestMaterializedViewFilterAndProject) EXPECT_CONTAINS(advises.front()->getOptimizedValue(), "d_month_seq - 1"); } -TEST_F(MaterializedViewAdviseTest, TestMaterializedViewFilterAndProject2) +TEST_F(MaterializedViewAdviseTest, DISABLED_TestMaterializedViewFilterAndProject2) { auto advises = getAdvises( {"select d_month_seq - 1, d_date_sk + 1 from date_dim where d_week_seq = 1", @@ -196,7 +196,7 @@ TEST_F(MaterializedViewAdviseTest, TestMaterializedViewCaseWhen) EXPECT_GE(advises.size(), 1); } -TEST_F(MaterializedViewAdviseTest, TestTPCDSQ6) +TEST_F(MaterializedViewAdviseTest, DISABLED_TestTPCDSQ6) { std::string sql = tester->loadQuery("q6").sql.front().first; auto advises = getAdvises({sql}); diff --git a/src/AggregateFunctions/AggregateFunctionBitmapFromColumn.h b/src/AggregateFunctions/AggregateFunctionBitmapFromColumn.h index ae6fdd7e221..8a644c0b61e 100644 --- a/src/AggregateFunctions/AggregateFunctionBitmapFromColumn.h +++ b/src/AggregateFunctions/AggregateFunctionBitmapFromColumn.h @@ -39,7 +39,8 @@ struct AggregateFunctionBitMapFromColumnData { // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; - using Array = PODArray; + /// to reduce memory allocation, initial size is set to 128 + using Array = PODArray; Array value; BitMap64 bitmap; @@ -77,7 +78,7 @@ class AggregateFunctionBitMapFromColumn final : public IAggregateFunctionDataHel { auto & cur_value = this->data(place).value; auto & cur_bitmap = this->data(place).bitmap; - auto & rhs_value = this->data(rhs).value; + const auto & rhs_value = this->data(rhs).value; auto & rhs_bitmap = const_cast(this->data(rhs).bitmap); if (!cur_value.empty()) @@ -130,7 +131,7 @@ class AggregateFunctionBitMapFromColumn final : public IAggregateFunctionDataHel buf.read(bitmap_chars.data(), bytes); auto & bitmap = this->data(place).bitmap; - bitmap.read(bitmap_chars.data(), bytes); + bitmap = roaring::Roaring64Map::readSafe(bitmap_chars.data(), bytes); } } diff --git a/src/AggregateFunctions/AggregateFunctionGenArrayMonth.cpp b/src/AggregateFunctions/AggregateFunctionGenArrayMonth.cpp index 6b3783caccf..c459e50d611 100644 --- a/src/AggregateFunctions/AggregateFunctionGenArrayMonth.cpp +++ b/src/AggregateFunctions/AggregateFunctionGenArrayMonth.cpp @@ -35,7 +35,7 @@ AggregateFunctionPtr createAggregateFunctionGenArrayMonth(const std::string & na String date_start = parameters[1].safeGet(); // use local timezone on default - String timezone = DateLUT::instance().getTimeZone(); + String timezone = DateLUT::sessionInstance().getTimeZone(); if (parameters.size() == 3) { timezone = parameters[2].safeGet(); } diff --git a/src/AggregateFunctions/AggregateFunctionSimpleState.h b/src/AggregateFunctions/AggregateFunctionSimpleState.h index 19488e7d98f..b1aa1f94cd4 100644 --- a/src/AggregateFunctions/AggregateFunctionSimpleState.h +++ b/src/AggregateFunctions/AggregateFunctionSimpleState.h @@ -35,18 +35,20 @@ class AggregateFunctionSimpleState final : public IAggregateFunctionHelpergetReturnType()->getName()); + // Need to make a clone to avoid recursive reference. + auto storage_type_out = DataTypeFactory::instance().get(nested_func->getReturnType()->getName()); // Need to make a new function with promoted argument types because SimpleAggregates requires arg_type = return_type. AggregateFunctionProperties properties; - auto function - = AggregateFunctionFactory::instance().get(nested_func->getName(), {storage_type}, nested_func->getParameters(), properties); - + auto function + = AggregateFunctionFactory::instance().get(nested_func->getName(), {storage_type_out}, nested_func->getParameters(), properties); + + // Need to make a clone because it'll be customized. + auto storage_type_arg = DataTypeFactory::instance().get(nested_func->getReturnType()->getName()); DataTypeCustomNamePtr custom_name = std::make_unique(function, DataTypes{nested_func->getReturnType()}, params); - storage_type->setCustomization(std::make_unique(std::move(custom_name), nullptr)); - return storage_type; + storage_type_arg->setCustomization(std::make_unique(std::move(custom_name), nullptr)); + return storage_type_arg; } bool isState() const override diff --git a/src/AggregateFunctions/AggregateFunctionSketchEstimate.cpp b/src/AggregateFunctions/AggregateFunctionSketchEstimate.cpp index eab9defba88..c16b89cee42 100644 --- a/src/AggregateFunctions/AggregateFunctionSketchEstimate.cpp +++ b/src/AggregateFunctions/AggregateFunctionSketchEstimate.cpp @@ -31,16 +31,18 @@ AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_typ const IDataType & argument_type = *argument_types[0]; WhichDataType which(argument_type); + bool ignore_wrong_date = argument_types.size() == 2; + if (which.idx == TypeIndex::SketchBinary) { - return std::make_shared::template AggregateFunction>(argument_types, params); + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); } else if (which.isAggregateFunction()) { - return std::make_shared::template AggregateFunction>(argument_types, params); + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); } - return std::make_shared::template AggregateFunction>(argument_types, params); + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); } AggregateFunctionPtr createAggregateFunctionHllSketchEstimate @@ -67,7 +69,7 @@ AggregateFunctionPtr createAggregateFunctionHllSketchEstimate precision = precision_param; } - if (argument_types.size() != 1) + if (argument_types.size() != 1 && argument_types.size() != 2) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); switch (precision) @@ -137,7 +139,7 @@ AggregateFunctionPtr createAggregateFunctionHllSketchUnion precision = precision_param; } - if (argument_types.size() != 1) + if (argument_types.size() != 1 && argument_types.size() != 2) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (argument_types[0]->getTypeId() != TypeIndex::SketchBinary) throw Exception("Incorrect type of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); @@ -190,7 +192,7 @@ AggregateFunctionPtr createAggregateFunctionHllSketchUnion AggregateFunctionPtr createAggregateFunctionKllSketchEstimate (const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *) { - if (argument_types.size() != 1) + if (argument_types.size() != 1 && argument_types.size() != 2) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); Float64 quantile; @@ -208,27 +210,28 @@ AggregateFunctionPtr createAggregateFunctionKllSketchEstimate "Aggregate function " + name + " first parameter should between 0 and 1.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String type_name = params[1].safeGet(); - + bool ignore_wrong_date = argument_types.size() == 2; + if (type_name == "UInt8") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt16") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int8") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int16") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Float32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Float64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else throw Exception( "Aggregate function " + name + " second parameter not correct.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -242,7 +245,7 @@ AggregateFunctionPtr createAggregateFunctionKllSketchEstimate AggregateFunctionPtr createAggregateFunctionQuantilesSketchEstimate (const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *) { - if (argument_types.size() != 1) + if (argument_types.size() != 1 && argument_types.size() != 2) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); Float64 quantile; @@ -260,27 +263,28 @@ AggregateFunctionPtr createAggregateFunctionQuantilesSketchEstimate "Aggregate function " + name + " first parameter should between 0 and 1.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String type_name = params[1].safeGet(); + bool ignore_wrong_date = argument_types.size() == 2; if (type_name == "UInt8") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt16") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "UInt64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int8") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int16") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Int64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Float32") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else if (type_name == "Float64") - return std::make_shared>(quantile, argument_types, params); + return std::make_shared>(quantile, argument_types, params, ignore_wrong_date); else throw Exception( "Aggregate function " + name + " second parameter not correct.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -293,7 +297,7 @@ AggregateFunctionPtr createAggregateFunctionQuantilesSketchEstimate AggregateFunctionPtr createAggregateFunctionQuantilesSketchUnion (const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *) { - if (argument_types.size() != 1) + if (argument_types.size() != 1 && argument_types.size() != 2) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!params.empty()) @@ -305,27 +309,28 @@ AggregateFunctionPtr createAggregateFunctionQuantilesSketchUnion } String type_name = params[0].safeGet(); + bool ignore_wrong_date = argument_types.size() == 2; if (type_name == "UInt8") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "UInt16") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "UInt32") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "UInt64") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Int8") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Int16") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Int32") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Int64") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Float32") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else if (type_name == "Float64") - return std::make_shared>(argument_types, params); + return std::make_shared>(argument_types, params, ignore_wrong_date); else throw Exception( "Aggregate function " + name + " second parameter not correct.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); diff --git a/src/AggregateFunctions/AggregateFunctionSketchEstimate.h b/src/AggregateFunctions/AggregateFunctionSketchEstimate.h index 826e5e3ec62..0623daeaa60 100644 --- a/src/AggregateFunctions/AggregateFunctionSketchEstimate.h +++ b/src/AggregateFunctions/AggregateFunctionSketchEstimate.h @@ -37,8 +37,8 @@ class AggregateFunctionHLLSketchUnion final : public IAggregateFunctionDataHelper, AggregateFunctionHLLSketchUnion> { public: - AggregateFunctionHLLSketchUnion(const DataTypes & argument_types_, const Array & params_) - : IAggregateFunctionDataHelper, AggregateFunctionHLLSketchUnion>(argument_types_, params_){} + AggregateFunctionHLLSketchUnion(const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionHLLSketchUnion>(argument_types_, params_), ignore_wrong_data(ignore_wrong_data_) {} String getName() const override { @@ -52,9 +52,19 @@ class AggregateFunctionHLLSketchUnion final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - datasketches::hll_sketch hll_sketch_data = datasketches::hll_sketch::deserialize(value.data, value.size, AggregateFunctionHllSketchAllocator()); - this->data(place).u.update(hll_sketch_data); + try + { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + datasketches::hll_sketch hll_sketch_data = datasketches::hll_sketch::deserialize(value.data, value.size, AggregateFunctionHllSketchAllocator()); + this->data(place).u.update(hll_sketch_data); + } + catch (std::exception & e) + { + if (!ignore_wrong_data) + throw e; + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override @@ -84,6 +94,7 @@ class AggregateFunctionHLLSketchUnion final bool allocatesMemoryInArena() const override { return true; } private: + bool ignore_wrong_data = false; inline datasketches::hll_sketch readHLLSketch(ReadBuffer & buf) const { String d; @@ -97,8 +108,12 @@ class AggregateFunctionHllSketchEstimate final : public IAggregateFunctionDataHelper, AggregateFunctionHllSketchEstimate> { public: - AggregateFunctionHllSketchEstimate(const DataTypes & argument_types_, const Array & params_) - : IAggregateFunctionDataHelper, AggregateFunctionHllSketchEstimate>(argument_types_, params_) {} + AggregateFunctionHllSketchEstimate(const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionHllSketchEstimate>(argument_types_, params_), ignore_wrong_data(ignore_wrong_data_) + { + if (params_.size() == 2) + use_composite_estimate = true; + } String getName() const override { @@ -112,24 +127,36 @@ class AggregateFunctionHllSketchEstimate final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - // String is for new datatype "Sketch" - if constexpr (std::is_same_v) + try { - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - datasketches::hll_sketch hllSketch = datasketches::hll_sketch::deserialize(value.data, value.size, AggregateFunctionHllSketchAllocator()); - this->data(place).u.update(hllSketch); + // String is for new datatype "Sketch" + if constexpr (std::is_same_v) + { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + datasketches::hll_sketch hllSketch = datasketches::hll_sketch::deserialize(value.data, value.size, AggregateFunctionHllSketchAllocator()); + this->data(place).u.update(hllSketch); + } + else if constexpr (std::is_same_v) + { + //the format of this value should be the same with serialize + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + ReadBuffer buf(const_cast(value.data), value.size); + this->data(place).u.update(readHllSketch(buf)); + } + else + { + StringRef value = columns[0]->getDataAt(row_num); + this->data(place).u.update(value.toString()); + } } - else if constexpr (std::is_same_v) + catch (std::exception & e) { - //the format of this value should be the same with serialize - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - ReadBuffer buf(const_cast(value.data), value.size); - this->data(place).u.update(readHllSketch(buf)); - } - else - { - StringRef value = columns[0]->getDataAt(row_num); - this->data(place).u.update(value.toString()); + if (!ignore_wrong_data) + throw e; } } @@ -152,12 +179,17 @@ class AggregateFunctionHllSketchEstimate final void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * ) const override { - static_cast(to).getData().push_back(this->data(place).u.get_estimate()); + if (use_composite_estimate) + static_cast(to).getData().push_back(this->data(place).u.get_composite_estimate()); + else + static_cast(to).getData().push_back(this->data(place).u.get_estimate()); } bool allocatesMemoryInArena() const override { return true; } private: + bool ignore_wrong_data = false; + bool use_composite_estimate = false; inline datasketches::hll_sketch readHllSketch(ReadBuffer & buf) const { String d; @@ -181,8 +213,8 @@ class AggregateFunctionKllSketchEstimate final : public IAggregateFunctionDataHelper, AggregateFunctionKllSketchEstimate> { public: - AggregateFunctionKllSketchEstimate(const double quantile_, const DataTypes & argument_types_, const Array & params_) - : IAggregateFunctionDataHelper, AggregateFunctionKllSketchEstimate>(argument_types_, params_),quantile(quantile_) {} + AggregateFunctionKllSketchEstimate(const double quantile_, const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionKllSketchEstimate>(argument_types_, params_),quantile(quantile_), ignore_wrong_data(ignore_wrong_data_) {} Float64 quantile = 0; @@ -198,9 +230,18 @@ class AggregateFunctionKllSketchEstimate final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - datasketches::kll_sketch kll_sketch_data = datasketches::kll_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); - this->data(place).u.merge(kll_sketch_data); + try { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + datasketches::kll_sketch kll_sketch_data = datasketches::kll_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); + this->data(place).u.merge(kll_sketch_data); + } + catch (std::exception & e) + { + if (!ignore_wrong_data) + throw e; + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override @@ -229,6 +270,7 @@ class AggregateFunctionKllSketchEstimate final bool allocatesMemoryInArena() const override { return true; } private: + bool ignore_wrong_data = false; inline datasketches::kll_sketch readKllSketch(ReadBuffer & buf) const { String d; @@ -252,8 +294,8 @@ class AggregateFunctionQuantilesSketchEstimate final : public IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchEstimate> { public: - AggregateFunctionQuantilesSketchEstimate(const double quantile_, const DataTypes & argument_types_, const Array & params_) - : IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchEstimate>(argument_types_, params_),quantile(quantile_) {} + AggregateFunctionQuantilesSketchEstimate(const double quantile_, const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchEstimate>(argument_types_, params_),quantile(quantile_), ignore_wrong_data(ignore_wrong_data_) {} Float64 quantile = 0; @@ -269,9 +311,19 @@ class AggregateFunctionQuantilesSketchEstimate final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - datasketches::quantiles_sketch quantiles_sketch_data = datasketches::quantiles_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); - this->data(place).u.merge(quantiles_sketch_data); + try + { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + datasketches::quantiles_sketch quantiles_sketch_data = datasketches::quantiles_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); + this->data(place).u.merge(quantiles_sketch_data); + } + catch (std::exception & e) + { + if (!ignore_wrong_data) + throw e; + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override @@ -299,6 +351,7 @@ class AggregateFunctionQuantilesSketchEstimate final bool allocatesMemoryInArena() const override { return true; } private: + bool ignore_wrong_data = false; inline datasketches::quantiles_sketch readQuantilesSketch(ReadBuffer & buf) const { String d; @@ -309,11 +362,11 @@ class AggregateFunctionQuantilesSketchEstimate final template class AggregateFunctionQuantilesSketchUnion final - : public IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchEstimate> + : public IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchUnion> { public: - AggregateFunctionQuantilesSketchUnion(const DataTypes & argument_types_, const Array & params_) - : IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchEstimate>(argument_types_, params_){} + AggregateFunctionQuantilesSketchUnion(const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionQuantilesSketchUnion>(argument_types_, params_), ignore_wrong_data(ignore_wrong_data_) {} String getName() const override { @@ -327,9 +380,19 @@ class AggregateFunctionQuantilesSketchUnion final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - const auto & value = static_cast(*columns[0]).getDataAt(row_num); - datasketches::quantiles_sketch quantiles_sketch_data = datasketches::quantiles_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); - this->data(place).u.merge(quantiles_sketch_data); + try + { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + datasketches::quantiles_sketch quantiles_sketch_data = datasketches::quantiles_sketch::deserialize(value.data, value.size, datasketches::serde(), std::less(), AggregateFunctionHllSketchAllocator()); + this->data(place).u.merge(quantiles_sketch_data); + } + catch (std::exception & e) + { + if (!ignore_wrong_data) + throw e; + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override @@ -359,6 +422,7 @@ class AggregateFunctionQuantilesSketchUnion final bool allocatesMemoryInArena() const override { return true; } private: + bool ignore_wrong_data = false; inline datasketches::quantiles_sketch readQuantilesSketch(ReadBuffer & buf) const { String d; diff --git a/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.cpp b/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.cpp new file mode 100644 index 00000000000..20125088e65 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.cpp @@ -0,0 +1,137 @@ +#include + +#include +#include + +#include +#include + +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +namespace +{ + + template + struct WithK + { + template + using AggregateFunction = AggregateFunctionThetaSketchEstimate; + }; + + template + AggregateFunctionPtr createAggregateFunctionWithK(const DataTypes & argument_types, const Array & params) + { + const IDataType & argument_type = *argument_types[0]; + WhichDataType which(argument_type); + + bool ignore_wrong_date = argument_types.size() == 2; + + if (which.isSketchBinary()) + { + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); + } + else if (which.isAggregateFunction()) + { + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); + } + else if (which.isString()) + { + return std::make_shared::template AggregateFunction>(argument_types, params, ignore_wrong_date); + } + else + { + throw Exception("Incorrect columns type for aggregate function: " + argument_type.getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + } + + + AggregateFunctionPtr createAggregateFunctionThetaSketchEstimate + (const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *) + { + UInt8 precision = 15; + if (!params.empty()) + { + if (params.size() != 1) + { + throw Exception( + "Aggregate function " + name + " requires one parameter or less.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + } + + UInt64 precision_param = applyVisitor(FieldVisitorConvertToNumber(), params[0]); + // This range is hardcoded below + if (precision_param > 26 || precision_param < 5) + { + throw Exception( + "Parameter for aggregate function " + name + "is out or range: [5, 26].", ErrorCodes::ARGUMENT_OUT_OF_BOUND); + } + + precision = precision_param; + } + if (argument_types.size() != 1 && argument_types.size() != 2) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + switch (precision) + { + case 5: + return createAggregateFunctionWithK<5>(argument_types, params); + case 6: + return createAggregateFunctionWithK<6>(argument_types, params); + case 7: + return createAggregateFunctionWithK<7>(argument_types, params); + case 8: + return createAggregateFunctionWithK<8>(argument_types, params); + case 9: + return createAggregateFunctionWithK<9>(argument_types, params); + case 10: + return createAggregateFunctionWithK<10>(argument_types, params); + case 11: + return createAggregateFunctionWithK<11>(argument_types, params); + case 12: + return createAggregateFunctionWithK<12>(argument_types, params); + case 13: + return createAggregateFunctionWithK<13>(argument_types, params); + case 14: + return createAggregateFunctionWithK<14>(argument_types, params); + case 15: + return createAggregateFunctionWithK<15>(argument_types, params); + case 16: + return createAggregateFunctionWithK<16>(argument_types, params); + case 17: + return createAggregateFunctionWithK<17>(argument_types, params); + case 18: + return createAggregateFunctionWithK<18>(argument_types, params); + case 19: + return createAggregateFunctionWithK<19>(argument_types, params); + case 20: + return createAggregateFunctionWithK<20>(argument_types, params); + case 21: + return createAggregateFunctionWithK<21>(argument_types, params); + case 22: + return createAggregateFunctionWithK<22>(argument_types, params); + case 23: + return createAggregateFunctionWithK<23>(argument_types, params); + case 24: + return createAggregateFunctionWithK<24>(argument_types, params); + case 25: + return createAggregateFunctionWithK<25>(argument_types, params); + case 26: + return createAggregateFunctionWithK<26>(argument_types, params); + } + + __builtin_unreachable(); + } +} + +void registerAggregateFunctionThetaSketchEstimate(AggregateFunctionFactory & factory) +{ + factory.registerFunction("thetaSketchEstimate", createAggregateFunctionThetaSketchEstimate); +} +} diff --git a/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.h b/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.h new file mode 100644 index 00000000000..3341aded950 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionThetaSketchEstimate.h @@ -0,0 +1,121 @@ +// +// Created by vita.lai on 2022/7/9. +// +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + + +namespace DB +{ +template +struct AggregateFunctionThetaSketchEstimateData +{ + datasketches::theta_union sk_union; + AggregateFunctionThetaSketchEstimateData():sk_union(datasketches::theta_union::builder().set_lg_k(K).build()){} + static String getName() { return "theta_sketch"; } +}; + +template +class AggregateFunctionThetaSketchEstimate final + : public IAggregateFunctionDataHelper, AggregateFunctionThetaSketchEstimate> +{ +public: + AggregateFunctionThetaSketchEstimate(const DataTypes & argument_types_, const Array & params_, bool ignore_wrong_data_ = false) + : IAggregateFunctionDataHelper, AggregateFunctionThetaSketchEstimate>(argument_types_, params_), ignore_wrong_data(ignore_wrong_data_) {} + + String getName() const override + { + return "thetaSketchEstimate"; + } + + DataTypePtr getReturnType() const override + { + return std::make_shared(); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override + { + try + { + // String is for new datatype "Sketch" + if constexpr (std::is_same_v) + { + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + // datasketches::compact_theta_sketch thetaSketch = datasketches::compact_theta_sketch::deserialize(value.data, value.size, datasketches::DEFAULT_SEED, AggregateFunctionThetaSketchAllocator()); + this->data(place).sk_union.update(datasketches::wrapped_compact_theta_sketch::wrap(value.data, value.size)); + } + else if constexpr (std::is_same_v) + { + //the format of this value should be the same with serialize + const auto & value = static_cast(*columns[0]).getDataAt(row_num); + if (ignore_wrong_data && value.size == 0) + return; + // ReadBuffer buf(const_cast(value.data), value.size); + // this->data(place).sk_union.update(readThetaSketch(buf)); + this->data(place).sk_union.update(datasketches::wrapped_compact_theta_sketch::wrap(value.data, value.size)); + } + else + { + StringRef value = columns[0]->getDataAt(row_num); + datasketches::update_theta_sketch sk_update = datasketches::update_theta_sketch::builder().build(); + sk_update.update(value.toString()); + this->data(place).sk_union.update(sk_update); + } + } + catch (std::exception & e) + { + if (!ignore_wrong_data) + throw e; + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).sk_union.update(this->data(rhs).sk_union.get_result()); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override + { + std::ostringstream oss; + this->data(place).sk_union.get_result().serialize(oss); + writeBinary(oss.str(), buf); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override + { + this->data(place).sk_union.update(readThetaSketch(buf)); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * ) const override + { + static_cast(to).getData().push_back(this->data(place).sk_union.get_result().get_estimate()); + } + + bool allocatesMemoryInArena() const override { return true; } + +private: + bool ignore_wrong_data = false; + inline datasketches::compact_theta_sketch readThetaSketch(ReadBuffer & buf) const + { + String d; + readBinary(d, buf); + return datasketches::compact_theta_sketch::deserialize(d.data(), d.size(), datasketches::DEFAULT_SEED, AggregateFunctionThetaSketchAllocator()); + } +}; +} diff --git a/src/AggregateFunctions/AggregateFunctionUniq.h b/src/AggregateFunctions/AggregateFunctionUniq.h index 3bb5d53f2f6..b4c3547b69e 100644 --- a/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/src/AggregateFunctions/AggregateFunctionUniq.h @@ -413,10 +413,10 @@ class AggregateFunctionUniq final : public IAggregateFunctionDataHelper & is_cancelled, Arena *) const override { if constexpr (is_able_to_parallelize_merge) - this->data(place).set.merge(this->data(rhs).set, &thread_pool); + this->data(place).set.merge(this->data(rhs).set, &thread_pool, &is_cancelled); else this->data(place).set.merge(this->data(rhs).set); } @@ -514,10 +514,10 @@ class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper< bool isAbleToParallelizeMerge() const override { return is_able_to_parallelize_merge; } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena *) const override + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, std::atomic & is_cancelled, Arena *) const override { if constexpr (is_able_to_parallelize_merge) - this->data(place).set.merge(this->data(rhs).set, &thread_pool); + this->data(place).set.merge(this->data(rhs).set, &thread_pool, &is_cancelled); else this->data(place).set.merge(this->data(rhs).set); } diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index d4aa34f1857..944575674a1 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -168,7 +168,7 @@ class IAggregateFunction : public std::enable_shared_from_this & /*is_cancelled*/, Arena * /*arena*/) const { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "merge() with thread pool parameter isn't implemented for {} ", getName()); } diff --git a/src/AggregateFunctions/IAggregateFunctionMySql.h b/src/AggregateFunctions/IAggregateFunctionMySql.h index 625790d69cb..51e4f59bd0b 100644 --- a/src/AggregateFunctions/IAggregateFunctionMySql.h +++ b/src/AggregateFunctions/IAggregateFunctionMySql.h @@ -192,9 +192,9 @@ class IAggregateFunctionMySql : public IAggregateFunction function->merge(place, rhs, arena); } bool isAbleToParallelizeMerge() const override { return function->isAbleToParallelizeMerge(); } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena * arena) const override + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, std::atomic & is_cancelled, Arena * arena) const override { - function->merge(place, rhs, thread_pool, arena); + function->merge(place, rhs, thread_pool, is_cancelled, arena); } void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override { function->serialize(place, buf); } void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena * arena) const override diff --git a/src/AggregateFunctions/UniqExactSet.h b/src/AggregateFunctions/UniqExactSet.h index cf05d54e541..33eac5eaf66 100644 --- a/src/AggregateFunctions/UniqExactSet.h +++ b/src/AggregateFunctions/UniqExactSet.h @@ -26,7 +26,7 @@ class UniqExactSet asTwoLevel().insert(std::forward(arg)); } - auto merge(const UniqExactSet & other, ThreadPool * thread_pool = nullptr) + auto merge(const UniqExactSet & other, ThreadPool * thread_pool = nullptr, std::atomic * is_cancelled = nullptr) { if (isSingleLevel() && other.isTwoLevel()) convertToTwoLevel(); @@ -49,7 +49,7 @@ class UniqExactSet { auto next_bucket_to_merge = std::make_shared(0); - auto thread_func = [&lhs, &rhs, next_bucket_to_merge, thread_group = CurrentThread::getGroup()]() + auto thread_func = [&lhs, &rhs, next_bucket_to_merge, is_cancelled, thread_group = CurrentThread::getGroup()]() { if (thread_group) CurrentThread::attachToIfDetached(thread_group); @@ -57,6 +57,8 @@ class UniqExactSet while (true) { + if (is_cancelled->load(std::memory_order_seq_cst)) + return; const auto bucket = next_bucket_to_merge->fetch_add(1); if (bucket >= rhs.NUM_BUCKETS) return; diff --git a/src/AggregateFunctions/registerAggregateFunctions.cpp b/src/AggregateFunctions/registerAggregateFunctions.cpp index 2e04caddb1e..838d11cdb4e 100644 --- a/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -124,6 +124,7 @@ void registerAggregateFunctionNdvBuckets(AggregateFunctionFactory & factory); void registerAggregateFunctionNdvBucketsExtend(AggregateFunctionFactory & factory); void registerAggregateFunctionNothing(AggregateFunctionFactory & factory); void registerAggregateFunctionHllSketchEstimate(AggregateFunctionFactory &); +void registerAggregateFunctionThetaSketchEstimate(AggregateFunctionFactory &); void registerAggregateFunctionAuc(AggregateFunctionFactory &); void registerAggregateFunctionFastAuc(AggregateFunctionFactory &); void registerAggregateFunctionFastAuc2(AggregateFunctionFactory &); @@ -259,6 +260,7 @@ void registerAggregateFunctions() registerAggregateFunctionNdvBucketsExtend(factory); registerAggregateFunctionNothing(factory); registerAggregateFunctionHllSketchEstimate(factory); + registerAggregateFunctionThetaSketchEstimate(factory); registerAggregateFunctionAuc(factory); registerAggregateFunctionFastAuc(factory); registerAggregateFunctionFastAuc2(factory); diff --git a/src/Analyzers/ASTEquals.cpp b/src/Analyzers/ASTEquals.cpp index 29a40d52d38..b08fac5f859 100644 --- a/src/Analyzers/ASTEquals.cpp +++ b/src/Analyzers/ASTEquals.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace DB::ASTEquality @@ -57,6 +58,13 @@ bool compareNode(const ASTWindowDefinition & left, const ASTWindowDefinition & r left.frame_end_preceding == right.frame_end_preceding; } +bool compareNode(const ASTClusterByElement & left, const ASTClusterByElement & right) +{ + return left.split_number == right.split_number && + left.is_with_range == right.is_with_range && + left.is_user_defined_expression == right.is_user_defined_expression; +} + bool compareNode(const ASTSubquery & left, const ASTSubquery & right) { return left.cte_name == right.cte_name && left.database_of_view == right.database_of_view; @@ -144,6 +152,9 @@ bool compareTree(const ASTPtr & left, const ASTPtr & right, const SubtreeCompara case ASTType::ASTTableIdentifier: node_equals = compareNode(left->as(), right->as()); break; + case ASTType::ASTClusterByElement: + node_equals = compareNode(left->as(), right->as()); + break; default: node_equals = left->getID() == right->getID(); // align with ScopeAwareHash } diff --git a/src/Analyzers/Analysis.cpp b/src/Analyzers/Analysis.cpp index 08251361410..56218553d2d 100644 --- a/src/Analyzers/Analysis.cpp +++ b/src/Analyzers/Analysis.cpp @@ -15,6 +15,10 @@ #include +#include +#include +#include + namespace DB { @@ -32,6 +36,11 @@ namespace DB throw Exception("Object already exists in " #container, ErrorCodes::LOGICAL_ERROR); \ } while(false) \ +namespace ErrorCodes +{ + extern const int INCORRECT_RESULT_OF_SCALAR_SUBQUERY; +} + void Analysis::setScope(IAST & statement, ScopePtr scope) { LOG_TRACE(logger, "scope of ast {}: {}", statement.dumpTree(0), scope->toString()); @@ -344,6 +353,14 @@ FieldDescriptions & Analysis::getOutputDescription(IAST & ast) MAP_GET(output_descriptions, &ast); } +bool Analysis::hasOutputDescription(IAST & ast) +{ + if (auto * subquery = ast.as()) + return hasOutputDescription(*subquery->children[0]); + + return output_descriptions.contains(&ast); +} + void Analysis::setRegisteredWindow(ASTSelectQuery & select_query, const String & name, ResolvedWindowPtr & window) { MAP_SET(registered_windows[&select_query], name, window); @@ -428,4 +445,143 @@ void Analysis::addUsedFunctionArgument(const String & func_name, ColumnsWithType function_arguments[func_name].emplace_back((*arg.column)[0].toString()); } } + +const Block & Analysis::getScalarSubqueryResult(const ASTPtr & subquery, ContextPtr context) +{ + auto hash = subquery->getTreeHash(); + String hash_str = toString(hash.first) + "_" + toString(hash.second); + + if (!executed_scalar_subqueries.count(hash_str)) + { + // ContextMutablePtr subquery_context = Context::createCopy(context); + // Settings subquery_settings = context->getSettings(); + // subquery_settings.max_result_rows = 1; + // subquery_settings.extremes = false; + // subquery_context->setSettings(subquery_settings); + SelectQueryOptions subquery_options; + auto & ast_subquery = subquery->as(); + auto & inner_query = ast_subquery.children.front(); + if (!inner_query->as()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Unrecognized query type '{}' when executing subquery {}", + inner_query->getID(), + inner_query->formatForErrorMessage()); + InterpreterSelectQueryUseOptimizer interpreter{inner_query, std::const_pointer_cast(context), subquery_options}; + auto io = interpreter.execute(); + PullingAsyncPipelineExecutor executor(io.pipeline); + Block block; + + while (block.rows() == 0 && executor.pull(block)) + ; + + if (block.rows() > 1) + throw Exception( + ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, + "Scalar subquery returned more than one row: {}", + subquery->formatForErrorMessage()); + + if (block.rows() == 0) + { + auto types = interpreter.getSampleBlock().getDataTypes(); + if (types.size() != 1) + types = {std::make_shared(types)}; + + auto & type = types[0]; + if (!type->isNullable()) + { + if (!type->canBeInsideNullable()) + throw Exception( + ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, + "Scalar subquery returned empty result of type {} which cannot be Nullable", + type->getName()); + + type = makeNullable(type); + } + + auto null_column = type->createColumn(); + null_column->insert(Null{}); + block.clear(); + block.insert(ColumnWithTypeAndName{ColumnPtr{std::move(null_column)}, type, ""}); + } + else + { + block = materializeBlock(block); + size_t columns = block.columns(); + + if (columns == 1) + { + auto & column = block.getByPosition(0); + /// Here we wrap type to nullable if we can. + /// It is needed cause if subquery return no rows, it's result will be Null. + /// In case of many columns, do not check it cause tuple can't be nullable. + if (!column.type->isNullable() && column.type->canBeInsideNullable()) + { + column.type = makeNullable(column.type); + column.column = makeNullable(column.column); + } + } + else + { + ColumnWithTypeAndName ctn; + ctn.type = std::make_shared(block.getDataTypes()); + ctn.column = ColumnTuple::create(block.getColumns()); + block = Block{ctn}; + } + + Block tmp_block; + while (tmp_block.rows() == 0 && executor.pull(tmp_block)) + ; + + if (tmp_block.rows() > 0) + throw Exception( + ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY, + "Scalar subquery returned more than one row: {}", + subquery->formatForErrorMessage()); + } + + executed_scalar_subqueries.emplace(hash_str, std::move(block)); + } + + return executed_scalar_subqueries.at(hash_str); +} + +SetPtr Analysis::getInSubqueryResult(const ASTPtr & subquery, ContextPtr context) +{ + auto hash = subquery->getTreeHash(); + String hash_str = toString(hash.first) + "_" + toString(hash.second); + + if (!executed_in_subqueries.count(hash_str)) + { + // ContextMutablePtr subquery_context = Context::createCopy(context); + SelectQueryOptions subquery_options; + auto & ast_subquery = subquery->as(); + auto & inner_query = ast_subquery.children.front(); + if (!inner_query->as()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Unrecognized query type '{}' when executing subquery {}", + inner_query->getID(), + inner_query->formatForErrorMessage()); + InterpreterSelectQueryUseOptimizer interpreter{inner_query, std::const_pointer_cast(context), subquery_options}; + BlockIO io = interpreter.execute(); + PullingAsyncPipelineExecutor executor(io.pipeline); + SizeLimits limites(context->getSettingsRef().max_rows_in_set, context->getSettingsRef().max_bytes_in_set, OverflowMode::THROW); + SetPtr set = std::make_shared(limites, true, context->getSettingsRef().transform_null_in); + set->setHeader(interpreter.getSampleBlock()); + Block block; + + while (executor.pull(block)) + { + if (block.rows() == 0) + continue; + set->insertFromBlock(block); + } + + set->finishInsert(); + executed_in_subqueries.emplace(hash_str, set); + } + + return executed_in_subqueries.at(hash_str); +} } diff --git a/src/Analyzers/Analysis.h b/src/Analyzers/Analysis.h index 81bfe451182..cbe3fd6184f 100644 --- a/src/Analyzers/Analysis.h +++ b/src/Analyzers/Analysis.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include #include @@ -68,12 +70,14 @@ struct JoinEqualityCondition ASTPtr right_ast; DataTypePtr left_coercion; DataTypePtr right_coercion; + bool null_safe; - JoinEqualityCondition(ASTPtr left_ast_, ASTPtr right_ast_, DataTypePtr left_coercion_, DataTypePtr right_coercion_) + JoinEqualityCondition(ASTPtr left_ast_, ASTPtr right_ast_, DataTypePtr left_coercion_, DataTypePtr right_coercion_, bool null_safe_) : left_ast(std::move(left_ast_)) , right_ast(std::move(right_ast_)) , left_coercion(std::move(left_coercion_)) , right_coercion(std::move(right_coercion_)) + , null_safe(null_safe_) {} }; @@ -401,6 +405,7 @@ struct Analysis std::unordered_map output_descriptions; void setOutputDescription(IAST & ast, const FieldDescriptions & field_descs); FieldDescriptions & getOutputDescription(IAST & ast); + bool hasOutputDescription(IAST & ast); /// Sub column optimization std::unordered_map sub_column_references; @@ -467,6 +472,12 @@ struct Analysis std::unordered_map> function_arguments; void addUsedFunctionArgument(const String & func_name, ColumnsWithTypeAndName & processed_arguments); + + std::map executed_scalar_subqueries; + const Block & getScalarSubqueryResult(const ASTPtr & subquery, ContextPtr context); + + std::map executed_in_subqueries; + SetPtr getInSubqueryResult(const ASTPtr & subquery, ContextPtr context); }; } diff --git a/src/Analyzers/ExprAnalyzer.cpp b/src/Analyzers/ExprAnalyzer.cpp index d9e377bcea2..8b52222d60a 100644 --- a/src/Analyzers/ExprAnalyzer.cpp +++ b/src/Analyzers/ExprAnalyzer.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -50,6 +51,7 @@ #include #include #include +#include "Core/Field.h" #include @@ -164,7 +166,7 @@ class ExprAnalyzerVisitor : public ASTVisitorgetSettingsRef().early_execute_scalar_subquery || (isInLambda() && context->getSettingsRef().execute_subquery_in_lambda); + + if (early_execute_subquery) + { + const auto & block = analysis.getScalarSubqueryResult(node, context); + const auto & col_with_type = block.getByPosition(0); + auto lit = std::make_shared(col_with_type.column->operator[](0)); + lit->alias = node->tryGetAlias(); + lit->prefer_alias_to_column_name = node->as().prefer_alias_to_column_name; + node = addTypeConversionToAST(std::move(lit), col_with_type.type->getName()); + return {col_with_type.column, col_with_type.type, ""}; + } + + if (isInLambda()) + throw Exception("Subquery in lambda is not supported by ApplyStep", ErrorCodes::SYNTAX_ERROR); // when a scalar subquery has 0 rows, it returns NULL, hence we change its type to Nullable type // note that this feature is not compatible with subquery with multiple output returning Tuple type @@ -678,6 +697,10 @@ ColumnWithTypeAndName ExprAnalyzerVisitor::analyzeExistsSubquery(ASTFunctionPtr { handleSubquery(function->arguments->children[0], false); } + + if (isInLambda()) + throw Exception("Subquery in lambda is not supported by ApplyStep", ErrorCodes::SYNTAX_ERROR); + analysis.exists_subqueries[options.select_query].push_back(function); analysis.subquery_support_semi_anti[function] = ac.only_and; return {nullptr, std::make_shared(), function->getColumnName()}; @@ -688,12 +711,72 @@ void ExprAnalyzerVisitor::processSubqueryArgsWithCoercion(ASTPtr & lhs_ast, ASTP AnalyzeContext ac{.only_and = false}; auto lhs_type = process(lhs_ast, ac).type; auto rhs_type = handleSubquery(rhs_ast, false); + // TODO: we should only execute uncorrelated subqueries + // TODO: handle type mismatch + bool early_execute_subquery + = context->getSettingsRef().early_execute_in_subquery || (isInLambda() && context->getSettingsRef().execute_subquery_in_lambda); + + if (early_execute_subquery) + { + auto set = analysis.getInSubqueryResult(rhs_ast, context); + auto set_columns = set->getSetElements(); + Tuple coll; + for (size_t row_id = 0; row_id < set_columns[0]->size(); ++row_id) + { + if (set_columns.size() == 1) + { + coll.push_back(set_columns[0]->operator[](row_id)); + } + else + { + Tuple nested_tuple; + for (const auto & set_column : set_columns) + nested_tuple.push_back(set_column->operator[](row_id)); + coll.push_back(std::move(nested_tuple)); + } + } + // TODO: better handle empty set + if (coll.empty()) + { + if (set_columns.size() == 1) + { + coll.push_back(Null{}); + } + else + { + Tuple nested_tuple(set_columns.size(), Field()); + coll.push_back(nested_tuple); + } + } + + rhs_ast = std::make_shared(std::move(coll)); + return; + } + + if (isInLambda()) + throw Exception("Subquery in lambda is not supported by ApplyStep", ErrorCodes::SYNTAX_ERROR); if (!JoinCommon::isJoinCompatibleTypes(lhs_type, rhs_type)) { DataTypePtr super_type = nullptr; if (enable_implicit_type_conversion) - super_type = getLeastSupertype(DataTypes{lhs_type, rhs_type}, allow_extended_conversion); + { + if (context->getSettingsRef().convert_to_right_type_for_in_subquery) + { + if (const auto * type_tuple = typeid_cast(rhs_type.get())) + { + DataTypes elem_types = type_tuple->getElements(); + std::transform(elem_types.begin(), elem_types.end(), elem_types.begin(), &JoinCommon::convertTypeToNullable); + super_type = std::make_shared(elem_types, type_tuple->getElementNames()); + } + else + { + super_type = JoinCommon::convertTypeToNullable(rhs_type); + } + } + else + super_type = getLeastSupertype(DataTypes{lhs_type, rhs_type}, allow_extended_conversion); + } if (!super_type) throw Exception("Incompatible types for IN prediacte", ErrorCodes::TYPE_MISMATCH); if (!lhs_type->equals(*super_type)) @@ -831,9 +914,6 @@ DataTypePtr ExprAnalyzerVisitor::handleSubquery(const ASTPtr & subquery, bool us if (!options.select_query) throw Exception("Provide query node if subquery is allowed", ErrorCodes::LOGICAL_ERROR); - if (isInLambda()) - throw Exception("Subquery is not support in lambda", ErrorCodes::SYNTAX_ERROR); - QueryAnalyzer::analyze(subquery->children[0], currentScope(), context, analysis); auto & output_columns = analysis.getOutputDescription(*subquery); diff --git a/src/Analyzers/ExprAnalyzer.h b/src/Analyzers/ExprAnalyzer.h index 299529263f2..b0ae7052f65 100644 --- a/src/Analyzers/ExprAnalyzer.h +++ b/src/Analyzers/ExprAnalyzer.h @@ -136,7 +136,7 @@ class ExprAnalyzer * !! Caution: expression will be modified in these cases: `untuple` function */ static DataTypePtr analyze( - ASTPtr expression, ScopePtr scope, ContextPtr context, Analysis & analysis, ExprAnalyzerOptions options = ExprAnalyzerOptions{}); + ASTPtr & expression, ScopePtr scope, ContextPtr context, Analysis & analysis, ExprAnalyzerOptions options = ExprAnalyzerOptions{}); }; } diff --git a/src/Analyzers/ExpressionVisitor.h b/src/Analyzers/ExpressionVisitor.h index 264faf7bc5c..c4f92c937d7 100644 --- a/src/Analyzers/ExpressionVisitor.h +++ b/src/Analyzers/ExpressionVisitor.h @@ -167,7 +167,7 @@ class ExpressionTraversalIncludeSubqueryVisitor: public AnalyzerExpressionVisito { process(ast.children[0]); auto subquery = ast.children[1]; - process(subquery->children[0]); + process(subquery->as().children[0]); } public: diff --git a/src/Analyzers/QueryAnalyzer.cpp b/src/Analyzers/QueryAnalyzer.cpp index 6933b9188a8..e3a99f34d0a 100644 --- a/src/Analyzers/QueryAnalyzer.cpp +++ b/src/Analyzers/QueryAnalyzer.cpp @@ -133,6 +133,7 @@ class QueryAnalyzerVisitor : public ASTVisitor , enable_implicit_type_conversion(context->getSettingsRef().enable_implicit_type_conversion) , allow_extended_conversion(context->getSettingsRef().allow_extended_type_conversion) , enable_subcolumn_optimization_through_union(context->getSettingsRef().enable_subcolumn_optimization_through_union) + , enable_implicit_arg_type_convert(context->getSettingsRef().enable_implicit_arg_type_convert) { } @@ -145,6 +146,7 @@ class QueryAnalyzerVisitor : public ASTVisitor const bool enable_implicit_type_conversion; const bool allow_extended_conversion; const bool enable_subcolumn_optimization_through_union; + const bool enable_implicit_arg_type_convert; // MySQL implicit cast rules Poco::Logger * logger = &Poco::Logger::get("QueryAnalyzerVisitor"); @@ -204,6 +206,7 @@ class QueryAnalyzerVisitor : public ASTVisitor void rewriteSelectInANSIMode(ASTSelectQuery & select_query, const Aliases & aliases, const NameSet & source_columns_set); void normalizeAliases(ASTPtr & expr, ASTPtr & aliases_expr); void normalizeAliases(ASTPtr & expr, const Aliases & aliases, const NameSet & source_columns_set); + DataTypePtr getCommonType(const DataTypes & types); }; static NameSet collectNames(ScopePtr scope); @@ -448,10 +451,7 @@ void QueryAnalyzerVisitor::analyzeSetOperation(ASTPtr & node, ASTs & selects) DataTypePtr output_type; // promote output type to super type if necessary - if (context->getSettingsRef().enable_implicit_arg_type_convert) - output_type = getLeastSupertype(elem_types, true); - else - output_type = getLeastSupertype(elem_types, allow_extended_conversion); + output_type = getCommonType(elem_types); output_desc.emplace_back( first_input_desc[column_idx].name, output_type, @@ -869,6 +869,9 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( NameSet seen_names; FieldDescriptions output_fields; + bool make_nullable_for_left = isRightOrFull(table_join.kind) && context->getSettingsRef().join_use_nulls; + bool make_nullable_for_right = isLeftOrFull(table_join.kind) && context->getSettingsRef().join_use_nulls; + if (use_ansi_semantic) { auto resolve_join_key @@ -916,7 +919,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( { try { - output_type = getLeastSupertype(DataTypes{left_type, right_type}, allow_extended_conversion); + output_type = getCommonType(DataTypes{left_type, right_type}); } catch (DB::Exception & ex) { @@ -936,18 +939,22 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( } /// Step 2. add non join fields - auto add_non_join_fields = [&](ScopePtr scope, std::vector & join_fields_list) { + auto add_non_join_fields = [&](ScopePtr scope, std::vector & join_fields_list, bool make_nullable) { std::unordered_set join_fields{join_fields_list.begin(), join_fields_list.end()}; for (size_t i = 0; i < scope->size(); ++i) { if (join_fields.find(i) == join_fields.end()) + { output_fields.push_back(scope->at(i)); + if (make_nullable) + output_fields.back().type = JoinCommon::convertTypeToNullable(output_fields.back().type); + } } }; - add_non_join_fields(left_scope, left_join_fields); - add_non_join_fields(right_scope, right_join_fields); + add_non_join_fields(left_scope, left_join_fields, make_nullable_for_left); + add_non_join_fields(right_scope, right_join_fields, make_nullable_for_right); } else { @@ -956,7 +963,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( /// Step 1. resolve join key for (size_t i = 0, true_index = 0; i < expr_list.size(); ++i) { - const auto & join_key_ast = expr_list[i]; + auto & join_key_ast = expr_list[i]; String key_name = join_key_ast->getAliasOrColumnName(); // see also 00702_join_with_using: @@ -1007,7 +1014,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( { try { - output_type = getLeastSupertype(DataTypes{left_type, right_type}, allow_extended_conversion); + output_type = getCommonType(DataTypes{left_type, right_type}); } catch (DB::Exception & ex) { @@ -1042,6 +1049,8 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( else { output_fields.emplace_back(input_field); + if (make_nullable_for_left) + output_fields.back().type = JoinCommon::convertTypeToNullable(output_fields.back().type); } } @@ -1063,6 +1072,8 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinUsing( if (!right_join_field_reverse_map.count(i)) { output_fields.emplace_back(input_field.withNewName(new_name)); + if (make_nullable_for_right) + output_fields.back().type = JoinCommon::convertTypeToNullable(output_fields.back().type); } else if (required_columns.count(new_name) && !source_columns.count(new_name)) { @@ -1096,18 +1107,35 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinOn( ScopePtr output_scope; { FieldDescriptions output_fields; + bool make_nullable_for_left = isRightOrFull(table_join.kind) && context->getSettingsRef().join_use_nulls; + bool make_nullable_for_right = isLeftOrFull(table_join.kind) && context->getSettingsRef().join_use_nulls; + auto update_type = [&](DataTypePtr & type, bool make_nullable) + { + if (make_nullable) + return JoinCommon::convertTypeToNullable(type); + return type; + }; if (use_ansi_semantic) { for (const auto & f : left_scope->getFields()) + { output_fields.emplace_back(f); + output_fields.back().type = update_type(output_fields.back().type, make_nullable_for_left); + } for (const auto & f : right_scope->getFields()) + { output_fields.emplace_back(f); + output_fields.back().type = update_type(output_fields.back().type, make_nullable_for_right); + } } else { for (const auto & f : left_scope->getFields()) + { output_fields.emplace_back(f); + output_fields.back().type = update_type(output_fields.back().type, make_nullable_for_left); + } auto source_names = collectNames(left_scope); bool check_identifier_begin_valid = context->getSettingsRef().check_identifier_begin_valid; @@ -1116,6 +1144,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinOn( { auto new_name = qualifyJoinedName(f.name, right_table_qualifier, source_names, check_identifier_begin_valid); output_fields.emplace_back(f.withNewName(new_name)); + output_fields.back().type = update_type(output_fields.back().type, make_nullable_for_right); } } @@ -1191,7 +1220,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinOn( DataTypePtr right_coercion = nullptr; // for non-ASOF join, inequality_conditions will be included in join filter, so don't have to do type coercion - if (func->name == "equals" || isAsofJoin(table_join)) + if ((func->name == "equals" || func->name == "bitEquals") || isAsofJoin(table_join)) { DataTypePtr left_type = analysis.getExpressionType(left_ast); DataTypePtr right_type = analysis.getExpressionType(right_ast); @@ -1204,7 +1233,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinOn( { try { - super_type = getLeastSupertype(DataTypes{left_type, right_type}, allow_extended_conversion); + super_type = getCommonType(DataTypes{left_type, right_type}); } catch (DB::Exception & ex) { @@ -1226,8 +1255,8 @@ ScopePtr QueryAnalyzerVisitor::analyzeJoinOn( } } - if (func->name == "equals") - equality_conditions.emplace_back(left_ast, right_ast, left_coercion, right_coercion); + if (func->name == "equals" || func->name == "bitEquals") + equality_conditions.emplace_back(left_ast, right_ast, left_coercion, right_coercion, func->name == "bitEquals"); else inequality_conditions.emplace_back(left_ast, right_ast, inequality, left_coercion, right_coercion); @@ -1275,7 +1304,7 @@ ScopePtr QueryAnalyzerVisitor::analyzeArrayJoin(ASTArrayJoin & array_join, ASTSe ArrayJoinDescriptions array_join_descs; FieldDescriptions output_fields = source_scope->getFields(); NameSet name_set; - for (const auto & array_join_expr : array_join_expression_list->children) + for (auto & array_join_expr : array_join_expression_list->children) { if (array_join_expr->tryGetAlias() == array_join_expr->getColumnName() && !array_join_expr->as()) throw Exception("No alias for non-trivial value in ARRAY JOIN: " + array_join_expr->tryGetAlias(), ErrorCodes::ALIAS_REQUIRED); @@ -1299,11 +1328,15 @@ ScopePtr QueryAnalyzerVisitor::analyzeArrayJoin(ASTArrayJoin & array_join, ASTSe ArrayJoinDescription array_join_desc; array_join_desc.expr = array_join_expr; - if (col_ref && array_join_expr->tryGetAlias().empty()) + if (col_ref && array_join_expr->tryGetAlias().empty()) // ARRAY JOIN `arr` { output_fields[col_ref->local_index] = FieldDescription{output_fields[col_ref->local_index].name, array_type->getNestedType()}; } - else + else if (col_ref && !array_join_expr->tryGetAlias().empty() && array_join_expr->tryGetAlias() == output_fields[col_ref->local_index].name) // ARRAY JOIN `arr` as `arr` + { + output_fields[col_ref->local_index] = FieldDescription{output_fields[col_ref->local_index].name, array_type->getNestedType()}; + } + else // ARRAY JOIN `arr` as `arr2` { array_join_desc.create_new_field = true; output_fields.emplace_back(output_name, array_type->getNestedType()); @@ -1375,7 +1408,7 @@ void QueryAnalyzerVisitor::analyzeWhere(ASTSelectQuery & select_query, ScopePtr ExprAnalyzerOptions expr_options{"WHERE expression"}; expr_options.selectQuery(select_query).subquerySupport(ExprAnalyzerOptions::SubquerySupport::CORRELATED).subqueryToSemiAnti(true); - auto filter_type = ExprAnalyzer::analyze(select_query.where(), source_scope, context, analysis, expr_options); + auto filter_type = ExprAnalyzer::analyze(select_query.refWhere(), source_scope, context, analysis, expr_options); if (auto inner_type = removeNullable(removeLowCardinality(filter_type))) { @@ -1400,7 +1433,7 @@ ASTs QueryAnalyzerVisitor::analyzeSelect(ASTSelectQuery & select_query, ScopePtr .aggregateSupport(ExprAnalyzerOptions::AggregateSupport::ALLOWED) .windowSupport(ExprAnalyzerOptions::WindowSupport::ALLOWED); - auto add_select_expression = [&](const ASTPtr & expression) { + auto add_select_expression = [&](ASTPtr & expression) { auto expression_type = ExprAnalyzer::analyze(expression, source_scope, context, analysis, expr_options); auto get_output_name = [&](const ASTPtr & expr) -> String { @@ -1448,7 +1481,7 @@ ASTs QueryAnalyzerVisitor::analyzeSelect(ASTSelectQuery & select_query, ScopePtr { if (source_scope->at(field_index).substituted_by_asterisk) { - auto field_reference = std::make_shared(field_index); + ASTPtr field_reference = std::make_shared(field_index); add_select_expression(field_reference); } } @@ -1467,7 +1500,7 @@ ASTs QueryAnalyzerVisitor::analyzeSelect(ASTSelectQuery & select_query, ScopePtr if (source_scope->at(field_index).substituted_by_asterisk && source_scope->at(field_index).prefix.hasSuffix(prefix)) { matched = true; - auto field_reference = std::make_shared(field_index); + ASTPtr field_reference = std::make_shared(field_index); add_select_expression(field_reference); } } @@ -1481,7 +1514,7 @@ ASTs QueryAnalyzerVisitor::analyzeSelect(ASTSelectQuery & select_query, ScopePtr if (source_scope->at(field_index).substituted_by_asterisk && asterisk_pattern->isColumnMatching(source_scope->at(field_index).name)) { - auto field_reference = std::make_shared(field_index); + ASTPtr field_reference = std::make_shared(field_index); add_select_expression(field_reference); } } @@ -1506,7 +1539,7 @@ ASTs QueryAnalyzerVisitor::analyzeSelect(ASTSelectQuery & select_query, ScopePtr { auto tuple_ast = function->arguments->children[0]; auto literal = std::make_shared(UInt64(++tid)); - auto func = makeASTFunction("tupleElement", tuple_ast, literal); + ASTPtr func = makeASTFunction("tupleElement", tuple_ast, literal); add_select_expression(func); } } @@ -1648,7 +1681,7 @@ void QueryAnalyzerVisitor::analyzeHaving(ASTSelectQuery & select_query, ScopePtr expr_options.selectQuery(select_query) .subquerySupport(ExprAnalyzerOptions::SubquerySupport::CORRELATED) .aggregateSupport(ExprAnalyzerOptions::AggregateSupport::ALLOWED); - ExprAnalyzer::analyze(select_query.having(), source_scope, context, analysis, expr_options); + ExprAnalyzer::analyze(select_query.refHaving(), source_scope, context, analysis, expr_options); } void QueryAnalyzerVisitor::analyzeOrderBy(ASTSelectQuery & select_query, ASTs & select_expressions, ScopePtr output_scope) @@ -2152,6 +2185,14 @@ void QueryAnalyzerVisitor::normalizeAliases(ASTPtr & expr, const Aliases & alias QueryNormalizer(normalizer_data).visit(expr); } +DataTypePtr QueryAnalyzerVisitor::getCommonType(const DataTypes & types) +{ + if (enable_implicit_arg_type_convert) + return getLeastSupertype(types, true); + else + return getLeastSupertype(types, allow_extended_conversion); +} + NameSet collectNames(ScopePtr scope) { NameSet result; diff --git a/src/Analyzers/QueryRewriter.cpp b/src/Analyzers/QueryRewriter.cpp index cf5040d1ec8..ceff262f74d 100644 --- a/src/Analyzers/QueryRewriter.cpp +++ b/src/Analyzers/QueryRewriter.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,7 @@ #include #include +#include "Core/SettingsEnums.h" namespace DB { @@ -195,6 +197,8 @@ namespace { if (context->getSettingsRef().rewrite_like_function) SimpleFunctionVisitor().visit(query); + if (context->getSettingsRef().dialect_type != DialectType::CLICKHOUSE && !context->getSettingsRef().only_full_group_by) + SubstituteSelectItemToAnyFunction(context).visit(query); } void expandView(ASTPtr & query, ContextMutablePtr context, int & graphviz_index) diff --git a/src/Analyzers/ReplaceViewWithSubqueryVisitor.h b/src/Analyzers/ReplaceViewWithSubqueryVisitor.h index 3cc3249fb15..e6091d26c0d 100644 --- a/src/Analyzers/ReplaceViewWithSubqueryVisitor.h +++ b/src/Analyzers/ReplaceViewWithSubqueryVisitor.h @@ -19,10 +19,17 @@ #include #include #include "Parsers/ASTIdentifier.h" +#include +#include namespace DB { +namespace ErrorCodes +{ + extern const int ACCESS_DENIED; +} + struct ReplaceViewWithSubquery { using TypeToVisit = ASTTableExpression; @@ -47,6 +54,19 @@ struct ReplaceViewWithSubquery if (dynamic_cast(table.get())) { auto table_metadata_snapshot = table->getInMemoryMetadataPtr(); + { + // check access rights. + auto access = context->getAccess(); + if (!access->isGranted(AccessType::SELECT, database_name, table_name)) + { + throw Exception( + ErrorCodes::ACCESS_DENIED, + "{}: Not enough privileges. To execute this query it's necessary to have grant SELECT on {}", + context->getUserName(), + table->getStorageID().getFullTableName()); + } + } + auto subquery = table_metadata_snapshot->getSelectQuery().inner_query->clone(); const auto alias = table_expression.database_and_table_name->tryGetAlias(); table_expression.database_and_table_name = {}; diff --git a/src/Analyzers/RewriteFusionMerge.cpp b/src/Analyzers/RewriteFusionMerge.cpp index 4021aea448e..8f1b053a1b9 100644 --- a/src/Analyzers/RewriteFusionMerge.cpp +++ b/src/Analyzers/RewriteFusionMerge.cpp @@ -88,12 +88,12 @@ namespace UInt64 field_time = timestamp.safeGet(); if (field_time > mills_test) { - String date = DateLUT::instance().dateToString(field_time / 1000); + String date = DateLUT::serverTimezoneInstance().dateToString(field_time / 1000); return std::make_shared(Field(date)); } else { - String date = DateLUT::instance().dateToString(field_time); + String date = DateLUT::serverTimezoneInstance().dateToString(field_time); return std::make_shared(Field(date)); } } diff --git a/src/Analyzers/SimpleFunctionVisitor.cpp b/src/Analyzers/SimpleFunctionVisitor.cpp index 47da564ef8f..0aaa351f043 100644 --- a/src/Analyzers/SimpleFunctionVisitor.cpp +++ b/src/Analyzers/SimpleFunctionVisitor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include "common/types.h" namespace DB @@ -27,7 +28,11 @@ void SimpleFunctionVisitor::visit(ASTFunction * func) if ((func->name == "like" || func->name == "notLike") && func->arguments->children.size() == 2 && func->arguments->children[1]->as()) { - Field converted = convertFieldToType(func->arguments->children[1]->as()->value, DataTypeString()); + auto & pattern = func->arguments->children[1]->as()->value; + if (pattern.getType() != Field::Types::String) + return; + + Field converted = convertFieldToType(pattern, DataTypeString()); String text = converted.safeGet(); for (auto & s : text) diff --git a/src/Analyzers/SubstituteSelectItemToAnyFunction.cpp b/src/Analyzers/SubstituteSelectItemToAnyFunction.cpp new file mode 100644 index 00000000000..4ff06e59cd5 --- /dev/null +++ b/src/Analyzers/SubstituteSelectItemToAnyFunction.cpp @@ -0,0 +1,258 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +void SubstituteSelectItemToAnyFunction::visit(ASTPtr & ast) +{ + if (auto * select = ast->as()) + visit(select); + + for (auto & child : ast->children) + visit(child); +} + +static bool needVisit(const ASTPtr & node) +{ + return !(node->as() || node->as() || node->as() || node->as()); +} + +bool SubstituteSelectItemToAnyFunction::hasAggregate(ASTPtr & ast) +{ + if (auto * func = ast->as()) + { + auto function_type = getFunctionType(*func, context); + if (function_type == FunctionType::AGGREGATE_FUNCTION || function_type == FunctionType::GROUPING_OPERATION) + return true; + if (function_type == FunctionType::IN_SUBQUERY || function_type == FunctionType::EXISTS_SUBQUERY + || function_type == FunctionType::LAMBDA_EXPRESSION || function_type == FunctionType::WINDOW_FUNCTION) + return false; + } + if (!needVisit(ast)) + return false; + bool has_aggregate = false; + for (auto & child : ast->children) + has_aggregate |= hasAggregate(child); + return has_aggregate; +} + +void SubstituteSelectItemToAnyFunction::visit(ASTSelectQuery * select_query) +{ + bool has_aggregate = false; + bool has_group_by = false; + + if (select_query->groupBy()) + has_group_by = true; + ASTs select_expressions; + for (auto & select_item : select_query->refSelect()->children) + { + if (select_item->as() || select_item->as() || select_item->as() + || (select_item->as() && select_item->as()->name == "untuple")) + return; + has_aggregate |= hasAggregate(select_item); + + if (has_group_by) + select_expressions.push_back(select_item); + } + + std::unordered_set grouping_names; + QualifiedNames grouping_qualified_names; + if (has_group_by) + { + // get grouping + bool allow_group_by_position = context->getSettingsRef().enable_positional_arguments && !select_query->group_by_with_rollup + && !select_query->group_by_with_cube && !select_query->group_by_with_grouping_sets; + auto get_grouping_expressions = [&](ASTs & grouping_expr_list) { + for (ASTPtr grouping_expr : grouping_expr_list) + { + if (allow_group_by_position) + if (auto * literal = grouping_expr->as(); literal && literal->tryGetAlias().empty() + && // avoid aliased expr being interpreted as positional argument + // e.g. SELECT 1 AS a ORDER BY a + literal->value.getType() == Field::Types::UInt64) + { + auto index = literal->value.get(); + if (index > select_expressions.size() || index < 1) + return; + grouping_expr = select_expressions[index - 1]; + } + + if (auto * grouping_identifier = grouping_expr->as()) + { + auto grouping_prefix = QualifiedName::extractQualifiedName(*grouping_identifier); + grouping_qualified_names.emplace_back(grouping_prefix); + } + grouping_names.emplace(grouping_expr->getAliasOrColumnName()); + } + }; + + if (select_query->group_by_with_grouping_sets) + { + for (auto & grouping_set_element : select_query->groupBy()->children) + get_grouping_expressions(grouping_set_element->children); + } + else + { + get_grouping_expressions(select_query->groupBy()->children); + } + } + + if (!has_aggregate && !has_group_by) + return; + + // process select + NameAndQualifiedName processed_identifier_qualified_names; + NameSet aliases; + SubstituteIdentifierToAnyFunction::Data select_data{grouping_qualified_names, processed_identifier_qualified_names, aliases, context, true, true}; + SubstituteIdentifierToAnyFunction select_visitor(select_data); + + for (auto & select_item : select_query->refSelect()->children) + { + String name = select_item->getAliasOrColumnName(); + if (grouping_names.contains(name)) + { + } + else if (auto * select_identifier = select_item->as()) + { + select_visitor.setAddAlias(true); + if (aliases.contains(name) || !select_identifier->isShort()) + select_visitor.setAddAlias(false); + select_visitor.visit(select_item); + } + else if (select_item->as()) + { + select_visitor.setAddAlias(false); + select_visitor.visit(select_item); + } + if (!select_item->tryGetAlias().empty()) + aliases.emplace(select_item->tryGetAlias()); + } + + // process having and order by + if (!processed_identifier_qualified_names.empty()) + { + SubstituteIdentifierToAnyFunction::Data expression_data{{}, processed_identifier_qualified_names, {}, context, false, false}; + SubstituteIdentifierToAnyFunction expression_visitor(expression_data); + if (select_query->having()) + expression_visitor.visit(select_query->refHaving()); + if (select_query->orderBy()) + expression_visitor.visit(select_query->refOrderBy()); + } +} + +void SubstituteIdentifierToAnyFunction::visit(ASTIdentifier & node, ASTPtr & ast, Data & data) +{ + String name = node.getAliasOrColumnName(); + if (!data.identifier_aliases.contains(name) && data.aliases.contains(name)) + return; + auto qualified_name = QualifiedName::extractQualifiedName(node); + if (data.process_grouping) + { + bool group_matched = false; + for (const auto & grouping_qualified_name : data.grouping_qualified_names) + { + if (qualified_name.hasSuffix(grouping_qualified_name) || grouping_qualified_name.hasSuffix(qualified_name)) + { + group_matched = true; + break; + } + } + + if (!group_matched) + { + bool has_alias = !node.tryGetAlias().empty(); + if (has_alias) + node.setAlias(""); + ast = makeASTFunction("any", ast); + if (has_alias || data.add_alias) + ast->setAlias(name); + data.identifier_aliases.emplace(name); + data.processed_identifier_qualified_names.emplace_back(name, qualified_name); + } + } + else + { + for (const auto & [name_, it_qualified_name] : data.processed_identifier_qualified_names) + { + if (name_ == node.name() && node.isShort()) + return; + else if (it_qualified_name.hasSuffix(qualified_name) || qualified_name.hasSuffix(it_qualified_name)) + { + ast = makeASTFunction("any", ast); + return; + } + } + } +} + +void SubstituteIdentifierToAnyFunction::visitChildren(IAST * node, Data & data) +{ + if (auto * func_node = node->as()) + { + auto function_type = getFunctionType(*func_node, data.context); + if (function_type == FunctionType::AGGREGATE_FUNCTION || function_type == FunctionType::GROUPING_OPERATION) + return; + + if (func_node->tryGetQueryArgument()) + return; + /// We skip the first argument. We also assume that the lambda function can not have parameters. + size_t first_pos = 0; + if (func_node->name == "lambda") + first_pos = 1; + + if (func_node->arguments) + { + auto & func_children = func_node->arguments->children; + + for (size_t i = first_pos; i < func_children.size(); ++i) + { + auto & child = func_children[i]; + + if (needVisit(child)) + visit(child, data); + } + } + + if (func_node->window_definition) + { + visitChildren(func_node->window_definition.get(), data); + } + } + else if (!node->as() && !node->as() && !node->as()) + { + for (auto & child : node->children) + if (needVisit(child)) + visit(child, data); + } +} + +void SubstituteIdentifierToAnyFunction::visit(ASTPtr & ast, Data & data) +{ + if (auto * node_id = ast->as()) + { + visit(*node_id, ast, data); + return; + } + + visitChildren(ast.get(), data); +} + +} diff --git a/src/Analyzers/SubstituteSelectItemToAnyFunction.h b/src/Analyzers/SubstituteSelectItemToAnyFunction.h new file mode 100644 index 00000000000..b4763a4164a --- /dev/null +++ b/src/Analyzers/SubstituteSelectItemToAnyFunction.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "Core/Names.h" +#include "QueryPlan/SymbolAllocator.h" + +namespace DB +{ +class ASTFunction; +struct QualifiedName; +class ASTSelectQuery; + +class SubstituteSelectItemToAnyFunction +{ +public: + explicit SubstituteSelectItemToAnyFunction(ContextMutablePtr context_) : context(std::move(context_)) + { + } + + void visit(ASTPtr & ast); + void visit(ASTSelectQuery * select_query); + + bool hasAggregate(ASTPtr & ast); + +private: + ContextMutablePtr context; +}; + +using QualifiedNames = std::vector; +using NameAndQualifiedName = std::vector>; +class SubstituteIdentifierToAnyFunction +{ +public: + struct Data + { + const QualifiedNames & grouping_qualified_names; + NameAndQualifiedName & processed_identifier_qualified_names; + const NameSet & aliases; + ContextPtr context; + bool add_alias; + bool process_grouping; + NameSet identifier_aliases; + explicit Data(const QualifiedNames & grouping_qualified_names_ + , NameAndQualifiedName & processed_identifier_qualified_names_ + , const NameSet & aliases_ + , ContextPtr context_ + , bool add_alias_ + , bool process_grouping_) + : grouping_qualified_names(grouping_qualified_names_) + , processed_identifier_qualified_names(processed_identifier_qualified_names_) + , aliases(aliases_) + , context(context_) + , add_alias(add_alias_) + , process_grouping(process_grouping_) + { + identifier_aliases = {}; + } + }; + + explicit SubstituteIdentifierToAnyFunction(Data & data) + : visitor_data(data) + {} + + void setAddAlias(bool add_alias){ visitor_data.add_alias = add_alias; } + void visit(ASTPtr & ast) + { + visit(ast, visitor_data); + } + +private: + Data & visitor_data; + + static void visit(ASTPtr & ast, Data & data); + static void visit(ASTIdentifier &, ASTPtr &, Data &); + + static void visitChildren(IAST * node, Data & data); +}; + + +} diff --git a/src/Analyzers/TypeAnalyzer.cpp b/src/Analyzers/TypeAnalyzer.cpp index 0feef280584..f059ac7e858 100644 --- a/src/Analyzers/TypeAnalyzer.cpp +++ b/src/Analyzers/TypeAnalyzer.cpp @@ -53,7 +53,8 @@ DataTypePtr TypeAnalyzer::getType(const ConstASTPtr & expr) const Analysis analysis; ExprAnalyzerOptions options; options.expandUntuple(false); - return ExprAnalyzer::analyze(REMOVE_CONST(expr), &scope, context, analysis, options); + ASTPtr tmp_ast = REMOVE_CONST(expr); + return ExprAnalyzer::analyze(tmp_ast, &scope, context, analysis, options); } DataTypePtr TypeAnalyzer::getTypeWithoutCheck(const ConstASTPtr & expr) const @@ -63,7 +64,8 @@ DataTypePtr TypeAnalyzer::getTypeWithoutCheck(const ConstASTPtr & expr) const options.expandUntuple(false); options.aggregateSupport(ExprAnalyzerOptions::AggregateSupport::ALLOWED); options.windowSupport(ExprAnalyzerOptions::WindowSupport::ALLOWED); - return ExprAnalyzer::analyze(REMOVE_CONST(expr), &scope, context, analysis, options); + ASTPtr tmp_ast = REMOVE_CONST(expr); + return ExprAnalyzer::analyze(tmp_ast, &scope, context, analysis, options); } ExpressionTypes TypeAnalyzer::getExpressionTypes(const ConstASTPtr & expr) const @@ -71,7 +73,8 @@ ExpressionTypes TypeAnalyzer::getExpressionTypes(const ConstASTPtr & expr) const Analysis analysis; ExprAnalyzerOptions options; options.expandUntuple(false); - ExprAnalyzer::analyze(REMOVE_CONST(expr), &scope, context, analysis, options); + ASTPtr tmp_ast = REMOVE_CONST(expr); + ExprAnalyzer::analyze(tmp_ast, &scope, context, analysis, options); return analysis.getExpressionTypes(); } diff --git a/src/Analyzers/function_utils.cpp b/src/Analyzers/function_utils.cpp index 7ea1afa88d6..6cfaa1bc728 100644 --- a/src/Analyzers/function_utils.cpp +++ b/src/Analyzers/function_utils.cpp @@ -69,8 +69,8 @@ ASTPtr getLambdaExpressionBody(ASTFunction & lambda) bool isComparisonFunction(const ASTFunction & function) { - return function.name == "equals" || function.name == "less" || function.name == "lessOrEquals" || function.name == "greater" - || function.name == "greaterOrEquals"; + return function.name == "equals" || function.name == "bitEquals" || function.name == "less" || function.name == "lessOrEquals" + || function.name == "greater" || function.name == "greaterOrEquals"; } bool functionIsInSubquery(const ASTFunction & function) diff --git a/src/Analyzers/resolveNamesAsMySQL.cpp b/src/Analyzers/resolveNamesAsMySQL.cpp index 49a6c8b838d..702deadfe53 100644 --- a/src/Analyzers/resolveNamesAsMySQL.cpp +++ b/src/Analyzers/resolveNamesAsMySQL.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -55,6 +56,11 @@ namespace { } void rewriteName(ASTPtr & ast, const ASTPtr & root_expression); + void rewriteChildren(ASTPtr & ast, const ASTPtr & root_expression) + { + for (auto & child : ast->children) + rewriteName(child, root_expression); + } private: std::vector levels; @@ -73,11 +79,12 @@ namespace break; } } - else if ((ast->as() && ast->as()->name != "lambda") || ast->as()) - { - for (auto & child : ast->children) - rewriteName(child, root_expression); - } + else if (const auto * func = ast->as(); func + && !AggregateUtils::isAggregateFunction(*func) /* prefer source column under aggregate function */ + && func->name != "lambda") + rewriteChildren(ast, root_expression); + else if (ast->as()) + rewriteChildren(ast, root_expression); } void collectNamedExpressions(const ASTPtr & expression, NamedExpressions & named_expressions) diff --git a/src/Catalog/Catalog.cpp b/src/Catalog/Catalog.cpp index da4650d4627..aca8e07d0fc 100644 --- a/src/Catalog/Catalog.cpp +++ b/src/Catalog/Catalog.cpp @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include #include @@ -112,6 +114,14 @@ namespace ProfileEvents extern const Event GetSQLBindingsFailed; extern const Event RemoveSQLBindingSuccess; extern const Event RemoveSQLBindingFailed; + extern const Event UpdatePreparedStatementSuccess; + extern const Event UpdatePreparedStatementFailed; + extern const Event GetPreparedStatementSuccess; + extern const Event GetPreparedStatementFailed; + extern const Event GetPreparedStatementsSuccess; + extern const Event GetPreparedStatementsFailed; + extern const Event RemovePreparedStatementSuccess; + extern const Event RemovePreparedStatementFailed; extern const Event CreateDatabaseSuccess; extern const Event CreateDatabaseFailed; extern const Event GetDatabaseSuccess; @@ -276,6 +286,10 @@ namespace ProfileEvents extern const Event GetAllUndoBufferFailed; extern const Event GetUndoBufferIteratorSuccess; extern const Event GetUndoBufferIteratorFailed; + extern const Event GetUndoBuffersWithKeysSuccess; + extern const Event GetUndoBuffersWithKeysFailed; + extern const Event ClearUndoBuffersByKeysSuccess; + extern const Event ClearUndoBuffersByKeysFailed; extern const Event GetTransactionRecordsSuccess; extern const Event GetTransactionRecordsFailed; extern const Event GetTransactionRecordsTxnIdsSuccess; @@ -514,6 +528,7 @@ namespace ErrorCodes extern const int UNKNOWN_MASKING_POLICY_NAME; extern const int BUCKET_TABLE_ENGINE_MISMATCH; extern const int ACCESS_ENTITY_ALREADY_EXISTS; + extern const int METASTORE_ERROR_KEY; } namespace Catalog @@ -557,6 +572,9 @@ namespace Catalog topology_key = name_space; else topology_key = name_space + "_" + config.topology_key; + + // Add background task to do some GC job for Catalog. + bg_task = std::make_shared(Context::createCopy(context.shared_from_this()), meta_proxy->getMetastore(), name_space); }, ProfileEvents::CatalogConstructorSuccess, ProfileEvents::CatalogConstructorFailed); @@ -957,7 +975,11 @@ namespace Catalog return; } - auto storage = tryGetTableByUUID(context, UUIDHelpers::UUIDToString(uuid), TxnTimestamp::maxTS()); + StoragePtr storage; + if (auto query_context = CurrentThread::getGroup()->query_context.lock()) + storage = tryGetTableByUUID(*query_context, UUIDHelpers::UUIDToString(uuid), TxnTimestamp::maxTS()); + else + storage = tryGetTableByUUID(context, UUIDHelpers::UUIDToString(uuid), TxnTimestamp::maxTS()); if (auto pcm = context.getPartCacheManager(); pcm && storage) { @@ -1145,7 +1167,7 @@ namespace Catalog // Set cluster status after Alter table is successful to update PartCacheManager with new table metadata if (is_modify_cluster_by) - setTableClusterStatus(storage->getStorageUUID(), false, new_table->getTableHashForClusterBy().getDeterminHash()); + setTableClusterStatus(storage->getStorageUUID(), false, new_table->getTableHashForClusterBy()); if (auto cache_manager = context.getPartCacheManager(); cache_manager) { @@ -1159,7 +1181,7 @@ namespace Catalog { // update cache with nullptr and latest table commit_time to prevent an old version be inserted into cache. // the cache will be reloaded in following getTable - cache_manager->insertStorageCache(storage->getStorageID(), nullptr, table->commit_time(), host_port.topology_version); + cache_manager->insertStorageCache(storage->getStorageID(), nullptr, table->commit_time(), host_port.topology_version, query_context); } } }, @@ -1255,7 +1277,7 @@ namespace Catalog /// update table name in table meta entry so that we can get table part metrics correctly. if (auto cache_manager = context.getPartCacheManager(); cache_manager && is_local_server) { - cache_manager->insertStorageCache(StorageID{from_database, from_table, UUIDHelpers::toUUID(table_uuid)}, nullptr, ts, host_port.topology_version); + cache_manager->insertStorageCache(StorageID{from_database, from_table, UUIDHelpers::toUUID(table_uuid)}, nullptr, ts, host_port.topology_version, context); cache_manager->updateTableNameInMetaEntry(table_uuid, to_database, to_table); } @@ -1327,7 +1349,6 @@ namespace Catalog throw Exception("Table not found: " + database + "." + name, ErrorCodes::UNKNOWN_TABLE); } - auto cache_manager = context.getPartCacheManager(); bool is_host_server = false; const auto host_server = context.getCnchTopologyMaster()->getTargetServer(table_id->uuid(), getServerVwNameFrom(*table_id), true); @@ -1337,10 +1358,12 @@ namespace Catalog if (is_host_server && cache_manager) { - auto cached_storage = cache_manager->getStorageFromCache(UUIDHelpers::toUUID(table_id->uuid()), host_server.topology_version); + auto cached_storage = cache_manager->getStorageFromCache(UUIDHelpers::toUUID(table_id->uuid()), host_server.topology_version, query_context); if (cached_storage && cached_storage->commit_time <= ts && cached_storage->getStorageID().database_name == database && cached_storage->getStorageID().table_name == name) { res = cached_storage; + //TODO:(@lianwenlong) force fetch global object schema from catalog + initStorageObjectSchema(res); return; } } @@ -1365,7 +1388,7 @@ namespace Catalog /// Try insert the storage into cache. if (res && is_host_server && cache_manager) - cache_manager->insertStorageCache(res->getStorageID(), res, table->commit_time(), host_server.topology_version); + cache_manager->insertStorageCache(res->getStorageID(), res, table->commit_time(), host_server.topology_version, query_context); }, ProfileEvents::GetTableSuccess, ProfileEvents::GetTableFailed); @@ -1407,7 +1430,7 @@ namespace Catalog { if (current_topology_version != PairInt64(0, 0)) { - auto cached_storage = cache_manager->getStorageFromCache(UUIDHelpers::toUUID(uuid), current_topology_version); + auto cached_storage = cache_manager->getStorageFromCache(UUIDHelpers::toUUID(uuid), current_topology_version, query_context); if (cached_storage && cached_storage->commit_time <= ts) { auto host_server = current_topology.getTargetServer(uuid, cached_storage->getServerVwName()); @@ -1438,7 +1461,7 @@ namespace Catalog { auto host_server = current_topology.getTargetServer(uuid, res->getServerVwName()); if (!host_server.empty() && isLocalServer(host_server.getRPCAddress(), std::to_string(context.getRPCPort()))) - cache_manager->insertStorageCache(res->getStorageID(), res, table->commit_time(), current_topology_version); + cache_manager->insertStorageCache(res->getStorageID(), res, table->commit_time(), current_topology_version, query_context); } }, ProfileEvents::TryGetTableByUUIDSuccess, @@ -2381,7 +2404,7 @@ namespace Catalog { if (!part->deleted && !table_definition_hash.match(part->table_definition_hash)) { - setTableClusterStatus(storage->getStorageUUID(), false, table_definition_hash.getDeterminHash()); + setTableClusterStatus(storage->getStorageUUID(), false, table_definition_hash); break; } } @@ -2719,7 +2742,7 @@ namespace Catalog return; } } - getPartitionsFromMetastore(*cnch_table, partitions); + getPartitionsFromMetastore(*cnch_table, partitions, nullptr); } for (auto it = partitions.begin(); it != partitions.end(); it++) @@ -2779,7 +2802,12 @@ namespace Catalog return partition_ids; } - PrunedPartitions Catalog::getPartitionsByPredicate(ContextPtr session_context, const ConstStoragePtr & storage, const SelectQueryInfo & query_info, const Names & column_names_to_return) + PrunedPartitions Catalog::getPartitionsByPredicate( + ContextPtr session_context, + const ConstStoragePtr & storage, + const SelectQueryInfo & query_info, + const Names & column_names_to_return, + const bool & ignore_ttl) { PrunedPartitions pruned_partitions; auto getPartitionsLocally = [&]() @@ -2789,7 +2817,7 @@ namespace Catalog return; auto all_partitions = getPartitionList(storage, nullptr); pruned_partitions.total_partition_number = all_partitions.size(); - pruned_partitions.partitions = cnch_mergetree->selectPartitionsByPredicate(query_info, all_partitions, column_names_to_return, session_context); + pruned_partitions.partitions = cnch_mergetree->selectPartitionsByPredicate(query_info, all_partitions, column_names_to_return, session_context, ignore_ttl); }; const auto host_port = context.getCnchTopologyMaster()->getTargetServer( UUIDHelpers::UUIDToString(storage->getStorageUUID()), storage->getServerVwName(), true); @@ -2800,7 +2828,7 @@ namespace Catalog try { auto host_with_rpc = host_port.getRPCAddress(); - pruned_partitions = context.getCnchServerClientPool().get(host_with_rpc)->fetchPartitions(host_with_rpc, storage, query_info, column_names_to_return, session_context->getCurrentTransactionID()); + pruned_partitions = context.getCnchServerClientPool().get(host_with_rpc)->fetchPartitions(host_with_rpc, storage, query_info, column_names_to_return, session_context->getCurrentTransactionID(), ignore_ttl); LOG_TRACE(log, "Fetched {}/{} partitions from remote host {}", pruned_partitions.partitions.size(), pruned_partitions.total_partition_number, host_port.toDebugString()); } catch (...) @@ -2816,8 +2844,9 @@ namespace Catalog } - template - void Catalog::getPartitionsFromMetastore(const MergeTreeMetaBase & table, Map & partition_list) + template + void + Catalog::getPartitionsFromMetastore(const MergeTreeMetaBase & table, Map & partition_list, std::shared_ptr lock_holder) { runWithMetricSupport( [&] { @@ -2830,7 +2859,8 @@ namespace Catalog Protos::PartitionMeta partition_meta; partition_meta.ParseFromString(it->value()); auto partition_ptr = createPartitionFromMetaModel(table, partition_meta); - auto partition_info = std::make_shared(table_uuid, partition_ptr, partition_meta.id()); + auto partition_lock = lock_holder ? (*lock_holder).getPartitionLock(partition_meta.id()) : nullptr; + auto partition_info = std::make_shared(table_uuid, partition_ptr, partition_meta.id(), partition_lock); if (partition_meta.has_gctime()) partition_info->gctime = partition_meta.gctime(); partition_list.emplace(partition_meta.id(), std::move(partition_info)); @@ -2840,8 +2870,9 @@ namespace Catalog ProfileEvents::GetPartitionsFromMetastoreFailed); } - template void Catalog::getPartitionsFromMetastore(const MergeTreeMetaBase &, PartitionMap &); - template void Catalog::getPartitionsFromMetastore>(const MergeTreeMetaBase &, ScanWaitFreeMap &); + template void Catalog::getPartitionsFromMetastore(const MergeTreeMetaBase &, PartitionMap &, std::shared_ptr); + template void Catalog::getPartitionsFromMetastore>( + const MergeTreeMetaBase &, ScanWaitFreeMap &, std::shared_ptr); Strings Catalog::getPartitionIDsFromMetastore(const ConstStoragePtr & storage) { @@ -3462,7 +3493,7 @@ namespace Catalog { if (!part->deleted && !table_definition_hash.match(part->table_definition_hash)) { - setTableClusterStatus(table->getStorageUUID(), false, table_definition_hash.getDeterminHash()); + setTableClusterStatus(table->getStorageUUID(), false, table_definition_hash); break; } } @@ -3499,6 +3530,8 @@ namespace Catalog if (!part_models.parts().empty()) context.getPartCacheManager()->insertDataPartsIntoCache( *table, part_models.parts(), is_merged_parts, false, host_port.topology_version); + if (!staged_part_models.parts().empty()) + context.getPartCacheManager()->insertStagedPartsIntoCache(*table, staged_part_models.parts(), host_port.topology_version); if (!commit_data.delete_bitmaps.empty()) { context.getPartCacheManager()->insertDeleteBitmapsIntoCache( @@ -3905,6 +3938,53 @@ namespace Catalog ProfileEvents::ClearUndoBufferFailed); } + void Catalog::clearUndoBuffersByKeys(const TxnTimestamp & txnID, const std::vector & keys) + { + runWithMetricSupport( + [&] { + const String undo_buffer_key_prefix = meta_proxy->undoBufferKeyPrefix(name_space, txnID, false); + const String undo_buffer_key_prefix_rev = meta_proxy->undoBufferKeyPrefix(name_space, txnID, true); + + BatchCommitRequest batch_writes; + for (const auto & key : keys) + { + if (!key.starts_with(undo_buffer_key_prefix) && !key.starts_with(undo_buffer_key_prefix_rev)) + { + throw Exception(ErrorCodes::METASTORE_ERROR_KEY, "Expected key {} but receive {}", undo_buffer_key_prefix, key); + } + batch_writes.AddDelete(key); + } + + BatchCommitResponse resp; + meta_proxy->batchWrite(batch_writes, resp); + }, + ProfileEvents::ClearUndoBuffersByKeysSuccess, + ProfileEvents::ClearUndoBuffersByKeysFailed); + } + + std::unordered_map, UndoResources>> Catalog::getUndoBuffersWithKeys(const TxnTimestamp & txnID) + { + std::unordered_map, UndoResources>> res; + runWithMetricSupport( + [&] { + auto get_func = [&](bool write_undo_buffer_new_key) { + auto it = meta_proxy->getUndoBuffer(name_space, txnID.toUInt64(), write_undo_buffer_new_key); + while (it->next()) + { + UndoResource resource = UndoResource::deserialize(it->value()); + resource.txn_id = txnID; + res[resource.uuid()].first.emplace_back(it->key()); + res[resource.uuid()].second.emplace_back(std::move(resource)); + } + }; + /// Get both old and new undo buffer keys; + get_func(true); + get_func(false); + }, + ProfileEvents::GetUndoBuffersWithKeysSuccess, + ProfileEvents::GetUndoBuffersWithKeysFailed); + return res; + } std::unordered_map Catalog::getUndoBuffer(const TxnTimestamp & txnID) { std::unordered_map res; @@ -4165,7 +4245,7 @@ namespace Catalog return getTransactionRecords(std::vector(txn_ids.begin(), txn_ids.end()), 100000); } - std::vector Catalog::getTransactionRecordsForGC(size_t max_result_number) + std::vector Catalog::getTransactionRecordsForGC(String & start_key, size_t max_result_number) { std::vector res; /// if exception occurs during get txn record, just return the partial result; @@ -4174,9 +4254,20 @@ namespace Catalog [&] { try { - auto it = meta_proxy->getAllTransactionRecord(name_space, max_result_number); + auto it = meta_proxy->getAllTransactionRecord(name_space, start_key, max_result_number); - while (it->next()) + if (!it->next()) + { + if (start_key.empty()) + return; + + start_key.clear(); + auto it = meta_proxy->getAllTransactionRecord(name_space, start_key, max_result_number); + if (!it->next()) + return; + } + + do { auto record = TransactionRecord::deserialize(it->value()); if (record.isSecondary()) @@ -4203,7 +4294,15 @@ namespace Catalog } res.push_back(std::move(record)); } - } + + } while (it->next()); + + // Save key so we can resume iteration in the next call. + if (!res.empty()) + start_key = meta_proxy->transactionRecordKey(name_space, res.back().txnID()); + + if (res.size() < max_result_number || max_result_number == 0) + start_key.clear(); } catch (...) { @@ -5169,6 +5268,7 @@ namespace Catalog void Catalog::createMutation(const StorageID & storage_id, const String & mutation_name, const String & mutate_text) { + LOG_TRACE(log, "createMutation: {}, {}", storage_id.getNameForLogs(), mutation_name); runWithMetricSupport( [&] { meta_proxy->createMutation(name_space, UUIDHelpers::UUIDToString(storage_id.uuid), mutation_name, mutate_text); }, ProfileEvents::CreateMutationSuccess, @@ -5177,6 +5277,7 @@ namespace Catalog void Catalog::removeMutation(const StorageID & storage_id, const String & mutation_name) { + LOG_TRACE(log, "removeMutation: {}, {}", storage_id.getNameForLogs(), mutation_name); runWithMetricSupport( [&] { meta_proxy->removeMutation(name_space, UUIDHelpers::UUIDToString(storage_id.uuid), mutation_name); }, ProfileEvents::RemoveMutationSuccess, @@ -5223,14 +5324,15 @@ namespace Catalog } } - void Catalog::setTableClusterStatus(const UUID & table_uuid, const bool clustered, const UInt64 & table_definition_hash) + void Catalog::setTableClusterStatus(const UUID & table_uuid, const bool clustered, const TableDefinitionHash & table_definition_hash) { + LOG_TRACE(log, "setTableClusterStatus: {} to {}", UUIDHelpers::UUIDToString(table_uuid), clustered); runWithMetricSupport( [&] { - meta_proxy->setTableClusterStatus(name_space, UUIDHelpers::UUIDToString(table_uuid), clustered, table_definition_hash); + meta_proxy->setTableClusterStatus(name_space, UUIDHelpers::UUIDToString(table_uuid), clustered, table_definition_hash.getDeterminHash()); /// keep the cache status up to date. if (context.getPartCacheManager()) - context.getPartCacheManager()->setTableClusterStatus(table_uuid, clustered); + context.getPartCacheManager()->setTableClusterStatus(table_uuid, clustered, table_definition_hash); }, ProfileEvents::SetTableClusterStatusSuccess, ProfileEvents::SetTableClusterStatusFailed); @@ -6230,7 +6332,12 @@ namespace Catalog for (const String & dependency : dependencies) batch_write.AddDelete(MetastoreProxy::viewDependencyKey(name_space, dependency, table_id.uuid())); - batch_write.AddPut(SinglePutRequest(MetastoreProxy::tableStoreKey(name_space, table_id.uuid(), ts.toUInt64()), table.SerializeAsString())); + addPotentialLargeKVToBatchwrite( + meta_proxy->getMetastore(), + batch_write, + name_space, + MetastoreProxy::tableStoreKey(name_space, table_id.uuid(), ts.toUInt64()), + table.SerializeAsString()); // use database name and table name in table_id is required because it may different with that in table data model. batch_write.AddPut(SinglePutRequest( MetastoreProxy::tableTrashKey(name_space, table_id.database(), table_id.name(), ts.toUInt64()), table_id.SerializeAsString())); @@ -6623,6 +6730,42 @@ namespace Catalog ProfileEvents::RemoveSQLBindingFailed); } + void Catalog::updatePreparedStatement(const PreparedStatementItemPtr & data) + { + runWithMetricSupport( + [&] { meta_proxy->updatePreparedStatement(name_space, data); }, + ProfileEvents::UpdatePreparedStatementSuccess, + ProfileEvents::UpdatePreparedStatementFailed); + } + + PreparedStatements Catalog::getPreparedStatements() + { + PreparedStatements res; + runWithMetricSupport( + [&] { res = meta_proxy->getPreparedStatements(name_space); }, + ProfileEvents::GetPreparedStatementSuccess, + ProfileEvents::GetPreparedStatementFailed); + return res; + } + + PreparedStatementItemPtr Catalog::getPreparedStatement(const String & name) + { + PreparedStatementItemPtr res; + runWithMetricSupport( + [&] { res = meta_proxy->getPreparedStatement(name_space, name); }, + ProfileEvents::GetPreparedStatementSuccess, + ProfileEvents::GetPreparedStatementFailed); + return res; + } + + void Catalog::removePreparedStatement(const String & name) + { + runWithMetricSupport( + [&] { meta_proxy->removePreparedStatement(name_space, name); }, + ProfileEvents::RemovePreparedStatementSuccess, + ProfileEvents::RemovePreparedStatementFailed); + } + void Catalog::setMergeMutateThreadStartTime(const StorageID & storage_id, const UInt64 & startup_time) const { meta_proxy->setMergeMutateThreadStartTime(name_space, UUIDHelpers::UUIDToString(storage_id.uuid), startup_time); diff --git a/src/Catalog/Catalog.h b/src/Catalog/Catalog.h index ec3d2ffc64b..a578b61097b 100644 --- a/src/Catalog/Catalog.h +++ b/src/Catalog/Catalog.h @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include "common/types.h" @@ -49,7 +50,7 @@ #include #include #include -// #include +#include namespace DB::ErrorCodes { @@ -81,6 +82,8 @@ enum class VisibilityLevel All }; +class CatalogBackgroundTask; + class Catalog { public: @@ -129,7 +132,7 @@ class Catalog ////////////// - void updateSQLBinding(const SQLBindingItemPtr data); + void updateSQLBinding(SQLBindingItemPtr data); SQLBindings getSQLBindings(); @@ -139,6 +142,14 @@ class Catalog void removeSQLBinding(const String & uuid, const String & tenant_id, const bool & is_re_expression); + void updatePreparedStatement(const PreparedStatementItemPtr & data); + + PreparedStatements getPreparedStatements(); + + PreparedStatementItemPtr getPreparedStatement(const String & name); + + void removePreparedStatement(const String & name); + ///////////////////////////// /// Database related API ///////////////////////////// @@ -343,12 +354,13 @@ class Catalog std::vector getLastModificationTimeHints(const ConstStoragePtr & table); - template - void getPartitionsFromMetastore(const MergeTreeMetaBase & table, Map & partition_list); + /// Caller should garrantee that `lock_holder` lives longer than this call. + template + void getPartitionsFromMetastore(const MergeTreeMetaBase & table, Map & partition_list, std::shared_ptr lock_holder); Strings getPartitionIDs(const ConstStoragePtr & storage, const Context * session_context); - PrunedPartitions getPartitionsByPredicate(ContextPtr session_context, const ConstStoragePtr & storage, const SelectQueryInfo & query_info, const Names & column_names_to_return); + PrunedPartitions getPartitionsByPredicate(ContextPtr session_context, const ConstStoragePtr & storage, const SelectQueryInfo & query_info, const Names & column_names_to_return, const bool & ignore_ttl); /// dictionary related APIs void createDictionary(const StorageID & storage_id, const String & create_query); @@ -466,6 +478,21 @@ class Catalog /// clear undo buffer void clearUndoBuffer(const TxnTimestamp & txnID, const String & rpc_address, PlanSegmentInstanceId instance_id); + /** + * @brief Clean all undo buffers with given keys (in the same table). + * + * @param txnID Currently, this will be used to verify if the undo buffers are in the same table. + * @param keys Keys of the undo buffers. + */ + void clearUndoBuffersByKeys(const TxnTimestamp & txnID, const std::vector & keys); + + /** + * @brief get Undo Buffers with there keys (in metastore). These keys can be further used to manipulate the data. + * + * @param txnID Transaction ID. + * @return map + */ + std::unordered_map, UndoResources>> getUndoBuffersWithKeys(const TxnTimestamp & txnID); /// return storage uuid -> undo resources std::unordered_map getUndoBuffer(const TxnTimestamp & txnID); std::unordered_map @@ -514,7 +541,7 @@ class Catalog std::vector getTransactionRecords(const std::vector & txn_ids, size_t batch_size = 0); /// clean zombie records. If the total transaction record number is too large, it may be impossible to get all of them. We can /// pass a max_result_number to only get part of them and clean zombie records repeatedlly - std::vector getTransactionRecordsForGC(size_t max_result_number); + std::vector getTransactionRecordsForGC(String & start_key, size_t max_result_number); TransactionRecords getTransactionRecords(const ServerDataPartsVector & parts, const DeleteBitmapMetaPtrVector & bitmaps); /// Clear intents written by zombie transaction. @@ -647,7 +674,7 @@ class Catalog std::multimap getAllMutations(); void fillMutationsByStorage(const StorageID & storage_id, std::map & out_mutations); - void setTableClusterStatus(const UUID & table_uuid, const bool clustered, const UInt64 & table_definition_hash); + void setTableClusterStatus(const UUID & table_uuid, const bool clustered, const TableDefinitionHash & table_definition_hash); void getTableClusterStatus(const UUID & table_uuid, bool & clustered); bool isTableClustered(const UUID & table_uuid); @@ -889,6 +916,8 @@ class Catalog void commitCheckpointVersion(const UUID & uuid, std::shared_ptr checkpoint_version); void cleanTableVersions(const UUID & uuid, std::vector> versions_to_clean); + void shutDown() {bg_task.reset();} + private: Poco::Logger * log = &Poco::Logger::get("Catalog"); Context & context; @@ -900,6 +929,8 @@ class Catalog std::mutex all_storage_nhut_mutex; CatalogSettings settings; + std::shared_ptr bg_task; + std::shared_ptr tryGetDatabaseFromMetastore(const String & database, const UInt64 & ts); std::shared_ptr tryGetTableFromMetastore(const String & table_uuid, const UInt64 & ts, bool with_prev_versions = false, bool with_deleted = false); diff --git a/src/Catalog/CatalogBackgroundTask.cpp b/src/Catalog/CatalogBackgroundTask.cpp new file mode 100644 index 00000000000..912f5811f4f --- /dev/null +++ b/src/Catalog/CatalogBackgroundTask.cpp @@ -0,0 +1,118 @@ +#include +#include +#include +#include + + +namespace DB +{ + +namespace Catalog +{ + +CatalogBackgroundTask::CatalogBackgroundTask( + const ContextPtr & context_, + const std::shared_ptr & metastore_, + const String & name_space_) + : context(context_), + metastore(metastore_), + name_space(name_space_) +{ + task_holder = context->getSchedulePool().createTask( + "CatalogBGTask", + [this](){ + execute(); + } + ); + + task_holder->activate(); + // wait for server startup + task_holder->scheduleAfter(30*1000); +} + +CatalogBackgroundTask::~CatalogBackgroundTask() +{ + try + { + task_holder->deactivate(); + } + catch (...) + { + tryLogCurrentException(log); + } +} + +void CatalogBackgroundTask::execute() +{ + // only server can perform catalog bg task + if (context->getServerType() != ServerType::cnch_server) + return; + + LOG_DEBUG(log, "Try execute catalog bg task."); + try + { + cleanStaleLargeKV(); + } + catch (...) + { + tryLogCurrentException(log, "Exception happens while executing catalog bg task."); + } + + // execute every 1 hour. + task_holder->scheduleAfter(60*60*1000); +} + +void CatalogBackgroundTask::cleanStaleLargeKV() +{ + // only leader can execute clean job + if (!context->getCnchServerManager()->isLeader()) + return; + + // scan large kv records + std::unordered_map uuid_to_key; + String large_kv_reference_prefix = MetastoreProxy::largeKVReferencePrefix(name_space); + auto it = metastore->getByPrefix(large_kv_reference_prefix); + + while (it->next()) + { + String uuid = it->key().substr(large_kv_reference_prefix.size()); + uuid_to_key.emplace(uuid, it->value()); + } + + // check for each large KV if still been referenced by stored key + for (const auto & [uuid, key] : uuid_to_key) + { + String value; + metastore->get(key, value); + if (!value.empty()) + { + Protos::DataModelLargeKVMeta large_kv_model; + if (tryParseLargeKVMetaModel(value, large_kv_model) && large_kv_model.uuid() == uuid) + continue; + } + + // remove large KV because it is not been referenced by original key + BatchCommitRequest batch_write; + BatchCommitResponse resp; + + auto large_kv_it = metastore->getByPrefix(MetastoreProxy::largeKVDataPrefix(name_space, uuid)); + while (large_kv_it->next()) + batch_write.AddDelete(large_kv_it->key()); + + batch_write.AddDelete(MetastoreProxy::largeKVReferenceKey(name_space, uuid)); + + try + { + metastore->batchWrite(batch_write, resp); + LOG_DEBUG(log, "Removed large KV(uuid: {}) from metastore.", uuid); + } + catch (...) + { + tryLogCurrentException(log, "Error occurs while removing large kv."); + } + } +} + +} + +} diff --git a/src/Catalog/CatalogBackgroundTask.h b/src/Catalog/CatalogBackgroundTask.h new file mode 100644 index 00000000000..d252cf040ec --- /dev/null +++ b/src/Catalog/CatalogBackgroundTask.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ + +namespace Catalog +{ + +class CatalogBackgroundTask +{ + +public: + CatalogBackgroundTask( + const ContextPtr & context_, + const std::shared_ptr & metastore_, + const String & name_space_); + + ~CatalogBackgroundTask(); + + void execute(); + +private: + + void cleanStaleLargeKV(); + + Poco::Logger * log = &Poco::Logger::get("CatalogBGTask"); + + ContextPtr context; + std::shared_ptr metastore; + String name_space; + + BackgroundSchedulePool::TaskHolder task_holder; +}; + +} + +} diff --git a/src/Catalog/CatalogMetricHelper.h b/src/Catalog/CatalogMetricHelper.h index ab55a2f033f..ccf9d597715 100644 --- a/src/Catalog/CatalogMetricHelper.h +++ b/src/Catalog/CatalogMetricHelper.h @@ -26,16 +26,16 @@ namespace Catalog { using Job = std::function; - static void runWithMetricSupport(const Job & job, const ProfileEvents::Event & /*success*/, const ProfileEvents::Event & /*failed*/) + static void runWithMetricSupport(const Job & job, const ProfileEvents::Event & success, const ProfileEvents::Event & failed) { try { job(); - //ProfileEvents::increment(success); + ProfileEvents::increment(success); } catch (...) { - //ProfileEvents::increment(failed); + ProfileEvents::increment(failed); throw; } } diff --git a/src/Catalog/DataModelPartWrapper.cpp b/src/Catalog/DataModelPartWrapper.cpp index 86315896629..91ef3610a73 100644 --- a/src/Catalog/DataModelPartWrapper.cpp +++ b/src/Catalog/DataModelPartWrapper.cpp @@ -196,14 +196,23 @@ void ServerDataPart::setVirtualPartSize(const UInt64 & vp_size) const { virtual_ UInt64 ServerDataPart::getVirtualPartSize() const { return virtual_part_size; } -UInt64 ServerDataPart::deletedRowsCount(const MergeTreeMetaBase & storage) const +UInt64 ServerDataPart::deletedRowsCount(const MergeTreeMetaBase & storage, bool ignore_error) const { UInt64 res = 0; /// For unique table, deletedRowsCount is calculated from delete_bitmap. if (storage.getInMemoryMetadataPtr()->hasUniqueKey()) { if (delete_bitmap_metas.empty()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Delete bitmap meta for part {} is empty whose engine is unique table, it's a bug!", name()); + { + if (ignore_error) + { + LOG_DEBUG(storage.getLogger(), "Delete bitmap meta for part {} is empty whose engine is unique table, it's a bug!", name()); + return 0; + } + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Delete bitmap meta for part {} is empty whose engine is unique table, it's a bug!", name()); + + } for (const auto & delete_bitmap_meta: delete_bitmap_metas) res += delete_bitmap_meta->cardinality(); diff --git a/src/Catalog/DataModelPartWrapper.h b/src/Catalog/DataModelPartWrapper.h index 72fb4420246..ea20a1b5e06 100644 --- a/src/Catalog/DataModelPartWrapper.h +++ b/src/Catalog/DataModelPartWrapper.h @@ -94,7 +94,7 @@ class ServerDataPart : public std::enable_shared_from_this, publ mutable std::forward_list delete_bitmap_metas; - UInt64 deletedRowsCount(const MergeTreeMetaBase & storage) const; + UInt64 deletedRowsCount(const MergeTreeMetaBase & storage, bool ignore_error = false) const; const ImmutableDeleteBitmapPtr & getDeleteBitmap(const MergeTreeMetaBase & storage, bool is_unique_new_part) const; diff --git a/src/Catalog/IMetastore.h b/src/Catalog/IMetastore.h index 1496d76f252..e9dbe959b48 100644 --- a/src/Catalog/IMetastore.h +++ b/src/Catalog/IMetastore.h @@ -117,6 +117,11 @@ class IMetaStore * get limitations of the kv store */ virtual uint32_t getMaxBatchSize() = 0; + + /*** + * get limitation single a KV size + */ + virtual uint32_t getMaxKVSize() = 0; }; } diff --git a/src/Catalog/LargeKVHandler.cpp b/src/Catalog/LargeKVHandler.cpp new file mode 100644 index 00000000000..70021b7147b --- /dev/null +++ b/src/Catalog/LargeKVHandler.cpp @@ -0,0 +1,194 @@ +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CORRUPTED_DATA; +} + +namespace Catalog +{ + +// Using SHA-1 value of the KV as its UUID so that we can perform CAS based on it. +String getUUIDForLargeKV(const String & key, const String & value) +{ + Poco::SHA1Engine engine; + engine.update(key.data(), key.size()); + engine.update(value.data(), value.size()); + const std::vector & sha1_value = engine.digest(); + String hexed_hash; + hexed_hash.resize(sha1_value.size() * 2); + boost::algorithm::hex(sha1_value.begin(), sha1_value.end(), hexed_hash.data()); + return hexed_hash; +} + +bool tryParseLargeKVMetaModel(const String & serialized, Protos::DataModelLargeKVMeta & model) +{ + if (serialized.compare(0, 4, MAGIC_NUMBER) == 0) + return model.ParseFromArray(serialized.c_str() + 4, serialized.size()-4); + + return false; +} + +void tryGetLargeValue(const std::shared_ptr & metastore, const String & name_space, const String & key, String & value) +{ + Protos::DataModelLargeKVMeta large_kv_model; + + if (!tryParseLargeKVMetaModel(value, large_kv_model)) + return; + + String kv_id = large_kv_model.uuid(); + UInt32 subkv_number = large_kv_model.subkv_number(); + + String resolved; + + if (large_kv_model.has_value_size()) + resolved.reserve(large_kv_model.value_size()); + + if (subkv_number < 10) + { + std::vector request_keys(subkv_number); + for (size_t i=0; imultiGet(request_keys); + for (const auto & [subvalue, _] : sub_values) + resolved += subvalue; + } + else + { + auto it = metastore->getByPrefix(MetastoreProxy::largeKVDataPrefix(name_space, kv_id)); + while (it->next()) + resolved += it->value(); + } + + //check kv uuid(KV hash) to verity the data integrity + if (getUUIDForLargeKV(key, resolved) != kv_id) + throw Exception(fmt::format("Cannot resolve value of big KV. Data may be corrupted. Origin value size : {}, resolved size : {}" + , large_kv_model.value_size(), resolved.size()), ErrorCodes::CORRUPTED_DATA); + + value.swap(resolved); +} + +LargeKVWrapperPtr tryGetLargeKVWrapper( + const std::shared_ptr & metastore, + const String & name_space, + const String & key, + const String & value, + bool if_not_exists, + const String & expected) +{ + const size_t max_allowed_kv_size = metastore->getMaxKVSize(); + size_t value_size = value.size(); + + auto transform_expected_value = [&]() + { + Protos::DataModelLargeKVMeta expected_large_kv_model; + expected_large_kv_model.set_uuid(getUUIDForLargeKV(key, expected)); + expected_large_kv_model.set_subkv_number(1 + ((expected.size() - 1) / max_allowed_kv_size)); + expected_large_kv_model.set_value_size(expected.size()); + + return MAGIC_NUMBER + expected_large_kv_model.SerializeAsString(); + }; + + bool current_value_is_large_kv_format = false; + + if (!expected.empty() && expected.size() < max_allowed_kv_size) + { + // If expected value is not empty, we need to get current value of the key to decide if we should build + // large KV format expected value: + // If the current value is large KV format, we should serialize current expected value as large KV format too. + // Otherwise, just keep it as it is. + String current_value; + metastore->get(key, current_value); + Protos::DataModelLargeKVMeta large_kv_data; + if (tryParseLargeKVMetaModel(current_value, large_kv_data)) + current_value_is_large_kv_format = true; + } + + size_t expected_value_size = current_value_is_large_kv_format ? 0 : expected.size(); + + // Both the expected value and the insert value are need to be take into account. + if (value_size + expected_value_size > max_allowed_kv_size) + { + String large_kv_id = getUUIDForLargeKV(key, value); + + std::vector puts; + UInt64 sub_key_index = 0; + // split serialized data to make substrings match the KV size limitation + for (size_t i=0; i= max_allowed_kv_size || current_value_is_large_kv_format) + base_req.expected_value = transform_expected_value(); + else + base_req.expected_value = expected; + } + + LargeKVWrapperPtr wrapper = std::make_shared(std::move(base_req)); + wrapper->sub_requests.swap(puts); + + return wrapper; + } + else + { + SinglePutRequest base_req(key, value); + base_req.if_not_exists = if_not_exists; + if (!expected.empty()) + { + if (current_value_is_large_kv_format) + base_req.expected_value = transform_expected_value(); + else + base_req.expected_value = expected; + } + + LargeKVWrapperPtr wrapper = std::make_shared(std::move(base_req)); + return wrapper; + } +} + +void addPotentialLargeKVToBatchwrite( + const std::shared_ptr & metastore, + BatchCommitRequest & batch_request, + const String & name_space, + const String & key, + const String & value, + bool if_not_eixts, + const String & expected) +{ + LargeKVWrapperPtr largekv_wrapper = tryGetLargeKVWrapper(metastore, name_space, key, value, if_not_eixts, expected); + + for (auto & sub_req : largekv_wrapper->sub_requests) + batch_request.AddPut(sub_req); + + batch_request.AddPut(largekv_wrapper->base_request); +} + +} + +} diff --git a/src/Catalog/LargeKVHandler.h b/src/Catalog/LargeKVHandler.h new file mode 100644 index 00000000000..5d9f0397482 --- /dev/null +++ b/src/Catalog/LargeKVHandler.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ + +namespace Catalog +{ + +static const char * MAGIC_NUMBER = "LGKV"; + +struct LargeKVWrapper +{ + LargeKVWrapper(SinglePutRequest && base) + : base_request(std::move(base)) + { + } + + SinglePutRequest base_request; + std::vector sub_requests; + + bool isLargeKV() { return sub_requests.size() > 0; } +}; + +using LargeKVWrapperPtr = std::shared_ptr; + +LargeKVWrapperPtr tryGetLargeKVWrapper( + const std::shared_ptr & metastore, + const String & name_space, + const String & key, + const String & value, + bool if_not_exists = false, + const String & expected = ""); + + +bool tryParseLargeKVMetaModel(const String & serialized, Protos::DataModelLargeKVMeta & model); + +void tryGetLargeValue(const std::shared_ptr & metastore, const String & name_space, const String & key, String & value); + +void addPotentialLargeKVToBatchwrite( + const std::shared_ptr & metastore, + BatchCommitRequest & batch_request, + const String & name_space, + const String & key, + const String & value, + bool if_not_eixts = false, + const String & expected = ""); +} + +} diff --git a/src/Catalog/MetastoreByteKVImpl.h b/src/Catalog/MetastoreByteKVImpl.h index b58b858c77d..c3426193cd8 100644 --- a/src/Catalog/MetastoreByteKVImpl.h +++ b/src/Catalog/MetastoreByteKVImpl.h @@ -122,6 +122,9 @@ class MetastoreByteKVImpl : public IMetaStore // leave some margin uint32_t getMaxBatchSize() final { return MAX_BYTEKV_BATCH_SIZE - 1000; } + // leave some margin + uint32_t getMaxKVSize() final { return MAX_BYTEKV_KV_SIZE - 200; } + public: std::shared_ptr client; diff --git a/src/Catalog/MetastoreFDBImpl.h b/src/Catalog/MetastoreFDBImpl.h index 5465dbb43b0..8e2887966c7 100644 --- a/src/Catalog/MetastoreFDBImpl.h +++ b/src/Catalog/MetastoreFDBImpl.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace DB { @@ -33,7 +34,7 @@ namespace Catalog class MetastoreFDBImpl : public IMetaStore { // Limitations of FDB (in bytes) -#define MAX_FDB_KV_SIZE 10000 +#define MAX_FDB_KV_SIZE 100000 //Hard limit.Keys cannot exceed 10,000 bytes in size. Values cannot exceed 100,000 bytes in size #define MAX_FDB_TRANSACTION_SIZE 10000000 public: @@ -105,6 +106,9 @@ class MetastoreFDBImpl : public IMetaStore // leave some margin uint32_t getMaxBatchSize() final { return MAX_FDB_TRANSACTION_SIZE - 1000; } + // leave some margin + uint32_t getMaxKVSize() final { return MAX_FDB_KV_SIZE - 200; } + private: /// convert metastore specific error code to Clickhouse error code for processing convenience in upper layer. static int toCommonErrorCode(const fdb_error_t & error_t); diff --git a/src/Catalog/MetastoreProxy.cpp b/src/Catalog/MetastoreProxy.cpp index edcd952b9a1..1121e69b49b 100644 --- a/src/Catalog/MetastoreProxy.cpp +++ b/src/Catalog/MetastoreProxy.cpp @@ -20,11 +20,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -342,8 +344,16 @@ void MetastoreProxy::createTable(const String & name_space, const UUID & db_uuid BatchCommitRequest batch_write; batch_write.AddPut(SinglePutRequest(nonHostUpdateKey(name_space, uuid), "0", true)); - // insert table meta - batch_write.AddPut(SinglePutRequest(tableStoreKey(name_space, uuid, table_data.commit_time()), serialized_meta, true)); + + // insert table meta. Handle by largeKVHandler in case the table meta exceeds KV size limitation + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + tableStoreKey(name_space, uuid, table_data.commit_time()), + serialized_meta, + true/*if_not_exists*/); + /// add dependency mapping if need for (const String & dependency : dependencies) batch_write.AddPut(SinglePutRequest(viewDependencyKey(name_space, dependency, uuid), uuid)); @@ -419,14 +429,34 @@ void MetastoreProxy::dropUDF(const String & name_space, const String &resolved_n void MetastoreProxy::updateTable(const String & name_space, const String & table_uuid, const String & table_info_new, const UInt64 & ts) { - metastore_ptr->put(tableStoreKey(name_space, table_uuid, ts), table_info_new); + if (table_info_new.size() > metastore_ptr->getMaxKVSize()) + { + BatchCommitRequest batch_write; + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + tableStoreKey(name_space, table_uuid, ts), + table_info_new); + BatchCommitResponse resp; + metastore_ptr->batchWrite(batch_write, resp); + } + else + metastore_ptr->put(tableStoreKey(name_space, table_uuid, ts), table_info_new); } void MetastoreProxy::updateTableWithID(const String & name_space, const Protos::TableIdentifier & table_id, const DB::Protos::DataModelTable & table_data) { BatchCommitRequest batch_write; batch_write.AddPut(SinglePutRequest(tableUUIDMappingKey(name_space, table_id.database(), table_id.name()), table_id.SerializeAsString())); - batch_write.AddPut(SinglePutRequest(tableStoreKey(name_space, table_id.uuid(), table_data.commit_time()), table_data.SerializeAsString())); + + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + tableStoreKey(name_space, table_id.uuid(), table_data.commit_time()), + table_data.SerializeAsString()); + BatchCommitResponse resp; metastore_ptr->batchWrite(batch_write, resp); } @@ -436,7 +466,10 @@ void MetastoreProxy::getTableByUUID(const String & name_space, const String & ta auto it = metastore_ptr->getByPrefix(tableStorePrefix(name_space, table_uuid)); while(it->next()) { - tables_info.emplace_back(it->value()); + String table_meta = it->value(); + /// NOTE: Too many large KVs will cause severe performance regression. It rarely happens + tryGetLargeValue(metastore_ptr, name_space, it->key(), table_meta); + tables_info.emplace_back(std::move(table_meta)); } } @@ -830,10 +863,14 @@ void MetastoreProxy::prepareRenameTable(const String & name_space, RPCHelpers::fillUUID(to_db_uuid, *identifier.mutable_db_uuid()); batch_write.AddPut(SinglePutRequest(tableUUIDMappingKey(name_space, to_table.database(), to_table.name()), identifier.SerializeAsString(), true)); - String meta_data; - to_table.SerializeToString(&meta_data); /// add new table meta data with new name - batch_write.AddPut(SinglePutRequest(tableStoreKey(name_space, table_uuid, to_table.commit_time()), meta_data, true)); + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + tableStoreKey(name_space, table_uuid, to_table.commit_time()), + to_table.SerializeAsString(), + true/*if_not_exists*/); } bool MetastoreProxy::alterTable(const String & name_space, const Protos::DataModelTable & table, const Strings & masks_to_remove, const Strings & masks_to_add) @@ -841,7 +878,14 @@ bool MetastoreProxy::alterTable(const String & name_space, const Protos::DataMod BatchCommitRequest batch_write; String table_uuid = UUIDHelpers::UUIDToString(RPCHelpers::createUUID(table.uuid())); - batch_write.AddPut(SinglePutRequest(tableStoreKey(name_space, table_uuid, table.commit_time()), table.SerializeAsString(), true)); + + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + tableStoreKey(name_space, table_uuid, table.commit_time()), + table.SerializeAsString(), + true/*if_not_exists*/); Protos::TableIdentifier identifier; identifier.set_database(table.database()); @@ -936,7 +980,6 @@ void MetastoreProxy::prepareAddDataParts( if (parts.empty()) return; - std::unordered_set existing_partitions{current_partitions.begin(), current_partitions.end()}; std::unordered_set partitions_found_in_deleting_set; std::unordered_map partition_map; @@ -959,18 +1002,23 @@ void MetastoreProxy::prepareAddDataParts( batch_write.AddPut(SinglePutRequest(manifestKeyForPart(name_space, table_uuid, txn_id, info_ptr->getPartName()), part_meta)); if (deleting_partitions.count(info_ptr->partition_id) && !partitions_found_in_deleting_set.count(info_ptr->partition_id)) - { partitions_found_in_deleting_set.emplace(info_ptr->partition_id); - partition_map.emplace(info_ptr->partition_id, it->partition_minmax()); - } - if (!existing_partitions.count(info_ptr->partition_id) && !partition_map.count(info_ptr->partition_id)) + if (!partition_map.count(info_ptr->partition_id)) partition_map.emplace(info_ptr->partition_id, it->partition_minmax()); } if (update_sync_list) batch_write.AddPut(SinglePutRequest(syncListKey(name_space, table_uuid, commit_time), std::to_string(commit_time))); + // Prepare partition metadata. Skip those already exists non-deleting partitions + for (const auto & exist_partition : current_partitions) + { + auto it = partition_map.find(exist_partition); + if (it != partition_map.end() && !partitions_found_in_deleting_set.count(exist_partition)) + partition_map.erase(it); + } + Protos::PartitionMeta partition_model; for (auto it = partition_map.begin(); it != partition_map.end(); it++) { @@ -994,7 +1042,6 @@ void MetastoreProxy::prepareAddStagedParts( if (parts.empty()) return; - std::unordered_set existing_partitions{current_partitions.begin(), current_partitions.end()}; std::unordered_map partition_map; size_t expected_staged_part_size = expected_staged_parts.size(); if (expected_staged_part_size != static_cast(parts.size())) @@ -1006,10 +1053,18 @@ void MetastoreProxy::prepareAddStagedParts( String part_meta = it->SerializeAsString(); batch_write.AddPut(SinglePutRequest(stagedDataPartKey(name_space, table_uuid, info_ptr->getPartName()), part_meta, expected_staged_parts[it - parts.begin()])); - if (!existing_partitions.count(info_ptr->partition_id) && !partition_map.count(info_ptr->partition_id)) + if (!partition_map.count(info_ptr->partition_id)) partition_map.emplace(info_ptr->partition_id, it->partition_minmax()); } + // Prepare partition metadata. Skip those already exists partitions + for (const auto & exist_partition : current_partitions) + { + auto it = partition_map.find(exist_partition); + if (it != partition_map.end()) + partition_map.erase(it); + } + Protos::PartitionMeta partition_model; for (auto & it : partition_map) { @@ -1219,9 +1274,11 @@ std::vector> MetastoreProxy::getTransactionRecords(con return metastore_ptr->multiGet(txn_keys); } -IMetaStore::IteratorPtr MetastoreProxy::getAllTransactionRecord(const String & name_space, const size_t & max_result_number) +IMetaStore::IteratorPtr +MetastoreProxy::getAllTransactionRecord(const String & name_space, const String & start_key, const size_t & max_result_number) { - return metastore_ptr->getByPrefix(escapeString(name_space) + "_" + TRANSACTION_RECORD_PREFIX, max_result_number); + return metastore_ptr->getByPrefix( + escapeString(name_space) + "_" + TRANSACTION_RECORD_PREFIX, max_result_number, DEFAULT_SCAN_BATCH_COUNT, start_key); } std::pair MetastoreProxy::updateTransactionRecord(const String & name_space, const UInt64 & txn_id, const String & txn_data_old, const String & txn_data_new) @@ -2372,6 +2429,56 @@ void MetastoreProxy::removeSQLBinding(const String & name_space, const String & metastore_ptr->batchWrite(batch_write, resp); } +void MetastoreProxy::updatePreparedStatement(const String & name_space, const PreparedStatementItemPtr & data) +{ + BatchCommitRequest batch_write; + + Protos::PreparedStatementItem prepared_statement; + prepared_statement.set_name(data->name); + prepared_statement.set_create_statement(data->create_statement); + batch_write.AddPut(SinglePutRequest(preparedStatementKey(name_space, data->name), prepared_statement.SerializeAsString())); + BatchCommitResponse resp; + metastore_ptr->batchWrite(batch_write, resp); +} + +PreparedStatements MetastoreProxy::getPreparedStatements(const String & name_space) +{ + PreparedStatements res; + auto prepared_prefix = preparedStatementPrefix(name_space); + auto it = metastore_ptr->getByPrefix(prepared_prefix); + while (it->next()) + { + Protos::PreparedStatementItem prepared_statement; + prepared_statement.ParseFromString(it->value()); + PreparedStatementItemPtr statement = std::make_shared(prepared_statement.name(), prepared_statement.create_statement()); + res.emplace_back(statement); + } + + return res; +} +PreparedStatementItemPtr MetastoreProxy::getPreparedStatement(const String & name_space, const String & name) +{ + String value; + auto prepared_statement_key = preparedStatementKey(name_space, name); + metastore_ptr->get(prepared_statement_key, value); + + if (value.empty()) + return nullptr; + + Protos::PreparedStatementItem prepared_statement; + prepared_statement.ParseFromString(value); + PreparedStatementItemPtr prepared = std::make_shared(prepared_statement.name(), prepared_statement.create_statement()); + return prepared; +} + +void MetastoreProxy::removePreparedStatement(const String & name_space, const String & name) +{ + BatchCommitRequest batch_write; + batch_write.AddDelete(preparedStatementKey(name_space, name)); + BatchCommitResponse resp; + metastore_ptr->batchWrite(batch_write, resp); +} + void MetastoreProxy::createVirtualWarehouse(const String & name_space, const String & vw_name, const VirtualWarehouseData & data) { auto vw_key = VWKey(name_space, vw_name); @@ -3371,7 +3478,9 @@ std::shared_ptr MetastoreProxy::getSensitive String MetastoreProxy::getAccessEntity(EntityType type, const String & name_space, const String & name) const { String data; - metastore_ptr->get(accessEntityKey(type, name_space, name), data); + String access_entity_key = accessEntityKey(type, name_space, name); + metastore_ptr->get(access_entity_key, data); + tryGetLargeValue(metastore_ptr, name_space, access_entity_key, data); return data; } @@ -3394,7 +3503,16 @@ std::vector> MetastoreProxy::getEntities(EntityType ty requests.push_back(accessEntityKey(type, name_space, s)); } - return metastore_ptr->multiGet(requests); + auto res = metastore_ptr->multiGet(requests); + + for (size_t i=0; igetByPrefix(accessEntityPrefix(type, name_space)); while (it->next()) { - models.push_back(it->value()); + String value = it->value(); + /// NOTE: Too many large KVs will cause severe performance regression. + tryGetLargeValue(metastore_ptr, name_space, it->key(), value); + models.push_back(std::move(value)); } return models; } @@ -3431,12 +3552,30 @@ bool MetastoreProxy::putAccessEntity(EntityType type, const String & name_space, BatchCommitRequest batch_write; BatchCommitResponse resp; auto is_rename = !old_access_entity.name().empty() && new_access_entity.name() != old_access_entity.name(); - auto put_access_entity_request = SinglePutRequest(accessEntityKey(type, name_space, new_access_entity.name()), new_access_entity.SerializeAsString(), !replace_if_exists); String uuid = UUIDHelpers::UUIDToString(RPCHelpers::createUUID(new_access_entity.uuid())); String serialized_old_access_entity = old_access_entity.SerializeAsString(); if (!serialized_old_access_entity.empty() && !is_rename) - put_access_entity_request.expected_value = serialized_old_access_entity; - batch_write.AddPut(put_access_entity_request); + { + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + accessEntityKey(type, name_space, new_access_entity.name()), + new_access_entity.SerializeAsString(), + !replace_if_exists, + serialized_old_access_entity); + } + else + { + addPotentialLargeKVToBatchwrite( + metastore_ptr, + batch_write, + name_space, + accessEntityKey(type, name_space, new_access_entity.name()), + new_access_entity.SerializeAsString(), + !replace_if_exists); + } + batch_write.AddPut(SinglePutRequest(accessEntityUUIDNameMappingKey(name_space, uuid), new_access_entity.name(), !replace_if_exists)); if (is_rename) batch_write.AddDelete(accessEntityKey(type, name_space, old_access_entity.name())); // delete old one in case of rename @@ -3446,21 +3585,22 @@ bool MetastoreProxy::putAccessEntity(EntityType type, const String & name_space, } catch (Exception & e) { + auto puts_size = batch_write.puts.size(); if (e.code() == ErrorCodes::METASTORE_COMMIT_CAS_FAILURE) { - if (resp.puts.count(0) && replace_if_exists && !serialized_old_access_entity.empty()) + if (resp.puts.count(puts_size-2) && replace_if_exists && !serialized_old_access_entity.empty()) { throw Exception( "Access Entity has recently been changed in catalog. Please try the request again.", ErrorCodes::METASTORE_ACCESS_ENTITY_CAS_ERROR); } - else if (resp.puts.count(0) && !replace_if_exists) + else if (resp.puts.count(puts_size-2) && !replace_if_exists) { throw Exception( "Access Entity with the same name already exists in catalog. Please use another name and try again.", ErrorCodes::METASTORE_ACCESS_ENTITY_EXISTS_ERROR); } - else if (resp.puts.count(1) && !replace_if_exists) + else if (resp.puts.count(puts_size-1) && !replace_if_exists) { throw Exception( "Access Entity with the same UUID already exists in catalog. Please use another name and try again.", diff --git a/src/Catalog/MetastoreProxy.h b/src/Catalog/MetastoreProxy.h index 0040066ff66..c4e3269c2b5 100644 --- a/src/Catalog/MetastoreProxy.h +++ b/src/Catalog/MetastoreProxy.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -113,6 +114,7 @@ namespace DB::Catalog #define COLUMN_STATISTICS_PREFIX "CS_" #define COLUMN_STATISTICS_TAG_PREFIX "CST_" // deprecated, just remove it #define SQL_BINDING_PREFIX "SBI_" +#define PREPARED_STATEMENT_PREFIX "PSTAT_" #define FILESYS_LOCK_PREFIX "FSLK_" #define UDF_STORE_PREFIX "UDF_" #define MERGEMUTATE_THREAD_START_TIME "MTST_" @@ -132,6 +134,9 @@ namespace DB::Catalog #define MANIFEST_DATA_PREFIX "MFST_" #define MANIFEST_LIST_PREFIX "MFSTS_" +#define LARGE_KV_DATA_PREFIX "LGKV_" +#define LARGE_KV_REFERENCE "LGKVRF_" + using EntityType = IAccessEntity::Type; struct EntityMetastorePrefix { @@ -742,6 +747,20 @@ class MetastoreProxy return ss.str(); } + static String preparedStatementKey(const String name_space, const String & key) + { + std::stringstream ss; + ss << escapeString(name_space) << '_' << PREPARED_STATEMENT_PREFIX << '_' << key; + return ss.str(); + } + + static String preparedStatementPrefix(const String name_space) + { + std::stringstream ss; + ss << escapeString(name_space) << '_' << PREPARED_STATEMENT_PREFIX; + return ss.str(); + } + static String tableStatisticKey(const String name_space, const String & uuid, const StatisticsTag & tag) { std::stringstream ss; @@ -954,6 +973,29 @@ class MetastoreProxy return manifestListPrefix(name_space, uuid) + toString(table_version); } + static String largeKVDataPrefix(const String & name_space, const String & uuid) + { + return escapeString(name_space) + '_' + LARGE_KV_DATA_PREFIX + uuid + '_'; + } + + static String largeKVDataKey(const String & name_space, const String & uuid, UInt64 index) + { + // keep records in the kv storage with the same order as index. Support at most 10k sub-kv + std::ostringstream oss; + oss << std::setw(5) << std::setfill('0') << index; + return largeKVDataPrefix(name_space, uuid) + oss.str(); + } + + static String largeKVReferencePrefix(const String & name_space) + { + return escapeString(name_space) + '_' + LARGE_KV_REFERENCE; + } + + static String largeKVReferenceKey(const String & name_space, const String & uuid) + { + return largeKVReferencePrefix(name_space) + uuid; + } + // parse the first key in format of '{prefix}{escapedString(first_key)}_postfix' // note that prefix should contains _, like TCS_ // return [first_key, postfix] @@ -965,7 +1007,8 @@ class MetastoreProxy void removeTransactionRecord(const String & name_space, const UInt64 & txn_id); void removeTransactionRecords(const String & name_space, const std::vector & txn_ids); String getTransactionRecord(const String & name_space, const UInt64 & txn_id); - IMetaStore::IteratorPtr getAllTransactionRecord(const String & name_space, const size_t & max_result_number = 0); + IMetaStore::IteratorPtr + getAllTransactionRecord(const String & name_space, const String & start_key = "", const size_t & max_result_number = 0); std::pair updateTransactionRecord(const String & name_space, const UInt64 & txn_id, const String & txn_data_old, const String & txn_data_new); std::vector> getTransactionRecords(const String & name_space, const std::vector & txn_ids); @@ -1037,7 +1080,7 @@ class MetastoreProxy void updateTableWithID(const String & name_space, const Protos::TableIdentifier & table_id, const DB::Protos::DataModelTable & table_data); void getTableByUUID(const String & name_space, const String & table_uuid, Strings & tables_info); void clearTableMeta(const String & name_space, const String & database, const String & table, const String & uuid, const Strings & dependencies, const UInt64 & ts = 0); - static void prepareRenameTable(const String & name_space, const String & table_uuid, const String & from_db, const String & from_table, const UUID & to_db_uuid, Protos::DataModelTable & to_table, BatchCommitRequest & batch_write); + void prepareRenameTable(const String & name_space, const String & table_uuid, const String & from_db, const String & from_table, const UUID & to_db_uuid, Protos::DataModelTable & to_table, BatchCommitRequest & batch_write); bool alterTable(const String & name_space, const Protos::DataModelTable & table, const Strings & masks_to_remove, const Strings & masks_to_add); Strings getAllTablesInDB(const String & name_space, const String & database); IMetaStore::IteratorPtr getAllTablesMeta(const String & name_space); @@ -1224,6 +1267,11 @@ class MetastoreProxy SQLBindingItemPtr getSQLBinding(const String & name_space, const String & uuid, const String & tenant_id, const bool & is_re_expression); void removeSQLBinding(const String & name_space, const String & uuid, const String & tenant_id, const bool & is_re_expression); + void updatePreparedStatement(const String & name_space, const PreparedStatementItemPtr & data); + PreparedStatements getPreparedStatements(const String & name_space); + PreparedStatementItemPtr getPreparedStatement(const String & name_space, const String & name); + void removePreparedStatement(const String & name_space, const String & name); + void updateTableStatistics(const String & name_space, const String & uuid, const std::unordered_map & data); // new api std::unordered_map getTableStatistics(const String & name_space, const String & uuid); diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index 223f713e273..14abe7f8897 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -1040,6 +1040,11 @@ Packet Connection::receivePacket() case Protocol::Server::ReadTaskRequest: return res; + + case Protocol::Server::TimezoneUpdate: + readStringBinary(server_timezone, *in); + res.server_timezone = server_timezone; + return res; default: /// In unknown state, disconnect - to not leave unsynchronised connection. diff --git a/src/Client/Connection.h b/src/Client/Connection.h index 7de4ceac26d..8ad348c71d4 100644 --- a/src/Client/Connection.h +++ b/src/Client/Connection.h @@ -96,6 +96,8 @@ struct Packet BlockStreamProfileInfo profile_info; std::vector part_uuids; + std::string server_timezone; + Packet() : type(Protocol::Server::Hello) {} }; diff --git a/src/Client/HedgedConnections.cpp b/src/Client/HedgedConnections.cpp index 8455ef3117e..355cd67544f 100644 --- a/src/Client/HedgedConnections.cpp +++ b/src/Client/HedgedConnections.cpp @@ -262,6 +262,7 @@ Packet HedgedConnections::drain() case Protocol::Server::Totals: case Protocol::Server::Extremes: case Protocol::Server::EndOfStream: + case Protocol::Server::TimezoneUpdate: break; case Protocol::Server::Exception: diff --git a/src/Client/MultiplexedConnections.cpp b/src/Client/MultiplexedConnections.cpp index 3eee6eb7f0f..5489e85f4e6 100644 --- a/src/Client/MultiplexedConnections.cpp +++ b/src/Client/MultiplexedConnections.cpp @@ -305,6 +305,7 @@ Packet MultiplexedConnections::drain() case Protocol::Server::Totals: case Protocol::Server::Extremes: case Protocol::Server::EndOfStream: + case Protocol::Server::TimezoneUpdate: break; case Protocol::Server::ProfileInfo: @@ -383,6 +384,7 @@ Packet MultiplexedConnections::receivePacketUnlocked(AsyncCallback async_callbac case Protocol::Server::Totals: case Protocol::Server::Extremes: case Protocol::Server::Log: + case Protocol::Server::TimezoneUpdate: break; case Protocol::Server::EndOfStream: diff --git a/src/CloudServices/CloudMergeTreeDedupWorker.cpp b/src/CloudServices/CloudMergeTreeDedupWorker.cpp index 364dc981ce1..c82edb21282 100644 --- a/src/CloudServices/CloudMergeTreeDedupWorker.cpp +++ b/src/CloudServices/CloudMergeTreeDedupWorker.cpp @@ -69,7 +69,7 @@ CloudMergeTreeDedupWorker::CloudMergeTreeDedupWorker(StorageCloudMergeTree & sto { /// init current_deduper before iterate std::lock_guard lock(current_deduper_mutex); - current_deduper = std::make_unique(storage, context); + current_deduper = std::make_unique(storage, context, CnchDedupHelper::DedupMode::UPSERT); } if (storage.getSettings()->duplicate_auto_repair) @@ -218,7 +218,7 @@ void CloudMergeTreeDedupWorker::iterate() std::vector locks_to_acquire = CnchDedupHelper::getLocksToAcquire( scope, txn->getTransactionID(), *cnch_table, storage.getSettings()->unique_acquire_write_lock_timeout.value.totalMilliseconds()); lock_watch.restart(); - cnch_lock = txn->createLockHolder(std::move(locks_to_acquire)); + cnch_lock = std::make_shared(context, std::move(locks_to_acquire)); if (!cnch_lock->tryLock()) { if (auto unique_table_log = context->getCloudUniqueTableLog()) @@ -261,6 +261,8 @@ void CloudMergeTreeDedupWorker::iterate() return; } + txn->appendLockHolder(cnch_lock); + /// Sorts by commit time std::sort(staged_parts.begin(), staged_parts.end(), [](auto & lhs, auto & rhs) { return lhs->commit_time < rhs->commit_time; diff --git a/src/CloudServices/CnchBGThreadCommon.h b/src/CloudServices/CnchBGThreadCommon.h index 2725e0b8ac2..73822fd1fb5 100644 --- a/src/CloudServices/CnchBGThreadCommon.h +++ b/src/CloudServices/CnchBGThreadCommon.h @@ -15,8 +15,18 @@ #pragma once +#include +#include +#include + namespace DB { +namespace ErrorCodes +{ + extern const int UNKNOWN_CNCH_BG_THREAD_ACTION; + extern const int UNKNOWN_CNCH_BG_THREAD_TYPE; +} + /** * Use enum in nested namespace instead enum class. * Because we want to pass it to protos easily, while it offers weaker compile-time check. @@ -24,10 +34,14 @@ namespace DB */ namespace CnchBGThread { + /// NOTE: when introducing a new type, remember to update + /// 1. {Server|Daemon}{Min|Max}Type accordingly + /// 2. toString(CnchBGThreadType type) enum Type : unsigned int { Empty = 0, + /// server types PartGC = 1, MergeMutate = 2, Consumer = 3, @@ -40,22 +54,22 @@ namespace CnchBGThread PartMover = 10, ManifestCheckpoint = 11, - ServerMinType = PartGC, - ServerMaxType = ManifestCheckpoint, - + /// DM types GlobalGC = 20, /// reserve several entries TxnGC = 21, AutoStatistics = 22, - DaemonMinType = GlobalGC, - DaemonMaxType = AutoStatistics, - ResourceReport = 30, /// worker - WorkerMinType = ResourceReport, /// Enum to mark start of worker types - WorkerMaxType = ResourceReport, /// Enum to mark end of worker types + /// worker types (perhaps this should not be included in CnchBGThread?) + ResourceReport = 30, }; - constexpr unsigned int NumType = WorkerMaxType + 1; + constexpr unsigned int ServerMinType = PartGC; + constexpr unsigned int ServerMaxType = ManifestCheckpoint; + constexpr unsigned int NumServerType = ServerMaxType + 1; + constexpr unsigned int DaemonMinType = GlobalGC; + constexpr unsigned int DaemonMaxType = AutoStatistics; + /// when introducing a new type, remember to update toCnchBGThreadAction() enum Action : unsigned int { Start = 0, @@ -116,14 +130,29 @@ constexpr auto toString(CnchBGThreadType type) __builtin_unreachable(); } -constexpr auto isServerBGThreadType(CnchBGThreadType t) +constexpr auto isServerBGThreadType(size_t t) +{ + return CnchBGThread::ServerMinType <= t && t <= CnchBGThread::ServerMaxType; +} + +inline CnchBGThreadType toServerBGThreadType(size_t t) +{ + if (unlikely(!isServerBGThreadType(t))) + throw Exception(ErrorCodes::UNKNOWN_CNCH_BG_THREAD_TYPE, "Unknown server bg thread type: {}", t); + return static_cast(t); +} + +constexpr auto iDaemonBGThreadType(size_t t) { - return CnchBGThreadType::ServerMinType <= t && t <= CnchBGThreadType::ServerMaxType; + return CnchBGThread::DaemonMinType <= t && t <= CnchBGThread::DaemonMaxType; } -constexpr auto iDaemonBGThreadType(CnchBGThreadType t) +inline CnchBGThreadAction toCnchBGThreadAction(size_t action) { - return CnchBGThreadType::DaemonMinType <= t && t <= CnchBGThreadType::DaemonMaxType; + if (unlikely(action > CnchBGThreadAction::Wakeup)) + throw Exception(ErrorCodes::UNKNOWN_CNCH_BG_THREAD_ACTION, "Unknown bg thread action: {}", action); + + return static_cast(action); } constexpr auto toString(CnchBGThreadAction action) diff --git a/src/CloudServices/CnchBGThreadPartitionSelector.cpp b/src/CloudServices/CnchBGThreadPartitionSelector.cpp index 0766340e8ab..29edcefa173 100644 --- a/src/CloudServices/CnchBGThreadPartitionSelector.cpp +++ b/src/CloudServices/CnchBGThreadPartitionSelector.cpp @@ -48,7 +48,7 @@ CnchBGThreadPartitionSelector::CnchBGThreadPartitionSelector(ContextMutablePtr g if (!res) break; - auto * col_uuid = checkAndGetColumn(*res.getByName("uuid").column); + auto * col_uuid = checkAndGetColumn(*res.getByName("uuid").column); auto * col_partition = checkAndGetColumn(*res.getByName("partition_id").column); auto * col_insert = checkAndGetColumn(*res.getByName("insert_parts").column); auto * col_insert_time = checkAndGetColumn(*res.getByName("last_insert_time").column); diff --git a/src/CloudServices/CnchBGThreadsMap.cpp b/src/CloudServices/CnchBGThreadsMap.cpp index 26a6d729252..00532b7e4c6 100644 --- a/src/CloudServices/CnchBGThreadsMap.cpp +++ b/src/CloudServices/CnchBGThreadsMap.cpp @@ -237,8 +237,8 @@ void CnchBGThreadsMap::cleanup() CnchBGThreadsMapArray::CnchBGThreadsMapArray(ContextPtr global_context_) : WithContext(global_context_) { - for (auto i = size_t(CnchBGThreadType::ServerMinType); i <= size_t(CnchBGThreadType::ServerMaxType); ++i) - threads_array[i] = std::make_unique(global_context_, CnchBGThreadType(i)); + for (auto i = CnchBGThread::ServerMinType; i <= CnchBGThread::ServerMaxType; ++i) + threads_array[i] = std::make_unique(global_context_, static_cast(i)); if (global_context_->getServerType() == ServerType::cnch_worker && global_context_->getResourceManagerClient()) { @@ -263,9 +263,9 @@ CnchBGThreadsMapArray::~CnchBGThreadsMapArray() void CnchBGThreadsMapArray::shutdown() { - ThreadPool pool(size_t(CnchBGThreadType::ServerMaxType) - size_t(CnchBGThreadType::ServerMinType) + 1); + ThreadPool pool(CnchBGThread::ServerMaxType - CnchBGThread::ServerMinType + 1); - for (auto i = size_t(CnchBGThreadType::ServerMinType); i <= size_t(CnchBGThreadType::ServerMaxType); ++i) + for (auto i = CnchBGThread::ServerMinType; i <= CnchBGThread::ServerMaxType; ++i) { if (auto * t = threads_array[i].get()) pool.scheduleOrThrowOnError([t] { t->stopAll(); }); @@ -285,7 +285,7 @@ void CnchBGThreadsMapArray::cleanThread() { try { - for (auto i = size_t(CnchBGThreadType::ServerMinType); i <= size_t(CnchBGThreadType::ServerMaxType); ++i) + for (auto i = CnchBGThread::ServerMinType; i <= CnchBGThread::ServerMaxType; ++i) threads_array[i]->cleanup(); } catch (...) diff --git a/src/CloudServices/CnchBGThreadsMap.h b/src/CloudServices/CnchBGThreadsMap.h index 73f348c62ad..6cf6a32b316 100644 --- a/src/CloudServices/CnchBGThreadsMap.h +++ b/src/CloudServices/CnchBGThreadsMap.h @@ -25,6 +25,11 @@ namespace DB { +namespace ErrorCodes +{ + extern const int UNKNOWN_CNCH_BG_THREAD_TYPE; +} + using UUIDToBGThreads = std::unordered_map; namespace ResourceManagement { @@ -90,20 +95,13 @@ class CnchBGThreadsMapArray : protected WithContext, private boost::noncopyable inline CnchBGThreadsMap * at(size_t type) { - try - { - auto * res = threads_array.at(type).get(); - if (unlikely(!res)) - { - throw Exception(ErrorCodes::LOGICAL_ERROR, "CnchBGThread for type {} is not initialized.", toString(static_cast(type))); - } - return res; - } - catch(...) - { - /// Show a better exception message. - throw Exception(ErrorCodes::LOGICAL_ERROR, "CnchBGThread for type {} is not initialized. Maybe the enum CnchBGThread is mismatch.", toString(static_cast(type))); - } + if (unlikely(!isServerBGThreadType(type))) + throw Exception(ErrorCodes::UNKNOWN_CNCH_BG_THREAD_TYPE, "Unknown server bg thread type: {}", type); + + auto * res = threads_array.at(type).get(); + if (unlikely(!res)) + throw Exception(ErrorCodes::LOGICAL_ERROR, "CnchBGThread for type {} is not initialized.", type); + return res; } void cleanThread(); @@ -113,7 +111,7 @@ class CnchBGThreadsMapArray : protected WithContext, private boost::noncopyable bool isResourceReportRegistered(); private: - std::array, CnchBGThread::NumType> threads_array; + std::array, CnchBGThread::NumServerType> threads_array; std::unique_ptr resource_reporter_task; diff --git a/src/CloudServices/CnchCreateQueryHelper.cpp b/src/CloudServices/CnchCreateQueryHelper.cpp index 8d660cc7c7d..faed0b9d04a 100644 --- a/src/CloudServices/CnchCreateQueryHelper.cpp +++ b/src/CloudServices/CnchCreateQueryHelper.cpp @@ -15,12 +15,15 @@ #include +#include #include #include #include #include +#include #include #include +#include #include #include #include @@ -32,6 +35,12 @@ namespace DB { +namespace ErrorCodes +{ + extern const int DUPLICATE_COLUMN; + extern const int INCORRECT_QUERY; +} + std::shared_ptr getASTCreateQueryFromString(const String & query, const ContextPtr & context) { ParserCreateQuery parser_create; @@ -51,6 +60,77 @@ std::shared_ptr getASTCreateQueryFromStorage(const IStorage & st return getASTCreateQueryFromString(storage.getCreateTableSql(), context); } +StoragePtr createStorageFromQuery(ASTCreateQuery & create_query, ContextMutablePtr context) +{ + ColumnsDescription columns; + IndicesDescription indices; + ConstraintsDescription constraints; + ForeignKeysDescription foreign_keys; + UniqueNotEnforcedDescription unique_not_enforced; + + if (create_query.columns_list) + { + if (create_query.columns_list->columns) + { + // Set attach = true to avoid making columns nullable due to ANSI settings, because the dialect change + // should NOT affect existing tables. + columns = InterpreterCreateQuery::getColumnsDescription(*create_query.columns_list->columns, context, /* attach= */ true); + } + + if (create_query.columns_list->indices) + for (const auto & index : create_query.columns_list->indices->children) + indices.push_back(IndexDescription::getIndexFromAST(index->clone(), columns, context)); + + if (create_query.columns_list->constraints) + for (const auto & constraint : create_query.columns_list->constraints->children) + constraints.constraints.push_back(std::dynamic_pointer_cast(constraint->clone())); + + if (create_query.columns_list->foreign_keys) + for (const auto & foreign_key : create_query.columns_list->foreign_keys->children) + foreign_keys.foreign_keys.push_back(std::dynamic_pointer_cast(foreign_key->clone())); + + if (create_query.columns_list->unique) + for (const auto & unique : create_query.columns_list->unique->children) + unique_not_enforced.unique.push_back(std::dynamic_pointer_cast(unique->clone())); + } + else + throw Exception("Incorrect CREATE query: required list of column descriptions or AS section or SELECT.", ErrorCodes::INCORRECT_QUERY); + + /// Even if query has list of columns, canonicalize it (unfold Nested columns). + ASTPtr new_columns = InterpreterCreateQuery::formatColumns(columns, ParserSettings::valueOf(context->getSettingsRef())); + ASTPtr new_indices = InterpreterCreateQuery::formatIndices(indices); + ASTPtr new_constraints = InterpreterCreateQuery::formatConstraints(constraints); + ASTPtr new_foreign_keys = InterpreterCreateQuery::formatForeignKeys(foreign_keys); + ASTPtr new_unique_not_enforced = InterpreterCreateQuery::formatUnique(unique_not_enforced); + + if (create_query.columns_list->columns) + create_query.columns_list->replace(create_query.columns_list->columns, new_columns); + + if (create_query.columns_list->indices) + create_query.columns_list->replace(create_query.columns_list->indices, new_indices); + + if (create_query.columns_list->constraints) + create_query.columns_list->replace(create_query.columns_list->constraints, new_constraints); + + if (create_query.columns_list->foreign_keys) + create_query.columns_list->replace(create_query.columns_list->foreign_keys, new_foreign_keys); + + if (create_query.columns_list->unique) + create_query.columns_list->replace(create_query.columns_list->unique, new_unique_not_enforced); + + /// Check for duplicates + std::set all_columns; + for (const auto & column : columns) + { + if (!all_columns.emplace(column.name).second) + throw Exception("Column " + backQuoteIfNeed(column.name) + " already exists", ErrorCodes::DUPLICATE_COLUMN); + } + + /// Table constructing + return StorageFactory::instance().get(create_query, "", context, context->getGlobalContext(), columns, constraints, foreign_keys, unique_not_enforced, false); +} + +/// TODO: impl based on createStorageFromQuery(create_query, context) ? StoragePtr createStorageFromQuery(const String & query, const ContextPtr & context) { auto ast = getASTCreateQueryFromString(query, context); @@ -90,23 +170,35 @@ StoragePtr createStorageFromQuery(const String & query, const ContextPtr & conte false /*has_force_restore_data_flag*/); } -void replaceCnchWithCloud(ASTCreateQuery & create_query, const String & new_table_name, const String & cnch_db, const String & cnch_table) +void replaceCnchWithCloud( + ASTStorage * storage, + const String & cnch_database, + const String & cnch_table, + WorkerEngineType engine_type, + const Strings & engine_args) { - if (!new_table_name.empty()) - create_query.table = new_table_name; - - auto * storage = create_query.storage; + if (!startsWith(storage->engine->name, "Cnch")) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Expect Cnch-family Engine but got {}", storage->engine->name); auto engine = std::make_shared(); - if (auto pos = storage->engine->name.find("Cnch"); pos != std::string::npos) - engine->name = String(storage->engine->name).replace(pos, strlen("Cnch"), "Cloud"); - + engine->name = storage->engine->name.replace(0, strlen("Cnch"), toString(engine_type)); engine->arguments = std::make_shared(); - engine->arguments->children.push_back(std::make_shared(cnch_db)); - engine->arguments->children.push_back(std::make_shared(cnch_table)); - if (storage->unique_key && storage->engine->arguments && storage->engine->arguments->children.size()) - /// NOTE: Used to pass the version column for unique table here. - engine->arguments->children.push_back(storage->engine->arguments->children[0]); + engine->arguments->children.emplace_back(std::make_shared(cnch_database)); + engine->arguments->children.emplace_back(std::make_shared(cnch_table)); + if (!engine_args.empty()) + { + for (const auto & arg : engine_args) + { + engine->arguments->children.emplace_back(std::make_shared(arg)); + } + } + else if (storage->engine->arguments) + { + for (const auto & arg : storage->engine->arguments->children) + { + engine->arguments->children.push_back(arg); + } + } storage->set(storage->engine, engine); } diff --git a/src/CloudServices/CnchCreateQueryHelper.h b/src/CloudServices/CnchCreateQueryHelper.h index fe8d5afa12d..e64371ef9e3 100644 --- a/src/CloudServices/CnchCreateQueryHelper.h +++ b/src/CloudServices/CnchCreateQueryHelper.h @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include #include @@ -23,8 +24,27 @@ namespace DB { +/// used in worker RPC, don't break backward compatibility +enum class WorkerEngineType : uint8_t +{ + CLOUD = 0, // CloudMergeTree + DICT = 1, // DictCloudMergeTree (for BitEngine dict table) +}; + +inline static String toString(WorkerEngineType type) +{ + switch (type) + { + case WorkerEngineType::CLOUD: + return "Cloud"; + case WorkerEngineType::DICT: + return "DictCloud"; + } +} + class ASTCreateQuery; class ASTSetQuery; +class ASTStorage; /// see Databases/DatabaseOnDisk.h extern String getObjectDefinitionFromCreateQuery(const ASTPtr & query, std::optional attach); @@ -32,9 +52,17 @@ extern String getObjectDefinitionFromCreateQuery(const ASTPtr & query, std::opti std::shared_ptr getASTCreateQueryFromString(const String & query, const ContextPtr & context); std::shared_ptr getASTCreateQueryFromStorage(const IStorage & storage, const ContextPtr & context); +StoragePtr createStorageFromQuery(ASTCreateQuery & create_query, ContextMutablePtr context); StoragePtr createStorageFromQuery(const String & query, const ContextPtr & context); -void replaceCnchWithCloud(ASTCreateQuery & create_query, const String & new_table_name, const String & cnch_db, const String & cnch_table); +/// change storage engine from Cnch-family to Cloud-family +/// TODO: can we get rid of engine_args? +void replaceCnchWithCloud( + ASTStorage * storage, + const String & cnch_database, + const String & cnch_table, + WorkerEngineType engine_type = WorkerEngineType::CLOUD, + const Strings & engine_args = {}); void modifyOrAddSetting(ASTSetQuery & set_query, const String & name, Field value); void modifyOrAddSetting(ASTCreateQuery & create_query, const String & name, Field value); diff --git a/src/CloudServices/CnchDataWriter.cpp b/src/CloudServices/CnchDataWriter.cpp index 86c6a9b52df..b6b7d10ead7 100644 --- a/src/CloudServices/CnchDataWriter.cpp +++ b/src/CloudServices/CnchDataWriter.cpp @@ -16,7 +16,7 @@ #include #include #include - +#include #include #include #include @@ -50,9 +50,9 @@ namespace ProfileEvents { -extern const Event CnchWriteDataElapsedMilliseconds; + extern const Event CnchWriteDataElapsedMilliseconds; + extern const Event PreloadSubmitTotalOps; } - namespace DB { namespace ErrorCodes @@ -63,9 +63,17 @@ namespace ErrorCodes extern const int BUCKET_TABLE_ENGINE_MISMATCH; } +bool DumpedData::isEmpty() const +{ + return parts.empty() && bitmaps.empty() && staged_parts.empty(); +} + void DumpedData::extend(DumpedData && data) { - auto extendImpl = [](auto & src, auto && dst) { + if (data.isEmpty()) + return; + + auto extendImpl = [] (auto & src, auto && dst) { if (src.empty()) { src = std::move(dst); @@ -80,6 +88,10 @@ void DumpedData::extend(DumpedData && data) extendImpl(parts, std::move(data.parts)); extendImpl(bitmaps, std::move(data.bitmaps)); extendImpl(staged_parts, std::move(data.staged_parts)); + + if (dedup_mode != data.dedup_mode) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "Dedup mode is mismatch, {}/{}", typeToString(dedup_mode), typeToString(data.dedup_mode)); } using DumpCancelPred = std::function; @@ -145,7 +157,7 @@ DumpedData CnchDataWriter::dumpAndCommitCnchParts( { if (temp_parts.empty() && temp_bitmaps.empty() && temp_staged_parts.empty()) // Nothing to dump and commit, returns - return {}; + return {.dedup_mode = dedup_mode}; LOG_DEBUG( storage.getLogger(), @@ -172,7 +184,7 @@ DumpedData CnchDataWriter::dumpCnchParts( { if (temp_parts.empty() && temp_bitmaps.empty() && temp_staged_parts.empty()) // Nothing to dump, returns - return {}; + return {.dedup_mode = dedup_mode}; Stopwatch watch; @@ -211,7 +223,7 @@ DumpedData CnchDataWriter::dumpCnchParts( auto txn_id = curr_txn->getTransactionID(); /// Write undo buffer first before dump to vfs - std::vector undo_resources; + UndoResources undo_resources; undo_resources.reserve(temp_parts.size() + temp_bitmaps.size() + temp_staged_parts.size()); /// For local parts and stage parts, the remote parts can be at different disk, /// so we record the disk name of each part in the undo buffer. @@ -269,39 +281,80 @@ DumpedData CnchDataWriter::dumpCnchParts( } /// Parallel dumping to shared storage - DumpedData result; + DumpedData result{.dedup_mode = dedup_mode}; S3ObjectMetadata::PartGeneratorID part_generator_id(S3ObjectMetadata::PartGeneratorID::TRANSACTION, curr_txn->getTransactionID().toString()); MergeTreeCNCHDataDumper dumper(storage, part_generator_id); watch.restart(); - ThreadPool dump_pool(std::min( - static_cast(storage.getSettings()->cnch_parallel_dumping_threads), std::max(temp_staged_parts.size(), temp_parts.size()))); + size_t pool_size = std::min(static_cast(storage.getSettings()->cnch_parallel_dumping_threads), std::max(temp_staged_parts.size(), temp_parts.size())); + /// make sure pool_size >= 1 + pool_size = pool_size >= 1 ? pool_size : 1; result.parts.resize(temp_parts.size()); - /// TODO: only use pool if > 1 parts - for (size_t i = 0; i < temp_parts.size(); ++i) - { - dump_pool.scheduleOrThrowOnError([&, i]() { + /// parallel dump delete bitmaps + // TODO: dump all bitmaps to one file to avoid creating too many small files on vfs + result.bitmaps = dumpDeleteBitmaps(storage, temp_bitmaps); + result.staged_parts.resize(temp_staged_parts.size()); + + auto dump_parts = [&, this](size_t i) -> void { + for (; i < temp_parts.size(); i += pool_size) + { const auto & temp_part = temp_parts[i]; auto dumped_part = dumper.dumpTempPart(temp_part, part_disks[i]); LOG_TRACE(storage.getLogger(), "Dumped part {}", temp_part->name); result.parts[i] = std::move(dumped_part); - }); - } - dump_pool.wait(); - // TODO: dump all bitmaps to one file to avoid creating too many small files on vfs - result.bitmaps = dumpDeleteBitmaps(storage, temp_bitmaps); - result.staged_parts.resize(temp_staged_parts.size()); - for (size_t i = 0; i < temp_staged_parts.size(); ++i) - { - dump_pool.scheduleOrThrowOnError([&, i]() { + } + }; + + auto dump_staged_parts = [&, this](size_t i) -> void { + for (; i < temp_staged_parts.size(); i += pool_size) + { const auto & temp_staged_part = temp_staged_parts[i]; auto staged_part = dumper.dumpTempPart(temp_staged_part, part_disks[i + temp_parts.size()]); LOG_TRACE(storage.getLogger(), "Dumped staged part {}", temp_staged_part->name); result.staged_parts[i] = std::move(staged_part); - }); + } + }; + + if (pool_size > 1) + { + ThreadPool dump_pool(pool_size); + for (size_t thread_id = 1; thread_id <= pool_size; thread_id++) + { + dump_pool.scheduleOrThrowOnError([&dump_parts, i = thread_id - 1, thread_group = CurrentThread::getGroup()] + { + SCOPE_EXIT_SAFE({ + if (thread_group) + CurrentThread::detachQueryIfNotDetached(); + }); + if (thread_group) + CurrentThread::attachTo(thread_group); + dump_parts(i); + }); + } + dump_pool.wait(); + + for (size_t thread_id = 1; thread_id <= pool_size; thread_id++) + { + dump_pool.scheduleOrThrowOnError([&dump_staged_parts, i = thread_id - 1, thread_group = CurrentThread::getGroup()] + { + SCOPE_EXIT_SAFE({ + if (thread_group) + CurrentThread::detachQueryIfNotDetached(); + }); + if (thread_group) + CurrentThread::attachTo(thread_group); + dump_staged_parts(i); + }); + } + dump_pool.wait(); + } + else + { + assert(pool_size == 1); + dump_parts(0); + dump_staged_parts(0); } - dump_pool.wait(); LOG_DEBUG( storage.getLogger(), @@ -326,6 +379,7 @@ void CnchDataWriter::commitDumpedParts(const DumpedData & dumped_data) return; TxnTimestamp txn_id = context->getCurrentTransactionID(); + UInt32 dedup_impl_version = 0; try { @@ -335,8 +389,7 @@ void CnchDataWriter::commitDumpedParts(const DumpedData & dumped_data) if (settings.debug_cnch_force_commit_parts_rpc) { auto server_client = context->getCnchServerClient("0.0.0.0", context->getRPCPort()); - server_client->commitParts(txn_id, type, storage, dumped_parts, delete_bitmaps, dumped_staged_parts, task_id, false, - consumer_group, tpl, binlog, peak_memory_usage); + server_client->commitParts(txn_id, type, storage, dumped_data, task_id, false, consumer_group, tpl, binlog, peak_memory_usage); } else { @@ -347,7 +400,8 @@ void CnchDataWriter::commitDumpedParts(const DumpedData & dumped_data) { auto is_server = context->getServerType() == ServerType::cnch_server; CnchServerClientPtr server_client; - if (auto worker_txn = dynamic_pointer_cast(context->getCurrentTransaction()); worker_txn && worker_txn->tryGetServerClient()) + auto worker_txn = dynamic_pointer_cast(context->getCurrentTransaction()); + if (worker_txn && worker_txn->tryGetServerClient()) { /// case: client submits INSERTs directly to worker server_client = worker_txn->getServerClient(); @@ -362,8 +416,9 @@ void CnchDataWriter::commitDumpedParts(const DumpedData & dumped_data) throw Exception("Server with transaction " + txn_id.toString() + " is unknown", ErrorCodes::LOGICAL_ERROR); } - server_client->precommitParts( - context, txn_id, type, storage, dumped_parts, delete_bitmaps, dumped_staged_parts, task_id, is_server, consumer_group, tpl, binlog, peak_memory_usage); + dedup_impl_version = server_client->precommitParts(context, txn_id, type, storage, dumped_data, task_id, is_server, consumer_group, tpl, binlog, peak_memory_usage); + if (worker_txn) + worker_txn->setDedupImplVersion(dedup_impl_version); } } catch (const Exception &) @@ -372,20 +427,18 @@ void CnchDataWriter::commitDumpedParts(const DumpedData & dumped_data) throw; } - if (auto part_log = context->getPartLog(storage.getDatabaseName())) - { - // for (auto & dumped_part : dumped_parts) - // part_log->add(PartLog::createElement(PartLogElement::COMMIT_PART, dumped_part, watch.elapsed())); - } + /// part log will be written in InsertAction::postCommit LOG_DEBUG( storage.getLogger(), - "Committed {} parts, {} bitmaps, {} staged parts in transaction {}, elapsed {} ms", + "Committed {} parts, {} bitmaps, {} staged parts in transaction {}, elapsed {} ms, dedup mode is {}, dedup impl version is {}", dumped_parts.size(), delete_bitmaps.size(), dumped_staged_parts.size(), toString(UInt64(txn_id)), - watch.elapsedMilliseconds()); + watch.elapsedMilliseconds(), + typeToString(dumped_data.dedup_mode), + dedup_impl_version); } void CnchDataWriter::initialize(size_t max_threads) @@ -560,7 +613,9 @@ void CnchDataWriter::commitPreparedCnchParts(const DumpedData & dumped_data, con } // Precommit stage. Write intermediate parts to KV - auto action = txn->createAction(storage_ptr, dumped_data.parts, dumped_data.bitmaps, dumped_data.staged_parts); + auto action + = txn->createAction(storage_ptr, dumped_data.parts, dumped_data.bitmaps, dumped_data.staged_parts); + action->as()->checkAndSetDedupMode(dumped_data.dedup_mode); txn->appendAction(action); action->executeV2(); } @@ -650,10 +705,18 @@ void CnchDataWriter::commitPreparedCnchParts(const DumpedData & dumped_data, con void CnchDataWriter::publishStagedParts(const MergeTreeDataPartsCNCHVector & staged_parts, const LocalDeleteBitmaps & bitmaps_to_dump) { + if (dedup_mode != CnchDedupHelper::DedupMode::APPEND) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Dedup mode is not append, but got {} when publish staged parts for table {}, it's a bug!", + typeToString(dedup_mode), + storage.getCnchStorageID().getNameForLogs()); DumpedData items; + items.dedup_mode = dedup_mode; + TxnTimestamp txn_id = context->getCurrentTransactionID(); - for (auto & staged_part : staged_parts) + for (const auto & staged_part : staged_parts) { // new part that shares the data file with the staged part Protos::DataModelPart new_part_model; @@ -683,7 +746,7 @@ void CnchDataWriter::publishStagedParts(const MergeTreeDataPartsCNCHVector & sta /// prepare undo resources /// setMetadata() return reference, so need to cast move - std::vector undo_resources; + UndoResources undo_resources; for (auto & part : items.parts) undo_resources.emplace_back( std::move(UndoResource(txn_id, UndoResourceType::Part, part->info.getPartNameWithHintMutation()).setMetadataOnly(true))); @@ -715,23 +778,25 @@ void CnchDataWriter::publishStagedParts(const MergeTreeDataPartsCNCHVector & sta void CnchDataWriter::preload(const MutableMergeTreeDataPartsCNCHVector & dumped_parts) { - if (context->tryGetPreloadThrottler()) - context->tryGetPreloadThrottler()->add(1); - const auto & settings = context->getSettingsRef(); + + if (!settings.parts_preload_level || (!storage.getSettings()->parts_preload_level && !storage.getSettings()->enable_preload_parts) + || !(storage.getSettings()->enable_local_disk_cache)) + return; + try { Stopwatch timer; auto server_client = context->getCnchServerClientPool().get(); MutableMergeTreeDataPartsCNCHVector preload_parts; std::copy_if(dumped_parts.begin(), dumped_parts.end(), std::back_inserter(preload_parts), [](const auto & part) { - return !part->deleted && !part->isPartial(); + return !part->deleted; }); if (!preload_parts.empty()) { - auto max_timeout = std::max(30 * 1000L, settings.max_execution_time.totalMilliseconds()); - server_client->submitPreloadTask(storage, preload_parts, max_timeout); + ProfileEvents::increment(ProfileEvents::PreloadSubmitTotalOps, 1, Metrics::MetricType::Rate); + server_client->submitPreloadTask(storage, preload_parts, settings.preload_send_rpc_max_ms); LOG_DEBUG( storage.getLogger(), "Finish submit preload {} task for {} parts to server {}, elapsed {} ms", @@ -740,7 +805,6 @@ void CnchDataWriter::preload(const MutableMergeTreeDataPartsCNCHVector & dumped_ server_client->getRPCAddress(), timer.elapsedMilliseconds()); } - // TODO: invalidate deleted part's disk cache } catch (...) { diff --git a/src/CloudServices/CnchDataWriter.h b/src/CloudServices/CnchDataWriter.h index 91eb44b6209..b837c3e9eb9 100644 --- a/src/CloudServices/CnchDataWriter.h +++ b/src/CloudServices/CnchDataWriter.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace DB { @@ -36,7 +37,9 @@ struct DumpedData MutableMergeTreeDataPartsCNCHVector parts; DeleteBitmapMetaPtrVector bitmaps; MutableMergeTreeDataPartsCNCHVector staged_parts; + CnchDedupHelper::DedupMode dedup_mode = CnchDedupHelper::DedupMode::APPEND; + bool isEmpty() const; void extend(DumpedData && data); }; @@ -87,6 +90,19 @@ class CnchDataWriter : private boost::noncopyable void setPeakMemoryUsage(UInt64 peak_memory_usage_) { peak_memory_usage = peak_memory_usage_; } + void setDedupMode(CnchDedupHelper::DedupMode dedup_mode_) + { + dedup_mode = dedup_mode_; + res.dedup_mode = dedup_mode; + } + + CnchDedupHelper::DedupMode getDedupMode() const + { + return dedup_mode; + } + + bool isNeedDedupStage() const { return dedup_mode != CnchDedupHelper::DedupMode::APPEND; } + DumpedData res; private: @@ -109,6 +125,8 @@ class CnchDataWriter : private boost::noncopyable UInt64 peak_memory_usage; + CnchDedupHelper::DedupMode dedup_mode = CnchDedupHelper::DedupMode::APPEND; + UUID newPartID(const MergeTreePartInfo& part_info, UInt64 txn_timestamp); }; diff --git a/src/CloudServices/CnchDedupHelper.cpp b/src/CloudServices/CnchDedupHelper.cpp index 37faf3706c1..ccc6eba0715 100644 --- a/src/CloudServices/CnchDedupHelper.cpp +++ b/src/CloudServices/CnchDedupHelper.cpp @@ -18,10 +18,17 @@ #include #include #include +#include +#include +#include +#include +#include namespace DB::ErrorCodes { extern const int LOGICAL_ERROR; +extern const int ABORTED; +extern const int CNCH_LOCK_ACQUIRE_FAILED; } namespace DB::CnchDedupHelper @@ -259,4 +266,110 @@ void DedupScope::filterParts(MergeTreeDataPartsCNCHVector & parts) const parts.end()); } +UInt64 getWriteLockTimeout(StorageCnchMergeTree & cnch_table, ContextPtr local_context) +{ + UInt64 session_value = local_context->getSettingsRef().unique_acquire_write_lock_timeout.value.totalMilliseconds(); + return session_value == 0 ? cnch_table.getSettings()->unique_acquire_write_lock_timeout.value.totalMilliseconds() : session_value; +} + +void acquireLockAndFillDedupTask(StorageCnchMergeTree & cnch_table, DedupTask & dedup_task, CnchServerTransaction & txn, ContextPtr local_context) +{ + /// Note: when txn is launched by worker, local_context is global context which means session settings will not take effect. TBD: support later. + TxnTimestamp ts; + std::sort(dedup_task.new_parts.begin(), dedup_task.new_parts.end(), [](auto & lhs, auto & rhs) { return lhs->info < rhs->info; }); + std::sort(dedup_task.delete_bitmaps_for_new_parts.begin(), dedup_task.delete_bitmaps_for_new_parts.end(), LessDeleteBitmapMeta()); + CnchLockHolderPtr cnch_lock; + MergeTreeDataPartsCNCHVector visible_parts, staged_parts; + bool force_normal_dedup = false; + Stopwatch watch; + do + { + CnchDedupHelper::DedupScope scope = CnchDedupHelper::getDedupScope(cnch_table, dedup_task.new_parts, force_normal_dedup); + + std::vector locks_to_acquire = CnchDedupHelper::getLocksToAcquire( + scope, txn.getTransactionID(), cnch_table, CnchDedupHelper::getWriteLockTimeout(cnch_table, local_context)); + watch.restart(); + cnch_lock = std::make_shared(local_context, std::move(locks_to_acquire)); + if (!cnch_lock->tryLock()) + { + if (auto unique_table_log = local_context->getCloudUniqueTableLog()) + { + auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, cnch_table.getCnchStorageID()); + current_log.txn_id = txn.getTransactionID(); + current_log.metric = ErrorCodes::CNCH_LOCK_ACQUIRE_FAILED; + current_log.event_msg = "Failed to acquire lock for txn " + txn.getTransactionID().toString(); + unique_table_log->add(current_log); + } + throw Exception("Failed to acquire lock for txn " + txn.getTransactionID().toString(), ErrorCodes::CNCH_LOCK_ACQUIRE_FAILED); + } + dedup_task.statistics.acquire_lock_cost += watch.elapsedMilliseconds(); + + watch.restart(); + ts = local_context->getTimestamp(); /// must get a new ts after locks are acquired + visible_parts = CnchDedupHelper::getVisiblePartsToDedup(scope, cnch_table, ts); + staged_parts = CnchDedupHelper::getStagedPartsToDedup(scope, cnch_table, ts); + dedup_task.statistics.get_metadata_cost += watch.elapsedMilliseconds(); + + /// In some case, visible parts or staged parts doesn't have same bucket definition or not a bucket part, we need to convert bucket lock to normal lock. + /// Otherwise, it may lead to duplicated data. + if (scope.isBucketLock() && !cnch_table.getSettings()->enable_bucket_level_unique_keys + && !CnchDedupHelper::checkBucketParts(cnch_table, visible_parts, staged_parts)) + { + force_normal_dedup = true; + cnch_lock->unlock(); + LOG_TRACE(txn.getLogger(), "Check bucket parts failed, switch to normal lock to dedup."); + continue; + } + else + { + /// Filter staged parts if lock scope is bucket level + scope.filterParts(staged_parts); + break; + } + } while (true); + + if (unlikely(local_context->getSettingsRef().unique_sleep_seconds_after_acquire_lock.totalSeconds())) + { + /// Test purpose only + std::this_thread::sleep_for(std::chrono::seconds(local_context->getSettingsRef().unique_sleep_seconds_after_acquire_lock.totalSeconds())); + } + + for (auto & visible_part: visible_parts) + { + dedup_task.visible_parts.emplace_back(std::const_pointer_cast(visible_part)); + for (const auto & bitmap_model : visible_part->delete_bitmap_metas) + dedup_task.delete_bitmaps_for_visible_parts.emplace_back(createFromModel(cnch_table, *bitmap_model)); + } + for (auto & staged_part: staged_parts) + { + dedup_task.staged_parts.emplace_back(std::const_pointer_cast(staged_part)); + for (const auto & bitmap_model: staged_part->delete_bitmap_metas) + dedup_task.delete_bitmaps_for_staged_parts.emplace_back(createFromModel(cnch_table, *bitmap_model)); + } + txn.appendLockHolder(cnch_lock); +} + +void executeDedupTask(StorageCnchMergeTree & cnch_table, DedupTask & dedup_task, const TxnTimestamp & txn_id, ContextPtr local_context) +{ + /// Precondition: parts already be sorted. + cnch_table.getDeleteBitmapMetaForCnchParts(dedup_task.visible_parts, dedup_task.delete_bitmaps_for_visible_parts, /*force_found=*/true); + cnch_table.getDeleteBitmapMetaForCnchParts(dedup_task.new_parts, dedup_task.delete_bitmaps_for_new_parts, /*force_found=*/false); + cnch_table.getDeleteBitmapMetaForCnchParts(dedup_task.staged_parts, dedup_task.delete_bitmaps_for_staged_parts, /*force_found=*/false); + MergeTreeDataDeduper deduper(cnch_table, local_context, dedup_task.dedup_mode); + LocalDeleteBitmaps bitmaps_to_dump = deduper.dedupParts( + txn_id, + {dedup_task.visible_parts.begin(), dedup_task.visible_parts.end()}, + {dedup_task.staged_parts.begin(), dedup_task.staged_parts.end()}, + {dedup_task.new_parts.begin(), dedup_task.new_parts.end()}); + + Stopwatch watch; + CnchDataWriter cnch_writer(cnch_table, local_context, ManipulationType::Insert); + cnch_writer.publishStagedParts({dedup_task.staged_parts.begin(), dedup_task.staged_parts.end()}, bitmaps_to_dump); + LOG_DEBUG( + cnch_table.getLogger(), + "Publish staged parts take {} ms, txn id: {}, dedup mode: {}", + watch.elapsedMilliseconds(), + txn_id.toUInt64(), + typeToString(dedup_task.dedup_mode)); +} } diff --git a/src/CloudServices/CnchDedupHelper.h b/src/CloudServices/CnchDedupHelper.h index 33b590d1e90..f5ed9ce548f 100644 --- a/src/CloudServices/CnchDedupHelper.h +++ b/src/CloudServices/CnchDedupHelper.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include #include @@ -28,23 +30,52 @@ namespace DB { class MergeTreeMetaBase; class StorageCnchMergeTree; +class DeleteBitmapMeta; +using DeleteBitmapMetaPtr = std::shared_ptr; +using DeleteBitmapMetaPtrVector = std::vector; +class CnchServerTransaction; } namespace DB::CnchDedupHelper { +enum class DedupMode : unsigned int +{ + APPEND = 0, + UPSERT, + THROW, + IGNORE +}; + +inline String typeToString(DedupMode type) +{ + switch (type) + { + case DedupMode::APPEND: + return "APPEND"; + case DedupMode::UPSERT: + return "UPSERT"; + case DedupMode::THROW: + return "THROW"; + case DedupMode::IGNORE: + return "IGNORE"; + default: + return "Unknown"; + } +} + class DedupScope { public: - enum class DedupMode + enum class DedupLevel { TABLE, PARTITION, }; - enum class LockMode + enum class LockLevel { NORMAL, /// For NORMAL lock mode, if dedup mode is table, it's table level. Otherwise, it's partition level. BUCKET, /// BUCKET level lock mode. @@ -63,35 +94,35 @@ class DedupScope static DedupScope TableDedup() { - static DedupScope table_scope{DedupMode::TABLE}; + static DedupScope table_scope{DedupLevel::TABLE}; return table_scope; } static DedupScope TableDedupWithBucket(const BucketSet & buckets_) { - DedupScope table_scope{DedupMode::TABLE, LockMode::BUCKET}; + DedupScope table_scope{DedupLevel::TABLE, LockLevel::BUCKET}; table_scope.buckets = buckets_; return table_scope; } static DedupScope PartitionDedup(const NameOrderedSet & partitions_) { - DedupScope partition_scope{DedupMode::PARTITION}; + DedupScope partition_scope{DedupLevel::PARTITION}; partition_scope.partitions = partitions_; return partition_scope; } static DedupScope PartitionDedupWithBucket(const BucketWithPartitionSet & bucket_with_partition_set_) { - DedupScope partition_scope{DedupMode::PARTITION, LockMode::BUCKET}; + DedupScope partition_scope{DedupLevel::PARTITION, LockLevel::BUCKET}; partition_scope.bucket_with_partition_set = bucket_with_partition_set_; for (const auto & bucket_with_partition : partition_scope.bucket_with_partition_set) partition_scope.partitions.insert(bucket_with_partition.first); return partition_scope; } - bool isTableDedup() const { return dedup_mode == DedupMode::TABLE; } - bool isBucketLock() const { return lock_mode == LockMode::BUCKET; } + bool isTableDedup() const { return dedup_level == DedupLevel::TABLE; } + bool isBucketLock() const { return lock_level == LockLevel::BUCKET; } const NameOrderedSet & getPartitions() const { return partitions; } @@ -103,10 +134,10 @@ class DedupScope void filterParts(MergeTreeDataPartsCNCHVector & parts) const; private: - DedupScope(DedupMode dedup_mode_, LockMode lock_mode_ = LockMode::NORMAL) : dedup_mode(dedup_mode_), lock_mode(lock_mode_) { } + DedupScope(DedupLevel dedup_level_, LockLevel lock_level_ = LockLevel::NORMAL) : dedup_level(dedup_level_), lock_level(lock_level_) { } - DedupMode dedup_mode; - LockMode lock_mode; + DedupLevel dedup_level; + LockLevel lock_level; NameOrderedSet partitions; BucketSet buckets; @@ -141,4 +172,48 @@ bool checkBucketParts( const MergeTreeDataPartsCNCHVector & visible_parts, const MergeTreeDataPartsCNCHVector & staged_parts); +struct DedupTask +{ + DedupMode dedup_mode; + StorageID storage_id; + MutableMergeTreeDataPartsCNCHVector new_parts; + DeleteBitmapMetaPtrVector delete_bitmaps_for_new_parts; + + MutableMergeTreeDataPartsCNCHVector staged_parts; + DeleteBitmapMetaPtrVector delete_bitmaps_for_staged_parts; + + MutableMergeTreeDataPartsCNCHVector visible_parts; + DeleteBitmapMetaPtrVector delete_bitmaps_for_visible_parts; + + struct Statistics + { + /// Record time cost for each stage(ms) + UInt64 acquire_lock_cost = 0; + UInt64 get_metadata_cost = 0; + UInt64 execute_task_cost = 0; + UInt64 other_cost = 0; + UInt64 total_cost = 0; + + String toString() + { + return fmt::format( + "[acquire lock cost {} ms, get metadata cost {} ms, execute task cost {} ms, other cost {} ms, total cost {} ms]", + acquire_lock_cost, + get_metadata_cost, + execute_task_cost, + other_cost, + total_cost); + } + } statistics; + + explicit DedupTask(const DedupMode & dedup_mode_, const StorageID & storage_id_) : dedup_mode(dedup_mode_), storage_id(storage_id_) { } +}; +using DedupTaskPtr = std::shared_ptr; + +UInt64 getWriteLockTimeout(StorageCnchMergeTree & cnch_table, ContextPtr local_context); + +void acquireLockAndFillDedupTask(StorageCnchMergeTree & cnch_table, DedupTask & dedup_task, CnchServerTransaction & txn, ContextPtr local_context); + +void executeDedupTask(StorageCnchMergeTree & cnch_table, DedupTask & dedup_task, const TxnTimestamp & txn_id, ContextPtr local_context); + } diff --git a/src/CloudServices/CnchMergeMutateThread.cpp b/src/CloudServices/CnchMergeMutateThread.cpp index b957eb6b28c..87360a13d72 100644 --- a/src/CloudServices/CnchMergeMutateThread.cpp +++ b/src/CloudServices/CnchMergeMutateThread.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -61,8 +62,12 @@ namespace constexpr auto DELAY_SCHEDULE_TIME_IN_SECOND = 60ul; constexpr auto DELAY_SCHEDULE_RANDOM_TIME_IN_SECOND = 3ul; - bool needMutate(const ServerDataPartPtr & part, const TxnTimestamp & commit_ts, bool change_schema) + bool needMutate(const ServerDataPartPtr & part, const TxnTimestamp & commit_ts, bool change_schema, bool is_recluster, const TableDefinitionHash & table_definition_hash) { + if (is_recluster) + { + return !table_definition_hash.match(part->part_model().table_definition_hash()); + } /// Some mutation commands (@see MutationCommands::changeSchema()) will not change the table schema /// which means it will not update columns_commit_time. To track those mutation commands, /// we add a new field `mutation_commit_time` in part metadata. And it's set to 0 for a new part by default. @@ -79,7 +84,7 @@ namespace TxnTimestamp getFirstMutation(const ServerDataPartPtr & part, const std::vector> & mutations) { for (const auto & [commit_time, change_schema] : mutations) - if (needMutate(part, commit_time, change_schema)) + if (needMutate(part, commit_time, change_schema, false, {})) return commit_time; return TxnTimestamp::maxTS(); } @@ -133,7 +138,7 @@ namespace } /// We maintain merging_mutating_parts based on the merge task's lifecycle. -/// Source parts are added to merging_mutating_parts when task is created, see FutureManipulationTask::assignSourceParts. +/// Source parts are added to merging_mutating_parts when task is created, see FutureManipulationTask::tagSourceParts. /// And they are removed from merging_mutating_parts when task record is destroyed. /// As the merge txn is committed in a 2-phase style, we need to hold the task record until txn phase-2 finish (success or fail). /// * phase 1 - ::finishTask is called. We mark the task record as committing by set commit_start_time, instead of destroy the record. @@ -146,7 +151,15 @@ ManipulationTaskRecord::~ManipulationTaskRecord() { std::lock_guard lock(parent.currently_merging_mutating_parts_mutex); for (auto & part : parts) + { parent.currently_merging_mutating_parts.erase(part->name()); + auto prev_part = part->tryGetPreviousPart(); + while(prev_part) + { + parent.currently_merging_mutating_parts.erase(prev_part->name()); + prev_part = prev_part->tryGetPreviousPart(); + } + } } { @@ -161,11 +174,22 @@ ManipulationTaskRecord::~ManipulationTaskRecord() } } -Strings ManipulationTaskRecord::getSourcePartNames() const +Strings ManipulationTaskRecord::getSourcePartNames(bool flatten) const { Strings res; for (const auto & part : parts) + { res.emplace_back(part->name()); + if (likely(flatten)) + { + auto prev_part = part->tryGetPreviousPart(); + while (prev_part) + { + res.emplace_back(prev_part->name()); + prev_part = prev_part->tryGetPreviousPart(); + } + } + } return res; } @@ -177,7 +201,15 @@ FutureManipulationTask::~FutureManipulationTask() { std::lock_guard lock(parent.currently_merging_mutating_parts_mutex); for (auto & part : parts) + { parent.currently_merging_mutating_parts.erase(part->name()); + auto prev_part = part->tryGetPreviousPart(); + while(prev_part) + { + parent.currently_merging_mutating_parts.erase(prev_part->name()); + prev_part = prev_part->tryGetPreviousPart(); + } + } } } catch (...) @@ -186,26 +218,38 @@ FutureManipulationTask::~FutureManipulationTask() } } -FutureManipulationTask & FutureManipulationTask::assignSourceParts(ServerDataPartsVector && parts_) +/// Add source parts (include invisible parts) to merging_mutating_parts. +FutureManipulationTask & FutureManipulationTask::tagSourceParts(ServerDataPartsVector && parts_) { - for (auto & part : parts_) - { - LOG_DEBUG(&Poco::Logger::get("MergeMutateDEBUG"), "assignSourceParts part {} name {}", static_cast(part.get()), part->name()); - } - - /// flatten the parts - CnchPartsHelper::flattenPartsVector(parts_); + auto check_and_add = [&](const auto & part_name) { + if (parent.currently_merging_mutating_parts.count(part_name)) + throw Exception("Part '" + part_name + "' was already in other Task, cancel merge.", ErrorCodes::ABORTED); + parent.currently_merging_mutating_parts.emplace(part_name); + }; if (!record->try_execute) { std::lock_guard lock(parent.currently_merging_mutating_parts_mutex); - for (auto & part : parts_) - if (parent.currently_merging_mutating_parts.count(part->name())) - throw Exception("Part '" + part->name() + "' was already in other Task, cancel merge.", ErrorCodes::ABORTED); + for (const auto & p : parts_) + { + check_and_add(p->name()); + + auto prev_part = p->tryGetPreviousPart(); + while (prev_part) + { + check_and_add(prev_part->name()); + prev_part = prev_part->tryGetPreviousPart(); + } + } + } - for (auto & part : parts_) - parent.currently_merging_mutating_parts.emplace(part->name()); + if (parent.log->trace()) + { + WriteBufferFromOwnString wb; + for (const auto & p : parts_) + wb << p->name() << " "; + LOG_TRACE(parent.log, "Added parts to merging_mutating_parts: {}", wb.str()); } parts = std::move(parts_); @@ -608,16 +652,14 @@ bool CnchMergeMutateThread::tryMergeParts(StoragePtr & istorage, StorageCnchMerg submitFutureManipulationTask(storage, *future_task); } - try - { - /// TODO: catch the exception during tryMergeParts() ? - - writePartMergeLogElement(istorage, part_merge_log_elem, metrics); - } - catch (...) - { - tryLogCurrentException(__PRETTY_FUNCTION__); - } + // try + // { + // writePartMergeLogElement(istorage, part_merge_log_elem, metrics); + // } + // catch (...) + // { + // tryLogCurrentException(__PRETTY_FUNCTION__); + // } return result; } @@ -650,6 +692,12 @@ bool CnchMergeMutateThread::trySelectPartsToMerge(StoragePtr & istorage, Storage bool only_realtime_partition = storage_settings->cnch_merge_only_realtime_partition; auto partitions = partition_selector->selectForMerge(istorage, num_partitions, only_realtime_partition); + if (partitions.empty()) + { + LOG_TRACE(log, "Skip empty table"); + return false; + } + metrics.num_partitions = partitions.size(); partitions = removeLockedPartition(partitions); metrics.num_unlock_partitions = partitions.size(); @@ -775,7 +823,7 @@ bool CnchMergeMutateThread::trySelectPartsToMerge(StoragePtr & istorage, Storage postpone_partitions.erase(selected_parts.front()->info().partition_id); auto future_task = std::make_unique(*this, ManipulationType::Merge); - future_task->assignSourceParts(std::move(selected_parts)); + future_task->tagSourceParts(std::move(selected_parts)); merge_pending_queue.push(std::move(future_task)); } @@ -790,7 +838,7 @@ bool CnchMergeMutateThread::trySelectPartsToMerge(StoragePtr & istorage, Storage Strings CnchMergeMutateThread::removeLockedPartition(const Strings & partitions) { - constexpr UInt64 SLOW_THRESHOLD_MS = 200; + constexpr UInt64 slow_threshold_ms = 200; Stopwatch watch; auto & txn_coordinator = getContext()->getCnchTransactionCoordinator(); auto transaction = txn_coordinator.createTransaction( @@ -810,14 +858,14 @@ Strings CnchMergeMutateThread::removeLockedPartition(const Strings & partitions) auto txn_id = transaction->getTransactionID(); Strings res; std::for_each(partitions.begin(), partitions.end(), - [& res, & transaction, txn_id, this] (const String & partition) + [& res, txn_id, this] (const String & partition) { LockInfoPtr partition_lock = std::make_shared(txn_id); partition_lock->setMode(LockMode::X); partition_lock->setUUIDAndPrefix(getStorageID().uuid, LockInfo::task_domain); partition_lock->setPartition(partition); - auto cnch_lock = transaction->createLockHolder({std::move(partition_lock)}); + auto cnch_lock = std::make_shared(getContext(), std::move(partition_lock)); if (cnch_lock->tryLock()) { LOG_TRACE(log, "partition {} is not lock", partition); @@ -834,7 +882,7 @@ Strings CnchMergeMutateThread::removeLockedPartition(const Strings & partitions) /// And finishTransaction in the SCOPE_EXIT make sure the txn is clean by server but not DM. transaction->commitV2(); UInt64 milliseconds = watch.elapsedMilliseconds(); - if (milliseconds >= SLOW_THRESHOLD_MS) + if (milliseconds >= slow_threshold_ms) LOG_INFO(log, "removeLockedPartition took {} ms.", milliseconds); return res; } @@ -887,7 +935,7 @@ String CnchMergeMutateThread::submitFutureManipulationTask( } } - auto cnch_lock = transaction->createLockHolder({std::move(partition_lock)}); + auto cnch_lock = std::make_shared(getContext(), std::move(partition_lock)); if (type == ManipulationType::Merge || type == ManipulationType::Mutate || type == ManipulationType::Clustering) cnch_lock->lock(); @@ -945,7 +993,6 @@ String CnchMergeMutateThread::submitFutureManipulationTask( task_record.task_id = params.task_id; task_record.worker = worker_client; - task_record.result_part_name = params.new_part_names.front(); task_record.manipulation_entry = local_context->getGlobalContext()->getManipulationList().insert(params, true, getContext()); task_record.manipulation_entry->get()->related_node = worker_client->getRPCAddress(); @@ -1022,6 +1069,7 @@ String CnchMergeMutateThread::triggerPartMerge( std::map mutation_entries; std::vector> mutation_timestamps; catalog->fillMutationsByStorage(storage_id, mutation_entries); + mutation_timestamps.reserve(mutation_entries.size()); for (const auto & [_, mutation_entry] : mutation_entries) mutation_timestamps.emplace_back(mutation_entry.commit_time, mutation_entry.commands.changeSchema()); @@ -1081,7 +1129,7 @@ String CnchMergeMutateThread::triggerPartMerge( storage, FutureManipulationTask(*this, ManipulationType::Merge) .setTryExecute(try_execute) - .assignSourceParts(std::move(res.front())), + .tagSourceParts(std::move(res.front())), true); } @@ -1186,7 +1234,7 @@ void CnchMergeMutateThread::finishTask(const String & task_id, std::functionsecond->commit_start_time = time(nullptr); partition_id = it->second->parts.front()->info().partition_id; - source_part_names = it->second->getSourcePartNames(); + source_part_names = it->second->getSourcePartNames(/*flatten*/true); try_execute = it->second->try_execute; manipulation_submit_time_ns = it->second->submit_time_ns; } @@ -1248,10 +1296,7 @@ ClusterTaskProgress CnchMergeMutateThread::getReclusteringTaskProgress() if (partition_list.empty()) return cluster_task_progress; - if (!scheduled_mutation_partitions.empty()) - cluster_task_progress.progress = (scheduled_mutation_partitions.size() / static_cast(partition_list.size())) * 100; - else if (!finish_mutation_partitions.empty()) - cluster_task_progress.progress = (finish_mutation_partitions.size() / static_cast(partition_list.size())) * 100; + cluster_task_progress.progress = (finish_mutation_partitions.size() / static_cast(partition_list.size())) * 100; cluster_task_progress.start_time_seconds = current_mutate_entry->commit_time.toSecond(); return cluster_task_progress; } @@ -1291,6 +1336,19 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer std::lock_guard lock(try_mutate_parts_mutex); auto merging_mutating_parts_snapshot = copyCurrentlyMergingMutatingParts(); + auto finish_current_mutation = [this, &lock, &storage]() + { + if (!current_mutate_entry) + return; + + removeMutationEntryFromKV(*current_mutate_entry, lock); + storage.removeMutationEntry(current_mutate_entry->commit_time); + + scheduled_mutation_partitions.clear(); + finish_mutation_partitions.clear(); + current_mutate_entry.reset(); + }; + /// Fetch mutation entries from KV. std::map current_mutations_by_version; auto catalog = getContext()->getCnchCatalog(); @@ -1316,10 +1374,33 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer current_mutate_entry = std::make_optional(entry_from_catalog); } - if (current_mutate_entry->isReclusterMutation() && !getContext()->getTableReclusterTaskStatus(storage_id)) - return false; + if (current_mutate_entry->isReclusterMutation()) + { + if (!getContext()->getTableReclusterTaskStatus(storage_id)) + { + LOG_TRACE(log, "recluster task is disabled for {}", storage_id.getNameForLogs()); + return false; + } + if (current_mutate_entry->columns_commit_time < storage.commit_time) + { + /// There is newer version storage, needs to check whether `cluster by` definition changed + + /// get specific version storage + auto entry_istorage = catalog->getTableByUUID(*getContext(), toString(storage_id.uuid), current_mutate_entry->columns_commit_time); + auto & entry_cnch_table = checkAndGetCnchTable(entry_istorage); + + if (entry_cnch_table.getTableHashForClusterBy() != storage.getTableHashForClusterBy()) + { + LOG_INFO(log, "recluster task {} is canceled due to newer version cluster by", current_mutate_entry->txn_id.toString()); + finish_current_mutation(); + return false; + } + } + } bool change_schema = current_mutate_entry->commands.changeSchema(); + bool is_recluster = current_mutate_entry->isReclusterMutation(); + auto table_definition_hash = storage.getTableHashForClusterBy(); /// Function to generating new tasks. Return true if we can still generate new tasks. auto generate_tasks = [&](const ServerDataPartsVector & visible_parts, const NameSet & snapshot) @@ -1330,74 +1411,38 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer size_t curr_mutate_part_size = 0; ServerDataPartsVector alter_parts; bool remain_tasks_in_partition = false; - String command_partition_id; - if (type == ManipulationType::Clustering) + for (const auto & part : visible_parts) { - auto mutation_command = current_mutate_entry->commands[0]; - if (mutation_command.partition) - command_partition_id = storage.getPartitionIDFromQuery(mutation_command.partition, getContext()); - else if (mutation_command.predicate) + if (!needMutate(part, commit_ts, change_schema, is_recluster, table_definition_hash)) + continue; + if (snapshot.count(part->name())) { - ServerDataPartsVector parts_to_recluster; - auto table_definition_hash = storage.getTableHashForClusterBy(); - for (const auto & part : visible_parts) - { - if (!table_definition_hash.match(part->part_model().table_definition_hash())) - parts_to_recluster.push_back(part); - } - - /// TODO: (vivek, zuochuang.zema) why not filter by columns_commit_time and mutation_commit_time? - alter_parts = storage.getServerPartsByPredicate( - mutation_command.predicate, - [&]{ return parts_to_recluster; }, - getContext()); + remain_tasks_in_partition = true; + continue; } - } - - if (alter_parts.empty()) - { - for (const auto & part : visible_parts) + remain_tasks_in_partition = true; + alter_parts.push_back(part); + curr_mutate_part_size += part->part_model().size(); + auto p_part = part->tryGetPreviousPart(); + while (p_part) { - if (!needMutate(part, commit_ts, change_schema)) - continue; - - if (snapshot.count(part->name())) - { - remain_tasks_in_partition = true; - continue; - } - - remain_tasks_in_partition = true; - - if (type == ManipulationType::Clustering - && command_partition_id != part->partition().getID(storage)) - continue; - - alter_parts.push_back(part); - curr_mutate_part_size += part->part_model().size(); - auto p_part = part->tryGetPreviousPart(); - while (p_part) - { - curr_mutate_part_size += p_part->part_model().size(); - p_part = p_part->tryGetPreviousPart(); - } - - /// Batch n parts in one task. - if (alter_parts.size() >= storage_settings->cnch_mutate_max_parts_to_mutate - || curr_mutate_part_size >= storage_settings->cnch_mutate_max_total_bytes_to_mutate) - { - submitFutureManipulationTask( - storage, - FutureManipulationTask(*this, type) - .setMutationEntry(*current_mutate_entry) - .assignSourceParts(std::move(alter_parts))); - - alter_parts.clear(); - curr_mutate_part_size = 0; - if (running_mutation_tasks >= storage.getSettings()->max_addition_mutation_task_num) - return true; - } + curr_mutate_part_size += p_part->part_model().size(); + p_part = p_part->tryGetPreviousPart(); + } + /// Batch n parts in one task. + if (alter_parts.size() >= storage_settings->cnch_mutate_max_parts_to_mutate + || curr_mutate_part_size >= storage_settings->cnch_mutate_max_total_bytes_to_mutate) + { + submitFutureManipulationTask( + storage, + FutureManipulationTask(*this, type) + .setMutationEntry(*current_mutate_entry) + .tagSourceParts(std::move(alter_parts))); + alter_parts.clear(); + curr_mutate_part_size = 0; + if (running_mutation_tasks >= storage.getSettings()->max_addition_mutation_task_num) + return true; } } @@ -1408,11 +1453,7 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer storage, FutureManipulationTask(*this, type) .setMutationEntry(*current_mutate_entry) - .assignSourceParts(std::move(alter_parts))); - } - else if (alter_parts.empty() && type == ManipulationType::Clustering) - { - remain_tasks_in_partition = false; + .tagSourceParts(std::move(alter_parts))); } return remain_tasks_in_partition; @@ -1424,12 +1465,12 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer const auto & commit_ts = current_mutate_entry->commit_time; for (const auto & part : visible_parts) { - if (needMutate(part, commit_ts, change_schema)) + if (needMutate(part, commit_ts, change_schema, is_recluster, table_definition_hash)) return false; } for (const auto & part : visible_staged_parts) { - if (needMutate(part, commit_ts, change_schema)) + if (needMutate(part, commit_ts, change_schema, is_recluster, table_definition_hash)) return false; } return true; @@ -1437,7 +1478,6 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer /// Step 1: generate mutations tasks for the earliest mutation entry. bool is_finish = true; - bool is_recluster_partition_finish = false; auto timestamp = getContext()->getTimestamp(); if (storage.getInMemoryMetadataPtr()->getPartitionKeyAST()) @@ -1470,7 +1510,7 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer visible_staged_parts = CnchPartsHelper::calcVisibleParts(staged_parts, false); } - if (check_all_done(visible_parts, visible_staged_parts) || is_recluster_partition_finish) + if (check_all_done(visible_parts, visible_staged_parts)) finish_mutation_partitions.emplace(partition_id); } /// Some parts are not scheduled, generate tasks for those parts @@ -1479,8 +1519,6 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer { LOG_TRACE(log, "No more mutation tasks for partition {}, mutation id: {}", partition_id, current_mutate_entry->txn_id); scheduled_mutation_partitions.emplace(partition_id); - if (current_mutate_entry->isReclusterMutation()) - is_recluster_partition_finish = true; } /// - if can't generate all tasks at this round, then go into next round. else @@ -1531,12 +1569,10 @@ bool CnchMergeMutateThread::tryMutateParts(StoragePtr & istorage, StorageCnchMer } } - removeMutationEntryFromKV(*current_mutate_entry, is_newest_cluster_mutation, lock); - storage.removeMutationEntry(commit_ts); + if (is_newest_cluster_mutation) + setTableClusterStatus(); - scheduled_mutation_partitions.clear(); - finish_mutation_partitions.clear(); - current_mutate_entry.reset(); + finish_current_mutation(); return false; } @@ -1557,6 +1593,7 @@ MergeTreeMutationStatusVector CnchMergeMutateThread::getAllMutationStatuses() TxnTimestamp curr_ts = context->tryGetTimestamp(__PRETTY_FUNCTION__); auto istorage = catalog->getTableByUUID(*context, UUIDHelpers::UUIDToString(storage_id.uuid), curr_ts); auto & storage = checkAndGetCnchTable(istorage); + auto table_definition_hash = storage.getTableHashForClusterBy(); auto all_parts = catalog->getAllServerDataParts(istorage, curr_ts, nullptr); auto visible_parts = CnchPartsHelper::calcVisibleParts(all_parts, false); @@ -1572,10 +1609,11 @@ MergeTreeMutationStatusVector CnchMergeMutateThread::getAllMutationStatuses() calcMutationPartitions(entry, istorage, storage); auto & partitions = entry.partition_ids.value(); bool change_schema = entry.commands.changeSchema(); + bool is_recluster = entry.isReclusterMutation(); for (auto & part : visible_parts) { bool partition_match = std::find(partitions.begin(), partitions.end(), part->info().partition_id) != partitions.end(); - if (partition_match && needMutate(part, commit_ts, change_schema)) + if (partition_match && needMutate(part, commit_ts, change_schema, is_recluster, table_definition_hash)) ++status.parts_to_do; } @@ -1593,7 +1631,40 @@ MergeTreeMutationStatusVector CnchMergeMutateThread::getAllMutationStatuses() return res; } -void CnchMergeMutateThread::removeMutationEntryFromKV(const CnchMergeTreeMutationEntry & entry, bool recluster_finish, std::lock_guard &) +void CnchMergeMutateThread::setTableClusterStatus() +{ + /// modify cluster status before removing recluster mutation entry. + LOG_DEBUG(log, "All reclusted tasks in table {} have been executed, check for cluster status", storage_id.getNameForLogs()); + + bool clustered = true; + auto istorage = getStorageFromCatalog(); + auto & storage = checkAndGetCnchTable(istorage); + auto table_definition_hash = storage.getTableHashForClusterBy(); + auto check_clustered = [&table_definition_hash](const ServerDataPartsVector & parts) + { + return std::all_of(parts.begin(), parts.end(), + [&table_definition_hash](const ServerDataPartPtr & part) { return table_definition_hash.match(part->part_model().table_definition_hash()); }); + }; + auto partition_ids = catalog->getPartitionIDs(istorage, nullptr); + for (const auto & partition_id : partition_ids) + { + auto parts = catalog->getServerDataPartsInPartitions(istorage, {partition_id}, TxnTimestamp::maxTS(), nullptr); + auto visible_parts = CnchPartsHelper::calcVisibleParts(parts, false); + ServerDataPartsVector visible_staged_parts; + if (storage.getInMemoryMetadataPtr()->hasUniqueKey()) + { + auto staged_parts = catalog->getStagedServerDataParts(istorage, TxnTimestamp::maxTS()); + visible_staged_parts = CnchPartsHelper::calcVisibleParts(staged_parts, false); + } + clustered = check_clustered(visible_parts) && check_clustered(visible_staged_parts); + if (!clustered) + break; + } + + catalog->setTableClusterStatus(storage.getStorageID().uuid, clustered, table_definition_hash); +} + +void CnchMergeMutateThread::removeMutationEntryFromKV(const CnchMergeTreeMutationEntry & entry, std::lock_guard &) { const auto & commit_time = entry.commit_time; /// FIXME: (zuochuang.zema) buggy: we don't touch active timestamp for insertion, @@ -1615,12 +1686,6 @@ void CnchMergeMutateThread::removeMutationEntryFromKV(const CnchMergeTreeMutatio return; } - /// modify cluster status before removing recluster mutation entry. - if (recluster_finish) - { - LOG_DEBUG(log, "Data parts are clustered in table {}.", storage_id.getNameForLogs()); - } - WriteBufferFromOwnString buf; entry.commands.writeText(buf); LOG_DEBUG(log, "Mutation {}(command: {}) has been done, will remove it from catalog.", commit_time, buf.str()); diff --git a/src/CloudServices/CnchMergeMutateThread.h b/src/CloudServices/CnchMergeMutateThread.h index ec5280bc771..5fb80b2a310 100644 --- a/src/CloudServices/CnchMergeMutateThread.h +++ b/src/CloudServices/CnchMergeMutateThread.h @@ -74,10 +74,7 @@ struct ManipulationTaskRecord CnchWorkerClientPtr worker; size_t lost_count{0}; - /// for system.part_merge_log & system.server_part_log - String result_part_name; - - Strings getSourcePartNames() const; + Strings getSourcePartNames(bool flatten = false) const; }; struct FutureManipulationTask @@ -102,7 +99,7 @@ struct FutureManipulationTask } TxnTimestamp calcColumnsCommitTime() const; - FutureManipulationTask & assignSourceParts(ServerDataPartsVector && parts); + FutureManipulationTask & tagSourceParts(ServerDataPartsVector && parts); FutureManipulationTask & prepareTransaction(); std::unique_ptr moveRecord(); @@ -199,6 +196,7 @@ class CnchMergeMutateThread : public ICnchBGThread void waitMutationFinish(UInt64 mutation_commit_time, UInt64 timeout_ms); MergeTreeMutationStatusVector getAllMutationStatuses(); ClusterTaskProgress getReclusteringTaskProgress(); + void setTableClusterStatus(); private: void preStart() override; @@ -217,7 +215,7 @@ class CnchMergeMutateThread : public ICnchBGThread String submitFutureManipulationTask(const StorageCnchMergeTree & storage, FutureManipulationTask & future_task, bool maybe_sync_task = false); // Mutate - void removeMutationEntryFromKV(const CnchMergeTreeMutationEntry & entry, bool recluster_finish, std::lock_guard &); + void removeMutationEntryFromKV(const CnchMergeTreeMutationEntry & entry, std::lock_guard &); void calcMutationPartitions(CnchMergeTreeMutationEntry & mutate_entry, StoragePtr & istorage, StorageCnchMergeTree & storage); bool tryMutateParts(StoragePtr & istorage, StorageCnchMergeTree & storage); diff --git a/src/CloudServices/CnchObjectColumnSchemaAssembleThread.cpp b/src/CloudServices/CnchObjectColumnSchemaAssembleThread.cpp index 12ebb8989e5..9e1c851a78b 100644 --- a/src/CloudServices/CnchObjectColumnSchemaAssembleThread.cpp +++ b/src/CloudServices/CnchObjectColumnSchemaAssembleThread.cpp @@ -111,7 +111,7 @@ void CnchObjectColumnSchemaAssembleThread::runImpl() // Step 4:update assembled schema and delete partial schema in storage cache if (auto cache_manager = getContext()->getPartCacheManager()) { - if (auto storage_in_cache = cache_manager->getStorageFromCache(table_uuid, current_topology_version)) + if (auto storage_in_cache = cache_manager->getStorageFromCache(table_uuid, current_topology_version, *getContext())) { auto & table_in_cache = checkAndGetCnchTable(storage_in_cache); table_in_cache.refreshAssembledSchema(new_assembled_schema, committed_partial_schema_txnids); diff --git a/src/CloudServices/CnchPartGCThread.cpp b/src/CloudServices/CnchPartGCThread.cpp index 0a7a1d5f02e..8aec09e62f7 100644 --- a/src/CloudServices/CnchPartGCThread.cpp +++ b/src/CloudServices/CnchPartGCThread.cpp @@ -448,10 +448,10 @@ void CnchPartGCThread::runDataRemoveTask() if (!istorage->is_dropped) { auto & storage = checkAndGetCnchTable(istorage); - size_t removed_size = doPhaseTwoGC(istorage, storage); + cleaned_items_in_a_round += doPhaseTwoGC(istorage, storage); auto storage_settings = storage.getSettings(); - if (removed_size) + if (!phase_two_start_key.empty() || cleaned_items_in_a_round) { sleep_ms = std::uniform_int_distribution(0, storage_settings->cleanup_delay_period_random_add * 1000)(rng); @@ -462,10 +462,12 @@ void CnchPartGCThread::runDataRemoveTask() { round_removing_no_data++; phase_two_continuous_hits = 0; - sleep_ms - = std::min(storage_settings->cleanup_delay_period * 1000 * std::pow(1.4, round_removing_no_data), 5 * 60 * 1000.0); + sleep_ms = storage_settings->cleanup_delay_period_upper_bound * 1000; LOG_TRACE(log, "[p2] Removed no data for {} round(s). Delay schedule for {} ms.", round_removing_no_data, sleep_ms); } + + if (phase_two_start_key.empty()) + cleaned_items_in_a_round = 0; } } catch (...) @@ -532,9 +534,13 @@ size_t CnchPartGCThread::doPhaseTwoGC(const StoragePtr & istorage, StorageCnchMe return false; }; - size_t pool_size = std::min( - static_cast(2 * std::pow(2.0, phase_two_continuous_hits)), - static_cast(storage.getSettings()->gc_remove_part_thread_pool_size)); + /// pool_size should be at least 1. + size_t pool_size = std::max( + std::min( + /// Avoid the number get too large. + static_cast(2 * std::pow(2.0, std::min(phase_two_continuous_hits, 15ul))), + static_cast(storage.getSettings()->gc_remove_part_thread_pool_size)), + 1ul); /// If batch_size <= 1, then round-robin may never move forward. size_t batch_size = std::max(static_cast(storage.getSettings()->gc_remove_part_batch_size), static_cast(2)); LOG_TRACE( diff --git a/src/CloudServices/CnchPartGCThread.h b/src/CloudServices/CnchPartGCThread.h index f059abd10dd..fba12116003 100644 --- a/src/CloudServices/CnchPartGCThread.h +++ b/src/CloudServices/CnchPartGCThread.h @@ -122,6 +122,10 @@ class CnchPartGCThread : public ICnchBGThread std::weak_ptr merge_thread; String phase_two_start_key; + /// The total number of items got cleaned from the start key to end key. + /// Reset when a new round start. + /// Recovery rates can be more conservative if the value is too low. + size_t cleaned_items_in_a_round = 0; }; diff --git a/src/CloudServices/CnchServerClient.cpp b/src/CloudServices/CnchServerClient.cpp index efbc7d43c2b..f5a5d4be66d 100644 --- a/src/CloudServices/CnchServerClient.cpp +++ b/src/CloudServices/CnchServerClient.cpp @@ -24,6 +24,7 @@ #include #include #include +#include namespace DB @@ -90,10 +91,10 @@ CnchServerClient::commitTransaction(const ICnchTransaction & txn, const StorageI return response.commit_ts(); } -void CnchServerClient::precommitTransaction(const TxnTimestamp & txn_id, const UUID & uuid) +void CnchServerClient::precommitTransaction(const ContextPtr & context, const TxnTimestamp & txn_id, const UUID & uuid) { brpc::Controller cntl; - cntl.set_timeout_ms(10 * 1000); + cntl.set_timeout_ms(context->getSettingsRef().max_dedup_execution_time.totalMilliseconds()); Protos::PrecommitTransactionReq request; Protos::PrecommitTransactionResp response; @@ -250,7 +251,8 @@ PrunedPartitions CnchServerClient::fetchPartitions( const ConstStoragePtr & table, const SelectQueryInfo & query_info, const Names & column_names, - const TxnTimestamp & txn_id) + const TxnTimestamp & txn_id, + const bool & ignore_ttl) { brpc::Controller cntl; if (const auto * storage = dynamic_cast(table.get())) @@ -272,6 +274,7 @@ PrunedPartitions CnchServerClient::fetchPartitions( request.add_column_name_filter(name); request.set_txnid(txn_id.toUInt64()); + request.set_ignore_ttl(ignore_ttl); stub->fetchPartitions(&cntl, &request, & response, nullptr); @@ -510,13 +513,11 @@ void CnchServerClient::redirectDetachAttachedS3Parts( RPCHelpers::checkResponse(response); } -void CnchServerClient::commitParts( +UInt32 CnchServerClient::commitParts( const TxnTimestamp & txn_id, ManipulationType type, MergeTreeMetaBase & storage, - const MutableMergeTreeDataPartsCNCHVector & parts, - const DeleteBitmapMetaPtrVector & delete_bitmaps, - const MutableMergeTreeDataPartsCNCHVector & staged_parts, + const DumpedData & dumped_data, const String & task_id, const bool from_server, const String & consumer_group, @@ -526,6 +527,10 @@ void CnchServerClient::commitParts( { /// TODO: check txn_id & start_ts + const auto & parts = dumped_data.parts; + const auto & delete_bitmaps = dumped_data.bitmaps; + const auto & staged_parts = dumped_data.staged_parts; + brpc::Controller cntl; cntl.set_timeout_ms(storage.getSettings()->cnch_meta_rpc_timeout_ms); Protos::CommitPartsReq request; @@ -605,21 +610,22 @@ void CnchServerClient::commitParts( new_bitmap->CopyFrom(*(delete_bitmap->getModel())); } + request.set_dedup_mode(static_cast(dumped_data.dedup_mode)); + stub->commitParts(&cntl, &request, &response, nullptr); assertController(cntl); RPCHelpers::checkResponse(response); + return response.has_dedup_impl_version() ? response.dedup_impl_version(): 1; } /* This method commits from worker side, it split the commit parts in multiple batches to avoid rpc timeout for too many parts. Note, it only applys to ManipulationType which supports 2pc, now we already separate txn commit from part commit */ -void CnchServerClient::precommitParts( +UInt32 CnchServerClient::precommitParts( ContextPtr context, const TxnTimestamp & txn_id, ManipulationType type, MergeTreeMetaBase & storage, - const MutableMergeTreeDataPartsCNCHVector & parts, - const DeleteBitmapMetaPtrVector & delete_bitmaps, - const MutableMergeTreeDataPartsCNCHVector & staged_parts, + const DumpedData & dumped_data, const String & task_id, const bool from_server, const String & consumer_group, @@ -628,9 +634,13 @@ void CnchServerClient::precommitParts( const UInt64 peak_memory_usage) { const UInt64 batch_size = context->getSettingsRef().catalog_max_commit_size; + const auto & parts = dumped_data.parts; + const auto & delete_bitmaps = dumped_data.bitmaps; + const auto & staged_parts = dumped_data.staged_parts; // Precommit parts in batches {batch_begin, batch_end} const size_t max_size = std::max({parts.size(), delete_bitmaps.size(), staged_parts.size()}); + UInt32 dedup_impl_version = 1; for (size_t batch_begin = 0; batch_begin < max_size; batch_begin += batch_size) { size_t batch_end = batch_begin + batch_size; @@ -646,7 +656,7 @@ void CnchServerClient::precommitParts( LOG_DEBUG( log, "Precommit: parts in batch: [{} ~ {}] of total: {}; delete_bitmaps in batch [{} ~ {}] of total {}; staged parts in batch [{} " - "~ {}] of total {}.", + "~ {}] of total {}; dedup mode is {}", part_batch_begin, part_batch_end, parts.size(), @@ -655,15 +665,20 @@ void CnchServerClient::precommitParts( delete_bitmaps.size(), staged_part_batch_begin, staged_part_batch_end, - staged_parts.size()); + staged_parts.size(), + typeToString(dumped_data.dedup_mode)); - commitParts( + DumpedData new_dumped_data; + new_dumped_data.parts = {parts.begin() + part_batch_begin, parts.begin() + part_batch_end}; + new_dumped_data.bitmaps = {delete_bitmaps.begin() + bitmap_batch_begin, delete_bitmaps.begin() + bitmap_batch_end}; + new_dumped_data.staged_parts = {staged_parts.begin() + staged_part_batch_begin, staged_parts.begin() + staged_part_batch_end}; + new_dumped_data.dedup_mode = dumped_data.dedup_mode; + + dedup_impl_version = commitParts( txn_id, type, storage, - {parts.begin() + part_batch_begin, parts.begin() + part_batch_end}, - {delete_bitmaps.begin() + bitmap_batch_begin, delete_bitmaps.begin() + bitmap_batch_end}, - {staged_parts.begin() + staged_part_batch_begin, staged_parts.begin() + staged_part_batch_end}, + new_dumped_data, task_id, from_server, consumer_group, @@ -671,6 +686,7 @@ void CnchServerClient::precommitParts( binlog, peak_memory_usage); } + return dedup_impl_version; } google::protobuf::RepeatedPtrField @@ -756,6 +772,21 @@ void CnchServerClient::cleanTransaction(const TransactionRecord & txn_record) RPCHelpers::checkResponse(response); } +void CnchServerClient::cleanUndoBuffers(const TransactionRecord & txn_record) +{ + brpc::Controller cntl; + Protos::CleanUndoBuffersReq request; + Protos::CleanUndoBuffersResp response; + + LOG_DEBUG(&Poco::Logger::get(__func__), "clean undo buffers for txn: [{}] on server: {}", txn_record.toString(), getRPCAddress()); + + request.mutable_txn_record()->CopyFrom(txn_record.pb_model); + stub->cleanUndoBuffers(&cntl, &request, &response, nullptr); + + assertController(cntl); + RPCHelpers::checkResponse(response); +} + void CnchServerClient::acquireLock(const LockInfoPtr & lock) { brpc::Controller cntl; @@ -869,6 +900,20 @@ UInt64 CnchServerClient::getServerStartTime() return response.server_start_time(); } +UInt32 CnchServerClient::getDedupImplVersion(const TxnTimestamp & txn_id, const UUID & uuid) +{ + brpc::Controller cntl; + Protos::GetDedupImplVersionReq request; + Protos::GetDedupImplVersionResp response; + request.set_txn_id(txn_id); + RPCHelpers::fillUUID(uuid, *request.mutable_uuid()); + + stub->getDedupImplVersion(&cntl, &request, &response, nullptr); + + assertController(cntl); + return response.version(); +} + bool CnchServerClient::scheduleGlobalGC(const std::vector & tables) { brpc::Controller cntl; @@ -984,18 +1029,19 @@ CnchServerClient::getBackGroundStatus(const CnchBGThreadType & type) return response.status(); } -void CnchServerClient::submitPreloadTask(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts, UInt64 timeout_ms) +brpc::CallId CnchServerClient::submitPreloadTask(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts, UInt64 timeout_ms) { + auto * cntl = new brpc::Controller(); + auto call_id = cntl->call_id(); if (parts.empty()) - return; + return call_id; - brpc::Controller cntl; Protos::SubmitPreloadTaskReq request; request.set_ts(time(nullptr)); - Protos::SubmitPreloadTaskResp response; + auto response = new Protos::SubmitPreloadTaskResp(); if (timeout_ms) - cntl.set_timeout_ms(timeout_ms); + cntl->set_timeout_ms(timeout_ms); /// prefer to get cnch table uuid from settings as multiple CloudMergeTrees cannot share a same uuid, /// thus most CloudMergeTrees have no uuids on the worker side @@ -1010,9 +1056,8 @@ void CnchServerClient::submitPreloadTask(const MergeTreeMetaBase & storage, cons fillPartModel(storage, *part, *new_part); } - stub->submitPreloadTask(&cntl, &request, &response, nullptr); - assertController(cntl); - RPCHelpers::checkResponse(response); + stub->submitPreloadTask(cntl, &request, response, brpc::NewCallback(RPCHelpers::onAsyncCallDone, response, cntl, std::make_shared())); + return call_id; } UInt32 CnchServerClient::reportDeduperHeartbeat(const StorageID & cnch_storage_id, const String & worker_table_name) diff --git a/src/CloudServices/CnchServerClient.h b/src/CloudServices/CnchServerClient.h index a8496d87289..3405ffa4f94 100644 --- a/src/CloudServices/CnchServerClient.h +++ b/src/CloudServices/CnchServerClient.h @@ -27,7 +27,7 @@ #include #include #include -#include "Storages/MergeTree/MarkRange.h" +#include #include namespace DB @@ -42,6 +42,7 @@ class CnchServerTransaction; using CnchServerTransactionPtr = std::shared_ptr; struct PrunedPartitions; class StorageCloudMergeTree; +struct DumpedData; class CnchServerClient : public RpcClientBase { @@ -58,7 +59,7 @@ class CnchServerClient : public RpcClientBase std::pair createTransactionForKafka(const StorageID & storage_id, const size_t consumer_index); TxnTimestamp commitTransaction( const ICnchTransaction & txn, const StorageID & kafka_storage_id = StorageID::createEmpty(), const size_t consumer_index = 0); - void precommitTransaction(const TxnTimestamp & txn_id, const UUID & uuid = UUIDHelpers::Nil); + void precommitTransaction(const ContextPtr & context, const TxnTimestamp & txn_id, const UUID & uuid = UUIDHelpers::Nil); TxnTimestamp rollbackTransaction(const TxnTimestamp & txn_id); void finishTransaction(const TxnTimestamp & txn_id); @@ -88,7 +89,8 @@ class CnchServerClient : public RpcClientBase const ConstStoragePtr & table, const SelectQueryInfo & query_info, const Names & column_names, - const TxnTimestamp & txn_id); + const TxnTimestamp & txn_id, + const bool & ignore_ttl); void redirectCommitParts( const StoragePtr & table, @@ -134,13 +136,11 @@ class CnchServerClient : public RpcClientBase const std::vector> & detached_bitmap_metas, const DB::Protos::DetachAttachType & type); - void commitParts( + UInt32 commitParts( const TxnTimestamp & txn_id, ManipulationType type, MergeTreeMetaBase & storage, - const MutableMergeTreeDataPartsCNCHVector & parts, - const DeleteBitmapMetaPtrVector & delete_bitmaps, - const MutableMergeTreeDataPartsCNCHVector & staged_parts, + const DumpedData & dumped_data, const String & task_id = {}, const bool from_server = false, const String & consumer_group = {}, @@ -148,14 +148,15 @@ class CnchServerClient : public RpcClientBase const MySQLBinLogInfo & binlog = {}, const UInt64 peak_memory_usage = 0); - void precommitParts( + /** + * @return UInt32 dedup impl version for unique table, current valid value is 1 or 2. 1 means old impl which will dedup in write suffix stage, 2 means new impl which will dedup in txn commit stage. + */ + UInt32 precommitParts( ContextPtr context, const TxnTimestamp & txn_id, ManipulationType type, MergeTreeMetaBase & storage, - const MutableMergeTreeDataPartsCNCHVector & parts, - const DeleteBitmapMetaPtrVector & delete_bitmaps, - const MutableMergeTreeDataPartsCNCHVector & staged_parts, + const DumpedData & dumped_data, const String & task_id = {}, const bool from_server = false, const String & consumer_group = {}, @@ -170,6 +171,12 @@ class CnchServerClient : public RpcClientBase getTableInfo(const std::vector> & tables); void controlCnchBGThread(const StorageID & storage_id, CnchBGThreadType type, CnchBGThreadAction action); void cleanTransaction(const TransactionRecord & txn_record); + /** + * @brief Clean undo buffers with the given txn (only) on target server. + * + * @param txn_record The transaction to which the Undo Buffer belongs. + */ + void cleanUndoBuffers(const TransactionRecord & txn_record); std::set getDeletingTablesInGlobalGC(); bool removeMergeMutateTasksOnPartitions(const StorageID &, const std::unordered_set &); @@ -186,7 +193,7 @@ class CnchServerClient : public RpcClientBase google::protobuf::RepeatedPtrField getBackGroundStatus(const CnchBGThreadType & type); - void submitPreloadTask(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts, UInt64 timeout_ms); + brpc::CallId submitPreloadTask(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts, UInt64 timeout_ms); UInt32 reportDeduperHeartbeat(const StorageID & cnch_storage_id, const String & worker_table_name); @@ -195,6 +202,8 @@ class CnchServerClient : public RpcClientBase void executeOptimize(const StorageID & storage_id, const String & partition_id, bool enable_try, bool mutations_sync, UInt64 timeout_ms); void notifyAccessEntityChange(IAccessEntity::Type type, const String & name); + UInt32 getDedupImplVersion(const TxnTimestamp & txn_id, const UUID & uuid); + #if USE_MYSQL void submitMaterializedMySQLDDLQuery(const String & database_name, const String & sync_thread, const String & query, const MySQLBinLogInfo & binlog); void reportHeartBeatForSyncThread(const String & database_name, const String & sync_thread); diff --git a/src/CloudServices/CnchServerResource.cpp b/src/CloudServices/CnchServerResource.cpp index ac430cf34d2..aa7acf35429 100644 --- a/src/CloudServices/CnchServerResource.cpp +++ b/src/CloudServices/CnchServerResource.cpp @@ -55,14 +55,14 @@ AssignedResource::AssignedResource(const StoragePtr & storage_) : storage(storag AssignedResource::AssignedResource(AssignedResource && resource) { storage = resource.storage; - worker_table_name = resource.worker_table_name; - create_table_query = resource.create_table_query; + table_version = resource.table_version; + table_definition = resource.table_definition; sent_create_query = resource.sent_create_query; bucket_numbers = resource.bucket_numbers; replicated = resource.replicated; - table_version = resource.table_version; server_parts = std::move(resource.server_parts); + virtual_parts = std::move(resource.virtual_parts); hive_parts = std::move(resource.hive_parts); file_parts = std::move(resource.file_parts); part_names = resource.part_names; // don't call move here @@ -123,7 +123,7 @@ void AssignedResource::addDataParts(const FileDataPartsCNCHVector & parts) } } -void ResourceStageInfo::filterResource(std::optional resource_option) +void ResourceStageInfo::filterResource(std::optional & resource_option) { if (resource_option) { @@ -223,15 +223,47 @@ void CnchServerResource::addCreateQuery( if (it == assigned_table_resource.end()) it = assigned_table_resource.emplace(storage->getStorageUUID(), AssignedResource{storage}).first; - it->second.create_table_query = create_query; - it->second.worker_table_name = worker_table_name; + it->second.table_definition.definition = create_query; + it->second.table_definition.local_table_name = worker_table_name; + it->second.table_definition.cacheable = false; +} + +void CnchServerResource::addCacheableCreateQuery( + const StoragePtr & storage, + const String & worker_table_name, + WorkerEngineType engine_type, + String underlying_dictionary_tables) +{ + auto uuid = storage->getStorageUUID(); + if (uuid == UUIDHelpers::Nil) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Cannot add definition for {} : UUID is empty", storage->getStorageID().getNameForLogs()); + + auto lock = getLock(); + + auto it = assigned_table_resource.find(uuid); + if (it == assigned_table_resource.end()) + it = assigned_table_resource.emplace(uuid, AssignedResource{storage}).first; + + it->second.table_definition = TableDefinitionResource { + storage->getCreateTableSql(), + worker_table_name, + /*cacheable=*/ true, + engine_type, + underlying_dictionary_tables + }; } -void CnchServerResource::setTableVersion(const UUID & storage_uuid, const UInt64 table_version) +void CnchServerResource::setTableVersion( + const UUID & storage_uuid, const UInt64 table_version, const std::set & required_bucket_numbers) { std::lock_guard lock(mutex); auto & assigned_resource = assigned_table_resource.at(storage_uuid); - assigned_resource.table_version = table_version; + if (assigned_resource.table_version == 0) + assigned_resource.table_version = table_version; + else if (assigned_resource.table_version != table_version) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Inconsistent table version for table {}", UUIDHelpers::UUIDToString(storage_uuid)); + assigned_resource.bucket_numbers.insert(required_bucket_numbers.begin(), required_bucket_numbers.end()); } void CnchServerResource::sendResource(const ContextPtr & context, const HostWithPorts & worker) @@ -412,15 +444,31 @@ void CnchServerResource::allocateResource( const auto & server_parts = resource.server_parts; const auto & required_bucket_numbers = resource.bucket_numbers; bool replicated = resource.replicated; + bool use_bucket_assignment = false; + BucketNumbersAssignmentMap assigned_bucket_map; ServerAssignmentMap assigned_map; - VirtualPartAssignmentMap virtual_part_assigned_map; + VirtualPartAssignmentMap assigned_virtual_part_map; HivePartsAssignMap assigned_hive_map; FilePartsAssignMap assigned_file_map; ServerDataPartsVector bucket_parts; ServerDataPartsVector leftover_server_parts; + auto * cnch_table = dynamic_cast(storage.get()); + // For function : arrayToBitmapWithEncode/EncodeBitmap bool bitengine_related_table = false; - if (auto * cnch_table = dynamic_cast(storage.get())) + if (resource.table_version > 0) // query with table version instead of parts + { + use_bucket_assignment = !required_bucket_numbers.empty(); + if (use_bucket_assignment) + { + assigned_bucket_map = assignBuckets(required_bucket_numbers, worker_group->getWorkerIDVec(), replicated); + } + else + { + /// allocate table version to all workers + } + } + else if (cnch_table) { // NOTE: server_parts maybe moved due to splitCnchParts and cannot be used again std::tie(bucket_parts, leftover_server_parts) = splitCnchParts(context, *storage, server_parts); @@ -433,7 +481,6 @@ void CnchServerResource::allocateResource( leftover_server_parts.size()); ProfileEvents::increment(ProfileEvents::CnchPartAllocationSplits); } - // If the # of parts over vw size is not zero, // only go through hybrid allocation logic when that is smaller than a configurable ratio if ((context->getSettingsRef().enable_hybrid_allocation || cnch_table->getSettings()->enable_hybrid_allocation) @@ -446,12 +493,12 @@ void CnchServerResource::allocateResource( min_rows_per_virtual_part = cnch_table->getSettings()->min_rows_per_virtual_part; auto virtual_part_size = computeVirtualPartSize(min_rows_per_virtual_part, cnch_table->getSettings()->index_granularity); - std::tie(assigned_map, virtual_part_assigned_map) + std::tie(assigned_map, assigned_virtual_part_map) = assignCnchHybridParts(worker_group, leftover_server_parts, virtual_part_size, context); } else { - assigned_map = assignCnchParts(worker_group, leftover_server_parts, context); + assigned_map = assignCnchParts(worker_group, leftover_server_parts, context, cnch_table->getSettings()); } moveBucketTablePartsToAssignedParts( assigned_map, bucket_parts, worker_group->getWorkerIDVec(), required_bucket_numbers, replicated); @@ -483,78 +530,79 @@ void CnchServerResource::allocateResource( auto & assigned_storage_worker_indexs = assigned_storage_workers[storage->getStorageUUID()]; for (const auto & host_ports : host_ports_vec) { + std::set assigned_buckets; ServerDataPartsVector assigned_parts; ServerVirtualPartVector assigned_virtual_parts; HiveFiles assigned_hive_parts; FileDataPartsCNCHVector assigned_file_parts; - bool has_parts = false; - if (auto it = assigned_map.find(host_ports.id); it != assigned_map.end()) + + if (auto it = assigned_bucket_map.find(host_ports.id); it != assigned_bucket_map.end()) { - assigned_parts = std::move(it->second); - assigned_storage_worker_indexs.insert(host_ports); - has_parts = true; + assigned_buckets = std::move(it->second); LOG_TRACE( log, - "SourcePrune Send {}.{} {}'s data part to worker {}", - storage->getDatabaseName(), - storage->getTableName(), - toString(storage->getStorageUUID()), + "Allocate {} buckets from table {} to {}", + assigned_buckets.size(), + storage->getStorageID().getNameForLogs(), host_ports.toDebugString()); + } + if (auto it = assigned_map.find(host_ports.id); it != assigned_map.end()) + { + assigned_parts = std::move(it->second); CnchPartsHelper::flattenPartsVector(assigned_parts); + LOG_TRACE( + log, + "Allocate {} parts from table {} to {}", + assigned_parts.size(), + storage->getStorageID().getNameForLogs(), + host_ports.toDebugString()); } - if (auto it = virtual_part_assigned_map.find(host_ports.id); it != virtual_part_assigned_map.end()) + if (auto it = assigned_virtual_part_map.find(host_ports.id); it != assigned_virtual_part_map.end()) { assigned_virtual_parts = getVirtualPartVector(leftover_server_parts, it->second); - assigned_storage_worker_indexs.insert(host_ports); LOG_TRACE( log, - "Send {} virtual data part (hybrid_allocation) to worker {} for table {}", - assigned_parts.size(), - host_ports.toDebugString(), - storage->getStorageID().getNameForLogs()); + "Allocate {} virtual parts from table {} to {}", + assigned_virtual_parts.size(), + storage->getStorageID().getNameForLogs(), + host_ports.toDebugString()); } if (auto it = assigned_hive_map.find(host_ports.id); it != assigned_hive_map.end()) { assigned_hive_parts = std::move(it->second); - assigned_storage_worker_indexs.insert(host_ports); - has_parts = true; LOG_TRACE( log, - "SourcePrune Send Hive {}.{} {}'s data part to worker {}", - storage->getDatabaseName(), - storage->getTableName(), - toString(storage->getStorageUUID()), + "Allocate {} hive parts from table {} to {}", + assigned_hive_parts.size(), + storage->getStorageID().getNameForLogs(), host_ports.toDebugString()); } - if (auto it = assigned_file_map.find(host_ports.id); it != assigned_file_map.end()) { assigned_file_parts = std::move(it->second); - assigned_storage_worker_indexs.insert(host_ports); - has_parts = true; LOG_TRACE( log, - "SourcePrune Send File {}.{} {} data parts to works {}, size = {}", - storage->getDatabaseName(), - storage->getTableName(), - toString(storage->getStorageUUID()), - host_ports.toDebugString(), - assigned_file_parts.size()); + "Allocate {} file parts from table {} to {}", + assigned_file_parts.size(), + storage->getStorageID().getNameForLogs(), + host_ports.toDebugString()); } - LOG_TRACE( - log, - "Storage {} host {} prune_table {} has_parts {}", - storage->getStorageID().getNameForLogs(), - host_ports.toDebugString(), - context->getSettingsRef().enable_prune_empty_resource, - has_parts); + + bool empty = (resource.table_version == 0 || (use_bucket_assignment && assigned_buckets.empty())) + && assigned_parts.empty() + && assigned_virtual_parts.empty() + && assigned_hive_parts.empty() + && assigned_file_parts.empty(); + + if (!empty) + assigned_storage_worker_indexs.insert(host_ports); if (!context->getSettingsRef().bsp_mode && context->getSettingsRef().enable_optimizer - && context->getSettingsRef().enable_prune_empty_resource && !has_parts && !bitengine_related_table) + && context->getSettingsRef().enable_prune_empty_resource && empty && !bitengine_related_table) { LOG_TRACE( log, @@ -577,10 +625,10 @@ void CnchServerResource::allocateResource( worker_resource.addDataParts(std::move(assigned_virtual_parts)); worker_resource.addDataParts(assigned_hive_parts); worker_resource.addDataParts(assigned_file_parts); + worker_resource.bucket_numbers = use_bucket_assignment ? assigned_buckets : required_bucket_numbers; worker_resource.sent_create_query = resource.sent_create_query; worker_resource.table_version = resource.table_version; - worker_resource.create_table_query = resource.create_table_query; - worker_resource.worker_table_name = resource.worker_table_name; + worker_resource.table_definition = resource.table_definition; worker_resource.object_columns = resource.object_columns; } } diff --git a/src/CloudServices/CnchServerResource.h b/src/CloudServices/CnchServerResource.h index 0fb908cd133..8aff3265a13 100644 --- a/src/CloudServices/CnchServerResource.h +++ b/src/CloudServices/CnchServerResource.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -69,6 +70,17 @@ struct SendLock ServerResourceLockManager & manager; }; +struct TableDefinitionResource +{ + /// if cacheable == 0, it's the rewrited table definition for worker; + /// otherwise, it's the original definition for cnch table + String definition; + String local_table_name; + bool cacheable = false; + WorkerEngineType engine_type = WorkerEngineType::CLOUD; + String underlying_dictionary_tables; // local dictionary table names for bitengine +}; + struct AssignedResource { explicit AssignedResource(const StoragePtr & storage); @@ -77,8 +89,7 @@ struct AssignedResource StoragePtr storage; UInt64 table_version{0}; //send table version instead of parts if set - String worker_table_name; - String create_table_query; + TableDefinitionResource table_definition; bool sent_create_query{false}; bool replicated{false}; @@ -98,7 +109,15 @@ struct AssignedResource void addDataParts(const HiveFiles & parts); void addDataParts(const FileDataPartsCNCHVector & parts); - bool empty() const { return sent_create_query && server_parts.empty() && hive_parts.empty() && file_parts.empty(); } + bool empty() const + { + return sent_create_query + && table_version == 0 + && server_parts.empty() + && virtual_parts.empty() + && hive_parts.empty() + && file_parts.empty(); + } }; // Send resources separately by UUID @@ -110,7 +129,7 @@ struct ResourceOption struct ResourceStageInfo { std::unordered_set sent_resource; - void filterResource(std::optional resource_option); + void filterResource(std::optional & resource_option); }; class CnchServerResource { @@ -131,7 +150,16 @@ class CnchServerResource const String & worker_table_name, bool create_local_table = true); - void setTableVersion(const UUID & storage_uuid, const UInt64 table_version); + void addCacheableCreateQuery( + const StoragePtr & storage, + const String & worker_table_name, + WorkerEngineType engine_type, + String underlying_dictionary_tables); + + void setTableVersion( + const UUID & storage_uuid, + UInt64 table_version, + const std::set & required_bucket_numbers); void setAggregateWorker(HostWithPorts aggregate_worker_) { aggregate_worker = std::move(aggregate_worker_); } @@ -144,14 +172,19 @@ class CnchServerResource void skipCleanWorker() { skip_clean_worker = true; } template - void addDataParts(const UUID & storage_id, const std::vector & data_parts, const std::set & required_bucket_numbers = {}) + void addDataParts( + const UUID & storage_id, + const std::vector & data_parts, + const std::set & required_bucket_numbers = {}, + bool replicated = false) { std::lock_guard lock(mutex); auto & assigned_resource = assigned_table_resource.at(storage_id); + /// accumulate resources for a table assigned_resource.addDataParts(data_parts); - if (assigned_resource.bucket_numbers.empty() && !required_bucket_numbers.empty()) - assigned_resource.bucket_numbers = required_bucket_numbers; + assigned_resource.replicated = assigned_resource.replicated | replicated; + assigned_resource.bucket_numbers.insert(required_bucket_numbers.begin(), required_bucket_numbers.end()); } const WorkerInfoSet & getAssignedWorkers(const UUID & storage_uuid) diff --git a/src/CloudServices/CnchServerServiceImpl.cpp b/src/CloudServices/CnchServerServiceImpl.cpp index bf5aaef4b40..01dcec4391f 100644 --- a/src/CloudServices/CnchServerServiceImpl.cpp +++ b/src/CloudServices/CnchServerServiceImpl.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -189,14 +190,24 @@ void CnchServerServiceImpl::commitParts( CnchDataWriter cnch_writer( *cnch, rpc_context, - ManipulationType(req->type()), + static_cast(req->type()), req->task_id(), std::move(consumer_group), tpl, binlog, peak_memory_usage); - cnch_writer.commitPreparedCnchParts(DumpedData{std::move(parts), std::move(delete_bitmaps), std::move(staged_parts)}); + auto dedup_mode = static_cast(req->dedup_mode()); + cnch_writer.setDedupMode(dedup_mode); + + cnch_writer.commitPreparedCnchParts( + DumpedData{std::move(parts), std::move(delete_bitmaps), std::move(staged_parts), dedup_mode}); + + // If main table uuid is not set, set it. Otherwise, skip it + if (cnch_txn->getMainTableUUID() == UUIDHelpers::Nil) + cnch_txn->setMainTableUUID(cnch->getCnchStorageUUID()); + + rsp->set_dedup_impl_version(cnch_txn->getDedupImplVersion(rpc_context)); } catch (...) { @@ -567,39 +578,40 @@ void CnchServerServiceImpl::reportTaskHeartbeat( } void CnchServerServiceImpl::reportDeduperHeartbeat( - google::protobuf::RpcController * cntl, + google::protobuf::RpcController *, const Protos::ReportDeduperHeartbeatReq * request, Protos::ReportDeduperHeartbeatResp * response, google::protobuf::Closure * done) { - brpc::ClosureGuard done_guard(done); - - try - { - auto cnch_storage_id = RPCHelpers::createStorageID(request->cnch_storage_id()); - - if (auto bg_thread = getContext()->tryGetDedupWorkerManager(cnch_storage_id)) + RPCHelpers::serviceHandler(done, response, [request = request, response = response, done = done, gc = getContext(), log = log] { + brpc::ClosureGuard done_guard(done); + try { - auto worker_table_name = request->worker_table_name(); - auto & manager = static_cast(*bg_thread); + auto cnch_storage_id = RPCHelpers::createStorageID(request->cnch_storage_id()); - auto ret = manager.reportHeartbeat(worker_table_name); + if (auto bg_thread = gc->tryGetDedupWorkerManager(cnch_storage_id)) + { + const auto & worker_table_name = request->worker_table_name(); + auto & manager = static_cast(*bg_thread); - // NOTE: here we send a response back to let the worker know the result. - response->set_code(static_cast(ret)); - return; + auto ret = manager.reportHeartbeat(worker_table_name); + + // NOTE: here we send a response back to let the worker know the result. + response->set_code(static_cast(ret)); + return; + } + else + { + LOG_WARNING(log, "Failed to get background thread"); + } } - else + catch (...) { - LOG_WARNING(log, "Failed to get background thread"); + tryLogCurrentException(log, __PRETTY_FUNCTION__); + RPCHelpers::handleException(response->mutable_exception()); } - } - catch (...) - { - tryLogCurrentException(log, __PRETTY_FUNCTION__); - RPCHelpers::handleException(response->mutable_exception()); - } - response->set_code(static_cast(DedupWorkerHeartbeatResult::Kill)); + response->set_code(static_cast(DedupWorkerHeartbeatResult::Kill)); + }); } void CnchServerServiceImpl::fetchDataParts( @@ -737,10 +749,51 @@ void CnchServerServiceImpl::fetchPartitions( session_context->setCurrentDatabase(request->database()); ReadBufferFromString rb(request->predicate()); ASTPtr query_ptr = deserializeAST(rb); + /// We should to add `database` into AST before calling `buildSelectQueryInfoForQuery`. + { + StoragePtr storage = gc->getCnchCatalog()->getTable(*gc, request->database(), request->table(), TxnTimestamp::maxTS()); + + auto calculated_host + = gc->getCnchTopologyMaster() + ->getTargetServer(UUIDHelpers::UUIDToString(storage->getStorageUUID()), storage->getServerVwName(), true) + .getRPCAddress(); + + if (request->remote_host() != calculated_host) + throw Exception( + "Fetch partitions failed because of inconsistent view of topology in remote server, remote_host: " + + request->remote_host() + ", calculated_host: " + calculated_host, + ErrorCodes::LOGICAL_ERROR); + + Names column_names; + for (const auto & name : request->column_name_filter()) + column_names.push_back(name); + auto session_context = Context::createCopy(gc); + session_context->setCurrentDatabase(request->database()); + ReadBufferFromString rb(request->predicate()); + ASTPtr query_ptr = deserializeAST(rb); + /// We should to add `database` into AST before calling `buildSelectQueryInfoForQuery`. + { + ASTSelectQuery * select_query = query_ptr->as(); + if (!select_query) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected AST type found in buildSelectQueryInfoForQuery"); + select_query->replaceDatabaseAndTable(request->database(), request->table()); + } + SelectQueryInfo query_info = buildSelectQueryInfoForQuery(query_ptr, session_context); + + session_context->setTemporaryTransaction( + TxnTimestamp(request->has_txnid() ? request->txnid() : session_context->getTimestamp()), 0, false); + auto required_partitions = gc->getCnchCatalog()->getPartitionsByPredicate( + session_context, storage, query_info, column_names, request->has_ignore_ttl() && request->ignore_ttl()); + + response->set_total_size(required_partitions.total_partition_number); + auto & mutable_partitions = *response->mutable_partitions(); + for (auto & partition : required_partitions.partitions) + *mutable_partitions.Add() = std::move(partition); + } SelectQueryInfo query_info = buildSelectQueryInfoForQuery(query_ptr, session_context); session_context->setTemporaryTransaction(TxnTimestamp(request->has_txnid() ? request->txnid() : session_context->getTimestamp()), 0, false); - auto required_partitions = gc->getCnchCatalog()->getPartitionsByPredicate(session_context, storage, query_info, column_names); + auto required_partitions = gc->getCnchCatalog()->getPartitionsByPredicate(session_context, storage, query_info, column_names, request->has_ignore_ttl() && request->ignore_ttl()); response->set_total_size(required_partitions.total_partition_number); auto & mutable_partitions = *response->mutable_partitions(); @@ -777,18 +830,8 @@ void CnchServerServiceImpl::getBackgroundThreadStatus( try { - std::map res; - - auto type = CnchBGThreadType(request->type()); - if (type >= CnchBGThreadType::ServerMinType && type <= CnchBGThreadType::ServerMaxType) - { - auto threads = global_context->getCnchBGThreadsMap(type); - res = threads->getStatusMap(); - } - else - { - throw Exception("Not support type " + toString(int(request->type())), ErrorCodes::NOT_IMPLEMENTED); - } + auto type = toServerBGThreadType(request->type()); + std::map res = global_context->getCnchBGThreadsMap(type)->getStatusMap(); for (const auto & [storage_id, status] : res) { @@ -831,8 +874,9 @@ void CnchServerServiceImpl::controlCnchBGThread( StorageID storage_id = StorageID::createEmpty(); if (!request->storage_id().table().empty()) storage_id = RPCHelpers::createStorageID(request->storage_id()); - auto type = CnchBGThreadType(request->type()); - auto action = CnchBGThreadAction(request->action()); + + auto type = toServerBGThreadType(request->type()); + auto action = toCnchBGThreadAction(request->action()); auto & controller = static_cast(*cntl); LOG_DEBUG(log, "Received controlBGThread for {} type {} action {} from {}", storage_id.empty() ? "empty storage" : storage_id.getNameForLogs(), @@ -933,6 +977,30 @@ void CnchServerServiceImpl::cleanTransaction( } ); } +void CnchServerServiceImpl::cleanUndoBuffers( + google::protobuf::RpcController *, + const Protos::CleanUndoBuffersReq * request, + Protos::CleanUndoBuffersResp * response, + google::protobuf::Closure * done) +{ + RPCHelpers::serviceHandler(done, response, [request, response, done, gc = getContext(), this] { + brpc::ClosureGuard done_guard(done); + + auto & txn_cleaner = gc->getCnchTransactionCoordinator().getTxnCleaner(); + TransactionRecord txn_record{request->txn_record()}; + + try + { + txn_cleaner.cleanUndoBuffers(txn_record ); + } + catch (...) + { + LOG_WARNING(log, "Clean txn record {} failed.", txn_record.toString()); + tryLogCurrentException(log, __PRETTY_FUNCTION__); + RPCHelpers::handleException(response->mutable_exception()); + } + }); +} void CnchServerServiceImpl::acquireLock( google::protobuf::RpcController * cntl, const Protos::AcquireLockReq * request, @@ -1031,6 +1099,19 @@ void CnchServerServiceImpl::getServerStartTime( response->set_server_start_time(server_start_time); } +void CnchServerServiceImpl::getDedupImplVersion( + google::protobuf::RpcController *, + const Protos::GetDedupImplVersionReq * request, + Protos::GetDedupImplVersionResp * response, + google::protobuf::Closure * done) +{ + brpc::ClosureGuard done_guard(done); + auto cnch_txn = getContext()->getCnchTransactionCoordinator().getTransaction(request->txn_id()); + if (cnch_txn->getMainTableUUID() == UUIDHelpers::Nil) + cnch_txn->setMainTableUUID(RPCHelpers::createUUID(request->uuid())); + response->set_version(cnch_txn->getDedupImplVersion(getContext())); +} + // About Auto Statistics void CnchServerServiceImpl::queryUdiCounter( [[maybe_unused]] google::protobuf::RpcController* controller, @@ -1738,6 +1819,17 @@ void CnchServerServiceImpl::executeOptimize( auto & database_catalog = DatabaseCatalog::instance(); auto istorage = database_catalog.getTable(storage_id, global_context); + if (istorage && istorage->getInMemoryMetadataPtr()->hasDynamicSubcolumns()) + { + if (auto * cnch_table = dynamic_cast(istorage.get())) + { + LOG_TRACE( + log, + "Object schema snapshot:{}", + cnch_table->getStorageSnapshot(cnch_table->getInMemoryMetadataPtr(), nullptr)->object_columns.toString()); + } + } + auto * merge_mutate_thread = dynamic_cast(bg_thread.get()); auto task_id = merge_mutate_thread->triggerPartMerge(istorage, partition_id, false, enable_try, false); if (request->mutations_sync()) @@ -1836,7 +1928,8 @@ void CnchServerServiceImpl::notifyTableCreated( catch (...) { tryLogCurrentException(log, __PRETTY_FUNCTION__); - RPCHelpers::handleException(response->mutable_exception()); + (void)response; + //RPCHelpers::handleException(response->mutable_exception()); } }); } diff --git a/src/CloudServices/CnchServerServiceImpl.h b/src/CloudServices/CnchServerServiceImpl.h index 1473f66f0d4..a826b0020f2 100644 --- a/src/CloudServices/CnchServerServiceImpl.h +++ b/src/CloudServices/CnchServerServiceImpl.h @@ -185,6 +185,12 @@ class CnchServerServiceImpl : protected WithMutableContext, public DB::Protos::C Protos::CleanTransactionResp * response, google::protobuf::Closure * done) override; + void cleanUndoBuffers( + google::protobuf::RpcController * cntl, + const Protos::CleanUndoBuffersReq * request, + Protos::CleanUndoBuffersResp * response, + google::protobuf::Closure * done) override; + void acquireLock( google::protobuf::RpcController * cntl, const Protos::AcquireLockReq * request, @@ -379,6 +385,12 @@ class CnchServerServiceImpl : protected WithMutableContext, public DB::Protos::C Protos::notifyTableCreatedResp * response, google::protobuf::Closure * done) override; + void getDedupImplVersion( + google::protobuf::RpcController * cntl, + const Protos::GetDedupImplVersionReq * request, + Protos::GetDedupImplVersionResp * response, + google::protobuf::Closure * done) override; + private: const UInt64 server_start_time; std::optional global_gc_manager; diff --git a/src/CloudServices/CnchWorkerClient.cpp b/src/CloudServices/CnchWorkerClient.cpp index 09e9ccc9f17..67796bf8cad 100644 --- a/src/CloudServices/CnchWorkerClient.cpp +++ b/src/CloudServices/CnchWorkerClient.cpp @@ -28,6 +28,9 @@ #include #include #include +#include +#include +#include #include #include #include "Storages/Hive/HiveFile/IHiveFile.h" @@ -218,32 +221,13 @@ void CnchWorkerClient::sendCreateQueries( for (const auto & cnch_table_create_query : cnch_table_create_queries) *request.mutable_cnch_table_create_queries()->Add() = cnch_table_create_query; + cntl.set_timeout_ms(settings.send_plan_segment_timeout_ms.totalMilliseconds()); stub->sendCreateQuery(&cntl, &request, &response, nullptr); assertController(cntl); RPCHelpers::checkResponse(response); } -brpc::CallId CnchWorkerClient::sendCnchFileDataParts( - const ContextPtr & context, - const StoragePtr & storage, - const String & local_table_name, - const DB::FileDataPartsCNCHVector & parts, - const ExceptionHandlerPtr & handler) -{ - Protos::SendCnchFileDataPartsReq request; - request.set_txn_id(context->getCurrentTransactionID()); - request.set_database_name(storage->getDatabaseName()); - request.set_table_name(local_table_name); - fillCnchFilePartsModel(parts, *request.mutable_parts()); - - auto * cntl = new brpc::Controller; - const auto call_id = cntl->call_id(); - auto * response = new Protos::SendCnchFileDataPartsResp; - stub->sendCnchFileDataParts(cntl, &request, response, brpc::NewCallback(RPCHelpers::onAsyncCallDone, response, cntl, handler)); - return call_id; -} - CheckResults CnchWorkerClient::checkDataParts( const ContextPtr & context, const IStorage & storage, @@ -302,8 +286,9 @@ brpc::CallId CnchWorkerClient::preloadDataParts( auto * response = new Protos::PreloadDataPartsResp(); /// adjust the timeout to prevent timeout if there are too many parts to send, const auto & settings = context->getSettingsRef(); - auto send_timeout = std::max(settings.max_execution_time.value.totalMilliseconds() >> 1, settings.brpc_data_parts_timeout_ms.totalMilliseconds()); - cntl->set_timeout_ms(send_timeout); + request.set_read_injection(settings.remote_fs_read_failed_injection); + cntl->set_timeout_ms(settings.preload_send_rpc_max_ms); + auto call_id = cntl->call_id(); stub->preloadDataParts(cntl, &request, response, brpc::NewCallback(RPCHelpers::onAsyncCallDone, response, cntl, handler)); @@ -340,49 +325,6 @@ brpc::CallId CnchWorkerClient::dropPartDiskCache( return cntl.call_id(); } -brpc::CallId CnchWorkerClient::sendQueryDataParts( - const ContextPtr & context, - const StoragePtr & storage, - const String & local_table_name, - const ServerDataPartsVector & data_parts, - const std::set & required_bucket_numbers, - const ExceptionHandlerWithFailedInfoPtr & handler, - const WorkerId & worker_id) -{ - Protos::SendDataPartsReq request; - request.set_txn_id(context->getCurrentTransactionID()); - request.set_database_name(storage->getDatabaseName()); - request.set_table_name(local_table_name); - request.set_disk_cache_mode(context->getSettingsRef().disk_cache_mode.toString()); - - fillBasePartAndDeleteBitmapModels(*storage, data_parts, *request.mutable_parts(), *request.mutable_bitmaps()); - for (const auto & bucket_num : required_bucket_numbers) - *request.mutable_bucket_numbers()->Add() = bucket_num; - - // TODO: - // auto udf_info = context.getNonSqlUdfVersionMap(); - // for (const auto & [name, version]: udf_info) - // { - // auto & new_info = *request.mutable_udf_infos()->Add(); - // new_info.set_function_name(name); - // new_info.set_version(version); - // } - - - auto * cntl = new brpc::Controller(); - auto * response = new Protos::SendDataPartsResp(); - /// adjust the timeout to prevent timeout if there are too many parts to send, - const auto & settings = context->getSettingsRef(); - auto send_timeout = std::max(settings.max_execution_time.value.totalMilliseconds() >> 1, settings.brpc_data_parts_timeout_ms.totalMilliseconds()); - cntl->set_timeout_ms(send_timeout); - - auto call_id = cntl->call_id(); - stub->sendQueryDataParts( - cntl, &request, response, brpc::NewCallback(RPCHelpers::onAsyncCallDoneWithFailedInfo, response, cntl, handler, worker_id)); - - return call_id; -} - brpc::CallId CnchWorkerClient::sendOffloadingInfo( // NOLINT [[maybe_unused]] const ContextPtr & context, [[maybe_unused]] const HostWithPortsVec & read_workers, @@ -412,14 +354,32 @@ brpc::CallId CnchWorkerClient::sendResources( /// so it should be larger than max_execution_time to make sure the session is not to be destroyed in advance. UInt64 recycle_timeout = max_execution_time > 0 ? max_execution_time + 60UL : 3600; request.set_timeout(recycle_timeout); + if (!settings.session_timezone.value.empty()) + request.set_session_timezone(settings.session_timezone.value); bool require_worker_info = false; for (const auto & resource: resources_to_send) { if (!resource.sent_create_query) { - request.add_create_queries(resource.create_table_query); - request.add_dynamic_object_column_schema(resource.object_columns.toString()); + const auto & def = resource.table_definition; + if (resource.table_definition.cacheable) + { + auto * cacheable = request.add_cacheable_create_queries(); + RPCHelpers::fillStorageID(resource.storage->getStorageID(), *cacheable->mutable_storage_id()); + cacheable->set_definition(def.definition); + if (!resource.object_columns.empty()) + cacheable->set_dynamic_object_column_schema(resource.object_columns.toString()); + cacheable->set_local_engine_type(static_cast(def.engine_type)); + cacheable->set_local_table_name(def.local_table_name); + if (!def.underlying_dictionary_tables.empty()) + cacheable->set_local_underlying_dictionary_tables(def.underlying_dictionary_tables); + } + else + { + request.add_create_queries(def.definition); + request.add_dynamic_object_column_schema(resource.object_columns.toString()); + } } /// parts @@ -440,7 +400,7 @@ brpc::CallId CnchWorkerClient::sendResources( } table_data_parts.set_database(resource.storage->getDatabaseName()); - table_data_parts.set_table(resource.worker_table_name); + table_data_parts.set_table(resource.table_definition.local_table_name); if (resource.table_version) { require_worker_info = true; @@ -492,8 +452,9 @@ brpc::CallId CnchWorkerClient::sendResources( { auto current_wg = context->getCurrentWorkerGroup(); auto * worker_info = request.mutable_worker_info(); - // TODO: resource manager should gurantee the worker number and worker index are consistent - RPCHelpers::fillWorkerInfo(*worker_info, worker_id.id, current_wg->workerNum()); + worker_info->set_worker_id(worker_id.id); + worker_info->set_index(current_wg->getWorkerIndex(worker_id.id)); + worker_info->set_num_workers(current_wg->workerNum()); // worker info validation if (worker_info->num_workers() <= worker_info->index()) @@ -503,7 +464,6 @@ brpc::CallId CnchWorkerClient::sendResources( request.set_disk_cache_mode(context->getSettingsRef().disk_cache_mode.toString()); - LOG_TRACE(log, "request : {}", request.ShortDebugString()); brpc::Controller * cntl = new brpc::Controller; /// send_timeout refers to the time to send resource to worker /// If max_execution_time is not set, the send_timeout will be set to brpc_data_parts_timeout_ms @@ -516,6 +476,81 @@ brpc::CallId CnchWorkerClient::sendResources( return call_id; } +static void onDedupTaskDone(Protos::ExecuteDedupTaskResp * response, brpc::Controller * cntl, ExceptionHandlerPtr handler, std::function funcOnCallback) +{ + try + { + std::unique_ptr response_guard(response); + std::unique_ptr cntl_guard(cntl); + RPCHelpers::assertController(*cntl); + RPCHelpers::checkResponse(*response); + funcOnCallback(/*success*/ true); + } + catch (...) + { + handler->setException(std::current_exception()); + funcOnCallback(/*success*/ false); + } +} + +brpc::CallId CnchWorkerClient::executeDedupTask( + const ContextPtr & context, + const TxnTimestamp & txn_id, + UInt16 rpc_port, + const IStorage & storage, + const CnchDedupHelper::DedupTask & dedup_task, + const ExceptionHandlerPtr & handler, + std::function funcOnCallback) +{ + Protos::ExecuteDedupTaskReq request; + request.set_txn_id(txn_id); + request.set_rpc_port(rpc_port); + RPCHelpers::fillUUID(dedup_task.storage_id.uuid, *request.mutable_table_uuid()); + request.set_dedup_mode(static_cast(dedup_task.dedup_mode)); + /// New parts + for (const auto & new_part : dedup_task.new_parts) + { + fillPartModel(storage, *new_part, *request.add_new_parts()); + request.add_new_parts_paths()->assign(new_part->relative_path); + } + for (const auto & delete_bitmap : dedup_task.delete_bitmaps_for_new_parts) + { + auto * new_bitmap = request.add_delete_bitmaps_for_new_parts(); + new_bitmap->CopyFrom(*(delete_bitmap->getModel())); + } + + /// Staged parts + for (const auto & staged_part : dedup_task.staged_parts) + { + fillPartModel(storage, *staged_part, *request.add_staged_parts()); + request.add_staged_parts_paths()->assign(staged_part->relative_path); + } + for (const auto & delete_bitmap : dedup_task.delete_bitmaps_for_staged_parts) + { + auto * new_bitmap = request.add_delete_bitmaps_for_staged_parts(); + new_bitmap->CopyFrom(*(delete_bitmap->getModel())); + } + + /// Visible parts + for (const auto & visible_part : dedup_task.visible_parts) + { + fillPartModel(storage, *visible_part, *request.add_visible_parts()); + request.add_visible_parts_paths()->assign(visible_part->relative_path); + } + for (const auto & delete_bitmap : dedup_task.delete_bitmaps_for_visible_parts) + { + auto * new_bitmap = request.add_delete_bitmaps_for_visible_parts(); + new_bitmap->CopyFrom(*(delete_bitmap->getModel())); + } + + auto * cntl = new brpc::Controller; + cntl->set_timeout_ms(context->getSettingsRef().max_dedup_execution_time.totalMilliseconds()); + const auto call_id = cntl->call_id(); + auto * response = new Protos::ExecuteDedupTaskResp; + stub->executeDedupTask(cntl, &request, response, brpc::NewCallback(onDedupTaskDone, response, cntl, handler, funcOnCallback)); + return call_id; +} + brpc::CallId CnchWorkerClient::removeWorkerResource(TxnTimestamp txn_id, ExceptionHandlerPtr handler) { brpc::Controller * cntl = new brpc::Controller; diff --git a/src/CloudServices/CnchWorkerClient.h b/src/CloudServices/CnchWorkerClient.h index 7a8b4ed29e1..8aa8e4a6f6c 100644 --- a/src/CloudServices/CnchWorkerClient.h +++ b/src/CloudServices/CnchWorkerClient.h @@ -28,8 +28,8 @@ #include #include #include -#include "Storages/Hive/HiveFile/IHiveFile_fwd.h" -#include "Storages/MergeTree/MergeTreeDataPartCNCH_fwd.h" +#include +#include #include #include @@ -50,6 +50,11 @@ namespace IngestColumnCnch struct IngestPartitionParam; } +namespace CnchDedupHelper +{ + struct DedupTask; +} + class MergeTreeMetaBase; class StorageMaterializedView; struct MarkRange; @@ -83,22 +88,6 @@ class CnchWorkerClient : public RpcClientBase /// send resource to worker async void sendCreateQueries(const ContextPtr & context, const std::vector & create_queries, std::set cnch_table_create_queries = {}); - brpc::CallId sendQueryDataParts( - const ContextPtr & context, - const StoragePtr & storage, - const String & local_table_name, - const ServerDataPartsVector & parts, - const std::set & required_bucket_numbers, - const ExceptionHandlerWithFailedInfoPtr & handler, - const WorkerId & worker_id = WorkerId{}); - - brpc::CallId sendCnchFileDataParts( - const ContextPtr & context, - const StoragePtr & storage, - const String & local_table_name, - const FileDataPartsCNCHVector & parts, - const ExceptionHandlerPtr & handler); - CheckResults checkDataParts( const ContextPtr & context, const IStorage & storage, @@ -117,14 +106,14 @@ class CnchWorkerClient : public RpcClientBase UInt64 parts_preload_level, UInt64 submit_ts); - brpc::CallId dropPartDiskCache( - const ContextPtr & context, - const TxnTimestamp & txn_id, - const IStorage & storage, - const String & create_local_table_query, - const ServerDataPartsVector & parts, - bool sync, - bool drop_vw_disk_cache); + brpc::CallId dropPartDiskCache( + const ContextPtr & context, + const TxnTimestamp & txn_id, + const IStorage & storage, + const String & create_local_table_query, + const ServerDataPartsVector & parts, + bool sync, + bool drop_vw_disk_cache); brpc::CallId sendOffloadingInfo( const ContextPtr & context, @@ -140,6 +129,15 @@ class CnchWorkerClient : public RpcClientBase const WorkerId & worker_id, bool with_mutations = false); + brpc::CallId executeDedupTask( + const ContextPtr & context, + const TxnTimestamp & txn_id, + UInt16 rpc_port, + const IStorage & storage, + const CnchDedupHelper::DedupTask & dedup_task, + const ExceptionHandlerPtr & handler, + std::function funcOnCallback); + brpc::CallId removeWorkerResource(TxnTimestamp txn_id, ExceptionHandlerPtr handler); void createDedupWorker(const StorageID & storage_id, const String & create_table_query, const HostWithPorts & host_ports, const size_t & deduper_index); diff --git a/src/CloudServices/CnchWorkerResource.cpp b/src/CloudServices/CnchWorkerResource.cpp index b101932f3df..4b09b3611a2 100644 --- a/src/CloudServices/CnchWorkerResource.cpp +++ b/src/CloudServices/CnchWorkerResource.cpp @@ -16,23 +16,23 @@ #include #include +#include #include #include #include #include -#include #include -#include +#include +#include #include -#include -#include +#include #include -#include +#include #include -#include #include -#include -#include +#include +#include +#include namespace DB @@ -40,122 +40,142 @@ namespace DB namespace ErrorCodes { - extern const int DUPLICATE_COLUMN; - extern const int INCORRECT_QUERY; + extern const int BAD_ARGUMENTS; extern const int TABLE_ALREADY_EXISTS; } -void CnchWorkerResource::executeCreateQuery(ContextMutablePtr context, const String & create_query, bool skip_if_exists, const ColumnsDescription & object_columns) +static ASTPtr parseCreateQuery(ContextMutablePtr context, const String & create_query) { - LOG_DEBUG(&Poco::Logger::get("WorkerResource"), "start create cloud table {}", create_query); const char * begin = create_query.data(); const char * end = create_query.data() + create_query.size(); ParserQueryWithOutput parser{end}; const auto & settings = context->getSettingsRef(); - ASTPtr ast_query = parseQuery(parser, begin, end, "CreateCloudTable", settings.max_query_size, settings.max_parser_depth); + return parseQuery(parser, begin, end, "CreateCloudTable", settings.max_query_size, settings.max_parser_depth); +} + +void CnchWorkerResource::executeCreateQuery(ContextMutablePtr context, const String & create_query, bool skip_if_exists, const ColumnsDescription & object_columns) +{ + LOG_DEBUG(&Poco::Logger::get("WorkerResource"), "start create cloud table {}", create_query); + auto ast_query = parseCreateQuery(context, create_query); auto & ast_create_query = ast_query->as(); /// set query settings + /// TODO: can we remove this? i.e., don't rely on create query to pass query setting if (ast_create_query.settings_ast) InterpreterSetQuery(ast_create_query.settings_ast, context).executeForCurrentContext(); + auto res = createStorageFromQuery(ast_create_query, context); + if (auto cloud_table = std::dynamic_pointer_cast(res)) + cloud_table->resetObjectColumns(object_columns); + res->startup(); + + bool throw_if_exists = !ast_create_query.if_not_exists && !skip_if_exists; const auto & database_name = ast_create_query.database; // not empty. const auto & table_name = ast_create_query.table; String tenant_db = formatTenantDatabaseName(database_name); - { - auto lock = getLock(); - if (cloud_tables.find({tenant_db, table_name}) != cloud_tables.end()) - { - if (ast_create_query.if_not_exists || skip_if_exists) - return; - else - throw Exception("Table " + tenant_db + "." + table_name + " already exists.", ErrorCodes::TABLE_ALREADY_EXISTS); - } - } + insertCloudTable({tenant_db, table_name}, res, context, throw_if_exists); +} - ColumnsDescription columns; - IndicesDescription indices; - ConstraintsDescription constraints; - ForeignKeysDescription foreign_keys; - UniqueNotEnforcedDescription unique_not_enforced; +void CnchWorkerResource::executeCacheableCreateQuery( + ContextMutablePtr context, + const StorageID & cnch_storage_id, + const String & definition, + const String & local_table_name, + WorkerEngineType engine_type, + const String & underlying_dictionary_tables, + const ColumnsDescription & object_columns) +{ + static auto * log = &Poco::Logger::get("WorkerResource"); - if (ast_create_query.columns_list) + std::shared_ptr cached; + if (auto cache = context->tryGetCloudTableDefinitionCache(); cache) { - if (ast_create_query.columns_list->columns) + auto load = [&]() -> std::shared_ptr { - // Set attach = true to avoid making columns nullable due to ANSI settings, because the dialect change - // should NOT affect existing tables. - columns = InterpreterCreateQuery::getColumnsDescription(*ast_create_query.columns_list->columns, context, /* attach= */ true); - } - - if (ast_create_query.columns_list->indices) - for (const auto & index : ast_create_query.columns_list->indices->children) - indices.push_back(IndexDescription::getIndexFromAST(index->clone(), columns, context)); - - if (ast_create_query.columns_list->constraints) - for (const auto & constraint : ast_create_query.columns_list->constraints->children) - constraints.constraints.push_back(std::dynamic_pointer_cast(constraint->clone())); - - if (ast_create_query.columns_list->foreign_keys) - for (const auto & foreign_key : ast_create_query.columns_list->foreign_keys->children) - foreign_keys.foreign_keys.push_back(std::dynamic_pointer_cast(foreign_key->clone())); - - if (ast_create_query.columns_list->unique) - for (const auto & unique : ast_create_query.columns_list->unique->children) - unique_not_enforced.unique.push_back(std::dynamic_pointer_cast(unique->clone())); + auto ast_query = parseCreateQuery(context, definition); + auto & create_query = ast_query->as(); + + replaceCnchWithCloud( + create_query.storage, + cnch_storage_id.getDatabaseName(), + cnch_storage_id.getTableName(), + engine_type); + + auto table = createStorageFromQuery(create_query, context); + if (auto cloud_table = std::dynamic_pointer_cast(table)) + return cloud_table; + return {}; + }; + + cached = cache->getOrSet(CloudTableDefinitionCache::hash(definition), std::move(load)).first; } - else - throw Exception("Incorrect CREATE query: required list of column descriptions or AS section or SELECT.", ErrorCodes::INCORRECT_QUERY); - /// Even if query has list of columns, canonicalize it (unfold Nested columns). - ASTPtr new_columns = InterpreterCreateQuery::formatColumns(columns, ParserSettings::valueOf(context->getSettingsRef())); - ASTPtr new_indices = InterpreterCreateQuery::formatIndices(indices); - ASTPtr new_constraints = InterpreterCreateQuery::formatConstraints(constraints); - ASTPtr new_foreign_keys = InterpreterCreateQuery::formatForeignKeys(foreign_keys); - ASTPtr new_unique_not_enforced = InterpreterCreateQuery::formatUnique(unique_not_enforced); + StoragePtr res; + if (cached) + { + LOG_DEBUG(log, "Creating cloud table {} from cached template of definition {}", local_table_name, definition); + StorageID actual_table_id = cached->getStorageID(); + actual_table_id.table_name = local_table_name; - if (ast_create_query.columns_list->columns) - ast_create_query.columns_list->replace(ast_create_query.columns_list->columns, new_columns); + std::unique_ptr new_settings = std::make_unique(*cached->getSettings()); + if (!underlying_dictionary_tables.empty()) + new_settings->underlying_dictionary_tables = underlying_dictionary_tables; - if (ast_create_query.columns_list->indices) - ast_create_query.columns_list->replace(ast_create_query.columns_list->indices, new_indices); + switch (engine_type) + { + case WorkerEngineType::CLOUD: + res = StorageCloudMergeTree::create( + actual_table_id, + cnch_storage_id.database_name, + cnch_storage_id.table_name, + *cached->getInMemoryMetadataPtr(), + context, + /*date_column_name*/ "", + cached->getMergingParams(), + std::move(new_settings)); + break; + case WorkerEngineType::DICT: + /// NOTE: StorageDictCloudMergeTree::create is broken, don't use it + res = std::make_shared( + actual_table_id, + cnch_storage_id.database_name, + cnch_storage_id.table_name, + *cached->getInMemoryMetadataPtr(), + context, + /*date_column_name*/ "", + cached->getMergingParams(), + std::move(new_settings)); + break; + default: + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown value for engine_type: {}", static_cast(engine_type)); + } - if (ast_create_query.columns_list->constraints) - ast_create_query.columns_list->replace(ast_create_query.columns_list->constraints, new_constraints); + } + else /// for cloud table other than CloudMergeTree. e.g., CloudS3, CloudHive, ... + { + auto ast_query = parseCreateQuery(context, definition); + auto & create_query = ast_query->as(); - if (ast_create_query.columns_list->foreign_keys) - ast_create_query.columns_list->replace(ast_create_query.columns_list->foreign_keys, new_foreign_keys); + replaceCnchWithCloud( + create_query.storage, + cnch_storage_id.getDatabaseName(), + cnch_storage_id.getTableName(), + engine_type); - if (ast_create_query.columns_list->unique) - ast_create_query.columns_list->replace(ast_create_query.columns_list->unique, new_unique_not_enforced); + create_query.table = local_table_name; + if (!underlying_dictionary_tables.empty()) + modifyOrAddSetting(create_query, "underlying_dictionary_tables", Field(underlying_dictionary_tables)); - /// Check for duplicates - std::set all_columns; - for (const auto & column : columns) - { - if (!all_columns.emplace(column.name).second) - throw Exception("Column " + backQuoteIfNeed(column.name) + " already exists", ErrorCodes::DUPLICATE_COLUMN); + LOG_DEBUG(log, "Creating cloud table {} from rewritted definition {}", local_table_name, serializeAST(create_query)); + res = createStorageFromQuery(create_query, context); } - /// Table constructing - StoragePtr res = StorageFactory::instance().get(ast_create_query, "", context, context->getGlobalContext(), columns, constraints, foreign_keys, unique_not_enforced, false); - res->startup(); - if (auto cloud_table = std::dynamic_pointer_cast(res)) cloud_table->resetObjectColumns(object_columns); + res->startup(); - { - auto lock = getLock(); - cloud_tables.emplace(std::make_pair(tenant_db, table_name), res); - auto it = memory_databases.find(tenant_db); - if (it == memory_databases.end()) - { - DatabasePtr database = std::make_shared(tenant_db, context->getGlobalContext()); - memory_databases.insert(std::make_pair(tenant_db, std::move(database))); - } - } - - LOG_DEBUG(&Poco::Logger::get("WorkerResource"), "Successfully create cloud table {} and database {}", res->getStorageID().getNameForLogs(), database_name); + auto res_table_id = res->getStorageID(); + insertCloudTable({res_table_id.getDatabaseName(), res_table_id.getTableName()}, res, context, /*throw_if_exists=*/ false); } StoragePtr CnchWorkerResource::getTable(const StorageID & table_id) const @@ -184,6 +204,27 @@ DatabasePtr CnchWorkerResource::getDatabase(const String & database_name) const return {}; } +void CnchWorkerResource::insertCloudTable(DatabaseAndTableName key, const StoragePtr & storage, ContextPtr context, bool throw_if_exists) +{ + auto & tenant_db = key.first; + { + auto lock = getLock(); + bool inserted = cloud_tables.emplace(key, storage).second; + if (!inserted && throw_if_exists) + throw Exception(ErrorCodes::TABLE_ALREADY_EXISTS, "Table {} already exists", storage->getStorageID().getFullTableName()); + auto it = memory_databases.find(tenant_db); + if (it == memory_databases.end()) + { + DatabasePtr database = std::make_shared(tenant_db, context->getGlobalContext()); + memory_databases.insert(std::make_pair(tenant_db, std::move(database))); + } + } + + static auto * log = &Poco::Logger::get("WorkerResource"); + LOG_DEBUG(log, "Successfully create database {} and table {} {}", + tenant_db, storage->getName(), storage->getStorageID().getNameForLogs()); +} + bool CnchWorkerResource::isCnchTableInWorker(const StorageID & table_id) const { String tenant_db = formatTenantDatabaseName(table_id.getDatabaseName()); diff --git a/src/CloudServices/CnchWorkerResource.h b/src/CloudServices/CnchWorkerResource.h index c4b422dacf6..24da35455dd 100644 --- a/src/CloudServices/CnchWorkerResource.h +++ b/src/CloudServices/CnchWorkerResource.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -35,11 +36,21 @@ class CnchWorkerResource { public: void executeCreateQuery(ContextMutablePtr context, const String & create_query, bool skip_if_exists = false, const ColumnsDescription & object_columns = {}); + + void executeCacheableCreateQuery( + ContextMutablePtr context, + const StorageID & cnch_storage_id, + const String & definition, + const String & local_table_name, + WorkerEngineType engine_type, + const String & underlying_dictionary_tables, + const ColumnsDescription & object_columns); + StoragePtr getTable(const StorageID & table_id) const; DatabasePtr getDatabase(const String & database_name) const; bool isCnchTableInWorker(const StorageID & table_id) const; - ~CnchWorkerResource() + ~CnchWorkerResource() { clearResource(); } @@ -83,6 +94,8 @@ class CnchWorkerResource TablesMap cloud_tables; std::unordered_map memory_databases; + void insertCloudTable(DatabaseAndTableName key, const StoragePtr & storage, ContextPtr context, bool throw_if_exists); + /// for offloading query TablesSet cnch_tables; std::map worker_table_names; diff --git a/src/CloudServices/CnchWorkerServiceImpl.cpp b/src/CloudServices/CnchWorkerServiceImpl.cpp index d2fc6fb2615..53cb069af0c 100644 --- a/src/CloudServices/CnchWorkerServiceImpl.cpp +++ b/src/CloudServices/CnchWorkerServiceImpl.cpp @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -37,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +56,9 @@ #include #include #include +#include +#include +#include #include #if USE_RDKAFKA @@ -76,7 +82,8 @@ extern const Event PreloadExecTotalOps; namespace ProfileEvents { -extern const Event PreloadExecTotalOps; + extern const Event QueryCreateTablesMicroseconds; + extern const Event QuerySendResourcesMicroseconds; } namespace DB @@ -89,6 +96,7 @@ namespace ErrorCodes extern const int PREALLOCATE_TOPOLOGY_ERROR; extern const int PREALLOCATE_QUERY_INTENT_NOT_FOUND; extern const int SESSION_NOT_FOUND; + extern const int ABORTED; } CnchWorkerServiceImpl::CnchWorkerServiceImpl(ContextMutablePtr context_) @@ -136,7 +144,10 @@ CnchWorkerServiceImpl::~CnchWorkerServiceImpl() RPCHelpers::handleException(response->mutable_exception()); \ } \ }; \ - THREADPOOL_SCHEDULE(_func); + Stopwatch watch; \ + THREADPOOL_SCHEDULE(_func); \ + UInt64 milliseconds = watch.elapsedMilliseconds(); \ + if (milliseconds > 100) LOG_DEBUG(log, "CnchWorkerService rpc request threadpool schedule cost : {} ", milliseconds); void CnchWorkerServiceImpl::executeSimpleQuery( @@ -264,7 +275,15 @@ void CnchWorkerServiceImpl::submitManipulationTask( if (!data) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Table {} is not CloudMergeTree", storage->getStorageID().getNameForLogs()); if (request->has_dynamic_object_column_schema()) + { + LOG_TRACE( + log, + "Received table:{}.{} with dynamic object column schema:{}.", + data->getCnchDatabase(), + data->getCnchTable(), + request->dynamic_object_column_schema()); data->resetObjectColumns(ColumnsDescription::parse(request->dynamic_object_column_schema())); + } auto params = ManipulationTaskParams(storage); params.type = static_cast(request->type()); @@ -301,6 +320,7 @@ void CnchWorkerServiceImpl::submitManipulationTask( rpc_context->initCnchServerResource(txn_id); rpc_context->setSetting("prefer_localhost_replica", false); rpc_context->setSetting("prefer_cnch_catalog", true); + rpc_context->setSetting("max_execution_time", 3600); trySetVirtualWarehouseAndWorkerGroup(data->getSettings()->cnch_vw_default.value, rpc_context); } @@ -473,83 +493,6 @@ void CnchWorkerServiceImpl::sendCreateQuery( }) } -void CnchWorkerServiceImpl::sendQueryDataParts( - google::protobuf::RpcController *, - const Protos::SendDataPartsReq * request, - Protos::SendDataPartsResp * response, - google::protobuf::Closure * done) -{ - SUBMIT_THREADPOOL({ - auto session = getContext()->acquireNamedCnchSession(request->txn_id(), {}, true); - const auto & query_context = session->context; - - auto storage = DatabaseCatalog::instance().getTable({request->database_name(), request->table_name()}, query_context); - auto & cloud_merge_tree = dynamic_cast(*storage); - - LOG_DEBUG( - log, - "Receiving {} parts for table {}(txn_id: {})", - request->parts_size(), - cloud_merge_tree.getStorageID().getNameForLogs(), - request->txn_id()); - - MergeTreeMutableDataPartsVector data_parts; - if (cloud_merge_tree.getInMemoryMetadataPtr()->hasUniqueKey()) - data_parts = createBasePartAndDeleteBitmapFromModelsForSend( - cloud_merge_tree, request->parts(), request->bitmaps()); - else - data_parts = createPartVectorFromModelsForSend(cloud_merge_tree, request->parts()); - - if (request->has_disk_cache_mode()) - { - SettingFieldDiskCacheMode disk_cache_mode; - disk_cache_mode.parseFromString(request->disk_cache_mode()); - if (disk_cache_mode.value != DiskCacheMode::AUTO) - { - for (auto & part : data_parts) - part->disk_cache_mode = disk_cache_mode; - } - } - cloud_merge_tree.loadDataParts(data_parts); - - LOG_DEBUG(log, "Received and loaded {} server parts.", data_parts.size()); - - std::set required_bucket_numbers; - for (const auto & bucket_number : request->bucket_numbers()) - required_bucket_numbers.insert(bucket_number); - - cloud_merge_tree.setRequiredBucketNumbers(required_bucket_numbers); - - // std::map udf_infos; - // for (const auto & udf_info: request->udf_infos()) - // udf_infos.emplace(udf_info.function_name(), udf_info.version()); - }) -} - - -void CnchWorkerServiceImpl::sendCnchFileDataParts( - google::protobuf::RpcController *, - const Protos::SendCnchFileDataPartsReq * request, - Protos::SendCnchFileDataPartsResp * response, - google::protobuf::Closure * done) -{ - SUBMIT_THREADPOOL({ - auto session = getContext()->acquireNamedCnchSession(request->txn_id(), {}, true); - const auto & query_context = session->context; - - auto storage = DatabaseCatalog::instance().getTable({request->database_name(), request->table_name()}, query_context); - auto & cnchfile_table = dynamic_cast(*storage); - - LOG_DEBUG(log, "Receiving parts for table {}", cnchfile_table.getStorageID().getNameForLogs()); - - auto data_parts = createCnchFileDataParts(getContext(), request->parts()); - - cnchfile_table.loadDataParts(data_parts); - - LOG_DEBUG(log, "Received and loaded {} file parts.", data_parts.size()); - }) -} - void CnchWorkerServiceImpl::checkDataParts( google::protobuf::RpcController * cntl, const Protos::CheckDataPartsReq * request, @@ -614,8 +557,6 @@ void CnchWorkerServiceImpl::preloadDataParts( google::protobuf::Closure * done) { SUBMIT_THREADPOOL({ - SCOPE_EXIT({ProfileEvents::increment(ProfileEvents::PreloadExecTotalOps, 1, Metrics::MetricType::Rate);}); - Stopwatch watch; auto rpc_context = RPCHelpers::createSessionContextForRPC(getContext(), *cntl); StoragePtr storage = createStorageFromQuery(request->create_table_query(), rpc_context); @@ -637,30 +578,43 @@ void CnchWorkerServiceImpl::preloadDataParts( || (!cloud_merge_tree.getSettings()->parts_preload_level && !cloud_merge_tree.getSettings()->enable_preload_parts)) return; - std::unique_ptr pool; - ThreadPool * pool_ptr; + auto preload_level = request->preload_level(); + auto submit_ts = request->submit_ts(); + auto read_injection = request->read_injection(); + if (request->sync()) { - pool = std::make_unique(std::min(data_parts.size(), cloud_merge_tree.getSettings()->cnch_parallel_preloading.value)); - pool_ptr = pool.get(); + auto & settings = getContext()->getSettingsRef(); + auto pool = std::make_unique(std::min(data_parts.size(), settings.cnch_parallel_preloading.value)); + for (const auto & part : data_parts) + { + pool->scheduleOrThrowOnError([part, preload_level, submit_ts, read_injection, storage] { + part->remote_fs_read_failed_injection = read_injection; + part->disk_cache_mode = DiskCacheMode::SKIP_DISK_CACHE;// avoid getCheckum & getIndex re-cache + part->preload(preload_level, submit_ts); + }); + } + pool->wait(); + LOG_DEBUG( + log, + "Finish preload tasks in {} ms, level: {}, sync: {}, size: {}", + watch.elapsedMilliseconds(), + preload_level, + sync, + data_parts.size()); } else - pool_ptr = &(IDiskCache::getThreadPool()); - - for (const auto & part : data_parts) { - part->preload(request->preload_level(), *pool_ptr, request->submit_ts()); + ThreadPool * preload_thread_pool = &(IDiskCache::getPreloadPool()); + for (const auto & part : data_parts) + { + preload_thread_pool->scheduleOrThrowOnError([part, preload_level, submit_ts, read_injection, storage] { + part->remote_fs_read_failed_injection = read_injection; + part->disk_cache_mode = DiskCacheMode::SKIP_DISK_CACHE;// avoid getCheckum & getIndex re-cache + part->preload(preload_level, submit_ts); + }); + } } - - if (request->sync()) - pool->wait(); - - LOG_DEBUG( - storage->getLogger(), - "Finish preload tasks in {} ms, level: {}, sync: {}", - watch.elapsedMilliseconds(), - request->preload_level(), - request->sync()); }) } @@ -728,22 +682,45 @@ void CnchWorkerServiceImpl::sendResources( auto session = rpc_context->acquireNamedCnchSession(request->txn_id(), request->timeout(), false); auto query_context = session->context; query_context->setTemporaryTransaction(request->txn_id(), request->primary_txn_id()); + if (request->has_session_timezone()) + query_context->setSetting("session_timezone", request->session_timezone()); + + CurrentThread::QueryScope query_scope(query_context); auto worker_resource = query_context->getCnchWorkerResource(); /// store cloud tables in cnch_session_resource. { + Stopwatch create_timer; /// create a copy of session_context to avoid modify settings in SessionResource auto context_for_create = Context::createCopy(query_context); for (int i = 0; i < request->create_queries_size(); i++) { auto create_query = request->create_queries().at(i); - auto object_columns = request->dynamic_object_column_schema().at(i); - - worker_resource->executeCreateQuery(context_for_create, create_query, false, ColumnsDescription::parse(object_columns)); - } + ColumnsDescription object_columns; + if (i < request->dynamic_object_column_schema_size()) + object_columns = ColumnsDescription::parse(request->dynamic_object_column_schema().at(i)); - LOG_DEBUG(log, "Successfully create {} queries for Session: {}", request->create_queries_size(), request->txn_id()); + worker_resource->executeCreateQuery(context_for_create, create_query, false, object_columns); + } + for (int i = 0; i < request->cacheable_create_queries_size(); i++) + { + auto & item = request->cacheable_create_queries().at(i); + ColumnsDescription object_columns; + if (item.has_dynamic_object_column_schema()) + object_columns = ColumnsDescription::parse(item.dynamic_object_column_schema()); + worker_resource->executeCacheableCreateQuery( + context_for_create, + RPCHelpers::createStorageID(item.storage_id()), + item.definition(), + item.local_table_name(), + static_cast(item.local_engine_type()), + item.local_underlying_dictionary_tables(), + object_columns); + } + create_timer.stop(); + LOG_INFO(log, "Prepared {} tables for session {} in {} us", request->create_queries_size() + request->cacheable_create_queries_size(), request->txn_id(), create_timer.elapsedMicroseconds()); + ProfileEvents::increment(ProfileEvents::QueryCreateTablesMicroseconds, create_timer.elapsedMicroseconds()); } for (const auto & data : request->data_parts()) @@ -756,17 +733,8 @@ void CnchWorkerServiceImpl::sendResources( { WGWorkerInfoPtr worker_info = RPCHelpers::createWorkerInfo(request->worker_info()); UInt64 version = data.table_version(); - ServerDataPartsWithDBM server_parts_with_dbms; - query_context->getGlobalDataManager()->loadDataPartsWithDBM(*cloud_merge_tree, cloud_merge_tree->getStorageUUID(), version, worker_info, server_parts_with_dbms); - size_t server_part_size = server_parts_with_dbms.first.size(); - size_t delete_bitmap_size = server_parts_with_dbms.second.size(); - cloud_merge_tree->loadServerDataPartsWithDBM(std::move(server_parts_with_dbms)); - - LOG_DEBUG( - log, - "Loaded {} server parts and {} delete bitmap for table {} with version {}", - server_part_size, - delete_bitmap_size, + cloud_merge_tree->setDataDescription(std::move(worker_info), version); + LOG_DEBUG(log, "Received table {} with data version {}", cloud_merge_tree->getStorageID().getNameForLogs(), version); } @@ -780,6 +748,8 @@ void CnchWorkerServiceImpl::sendResources( server_parts = createPartVectorFromModelsForSend(*cloud_merge_tree, data.server_parts()); + auto server_parts_size = server_parts.size(); + if (request->has_disk_cache_mode()) { auto disk_cache_mode = SettingFieldDiskCacheModeTraits::fromString(request->disk_cache_mode()); @@ -790,14 +760,24 @@ void CnchWorkerServiceImpl::sendResources( } } - cloud_merge_tree->loadDataParts(server_parts); + /// `loadDataParts` is an expensive action as it may involve remote read. + /// The worker rpc thread pool may be blocked when there are many `sendResources` requests. + /// Here we just pass the server_parts to storage. And it will do `loadDataParts` later (before reading). + /// One exception is StorageDictCloudMergeTree as it use a different read logic rather than StorageCloudMergeTree::read. + bool is_dict = false; + if (auto * cloud_dict = dynamic_cast(storage.get())) + { + cloud_dict->loadDataParts(server_parts); + is_dict = true; + } + else + cloud_merge_tree->receiveDataParts(std::move(server_parts)); LOG_DEBUG( log, - "Received and loaded {} parts for table {}(txn_id: {}), disk_cache_mode {}", - data.server_parts_size(), - cloud_merge_tree->getStorageID().getNameForLogs(), - request->txn_id(), request->disk_cache_mode()); + "Received {} parts for table {}(txn_id: {}), disk_cache_mode {}, is_dict: {}", + server_parts_size, cloud_merge_tree->getStorageID().getNameForLogs(), + request->txn_id(), request->disk_cache_mode(), is_dict); } if (!data.virtual_parts().empty()) @@ -810,6 +790,8 @@ void CnchWorkerServiceImpl::sendResources( virtual_parts = createPartVectorFromModelsForSend(*cloud_merge_tree, data.virtual_parts()); + auto virtual_parts_size = virtual_parts.size(); + if (request->has_disk_cache_mode()) { auto disk_cache_mode = SettingFieldDiskCacheModeTraits::fromString(request->disk_cache_mode()); @@ -820,13 +802,20 @@ void CnchWorkerServiceImpl::sendResources( } } - cloud_merge_tree->loadDataParts(virtual_parts); + bool is_dict = false; + if (auto * cloud_dict = dynamic_cast(storage.get())) + { + cloud_dict->loadDataParts(virtual_parts); + is_dict = true; + } + else + cloud_merge_tree->receiveVirtualDataParts(std::move(virtual_parts)); LOG_DEBUG( log, - "Received and loaded {} virtual server parts for table {}", - virtual_parts.size(), - cloud_merge_tree->getStorageID().getNameForLogs()); + "Received {} virtual parts for table {}(txn_id: {}), disk_cache_mode {}, is_dict: {}", + virtual_parts_size, cloud_merge_tree->getStorageID().getNameForLogs(), + request->txn_id(), request->disk_cache_mode(), is_dict); } std::set required_bucket_numbers; @@ -862,7 +851,9 @@ void CnchWorkerServiceImpl::sendResources( throw Exception("Unknown table engine: " + storage->getName(), ErrorCodes::UNKNOWN_TABLE); } - LOG_TRACE(log, "Received all resource for session: {}, elapsed: {}ms.", request->txn_id(), watch.elapsedMilliseconds()); + watch.stop(); + LOG_INFO(log, "Received all resources for session {} in {} us.", request->txn_id(), watch.elapsedMicroseconds()); + ProfileEvents::increment(ProfileEvents::QuerySendResourcesMicroseconds, watch.elapsedMicroseconds()); }) } @@ -1096,6 +1087,64 @@ void CnchWorkerServiceImpl::getDedupWorkerStatus( }) } +void CnchWorkerServiceImpl::executeDedupTask( + google::protobuf::RpcController * cntl, + const Protos::ExecuteDedupTaskReq * request, + Protos::ExecuteDedupTaskResp * response, + google::protobuf::Closure * done) +{ + SUBMIT_THREADPOOL({ + auto txn_id = TxnTimestamp(request->txn_id()); + auto rpc_context = RPCHelpers::createSessionContextForRPC(getContext(), *cntl); + rpc_context->getClientInfo().rpc_port = request->rpc_port(); + auto server_client + = rpc_context->getCnchServerClient(rpc_context->getClientInfo().current_address.host().toString(), request->rpc_port()); + auto worker_txn = std::make_shared(rpc_context, txn_id, server_client); + /// This stage is in commit process, we can not finish transaction here. + worker_txn->setIsInitiator(false); + rpc_context->setCurrentTransaction(worker_txn); + + auto catalog = getContext()->getCnchCatalog(); + TxnTimestamp ts = getContext()->getTimestamp(); + auto table_uuid_str = UUIDHelpers::UUIDToString(RPCHelpers::createUUID(request->table_uuid())); + auto table = catalog->tryGetTableByUUID(*getContext(), table_uuid_str, ts); + if (!table) + throw Exception(ErrorCodes::ABORTED, "Table {} has been dropped", table_uuid_str); + auto cnch_table = dynamic_pointer_cast(table); + if (!cnch_table) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Table {} is not cnch merge tree", table_uuid_str); + + auto new_parts = createPartVectorFromModels(*cnch_table, request->new_parts(), &request->new_parts_paths()); + DeleteBitmapMetaPtrVector delete_bitmaps_for_new_parts; + delete_bitmaps_for_new_parts.reserve(request->delete_bitmaps_for_new_parts_size()); + for (const auto & bitmap_model : request->delete_bitmaps_for_new_parts()) + delete_bitmaps_for_new_parts.emplace_back(createFromModel(*cnch_table, bitmap_model)); + + auto staged_parts = createPartVectorFromModels(*cnch_table, request->staged_parts(), &request->staged_parts_paths()); + DeleteBitmapMetaPtrVector delete_bitmaps_for_staged_parts; + delete_bitmaps_for_staged_parts.reserve(request->delete_bitmaps_for_staged_parts_size()); + for (const auto & bitmap_model : request->delete_bitmaps_for_staged_parts()) + delete_bitmaps_for_staged_parts.emplace_back(createFromModel(*cnch_table, bitmap_model)); + + auto visible_parts = createPartVectorFromModels(*cnch_table, request->visible_parts(), &request->visible_parts_paths()); + DeleteBitmapMetaPtrVector delete_bitmaps_for_visible_parts; + delete_bitmaps_for_visible_parts.reserve(request->delete_bitmaps_for_visible_parts_size()); + for (const auto & bitmap_model : request->delete_bitmaps_for_visible_parts()) + delete_bitmaps_for_visible_parts.emplace_back(createFromModel(*cnch_table, bitmap_model)); + + auto dedup_mode = static_cast(request->dedup_mode()); + auto dedup_task = std::make_shared(dedup_mode, cnch_table->getCnchStorageID()); + dedup_task->new_parts = std::move(new_parts); + dedup_task->delete_bitmaps_for_new_parts = std::move(delete_bitmaps_for_new_parts); + dedup_task->staged_parts = std::move(staged_parts); + dedup_task->delete_bitmaps_for_staged_parts = std::move(delete_bitmaps_for_staged_parts); + dedup_task->visible_parts = std::move(visible_parts); + dedup_task->delete_bitmaps_for_visible_parts = std::move(delete_bitmaps_for_visible_parts); + + CnchDedupHelper::executeDedupTask(*cnch_table, *dedup_task, txn_id, rpc_context); + }) +} + #if USE_RDKAFKA void CnchWorkerServiceImpl::submitKafkaConsumeTask( google::protobuf::RpcController * cntl, diff --git a/src/CloudServices/CnchWorkerServiceImpl.h b/src/CloudServices/CnchWorkerServiceImpl.h index 0c86f5aa70f..49578d86bca 100644 --- a/src/CloudServices/CnchWorkerServiceImpl.h +++ b/src/CloudServices/CnchWorkerServiceImpl.h @@ -133,6 +133,12 @@ class CnchWorkerServiceImpl : protected WithMutableContext, public DB::Protos::C Protos::GetDedupWorkerStatusResp * response, google::protobuf::Closure * done) override; + void executeDedupTask( + google::protobuf::RpcController *, + const Protos::ExecuteDedupTaskReq * request, + Protos::ExecuteDedupTaskResp * response, + google::protobuf::Closure * done) override; + #if USE_RDKAFKA void submitKafkaConsumeTask( google::protobuf::RpcController * cntl, @@ -179,12 +185,6 @@ class CnchWorkerServiceImpl : protected WithMutableContext, public DB::Protos::C Protos::SendCreateQueryResp * response, google::protobuf::Closure * done) override; - void sendQueryDataParts( - google::protobuf::RpcController * cntl, - const Protos::SendDataPartsReq * request, - Protos::SendDataPartsResp * response, - google::protobuf::Closure * done) override; - void sendResources( google::protobuf::RpcController * cntl, const Protos::SendResourcesReq * request, @@ -197,21 +197,6 @@ class CnchWorkerServiceImpl : protected WithMutableContext, public DB::Protos::C Protos::RemoveWorkerResourceResp * response, google::protobuf::Closure * done) override; - /* - void sendQueryVirtualDataParts( - google::protobuf::RpcController * cntl, - const Protos::SendVirtualDataPartsReq * request, - Protos::SendVirtualDataPartsResp * response, - google::protobuf::Closure * done) override {} - */ - - - void sendCnchFileDataParts( - google::protobuf::RpcController * cntl, - const Protos::SendCnchFileDataPartsReq * request, - Protos::SendCnchFileDataPartsResp * response, - google::protobuf::Closure * done) override; - void checkDataParts( google::protobuf::RpcController * cntl, const Protos::CheckDataPartsReq * request, diff --git a/src/CloudServices/DedupWorkerManager.cpp b/src/CloudServices/DedupWorkerManager.cpp index 3a1d06c9667..5dacbd2fce0 100644 --- a/src/CloudServices/DedupWorkerManager.cpp +++ b/src/CloudServices/DedupWorkerManager.cpp @@ -163,13 +163,14 @@ void DedupWorkerManager::initialize(StoragePtr & storage, StorageCnchMergeTree & } } -void DedupWorkerManager::createDeduperOnWorker(StoragePtr & storage, StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock) +void DedupWorkerManager::createDeduperOnWorker(StoragePtr & storage, StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock) { if (info->worker_client) return; try { - info->worker_storage_id = {storage->getStorageID().getDatabaseName(), storage->getStorageID().getTableName()}; + auto cnch_storage_id = storage->getStorageID(); + info->worker_storage_id = {cnch_storage_id.getDatabaseName(), cnch_storage_id.getTableName()}; selectDedupWorker(cnch_table, info, info_lock); /// create a unique table suffix @@ -177,14 +178,15 @@ void DedupWorkerManager::createDeduperOnWorker(StoragePtr & storage, StorageCnch info->worker_storage_id.table_name = storage_id.table_name + deduper_table_suffix; auto create_ast = getASTCreateQueryFromStorage(*storage, getContext()); - replaceCnchWithCloud( - *create_ast, info->worker_storage_id.table_name, storage->getStorageID().getDatabaseName(), storage->getStorageID().getTableName()); - modifyOrAddSetting(*create_ast, "cloud_enable_dedup_worker", Field(UInt64(1))); - modifyOrAddSetting(*create_ast, "allow_nullable_key", Field(UInt64(1))); + auto & create = *create_ast; + create.table = info->worker_storage_id.table_name; + replaceCnchWithCloud(create.storage, cnch_storage_id.getDatabaseName(), cnch_storage_id.getTableName()); + modifyOrAddSetting(create, "cloud_enable_dedup_worker", Field(UInt64(1))); + modifyOrAddSetting(create, "allow_nullable_key", Field(UInt64(1))); /// Set cnch uuid for CloudMergeTree to commit data on worker side - modifyOrAddSetting(*create_ast, "cnch_table_uuid", Field(static_cast(UUIDHelpers::UUIDToString(create_ast->uuid)))); + modifyOrAddSetting(create, "cnch_table_uuid", Field(static_cast(UUIDHelpers::UUIDToString(create_ast->uuid)))); /// It's not allowed to create multi tables with same uuid on Cnch-Worker side now - create_ast->uuid = UUIDHelpers::Nil; + create.uuid = UUIDHelpers::Nil; String create_query = getTableDefinitionFromCreateQuery(static_pointer_cast(create_ast), false); LOG_TRACE(log, "Create table query of dedup worker: {}", create_query); @@ -200,7 +202,7 @@ void DedupWorkerManager::createDeduperOnWorker(StoragePtr & storage, StorageCnch } } -void DedupWorkerManager::selectDedupWorker(StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) +void DedupWorkerManager::selectDedupWorker(StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) { auto vw_handle = getContext()->getVirtualWarehousePool().get(cnch_table.getSettings()->cnch_vw_write); HostWithPorts history_dedup_worker = dedup_scheduler->tryPickWorker(info->index); @@ -216,13 +218,13 @@ void DedupWorkerManager::selectDedupWorker(StorageCnchMergeTree & cnch_table, De } } -void DedupWorkerManager::markDedupWorker(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) +void DedupWorkerManager::markDedupWorker(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) { dedup_scheduler->markIndexDedupWorker(info->index, info->worker_client->getHostWithPorts()); info->is_running = true; } -void DedupWorkerManager::assignHighPriorityDedupPartition(DeduperInfoPtr & info, const Names & high_priority_partition, std::unique_lock & /*info_lock*/) +void DedupWorkerManager::assignHighPriorityDedupPartition(DeduperInfoPtr & info, const Names & high_priority_partition, std::unique_lock & /*info_lock*/) { if (!info->worker_client) return; @@ -230,13 +232,13 @@ void DedupWorkerManager::assignHighPriorityDedupPartition(DeduperInfoPtr & info, info->worker_client->assignHighPriorityDedupPartition(info->worker_storage_id, high_priority_partition); } -void DedupWorkerManager::unsetWorkerClient(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) +void DedupWorkerManager::unsetWorkerClient(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) { info->worker_client = nullptr; info->is_running = false; } -void DedupWorkerManager::assignRepairGran(DeduperInfoPtr & info, const DedupGran & dedup_gran, const UInt64 & max_event_time, std::unique_lock & /*info_lock*/) +void DedupWorkerManager::assignRepairGran(DeduperInfoPtr & info, const DedupGran & dedup_gran, const UInt64 & max_event_time, std::unique_lock & /*info_lock*/) { if (!info->worker_client) return; @@ -265,7 +267,7 @@ void DedupWorkerManager::stopDeduperWorker(DeduperInfoPtr & info) unsetWorkerClient(info, info_lock); } -String DedupWorkerManager::getDedupWorkerDebugInfo(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) +String DedupWorkerManager::getDedupWorkerDebugInfo(DeduperInfoPtr & info, std::unique_lock & /*info_lock*/) { if (!info->worker_client) return "dedup worker is not assigned."; @@ -273,7 +275,7 @@ String DedupWorkerManager::getDedupWorkerDebugInfo(DeduperInfoPtr & info, std::u + info->worker_client->getHostWithPorts().toDebugString(); } -bool DedupWorkerManager::checkDedupWorkerStatus(DeduperInfoPtr & info, std::unique_lock & info_lock) +bool DedupWorkerManager::checkDedupWorkerStatus(DeduperInfoPtr & info, std::unique_lock & info_lock) { if (!info->worker_client) return false; @@ -340,8 +342,8 @@ void DedupWorkerManager::dedupWithHighPriority(const ASTPtr & partition, const C DedupWorkerHeartbeatResult DedupWorkerManager::reportHeartbeat(const String & worker_table_name) { - std::lock_guard lock(deduper_infos_mutex); LOG_TRACE(log, "Report heartbeat of dedup worker: worker table name is {}", worker_table_name); + std::lock_guard lock(deduper_infos_mutex); for (const auto & info : deduper_infos) { std::lock_guard info_lock(info->mutex); diff --git a/src/CloudServices/DedupWorkerManager.h b/src/CloudServices/DedupWorkerManager.h index a98e1eee240..f9cd8acdc45 100644 --- a/src/CloudServices/DedupWorkerManager.h +++ b/src/CloudServices/DedupWorkerManager.h @@ -64,7 +64,7 @@ class DedupWorkerManager: public ICnchBGThread worker_client(other.worker_client), worker_storage_id(other.worker_storage_id) {} - mutable std::mutex mutex; + mutable bthread::Mutex mutex; bool is_running{false}; size_t index{0}; CnchWorkerClientPtr worker_client; @@ -80,20 +80,20 @@ class DedupWorkerManager: public ICnchBGThread void initialize(StoragePtr & storage, StorageCnchMergeTree & cnch_table); - void createDeduperOnWorker(StoragePtr & storage, StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock); + void createDeduperOnWorker(StoragePtr & storage, StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock); - void selectDedupWorker(StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock); + void selectDedupWorker(StorageCnchMergeTree & cnch_table, DeduperInfoPtr & info, std::unique_lock & info_lock); - void markDedupWorker(DeduperInfoPtr & info, std::unique_lock & info_lock); + void markDedupWorker(DeduperInfoPtr & info, std::unique_lock & info_lock); void stopDeduperWorker(DeduperInfoPtr & info); - bool checkDedupWorkerStatus(DeduperInfoPtr & info, std::unique_lock & info_lock); + bool checkDedupWorkerStatus(DeduperInfoPtr & info, std::unique_lock & info_lock); - static void assignHighPriorityDedupPartition(DeduperInfoPtr & info, const Names & high_priority_partition, std::unique_lock & info_lock); - static void unsetWorkerClient(DeduperInfoPtr & info, std::unique_lock & info_lock); - static void assignRepairGran(DeduperInfoPtr & info, const DedupGran & dedup_gran, const UInt64 & max_event_time, std::unique_lock & info_lock); - static String getDedupWorkerDebugInfo(DeduperInfoPtr & info, std::unique_lock & info_lock); + static void assignHighPriorityDedupPartition(DeduperInfoPtr & info, const Names & high_priority_partition, std::unique_lock & info_lock); + static void unsetWorkerClient(DeduperInfoPtr & info, std::unique_lock & info_lock); + static void assignRepairGran(DeduperInfoPtr & info, const DedupGran & dedup_gran, const UInt64 & max_event_time, std::unique_lock & info_lock); + static String getDedupWorkerDebugInfo(DeduperInfoPtr & info, std::unique_lock & info_lock); mutable bthread::Mutex deduper_infos_mutex; std::atomic initialized{false}; diff --git a/src/CloudServices/RpcClientBase.cpp b/src/CloudServices/RpcClientBase.cpp index e0d204e9b8a..8d7d75bc63d 100644 --- a/src/CloudServices/RpcClientBase.cpp +++ b/src/CloudServices/RpcClientBase.cpp @@ -31,6 +31,7 @@ namespace ErrorCodes extern const int BRPC_TIMEOUT; extern const int BRPC_HOST_DOWN; extern const int BRPC_CONNECT_ERROR; + extern const int BRPC_NO_METHOD; } static auto getDefaultChannelOptions() @@ -74,7 +75,7 @@ void RpcClientBase::assertController(const brpc::Controller & cntl) if (cntl.Failed()) { auto err = cntl.ErrorCode(); - const String err_prefix = "Error on connecting " + (!host_ports.id.empty() ? host_ports.id : host_ports.toDebugString()) + " by rpc: "; + const String err_prefix = fmt::format("Error on connecting {} by rpc {}: ", (!host_ports.id.empty() ? host_ports.id : host_ports.toDebugString()), cntl.method()->full_name()); if (err == ECONNREFUSED || err == ECONNRESET || err == ENETUNREACH) { @@ -91,6 +92,10 @@ void RpcClientBase::assertController(const brpc::Controller & cntl) { throw Exception(err_prefix + std::to_string(err) + ":" + cntl.ErrorText(), ErrorCodes::BRPC_TIMEOUT); } + else if (err == brpc::Errno::ENOMETHOD) + { + throw Exception(err_prefix + std::to_string(err) + ":" + cntl.ErrorText(), ErrorCodes::BRPC_NO_METHOD); + } else /// Should we throw exception here to cover all other errors? throw Exception(err_prefix + std::to_string(err) + ":" + cntl.ErrorText(), ErrorCodes::BRPC_EXCEPTION); } diff --git a/src/Common/Brpc/BrpcGflagsConfigHolder.cpp b/src/Common/Brpc/BrpcGflagsConfigHolder.cpp index 8d5f3ff475e..04b13aa1fc9 100644 --- a/src/Common/Brpc/BrpcGflagsConfigHolder.cpp +++ b/src/Common/Brpc/BrpcGflagsConfigHolder.cpp @@ -24,7 +24,7 @@ static std::unordered_map configurable_brpc_gflags /* {/// Number of event dispatcher {"event_dispatcher_num", "2"}, /// Max unwritten bytes in each socket, if the limit is reached - {"socket_max_unwritten_bytes", "1073741824"}, + {"socket_max_unwritten_bytes", "2147483648"}, /// Set the recv buffer size of socket if this value is positive {"socket_recv_buffer_size", ""}, /// Set send buffer size of sockets if this value is positive @@ -36,7 +36,7 @@ static std::unordered_map configurable_brpc_gflags /* /// values <= 0 disables this feature {"free_memory_to_system_interval", ""}, /// Maximum size of a single message body in all protocols - {"max_body_size", "671088640"}, + {"max_body_size", "3221225472"}, /// Print Controller.ErrorText() when server is about to respond a failed RPC {"log_error_text", ""}, {"bvar_enable_sampling", "true"}}; diff --git a/src/Common/Configurations.h b/src/Common/Configurations.h index 72ea2abafa8..d02b3fef715 100644 --- a/src/Common/Configurations.h +++ b/src/Common/Configurations.h @@ -128,10 +128,6 @@ struct BSPConfiguration final : public BSPConfigurationData M(UInt64, max_concurrent_insert_queries, "", 0, ConfigFlag::Default, "") \ M(UInt64, max_concurrent_system_queries, "", 0, ConfigFlag::Default, "") \ M(Float32, cache_size_to_ram_max_ratio, "", 0.5, ConfigFlag::Default, "") \ - M(UInt64, uncompressed_cache_size, "", 0, ConfigFlag::Default, "") \ - M(UInt64, mark_cache_size, "", 5368709120, ConfigFlag::Default, "") \ - M(UInt64, cnch_checksums_cache_size, "", 5368709120, ConfigFlag::Default, "") \ - M(UInt64, cnch_primary_index_cache_size, "", 5368709120, ConfigFlag::Default, "") \ M(UInt64, shutdown_wait_unfinished, "", 5, ConfigFlag::Default, "") \ M(UInt64, cnch_transaction_ts_expire_time, "", 2 * 60 * 60 * 1000, ConfigFlag::Default, "") \ M(UInt64, cnch_task_heartbeat_interval, "", 5, ConfigFlag::Default, "") \ @@ -142,11 +138,59 @@ struct BSPConfiguration final : public BSPConfigurationData M(UInt64, max_async_query_threads, "", 5000, ConfigFlag::Default, "Maximum threads that async queries use.") \ M(UInt64, async_query_status_ttl, "", 86400, ConfigFlag::Default, "TTL for async query status stored in catalog, in seconds.") \ M(UInt64, async_query_expire_time, "", 3600, ConfigFlag::Default, "Expire time for async query, in seconds.") \ - M(UInt64, async_query_status_check_period, "", 15 * 60, ConfigFlag::Default, "Cycle for checking expired async query status stored in catalog, in seconds.") \ + M(UInt64, \ + async_query_status_check_period, \ + "", \ + 15 * 60, \ + ConfigFlag::Default, \ + "Cycle for checking expired async query status stored in catalog, in seconds.") \ + M(Bool, enable_cnch_write_remote_catalog, "", true, ConfigFlag::Default, "Set to false to disable writing catalog") \ + M(Bool, enable_cnch_write_remote_disk, "", true, ConfigFlag::Default, "set to false to disable writing data") \ + /** + * Memory caches */ \ + M(UInt64, bitengine_memory_cache_size, "", 50UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, checksum_cache_size, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, checksum_cache_bucket, "", 5000, ConfigFlag::Default, "") \ + M(UInt64, checksum_cache_shard, "", 8, ConfigFlag::Default, "") \ + M(UInt64, checksum_cache_lru_update_interval, "", 60, ConfigFlag::Default, "In seconds") \ + M(UInt64, cnch_checksums_cache_size, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") /* Checksums cache configs in cnch 1.4 */ \ + M(UInt64, cnch_primary_index_cache_size, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, compiled_expression_cache_size, "", 128UL * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, compressed_data_index_cache, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, delete_bitmap_cache_size, "", 1UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, footer_cache_size, "", 3UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, geometry_primary_index_cache_size, "", 1UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, gin_index_filter_result_cache, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, ginindex_store_cache_size, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, ginindex_store_cache_bucket, "", 5000, ConfigFlag::Default, "") \ + M(UInt64, ginindex_store_cache_shard, "", 2, ConfigFlag::Default, "") \ + M(UInt64, ginindex_store_cache_ttl, "", 60, ConfigFlag::Default, "") \ + M(UInt64, ginindex_store_cache_lru_update_interval, "", 60, ConfigFlag::Default, "In seconds") \ + M(UInt64, intermediate_result_cache_size, "", 1UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, mark_cache_size, "", 5UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, mmap_cache_size, "", 1000, ConfigFlag::Default, "") \ + M(UInt64, unique_key_index_data_cache_size, "", 1UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, unique_key_index_meta_cache_size, "", 1UL * 1024 * 1024 * 1024, ConfigFlag::Default, "") \ + M(UInt64, uncompressed_cache_size, "", 0, ConfigFlag::Default, "") \ + /** + * Cache default size max ratio */ \ + M(Float32, bitengine_memory_cache_size_default_max_ratio, "", 0.15, ConfigFlag::Default, "") \ + M(Float32, checksum_cache_size_default_max_ratio, "", 0.05, ConfigFlag::Default, "") \ + M(Float32, cnch_primary_index_cache_size_default_max_ratio, "", 0.05, ConfigFlag::Default, "") \ + M(Float32, compiled_expression_cache_size_default_max_ratio, "", 0.001, ConfigFlag::Default, "") \ + M(Float32, compressed_data_index_cache_default_max_ratio, "", 0.025, ConfigFlag::Default, "") \ + M(Float32, delete_bitmap_cache_size_default_max_ratio, "", 0.005, ConfigFlag::Default, "") \ + M(Float32, footer_cache_size_default_max_ratio, "", 0.015, ConfigFlag::Default, "") \ + M(Float32, gin_index_filter_result_cache_default_max_ratio, "", 0.025, ConfigFlag::Default, "") \ + M(Float32, ginindex_store_cache_size_default_max_ratio, "", 0.025, ConfigFlag::Default, "") \ + M(Float32, intermediate_result_cache_size_default_max_ratio, "", 0.005, ConfigFlag::Default, "") \ + M(Float32, mark_cache_size_default_max_ratio, "", 0.05, ConfigFlag::Default, "") \ + M(Float32, unique_key_index_data_cache_size_default_max_ratio, "", 0.005, ConfigFlag::Default, "") \ + M(Float32, unique_key_index_meta_cache_size_default_max_ratio, "", 0.01, ConfigFlag::Default, "") \ /** * Mutable */ \ M(MutableUInt64, max_server_memory_usage, "", 0, ConfigFlag::Default, "") \ - M(MutableFloat32, max_server_memory_usage_to_ram_ratio, "", 0.8, ConfigFlag::Default, "") \ + M(MutableFloat32, max_server_memory_usage_to_ram_ratio, "", 0.9, ConfigFlag::Default, "") \ M(MutableUInt64, kafka_max_partition_fetch_bytes, "", 1048576, ConfigFlag::Default, "") \ M(MutableUInt64, stream_poll_timeout_ms, "", 500, ConfigFlag::Default, "") \ M(MutableUInt64, debug_disable_merge_mutate_thread, "", false, ConfigFlag::Default, "") \ diff --git a/src/Common/DateLUT.cpp b/src/Common/DateLUT.cpp index 2ab457099c6..b9d51f85d2f 100644 --- a/src/Common/DateLUT.cpp +++ b/src/Common/DateLUT.cpp @@ -1,5 +1,9 @@ #include +#include +#include +#include + #include #include #include @@ -29,12 +33,12 @@ std::string determineDefaultTimeZone() { namespace fs = std::filesystem; - const char * tzdir_env_var = std::getenv("TZDIR"); + const char * tzdir_env_var = std::getenv("TZDIR"); // NOLINT(concurrency-mt-unsafe) // ok, because it does not run concurrently with other getenv calls fs::path tz_database_path = tzdir_env_var ? tzdir_env_var : "/usr/share/zoneinfo/"; fs::path tz_file_path; std::string error_prefix; - const char * tz_env_var = std::getenv("TZ"); + const char * tz_env_var = std::getenv("TZ"); // NOLINT(concurrency-mt-unsafe) // ok, because it does not run concurrently with other getenv calls /// In recent tzdata packages some files now are symlinks and canonical path resolution /// may give wrong timezone names - store the name as it is, if possible. @@ -138,6 +142,38 @@ std::string determineDefaultTimeZone() } +const DateLUTImpl & DateLUT::sessionInstance() +{ + const auto & date_lut = getInstance(); + + if (DB::CurrentThread::isInitialized()) + { + std::string timezone_from_context; + const DB::ContextPtr query_context = DB::CurrentThread::get().getQueryContext(); + + if (query_context) + { + timezone_from_context = extractTimezoneFromContext(query_context); + + if (!timezone_from_context.empty()) + return date_lut.getImplementation(timezone_from_context); + } + + /// On the server side, timezone is passed in query_context, + /// but on CH-client side we have no query context, + /// and each time we modify client's global context + const DB::ContextPtr global_context = DB::CurrentThread::get().getGlobalContext(); + if (global_context) + { + timezone_from_context = extractTimezoneFromContext(global_context); + + if (!timezone_from_context.empty()) + return date_lut.getImplementation(timezone_from_context); + } + } + return serverTimezoneInstance(); +} + DateLUT::DateLUT() { /// Initialize the pointer to the default DateLUTImpl. @@ -148,7 +184,7 @@ DateLUT::DateLUT() const DateLUTImpl & DateLUT::getImplementation(const std::string & time_zone) const { - std::lock_guard lock(mutex); + std::lock_guard lock(mutex); auto it = impls.emplace(time_zone, nullptr).first; if (!it->second) @@ -162,3 +198,8 @@ DateLUT & DateLUT::getInstance() static DateLUT ret; return ret; } + +std::string DateLUT::extractTimezoneFromContext(DB::ContextPtr query_context) +{ + return query_context->getSettingsRef().session_timezone.value; +} diff --git a/src/Common/DateLUT.h b/src/Common/DateLUT.h index bef65d8a026..ecb1cd541d1 100644 --- a/src/Common/DateLUT.h +++ b/src/Common/DateLUT.h @@ -3,6 +3,7 @@ #include #include +#include #include @@ -11,13 +12,22 @@ #include #include +namespace DB +{ +class Context; +using ContextPtr = std::shared_ptr; +} + +class DateLUTImpl; + /// This class provides lazy initialization and lookup of singleton DateLUTImpl objects for a given timezone. class DateLUT : private boost::noncopyable { public: - /// Return singleton DateLUTImpl instance for the default time zone. - static ALWAYS_INLINE const DateLUTImpl & instance() // -V1071 + /// The default instance will return singleton DateLUTImpl for the server time zone. + /// It may be set using 'timezone' server setting. + static ALWAYS_INLINE const DateLUTImpl & serverTimezoneInstance() { const auto & date_lut = getInstance(); return *date_lut.default_impl.load(std::memory_order_acquire); @@ -26,12 +36,18 @@ class DateLUT : private boost::noncopyable /// Return singleton DateLUTImpl instance for a given time zone. static ALWAYS_INLINE const DateLUTImpl & instance(const std::string & time_zone) { - const auto & date_lut = getInstance(); if (time_zone.empty()) - return *date_lut.default_impl.load(std::memory_order_acquire); + return sessionInstance(); + const auto & date_lut = getInstance(); return date_lut.getImplementation(time_zone); } + + /// Return DateLUTImpl instance for session timezone. + /// session_timezone is a session-level setting. + /// If setting is not set, returns the server timezone. + static const DateLUTImpl & sessionInstance(); + static void setDefaultTimezone(const std::string & time_zone) { auto & date_lut = getInstance(); @@ -45,6 +61,8 @@ class DateLUT : private boost::noncopyable private: static DateLUT & getInstance(); + static std::string extractTimezoneFromContext(DB::ContextPtr query_context); + const DateLUTImpl & getImplementation(const std::string & time_zone) const; using DateLUTImplPtr = std::unique_ptr; diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index 619f5cb0faa..70da56e0a0f 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -780,6 +780,7 @@ M(2010, EXCHANGE_DATA_TRANS_EXCEPTION) \ M(2011, BRPC_PROTOCOL_VERSION_UNSUPPORT) \ M(2013, QUERY_WAS_CANCELLED_INTERNAL) \ + M(2014, BRPC_NO_METHOD) \ M(2900, BSP_EXCHANGE_DATA_DISK_LIMIT_EXCEEDED) \ M(2901, BSP_CLEANUP_PREVIOUS_SEGMENT_INSTANCE_FAILED) \ \ @@ -818,6 +819,8 @@ M(5025, VIRTUAL_WAREHOUSE_NOT_FOUND) \ M(5027, CNCH_SERVER_NOT_FOUND) \ M(5030, CNCH_BG_THREAD_NOT_FOUND) \ + M(5031, UNKNOWN_CNCH_BG_THREAD_TYPE) \ + M(5032, UNKNOWN_CNCH_BG_THREAD_ACTION) \ M(5035, INSERTION_LABEL_ALREADY_EXISTS) \ M(5036, FAILED_TO_PUT_INSERTION_LABEL) \ M(5037, VIRTUAL_WAREHOUSE_ALREADY_EXISTS) \ @@ -828,6 +831,8 @@ M(5044, UNKNOWN_WORKER_GROUP) \ M(5045, WORKER_NODE_INCONSISTENTCY) \ M(5046, DISK_CACHE_NOT_USED) \ + M(5047, WORKER_RESTARTED) \ + M(5048, WORKER_NODE_NOT_FOUND) \ \ M(5453, HDFS_FILE_SYSTEM_UNREGISTER) \ M(5454, BAD_HDFS_META_FILE) \ @@ -859,6 +864,8 @@ M(7112, RESOURCE_MANAGER_WRONG_COORDINATE_MODE) \ M(7113, RESOURCE_MANAGER_REMOVE_WORKER_ERROR) \ M(7114, RESOURCE_MANAGER_LEADER_NOT_WORK_WELL) \ +\ + M(7150, MERGE_BAD_PART_NAME) \ \ M(7200, UNIQUE_KEY_STRING_SIZE_LIMIT_EXCEEDED) \ M(7201, UNIQUE_TABLE_DUPLICATE_KEY_FOUND) \ @@ -897,6 +904,7 @@ M(8075, METASTORE_ACCESS_ENTITY_EXISTS_ERROR) \ M(8076, METASTORE_ACCESS_ENTITY_NOT_IMPLEMENTED) \ M(8077, TABLE_IS_DETACHED) \ + M(8078, METASTORE_ERROR_KEY) \ \ M(11000, CNCH_TRANSACTION_TSCACHE_CHECK_FAILED) \ M(11001, CNCH_TRANSACTION_EXECUTION_ERROR) \ diff --git a/src/Common/Exception.h b/src/Common/Exception.h index e347b76845c..aa6e2c251ea 100644 --- a/src/Common/Exception.h +++ b/src/Common/Exception.h @@ -270,7 +270,7 @@ class ExceptionHandler class ExceptionHandlerWithFailedInfo : public ExceptionHandler { using ErrorCode = int32_t; - using WorkerIdErrorCodeMap = std::unordered_map; + using WorkerIdErrorCodeMap = std::unordered_map; public: void addFailedRpc(const DB::WorkerId & worker_id, int32_t error_code) diff --git a/src/Common/FilePathMatcher.cpp b/src/Common/FilePathMatcher.cpp new file mode 100644 index 00000000000..2600b24fc84 --- /dev/null +++ b/src/Common/FilePathMatcher.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +/** + * @brief Recursive directory listing with matched paths as a result. + */ +Strings FilePathMatcher::regexMatchFiles(const String & path_for_ls, const String & for_match) +{ + const size_t first_glob = for_match.find_first_of("*?{"); + + const size_t end_of_path_without_globs = for_match.substr(0, first_glob).rfind('/'); + const String suffix_with_globs = for_match.substr(end_of_path_without_globs); /// begin with '/' + String prefix_without_globs = path_for_ls + for_match.substr(1, end_of_path_without_globs); /// ends with '/' + + const size_t next_slash = suffix_with_globs.find('/', 1); + re2::RE2 matcher(makeRegexpPatternFromGlobs(suffix_with_globs.substr(0, next_slash))); + + Strings result; + FileInfos file_infos = getFileInfos(prefix_without_globs); + for (const FileInfo & file_info : file_infos) + { + const size_t last_slash = file_info.file_path.rfind('/'); + const String file_name = file_info.file_path.substr(last_slash); + const bool looking_for_directory = next_slash != std::string::npos; + /// Condition with type of current file_info means what kind of path is it in current iteration of ls + if (!file_info.is_directory && !looking_for_directory) + { + if (re2::RE2::FullMatch(file_name, matcher)) + { + result.push_back(getSchemeAndPrefix() + file_info.file_path); + } + } + else if (file_info.is_directory && looking_for_directory) + { + if (re2::RE2::FullMatch(file_name, matcher)) + { + Strings result_part + = regexMatchFiles(std::filesystem::path(file_info.file_path) / "", suffix_with_globs.substr(next_slash)); + /// Recursion depth is limited by pattern. '*' works only for depth = 1, for depth = 2 pattern path is '*/*'. So we do not need additional check. + std::move(result_part.begin(), result_part.end(), std::back_inserter(result)); + } + } + } + + return result; +} + +String FilePathMatcher::removeSchemeAndPrefix(const String & full_path) +{ + String match_path = full_path; + // remove scheme from path + Poco::URI uri(full_path); + // If there is a '?', substring after '?' will be recognized as a query + if (!uri.getQuery().empty()) + match_path = uri.getPathAndQuery(); + else + match_path = uri.getPath(); + + return match_path; +} +} diff --git a/src/Common/FilePathMatcher.h b/src/Common/FilePathMatcher.h new file mode 100644 index 00000000000..72f2097e1d5 --- /dev/null +++ b/src/Common/FilePathMatcher.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace DB +{ + +struct FileInfo +{ + String file_path; + bool is_directory; + + FileInfo(const String & file_path_, bool is_directory_) : file_path(file_path_), is_directory(is_directory_) { } +}; + +using FileInfos = std::vector; + +class FilePathMatcher +{ +public: + virtual ~FilePathMatcher() = default; + + Strings regexMatchFiles(const String & path_for_ls, const String & for_match); + + // For regex match, we remove scheme and prefix(S3 bucket) from full path. + virtual String removeSchemeAndPrefix(const String & full_path); + +protected: + virtual FileInfos getFileInfos(const String & prefix_path) = 0; + + // For regex match, we remove scheme and prefix(S3 bucket) from full path. + // But these prefix are needed when infile from some file system, so we will add it back. + virtual String getSchemeAndPrefix() { return ""; } +}; + +} diff --git a/src/Common/HDFSFilePathMatcher.cpp b/src/Common/HDFSFilePathMatcher.cpp new file mode 100644 index 00000000000..de37b049f34 --- /dev/null +++ b/src/Common/HDFSFilePathMatcher.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace DB +{ + +HDFSFilePathMatcher::HDFSFilePathMatcher(String & path, const ContextPtr & context_ptr) +{ + Poco::URI uri(path); + HDFSConnectionParams hdfs_params = context_ptr->getHdfsConnectionParams(); + HDFSBuilderPtr builder = hdfs_params.createBuilder(uri); + hdfs_fs = createHDFSFS(builder.get()); +} + +FileInfos HDFSFilePathMatcher::getFileInfos(const String & prefix_path) +{ + FileInfos file_infos; + HDFSFileInfo ls; + ls.file_info = hdfsListDirectory(hdfs_fs.get(), prefix_path.data(), &ls.length); + for (int i = 0; i < ls.length; i++) + { + file_infos.emplace_back(ls.file_info[i].mName, ls.file_info[i].mKind == kObjectKindDirectory); + } + return file_infos; +} + +} diff --git a/src/Common/HDFSFilePathMatcher.h b/src/Common/HDFSFilePathMatcher.h new file mode 100644 index 00000000000..308e85ad965 --- /dev/null +++ b/src/Common/HDFSFilePathMatcher.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ + +class HDFSFilePathMatcher : public FilePathMatcher +{ +public: + HDFSFilePathMatcher(String & path, const ContextPtr & context_ptr); + + ~HDFSFilePathMatcher() override = default; + + FileInfos getFileInfos(const String & prefix_path) override; + +private: + HDFSFSPtr hdfs_fs; +}; + +} diff --git a/src/Common/HostWithPorts.cpp b/src/Common/HostWithPorts.cpp index 561470b8945..40056143527 100644 --- a/src/Common/HostWithPorts.cpp +++ b/src/Common/HostWithPorts.cpp @@ -22,7 +22,6 @@ #include #include #include -#include namespace DB { @@ -63,7 +62,7 @@ HostWithPorts HostWithPorts::fromRPCAddress(const std::string & s) bool HostWithPorts::isExactlySameVec(const HostWithPortsVec & lhs, const HostWithPortsVec & rhs) { - return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), HostWithPorts::IsExactlySame{}); + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), HostWithPorts::IsSameEndpoint{}); } std::ostream & operator<<(std::ostream & os, const HostWithPorts & host_ports) diff --git a/src/Common/HostWithPorts.h b/src/Common/HostWithPorts.h index a6919cae1e9..0f0126aee42 100644 --- a/src/Common/HostWithPorts.h +++ b/src/Common/HostWithPorts.h @@ -204,12 +204,15 @@ class HostWithPorts std::string getExchangeAddress() const { return getRPCAddress(); } std::string getExchangeStatusAddress() const { return getRPCAddress(); } + bool operator<(const HostWithPorts & rhs) const { return id < rhs.getId(); } const std::string & getHost() const { return host; } uint16_t getTCPPort() const { return tcp_port; } uint16_t getHTTPPort() const { return http_port; } uint16_t getRPCPort() const { return rpc_port; } std::string toDebugString() const; + String getId() const { return id; } + static HostWithPorts fromRPCAddress(const std::string & s); /// NOTE: PLEASE DO NOT implement any comparison operator which is a kind of bad code style diff --git a/src/Common/HuAllocator.h b/src/Common/HuAllocator.h index 7f627808917..a4a9b768677 100644 --- a/src/Common/HuAllocator.h +++ b/src/Common/HuAllocator.h @@ -121,12 +121,17 @@ class HuAllocator static void InitHuAlloc(size_t cached) { - hu_check_init_w(); - pthread_t tid; - size_t use_cache = cached / 2; + static std::once_flag hualloc_init_flag; + static size_t use_cache = cached / 2; if (use_cache <= 0) use_cache = 1024 * (1ull << 20); /// If not set properly use 1G as default - pthread_create(&tid, nullptr, ReclaimThread, &use_cache); + + std::call_once(hualloc_init_flag, [&]() + { + hu_check_init_w(); + pthread_t tid; + pthread_create(&tid, nullptr, ReclaimThread, &use_cache); + }); } protected: diff --git a/src/Common/HyperLogLogWithSmallSetOptimization.h b/src/Common/HyperLogLogWithSmallSetOptimization.h index 39c00660ebe..0df5786e513 100644 --- a/src/Common/HyperLogLogWithSmallSetOptimization.h +++ b/src/Common/HyperLogLogWithSmallSetOptimization.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -28,9 +29,10 @@ class HyperLogLogWithSmallSetOptimization : private boost::noncopyable using Small = SmallSet; using Large = HyperLogLogCounter; using LargeValueType = typename Large::value_type; + using LargePtr = std::shared_ptr; Small small; - Large * large = nullptr; + LargePtr large; bool isLarge() const { @@ -40,22 +42,18 @@ class HyperLogLogWithSmallSetOptimization : private boost::noncopyable void toLarge() { /// At the time of copying data from `tiny`, setting the value of `large` is still not possible (otherwise it will overwrite some data). - Large * tmp_large = new Large; + LargePtr tmp_large = std::make_shared(); for (const auto & x : small) tmp_large->insert(static_cast(x.getValue())); - large = tmp_large; + large = std::move(tmp_large); } public: using value_type = Key; - ~HyperLogLogWithSmallSetOptimization() - { - if (isLarge()) - delete large; - } + ~HyperLogLogWithSmallSetOptimization() = default; /// ALWAYS_INLINE is required to have better code layout for uniqHLL12 function void ALWAYS_INLINE insert(Key value) diff --git a/src/Common/JSONParsers/DummyJSONParser.h~cnch-ce-merge b/src/Common/JSONParsers/DummyJSONParser.h~cnch-ce-merge deleted file mode 100644 index 6266ed48f65..00000000000 --- a/src/Common/JSONParsers/DummyJSONParser.h~cnch-ce-merge +++ /dev/null @@ -1,111 +0,0 @@ -#pragma once - -#include -#include -#include -#include "ElementTypes.h" - - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NOT_IMPLEMENTED; -} - -/// This class can be used as an argument for the template class FunctionJSON when we unable to parse JSONs. -/// It can't do anything useful and just throws an exception. -struct DummyJSONParser -{ - class Array; - class Object; - - /// References an element in a JSON document, representing a JSON null, boolean, string, number, - /// array or object. - class Element - { - public: - Element() = default; - static ElementType type() { return ElementType::NULL_VALUE; } - static bool isInt64() { return false; } - static bool isUInt64() { return false; } - static bool isDouble() { return false; } - static bool isString() { return false; } - static bool isArray() { return false; } - static bool isObject() { return false; } - static bool isBool() { return false; } - static bool isNull() { return false; } - - static Int64 getInt64() { return 0; } - static UInt64 getUInt64() { return 0; } - static double getDouble() { return 0; } - static bool getBool() { return false; } - static std::string_view getString() { return {}; } - static Array getArray() { return {}; } - static Object getObject() { return {}; } - - static Element getElement() { return {}; } - }; - - /// References an array in a JSON document. - class Array - { - public: - class Iterator - { - public: - Element operator*() const { return {}; } - Iterator & operator++() { return *this; } - Iterator operator++(int) { return *this; } /// NOLINT - friend bool operator==(const Iterator &, const Iterator &) { return true; } - friend bool operator!=(const Iterator &, const Iterator &) { return false; } - }; - - static Iterator begin() { return {}; } - static Iterator end() { return {}; } - static size_t size() { return 0; } - Element operator[](size_t) const { return {}; } - }; - - using KeyValuePair = std::pair; - - /// References an object in a JSON document. - class Object - { - public: - class Iterator - { - public: - KeyValuePair operator*() const { return {}; } - Iterator & operator++() { return *this; } - Iterator operator++(int) { return *this; } /// NOLINT - friend bool operator==(const Iterator &, const Iterator &) { return true; } - friend bool operator!=(const Iterator &, const Iterator &) { return false; } - }; - - static Iterator begin() { return {}; } - static Iterator end() { return {}; } - static size_t size() { return 0; } - bool find(std::string_view, Element &) const { return false; } /// NOLINT - -#if 0 - /// Optional: Provides access to an object's element by index. - KeyValuePair operator[](size_t) const { return {}; } -#endif - }; - - /// Parses a JSON document, returns the reference to its root element if succeeded. - bool parse(std::string_view, Element &) { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Functions JSON* are not supported"); } /// NOLINT - -#if 0 - /// Optional: Allocates memory to parse JSON documents faster. - void reserve(size_t max_size); -#endif -}; - -inline ALWAYS_INLINE std::ostream& operator<<(std::ostream& out, DummyJSONParser::Element) -{ - return out; -} - -} diff --git a/src/Common/LRUCache.h b/src/Common/LRUCache.h index 4bfa872863c..f52475dccab 100644 --- a/src/Common/LRUCache.h +++ b/src/Common/LRUCache.h @@ -194,6 +194,9 @@ class LRUCache LRUCache(size_t max_size_, const Delay & expiration_delay_ = Delay::zero()) : max_size(std::max(static_cast(1), max_size_)), expiration_delay(expiration_delay_) {} + void setCapacity(size_t max_size_) { max_size = std::max(static_cast(1), max_size_); } + size_t getCapacity() { return max_size; } + MappedPtr get(const Key & key) { std::lock_guard lock(mutex); @@ -478,7 +481,7 @@ class LRUCache /// Total weight of values. size_t current_size = 0; - const size_t max_size; + size_t max_size; const Delay expiration_delay; std::atomic hits {0}; diff --git a/src/Common/LocalDate.h b/src/Common/LocalDate.h index 0b8783e27d7..7284d46ede1 100644 --- a/src/Common/LocalDate.h +++ b/src/Common/LocalDate.h @@ -46,9 +46,8 @@ class LocalDate unsigned char m_month; unsigned char m_day; - void init(time_t time) + void init(time_t time, const DateLUTImpl & date_lut) { - const auto & date_lut = DateLUT::instance(); const auto & values = date_lut.getValues(time); m_year = values.year; @@ -78,22 +77,19 @@ class LocalDate } public: - explicit LocalDate(time_t time) - { - init(time); - } + explicit LocalDate(time_t time, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { init(time, time_zone); } - LocalDate(DayNum day_num) + LocalDate(DayNum day_num, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) /// NOLINT { - const auto & values = DateLUT::instance().getValues(day_num); + const auto & values = time_zone.getValues(day_num); m_year = values.year; m_month = values.month; m_day = values.day_of_month; } - explicit LocalDate(ExtendedDayNum day_num) + explicit LocalDate(ExtendedDayNum day_num, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { - const auto & values = DateLUT::instance().getValues(day_num); + const auto & values = time_zone.getValues(day_num); m_year = values.year; m_month = values.month; m_day = values.day_of_month; @@ -121,16 +117,14 @@ class LocalDate LocalDate(const LocalDate &) noexcept = default; LocalDate & operator= (const LocalDate &) noexcept = default; - DayNum getDayNum() const + DayNum getDayNum(const DateLUTImpl & lut = DateLUT::serverTimezoneInstance()) const { - const auto & lut = DateLUT::instance(); return DayNum(lut.makeDayNum(m_year, m_month, m_day).toUnderType()); } - ExtendedDayNum getExtendedDayNum() const + ExtendedDayNum getExtendedDayNum(const DateLUTImpl & lut = DateLUT::serverTimezoneInstance()) const { - const auto & lut = DateLUT::instance(); - return ExtendedDayNum (lut.makeDayNum(m_year, m_month, m_day).toUnderType()); + return ExtendedDayNum(lut.makeDayNum(m_year, m_month, m_day).toUnderType()); } operator DayNum() const @@ -138,10 +132,7 @@ class LocalDate return getDayNum(); } - operator time_t() const - { - return DateLUT::instance().makeDate(m_year, m_month, m_day); - } + operator time_t() const { return DateLUT::serverTimezoneInstance().makeDate(m_year, m_month, m_day); } unsigned short year() const { return m_year; } unsigned char month() const { return m_month; } diff --git a/src/Common/LocalDateTime.h b/src/Common/LocalDateTime.h index 0dc89ce11ca..54e785c8b52 100644 --- a/src/Common/LocalDateTime.h +++ b/src/Common/LocalDateTime.h @@ -81,10 +81,7 @@ class LocalDateTime } public: - explicit LocalDateTime(time_t time, const DateLUTImpl & time_zone = DateLUT::instance()) - { - init(time, time_zone); - } + explicit LocalDateTime(time_t time, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { init(time, time_zone); } LocalDateTime(unsigned short year_, unsigned char month_, unsigned char day_, unsigned char hour_, unsigned char minute_, unsigned char second_) @@ -115,9 +112,7 @@ class LocalDateTime operator time_t() const { - return m_year == 0 - ? 0 - : DateLUT::instance().makeDateTime(m_year, m_month, m_day, m_hour, m_minute, m_second); + return m_year == 0 ? 0 : DateLUT::serverTimezoneInstance().makeDateTime(m_year, m_month, m_day, m_hour, m_minute, m_second); } unsigned short year() const { return m_year; } diff --git a/src/Common/LocalFilePathMatcher.cpp b/src/Common/LocalFilePathMatcher.cpp new file mode 100644 index 00000000000..c8c4b5b4779 --- /dev/null +++ b/src/Common/LocalFilePathMatcher.cpp @@ -0,0 +1,17 @@ +#include +#include + +namespace DB +{ + +FileInfos LocalFilePathMatcher::getFileInfos(const String & prefix_path) +{ + FileInfos file_infos; + for (const auto & entry : std::filesystem::directory_iterator(prefix_path)) + { + file_infos.emplace_back(entry.path(), entry.is_directory()); + } + return file_infos; +} + +} diff --git a/src/Common/LocalFilePathMatcher.h b/src/Common/LocalFilePathMatcher.h new file mode 100644 index 00000000000..200e88bd0f6 --- /dev/null +++ b/src/Common/LocalFilePathMatcher.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace DB +{ + +class LocalFilePathMatcher : public FilePathMatcher +{ +public: + ~LocalFilePathMatcher() override = default; + + FileInfos getFileInfos(const String & prefix_path) override; +}; + +} diff --git a/src/Common/MemoryTracker.cpp b/src/Common/MemoryTracker.cpp index 9749b259bdf..364d4049ecd 100644 --- a/src/Common/MemoryTracker.cpp +++ b/src/Common/MemoryTracker.cpp @@ -211,7 +211,7 @@ void MemoryTracker::allocImpl(Int64 size, bool throw_if_memory_exceeded) #endif std::bernoulli_distribution fault(fault_probability); - if (unlikely(fault_probability && fault(thread_local_rng)) && memoryTrackerCanThrow(level, true) && throw_if_memory_exceeded) + if (unlikely(fault_probability && fault(thread_local_rng)) && throw_if_memory_exceeded && memoryTrackerCanThrow(level, true)) { ProfileEvents::increment(ProfileEvents::QueryMemoryLimitExceeded); amount.fetch_sub(size, std::memory_order_relaxed); @@ -250,7 +250,7 @@ void MemoryTracker::allocImpl(Int64 size, bool throw_if_memory_exceeded) DB::TraceCollector::collect(DB::TraceType::MemorySample, StackTrace(), size); } - if (unlikely(current_hard_limit && will_be > current_hard_limit) && memoryTrackerCanThrow(level, false) && throw_if_memory_exceeded) + if (unlikely(current_hard_limit && will_be > current_hard_limit) && throw_if_memory_exceeded && memoryTrackerCanThrow(level, false)) { /// Prevent recursion. Exception::ctor -> std::string -> new[] -> MemoryTracker::alloc BlockerInThread untrack_lock(VariableContext::Global); diff --git a/src/Common/ProfileEvents.cpp b/src/Common/ProfileEvents.cpp index dbac6c90e81..ba91347891b 100644 --- a/src/Common/ProfileEvents.cpp +++ b/src/Common/ProfileEvents.cpp @@ -316,6 +316,11 @@ M(PerfInequalConditionAppendMicroseconds, "") \ M(PerfJoinElapsedMicroseconds, "") \ M(PerfFilterElapsedMicroseconds, "") \ +\ + M(QueryCreateTablesMicroseconds, "") \ + M(QuerySendResourcesMicroseconds, "") \ + M(CloudTableDefinitionCacheHits, "") \ + M(CloudTableDefinitionCacheMisses, "") \ \ M(CreatedHTTPConnections, "Total amount of created HTTP connections (closed or opened).") \ \ @@ -521,6 +526,14 @@ M(GetSQLBindingsSuccess, "") \ M(RemoveSQLBindingFailed, "") \ M(RemoveSQLBindingSuccess, "") \ + M(UpdatePreparedStatementFailed, "") \ + M(UpdatePreparedStatementSuccess, "") \ + M(GetPreparedStatementFailed, "") \ + M(GetPreparedStatementSuccess, "") \ + M(GetPreparedStatementsFailed, "") \ + M(GetSPreparedStatementsSuccess, "") \ + M(RemovePreparedStatementFailed, "") \ + M(RemovePreparedStatementSuccess, "") \ M(CreateDatabaseSuccess, "") \ M(CreateDatabaseFailed, "") \ M(GetDatabaseSuccess, "") \ @@ -691,6 +704,10 @@ M(GetAllUndoBufferFailed, "") \ M(GetUndoBufferIteratorSuccess, "") \ M(GetUndoBufferIteratorFailed, "") \ + M(GetUndoBuffersWithKeysSuccess, "") \ + M(GetUndoBuffersWithKeysFailed, "") \ + M(ClearUndoBuffersByKeysSuccess, "") \ + M(ClearUndoBuffersByKeysFailed, "") \ M(GetTransactionRecordsSuccess, "") \ M(GetTransactionRecordsFailed, "") \ M(GetTransactionRecordsTxnIdsSuccess, "") \ @@ -1113,7 +1130,11 @@ M(OrcIOSharedBytes, "") \ M(OrcIODirectCount, "") \ M(OrcIODirectBytes, "") \ - M(PreparePartsForReadMilliseconds, "The time spend on loading CNCH part from ServerPart on worker when query with table version") + M(PreparePartsForReadMilliseconds, "The time spend on loading CNCH part from ServerPart on worker when query with table version") \ + M(LoadedServerParts, "Total server parts loaded from storage manager by version") \ + M(LoadServerPartsMilliseconds, "The time spend on loading server parts by version from storage data manager.") \ + M(LoadManifestPartsCacheHits, "Cache(disk) hit count of loading parts from manifest") \ + M(LoadManifestPartsCacheMisses, "Cache(disk) miss count of loading parts from manifest") namespace ProfileEvents { diff --git a/src/Common/S3FilePathMatcher.cpp b/src/Common/S3FilePathMatcher.cpp new file mode 100644 index 00000000000..e1e77808f6d --- /dev/null +++ b/src/Common/S3FilePathMatcher.cpp @@ -0,0 +1,64 @@ +#include +#include + +namespace DB +{ + +S3FilePathMatcher::S3FilePathMatcher(const String & path, const ContextPtr & context_ptr) +{ + const auto & settings = context_ptr->getSettingsRef(); + S3::URI s3_uri(path); + String endpoint = !s3_uri.endpoint.empty() ? s3_uri.endpoint : settings.s3_endpoint.toString(); + S3::S3Config s3_cfg( + endpoint, + settings.s3_region.toString(), + s3_uri.bucket, + settings.s3_ak_id.toString(), + settings.s3_ak_secret.toString(), + "", + "", + settings.s3_use_virtual_hosted_style); + const std::shared_ptr client = s3_cfg.create(); + s3_util = std::make_unique(client, s3_uri.bucket, false); +} + + +FileInfos S3FilePathMatcher::getFileInfos(const String & prefix_path) +{ + FileInfos file_infos; + + // erase '/' at first to list objects in the bucket + String prefix_without_slash = prefix_path; + size_t pos = prefix_without_slash.find_first_not_of('/'); + if (pos != std::string::npos) + prefix_without_slash.erase(0, pos); + else + prefix_without_slash.clear(); + + S3::S3Util::S3ListResult s3_list_result = s3_util->listObjectsWithDelimiter(prefix_without_slash, "/", false); + + if (s3_list_result.object_names.empty()) + return file_infos; + + int ls_length = s3_list_result.object_names.size(); + for (int i = 0; i < ls_length; i++) + { + // add '/' at first to keep the same with other file system + String file_path = std::filesystem::path("/") / s3_list_result.object_names[i]; + file_infos.emplace_back(file_path, s3_list_result.is_common_prefix[i]); + } + + return file_infos; +} + +String S3FilePathMatcher::getSchemeAndPrefix() +{ + return S3_SCHEME + s3_util->getBucket(); +} + +String S3FilePathMatcher::removeSchemeAndPrefix(const String & full_path) +{ + // remove scheme and bucket from path, add '/' at first to keep the same with other file system + return std::filesystem::path("/") / S3::URI(full_path).key; +} +} diff --git a/src/Common/S3FilePathMatcher.h b/src/Common/S3FilePathMatcher.h new file mode 100644 index 00000000000..a48e3719407 --- /dev/null +++ b/src/Common/S3FilePathMatcher.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ +const static String S3_SCHEME = "s3://"; + +class S3FilePathMatcher : public FilePathMatcher +{ +public: + S3FilePathMatcher(const String & path, const ContextPtr & context_ptr); + + ~S3FilePathMatcher() override = default; + + FileInfos getFileInfos(const String & prefix_path) override; + + String getSchemeAndPrefix() override; + + String removeSchemeAndPrefix(const String & full_path) override; + +private: + std::unique_ptr s3_util; +}; + +} diff --git a/src/Common/SettingsChanges.cpp b/src/Common/SettingsChanges.cpp index 5ab86fe2c26..0402ccbb51d 100644 --- a/src/Common/SettingsChanges.cpp +++ b/src/Common/SettingsChanges.cpp @@ -162,6 +162,7 @@ void SettingsChanges::fillFromProto(const Protos::SettingsChanges & proto) } std::unordered_set SettingsChanges::WHITELIST_SETTINGS = { + "access_table_names", "accessible_table_names", "active_role", "add_http_cors_header", @@ -217,7 +218,6 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "deduce_part_eliminate", "delay_dequeue_ms", "dialect_type", - "dict_table_full_mode", "direct_forward_query_to_cnch", "disable_perfect_shard_auto_merge", "disable_remote_stream_log", @@ -243,6 +243,7 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "enable_deterministic_sample_by_range", "enable_dictionary_compression", "enable_direct_insert", + "enable_distinct_remove", "enable_distributed_stages", "enable_dynamic_filter", "enable_final_for_delta", @@ -346,6 +347,15 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "kafka_max_partition_fetch_bytes", "kafka_session_timeout_ms", "kms_token", + "lasfs_access_key", + "lasfs_endpoint", + "lasfs_identity_id", + "lasfs_identity_type", + "lasfs_overwrite", + "lasfs_region", + "lasfs_secret_key", + "lasfs_service_name", + "lasfs_session_token", "load_balancing_offset", "local_disk_cache_thread_pool_size", "log_id", @@ -490,6 +500,7 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "preload_checksums_and_primary_index_cache", "priority", "process_list_block_time", + "profile", "query_auto_retry", "query_auto_retry_millisecond", "query_cache_min_lifetime", @@ -511,6 +522,27 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "restore_table_expression_in_distributed", "result_overflow_mode", "rm_zknodes_while_alter_engine", + "s3_access_key_id", + "s3_access_key_secret", + "s3_ak_id", + "s3_ak_secret", + "s3_check_objects_after_upload", + "s3_endpoint", + "s3_gc_inter_partition_parallelism", + "s3_gc_intra_partition_parallelism", + "s3_max_connections", + "s3_max_list_nums", + "s3_max_redirects", + "s3_max_request_ms", + "s3_max_single_part_upload_size", + "s3_max_single_read_retries", + "s3_max_unexpected_write_error_retries", + "s3_min_upload_part_size", + "s3_region", + "s3_skip_empty_files", + "s3_upload_part_size_multiply_factor", + "s3_upload_part_size_multiply_parts_count_threshold", + "s3_use_virtual_hosted_style", "schedule_sync_thread_per_table", "select_sequential_consistency", "send_logs_level", @@ -537,6 +569,13 @@ std::unordered_set SettingsChanges::WHITELIST_SETTINGS = "tcp_keep_alive_timeout", "tealimit_order_keep", "timeout_before_checking_execution_speed", + "tos_access_key", + "tos_connection_timeout", + "tos_endpoint", + "tos_region", + "tos_request_timeout", + "tos_secret_key", + "tos_security_token", "totals_auto_threshold", "totals_mode", "underlying_dictionary_tables", diff --git a/src/Common/StatusFile.cpp b/src/Common/StatusFile.cpp index ef00bb41cb2..36b17320749 100644 --- a/src/Common/StatusFile.cpp +++ b/src/Common/StatusFile.cpp @@ -85,7 +85,18 @@ StatusFile::StatusFile(std::string path_, FillFunction fill_) /// Write information about current server instance to the file. WriteBufferFromFileDescriptor out(fd, 1024); - fill(out); + try + { + fill(out); + /// Finalize here to avoid throwing exceptions in destructor. + out.finalize(); + } + catch (...) + { + /// Finalize in case of exception to avoid throwing exceptions in destructor + out.finalize(); + throw; + } } catch (...) { diff --git a/src/Common/ThreadStatus.h b/src/Common/ThreadStatus.h index 7f76cd5716d..744a34da4c1 100644 --- a/src/Common/ThreadStatus.h +++ b/src/Common/ThreadStatus.h @@ -313,6 +313,11 @@ class ThreadStatus : public boost::noncopyable return query_context.lock(); } + ContextPtr getGlobalContext() const + { + return global_context.lock(); + } + /// Starts new query and create new thread group for it, current thread becomes master thread of the query void initializeQuery(MemoryTracker * memory_tracker_ = nullptr); diff --git a/src/Common/Trace/DirectSystemLogExporter.cpp b/src/Common/Trace/DirectSystemLogExporter.cpp index cf7be55699e..177f2d6c8df 100644 --- a/src/Common/Trace/DirectSystemLogExporter.cpp +++ b/src/Common/Trace/DirectSystemLogExporter.cpp @@ -61,7 +61,7 @@ void OTELTraceLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(event_time); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time / 1000000000).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time / 1000000000).toUnderType()); columns[i++]->insert(trace_id); columns[i++]->insert(span_id); columns[i++]->insert(parent_span_id); diff --git a/src/Common/WorkerId.h b/src/Common/WorkerId.h index 8fbe45b17e8..117c3c32275 100644 --- a/src/Common/WorkerId.h +++ b/src/Common/WorkerId.h @@ -3,7 +3,7 @@ namespace DB { using String = std::string; -struct WorkerId +struct WorkerId { WorkerId(const String & vw_name_, const String & wg_name_, const String & id_) : vw_name(vw_name_), wg_name(wg_name_), id(id_) { } WorkerId() = default; @@ -14,19 +14,20 @@ struct WorkerId { return vw_name + "." + wg_name + "." + id; } -}; -struct WorkerIdEqual -{ - bool operator()(const WorkerId & lhs, const WorkerId & rhs) const + + inline bool operator==(WorkerId const & rhs) const { - return lhs.vw_name == rhs.vw_name && lhs.wg_name == rhs.wg_name && lhs.id == rhs.id; + return (this->vw_name == rhs.vw_name && this->wg_name == wg_name && this->id == id); } + }; -struct WorkerIdHash + +struct WorkerIdHash { std::size_t operator()(const WorkerId & worker_id) const { return std::hash()(worker_id.ToString()); } -}; +}; + } // namespace DB diff --git a/src/Common/callOnce.h b/src/Common/callOnce.h new file mode 100644 index 00000000000..402bb7365a1 --- /dev/null +++ b/src/Common/callOnce.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace DB +{ + +using OnceFlag = std::once_flag; + +template +void callOnce(OnceFlag & flag, Callable && func, Args&&... args) +{ + std::call_once(flag, std::forward(func), std::forward(args)...); +} + +} diff --git a/src/Common/time.h b/src/Common/time.h index 08e27548a53..456d5817512 100644 --- a/src/Common/time.h +++ b/src/Common/time.h @@ -31,3 +31,12 @@ inline std::chrono::time_point{ std::chrono::duration_cast(timespec_to_duration(ts))}; } + +/// return duration in ms from now to timestamp_ms, if now exceeded timestamp_ms, return empty +inline std::optional duration_ms_from_now(UInt64 timestamp_ms) +{ + auto now = time_in_milliseconds(std::chrono::system_clock::now()); + if (timestamp_ms < now) + return {}; + return timestamp_ms - now; +} diff --git a/src/Core/MySQL/MySQLReplication.cpp b/src/Core/MySQL/MySQLReplication.cpp index e29958ba2e7..83e67b45726 100644 --- a/src/Core/MySQL/MySQLReplication.cpp +++ b/src/Core/MySQL/MySQLReplication.cpp @@ -430,8 +430,11 @@ namespace MySQLReplication UInt32 i24 = 0; payload.readStrict(reinterpret_cast(&i24), 3); - const DayNum date_day_number(DateLUT::instance().makeDayNum( - static_cast((i24 >> 9) & 0x7fff), static_cast((i24 >> 5) & 0xf), static_cast(i24 & 0x1f)).toUnderType()); + const DayNum date_day_number( + DateLUT::serverTimezoneInstance() + .makeDayNum( + static_cast((i24 >> 9) & 0x7fff), static_cast((i24 >> 5) & 0xf), static_cast(i24 & 0x1f)) + .toUnderType()); row.push_back(Field(date_day_number.toUnderType())); break; @@ -536,10 +539,13 @@ namespace MySQLReplication readTimeFractionalPart(payload, fsp, meta); UInt32 year_month = readBits(val, 1, 17, 40); - time_t date_time = DateLUT::instance().makeDateTime( - year_month / 13, year_month % 13, readBits(val, 18, 5, 40) - , readBits(val, 23, 5, 40), readBits(val, 28, 6, 40), readBits(val, 34, 6, 40) - ); + time_t date_time = DateLUT::serverTimezoneInstance().makeDateTime( + year_month / 13, + year_month % 13, + readBits(val, 18, 5, 40), + readBits(val, 23, 5, 40), + readBits(val, 28, 6, 40), + readBits(val, 34, 6, 40)); if (!meta) // The max value of the 64 bit int flagged here exceeds the year value that is diff --git a/src/Core/MySQL/MySQLUtils.cpp b/src/Core/MySQL/MySQLUtils.cpp index 00d35ef9a4d..3940f0a4127 100644 --- a/src/Core/MySQL/MySQLUtils.cpp +++ b/src/Core/MySQL/MySQLUtils.cpp @@ -12,7 +12,7 @@ namespace MySQLProtocol namespace MySQLUtils { -DecimalUtils::DecimalComponents getNormalizedDateTime64Components(DataTypePtr data_type, ColumnPtr col, size_t row_num) +DecimalUtils::DecimalComponents getNormalizedDateTime64Components(DataTypePtr data_type, ColumnPtr col, size_t row_num, bool adapt_scale) { const auto * date_time_type = typeid_cast(data_type.get()); @@ -29,7 +29,7 @@ DecimalUtils::DecimalComponents getNormalizedDateTime64Components(Da --components.whole; } - if (components.fractional != 0) + if (components.fractional != 0 && adapt_scale) { if (scale > 6) { diff --git a/src/Core/MySQL/MySQLUtils.h b/src/Core/MySQL/MySQLUtils.h index e77e9c22ee4..23b97007e63 100644 --- a/src/Core/MySQL/MySQLUtils.h +++ b/src/Core/MySQL/MySQLUtils.h @@ -11,7 +11,7 @@ namespace MySQLUtils { /// Splits DateTime64 column data at a certain row number into whole and fractional part /// Additionally, normalizes the fractional part as if it was scale 6 for MySQL compatibility purposes -DecimalUtils::DecimalComponents getNormalizedDateTime64Components(DataTypePtr data_type, ColumnPtr col, size_t row_num); +DecimalUtils::DecimalComponents getNormalizedDateTime64Components(DataTypePtr data_type, ColumnPtr col, size_t row_num, bool adapt_scale = true); } } } diff --git a/src/Core/MySQL/PacketsProtocolText.cpp b/src/Core/MySQL/PacketsProtocolText.cpp index d67e9b01609..b34de90ba42 100644 --- a/src/Core/MySQL/PacketsProtocolText.cpp +++ b/src/Core/MySQL/PacketsProtocolText.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "Common/assert_cast.h" #include "Core/MySQL/IMySQLWritePacket.h" #include "DataTypes/DataTypeLowCardinality.h" @@ -41,10 +42,14 @@ ResultSetRow::ResultSetRow(const Serializations & serializations, const DataType ColumnPtr col = columns[i]->convertToFullIfNeeded(); if (col->isNullable()) col = assert_cast(*col).getNestedColumnPtr(); - auto components = MySQLUtils::getNormalizedDateTime64Components(data_type, col, row_num); + const auto * date_time_type = typeid_cast(data_type.get()); + auto context = CurrentThread::get().getQueryContext(); + bool keep_scale = context && context->getSettingsRef().datetime_format_mysql_protocol && date_time_type->getScale() < 6; + UInt32 scale = keep_scale ? date_time_type->getScale() : 6; + auto components = MySQLUtils::getNormalizedDateTime64Components(data_type, col, row_num, !keep_scale); writeDateTimeText<'-', ':', ' '>(LocalDateTime(components.whole, DateLUT::instance(getDateTimeTimezone(*data_type))), ostr); - ostr.write('.'); - writeDateTime64FractionalText(components.fractional, 6, ostr); + if (scale > 0) ostr.write('.'); + writeDateTime64FractionalText(components.fractional, scale, ostr); payload_size += getLengthEncodedStringSize(ostr.str()); serialized.push_back(std::move(ostr.str())); } @@ -158,68 +163,89 @@ void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const ColumnDefinition getColumnDefinition(const String & column_name, const DataTypePtr & data_type) { ColumnType column_type; + /// max column length after serialize into text + /// however, this func is called before serializing data. + /// we therefore do not have the exact max length + /// if set to 0, power BI would treat the column as null and reports error + /// to avoid that, we return the theoretical max length based on data type + uint32_t column_length = 0; CharacterSet charset = CharacterSet::binary; int flags = 0; uint8_t decimals = 0; + if (!data_type->isNullable()) + flags = ColumnDefinitionFlags::NOT_NULL_FLAG; DataTypePtr normalized_data_type = removeLowCardinalityAndNullable(data_type); TypeIndex type_index = normalized_data_type->getTypeId(); switch (type_index) { case TypeIndex::UInt8: column_type = ColumnType::MYSQL_TYPE_TINY; + column_length = 3; // max val 255 flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; break; case TypeIndex::UInt16: column_type = ColumnType::MYSQL_TYPE_SHORT; + column_length = 5; flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; break; case TypeIndex::UInt32: column_type = ColumnType::MYSQL_TYPE_LONG; + column_length = 10; flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; break; case TypeIndex::UInt64: column_type = ColumnType::MYSQL_TYPE_LONGLONG; + column_length = 20; flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; break; case TypeIndex::Int8: column_type = ColumnType::MYSQL_TYPE_TINY; + column_length = 4; // min val -127 flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Int16: column_type = ColumnType::MYSQL_TYPE_SHORT; + column_length = 6; flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Int32: column_type = ColumnType::MYSQL_TYPE_LONG; + column_length = 11; flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Int64: column_type = ColumnType::MYSQL_TYPE_LONGLONG; + column_length = 21; flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Float32: column_type = ColumnType::MYSQL_TYPE_FLOAT; flags = ColumnDefinitionFlags::BINARY_FLAG; decimals = 31; + column_length = 14; break; case TypeIndex::Float64: column_type = ColumnType::MYSQL_TYPE_DOUBLE; flags = ColumnDefinitionFlags::BINARY_FLAG; decimals = 31; + column_length = 24; break; case TypeIndex::Date: case TypeIndex::Date32: column_type = ColumnType::MYSQL_TYPE_DATE; + column_length = 10; // e.g., 2020-12-12 flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::DateTime: case TypeIndex::DateTime64: column_type = ColumnType::MYSQL_TYPE_DATETIME; + column_length = 26; // e.g., 2020-12-12 11:11:11.123456 flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Decimal32: case TypeIndex::Decimal64: column_type = ColumnType::MYSQL_TYPE_DECIMAL; + column_length = 20; // 18 (precision) + 1 (sign) + 1 (point) flags = ColumnDefinitionFlags::BINARY_FLAG; break; case TypeIndex::Decimal128: { @@ -237,14 +263,16 @@ ColumnDefinition getColumnDefinition(const String & column_name, const DataTypeP column_type = ColumnType::MYSQL_TYPE_DECIMAL; flags = ColumnDefinitionFlags::BINARY_FLAG; } + column_length = 67; // 65 + 1 (sign) + 1 (point) break; } default: - column_type = ColumnType::MYSQL_TYPE_STRING; + column_type = ColumnType::MYSQL_TYPE_VAR_STRING; + column_length = 65535; // max mysql var string len charset = CharacterSet::utf8_general_ci; break; } - return ColumnDefinition(column_name, charset, 0, column_type, flags, decimals); + return ColumnDefinition(column_name, charset, column_length, column_type, flags, decimals); } } diff --git a/src/Core/MySQL/PacketsProtocolText.h b/src/Core/MySQL/PacketsProtocolText.h index 07969a1ed93..309b07eed30 100644 --- a/src/Core/MySQL/PacketsProtocolText.h +++ b/src/Core/MySQL/PacketsProtocolText.h @@ -24,6 +24,7 @@ enum CharacterSet // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html enum ColumnDefinitionFlags { + NOT_NULL_FLAG = 1, UNSIGNED_FLAG = 32, BINARY_FLAG = 128 }; diff --git a/src/Core/Protocol.h b/src/Core/Protocol.h index bb24d5d9d94..ffd8bcb59c1 100644 --- a/src/Core/Protocol.h +++ b/src/Core/Protocol.h @@ -101,7 +101,8 @@ namespace Protocol /// This is such an inverted logic, where server sends requests /// And client returns back response QueryMetrics = 14, /// Query metrics in cnch worker side - MAX = QueryMetrics, + TimezoneUpdate = 15, + MAX = TimezoneUpdate, }; @@ -126,7 +127,8 @@ namespace Protocol "TableColumns", "PartUUIDs", "ReadTaskRequest", - "QueryMetrics" + "QueryMetrics", + "TimezoneUpdate" }; return packet <= MAX ? data[packet] diff --git a/src/Core/ProtocolDefines.h b/src/Core/ProtocolDefines.h index bfb0ef1adbc..f6764f099df 100644 --- a/src/Core/ProtocolDefines.h +++ b/src/Core/ProtocolDefines.h @@ -59,9 +59,11 @@ /// NOTE: DBMS_TCP_PROTOCOL_VERSION has nothing common with VERSION_REVISION, /// later is just a number for server version (one number instead of commit SHA) /// for simplicity (sometimes it may be more convenient in some use cases). -#define DBMS_TCP_PROTOCOL_VERSION 54450 +#define DBMS_TCP_PROTOCOL_VERSION 54451 #define DBMS_MIN_PROTOCOL_VERSION_WITH_INITIAL_QUERY_START_TIME 54449 #define DBMS_MIN_REVISION_WITH_QUERY_METRICS 54450 +static constexpr auto DBMS_MIN_PROTOCOL_VERSION_WITH_TIMEZONE_UPDATES = 54451; + diff --git a/src/Core/Settings.h b/src/Core/Settings.h index cb7f6d343d5..7a25004c3ab 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -377,6 +377,7 @@ enum PreloadLevelSettings : UInt64 0) \ \ M(Bool, log_queries, 1, "Log requests and write the log to the system table.", 0) \ + M(Bool, log_query_plan, 0, "Log json format query plan to the system query_log table.", 0) \ M(Bool, log_max_io_thread_queries, 1, "Log max io time thread requests and write the log to the system table", 0) \ M(LogQueriesType, \ log_queries_min_type, \ @@ -399,6 +400,7 @@ enum PreloadLevelSettings : UInt64 \ M(Bool, log_processors_profiles, false, "Log Processors profile events.", 0) \ M(Bool, log_segment_profiles, false, "Log profile of each segment info including runtime and planning information.", 0) \ + M(Bool, report_segment_profiles, false, "Report plan segment profile to coordinator.", 0)\ M(Bool, report_processors_profiles, false, "Report processors profile to coordinator.", 0) \ M(UInt64, report_processors_profiles_timeout_millseconds, 10, "Report processors profile to coordinator timeout millseconds.", 0) \ M(DistributedProductMode, \ @@ -459,89 +461,37 @@ enum PreloadLevelSettings : UInt64 M(String, count_distinct_implementation, "uniqExact", "What aggregate function to use for implementation of count(DISTINCT ...)", 0) \ \ M(Bool, add_http_cors_header, false, "Write add http CORS header.", 0) \ -\ - M(UInt64, \ - max_http_get_redirects, \ - 0, \ - "Max number of http GET redirects hops allowed. Make sure additional security measures are in place to prevent a malicious server " \ - "to redirect your requests to unexpected services.", \ - 0) \ -\ - M(Bool, \ - use_client_time_zone, \ - false, \ - "Use client timezone for interpreting DateTime string values, instead of adopting server timezone.", \ - 0) \ -\ - M(Bool, \ - send_progress_in_http_headers, \ - false, \ - "Send progress notifications using X-ClickHouse-Progress headers. Some clients do not support high amount of HTTP headers (Python " \ - "requests in particular), so it is disabled by default.", \ - 0) \ -\ - M(UInt64, \ - http_headers_progress_interval_ms, \ - 100, \ - "Do not send HTTP headers X-ClickHouse-Progress more frequently than at each specified interval.", \ - 0) \ -\ - M(Bool, \ - fsync_metadata, \ - 1, \ - "Do fsync after changing metadata for tables and databases (.sql files). Could be disabled in case of poor latency on server with " \ - "high load of DDL queries and high load of disk subsystem.", \ - 0) \ -\ - M(Bool, \ - join_use_nulls, \ - 1, \ - "Use NULLs for non-joined rows of outer JOINs for types that can be inside Nullable. If false, use default value of corresponding " \ - "columns data type.", \ - IMPORTANT) \ + \ + M(UInt64, max_http_get_redirects, 0, "Max number of http GET redirects hops allowed. Make sure additional security measures are in place to prevent a malicious server to redirect your requests to unexpected services.", 0) \ + \ + M(Bool, use_client_time_zone, false, "Use client timezone for interpreting DateTime string values, instead of adopting server timezone.", 0) \ + M(Timezone, session_timezone, "", "The default timezone for current session or query. The default value is server default timezone if empty.", 0) \ + \ + M(Bool, send_progress_in_http_headers, false, "Send progress notifications using X-ClickHouse-Progress headers. Some clients do not support high amount of HTTP headers (Python requests in particular), so it is disabled by default.", 0) \ + \ + M(UInt64, http_headers_progress_interval_ms, 100, "Do not send HTTP headers X-ClickHouse-Progress more frequently than at each specified interval.", 0) \ + \ + M(Bool, fsync_metadata, 1, "Do fsync after changing metadata for tables and databases (.sql files). Could be disabled in case of poor latency on server with high load of DDL queries and high load of disk subsystem.", 0) \ + \ + M(Bool, join_use_nulls, 1, "Use NULLs for non-joined rows of outer JOINs for types that can be inside Nullable. If false, use default value of corresponding columns data type.", IMPORTANT) \ M(Bool, join_using_null_safe, 0, "Force null safe equal comparison for USING keys except the last key of ASOF join", 0) \ \ + M(Bool, allow_return_nullable_array, 1, "For array related functions, if true, will return nullable(array)", 0) \ + \ M(JoinStrictness, join_default_strictness, JoinStrictness::ALL, "Set default strictness in JOIN query. Possible values: empty string, 'ANY', 'ALL'. If empty, query without strictness will throw exception.", 0) \ M(Bool, any_join_distinct_right_table_keys, false, "Enable old ANY JOIN logic with many-to-one left-to-right table keys mapping for all ANY JOINs. It leads to confusing not equal results for 't1 ANY LEFT JOIN t2' and 't2 ANY RIGHT JOIN t1'. ANY RIGHT JOIN needs one-to-many keys mapping to be consistent with LEFT one.", IMPORTANT) \ M(Bool, enable_join_on_1_equals_1, false, "Enable join on 1=1.", 0) \ \ M(UInt64, preferred_block_size_bytes, 1000000, "", 0) \ -\ - M(UInt64, \ - max_replica_delay_for_distributed_queries, \ - 300, \ - "If set, distributed queries of Replicated tables will choose servers with replication delay in seconds less than the specified " \ - "value (not inclusive). Zero means do not take delay into account.", \ - 0) \ - M(Bool, \ - fallback_to_stale_replicas_for_distributed_queries, \ - 1, \ - "Suppose max_replica_delay_for_distributed_queries is set and all replicas for the queried table are stale. If this setting is " \ - "enabled, the query will be performed anyway, otherwise the error will be reported.", \ - 0) \ - M(UInt64, \ - preferred_max_column_in_block_size_bytes, \ - 0, \ - "Limit on max column size in block while reading. Helps to decrease cache misses count. Should be close to L2 cache size.", \ - 0) \ -\ - M(Bool, \ - insert_distributed_sync, \ - false, \ - "If setting is enabled, insert query into distributed waits until data will be sent to all nodes in cluster.", \ - 0) \ - M(UInt64, \ - insert_distributed_timeout, \ - 0, \ - "Timeout for insert query into distributed. Setting is used only with insert_distributed_sync enabled. Zero value means no " \ - "timeout.", \ - 0) \ - M(Int64, \ - distributed_ddl_task_timeout, \ - 180, \ - "Timeout for DDL query responses from all hosts in cluster. If a ddl request has not been performed on all hosts, a response will " \ - "contain a timeout error and a request will be executed in an async mode. Negative value means infinite. Zero means async mode.", \ - 0) \ + \ + M(UInt64, max_replica_delay_for_distributed_queries, 300, "If set, distributed queries of Replicated tables will choose servers with replication delay in seconds less than the specified value (not inclusive). Zero means do not take delay into account.", 0) \ + M(Bool, fallback_to_stale_replicas_for_distributed_queries, 1, "Suppose max_replica_delay_for_distributed_queries is set and all replicas for the queried table are stale. If this setting is enabled, the query will be performed anyway, otherwise the error will be reported.", 0) \ + M(UInt64, preferred_max_column_in_block_size_bytes, 0, "Limit on max column size in block while reading. Helps to decrease cache misses count. Should be close to L2 cache size.", 0) \ + \ + M(Bool, insert_select_with_profiles, false, "If setting is enabled, return the total inserted (selected) rows.", 0) \ + M(Bool, insert_distributed_sync, false, "If setting is enabled, insert query into distributed waits until data will be sent to all nodes in cluster.", 0) \ + M(UInt64, insert_distributed_timeout, 0, "Timeout for insert query into distributed. Setting is used only with insert_distributed_sync enabled. Zero value means no timeout.", 0) \ + M(Int64, distributed_ddl_task_timeout, 180, "Timeout for DDL query responses from all hosts in cluster. If a ddl request has not been performed on all hosts, a response will contain a timeout error and a request will be executed in an async mode. Negative value means infinite. Zero means async mode.", 0) \ M(Milliseconds, stream_flush_interval_ms, 7500, "Timeout for flushing data from streaming storages.", 0) \ M(Milliseconds, stream_poll_timeout_ms, 500, "Timeout for polling data from/to streaming storages.", 0) \ \ @@ -673,6 +623,7 @@ enum PreloadLevelSettings : UInt64 M(Bool, allow_experimental_data_skipping_indices, true, "Emulate data skipping indices", 0) \ M(Bool, enable_predicate_pushdown, false, "Where to push down predicate", 0) \ M(Bool, dict_table_full_mode, false, "If encode / decode table is not bucket table, try to dispatch dict to all workers, if false, throw exception instead", 0) \ + M(UInt64, max_in_value_list_to_pushdown, 10000, "Max size of in value list in filter", 0) \ M(UInt64, pathgraph_threshold_y, 0, "maximum point number in each level", 0) \ M(Bool, to_string_extra_arguments, true, "Whether to allow an extra argument in toString Function", 0) \ \ @@ -894,6 +845,8 @@ enum PreloadLevelSettings : UInt64 "longest one.", \ 0) \ M(Bool, optimize_read_in_order, true, "Enable ORDER BY optimization for reading data in corresponding order in MergeTree tables.", 0) \ + M(Bool, optimize_read_in_partition_order, false, "In optimize_read_in_order mode, whether to read parts partition-by-partition if applicable, it will also delay inverted index evaluation till pipeline execution", 0) \ + M(Bool, force_read_in_partition_order, 0, "Similar to optimize_read_in_partition_order, but throw an exception if it cannot be applied to the query, mainly for testing", 0) \ M(Bool, optimize_aggregation_in_order, false, "Enable GROUP BY optimization for aggregating data in corresponding order in MergeTree tables.", 0) \ M(UInt64, read_in_order_two_level_merge_threshold, 100, "Minimal number of parts to read to run preliminary merge step during multithread reading in order of primary key.", 0) \ M(Bool, low_cardinality_allow_in_native_format, true, "Use LowCardinality type in Native format. Otherwise, convert LowCardinality columns to ordinary for select query, and convert ordinary columns to required LowCardinality for insert query.", 0) \ @@ -968,6 +921,7 @@ enum PreloadLevelSettings : UInt64 M(UInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \ M(UInt64, mutations_wait_timeout, 0, "Maximum seconds to wait for synchronous mutations. 0 - wait unlimited time", 0) \ M(String, mutation_query_id, "", "Used to overwrite mutation's query id in tests", 0) \ + M(Bool, mutation_allow_modify_remove_nullable, false, "default not allow modify column from Nullable(xxx) to xxx", 0) \ M(Bool, system_mutations_only_basic_info, false, "Only return basic information that stored in KV. It avoid acquiring merge thread of tables", 0) \ M(Bool, enable_lightweight_delete, true, "Enable lightweight DELETE for mergetree tables.", 0) \ M(Bool, optimize_move_functions_out_of_any, false, "Move functions out of aggregate functions 'any', 'anyLast'.", 0) \ @@ -1096,6 +1050,7 @@ enum PreloadLevelSettings : UInt64 0) \ \ M(Bool, handle_division_by_zero, false, "If set true, return null for division by zero (MySQL Behavior)", 0) \ + M(Bool, enable_bucket_for_distribute, true, "If set true, enable distribute by keyword by replacing with distribute", 0) \ \ M(Bool, optimize_rewrite_sum_if_to_count_if, true, "Rewrite sumIf() and sum(if()) function countIf() function when logically equivalent", 0) \ M(UInt64, insert_shard_id, 0, "If non zero, when insert into a distributed table, the data will be inserted into the shard `insert_shard_id` synchronously. Possible values range from 1 to `shards_number` of corresponding distributed table", 0) \ @@ -1194,11 +1149,14 @@ enum PreloadLevelSettings : UInt64 M(TextCaseOption, text_case_option, TextCaseOption::MIXED, "Convert identifiers to lower case/upper case just like MySQL", 0) \ M(Bool, enable_implicit_arg_type_convert, false, "Eable implicit type conversion for functions", 0) \ M(Bool, exception_on_unsupported_mysql_syntax, true, "Whether throws exceptions on currently unsupported mysql syntax such as auto_increment", 0) \ + M(Bool, only_full_group_by, true, "If the ONLY_FULL_GROUP_BY is enabled (which it is by default), rejects queries for which the select list, HAVING condition, or ORDER BY list refer to nonaggregated columns that are neither named in the GROUP BY clause nor are functionally dependent on them.", 0) \ M(Bool, adaptive_type_cast, true, "Performs type cast operations adaptively, according to the value", 0) \ M(Bool, parse_literal_as_decimal, false, "Parse numeric literal as decimal instead of float", 0) \ M(Bool, formatdatetime_f_prints_single_zero, false, "Formatter '%f' in function 'formatDateTime()' produces a single zero instead of six zeros if the formatted value has no fractional seconds.", 0) \ M(Bool, formatdatetime_parsedatetime_m_is_month_name, false, "Formatter '%M' in functions 'formatDateTime()' and 'parseDateTime()' produces the month name instead of minutes.", 0) \ M(Bool, date_format_clickhouse, false, "use date_format as a clickhouse function instead of hive", 0) \ + M(Bool, datetime_format_mysql_protocol, false, "In mysql protocol, outputs datetime with precision similar to mysql", 0) \ + M(Bool, datetime_format_mysql_definition, false, "In mysql dialect, whether create table with timestamp/datetime uses datetime64(3)", 0) \ M(Bool, tealimit_order_keep, false, "Whether tealimit output keep order by clause", 0)\ M(UInt64, early_limit_for_map_virtual_columns, 0, "Enable early limit while quering _map_column_keys column", 0)\ M(Bool, skip_nullinput_notnull_col, false, "Skip null value in JSON for not null column", 0)\ @@ -1264,6 +1222,7 @@ enum PreloadLevelSettings : UInt64 M(Bool, rewrite_unknown_left_join_identifier, true, "Whether to rewrite unknown left join identifier, this is a deprecated feature but Aeolus SQL depends on it", 0) \ M(Bool, allow_mysql_having_name_resolution, false, "Whether to use MySQL special name resolution rules for HAVING clauses ", 0) \ M(String, access_table_names, "", "Session level restricted tables query can access", 0) \ + M(String, accessible_table_names, "", "Session level restricted tables query can access", 0) \ \ /** settings in cnch **/ \ M(Seconds, drop_range_memory_lock_timeout, 5, "The time that spend on wait for memory lock when doing drop range", 0) \ @@ -1314,7 +1273,8 @@ enum PreloadLevelSettings : UInt64 M(UInt64, cloud_task_auto_stop_timeout, 60, "We will remove this task when heartbeat can't find this task more than retries_count times.", 0)\ M(Bool, enable_local_disk_cache, 1, "enable global local disk cache", 0) \ M(UInt64, parts_preload_level, 1, "used for global preload(manual alter&table auto), 0=close preload;1=preload meta;2=preload data;3=preload meta&data, Note: for table auto preload, 0 will disable all table preload, > 0 will use table preload setting", 0) \ - M(UInt64, parts_preload_throttler, 0, "used for max preload rpc concurrent count", 0) \ + M(MaxThreads, cnch_parallel_preloading, 0, "Max threads when worker preload parts", 0) \ + M(UInt64, preload_send_rpc_max_ms, 3000, "Max rpc ms when send preload parts reqeust", 0) \ M(DiskCacheMode, disk_cache_mode, DiskCacheMode::AUTO, "Whether to use local disk cache", 0) \ M(Bool, enable_vw_customized_setting, false, "Allow vw customized overwrite profile settings", 0) \ M(Bool, enable_async_execution, false, "Whether to enable async execution", 0) \ @@ -1341,8 +1301,13 @@ enum PreloadLevelSettings : UInt64 M(Seconds, unique_key_attach_partition_timeout, 3600, "Default timeout (seconds) for attaching partition for unique key", 0) \ M(Bool, enable_unique_table_attach_without_dedup, false, "Enable directly make attached parts visible without dedup for unique table, for example: override mode of offline loading", 0) \ M(Bool, enable_unique_table_detach_ignore_delete_bitmap, false, "Enable ignore delete bitmap info when handling detach commands for unique table, for example: delete bitmap has been broken, we can just ignore it via this parameter.", 0) \ - M(DedupKeyMode, dedup_key_mode, DedupKeyMode::REPLACE, "Handle different deduplication modes, current valid values: REPLACE, THROW, APPEND. THROW mode can only be used in non-staging area scenarios. APPEND mode will not execute dedup process, which is suitable for historical non-duplicated data import scenarios", 0) \ + M(DedupKeyMode, dedup_key_mode, DedupKeyMode::REPLACE, "Handle different deduplication modes, current valid values: REPLACE, THROW, APPEND, IGNORE. THROW mode and IGNORE mode can only be used in non-staging area scenarios. APPEND mode will not execute dedup process, which is suitable for historical non-duplicated data import scenarios", 0) \ M(Seconds, unique_sleep_seconds_after_acquire_lock, 0, "Only for test", 0) \ + M(Seconds, unique_acquire_write_lock_timeout, 0, "It has higher priority than table setting. Only when it's zero, use table setting", 0) \ + M(Seconds, max_dedup_execution_time, 21600, "Set default value to 6h", 0) \ + M(UInt64, max_dedup_retry_time, 1, "Dedup task retry num", 0) \ + M(Bool, insert_if_not_exists, false, "Valid for partial update using update set statements, insert will be performed when no row exists if enabled", 0) \ + M(Bool, optimize_unique_table_write, false, "Remove gather stage and support parallel insert for unique table ETL task", 0) \ \ /** Settings for Map */ \ M(Bool, optimize_map_column_serialization, false, "Construct map value columns in advance during serialization", 0) \ @@ -1430,7 +1395,8 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_prune_source_plan_segment, false, "Whether prune source plan segment", 0) \ M(Bool, enable_prune_empty_resource, false, "Whether prune resource sending", 0) \ M(Bool, enable_prune_compute_plan_segment, false, "Whether prune compute plan segment", 0) \ - M(Bool, enable_optimizer_for_create_select, false, "Whether enable query optimizer for CREATE TABLE SELECT queries", 0) \ + M(Bool, send_cacheable_table_definitions, false, "Whether to send cacheable table definitions to worker, which reduces parsing overhead and is particularly beneficial for high concurrency workload", 0) \ + M(Bool, enable_optimizer_for_create_select, true, "Whether enable query optimizer for CREATE TABLE SELECT queries", 0) \ M(Bool, log_optimizer_run_time, false, "Whether Log optimizer runtime", 0) \ M(UInt64, plan_optimizer_timeout, 600000, "Max running time of a plan rewriter optimizer in ms", 0) \ M(UInt64, plan_optimizer_rule_warning_time, 1000, "Send warning if a optimize rule optimize time exceed timeout", 0) \ @@ -1450,6 +1416,7 @@ enum PreloadLevelSettings : UInt64 M(UInt64, global_bindings_update_time, 60*60, "Interval to update global binding cache from catalog, in seconds.", 0) \ /** */ \ M(Bool, late_materialize_aggressive_push_down, false, "When table use early materialize strategy, this setting enable aggressively moving predicates to read chain w/o considering other factor like columns size or number of columns in the query", 0) \ + M(Bool, convert_to_right_type_for_in_subquery, true, "For IN subquery, whether convert arguments to the right type", 0) \ /** Optimizer relative settings, Plan build and RBO */ \ M(Bool, enable_auto_prepared_statement, false, "Whether to enable automatic prepared statement", 0) \ M(Bool, enable_nested_loop_join, true, "Whether enable nest loop join for outer join with filter", 0)\ @@ -1503,6 +1470,8 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_subcolumn_optimization_through_union, true, "Whether enable sub column optimization through set operation.", 0) \ M(Bool, enable_buffer_for_deadlock_cte, true, "Whether to buffer data for deadlock cte", 0) \ M(UInt64, statistics_collect_debug_level, 0, "Debug level for statistics collector", 0) \ + M(Bool, enable_remove_remove_unnecessary_buffer, false, "Whether to only add buffer for cte consumer that may cause deadlock", 0) \ + M(Int64, max_buffer_size_for_deadlock_cte, 8000000000, "Inline CTE if buffer is oversized, set 0 to inline all cte, set -1 to buffer data for all cte even no stats", 0) \ M(Bool, enable_add_exchange, true, "Whether to enable AddExchange rule", 0) \ M(Bool, enable_bitmap_index_splitter, true, "Whether to enable BitMapIndexSplitter", 0) \ M(Bool, enable_column_pruning, true, "Whether to enable ColumnPruning", 0) \ @@ -1562,12 +1531,13 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_push_partial_agg_through_exchange, true, "Whether to enable PushPartialAggThroughExchange rules", 0) \ M(Bool, enable_push_partial_agg_through_union, true, "Whether to enable PushPartialAggThroughUnion rules", 0) \ M(Bool, enable_push_partial_sorting_through_exchange, true, "Whether to enable PushPartialSortingThroughExchange rules", 0) \ + M(Bool, enable_push_partial_sorting_through_union, true, "Whether to enable PushPartialSortingThroughUnion rules", 0) \ M(Bool, enable_push_partial_limit_through_exchange, true, "Whether to enable PushPartialLimitThroughExchange rules", 0) \ M(Bool, enable_push_partial_distinct_through_exchange, true, "Whether to enable PushPartialDistinctThroughExchange rules", 0) \ M(UInt64, max_rows_to_use_topn_filtering, 0, "The maximum N of TopN to use topn filtering optimization. Set 0 to choose this value adaptively.", 0) \ M(String, topn_filtering_algorithm_for_unsorted_stream, "SortAndLimit", "The default topn filtering algorithm for unsorted stream, can be one of: 'SortAndLimit', 'Heap'", 0) \ M(Bool, enable_create_topn_filtering_for_aggregating, false, "Whether to enable CreateTopNFilteringForAggregating rules", 0) \ - M(Bool, enable_push_topn_through_projection, true, "Whether to enable PushTopNThroughProjection rules", 0) \ + M(Bool, enable_push_topn_through_projection, false, "Whether to enable PushTopNThroughProjection rules", 0) \ M(Bool, enable_push_topn_filtering_through_projection, true, "Whether to enable PushTopNFilteringThroughProjection rules", 0) \ M(Bool, enable_push_topn_filtering_through_union, true, "Whether to enable PushTopNFilteringThroughUnion rules", 0) \ M(Bool, enable_optimize_aggregate_memory_efficient, false, "Whether to enable OptimizeMemoryEfficientAggregation rules", 0) \ @@ -1579,8 +1549,9 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_common_expression_sharing_for_prewhere, true, "Whether to share common expression between steps and PREWHERE", 0) \ M(Bool, enable_unalias_symbol_references, true, "Whether to enable unalias symbol references", 0) \ M(UInt64, common_expression_sharing_threshold, 3, "The minimal cost to share a common expression, the cost is defined by (complexity * (occurrence - 1))", 0) \ - M(Bool, extract_bitmap_implicit_filter, false, "Whether to extract implicit filter for bitmap functions, e.g. for bitmapCount('1 | 2 & 3')(a, b), extract 'a in (1, 2, 3)'", 0) \ + M(Bool, extract_bitmap_implicit_filter, true, "Whether to extract implicit filter for bitmap functions, e.g. for bitmapCount('1 | 2 & 3')(a, b), extract 'a in (1, 2, 3)'", 0) \ M(Bool, enable_add_local_exchange, false, "Whether to add local exchange", 0) \ + M(Bool, enable_join_using_to_join_on, false, "Whether rewrite Join Using to Join On to make reordering possible", 0) \ M(Bool, enable_ab_test, false, "Whether to open ab test for settings, If true, the settings for some queries are set in the ab_test_profile profile.", 0) \ M(Float, ab_test_traffic_factor, 0, "Proportion of queries that perform ab test, meaningful between 0 and 1", 0) \ M(String, ab_test_profile, "default", "Profile name for ab test", 0) \ @@ -1602,9 +1573,10 @@ enum PreloadLevelSettings : UInt64 M(Bool, statistics_simplify_histogram, false, "Reduce buckets of histogram with simplifying", 0) \ M(Float, statistics_simplify_histogram_ndv_density_threshold, 0.2, "Histogram simplifying threshold for ndv", 0) \ M(Float, statistics_simplify_histogram_range_density_threshold, 0.2, "Histogram simplifying threshold for range", 0) \ - M(Bool, statistics_expand_to_current, false, "Expand Date/Date32/DateTime/DateTime64 columns stats to current timestamp", 0) \ + M(Bool, statistics_expand_to_current, true, "Expand Date/Date32/DateTime/DateTime64 columns stats to current timestamp", 0) \ M(UInt64, statistics_current_timestamp, 0, "Timestamp used for statistics_expand_to_current, 0 to use now(), for testing purpose", 0) \ M(UInt64, statistics_expand_to_current_threshold_days, 31, "If abs(stats_timestamp - stats_column_max) is within this threshold, we will expand this column", 0) \ + M(Float, statistics_expand_to_current_histogram_ratio, 0.10, "For histogram, only expand last buckets containing rows with this ratio", 0) \ M(StatisticsCachePolicy, statistics_cache_policy, StatisticsCachePolicy::Default, "Cache policy for stats command and SQLs: (default|cache|catalog)", 0) \ M(Bool, statistics_query_cnch_parts_for_row_count, true, "Use cnch parts instead of count(*) for row count to speed up test", 0) \ /** Optimizer relative settings, cost model and estimation */ \ @@ -1643,7 +1615,6 @@ enum PreloadLevelSettings : UInt64 M(UInt64, max_replicate_shuffle_size, 50000000, "Max join build size, when enum replicate", 0) \ M(UInt64, parallel_join_threshold, 2000000, "Parallel join right source rows threshold", 0) \ M(Bool, enable_adaptive_scheduler, false, "Whether enable adaptive scheduler", 0) \ - M(Bool, enable_wait_cancel_rpc, false, "Whether wait rpcs of cancel worker to finish", 0) \ M(UInt64, parallel_join_rows_batch_threshold, 4096, "Rows that concurrent hash join wait data reach, then to build hashtable or join block", 0) \ M(Bool, add_parallel_after_join, false, "Add parallel after join", 0) \ M(Bool, enforce_round_robin, false, "Whether add round robin exchange node", 0) \ @@ -1655,6 +1626,7 @@ enum PreloadLevelSettings : UInt64 M(UInt64, max_expand_join_key_size, 3, "Whether enable using equivalences when property match", 0) \ M(UInt64, max_expand_agg_key_size, 3, "Max allowed agg/window keys number when expand powerset when property match", 0) \ M(Bool, enable_sharding_optimize, false, "Whether enable sharding optimization, eg. local join", 0) \ + M(Bool, enable_bucket_shuffle, false, "Whether enable bucket shuffle", 0) \ M(Bool, enable_magic_set, true, "Whether enable magic set rewriting for join aggregation", 0) \ M(Float, magic_set_filter_factor, 0.5, "The minimum filter factor of magic set, used for early pruning", 0) \ M(UInt64, magic_set_max_search_tree, 2, "The maximum table scans in magic set, used for early pruning", 0) \ @@ -1693,6 +1665,10 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_eliminate_complicated_pk_fk_join, false, "Whether to eliminate complicated join by fk optimization", 0) \ M(Bool, enable_eliminate_complicated_pk_fk_join_without_top_join, false, "Whether to allow eliminate complicated join by fk pull through pass the multi-child node even if no top join", 0) \ M(Bool, enable_filtered_pk_selectivity, 1, "Enable the selectivity of filtered pk table", 0) \ + M(Bool, execute_subquery_in_lambda, true, "Whether to execute subquery in lambda", 0) \ + M(Bool, early_execute_scalar_subquery, false, "Whether to early execute scalar subquery", 0) \ + M(Bool, early_execute_in_subquery, false, "Whether to early execute in subquery", 0) \ + M(String, prewhere_skip_functions, "", "A collection of functions which are not choosen as prewhere, use ',' to seperate", 0) \ \ /** remote disk cache*/ \ M(Bool, use_local_cache_for_remote_storage, true, "Use local cache for remote storage like HDFS or S3, it's used for remote table engine only", 0) \ @@ -1725,6 +1701,7 @@ enum PreloadLevelSettings : UInt64 M(Bool, exchange_force_use_buffer, false, "Force exchange use buffer as possible", 0) \ M(Bool, exchange_enable_node_stable_hash, false, "Force exchange use buffer as possible", 0) \ M(Bool, exchange_use_query_memory_tracker, true, "Use query-level memory tracker", 0) \ + M(String, exchange_shuffle_method_name, "cityHash64V2", "Shuffle method name used in exchange", 0) \ M(UInt64, wait_for_post_processing_timeout_ms, 1000, "Timeout for waiting post processing rpc from workers.", 0) \ M(UInt64, distributed_query_wait_exception_ms, 2000,"Wait final planSegment exception from segmentScheduler.", 0) \ M(UInt64, distributed_max_parallel_size, false, "Max distributed execution parallel size", 0) \ @@ -1847,6 +1824,7 @@ enum PreloadLevelSettings : UInt64 M(Bool, force_manipulate_materialized_mysql_table, false, "For tables of materialized mysql engine, force to manipulate it.", 0) \ M(Bool, throw_exception_when_mysql_connection_failed, false, "For mysql database engine, whether throw exception when mysql connection failed. If it is set to true, clickhouse may shutdown during restarting due to mysql connection failure", 0) \ /** for inverted index*/ \ + M(Bool, enable_inverted_index, true, "Enable inverted index", 0) \ M(UInt64, skip_inverted_index_term_size, 512, "If term size bigger than size, do not filter with inverted index", 0) \ M(Bool, disable_str_to_array_cast, false, "disable String to Array(XXX) CAST", 0) \ /** materialized view async refresh related settings */ \ @@ -1877,6 +1855,7 @@ enum PreloadLevelSettings : UInt64 MAKE_OBSOLETE(M, UInt64, exchange_local_no_repartition_extra_threads, 32) \ MAKE_OBSOLETE(M, UInt64, filtered_ratio_to_use_skip_read, 0) \ MAKE_OBSOLETE(M, Bool, enable_two_stages_prewhere, false) \ + MAKE_OBSOLETE(M, Bool, funnel_old_rule, false) \ /** End of OBSOLETE_SETTINGS */ \ #define FORMAT_FACTORY_SETTINGS(M) \ @@ -1965,6 +1944,7 @@ enum PreloadLevelSettings : UInt64 M(Bool, input_format_parquet_coalesce_read, true, "Merge small IO ranges, See arrow::ReadRangeCache", 0) \ M(Bool, input_format_parquet_use_lazy_io_cache, true, "Lazy caching will trigger io requests when they are requested for the first time. See arrow::ReadRangeCache", 0) \ M(Bool, input_format_orc_filter_push_down, true, "When reading Orc files, skip whole row groups based on the WHERE/PREWHERE expressions and min/max statistics in the Parquet metadata.", 0) \ + M(DateTimeOverflowBehavior, date_time_overflow_behavior, "ignore", "Overflow mode for Date, Date32, DateTime, DateTime64 types. Possible values: 'ignore', 'throw', 'saturate'.", 0) \ \ M(Bool, input_format_orc_allow_missing_columns, false, "Allow missing columns while reading ORC input formats", 0) \ M(Bool, input_format_arrow_import_nested, false, "Allow to insert array of structs into Nested table in Arrow input format.", 0) \ @@ -2071,6 +2051,8 @@ enum PreloadLevelSettings : UInt64 M(Bool, enable_cache_reader_buffer_reuse, false, "Decpreated settings, only a place holder", 0) \ M(Bool, enable_auto_query_forwarding, true, "Auto forward query to target server when having multiple servers", 0) \ M(Bool, enable_select_query_forwarding, false, "Auto forward select query to target server when having multiple servers", 0) \ + M(Bool, enable_multiple_table_select_query_forwarding, false, "Auto forward select query with multiple tables to target server when having multiple servers", 0) \ + M(String, explicit_main_table, "", "User specified main table for query forwarding when select multiple tables", 0) \ \ M(Bool, merge_partition_stats, false, "merge all partition stats", 0) \ M(Bool, enable_three_part_identifier, true, "merge all partition stats", 0) \ @@ -2093,8 +2075,14 @@ enum PreloadLevelSettings : UInt64 M(Bool, load_dict_from_cache, true, "Read dict from cache", 0) \ M(Bool, throw_exception_if_bucket_unmatched, false, "Whether to throw exception if bucket is unmatched when send bitengine resource", 0) \ M(Bool, enable_cnch_engine_conversion, false, "Whether to converse MergeTree engine to CnchMergeTree engine", 0) \ - /** End of BitEngine related settings */ \ + M(Bool, enable_short_circuit, false, "Whether to enable topn short path", 0) \ + M(Bool, enable_table_scan_build_pipeline_optimization, false, "Whether to enable table scan build pipeline optimization", 0) \ + \ + /** End of gis related settings */ \ \ + M(Bool, filter_mark_ranges_with_ivt_when_exec, false, "Delay mark ranges filter with inverted index at pipeline exec", 0) \ + M(Int64, remote_fs_read_failed_injection, 0, "inject read error for remote fs, 0 means disable, -1 means return error immediately, > 0 means delay read ms", 0) \ + M(Int64, remote_fs_write_failed_injection, 0, "inject write error for remote fs, 0 means disable, -1 means return error immediately, > 0 means delay write ms", 0) \ // End of FORMAT_FACTORY_SETTINGS diff --git a/src/Core/SettingsEnums.cpp b/src/Core/SettingsEnums.cpp index 3eae3d7b9e1..29e85d43a75 100644 --- a/src/Core/SettingsEnums.cpp +++ b/src/Core/SettingsEnums.cpp @@ -223,7 +223,12 @@ IMPLEMENT_SETTING_ENUM(ShortCircuitFunctionEvaluation, ErrorCodes::BAD_ARGUMENTS IMPLEMENT_SETTING_ENUM( DedupKeyMode, ErrorCodes::BAD_ARGUMENTS, - {{"replace", DedupKeyMode::REPLACE}, {"append", DedupKeyMode::APPEND}, {"throw", DedupKeyMode::THROW}}) + {{"replace", DedupKeyMode::REPLACE}, {"append", DedupKeyMode::APPEND}, {"throw", DedupKeyMode::THROW}, {"ignore", DedupKeyMode::IGNORE}}) + +IMPLEMENT_SETTING_ENUM( + DedupImplVersion, + ErrorCodes::BAD_ARGUMENTS, + {{"dedup_in_write_suffix", DedupImplVersion::DEDUP_IN_WRITE_SUFFIX}, {"dedup_in_txn_commit", DedupImplVersion::DEDUP_IN_TXN_COMMIT}}) IMPLEMENT_SETTING_ENUM( RefreshViewTaskStatus, @@ -244,4 +249,9 @@ IMPLEMENT_SETTING_ENUM(SchemaInferenceMode, ErrorCodes::BAD_ARGUMENTS, {{"default", SchemaInferenceMode::DEFAULT}, {"union", SchemaInferenceMode::UNION}}) +IMPLEMENT_SETTING_ENUM(DateTimeOverflowBehavior, ErrorCodes::BAD_ARGUMENTS, + {{"throw", FormatSettings::DateTimeOverflowBehavior::Throw}, + {"ignore", FormatSettings::DateTimeOverflowBehavior::Ignore}, + {"saturate", FormatSettings::DateTimeOverflowBehavior::Saturate}}) + } // namespace DB diff --git a/src/Core/SettingsEnums.h b/src/Core/SettingsEnums.h index f8795fa669e..a6013f12411 100644 --- a/src/Core/SettingsEnums.h +++ b/src/Core/SettingsEnums.h @@ -381,10 +381,19 @@ enum class DedupKeyMode REPLACE, THROW, APPEND, + IGNORE, }; DECLARE_SETTING_ENUM(DedupKeyMode) +enum class DedupImplVersion : int8_t +{ + DEDUP_IN_WRITE_SUFFIX = 1, + DEDUP_IN_TXN_COMMIT = 2, +}; + +DECLARE_SETTING_ENUM(DedupImplVersion) + enum class RefreshViewTaskStatus : int8_t { START = 1, @@ -412,4 +421,6 @@ enum class SchemaInferenceMode DECLARE_SETTING_ENUM(SchemaInferenceMode) +DECLARE_SETTING_ENUM_WITH_RENAME(DateTimeOverflowBehavior, FormatSettings::DateTimeOverflowBehavior) + } diff --git a/src/Core/SettingsFields.cpp b/src/Core/SettingsFields.cpp index 0812fe0886c..c9f2bd9365b 100644 --- a/src/Core/SettingsFields.cpp +++ b/src/Core/SettingsFields.cpp @@ -31,6 +31,7 @@ #include #include #include +#include namespace DB @@ -327,6 +328,24 @@ String SettingFieldEnumHelpers::readBinary(ReadBuffer & in) return str; } +void SettingFieldTimezone::writeBinary(WriteBuffer & out) const +{ + writeStringBinary(value, out); +} + +void SettingFieldTimezone::readBinary(ReadBuffer & in) +{ + String str; + readStringBinary(str, in); + *this = std::move(str); +} + +void SettingFieldTimezone::validateTimezone(const std::string & tz_str) +{ + cctz::time_zone validated_tz; + if (!tz_str.empty() && !cctz::load_time_zone(tz_str, &validated_tz)) + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Invalid time zone: {}", tz_str); +} String SettingFieldCustom::toString() const { diff --git a/src/Core/SettingsFields.h b/src/Core/SettingsFields.h index 74cf852f9f5..5602a108f0b 100644 --- a/src/Core/SettingsFields.h +++ b/src/Core/SettingsFields.h @@ -460,6 +460,37 @@ void SettingFieldMultiEnum::readBinary(ReadBuffer & in) return std::initializer_list> __VA_ARGS__ .size();\ } +/// Setting field for specifying user-defined timezone. It is basically a string, but it needs validation. +struct SettingFieldTimezone +{ + String value; + bool changed = false; + + explicit SettingFieldTimezone(std::string_view str = {}) { validateTimezone(std::string(str)); value = str; } + explicit SettingFieldTimezone(const String & str) { validateTimezone(str); value = str; } + explicit SettingFieldTimezone(String && str) { validateTimezone(str); value = std::move(str); } + explicit SettingFieldTimezone(const char * str) { validateTimezone(str); value = str; } + explicit SettingFieldTimezone(const Field & f) { const String & str = f.safeGet(); validateTimezone(str); value = str; } + + SettingFieldTimezone & operator =(std::string_view str) { validateTimezone(std::string(str)); value = str; changed = true; return *this; } + SettingFieldTimezone & operator =(const String & str) { *this = std::string_view{str}; return *this; } + SettingFieldTimezone & operator =(String && str) { validateTimezone(str); value = std::move(str); changed = true; return *this; } + SettingFieldTimezone & operator =(const char * str) { *this = std::string_view{str}; return *this; } + SettingFieldTimezone & operator =(const Field & f) { *this = f.safeGet(); return *this; } + + operator const String &() const { return value; } /// NOLINT + explicit operator Field() const { return value; } + + const String & toString() const { return value; } + void parseFromString(const String & str) { *this = str; } + + void writeBinary(WriteBuffer & out) const; + void readBinary(ReadBuffer & in); + +private: + void validateTimezone(const std::string & tz_str); +}; + /// Can keep a value of any type. Used for user-defined settings. struct SettingFieldCustom { diff --git a/src/Core/SortDescription.cpp b/src/Core/SortDescription.cpp index ff5da536005..4c9d904567b 100644 --- a/src/Core/SortDescription.cpp +++ b/src/Core/SortDescription.cpp @@ -137,19 +137,4 @@ JSONBuilder::ItemPtr explainSortDescription(const SortDescription & description, return json_array; } -bool SortDescription::hasPrefix(const SortDescription & prefix) const -{ - if (prefix.empty()) - return true; - - if (prefix.size() > size()) - return false; - - for (size_t i = 0; i < prefix.size(); ++i) - { - if ((*this)[i] != prefix[i]) - return false; - } - return true; -} } diff --git a/src/Core/SortDescription.h b/src/Core/SortDescription.h index 956b312fa9a..f9cd4b73305 100644 --- a/src/Core/SortDescription.h +++ b/src/Core/SortDescription.h @@ -93,8 +93,8 @@ struct SortColumnDescription bool operator == (const SortColumnDescription & other) const { - return column_name == other.column_name && column_number == other.column_number - && direction == other.direction && (nulls_direction == other.nulls_direction || nulls_direction == 0 || other.nulls_direction == 0); + return column_name == other.column_name && column_number == other.column_number && direction == other.direction + && nulls_direction == other.nulls_direction; } bool operator != (const SortColumnDescription & other) const @@ -130,7 +130,7 @@ struct SortColumnDescription class SortDescription : public std::vector { public: - bool hasPrefix(const SortDescription & prefix) const; + using vector::vector; }; /// Outputs user-readable description into `out`. diff --git a/src/Core/tests/gtest_protobuf.cpp b/src/Core/tests/gtest_protobuf.cpp index 63e2e0ff53e..64cf9f84af4 100644 --- a/src/Core/tests/gtest_protobuf.cpp +++ b/src/Core/tests/gtest_protobuf.cpp @@ -1,4 +1,5 @@ #include +#include #include "Interpreters/DistributedStages/PlanSegment.h" @@ -192,7 +193,7 @@ TEST_F(ProtobufTest, AggregateDescription) { std::default_random_engine eng(42); // construct valid object - auto obj = generateAggregateDescription(eng); + auto obj = generateAggregateDescription(eng, 6); // serialize to protobuf Protos::AggregateDescription pb; obj.toProto(pb); @@ -306,6 +307,23 @@ TEST_F(ProtobufTest, PlanSegmentInput) compareProto(pb, pb2); } +TEST_F(ProtobufTest, PlanSegmentOutput) +{ + std::default_random_engine eng(42); + // construct valid step + auto output = generatePlanSegmentOutput(eng); + // serialize to protobuf + Protos::PlanSegmentOutput pb; + output->toProto(pb); + // deserialize from protobuf + auto output2 = std::make_shared(); + output2->fillFromProto(pb); + // re-serialize to protobuf + Protos::PlanSegmentOutput pb2; + output2->toProto(pb2); + compareProto(pb, pb2); +} + TEST_F(ProtobufTest, InputOrderInfo) { std::default_random_engine eng(42); diff --git a/src/Core/tests/gtest_protobuf_common.h b/src/Core/tests/gtest_protobuf_common.h index 8510d13872b..e75698161da 100644 --- a/src/Core/tests/gtest_protobuf_common.h +++ b/src/Core/tests/gtest_protobuf_common.h @@ -44,11 +44,12 @@ #include #include #include -#include "Core/NamesAndTypes.h" -#include "DataTypes/DataTypeMap.h" -#include "DataTypes/DataTypeNullable.h" -#include "IO/WriteBuffer.h" -#include "Interpreters/Context.h" +#include +#include +#include +#include +#include +#include namespace DB::UnitTest { @@ -201,8 +202,8 @@ class ProtobufTest : public testing::Test { auto is_equal = isPlanStepEqual(*a, *b); ASSERT_TRUE(is_equal); - auto ha = hashPlanStep(*a); - auto hb = hashPlanStep(*b); + auto ha = hashPlanStep(*a, true); + auto hb = hashPlanStep(*b, true); ASSERT_EQ(ha, hb); } @@ -261,7 +262,7 @@ class ProtobufTest : public testing::Test static Block generateBlock(std::default_random_engine & eng, bool arr = false) { - std::vector columns = {"a", "b", "c"}; + std::vector columns = {"a", "b", "c", "col_0", "col_1", "col_3"}; size_t rows = 10; size_t stride = 1; size_t start = 0; @@ -334,7 +335,7 @@ class ProtobufTest : public testing::Test auto buckets = eng() % 1000; auto enforce_round_robin = eng() % 2 == 1; auto component = static_cast(eng() % 3); - auto result = Partitioning(handle, columns, require_handle, buckets, enforce_round_robin, component); + auto result = Partitioning(handle, columns, require_handle, buckets, nullptr, enforce_round_robin, component); return result; } @@ -400,7 +401,7 @@ class ProtobufTest : public testing::Test return res; } - static AggregateDescription generateAggregateDescription(std::default_random_engine & eng) + static AggregateDescription generateAggregateDescription(std::default_random_engine & /*eng*/, int i) { AggregateDescription res; AggregateFunctionProperties properties; @@ -412,7 +413,7 @@ class ProtobufTest : public testing::Test // generate Names // for (int i = 0; i < 10; ++i) // res.argument_names.emplace_back(fmt::format("text{}", eng() % 100)); - res.column_name = std::vector{"a", "b"}[eng() % 2]; + res.column_name = "col_" + std::to_string(i); res.mask_column = res.column_name; return res; } @@ -426,7 +427,7 @@ class ProtobufTest : public testing::Test keys.emplace_back(eng() % 3); AggregateDescriptions aggregates; for (int i = 0; i < 2; ++i) - aggregates.emplace_back(generateAggregateDescription(eng)); + aggregates.emplace_back(generateAggregateDescription(eng, i)); auto overflow_row = eng() % 2 == 1; auto max_rows_to_group_by = eng() % 1000; auto group_by_overflow_mode = static_cast(eng() % 3); @@ -526,7 +527,6 @@ class ProtobufTest : public testing::Test { Block header = {ColumnWithTypeAndName(ColumnUInt8::create(), std::make_shared(), "local_exchange_test")}; AddressInfo local_address("localhost", 0, "test", "123456"); - PlanSegmentInputs inputs; auto input = std::make_shared(header, PlanSegmentType::EXCHANGE); input->setExchangeParallelSize(2); @@ -536,6 +536,22 @@ class ProtobufTest : public testing::Test return input; } + static std::shared_ptr generatePlanSegmentOutput(std::default_random_engine & eng) + { + Block header = {ColumnWithTypeAndName(ColumnUInt8::create(), std::make_shared(), "local_exchange_test")}; + auto output = std::make_shared(header, PlanSegmentType::EXCHANGE); + output->setExchangeParallelSize(2); + output->setExchangeId(3); + output->setPlanSegmentId(4); + output->setKeepOrder(true); + output->setShuffleFunctionName("bucket"); + Array params; + params.emplace_back(generateField(eng)); + params.emplace_back(generateField(eng)); + output->setShuffleFunctionParams(params); + return output; + } + static WindowFrame generateWindowFrame(std::default_random_engine & eng) { WindowFrame res; diff --git a/src/DaemonManager/DaemonJobServerBGThread.cpp b/src/DaemonManager/DaemonJobServerBGThread.cpp index 54f8efca4e6..9621b900066 100644 --- a/src/DaemonManager/DaemonJobServerBGThread.cpp +++ b/src/DaemonManager/DaemonJobServerBGThread.cpp @@ -94,6 +94,13 @@ std::unordered_map getUUIDsFromCatalog(DaemonJobServerBGThread auto data_models = context.getCnchCatalog()->getAllTables(); for (const auto & data_model : data_models) { + LOG_DEBUG(log, + "data model database: {}, name: {}, status: {}, definition: {}", + data_model.database(), + data_model.name(), + data_model.status(), + data_model.definition() + ); if (Status::isDetached(data_model.status()) || Status::isDeleted(data_model.status())) continue; diff --git a/src/DaemonManager/DaemonJobTxnGC.cpp b/src/DaemonManager/DaemonJobTxnGC.cpp index a841a6c490d..474347e7c8b 100644 --- a/src/DaemonManager/DaemonJobTxnGC.cpp +++ b/src/DaemonManager/DaemonJobTxnGC.cpp @@ -37,10 +37,24 @@ namespace DB::DaemonManager bool DaemonJobTxnGC::executeImpl() { - const Context & context = *getContext(); - auto txn_records - = context.getCnchCatalog()->getTransactionRecordsForGC(context.getConfigRef().getInt("cnch_txn_clean_batch_size", 200000)); + String last_start_key = start_key; + size_t wanted_txn_number = context.getConfigRef().getInt("cnch_txn_clean_batch_size", 100000); + /// Normally, older transaction get higher priority to be cleaned, + /// so we will always scan from the start. + /// In some (rare) cases, we want to clean transactions in the middle of the transaction lists, + /// that's where `cnch_txn_clean_round_robin` works. + /// Please note that "round robin" might not work as you expected, + /// as transactions inserted fast at the end of the lists, "another round" + /// might never come. + bool round_robin = context.getConfigRef().getBool("cnch_txn_clean_round_robin", false); + if (!round_robin) + start_key = ""; + auto txn_records = context.getCnchCatalog()->getTransactionRecordsForGC( + start_key, wanted_txn_number); + + LOG_DEBUG( + log, "start_key changed from: {} to {} (wanted {}, get {})", last_start_key, start_key, wanted_txn_number, txn_records.size()); if (!txn_records.empty()) { cleanTxnRecords(txn_records); diff --git a/src/DaemonManager/DaemonJobTxnGC.h b/src/DaemonManager/DaemonJobTxnGC.h index f82f168acbf..b670b91681e 100644 --- a/src/DaemonManager/DaemonJobTxnGC.h +++ b/src/DaemonManager/DaemonJobTxnGC.h @@ -69,6 +69,7 @@ class DaemonJobTxnGC : public DaemonJob using TransactionRecords = std::vector; private: + String start_key; void cleanTxnRecords(const TransactionRecords & records); void cleanUndoBuffers(const TransactionRecords & records); void cleanTxnRecord(const TransactionRecord & record, TxnTimestamp current_time, std::vector & cleanTxnIds, TxnGCLog & summary); diff --git a/src/DaemonManager/DaemonManager.cpp b/src/DaemonManager/DaemonManager.cpp index 818b4c3f8f3..a649cbfe35d 100644 --- a/src/DaemonManager/DaemonManager.cpp +++ b/src/DaemonManager/DaemonManager.cpp @@ -147,7 +147,7 @@ std::vector createLocalDaemonJobs( std::map default_config = { { "GLOBAL_GC", 5000}, { "AUTO_STATISTICS", 10000}, - { "TXN_GC", 600000} + { "TXN_GC", 5 * 60 * 1000} }; std::map config = updateConfig(std::move(default_config), app_config); diff --git a/src/DataStreams/ITTLAlgorithm.cpp b/src/DataStreams/ITTLAlgorithm.cpp index 7513e0c6ce0..854bcf53768 100644 --- a/src/DataStreams/ITTLAlgorithm.cpp +++ b/src/DataStreams/ITTLAlgorithm.cpp @@ -10,13 +10,12 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -ITTLAlgorithm::ITTLAlgorithm( - const TTLDescription & description_, const TTLInfo & old_ttl_info_, time_t current_time_, bool force_) +ITTLAlgorithm::ITTLAlgorithm(const TTLDescription & description_, const TTLInfo & old_ttl_info_, time_t current_time_, bool force_) : description(description_) , old_ttl_info(old_ttl_info_) , current_time(current_time_) , force(force_) - , date_lut(DateLUT::instance()) + , date_lut(DateLUT::serverTimezoneInstance()) { } diff --git a/src/DataStreams/MongoDBBlockInputStream.cpp b/src/DataStreams/MongoDBBlockInputStream.cpp index d583cb0d5b4..dc5920a1140 100644 --- a/src/DataStreams/MongoDBBlockInputStream.cpp +++ b/src/DataStreams/MongoDBBlockInputStream.cpp @@ -270,7 +270,7 @@ namespace throw Exception{"Type mismatch, expected Timestamp, got type id = " + toString(value.type()) + " for column " + name, ErrorCodes::TYPE_MISMATCH}; - assert_cast(column).getData().push_back(static_cast(DateLUT::instance().toDayNum( + assert_cast(column).getData().push_back(static_cast(DateLUT::serverTimezoneInstance().toDayNum( static_cast &>(value).value().epochTime()))); break; } diff --git a/src/DataStreams/TemporaryFileStream.h b/src/DataStreams/TemporaryFileStreamLegacy.h similarity index 87% rename from src/DataStreams/TemporaryFileStream.h rename to src/DataStreams/TemporaryFileStreamLegacy.h index ce9071801d0..f2baf6705fa 100644 --- a/src/DataStreams/TemporaryFileStream.h +++ b/src/DataStreams/TemporaryFileStreamLegacy.h @@ -13,19 +13,19 @@ namespace DB { /// To read the data that was flushed into the temporary data file. -struct TemporaryFileStream +struct TemporaryFileStreamLegacy { ReadBufferFromFile file_in; CompressedReadBuffer compressed_in; BlockInputStreamPtr block_in; - explicit TemporaryFileStream(const std::string & path) + explicit TemporaryFileStreamLegacy(const std::string & path) : file_in(path) , compressed_in(file_in) , block_in(std::make_shared(compressed_in, DBMS_TCP_PROTOCOL_VERSION)) {} - TemporaryFileStream(const std::string & path, const Block & header_) + TemporaryFileStreamLegacy(const std::string & path, const Block & header_) : file_in(path) , compressed_in(file_in) , block_in(std::make_shared(compressed_in, header_, 0)) @@ -63,7 +63,7 @@ class TemporaryFileLazyInputStream : public IBlockInputStream return {}; if (!stream) - stream = std::make_unique(path, header); + stream = std::make_unique(path, header); auto block = stream->block_in->read(); if (!block) @@ -78,7 +78,7 @@ class TemporaryFileLazyInputStream : public IBlockInputStream const std::string path; Block header; bool done; - std::unique_ptr stream; + std::unique_ptr stream; }; } diff --git a/src/DataTypes/DataTypeDate32.h b/src/DataTypes/DataTypeDate32.h index cc2e850c970..d97c466bb48 100644 --- a/src/DataTypes/DataTypeDate32.h +++ b/src/DataTypes/DataTypeDate32.h @@ -28,10 +28,7 @@ class DataTypeDate32 final : public DataTypeNumberBase TypeIndex getTypeId() const override { return TypeIndex::Date32; } const char * getFamilyName() const override { return family_name; } - Field getDefault() const override - { - return -static_cast(DateLUT::instance().getDayNumOffsetEpoch()); - } + Field getDefault() const override { return -static_cast(DateLUT::serverTimezoneInstance().getDayNumOffsetEpoch()); } bool canBeUsedAsVersion() const override { return true; } bool canBeInsideNullable() const override { return true; } diff --git a/src/DataTypes/DataTypeDateTime.h b/src/DataTypes/DataTypeDateTime.h index 926d529a5d8..fa44099d8af 100644 --- a/src/DataTypes/DataTypeDateTime.h +++ b/src/DataTypes/DataTypeDateTime.h @@ -43,7 +43,9 @@ class TimezoneMixin * all types with different time zones are equivalent and may be used interchangingly. * Time zone only affects parsing and displaying in text formats. * - * If time zone is not specified (example: DateTime without parameter), then default time zone is used. + * If time zone is not specified (example: DateTime without parameter), + * then `session_timezone` setting value is used. + * If `session_timezone` is not set (or empty string), server default time zone is used. * Default time zone is server time zone, if server is doing transformations * and if client is doing transformations, unless 'use_client_time_zone' setting is passed to client; * Server time zone is the time zone specified in 'timezone' parameter in configuration file, diff --git a/src/DataTypes/DataTypeString.cpp b/src/DataTypes/DataTypeString.cpp index c064a9b2845..8e80b0ccf02 100644 --- a/src/DataTypes/DataTypeString.cpp +++ b/src/DataTypes/DataTypeString.cpp @@ -113,6 +113,7 @@ void registerDataTypeString(DataTypeFactory & factory) factory.registerAlias("NCHAR", "String", DataTypeFactory::CaseInsensitive); factory.registerAlias("CHARACTER", "String", DataTypeFactory::CaseInsensitive); factory.registerAlias("VARCHAR", "String", DataTypeFactory::CaseInsensitive); + factory.registerAlias("VARBINARY", "String", DataTypeFactory::CaseInsensitive); factory.registerAlias("NVARCHAR", "String", DataTypeFactory::CaseInsensitive); factory.registerAlias("VARCHAR2", "String", DataTypeFactory::CaseInsensitive); /// Oracle factory.registerAlias("TEXT", "String", DataTypeFactory::CaseInsensitive); diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index 50ce606e0f7..7be66a7f936 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -487,6 +487,7 @@ struct WhichDataType constexpr bool isSimple() const { return isInt() || isUInt() || isFloat() || isString(); } constexpr bool isBitmap64() const { return idx == TypeIndex::BitMap64; } constexpr bool isLowCardinality() const { return idx == TypeIndex::LowCardinality; } + constexpr bool isSketchBinary() const { return idx == TypeIndex::SketchBinary; } }; /// IDataType helpers (alternative for IDataType virtual methods with single point of truth) diff --git a/src/DataTypes/Serializations/SerializationDate.cpp b/src/DataTypes/Serializations/SerializationDate.cpp index 942a0449323..87962d10947 100644 --- a/src/DataTypes/Serializations/SerializationDate.cpp +++ b/src/DataTypes/Serializations/SerializationDate.cpp @@ -33,7 +33,7 @@ void SerializationDate::checkDataOverflow(const FormatSettings & settings) void SerializationDate::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const { - writeDateText(DayNum(assert_cast(column).getData()[row_num]), ostr); + writeDateText(DayNum(assert_cast(column).getData()[row_num]), ostr, time_zone); } void SerializationDate::deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const @@ -44,7 +44,7 @@ void SerializationDate::deserializeWholeText(IColumn & column, ReadBuffer & istr void SerializationDate::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { DayNum x; - readDateText(x, istr); + readDateText(x, istr, time_zone); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); } @@ -65,7 +65,7 @@ void SerializationDate::deserializeTextQuoted(IColumn & column, ReadBuffer & ist { DayNum x; assertChar('\'', istr); - readDateText(x, istr); + readDateText(x, istr, time_zone); assertChar('\'', istr); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); /// It's important to do this at the end - for exception safety. @@ -82,7 +82,7 @@ void SerializationDate::deserializeTextJSON(IColumn & column, ReadBuffer & istr, { DayNum x; assertChar('"', istr); - readDateText(x, istr); + readDateText(x, istr, time_zone); assertChar('"', istr); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); diff --git a/src/DataTypes/Serializations/SerializationDate.h b/src/DataTypes/Serializations/SerializationDate.h index 1cc21dbccbf..e76883f2649 100644 --- a/src/DataTypes/Serializations/SerializationDate.h +++ b/src/DataTypes/Serializations/SerializationDate.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB { @@ -8,6 +9,8 @@ namespace DB class SerializationDate final : public SerializationNumber { public: + explicit SerializationDate(const DateLUTImpl & time_zone_ = DateLUT::sessionInstance()): time_zone(time_zone_) {} + void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; @@ -19,6 +22,9 @@ class SerializationDate final : public SerializationNumber void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; +protected: + const DateLUTImpl & time_zone; + private: static void checkDataOverflow(const FormatSettings & settings); }; diff --git a/src/DataTypes/Serializations/SerializationDate32.cpp b/src/DataTypes/Serializations/SerializationDate32.cpp index cb45c09b25d..03977accbc2 100644 --- a/src/DataTypes/Serializations/SerializationDate32.cpp +++ b/src/DataTypes/Serializations/SerializationDate32.cpp @@ -47,7 +47,7 @@ void SerializationDate32::checkDataOverflow(const FormatSettings & settings) void SerializationDate32::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const { - writeDateText(ExtendedDayNum(assert_cast(column).getData()[row_num]), ostr); + writeDateText(ExtendedDayNum(assert_cast(column).getData()[row_num]), ostr, time_zone); } void SerializationDate32::deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const @@ -60,7 +60,7 @@ void SerializationDate32::deserializeWholeText(IColumn & column, ReadBuffer & is void SerializationDate32::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const { ExtendedDayNum x; - readDateText(x, istr); + readDateText(x, istr, time_zone); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); } @@ -81,7 +81,7 @@ void SerializationDate32::deserializeTextQuoted(IColumn & column, ReadBuffer & i { ExtendedDayNum x; assertChar('\'', istr); - readDateText(x, istr); + readDateText(x, istr, time_zone); assertChar('\'', istr); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); /// It's important to do this at the end - for exception safety. @@ -98,7 +98,7 @@ void SerializationDate32::deserializeTextJSON(IColumn & column, ReadBuffer & ist { ExtendedDayNum x; assertChar('"', istr); - readDateText(x, istr); + readDateText(x, istr, time_zone); assertChar('"', istr); checkDataOverflow(settings); assert_cast(column).getData().push_back(x); diff --git a/src/DataTypes/Serializations/SerializationDate32.h b/src/DataTypes/Serializations/SerializationDate32.h index e4fde183c4d..3063d82da3c 100644 --- a/src/DataTypes/Serializations/SerializationDate32.h +++ b/src/DataTypes/Serializations/SerializationDate32.h @@ -16,12 +16,15 @@ #pragma once #include +#include namespace DB { class SerializationDate32 final : public SerializationNumber { public: + explicit SerializationDate32(const DateLUTImpl & time_zone_ = DateLUT::sessionInstance()): time_zone(time_zone_) {} + void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; @@ -33,6 +36,9 @@ class SerializationDate32 final : public SerializationNumber void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; +protected: + const DateLUTImpl & time_zone; + private: static void checkDataOverflow(const FormatSettings & settings); }; diff --git a/src/Databases/DatabasesCommon.cpp b/src/Databases/DatabasesCommon.cpp index 91241a92e67..20c8680d088 100644 --- a/src/Databases/DatabasesCommon.cpp +++ b/src/Databases/DatabasesCommon.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -188,7 +189,8 @@ StoragePtr DatabaseWithOwnTablesBase::detachTableUnlocked(const String & table_n auto table_id = res->getStorageID(); if (table_id.hasUUID()) { - assert(database_name == DatabaseCatalog::TEMPORARY_DATABASE || getUUID() != UUIDHelpers::Nil || ((res->getName() == "CloudMergeTree") && (getEngineName() == "Memory"))); + [[maybe_unused]] bool is_cloud = dynamic_cast(res.get()) != nullptr; + assert(database_name == DatabaseCatalog::TEMPORARY_DATABASE || getUUID() != UUIDHelpers::Nil || (is_cloud && (getEngineName() == "Memory"))); DatabaseCatalog::instance().removeUUIDMapping(table_id.uuid); } diff --git a/src/Dictionaries/ClickHouseDictionarySource.cpp b/src/Dictionaries/ClickHouseDictionarySource.cpp index a1057589e5e..aa37212bfaf 100644 --- a/src/Dictionaries/ClickHouseDictionarySource.cpp +++ b/src/Dictionaries/ClickHouseDictionarySource.cpp @@ -157,7 +157,7 @@ std::string ClickHouseDictionarySource::getUpdateFieldAndDate() if (update_time != std::chrono::system_clock::from_time_t(0)) { time_t hr_time = std::chrono::system_clock::to_time_t(update_time) - configuration.update_lag; - std::string str_time = DateLUT::instance().timeToString(hr_time); + std::string str_time = DateLUT::serverTimezoneInstance().timeToString(hr_time); update_time = std::chrono::system_clock::now(); return query_builder.composeUpdateQuery(configuration.update_field, str_time); } diff --git a/src/Dictionaries/MySQLDictionarySource.cpp b/src/Dictionaries/MySQLDictionarySource.cpp index c6f6a0684fd..1d706c176d2 100644 --- a/src/Dictionaries/MySQLDictionarySource.cpp +++ b/src/Dictionaries/MySQLDictionarySource.cpp @@ -120,7 +120,7 @@ std::string MySQLDictionarySource::getUpdateFieldAndDate() if (update_time != std::chrono::system_clock::from_time_t(0)) { time_t hr_time = std::chrono::system_clock::to_time_t(update_time) - configuration.update_lag; - std::string str_time = DateLUT::instance().timeToString(hr_time); + std::string str_time = DateLUT::serverTimezoneInstance().timeToString(hr_time); update_time = std::chrono::system_clock::now(); return query_builder.composeUpdateQuery(configuration.update_field, str_time); } diff --git a/src/Dictionaries/PostgreSQLDictionarySource.cpp b/src/Dictionaries/PostgreSQLDictionarySource.cpp index f1ca3d4855b..ed80f67df86 100644 --- a/src/Dictionaries/PostgreSQLDictionarySource.cpp +++ b/src/Dictionaries/PostgreSQLDictionarySource.cpp @@ -145,7 +145,7 @@ std::string PostgreSQLDictionarySource::getUpdateFieldAndDate() if (update_time != std::chrono::system_clock::from_time_t(0)) { time_t hr_time = std::chrono::system_clock::to_time_t(update_time) - configuration.update_lag; - std::string str_time = DateLUT::instance().timeToString(hr_time); + std::string str_time = DateLUT::serverTimezoneInstance().timeToString(hr_time); update_time = std::chrono::system_clock::now(); return query_builder.composeUpdateQuery(configuration.update_field, str_time); } diff --git a/src/Dictionaries/XDBCDictionarySource.cpp b/src/Dictionaries/XDBCDictionarySource.cpp index 0a2c375f961..83ce0f7e5bc 100644 --- a/src/Dictionaries/XDBCDictionarySource.cpp +++ b/src/Dictionaries/XDBCDictionarySource.cpp @@ -143,7 +143,7 @@ std::string XDBCDictionarySource::getUpdateFieldAndDate() if (update_time != std::chrono::system_clock::from_time_t(0)) { time_t hr_time = std::chrono::system_clock::to_time_t(update_time) - configuration.update_lag; - std::string str_time = DateLUT::instance().timeToString(hr_time); + std::string str_time = DateLUT::serverTimezoneInstance().timeToString(hr_time); update_time = std::chrono::system_clock::now(); return query_builder.composeUpdateQuery(configuration.update_field, str_time); } diff --git a/src/Disks/DiskByteS3.cpp b/src/Disks/DiskByteS3.cpp index 72456b45467..83717419034 100644 --- a/src/Disks/DiskByteS3.cpp +++ b/src/Disks/DiskByteS3.cpp @@ -143,7 +143,8 @@ DiskByteS3::DiskByteS3(const String& name_, const String& root_prefix_, const St disk_id(next_disk_id.fetch_add(1)), name(name_), root_prefix(root_prefix_), s3_util(client_, bucket_, true), reader_opts(std::make_shared(client_, bucket_)), reserved_bytes(0), reservation_count(0), - min_upload_part_size(min_upload_part_size_), max_single_part_upload_size(max_single_part_upload_size_) + min_upload_part_size(min_upload_part_size_), max_single_part_upload_size(max_single_part_upload_size_), + log(&Poco::Logger::get(name)) { } @@ -196,6 +197,17 @@ void DiskByteS3::listFiles(const String& path, std::vector& file_names) std::unique_ptr DiskByteS3::readFile(const String & path, const ReadSettings & settings) const { + if (unlikely(settings.remote_fs_read_failed_injection != 0)) + { + if (settings.remote_fs_read_failed_injection == -1) + throw Exception("remote_fs_read_failed_injection is enabled and return error immediately", ErrorCodes::LOGICAL_ERROR); + else + { + LOG_TRACE(log, "remote_fs_read_failed_injection is enabled and will sleep {}ms", settings.remote_fs_read_failed_injection); + std::this_thread::sleep_for(std::chrono::milliseconds(settings.remote_fs_read_failed_injection)); + } + } + String object_key = std::filesystem::path(root_prefix) / path; if (IO::Scheduler::IOSchedulerSet::instance().enabled() && settings.enable_io_scheduler) { if (settings.enable_io_pfra) { @@ -241,9 +253,31 @@ std::unique_ptr DiskByteS3::readFile(const String & path std::unique_ptr DiskByteS3::writeFile(const String & path, const WriteSettings & settings) { - return std::make_unique(s3_util.getClient(), s3_util.getBucket(), - std::filesystem::path(root_prefix) / path, max_single_part_upload_size, - min_upload_part_size, settings.file_meta, settings.buffer_size, false, nullptr, 0, true); + if (unlikely(settings.remote_fs_write_failed_injection != 0)) + { + if (settings.remote_fs_write_failed_injection == -1) + throw Exception("remote_fs_write_failed_injection is enabled and return error immediately", ErrorCodes::LOGICAL_ERROR); + else + { + LOG_TRACE(log, "remote_fs_write_failed_injection is enabled and will sleep {}ms", settings.remote_fs_write_failed_injection); + std::this_thread::sleep_for(std::chrono::milliseconds(settings.remote_fs_write_failed_injection)); + } + } + + { + return std::make_unique( + s3_util.getClient(), + s3_util.getBucket(), + std::filesystem::path(root_prefix) / path, + max_single_part_upload_size, + min_upload_part_size, + settings.file_meta, + settings.buffer_size, + false, + nullptr, + 0, + true); + } } void DiskByteS3::removeFile(const String& path) diff --git a/src/Disks/DiskByteS3.h b/src/Disks/DiskByteS3.h index 6b2d3de783b..350be8c0cd0 100644 --- a/src/Disks/DiskByteS3.h +++ b/src/Disks/DiskByteS3.h @@ -109,6 +109,9 @@ class DiskByteS3: public IDisk virtual void removeRecursive(const String & path) override; + /// For S3, only need to remove the data file + virtual void removePart(const String & path) override { removeFileIfExists(fs::path(path) / "data"); } + virtual void setLastModified(const String & , const Poco::Timestamp & ) override { throw Exception("setLastModified is not implemented in DiskByteS3", ErrorCodes::NOT_IMPLEMENTED); } virtual Poco::Timestamp getLastModified(const String & ) override { throw Exception("getLastModified is not implemented in DiskByteS3", ErrorCodes::NOT_IMPLEMENTED); } @@ -147,6 +150,7 @@ class DiskByteS3: public IDisk UInt64 min_upload_part_size; UInt64 max_single_part_upload_size; + Poco::Logger * log; }; using DiskByteS3Ptr = std::shared_ptr; diff --git a/src/Disks/DiskLocal.cpp b/src/Disks/DiskLocal.cpp index ba263d7dd01..33300ec13b1 100644 --- a/src/Disks/DiskLocal.cpp +++ b/src/Disks/DiskLocal.cpp @@ -136,16 +136,33 @@ bool DiskLocal::tryReserve(UInt64 bytes) } auto available_space = getAvailableSpace(); - auto unreserved_space = available_space - DiskStats{std::min(available_space.bytes, reserved_bytes), std::min(available_space.inodes, reserved_inodes)}; - if (!unreserved_space.isEmpty()) + auto unreserved_space + = available_space - DiskStats{std::min(available_space.bytes, reserved_bytes), std::min(available_space.inodes, reserved_inodes)}; + if (unreserved_space.bytes >= bytes) { - LOG_DEBUG(log, "Reserving {} on disk {}, having unreserved {}({}).", - ReadableSize(bytes), backQuote(name), ReadableSize(unreserved_space.bytes), unreserved_space.inodes); + LOG_TRACE( + log, + "Reserving {} on disk {}(free {}({})), having unreserved {}({}).", + ReadableSize(bytes), + backQuote(name), + ReadableSize(available_space.bytes), + available_space.inodes, + ReadableSize(unreserved_space.bytes), + unreserved_space.inodes); ++reservation_count; reserved_bytes += bytes; - reserved_inodes += 1; return true; } + + LOG_WARNING( + log, + "Can't reserving {} on disk {}(free {}({})), having unreserved {}({}).", + ReadableSize(bytes), + backQuote(name), + ReadableSize(available_space.bytes), + available_space.inodes, + ReadableSize(unreserved_space.bytes), + unreserved_space.inodes); return false; } @@ -437,7 +454,7 @@ void registerDiskLocal(DiskFactory & factory) config.getUInt64(config_prefix + ".keep_free_space_inodes", 0), config.getUInt64("global_keep_free_space_inodes", 0)); double ratio = std::max( config.getDouble(config_prefix + ".keep_free_space_ratio", 0), - config.getDouble(config_prefix + "global_keep_free_space_ratio", 0.05)); + config.getDouble("global_keep_free_space_ratio", 0.05)); if (ratio < 0 || ratio > 1) throw Exception("'keep_free_space_ratio' have to be between 0 and 1", ErrorCodes::EXCESSIVE_ELEMENT_IN_CONFIG); diff --git a/src/Disks/DiskLocal.h b/src/Disks/DiskLocal.h index fdc6fc9cbba..2fe954bbebe 100644 --- a/src/Disks/DiskLocal.h +++ b/src/Disks/DiskLocal.h @@ -130,7 +130,7 @@ class DiskLocal : public IDisk const DiskStats keep_free_disk_stats; UInt64 reserved_bytes = 0; - UInt64 reserved_inodes = 0; + UInt64 reserved_inodes = 0; // TODO: placeholder and not implemented yet UInt64 reservation_count = 0; static std::mutex reservation_mutex; diff --git a/src/Disks/HDFS/DiskByteHDFS.cpp b/src/Disks/HDFS/DiskByteHDFS.cpp index 3cbf76dcd24..c5feefd14e8 100644 --- a/src/Disks/HDFS/DiskByteHDFS.cpp +++ b/src/Disks/HDFS/DiskByteHDFS.cpp @@ -88,6 +88,7 @@ DiskByteHDFS::DiskByteHDFS(const String & disk_name_, const String & hdfs_base_p { pread_reader_opts = std::make_shared(hdfs_params, true); read_reader_opts = std::make_shared(hdfs_params, false); + log = &Poco::Logger::get("DiskByteHDFS"); } const String & DiskByteHDFS::getName() const @@ -186,6 +187,18 @@ void DiskByteHDFS::listFiles(const String & path, std::vector & file_nam std::unique_ptr DiskByteHDFS::readFile(const String & path, const ReadSettings & settings) const { + if (unlikely(settings.remote_fs_read_failed_injection != 0)) + { + if (settings.remote_fs_read_failed_injection == -1) + throw Exception("remote_fs_read_failed_injection is enabled and return error immediately", ErrorCodes::LOGICAL_ERROR); + else + { + LOG_TRACE(log, "remote_fs_read_failed_injection is enabled and will sleep {}ms", settings.remote_fs_read_failed_injection); + std::this_thread::sleep_for(std::chrono::milliseconds(settings.remote_fs_read_failed_injection)); + } + } + + String file_path = absolutePath(path); if (IO::Scheduler::IOSchedulerSet::instance().enabled() && settings.enable_io_scheduler) { @@ -229,8 +242,22 @@ std::unique_ptr DiskByteHDFS::readFile(const String & pa std::unique_ptr DiskByteHDFS::writeFile(const String & path, const WriteSettings & settings) { - int write_mode = settings.mode == WriteMode::Append ? (O_APPEND | O_WRONLY) : O_WRONLY; - return std::make_unique(absolutePath(path), hdfs_params, settings.buffer_size, write_mode); + if (unlikely(settings.remote_fs_write_failed_injection != 0)) + { + if (settings.remote_fs_write_failed_injection == -1) + throw Exception("remote_fs_write_failed_injection is enabled and return error immediately", ErrorCodes::LOGICAL_ERROR); + else + { + LOG_TRACE(log, "remote_fs_write_failed_injection is enabled and will sleep {}ms", settings.remote_fs_write_failed_injection); + std::this_thread::sleep_for(std::chrono::milliseconds(settings.remote_fs_write_failed_injection)); + } + } + + { + int write_mode = settings.mode == WriteMode::Append ? (O_APPEND | O_WRONLY) : O_WRONLY; + return std::make_unique(absolutePath(path), hdfs_params, + settings.buffer_size, write_mode); + } } void DiskByteHDFS::removeFile(const String & path) @@ -257,6 +284,24 @@ void DiskByteHDFS::removeRecursive(const String & path) hdfs_fs.remove(absolutePath(path), true); } +void DiskByteHDFS::removePart(const String & path) +{ + try + { + removeRecursive(path); + } + catch (Poco::FileException &e) + { + /// We don't know if this exception is caused by a non-existent path, + /// so we need to determine it manually + if (!exists(path)) { + /// the part has already been deleted, exit + return; + } + throw e; + } +} + void DiskByteHDFS::setLastModified(const String & path, const Poco::Timestamp & timestamp) { hdfs_fs.setLastModifiedInSeconds(absolutePath(path), timestamp.epochTime()); diff --git a/src/Disks/HDFS/DiskByteHDFS.h b/src/Disks/HDFS/DiskByteHDFS.h index d441536d693..dac24f353f0 100644 --- a/src/Disks/HDFS/DiskByteHDFS.h +++ b/src/Disks/HDFS/DiskByteHDFS.h @@ -89,6 +89,8 @@ class DiskByteHDFS final : public IDisk virtual void removeRecursive(const String & path) override; + virtual void removePart(const String & path) override; + virtual void setLastModified(const String & path, const Poco::Timestamp & timestamp) override; virtual Poco::Timestamp getLastModified(const String & path) override; @@ -111,6 +113,8 @@ class DiskByteHDFS final : public IDisk std::shared_ptr read_reader_opts; HDFSFileSystem hdfs_fs; + + Poco::Logger * log; }; class DiskByteHDFSReservation: public IReservation diff --git a/src/Disks/IDisk.cpp b/src/Disks/IDisk.cpp index 9c2ad1f25d3..e819e880583 100644 --- a/src/Disks/IDisk.cpp +++ b/src/Disks/IDisk.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace DB { diff --git a/src/Disks/IDisk.h b/src/Disks/IDisk.h index 65476de045a..bd1ec8a8c05 100644 --- a/src/Disks/IDisk.h +++ b/src/Disks/IDisk.h @@ -268,6 +268,9 @@ class IDisk : public Space /// Second bool param is a flag to remove (true) or keep (false) shared data on S3 virtual void removeSharedFileIfExists(const String & path, bool) { removeFileIfExists(path); } + /// Remove CNCH MergeTreeDataPart, only used in ByteS3/ByteHDFS disks + virtual void removePart(const String &) { throw Exception("removePart is not implemented", ErrorCodes::NOT_IMPLEMENTED); } + /// Set last modified time to file or directory at `path`. virtual void setLastModified(const String & path, const Poco::Timestamp & timestamp) = 0; diff --git a/src/FormaterTool/PartMergerImpl.cpp b/src/FormaterTool/PartMergerImpl.cpp index b30b80b0fd2..bf8cabd2fef 100644 --- a/src/FormaterTool/PartMergerImpl.cpp +++ b/src/FormaterTool/PartMergerImpl.cpp @@ -47,14 +47,21 @@ void PartMergerImpl::copyPartData(const DiskPtr & from_disk, const String & from std::shared_ptr PartMergerImpl::createStorage(const String & path, const String & create_table_query) { auto context = getContext(); - auto storage = createStorageFromQuery(create_table_query, context); + auto ast = getASTCreateQueryFromString(create_table_query, context); + ASTCreateQuery & create_query = *ast; + /// CloudMergeTree checks for non-empty UUID in its constructor, + /// let's fake it (not used in part-merger anyway) + UUID fake_cnch_uuid = UUIDHelpers::generateV4(); + modifyOrAddSetting(create_query, "cnch_table_uuid", Field(UUIDHelpers::UUIDToString(fake_cnch_uuid))); + auto storage = createStorageFromQuery(create_query, context); auto merge_tree = std::dynamic_pointer_cast(storage); - merge_tree->setRelativeDataPath(IStorage::StorageLocation::MAIN, path); if (!merge_tree) { /// Must use part-merger with `ENGINE = CloudMergeTree`. throw Exception("Please choose `CloudMergeTree` as the engine.", ErrorCodes::INVALID_CONFIG_PARAMETER); } + /// IMPORTANT: reset table relative path to the requested value + merge_tree->setRelativeDataPath(IStorage::StorageLocation::MAIN, path); return merge_tree; } diff --git a/src/FormaterTool/PartToolkitBase.cpp b/src/FormaterTool/PartToolkitBase.cpp index 67f20f5c037..f89f93d5d1a 100644 --- a/src/FormaterTool/PartToolkitBase.cpp +++ b/src/FormaterTool/PartToolkitBase.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace DB { @@ -167,6 +168,9 @@ StoragePtr PartToolkitBase::getTable() ForeignKeysDescription foreign_keys = InterpreterCreateQuery::getForeignKeysDescription(create.columns_list->foreign_keys); UniqueNotEnforcedDescription unique = InterpreterCreateQuery::getUniqueNotEnforcedDescription(create.columns_list->unique); + /// In PartTools, BitEngineEncode is illegal, discard + processIgnoreBitEngineEncode(columns); + StoragePtr res = StorageFactory::instance().get( create, PT_RELATIVE_LOCAL_PATH, @@ -183,6 +187,21 @@ StoragePtr PartToolkitBase::getTable() } } +void PartToolkitBase::processIgnoreBitEngineEncode(ColumnsDescription & columns) +{ + auto reset_bitengine_encode = [](auto & column) + { + if (column.type->isBitEngineEncode()) + { + auto bitmap_type = std::make_shared(); + bitmap_type->setFlags(column.type->getFlags()); + bitmap_type->resetFlags(TYPE_BITENGINE_ENCODE_FLAG); + const_cast(column).type = std::move(bitmap_type); + } + }; + std::for_each(columns.begin(), columns.end(), reset_bitengine_encode); +} + PartNamesWithDisks PartToolkitBase::collectPartsFromSource(const String & source_dirs_str, const String & dest_dir) { diff --git a/src/FormaterTool/PartToolkitBase.h b/src/FormaterTool/PartToolkitBase.h index 4d4b2b7648e..5a216982a67 100644 --- a/src/FormaterTool/PartToolkitBase.h +++ b/src/FormaterTool/PartToolkitBase.h @@ -56,6 +56,8 @@ class PartToolkitBase : public WithMutableContext StoragePtr getTable(); + void processIgnoreBitEngineEncode(ColumnsDescription & columns); + PartNamesWithDisks collectPartsFromSource(const String & source_dirs_str, const String & dest_dir); const ASTPtr & query_ptr; diff --git a/src/Formats/FormatFactory.cpp b/src/Formats/FormatFactory.cpp index f604cf15835..634fb158c08 100644 --- a/src/Formats/FormatFactory.cpp +++ b/src/Formats/FormatFactory.cpp @@ -167,6 +167,7 @@ FormatSettings getFormatSettings(ContextPtr context, const Settings & settings) format_settings.map.skip_null_map_value = settings.input_format_skip_null_map_value; format_settings.map.max_map_key_length = settings.input_format_max_map_key_long; format_settings.check_data_overflow = settings.check_data_overflow; + format_settings.date_time_overflow_behavior = settings.date_time_overflow_behavior; /// Validate avro_schema_registry_url with RemoteHostFilter when non-empty and in Server context if (format_settings.schema.is_server) diff --git a/src/Formats/FormatSettings.h b/src/Formats/FormatSettings.h index 389a4cdaa48..b2cde5bbf67 100644 --- a/src/Formats/FormatSettings.h +++ b/src/Formats/FormatSettings.h @@ -82,6 +82,16 @@ struct FormatSettings DateTimeOutputFormat date_time_output_format = DateTimeOutputFormat::Simple; + enum class DateTimeOverflowBehavior + { + Ignore, + Throw, + Saturate + }; + + DateTimeOverflowBehavior date_time_overflow_behavior = DateTimeOverflowBehavior::Ignore; + + UInt64 input_allow_errors_num = 0; Float32 input_allow_errors_ratio = 0; diff --git a/src/Functions/FunctionAddTime.cpp b/src/Functions/FunctionAddTime.cpp index 9b1cb06c8f7..c189fc7d9f8 100644 --- a/src/Functions/FunctionAddTime.cpp +++ b/src/Functions/FunctionAddTime.cpp @@ -238,12 +238,12 @@ class FunctionAddOrSubTime : public IFunction switch (base_type->getTypeId()) { case TypeIndex::Date: { - const auto & time_zone = DateLUT::instance(); + const auto & time_zone = DateLUT::sessionInstance(); executeInternal(base_col, delta_arg, result_col.get(), time_zone, 0); break; } case TypeIndex::Date32: { - const auto & time_zone = DateLUT::instance(); + const auto & time_zone = DateLUT::sessionInstance(); executeInternal(base_col, delta_arg, result_col.get(), time_zone, 0); break; } @@ -260,7 +260,7 @@ class FunctionAddOrSubTime : public IFunction } case TypeIndex::Time: { const auto & t = assert_cast(*arguments[0].type); - const auto & time_zone = DateLUT::instance(); + const auto & time_zone = DateLUT::sessionInstance(); executeInternal(base_col, delta_arg, result_col.get(), time_zone, t.getScale()); break; } diff --git a/src/Functions/FunctionBucket.cpp b/src/Functions/FunctionBucket.cpp new file mode 100644 index 00000000000..eab51171627 --- /dev/null +++ b/src/Functions/FunctionBucket.cpp @@ -0,0 +1,11 @@ +#include +#include +#include + +namespace DB +{ +REGISTER_FUNCTION(Bucket) +{ + factory.registerFunction(); +} +} diff --git a/src/Functions/FunctionBucket.h b/src/Functions/FunctionBucket.h new file mode 100644 index 00000000000..c3cc969826a --- /dev/null +++ b/src/Functions/FunctionBucket.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +template +class FunctionBucket : public IExecutableFunction +{ +public: + static constexpr auto name = "bucket"; + + explicit FunctionBucket( + ExecutableFunctionPtr hash_function_, + UInt64 bucket_size_, + UInt64 is_with_range_, + UInt64 split_number_) + : hash_function(std::move(hash_function_)) + , bucket_size(bucket_size_) + , is_with_range(is_with_range_) + , split_number(split_number_) + , split_number_argument(ColumnWithTypeAndName{}) + { + } + + explicit FunctionBucket( + ExecutableFunctionPtr hash_function_, + UInt64 bucket_size_, + UInt64 is_with_range_, + UInt64 split_number_, + ColumnWithTypeAndName split_number_argument_) + : hash_function(std::move(hash_function_)) + , bucket_size(bucket_size_) + , is_with_range(is_with_range_) + , split_number(split_number_) + , split_number_argument(std::move(split_number_argument_)) + { + } + + std::string getName() const override { return name; } + + bool useDefaultImplementationForConstants() const override { return true; } + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForNothing() const override { return false; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } + + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + { + auto result = ColumnUInt64::create(input_rows_count, 0); + auto & result_data = result->getData(); + ColumnPtr hash_result; + if constexpr (ModSplitNumberInside) + { + ColumnsWithTypeAndName full_args = arguments; + full_args.emplace_back(split_number_argument); + hash_result = hash_function->execute(full_args, result_type, input_rows_count, false); + } + else + { + hash_result = hash_function->execute(arguments, result_type, input_rows_count, false); + } + + const auto * hash_result_ptr = typeid_cast *>(hash_result.get()); + if (!hash_result_ptr) + { + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Function {} return unexpected type id: {}, it should return ColumnUInt64", + hash_function->getName(), + hash_result->getDataType()); + } + auto & hash_data = (const_cast *>(hash_result_ptr))->getData(); + if constexpr (!ModSplitNumberInside) + { + if (split_number > 0) + { + for (size_t i = 0; i < input_rows_count; i++) + { + hash_data[i] = hash_data[i] % split_number; + } + } + } + + if (!is_with_range) + { + for (size_t i = 0; i < input_rows_count; i++) + { + result_data[i] = hash_data[i] % bucket_size; + } + } + else + { + auto shard_ratio = split_number / bucket_size; + shard_ratio = shard_ratio == 0 ? 1 : shard_ratio; + for (size_t i = 0; i < input_rows_count; i++) + { + // implicit floor for shard ratio. + // split_number has no constraint to match user requirement, so a shard_ratio(0), when split_number < bucket_size , is ok for customer. + UInt64 bucket_number = hash_data[i] / shard_ratio; + bucket_number = bucket_number >= bucket_size ? bucket_size - 1 : bucket_number; + result_data[i] = bucket_number; + } + } + + return result; + } + +private: + ExecutableFunctionPtr hash_function; + UInt64 bucket_size; + bool is_with_range; + UInt64 split_number; + ColumnWithTypeAndName split_number_argument; +}; + +class BucketFunctionBase : public IFunctionBase +{ +public: + static constexpr auto name = "bucket"; + BucketFunctionBase(DataTypes argument_types_, ContextPtr context_) + : argument_types(std::move(argument_types_)), context(std::move(context_)) + { + } + + String getName() const override { return name; } + + const DataTypes & getArgumentTypes() const override { return argument_types; } + + virtual const DataTypePtr & getResultType() const override { return BucketFunctionBase::RESULT_DATA_TYPE; } + + virtual ExecutableFunctionPtr prepareWithParameters(const ColumnsWithTypeAndName & arguments, const Array & parameters) const override + { + if (parameters.size() != 4) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires 4 parameters", getName()); + } + const String & hash_func_name = parameters[0].safeGet(); + auto bucket_size = parameters[1].safeGet(); + if (bucket_size == 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} requires positive bucket_size", getName()); + } + auto is_with_range = parameters[2].safeGet(); + auto split_number = parameters[3].safeGet(); + + FunctionOverloadResolverPtr hash_func_builder = FunctionFactory::instance().get(hash_func_name, context); + + if (hash_func_name == "dtspartition") + { + auto split_number_column + = ColumnWithTypeAndName{ColumnInt64::create(1, split_number), BucketFunctionBase::SPLIT_NUMBER_TYPE, ""}; + auto full_args = arguments; + full_args.emplace_back(split_number_column); + FunctionBasePtr hash_func_base = hash_func_builder->build(full_args); + auto executable_hash_func = hash_func_base->prepare(full_args); + return std::make_unique>( + executable_hash_func, bucket_size, is_with_range, split_number, split_number_column); + } + + FunctionBasePtr hash_func_base = hash_func_builder->build(arguments); + auto executable_hash_func = hash_func_base->prepare(arguments); + return std::make_unique>(executable_hash_func, bucket_size, is_with_range, split_number); + } + + bool isDeterministic() const override { return true; } + bool isDeterministicInScopeOfQuery() const override { return true; } + + bool isSuitableForConstantFolding() const override { return false; } + + static const DataTypePtr RESULT_DATA_TYPE; + static const DataTypePtr SPLIT_NUMBER_TYPE; + +private: + DataTypes argument_types; + ContextPtr context; +}; + + +const DataTypePtr BucketFunctionBase::RESULT_DATA_TYPE = std::make_shared(); +const DataTypePtr BucketFunctionBase::SPLIT_NUMBER_TYPE = std::make_shared(); + +class FunctionBucketOverloadResolver : public IFunctionOverloadResolver +{ +public: + static constexpr auto name = "bucket"; + + explicit FunctionBucketOverloadResolver(ContextPtr context_) : context(std::move(context_)) { } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 0; } + + bool isVariadic() const override { return true; } + + static FunctionOverloadResolverPtr create(ContextPtr context_) + { + return std::make_unique(std::move(context_)); + } + + DataTypePtr getReturnTypeImpl(const DataTypes &) const override { return BucketFunctionBase::RESULT_DATA_TYPE; } + + FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const override + { + if (arguments.size() != 1 && arguments.size() != 2) + throw Exception( + "Number of arguments for function " + getName() + " should be 1 or 2.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + DataTypes arguments_types; + + for (const auto & arg : arguments) + { + arguments_types.push_back(arg.type); + } + return std::make_unique(arguments_types, std::move(context)); + } + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForNothing() const override { return false; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } + +private: + ContextPtr context; +}; + +} diff --git a/src/Functions/FunctionCustomWeekToSomething.h b/src/Functions/FunctionCustomWeekToSomething.h index a62b7cce92e..e73cf771915 100644 --- a/src/Functions/FunctionCustomWeekToSomething.h +++ b/src/Functions/FunctionCustomWeekToSomething.h @@ -211,7 +211,7 @@ class FunctionCustomWeekToSomething : public IFunction } /// This method is called only if the function has one argument. Therefore, we do not care about the non-local time zone. - const DateLUTImpl & date_lut = DateLUT::instance(); + const DateLUTImpl & date_lut = DateLUT::sessionInstance(); if (left.isNull() || right.isNull()) return is_not_monotonic; diff --git a/src/Functions/FunctionDateOrDateTimeToSomething.h b/src/Functions/FunctionDateOrDateTimeToSomething.h index 9bff4839037..981a836775b 100644 --- a/src/Functions/FunctionDateOrDateTimeToSomething.h +++ b/src/Functions/FunctionDateOrDateTimeToSomething.h @@ -180,7 +180,7 @@ class FunctionDateOrDateTimeToSomething : public IFunctionDateOrDateTime(&type)) date_lut = &timezone->getTimeZone(); if (left.isNull() || right.isNull()) diff --git a/src/Functions/FunctionSQLJSON.cpp b/src/Functions/FunctionSQLJSON.cpp index 7fa853ae850..3be29bbaf14 100644 --- a/src/Functions/FunctionSQLJSON.cpp +++ b/src/Functions/FunctionSQLJSON.cpp @@ -7,14 +7,17 @@ namespace DB REGISTER_FUNCTION(SQLJSON) { - factory.registerFunction>(); - factory.registerFunction>(); - factory.registerFunction>(); - factory.registerFunction>(); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerFunction>(FunctionFactory::CaseInsensitive); + factory.registerAlias("JSON_SIZE", "JSON_LENGTH", FunctionFactory::CaseInsensitive); + // factory.registerAlias("JSON_ARRAY_LENGTH", "JSON_LENGTH", FunctionFactory::CaseInsensitive); factory.registerFunction>(FunctionFactory::CaseInsensitive); factory.registerFunction>(FunctionFactory::CaseInsensitive); - factory.registerFunction>(FunctionFactory::CaseInsensitive); - + factory.registerFunction>(FunctionFactory::CaseInsensitive); } } diff --git a/src/Functions/FunctionSQLJSON.h b/src/Functions/FunctionSQLJSON.h index b32aed0973c..f96c29bbcff 100644 --- a/src/Functions/FunctionSQLJSON.h +++ b/src/Functions/FunctionSQLJSON.h @@ -22,18 +22,20 @@ #include #include #include -#include "Common/assert_cast.h" +#include #include #include #include #include -#include "Columns/ColumnObject.h" -#include "Columns/IColumn.h" -#include "Core/ColumnsWithTypeAndName.h" -#include "DataTypes/IDataType.h" -#include "Functions/FunctionsComparison.h" -#include "Parsers/IAST_fwd.h" +#include +#include +#include +#include +#include +#include +#include #include +#include #if !defined(ARCADIA_BUILD) #include "config_functions.h" @@ -48,23 +50,145 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; } +template +class JSONUtils +{ +public: + using Element = typename JSONParser::Element; + using Object = typename JSONParser::Object; + using Array = typename JSONParser::Array; + + static bool jsonElementEqual(const Element & left, const Element & right) + { + if (left.isInt64() && right.isInt64()) + { + return left.getInt64() == right.getInt64(); + } + else if (left.isUInt64() && right.isUInt64()) + { + return left.getUInt64() == right.getUInt64(); + } + else if (left.isDouble() && right.isDouble()) + { + return left.getDouble() == right.getDouble(); + } + else if (left.isString() && right.isString()) + { + return left.getString() == right.getString(); + } + else if (left.isBool() && right.isBool()) + { + return left.getBool() == right.getBool(); + } + else if (left.isNull() && right.isNull()) + { + return true; + } + + return false; + } + + static bool jsonArrayContains(const Array & json_array, const Element & sub_element) + { + if (sub_element.isArray()) + { + const auto & sub_array = sub_element.getArray(); + for (auto it = sub_array.begin(); it != sub_array.end(); ++it) + { + if (!jsonArrayContains(json_array, *it)) + { + return false; + } + } + } + else if (sub_element.isObject()) + { + return false; + } + else + { + for (auto it = json_array.begin(); it != json_array.end(); ++it) + { + if (jsonElementEqual(*it, sub_element)) + { + return true; + } + } + + return false; + } + + return true; + } + + static bool jsonObjectContains(const Object & json_object, const Element & sub_element) + { + if (sub_element.isObject()) + { + for (const auto & [key, value] : sub_element.getObject()) + { + Element temp_element; + bool contains_key = json_object.find(key, temp_element); + if (!contains_key) + return false; + + if (temp_element.isObject()) + { + if (!jsonObjectContains(temp_element.getObject(), value)) + return false; + else + continue; + } + + if (temp_element.isArray()) + { + if (!jsonArrayContains(temp_element.getArray(), value)) + return false; + else + continue; + } + + if (!jsonElementEqual(temp_element, value)) + return false; + } + } + else + { + return false; + } + + return true; + } + + static bool contains(const Element & parent_element, const Element & sub_element) + { + if (parent_element.isObject()) + { + return jsonObjectContains(parent_element.getObject(), sub_element); + } + else if (parent_element.isArray()) + { + return jsonArrayContains(parent_element.getArray(), sub_element); + } + else + { + return jsonElementEqual(parent_element, sub_element); + } + } +}; + class FunctionSQLJSONHelpers { public: template typename Impl, class JSONParser> - class Executor + class ExecutorString { public: static ColumnPtr - run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, uint32_t parse_depth) + run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, uint32_t parse_depth, DialectType dialect_type) { MutableColumnPtr to{result_type->createColumn()}; to->reserve(input_rows_count); - // TODO: add logic to handle single argument - if (arguments.size() < 2) - { - throw Exception{"JSONPath functions require at least 2 arguments", ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION}; - } const auto & json_column = arguments[0]; @@ -118,7 +242,13 @@ class FunctionSQLJSONHelpers const bool parse_res = parser.parse(token_iterator, res, expected); if (!parse_res) { - throw Exception{"Unable to parse JSONPath", ErrorCodes::BAD_ARGUMENTS}; + if (dialect_type != DialectType::MYSQL) + throw Exception{"Unable to parse JSONPath", ErrorCodes::BAD_ARGUMENTS}; + else + { + to->insertManyDefaults(input_rows_count); + return to; + } } /// Get data and offsets for 2 argument (JSON) @@ -131,7 +261,15 @@ class FunctionSQLJSONHelpers bool document_ok = false; /// Parse JSON for every row - Impl impl; + Impl> impl; + + constexpr bool has_member_prepare = requires + { + impl.prepare("", DataTypePtr{}); + }; + + if constexpr (has_member_prepare) + impl.prepare(Name::name, result_type); for (const auto i : collections::range(0, input_rows_count)) { @@ -140,9 +278,10 @@ class FunctionSQLJSONHelpers document_ok = json_parser.parse(json, document); bool added_to_column = false; + ElementIterator iterator(document); if (document_ok) { - added_to_column = impl.insertResultToColumn(*to, document, res); + added_to_column = impl.insertResultToColumn(*to, iterator, res, dialect_type); } if (!added_to_column) { @@ -152,41 +291,479 @@ class FunctionSQLJSONHelpers return to; } }; + + + template typename Impl, class JSONParser> + class ExecutorObject + { + public: + template + static ColumnPtr run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, uint32_t parse_depth, DialectType dialect_type) + { + MutableColumnPtr to{result_type->createColumn()}; + to->reserve(input_rows_count); + + const auto & json_column = arguments[0]; + + if (!isObject(json_column.type) && !isTuple(json_column.type)) + { + throw Exception( + "JSONPath functions require first argument to be JSON of Object or Tuple, illegal type: " + json_column.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + const auto & json_path_column = arguments[1]; + + if (!isString(json_path_column.type)) + { + throw Exception( + "JSONPath functions require second argument to be JSONPath of type string, illegal type: " + + json_path_column.type->getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + if (!isColumnConst(*json_path_column.column)) + { + throw Exception("Second argument (JSONPath) must be constant string", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + const ColumnPtr & arg_jsonpath = json_path_column.column; + const auto * arg_jsonpath_const = typeid_cast(arg_jsonpath.get()); + const auto * arg_jsonpath_string = typeid_cast(arg_jsonpath_const->getDataColumnPtr().get()); + + const ColumnPtr & arg_json = json_column.column; + const auto * col_json_const = typeid_cast(arg_json.get()); + const auto * col_json_object + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + + ColumnPtr column_tuple; + DataTypePtr type_tuple; + + if constexpr (std::is_same_v) + { + std::tie(column_tuple, type_tuple) = unflattenObjectToTuple(*col_json_object); + } + else + { + column_tuple = col_json_object->getPtr(); + type_tuple = json_column.type; + } + + /// Get data and offsets for 1 argument (JSONPath) + const ColumnString::Chars & chars_path = arg_jsonpath_string->getChars(); + const ColumnString::Offsets & offsets_path = arg_jsonpath_string->getOffsets(); + + /// Prepare to parse 1 argument (JSONPath) + const char * query_begin = reinterpret_cast(&chars_path[0]); + const char * query_end = query_begin + offsets_path[0] - 1; + + /// Tokenize query + Tokens tokens(query_begin, query_end); + /// Max depth 0 indicates that depth is not limited + IParser::Pos token_iterator(tokens, parse_depth); + + /// Parse query and create AST tree + Expected expected; + ASTPtr res; + ParserJSONPath parser; + const bool parse_res = parser.parse(token_iterator, res, expected); + if (!parse_res) + { + if (dialect_type != DialectType::MYSQL) + throw Exception{"Unable to parse JSONPath", ErrorCodes::BAD_ARGUMENTS}; + else + { + to->insertManyDefaults(input_rows_count); + return to; + } + } + + // Element document; + + /// Parse JSON for every row + Impl impl; + + constexpr bool has_member_prepare = requires + { + impl.prepare("", DataTypePtr{}); + }; + + if constexpr (has_member_prepare) + impl.prepare(Name::name, result_type); + + for (const auto i : collections::range(0, input_rows_count)) + { + ObjectIterator iterator(type_tuple, column_tuple, col_json_const ? 0 : i); + bool added_to_column = impl.insertResultToColumn(*to, iterator, res, dialect_type); + + if (!added_to_column) + { + to->insertDefault(); + } + } + return to; + } + }; }; -template typename Impl> -class FunctionSQLJSON : public IFunction, WithConstContext +template +class ExecutableFunctionSQLJSONBase : public IExecutableFunction { + public: - static FunctionPtr create(ContextPtr context_) { return std::make_shared(context_); } - explicit FunctionSQLJSON(ContextPtr context_) : WithConstContext(context_) { } + explicit ExecutableFunctionSQLJSONBase(const NullPresence & null_presence_, const DataTypePtr & json_return_type_, uint32_t parser_depth_, DialectType dialect_type_) + : null_presence(null_presence_), json_return_type(json_return_type_), parser_depth(parser_depth_), dialect_type(dialect_type_) + { + } - static constexpr auto name = Name::name; String getName() const override { return Name::name; } - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { - return Impl::getReturnType(Name::name, arguments); + if (null_presence.has_null_constant) + return result_type->createColumnConstWithDefaultValue(input_rows_count); + + auto temp_arguments = null_presence.has_nullable ? createBlockWithNestedColumns(arguments) : arguments; + auto temporary_result = Derived::run(temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + if (null_presence.has_nullable) + return wrapInNullable(temporary_result, arguments, result_type, input_rows_count); + return temporary_result; } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override +private: + NullPresence null_presence; + DataTypePtr json_return_type; + uint32_t parser_depth; + DialectType dialect_type; +}; + +template typename Impl, bool allow_simdjson> +class ExecutableFunctionSQLJSONString : public ExecutableFunctionSQLJSONBase> +{ +public: + using Base = ExecutableFunctionSQLJSONBase; + + ExecutableFunctionSQLJSONString(const NullPresence & null_presence_, const DataTypePtr & json_return_type_, uint32_t parser_depth_, DialectType dialect_type_) + : Base(null_presence_, json_return_type_, parser_depth_, dialect_type_) + { + } + + static ColumnPtr run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & json_return_type, size_t input_rows_count, uint32_t parser_depth, const DialectType & dialect_type) + { + auto temp_arguments = arguments; + if (temp_arguments.size() < 2) + { + DataTypePtr default_path_type = std::make_shared(); + MutableColumnPtr default_path_string_column = default_path_type->createColumn(); + default_path_string_column->insert("$"); + MutableColumnPtr default_path_column = ColumnConst::create(std::move(default_path_string_column), 1); + temp_arguments.emplace_back(ColumnWithTypeAndName(std::move(default_path_column), default_path_type, "$")); + } + + return chooseAndRunJSONParser(temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } + +private: + static ColumnPtr chooseAndRunJSONParser(const ColumnsWithTypeAndName & arguments, const DataTypePtr & json_return_type, size_t input_rows_count, uint32_t parser_depth, const DialectType & dialect_type) { - /// Choose JSONParser. - /// 1. Lexer(path) -> Tokens - /// 2. Create ASTPtr - /// 3. Parser(Tokens, ASTPtr) -> complete AST - /// 4. Execute functions: call getNextItem on generator and handle each item - uint32_t parse_depth = getContext()->getSettingsRef().max_parser_depth; #if USE_SIMDJSON - if (getContext()->getSettingsRef().allow_simdjson) - return FunctionSQLJSONHelpers::Executor::run(arguments, result_type, input_rows_count, parse_depth); + if constexpr (allow_simdjson) + return FunctionSQLJSONHelpers::ExecutorString::run(arguments, json_return_type, input_rows_count, parser_depth, dialect_type); +#endif + + return FunctionSQLJSONHelpers::ExecutorString::run(arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } +}; + +template typename Impl, bool allow_simdjson> +class ExecutableFunctionSQLJSONObject : public ExecutableFunctionSQLJSONBase> +{ +public: + using Base = ExecutableFunctionSQLJSONBase; + + ExecutableFunctionSQLJSONObject(const NullPresence & null_presence_, const DataTypePtr & json_return_type_, uint32_t parser_depth_, DialectType dialect_type_) + : Base(null_presence_, json_return_type_, parser_depth_, dialect_type_) + { + } + + static ColumnPtr run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & json_return_type, size_t input_rows_count, uint32_t parser_depth, DialectType dialect_type) + { + assert(!arguments.empty()); + + auto temp_arguments = arguments; + if (temp_arguments.size() < 2) + { + DataTypePtr default_path_type = std::make_shared(); + MutableColumnPtr default_path_string_column = default_path_type->createColumn(); + default_path_string_column->insert("$"); + MutableColumnPtr default_path_column = ColumnConst::create(std::move(default_path_string_column), 1); + temp_arguments.emplace_back(ColumnWithTypeAndName(std::move(default_path_column), default_path_type, "$")); + } + const auto & type_object = assert_cast(*temp_arguments[0].type); + const auto & arg_object = temp_arguments[0].column; + const auto * column_const = typeid_cast(arg_object.get()); + const auto * column_object + = typeid_cast(column_const ? column_const->getDataColumnPtr().get() : arg_object.get()); + + assert(column_object); + if (column_object->hasNullableSubcolumns()) + { + auto non_nullable_object = ColumnObject::create(false); + for (const auto & entry : column_object->getSubcolumns()) + { + auto new_subcolumn = recursiveAssumeNotNullable(entry->data.getFinalizedColumnPtr()); + non_nullable_object->addSubcolumn(entry->path, new_subcolumn->assumeMutable()); + } + + temp_arguments[0].type = std::make_shared(type_object.getSchemaFormat(), false); + temp_arguments[0].column = std::move(non_nullable_object); + + if (column_const) + temp_arguments[0].column = ColumnConst::create(temp_arguments[0].column, column_const->size()); + } + +#if USE_SIMDJSON + if constexpr (allow_simdjson) + { + return FunctionSQLJSONHelpers::ExecutorObject::template run( + temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } #endif - return FunctionSQLJSONHelpers::Executor::run(arguments, result_type, input_rows_count, parse_depth); + + return FunctionSQLJSONHelpers::ExecutorObject::template run( + temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } +}; + +template typename Impl, bool allow_simdjson> +class ExecutableFunctionSQLJSONTuple : public ExecutableFunctionSQLJSONBase> +{ +public: + using Base = ExecutableFunctionSQLJSONBase; + + ExecutableFunctionSQLJSONTuple(const NullPresence & null_presence_, const DataTypePtr & json_return_type_, uint32_t parser_depth_, DialectType dialect_type_) + : Base(null_presence_, json_return_type_, parser_depth_, dialect_type_) + { + } + + static ColumnPtr + run(const ColumnsWithTypeAndName & arguments, const DataTypePtr & json_return_type, size_t input_rows_count, uint32_t parser_depth, DialectType dialect_type) + { + auto temp_arguments = arguments; + if (temp_arguments.size() < 2) + { + DataTypePtr default_path_type = std::make_shared(); + MutableColumnPtr default_path_string_column = default_path_type->createColumn(); + default_path_string_column->insert("$"); + MutableColumnPtr default_path_column = ColumnConst::create(std::move(default_path_string_column), 1); + temp_arguments.emplace_back(ColumnWithTypeAndName(std::move(default_path_column), default_path_type, "$")); + } +#if USE_SIMDJSON + if constexpr (allow_simdjson) + { + return FunctionSQLJSONHelpers::ExecutorObject::template run( + temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } +#endif + + return FunctionSQLJSONHelpers::ExecutorObject::template run( + temp_arguments, json_return_type, input_rows_count, parser_depth, dialect_type); + } +}; + + +template +class FunctionBaseFunctionSQLJSON : public IFunctionBase +{ +public: + String getName() const override { return Name::name; } + + const DataTypes & getArgumentTypes() const override { return argument_types; } + + const DataTypePtr & getResultType() const override { return return_type; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + +protected: + explicit FunctionBaseFunctionSQLJSON( + const NullPresence & null_presence_, + DataTypes argument_types_, + DataTypePtr return_type_, + DataTypePtr json_return_type_, + uint32_t parser_depth_, + DialectType dialect_type_) + : null_presence(null_presence_) + , argument_types(std::move(argument_types_)) + , return_type(std::move(return_type_)) + , json_return_type(std::move(json_return_type_)) + , parser_depth(parser_depth_) + , dialect_type(dialect_type_) + { + } + + NullPresence null_presence; + bool allow_simdjson; + DataTypes argument_types; + DataTypePtr return_type; + DataTypePtr json_return_type; + uint32_t parser_depth; + DialectType dialect_type; +}; + +template typename Impl> +class FunctionBaseFunctionSQLJSONString : public FunctionBaseFunctionSQLJSON +{ +public: + template + explicit FunctionBaseFunctionSQLJSONString(bool allow_simdjson_, Args &&... args) + : FunctionBaseFunctionSQLJSON{std::forward(args)...} + , allow_simdjson(allow_simdjson_) + { + } + + ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override + { + if (this->allow_simdjson) + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } +private: + bool allow_simdjson; +}; + +template typename Impl> +class FunctionBaseFunctionSQLJSONObject : public FunctionBaseFunctionSQLJSON +{ +public: + template + explicit FunctionBaseFunctionSQLJSONObject(bool allow_simdjson_, Args &&... args) + : FunctionBaseFunctionSQLJSON{std::forward(args)...} + , allow_simdjson(allow_simdjson_) + { + } + + ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override + { + if (this->allow_simdjson) + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + } + +private: + bool allow_simdjson; +}; + +template typename Impl> +class FunctionBaseFunctionSQLJSONTuple : public FunctionBaseFunctionSQLJSON +{ +public: + template + explicit FunctionBaseFunctionSQLJSONTuple(bool allow_simdjson_, Args &&... args) + : FunctionBaseFunctionSQLJSON{std::forward(args)...} + , allow_simdjson(allow_simdjson_) + { + } + + ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override + { + if (this->allow_simdjson) + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + + return std::make_unique>(this->null_presence, this->json_return_type, this->parser_depth, this->dialect_type); + } +private: + bool allow_simdjson; +}; + +using ObjectIterator = FunctionJSONHelpers::ObjectIterator; +template +using ElementIterator = FunctionJSONHelpers::JSONElementIterator; + +/// We use IFunctionOverloadResolver instead of IFunction to handle non-default NULL processing. +/// Both NULL and JSON NULL should generate NULL value. If any argument is NULL, return NULL. +template typename Impl> +class SQLJSONOverloadResolver : public IFunctionOverloadResolver, WithContext +{ +public: + static constexpr auto name = Name::name; + + String getName() const override { return name; } + + static FunctionOverloadResolverPtr create(ContextPtr context_) + { + return std::make_unique(context_); + } + + explicit SQLJSONOverloadResolver(ContextPtr context_) : WithContext(context_) {} + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForNulls() const override { return false; } + + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } + + FunctionBasePtr build(const ColumnsWithTypeAndName & arguments) const override + { + bool has_nothing_argument = false; + for (const auto & arg : arguments) + has_nothing_argument |= isNothing(arg.type); + + if (arguments.empty()) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Function {} requires at least one argument", Name::name); + + const auto & first_column = arguments[0]; + auto first_type_base = removeNullable(removeLowCardinality(first_column.type)); + + bool is_string = isString(first_type_base); + bool is_object = isObject(first_type_base); + bool is_tuple = isTuple(first_type_base); + bool is_nothing = isNothing(first_type_base); + + if (!is_string && !is_object && !is_tuple && !is_nothing) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "The first argument of function {} should be a string containing JSON or Object or Tuple, illegal type: {}", + Name::name, first_column.type->getName()); + + auto json_return_type = Impl::getReturnType(Name::name, createBlockWithNestedColumns(arguments)); + NullPresence null_presence = getNullPresense(arguments); + DataTypePtr return_type; + if (has_nothing_argument) + return_type = std::make_shared(); + else if (null_presence.has_null_constant) + return_type = makeNullable(std::make_shared()); + else if (null_presence.has_nullable) + return_type = makeNullable(json_return_type); + else + return_type = json_return_type; + + /// Top-level LowCardinality columns are processed outside JSON parser. + json_return_type = removeLowCardinality(json_return_type); + + DataTypes argument_types; + argument_types.reserve(arguments.size()); + for (const auto & argument : arguments) + argument_types.emplace_back(argument.type); + + auto allow_simdjson = getContext()->getSettingsRef().allow_simdjson; + uint32_t parser_depth = getContext()->getSettingsRef().max_parser_depth; + DialectType dialect_type = getContext()->getSettingsRef().dialect_type; + if (is_string || is_nothing) + return std::make_unique>( + allow_simdjson, null_presence, argument_types, return_type, json_return_type, parser_depth, dialect_type); + else if (is_object) + return std::make_unique>( + allow_simdjson, null_presence, argument_types, return_type, json_return_type, parser_depth, dialect_type); + else + return std::make_unique>( + allow_simdjson, null_presence, argument_types, return_type, json_return_type, parser_depth, dialect_type); } }; @@ -220,25 +797,142 @@ struct NameSQLJSONContainsPath static constexpr auto name{"JSON_CONTAINS_PATH"}; }; -struct NameSQLJSONExtractPath +struct NameSQLJSONArrayContains +{ + static constexpr auto name{"JSON_ARRAY_CONTAINS"}; +}; + +struct NameSQLJSONKeys +{ + static constexpr auto name{"JSON_KEYS"}; +}; + +struct NameSQLJSONExtract { static constexpr auto name{"JSON_EXTRACT"}; }; -template +template +class SQLJSONKeysImpl +{ +public: + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_unique(std::make_shared()); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsElementIterator + { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; + GeneratorJSONPath generator_json_path(query_ptr); + Element current_element = iterator.getElement(); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } + current_element = iterator.getElement(); + } + + Iterator sub_iterator{current_element}; + return JSONExtractKeysImpl::insertResultToColumn(dest, sub_iterator); + } + + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsObjectIterator + { + ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } + } + + return JSONExtractKeysImpl::insertResultToColumn(dest, iterator); + } +}; + +template +class SQLJSONExtractImpl +{ +public: + + static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) + { + return std::make_shared(); + } + + static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType dialect_type) requires IsElementIterator + { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; + GeneratorJSONPath generator_json_path(query_ptr); + Element current_element = iterator.getElement(); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } + current_element = iterator.getElement(); + } + + if (status == VisitorStatus::Exhausted) + { + return false; + } + + Iterator sub_iterator{current_element}; + return JSONExtractRawImpl::insertResultToColumn(dest, sub_iterator, dialect_type); + } + + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, ASTPtr & query_ptr, DialectType dialect_type) requires IsObjectIterator + { + ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } + } + + if (status == VisitorStatus::Exhausted) + { + return false; + } + + return JSONExtractRawImpl::insertResultToColumn(dest, iterator, dialect_type); + } +}; + +template class SQLJSONExistsImpl { public: - using Element = typename JSONParser::Element; static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Element & root, ASTPtr & query_ptr) + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsElementIterator { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; GeneratorJSONPath generator_json_path(query_ptr); - Element current_element = root; + Element current_element = iterator.getElement(); VisitorStatus status; while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) { @@ -246,7 +940,33 @@ class SQLJSONExistsImpl { break; } - current_element = root; + current_element = iterator.getElement(); + } + + /// insert result, status can be either Ok (if we found the item) + /// or Exhausted (if we never found the item) + ColumnUInt8 & col_bool = assert_cast(dest); + if (status == VisitorStatus::Ok) + { + col_bool.insert(1); + } + else + { + col_bool.insert(0); + } + return true; + } + + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsObjectIterator + { + ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } } /// insert result, status can be either Ok (if we found the item) @@ -264,20 +984,21 @@ class SQLJSONExistsImpl } }; -template +template class SQLJSONValueImpl { public: - using Element = typename JSONParser::Element; static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Element & root, ASTPtr & query_ptr) + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsElementIterator { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; GeneratorJSONPath generator_json_path(query_ptr); - Element current_element = root; + Element current_element = iterator.getElement(); VisitorStatus status; Element res; while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) @@ -295,7 +1016,7 @@ class SQLJSONValueImpl /// Here it is possible to handle errors with ON ERROR (as described in ISO/IEC TR 19075-6), /// however this functionality is not implemented yet } - current_element = root; + current_element = iterator.getElement(); } if (status == VisitorStatus::Exhausted) @@ -310,26 +1031,56 @@ class SQLJSONValueImpl col_str.insertData(output_str.data(), output_str.size()); return true; } + + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsObjectIterator + { + ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + const auto & element_type = iterator.getType(); + if (!(isArray(element_type) || isObject(element_type) || isTuple(element_type))) + break; + } + } + + if (status == VisitorStatus::Exhausted) + { + return false; + } + + auto row = iterator.getRow(); + if (const auto * column_string = typeid_cast(iterator.getColumn().get())) + { + dest.insertFrom(*column_string, row); + return true; + } + + return JSONExtractRawImpl::insertResultToColumn(dest, iterator); + } }; /** * Function to test jsonpath member access, will be removed in final PR * @tparam JSONParser parser */ -template +template class SQLJSONQueryImpl { public: - using Element = typename JSONParser::Element; static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Element & root, ASTPtr & query_ptr) + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType /*dialect_type*/) requires IsElementIterator { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; GeneratorJSONPath generator_json_path(query_ptr); - Element current_element = root; + Element current_element = iterator.getElement(); VisitorStatus status; std::stringstream out; // STYLE_CHECK_ALLOW_STD_STRING_STREAM /// Create json array of results: [res1, res2, ...] @@ -352,7 +1103,7 @@ class SQLJSONQueryImpl /// Here it is possible to handle errors with ON ERROR (as described in ISO/IEC TR 19075-6), /// however this functionality is not implemented yet } - current_element = root; + current_element = iterator.getElement(); } out << "]"; if (!success) @@ -364,13 +1115,17 @@ class SQLJSONQueryImpl col_str.insertData(output_str.data(), output_str.size()); return true; } + + static bool insertResultToColumn(IColumn & /*dest*/, ObjectIterator & /*iterator*/, ASTPtr & /*query_ptr*/, DialectType /*dialect_type*/) requires IsObjectIterator + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "JSON_QUERY is not implemented for Object or Tuple."); + } }; -template +template class SQLJSONLengthImpl { public: - using Element = typename JSONParser::Element; static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { @@ -379,10 +1134,12 @@ class SQLJSONLengthImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Element & root, ASTPtr & query_ptr) + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, ASTPtr & query_ptr, DialectType dialect_type) requires IsElementIterator { + using Element = typename Iterator::Element; + using JSONParser = typename Iterator::JSONParserType; GeneratorJSONPath generator_json_path(query_ptr); - Element current_element = root; + Element current_element = iterator.getElement(); VisitorStatus status; ColumnNullable & col = assert_cast(dest); @@ -392,7 +1149,43 @@ class SQLJSONLengthImpl { break; } - current_element = root; + current_element = iterator.getElement(); + } + + if (status == VisitorStatus::Exhausted) + { + col.insertData(nullptr, 0); + return false; + } + + size_t size; + if (current_element.isArray()) + size = current_element.getArray().size(); + else if (current_element.isObject()) + size = current_element.getObject().size(); + else + { + if (dialect_type == DialectType::MYSQL) + size = 0; + else + size = 1; + } + + col.insert(size); + return true; + } + + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, ASTPtr & query_ptr, DialectType dialect_type) requires IsObjectIterator + { + ColumnNullable & col = assert_cast(dest); + ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } } if (status == VisitorStatus::Exhausted) @@ -400,16 +1193,37 @@ class SQLJSONLengthImpl col.insertData(nullptr, 0); return false; } - - size_t size; - if (current_element.isArray()) - size = current_element.getArray().size(); - else if (current_element.isObject()) - size = current_element.getObject().size(); + + const auto * column_array = typeid_cast(iterator.getColumn().get()); + if (column_array) + { + const auto & offsets = column_array->getOffsets(); + auto row = iterator.getRow(); + UInt64 size = offsets[row] - offsets[row - 1]; + col.insert(size); + return true; + } + + const auto * column_tuple = typeid_cast(iterator.getColumn().get()); + if (column_tuple) + { + if (isDummyTuple(*iterator.getType())) + return false; + + UInt64 size = column_tuple->getColumns().size(); + col.insert(size); + return true; + } + + if (dialect_type == DialectType::MYSQL) + { + col.insert(0); + } else - size = 1; + { + col.insert(1); + } - col.insert(size); return true; } }; @@ -566,6 +1380,51 @@ class FunctionSQLJSONContains : public IFunction, WithConstContext return true; } + bool + insertResultToColumn(IColumn & dest, ElementIterator & iterator, const ColumnWithTypeAndName & candidate, ASTPtr & query_ptr) const + { + ColumnUInt8 & col_bool = assert_cast(dest); + + auto current_element = iterator.getElement(); + if (query_ptr) + { + GeneratorJSONPath generator_json_path(query_ptr); + VisitorStatus status; + while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + break; + } + current_element = iterator.getElement(); + } + + if (status == VisitorStatus::Exhausted) + { + return false; + } + } + + const auto & candidate_json_column = candidate.column; + const auto * candidate_json_const = typeid_cast(candidate_json_column.get()); + const auto * candidate_json_string = typeid_cast( + candidate_json_const ? candidate_json_const->getDataColumnPtr().get() : candidate_json_column.get()); + + std::string_view json{candidate_json_string ? candidate_json_string->getDataAt(0) : ""}; + SimdJSONParser json_parser; + using Element = typename SimdJSONParser::Element; + Element sub_document; + const bool parse_ok = json_parser.parse(json, sub_document); + + if (parse_ok) + { + bool contains = JSONUtils::contains(current_element, sub_document); + col_bool.insert(contains ? 1 : 0); + } + + return parse_ok; + } + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { //Only support Object JSON @@ -585,13 +1444,14 @@ class FunctionSQLJSONContains : public IFunction, WithConstContext const auto & json_column = arguments[0]; auto first_type_base = removeNullable(removeLowCardinality(json_column.type)); + bool is_string = isString(first_type_base); bool is_object = isObject(first_type_base); bool is_tuple = isTuple(first_type_base); - if (!is_object && !is_tuple) + if (!is_string && !is_object && !is_tuple) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "The first argument of function {} should be a string containing Object or Tuple, illegal type: {}", + "The first argument of function {} should be a string containing Object or Tuple or JSON, illegal type: {}", Name::name, json_column.type->getName()); @@ -602,27 +1462,6 @@ class FunctionSQLJSONContains : public IFunction, WithConstContext MutableColumnPtr to{result_type->createColumn()}; to->reserve(input_rows_count); - const auto & arg_json = json_column.column; - const auto * col_json_const = typeid_cast(arg_json.get()); - const auto * col_json_object - = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); - - if (!col_json_object) - throw Exception{ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()}; - - ColumnPtr column_tuple; - DataTypePtr type_tuple; - - if (is_object) - { - std::tie(column_tuple, type_tuple) = unflattenObjectToTuple(*col_json_object); - } - else - { - column_tuple = col_json_object->getPtr(); - type_tuple = json_column.type; - } - ASTPtr res; if (arguments.size() == 3) { @@ -669,14 +1508,69 @@ class FunctionSQLJSONContains : public IFunction, WithConstContext } } - for (const auto i : collections::range(0, input_rows_count)) + const auto & arg_json = json_column.column; + const auto * col_json_const = typeid_cast(arg_json.get()); + + if (is_object || is_tuple) { - ObjectIterator iterator(type_tuple, column_tuple, col_json_const ? 0 : i); + const auto * col_json_object + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); - bool added_to_column = this->insertResultToColumn(*to, iterator, candidate, res); - if (!added_to_column) + if (!col_json_object) + throw Exception{ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()}; + + ColumnPtr column_tuple; + DataTypePtr type_tuple; + + if (is_object) { - to->insertDefault(); + std::tie(column_tuple, type_tuple) = unflattenObjectToTuple(*col_json_object); + } + else + { + column_tuple = col_json_object->getPtr(); + type_tuple = json_column.type; + } + for (const auto i : collections::range(0, input_rows_count)) + { + ObjectIterator iterator(type_tuple, column_tuple, col_json_const ? 0 : i); + + bool added_to_column = this->insertResultToColumn(*to, iterator, candidate, res); + if (!added_to_column) + { + to->insertDefault(); + } + } + } + else + { + const auto * col_json_string + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + + const ColumnString::Chars & chars_json = col_json_string->getChars(); + const ColumnString::Offsets & offsets_json = col_json_string->getOffsets(); + + SimdJSONParser json_parser; + using Element = typename SimdJSONParser::Element; + Element document; + bool document_ok = false; + + for (const auto i : collections::range(0, input_rows_count)) + { + std::string_view json{ + reinterpret_cast(&chars_json[offsets_json[i - 1]]), offsets_json[i] - offsets_json[i - 1] - 1}; + document_ok = json_parser.parse(json, document); + + bool added_to_column = false; + ElementIterator iterator(document); + if (document_ok) + { + added_to_column = this->insertResultToColumn(*to, iterator, candidate, res); + } + if (!added_to_column) + { + to->insertDefault(); + } } } @@ -713,16 +1607,50 @@ class FunctionSQLJSONContainsPath : public IFunction, WithConstContext for (const auto & query_ptr : query_ptrs) { + auto temp_iterator = iterator; ObjectJSONGeneratorJSONPath generator_json_path(query_ptr); VisitorStatus status; - while ((status = generator_json_path.getNextItem(iterator)) != VisitorStatus::Exhausted) + while ((status = generator_json_path.getNextItem(temp_iterator)) != VisitorStatus::Exhausted) + { + if (status == VisitorStatus::Ok) + { + contains = true; + break; + } + } + + if (status == VisitorStatus::Exhausted) + { + if (contains_all) + return false; + else + continue; + } + } + + col_bool.insert(contains ? 1 : 0); + return true; + } + + bool insertResultToColumn(IColumn & dest, ElementIterator & iterator, ASTs & query_ptrs, bool contains_all) const + { + ColumnUInt8 & col_bool = assert_cast(dest); + + bool contains = false; + for (const auto & query_ptr : query_ptrs) + { + GeneratorJSONPath generator_json_path(query_ptr); + auto current_element = iterator.getElement(); + VisitorStatus status; + + while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) { if (status == VisitorStatus::Ok) { - if (!contains_all) - contains = true; + contains = true; break; } + current_element = iterator.getElement(); } if (status == VisitorStatus::Exhausted) @@ -757,44 +1685,21 @@ class FunctionSQLJSONContainsPath : public IFunction, WithConstContext const auto & json_column = arguments[0]; auto first_type_base = removeNullable(removeLowCardinality(json_column.type)); + bool is_string = isString(first_type_base); bool is_object = isObject(first_type_base); bool is_tuple = isTuple(first_type_base); - if (!is_object && !is_tuple) + if (!is_string && !is_object && !is_tuple) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "The first argument of function {} should be a string containing Object or Tuple, illegal type: {}", + "The first argument of function {} should be a string containing Object or Tuple or JSON, illegal type: {}", Name::name, json_column.type->getName()); const auto & any_or_all_column = arguments[1]; if (!isString(any_or_all_column.type) || !isColumnConst(*any_or_all_column.column)) throw Exception("Second argument (any or all) must be constant string", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - MutableColumnPtr to{result_type->createColumn()}; - to->reserve(input_rows_count); - - const auto & arg_json = json_column.column; - const auto * col_json_const = typeid_cast(arg_json.get()); - const auto * col_json_object - = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); - - if (!col_json_object) - throw Exception{ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()}; - - ColumnPtr column_tuple; - DataTypePtr type_tuple; - if (is_object) - { - std::tie(column_tuple, type_tuple) = unflattenObjectToTuple(*col_json_object); - } - else - { - column_tuple = col_json_object->getPtr(); - type_tuple = json_column.type; - } - const ColumnPtr & arg_any_or_all = any_or_all_column.column; const auto * arg_any_or_all_const = typeid_cast(arg_any_or_all.get()); const auto * arg_any_or_all_string = typeid_cast(arg_any_or_all_const->getDataColumnPtr().get()); @@ -850,14 +1755,74 @@ class FunctionSQLJSONContainsPath : public IFunction, WithConstContext json_path_ast_ptrs.emplace_back(res); } - for (const auto i : collections::range(0, input_rows_count)) + MutableColumnPtr to{result_type->createColumn()}; + to->reserve(input_rows_count); + + const auto & arg_json = json_column.column; + const auto * col_json_const = typeid_cast(arg_json.get()); + + if (is_object || is_tuple) { - ObjectIterator iterator(type_tuple, column_tuple, col_json_const ? 0 : i); + const auto * col_json_object + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + + if (!col_json_object) + throw Exception{ErrorCodes::ILLEGAL_COLUMN, "Illegal column {}", arg_json->getName()}; - bool added_to_column = this->insertResultToColumn(*to, iterator, json_path_ast_ptrs, contains_all); - if (!added_to_column) + ColumnPtr column_tuple; + DataTypePtr type_tuple; + + if (is_object) { - to->insertDefault(); + std::tie(column_tuple, type_tuple) = unflattenObjectToTuple(*col_json_object); + } + else + { + column_tuple = col_json_object->getPtr(); + type_tuple = json_column.type; + } + + + for (const auto i : collections::range(0, input_rows_count)) + { + ObjectIterator iterator(type_tuple, column_tuple, col_json_const ? 0 : i); + + bool added_to_column = this->insertResultToColumn(*to, iterator, json_path_ast_ptrs, contains_all); + if (!added_to_column) + { + to->insertDefault(); + } + } + } + else + { + const auto * col_json_string + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + + const ColumnString::Chars & chars_json = col_json_string->getChars(); + const ColumnString::Offsets & offsets_json = col_json_string->getOffsets(); + + SimdJSONParser json_parser; + using Element = typename SimdJSONParser::Element; + Element document; + bool document_ok = false; + + for (const auto i : collections::range(0, input_rows_count)) + { + std::string_view json{ + reinterpret_cast(&chars_json[offsets_json[i - 1]]), offsets_json[i] - offsets_json[i - 1] - 1}; + document_ok = json_parser.parse(json, document); + + bool added_to_column = false; + ElementIterator iterator(document); + if (document_ok) + { + added_to_column = this->insertResultToColumn(*to, iterator, json_path_ast_ptrs, contains_all); + } + if (!added_to_column) + { + to->insertDefault(); + } } } @@ -865,37 +1830,129 @@ class FunctionSQLJSONContainsPath : public IFunction, WithConstContext } }; -template -class SQLJSONExtractPathImpl +template +class FunctionSQLJSONArrayContains : public IFunction, WithConstContext { public: - using Element = typename JSONParser::Element; + static FunctionPtr create(ContextPtr context_) { return std::make_shared(context_); } + explicit FunctionSQLJSONArrayContains(ContextPtr context_) : WithConstContext(context_) + {} - static DataTypePtr getReturnType(const char *, const ColumnsWithTypeAndName &) { return std::make_shared(); } + static constexpr auto name = Name::name; + String getName() const override { return Name::name; } + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + return std::make_shared(); + } - static bool insertResultToColumn(IColumn & dest, const Element & root, ASTPtr & query_ptr) + template + ColumnPtr internalExecuteImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, uint32_t /*parse_depth*/) const { - GeneratorJSONPath generator_json_path(query_ptr); - Element current_element = root; - VisitorStatus status; - while ((status = generator_json_path.getNextItem(current_element)) != VisitorStatus::Exhausted) + if (arguments.size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", Name::name); + + const auto & json_column = arguments[0]; + auto first_type_base = removeNullable(removeLowCardinality(json_column.type)); + + bool is_string = isString(first_type_base); + + if (!is_string) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "The first argument of function {} should be a string, illegal type: {}", + Name::name, + json_column.type->getName()); + + + const ColumnPtr & arg_json = json_column.column; + const auto * col_json_const = typeid_cast(arg_json.get()); + const auto * col_json_string + = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); + /// Get data and offsets for 2 argument (JSON) + const ColumnString::Chars & chars_json = col_json_string->getChars(); + const ColumnString::Offsets & offsets_json = col_json_string->getOffsets(); + + const auto & target_value_column = arguments[1]; + + JSONParser json_parser; + using Element = typename JSONParser::Element; + Element document; + bool document_ok = false; + + MutableColumnPtr to{result_type->createColumn()}; + to->reserve(input_rows_count); + + auto compare = [&target_value_column](const Element & json_array_element) { - if (status == VisitorStatus::Ok) + if (isString(target_value_column.type) && json_array_element.isString()) { - break; + return target_value_column.column->getDataAt(0).toString() == json_array_element.getString(); } - current_element = root; - } - if (status == VisitorStatus::Exhausted) + if (isNumber(target_value_column.type) && json_array_element.isInt64()) + { + return target_value_column.column->getInt(0) == json_array_element.getInt64(); + } + + if (isBool(target_value_column.type) && json_array_element.isBool()) + return target_value_column.column->getBool(0) == json_array_element.getBool(); + + return false; + }; + + for (const auto i : collections::range(0, input_rows_count)) { - return false; + + std::string_view json{ + reinterpret_cast(&chars_json[offsets_json[i - 1]]), offsets_json[i] - offsets_json[i - 1] - 1}; + document_ok = json_parser.parse(json, document); + if (!document_ok) + { + to->insertDefault(); + continue; + } + + if (document.isArray()) + { + const auto & json_array = document.getArray(); + for (auto it = json_array.begin(); it != json_array.end(); ++it) + { + if (compare(*it)) + { + to->insert(1); + break; + } + } + to->insertDefault(); + } + else + { + to->insertDefault(); + } } - ElementIterator iterator(current_element); - return JSONExtractRawImpl>::insertResultToColumn(dest, iterator); + return to; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + { + /// Choose JSONParser. + /// 1. Lexer(path) -> Tokens + /// 2. Create ASTPtr + /// 3. Parser(Tokens, ASTPtr) -> complete AST + /// 4. Execute functions: call getNextItem on generator and handle each item + uint32_t parse_depth = getContext()->getSettingsRef().max_parser_depth; +#if USE_SIMDJSON + if (getContext()->getSettingsRef().allow_simdjson) + return this->template internalExecuteImpl(arguments, result_type, input_rows_count, parse_depth); +#endif + return this->template internalExecuteImpl(arguments, result_type, input_rows_count, parse_depth); } }; diff --git a/src/Functions/FunctionSipHashBuiltin.cpp b/src/Functions/FunctionSipHashBuiltin.cpp new file mode 100644 index 00000000000..98765ed6ef6 --- /dev/null +++ b/src/Functions/FunctionSipHashBuiltin.cpp @@ -0,0 +1,11 @@ +#include +#include +#include + +namespace DB +{ +REGISTER_FUNCTION(SipHashBuiltin) +{ + factory.registerFunction(FunctionSipHashBuiltin::name, FunctionFactory::CaseSensitive); +} +} diff --git a/src/Functions/FunctionSipHashBuiltin.h b/src/Functions/FunctionSipHashBuiltin.h new file mode 100644 index 00000000000..4a499f00800 --- /dev/null +++ b/src/Functions/FunctionSipHashBuiltin.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB { + + +class FunctionSipHashBuiltin : public IFunction +{ +public: + static constexpr auto name = "sipHashBuitin"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool useDefaultImplementationForConstants() const override { return false; } + + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForNothing() const override { return false; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } + + virtual DataTypePtr getReturnTypeImpl(const DataTypes & ) const override + { + return std::make_shared(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + auto result_column = ColumnUInt64::create(input_rows_count, 0); + auto & result_date = result_column->getData(); + for (size_t i = 0; i < input_rows_count; i++) + { + SipHash hash; + for (const auto & argument : arguments) + { + argument.column->updateHashWithValue(i, hash); + } + result_date[i] = hash.get64(); + } + return result_column; + } +}; + +} + diff --git a/src/Functions/FunctionSketch.h b/src/Functions/FunctionSketch.h index 81f0fcd2e8e..5920d63ea73 100644 --- a/src/Functions/FunctionSketch.h +++ b/src/Functions/FunctionSketch.h @@ -176,13 +176,14 @@ class FunctionHLLSketch : public IFunction return name; } - size_t getNumberOfArguments() const override { return 1; } + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - if (arguments.size() != 1) + if (arguments.size() != 1 && arguments.size() != 2) throw Exception("Illegal argument size of function " + getName(), ErrorCodes::BAD_ARGUMENTS); @@ -207,6 +208,7 @@ class FunctionHLLSketch : public IFunction ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { + bool use_composite_estimate = arguments.size() > 1 ? true : false; if (arguments[0].column->isNullable()) { auto result_type = std::make_shared(std::make_shared()); @@ -222,7 +224,10 @@ class FunctionHLLSketch : public IFunction else { datasketches::hll_sketch hll_sketch_data = datasketches::hll_sketch::deserialize(nullable_sketch.getDataAt(i).data, nullable_sketch.getDataAt(i).size, AggregateFunctionHllSketchAllocator()); - result_column->insert(hll_sketch_data.get_estimate()); + if (use_composite_estimate) + result_column->insert(hll_sketch_data.get_composite_estimate()); + else + result_column->insert(hll_sketch_data.get_estimate()); } } return result_column; @@ -238,7 +243,10 @@ class FunctionHLLSketch : public IFunction { auto value = value_column.getDataAt(i); datasketches::hll_sketch hll_sketch_data = datasketches::hll_sketch::deserialize(value.data, value.size, AggregateFunctionHllSketchAllocator()); - dst_data[i] = hll_sketch_data.get_estimate(); + if (use_composite_estimate) + dst_data[i] = hll_sketch_data.get_composite_estimate(); + else + dst_data[i] = hll_sketch_data.get_estimate(); } return result_column; diff --git a/src/Functions/FunctionsBitEngineHelper.h b/src/Functions/FunctionsBitEngineHelper.h deleted file mode 100644 index 10fd8b5e0e7..00000000000 --- a/src/Functions/FunctionsBitEngineHelper.h +++ /dev/null @@ -1,455 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int BITENGINE_DICT_EXCEPTION; -} - -static NameSet bitmap_aggregate_functions{ - "bitmapextract", "bitmapextractbysplit", "bitmapcolumnor", "bitmapcolumnxor", "bitmapcolumnand", "bitmapextractv2"}; - -struct BitEngineDictionaryEncodeInfo -{ - String database; - String table; - String dict_name; - bool add_new; - bool tolerant_loss; -}; - -String tryGetDictName(const String & column_name, StorageCloudMergeTree * merge_tree) -{ - String name_lowercase = Poco::toLower(column_name); - size_t pos = name_lowercase.find_first_of('('); - - // if we enter tryGetDictName, that means string column_name is not - // a legal filed of table, and if there's no '(' in it. It may be a - // wrong dict_name, return empty string to throw exception. - if (pos == std::string::npos) - return ""; - - String inner_func_name = name_lowercase.substr(0, name_lowercase.find_first_of('(')); - - if (!bitmap_aggregate_functions.count(inner_func_name) && name_lowercase.find("bitmap") == std::string::npos) - return ""; - - auto check = [&](auto & storage_bitengine_cloud, auto & name) { - if (auto dict_manager = storage_bitengine_cloud->getBitEngineDictionaryManager()) - return dict_manager->hasBitEngineDictionary(name); - return false; - }; - - size_t pos_begin = column_name.find_last_of('('); - size_t pos_end = column_name.find_last_of(')'); - if (pos_begin == std::string::npos || pos_end == std::string::npos) - return ""; - - Names values; - size_t pre_index = pos_end; - for (size_t i = pos_begin + 1; i <= pos_end; ++i) - { - char c = column_name[i]; - if (c == ' ' || c == ',' || c == '(' || c == ')') - { - if (pre_index != pos_end) - { - values.push_back(column_name.substr(pre_index, i - pre_index)); - pre_index = pos_end; - } - } - else if (pre_index == pos_end) - pre_index = i; - } - - String res; - for (const auto & value : values) - { - String tmp_name; - if (merge_tree && check(merge_tree, value)) - tmp_name = value; - if (!res.empty() && !tmp_name.empty()) - throw Exception("BitEngine cannot decode multiple column in single function", ErrorCodes::BITENGINE_DICT_EXCEPTION); - else if (!tmp_name.empty()) - res = tmp_name; - } - return res; -} - -inline bool checkDataTypeForBitEngineDecode(const DataTypePtr & data_type) -{ - return isBitmap64(data_type) || WhichDataType(data_type).isNativeUInt(); -} - -bool checkDataTypeForBitEngineEncode(const DataTypePtr & data_type) -{ - return isBitmap64(data_type) || isNativeInteger(data_type) || isArrayOfString(data_type) || isArrayOfUInt64(data_type) - || isString(data_type); -} - -StorageCloudMergeTree * loadDictsForCnchServer( - StorageCnchMergeTree * storage_bitengine_cnch, - const String & encode_database, - const String & encode_table, - [[maybe_unused]] MemoryDictMode mode, - const ContextPtr & local_context) -{ - if (!storage_bitengine_cnch->isBitEngineTable()) - throw Exception( - fmt::format("Table <`{}`.`{}`> is not a BitEngine table", encode_database, encode_table), ErrorCodes::BITENGINE_DICT_EXCEPTION); - - auto query_context = local_context->getQueryContext(); - /// init and get WorkerResource, then create cloud table for BitEngine, as well as load parts - if (!query_context->tryGetCnchWorkerResource()) - query_context->initCnchWorkerResource(); - auto worker_resource = query_context->getCnchWorkerResource(); - - auto lock = worker_resource->getBitEngineDictLoadLock(); - - auto bitengine_tables_in_query = query_context->getBitEngineTables(); - if (!bitengine_tables_in_query) - throw Exception( - ErrorCodes::BITENGINE_DICT_EXCEPTION, - "Not found bitengine tables info in context! node type: {}", - query_context->getServerTypeString()); - - auto it = bitengine_tables_in_query->find(storage_bitengine_cnch->getStorageUUID()); - - BitEngineDictionaryTableMapping underlying_dictionary_table_cloud; - auto cloud_table_name = storage_bitengine_cnch->getCloudTableName(local_context); - if (!it->second.cloud_table_created_on_server) - { - /// create cloud table for underlying dict tables - /// NOTE: load parts are note implemented yet - const auto & dicts_mapping = storage_bitengine_cnch->getUnderlyDictionaryTables(); - for (const auto & entry : dicts_mapping) - { - auto storage_underlying_dict - = DatabaseCatalog::instance().tryGetTable(StorageID{entry.second.first, entry.second.second}, local_context); - StorageCnchMergeTree * storage_underlying_dict_cnch = dynamic_cast(storage_underlying_dict.get()); - if (storage_underlying_dict_cnch) - { - auto dict_table_name_cloud = storage_underlying_dict_cnch->getCloudTableName(local_context); - underlying_dictionary_table_cloud.emplace(entry.first, std::make_pair(entry.second.first, dict_table_name_cloud)); - - auto create_table_query = storage_underlying_dict_cnch->getCreateQueryForCloudTable( - storage_underlying_dict_cnch->getCreateTableSql(), - dict_table_name_cloud, - local_context, - false, - std::nullopt, - {}, - {}, - WorkerEngineType::DICT); - - /// try find dict_cloud_table first, maybe it's created already, like insert into select DecodeBitmap() - auto storage_underlying_dict_cloud - = worker_resource->getTable(StorageID{storage_bitengine_cnch->getDatabaseName(), dict_table_name_cloud}); - - bool dict_cloud_already_exists{true}; - if (!storage_underlying_dict_cloud) - { - worker_resource->executeCreateQuery(local_context->getQueryContext(), create_table_query, /* skip_if_exists */ true); - dict_cloud_already_exists = false; - } - - /// after dict_cloud_table created, now get and load parts - storage_underlying_dict_cloud - = worker_resource->getTable(StorageID{storage_bitengine_cnch->getDatabaseName(), dict_table_name_cloud}); - auto * underlying_dict_cloud_table = dynamic_cast(storage_underlying_dict_cloud.get()); - - if (!underlying_dict_cloud_table) - { - throw Exception( - fmt::format( - "In decoding, cannot get DictCloudMergeTree for table:<`{}`.`{}`>", - storage_bitengine_cnch->getDatabaseName(), - dict_table_name_cloud), - ErrorCodes::UNKNOWN_TABLE); - } - - if (!dict_cloud_already_exists) - { - auto server_parts = storage_underlying_dict_cnch->getAllPartsWithDBM(local_context).first; - MergeTreeDataPartTypeHelper::MutableDataPartsVector parts; - for (auto & part : server_parts) - { - parts.emplace_back(part->toCNCHDataPart(*storage_underlying_dict_cnch)); - } - - underlying_dict_cloud_table->loadDataParts(parts); - } - } - } - - /// create cloud table for bitengine table - auto create_table_query = storage_bitengine_cnch->getCreateQueryForCloudTable( - storage_bitengine_cnch->getCreateTableSql(), - cloud_table_name, - local_context, - false, - std::nullopt, - {}, - {}, - WorkerEngineType::CLOUD, - underlying_dictionary_table_cloud); - - worker_resource->executeCreateQuery(local_context->getQueryContext(), create_table_query, /* skip_if_exists */ true); - LOG_DEBUG( - &Poco::Logger::get("FunctionsBitEngineHelper"), - "Created table {} on {}", - storage_bitengine_cnch->getStorageID().getFullTableName(), - local_context->getServerTypeString()); - it->second.cloud_table_created_on_server = true; - it->second.dict_loaded_on_server = true; - } - - auto storage_bitengine = worker_resource->getTable(StorageID{storage_bitengine_cnch->getDatabaseName(), cloud_table_name}); - - auto * cloud_table = dynamic_cast(storage_bitengine.get()); - - if (!cloud_table) - { - throw Exception( - fmt::format( - "Cannot parse and get a CloudMergeTree from the argument of decode function, " - "which is <`{}`.`{}`>, and the cloud table name is <`{}`.`{}`>", - encode_database, - encode_table, - storage_bitengine_cnch->getDatabaseName(), - cloud_table_name), - ErrorCodes::UNKNOWN_TABLE); - } - - return cloud_table; -} - -inline StorageCloudMergeTree * -getCloudTable(const String & database, const String & table, [[maybe_unused]] MemoryDictMode mode, const ContextPtr & local_context) -{ - StoragePtr storage = DatabaseCatalog::instance().getTable(StorageID{database, table}, local_context); - auto * cloud_table = dynamic_cast(storage.get()); - if (!cloud_table) - { - auto * storage_bitengine_cnch = dynamic_cast(storage.get()); - if (!storage_bitengine_cnch) - { - throw Exception( - fmt::format("`{}`.`{}` is not a StorageCnchMergeTree table, check the arguments", database, table), - ErrorCodes::UNKNOWN_TABLE); - } - - cloud_table = loadDictsForCnchServer(storage_bitengine_cnch, database, table, mode, local_context); - - if (!cloud_table) - throw Exception( - fmt::format( - "BitEngine dict manager is not initialized for table `{}`.`{}`" - ", or it's not a BitEngine table", - cloud_table->getDatabaseName(), - cloud_table->getTableName()), - ErrorCodes::UNKNOWN_TABLE); - } - - if (!cloud_table->isBitEngineMode()) - throw Exception( - fmt::format( - "BitEngine dict manager is not initialized for table `{}`.`{}`" - ", or it's not a BitEngine table", - cloud_table->getDatabaseName(), - cloud_table->getTableName()), - ErrorCodes::BITENGINE_DICT_EXCEPTION); - - return cloud_table; -} - -inline String getValidDictName(StorageCloudMergeTree * storage_bitengine_cloud, String & input_dict_name) -{ - String res = input_dict_name; - - if (!storage_bitengine_cloud->isBitEngineEncodeColumn(input_dict_name)) - { - res = tryGetDictName(input_dict_name, storage_bitengine_cloud); - if (res.empty()) - throw Exception( - "In the first argument: " + input_dict_name + ", we cannot find dict name(which is a bitmap field from table, same name). " - + "You should make sure a physical bitmap field exists in the first argument. Or you can try DecodeBitmap", - ErrorCodes::BITENGINE_DICT_EXCEPTION); - } - - auto name_and_type = storage_bitengine_cloud->getInMemoryMetadataPtr()->getColumns().tryGetPhysical(res); - if (!name_and_type.has_value() || !name_and_type.value().type->isBitEngineEncode()) - throw Exception( - "The type of column " + res + " (type: " + name_and_type.value().type->getName() + ") is not BitMap64", - ErrorCodes::BITENGINE_DICT_EXCEPTION); - - return res; -} - -void checkDictionaryPrerequisiteTypes(const DataTypes & arguments, const String & func_name) -{ - /// first check dictionary info: db, tbl, bitmap_filed - /// they are all String - for (size_t i = 1; i < 4; ++i) - { - if (!checkAndGetDataType(arguments[i].get())) - throw Exception( - fmt::format( - "The {} argument type of function {} is not right, received={}, expected={}", - argPositionToSequence(i + 1), - func_name, - arguments[i]->getName(), - "String"), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - } - - /// Until now, Bool switches are not used in CNCH, do not throw exception - /// when enabling optimizer (1 will be assigned Int8, not UInt8) - for (size_t i = 4; i < arguments.size(); ++i) - { - if (!checkAndGetDataType(arguments[i].get())) - throw Exception( - fmt::format( - "The {} argument type of function {} is not right, recived={}, expected={}", - argPositionToSequence(i + 1), - func_name, - arguments[i]->getName(), - "UInt8"), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - } -} - -BitEngineDictionaryEncodeInfo getDictionaryPrerequisite(const ColumnsWithTypeAndName & arguments, const String & func_name) -{ - const auto * database_column = checkAndGetColumnEvenIfConst(arguments[1].column.get()); - const auto * table_column = checkAndGetColumnEvenIfConst(arguments[2].column.get()); - if (!database_column || !table_column) - throw Exception("Function " + func_name + " has illegal argument of database or table", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - String database = database_column->getDataAt(0).toString(); - String table = table_column->getDataAt(0).toString(); - - String dict_name; - bool add_new{false}; - bool tolerant_loss{false}; - if (arguments.size() >= 4) - { - const auto * dict_column = checkAndGetColumnEvenIfConst(arguments[3].column.get()); - if (!dict_column) - throw Exception("Function " + func_name + " has illegal argument of dictionary", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - dict_name = dict_column->getDataAt(0).toString(); - } - if (arguments.size() >= 5) - { - const auto * setting_column = checkAndGetColumnEvenIfConst(arguments[4].column.get()); - add_new = setting_column->getBool(0); - } - if (arguments.size() >= 6) - { - const auto * tolerant_column = checkAndGetColumnEvenIfConst(arguments[5].column.get()); - tolerant_loss = tolerant_column->getBool(0); - } - - return {database, table, dict_name, add_new, tolerant_loss}; -} - -void checkDataTypeAndDictKeyType(const DataTypePtr & data_type) -{ - if (!checkDataTypeForBitEngineEncode(data_type)) - { - throw Exception( - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "BitEngine cannot encode column, it's type: {}. BitEngine can encode type: " - "UInt64, BitMap64, Array(Integer), Array(String), String", - data_type->getName()); - } -} - -/// `DISCARD` here means those keys not-found are discarded by BitEngineDictionary -/// As for user-input data, there are two ways: -/// 1. for bitmap, those keys not-found is discard in the result, so the result bitmap size may be smaller -/// 2. for non BitEngine column (like UInt64), those keys are replaced by a default value in result column -/// In both two ways we keep the column size unchanged -ColumnPtr -encodeColumnDiscardUnknown(ColumnWithTypeAndName & column, BitEngineDictionaryEncodeInfo & encode_info, const ContextPtr & local_context) -{ - checkDataTypeAndDictKeyType(column.type); - - auto * storage_bitengine_cloud = getCloudTable(encode_info.database, encode_info.table, MemoryDictMode::ENCODE, local_context); - - BitEngineEncodeSettings encode_settings - = BitEngineEncodeSettings(local_context->getSettingsRef(), storage_bitengine_cloud->getSettings()) - .bitengineEncodeWithoutLock(true) - .bitengineTolerantLoss(encode_info.tolerant_loss); - - ColumnWithTypeAndName column_encoded = storage_bitengine_cloud->getBitEngineDictionaryManager()->encodeColumn( - column, encode_info.dict_name, local_context, encode_settings); - - return column_encoded.column; -} - -/// `ADD` means keys not found in dict will be encoded -ColumnPtr -encodeColumnAddUnknown(ColumnWithTypeAndName & column, BitEngineDictionaryEncodeInfo & encode_info, const ContextPtr & local_context) -{ - checkDataTypeAndDictKeyType(column.type); - - auto * storage_bitengine_cloud = getCloudTable(encode_info.database, encode_info.table, MemoryDictMode::ENCODE, local_context); - - auto column_encoded - = storage_bitengine_cloud->getBitEngineDictionaryManager()->encodeColumn(column, encode_info.dict_name, local_context, {}); - - return column_encoded.column; -} - -template -ColumnPtr decodeColumn( - ColumnWithTypeAndName & column, - BitEngineDictionaryEncodeInfo & encode_info, - bool try_parse_dict_name, - const KeyType & expected_key_type, - const ContextPtr & local_context) -{ - if (!checkDataTypeForBitEngineDecode(column.type)) - throw Exception( - "BitEngine cannot decode column " + column.name + ", since it's type is not UInt64 or Bitmap64", - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - auto * storage_bitengine_cloud = getCloudTable(encode_info.database, encode_info.table, MemoryDictMode::ENCODE, local_context); - - String valid_dict_name = try_parse_dict_name ? getValidDictName(storage_bitengine_cloud, encode_info.dict_name) : encode_info.dict_name; - - auto dict_manager = storage_bitengine_cloud->getBitEngineDictionaryManager(); - - if (expected_key_type == KeyType::KEY_STRING && !dict_manager->isKeyStringDictionary(valid_dict_name)) - { - throw Exception( - "You expect BitEngine KeyType::STRING, but the dict " + valid_dict_name + " is not this type.", - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - } - else if (expected_key_type == KeyType::KEY_INTEGER && dict_manager->isKeyStringDictionary(valid_dict_name)) - { - throw Exception( - "You expect BitEngine KeyType::INTEGER, but the dict " + valid_dict_name + " is not this type.", - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - } - - ColumnPtr column_decoded = dict_manager->decodeColumn(column, valid_dict_name, local_context, {}); - return column_decoded; -} - -} diff --git a/src/Functions/FunctionsConversion.h b/src/Functions/FunctionsConversion.h index 6b40cf4c4fb..46990618e95 100644 --- a/src/Functions/FunctionsConversion.h +++ b/src/Functions/FunctionsConversion.h @@ -499,9 +499,9 @@ struct ConvertImpl && !std { static ColumnPtr execute(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count, UInt16 = 0) { - const auto & time_zone = DateLUT::instance(); + const auto & time_zone = DateLUT::sessionInstance(); const auto today = time_zone.toDayNum(time(nullptr)); - auto date_time = DateLUT::instance().fromDayNum(today); + auto date_time = DateLUT::sessionInstance().fromDayNum(today); if constexpr (std::is_same_v || std::is_same_v) { @@ -516,7 +516,7 @@ struct ConvertImpl && !std auto time_scale = sources->getScale(); auto scale_multiplier = DecimalUtils::scaleMultiplier(time_scale); // result - auto mutable_result_col = result_type->createColumn(); + auto mutable_result_col = removeNullable(result_type)->createColumn(); auto * col_to = assert_cast(mutable_result_col.get()); auto & col_to_data = col_to->getData(); col_to_data.resize(input_rows_count); @@ -550,7 +550,8 @@ struct ConvertImpl && !std UInt32(date_time) + components.whole, components.fractional, dt_scale_multiplier); } } - + if (result_type->isNullable()) + return ColumnNullable::create(std::move(mutable_result_col), ColumnUInt8::create(input_rows_count, 0)); return mutable_result_col; } }; @@ -709,7 +710,7 @@ struct ToDate32Transform32Or64Signed static inline NO_SANITIZE_UNDEFINED ToType execute(const FromType & from, const DateLUTImpl & time_zone) { - static const Int32 daynum_min_offset = -static_cast(DateLUT::instance().getDayNumOffsetEpoch()); + static const Int32 daynum_min_offset = -static_cast(time_zone.getDayNumOffsetEpoch()); if (from < daynum_min_offset) return daynum_min_offset; return (from < DATE_LUT_MAX_EXTEND_DAY_NUM) @@ -1020,18 +1021,18 @@ struct FormatImpl template <> struct FormatImpl { - static void execute(const DataTypeDate::FieldType x, WriteBuffer & wb, const DataTypeDate *, const DateLUTImpl *) + static void execute(const DataTypeDate::FieldType x, WriteBuffer & wb, const DataTypeDate *, const DateLUTImpl * time_zone) { - writeDateText(DayNum(x), wb); + writeDateText(DayNum(x), wb, *time_zone); } }; template <> struct FormatImpl { - static void execute(const DataTypeDate32::FieldType x, WriteBuffer & wb, const DataTypeDate32 *, const DateLUTImpl *) + static void execute(const DataTypeDate32::FieldType x, WriteBuffer & wb, const DataTypeDate32 *, const DateLUTImpl * time_zone) { - writeDateText(ExtendedDayNum(x), wb); + writeDateText(ExtendedDayNum(x), wb, *time_zone); } }; @@ -1109,7 +1110,9 @@ struct ConvertImplconvertToFullColumnIfConst(); const DateLUTImpl * time_zone = nullptr; - /// For argument of DateTime type, second argument with time zone could be specified. + if constexpr (std::is_same_v || std::is_same_v) + time_zone = &DateLUT::sessionInstance(); + /// For argument of Date or DateTime type, second argument with time zone could be specified. if constexpr (std::is_same_v || std::is_same_v) time_zone = &extractTimeZoneFromFunctionArguments(arguments, 1, 0); @@ -1264,18 +1267,18 @@ void parseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTI } template <> -inline void parseImpl(DataTypeDate::FieldType & x, ReadBuffer & rb, const DateLUTImpl *) +inline void parseImpl(DataTypeDate::FieldType & x, ReadBuffer & rb, const DateLUTImpl * time_zone) { DayNum tmp(0); - readDateText(tmp, rb); + readDateText(tmp, rb, *time_zone); x = tmp; } template <> -inline void parseImpl(DataTypeDate32::FieldType & x, ReadBuffer & rb, const DateLUTImpl *) +inline void parseImpl(DataTypeDate32::FieldType & x, ReadBuffer & rb, const DateLUTImpl * time_zone) { ExtendedDayNum tmp(0); - readDateText(tmp, rb); + readDateText(tmp, rb, *time_zone); x = tmp; } @@ -1322,20 +1325,20 @@ bool tryParseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateL } template <> -inline bool tryParseImpl(DataTypeDate::FieldType & x, ReadBuffer & rb, const DateLUTImpl *) +inline bool tryParseImpl(DataTypeDate::FieldType & x, ReadBuffer & rb, const DateLUTImpl * time_zone) { DayNum tmp(0); - if (!tryReadDateText(tmp, rb)) + if (!tryReadDateText(tmp, rb, *time_zone)) return false; x = tmp; return true; } template <> -inline bool tryParseImpl(DataTypeDate32::FieldType & x, ReadBuffer & rb, const DateLUTImpl *) +inline bool tryParseImpl(DataTypeDate32::FieldType & x, ReadBuffer & rb, const DateLUTImpl * time_zone) { ExtendedDayNum tmp(0); - if (!tryReadDateText(tmp, rb)) + if (!tryReadDateText(tmp, rb, *time_zone)) return false; // ExtendedDayNum is int32 and DataTypeData32::FieldType is also int32 // coverity[store_truncates_time_t] @@ -1549,7 +1552,7 @@ struct ConvertThroughParsing const DateLUTImpl * local_time_zone [[maybe_unused]] = nullptr; const DateLUTImpl * utc_time_zone [[maybe_unused]] = nullptr; - /// For conversion to DateTime type, second argument with time zone could be specified. + /// For conversion to Date or DateTime type, second argument with time zone could be specified. if constexpr (std::is_same_v || to_datetime64) { const auto result_type = removeNullable(res_type); @@ -1564,6 +1567,12 @@ struct ConvertThroughParsing if constexpr (parsing_mode == ConvertFromStringParsingMode::BestEffort || parsing_mode == ConvertFromStringParsingMode::BestEffortUS) utc_time_zone = &DateLUT::instance("UTC"); } + else if constexpr (std::is_same_v || std::is_same_v) + { + // Timezone is more or less dummy when parsing Date/Date32 from string. + local_time_zone = &DateLUT::sessionInstance(); + utc_time_zone = &DateLUT::instance("UTC"); + } const IColumn * col_from = arguments[0].column.get(); const ColumnString * col_from_string = checkAndGetColumn(col_from); @@ -1807,7 +1816,7 @@ struct ConvertThroughParsing { if constexpr (std::is_same_v) { - vec_to[i] = -static_cast(DateLUT::instance().getDayNumOffsetEpoch()); + vec_to[i] = -static_cast(DateLUT::sessionInstance().getDayNumOffsetEpoch()); } else { @@ -2230,7 +2239,7 @@ class FunctionConvert : public IFunction || std::is_same_v // toDate(value[, timezone : String]) || std::is_same_v // TODO: shall we allow timestamp argument for toDate? DateTime knows nothing about timezones and this argument is ignored below. - // toDate(value[, timezone : String]) + // toDate32(value[, timezone : String]) || std::is_same_v // toDateTime(value[, timezone: String]) || std::is_same_v diff --git a/src/Functions/FunctionsHashing.h b/src/Functions/FunctionsHashing.h index 22c3ce0259c..4fea62d135c 100644 --- a/src/Functions/FunctionsHashing.h +++ b/src/Functions/FunctionsHashing.h @@ -1610,13 +1610,39 @@ class FunctionAnyHash : public IFunction if (const ColumnNullable * nullable = typeid_cast(column)) { const IColumn * nullable_column = &nullable->getNestedColumn(); - executeAny(key, nullable_type, nullable_column, vec_to); const auto & null_map_data = nullable->getNullMapData(); auto s = nullable_column->size(); - /// Use fixed data for nulls. - for (size_t row = 0; row < s; ++row) - if (null_map_data[row]) - vec_to[row] = value; + if (first) + { + executeAny(key, nullable_type, nullable_column, vec_to); + /// Use fixed data for nulls. + for (size_t row = 0; row < s; ++row) + if (null_map_data[row]) + vec_to[row] = value; + } + else + { + std::vector null_list; + std::vector null_list_value_before; + null_list.reserve(s); + null_list_value_before.reserve(s); + for (size_t row = 0; row < s; ++row) + { + if (null_map_data[row]) + { + null_list.push_back(row); + null_list_value_before.push_back(vec_to[row]); + } + } + executeAny(key, nullable_type, nullable_column, vec_to); + + for (size_t i = 0; i < null_list.size(); ++i) + { + size_t row = null_list[i]; + ToType value_before = null_list_value_before[i]; + vec_to[row] = value_before; + } + } } // else if (const ColumnNullable * nullable_const = checkAndGetColumnConstData(column)) // { diff --git a/src/Functions/FunctionsJSON.cpp b/src/Functions/FunctionsJSON.cpp index aa2a116c0ba..0dc9f3488f3 100644 --- a/src/Functions/FunctionsJSON.cpp +++ b/src/Functions/FunctionsJSON.cpp @@ -177,6 +177,7 @@ REGISTER_FUNCTION(JSON) factory.registerFunction>(); factory.registerFunction>(); factory.registerFunction>(FunctionFactory::CaseSensitiveness::CaseInsensitive); + factory.registerAlias("JSON_UNQUOTE", "JSONUnquote", FunctionFactory::CaseInsensitive); } } diff --git a/src/Functions/FunctionsJSON.h b/src/Functions/FunctionsJSON.h index b5a7cf6158d..074715c792c 100644 --- a/src/Functions/FunctionsJSON.h +++ b/src/Functions/FunctionsJSON.h @@ -34,6 +34,7 @@ #include #include #include +#include "Core/SettingsEnums.h" #include #include #include @@ -107,11 +108,13 @@ class FunctionJSONHelpers throw Exception{"Function " + String(Name::name) + " requires at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; const auto & first_column = arguments[0]; - if (!isString(first_column.type)) + auto first_type_base = removeNullable(removeLowCardinality(first_column.type)); + + if (!isString(first_type_base)) throw Exception{"The first argument of function " + String(Name::name) + " should be a string containing JSON, illegal type: " + first_column.type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - - const ColumnPtr & arg_json = first_column.column; + + const ColumnPtr & arg_json = recursiveAssumeNotNullable(first_column.column); const auto * col_json_const = typeid_cast(arg_json.get()); const auto * col_json_string = typeid_cast(col_json_const ? col_json_const->getDataColumnPtr().get() : arg_json.get()); @@ -316,6 +319,8 @@ class FunctionJSONHelpers const Element & getElement() const { return element; } std::string_view getLastKey() const { return last_key; } + using JSONParserType = JSONParser; + private: Element element; std::string_view last_key; @@ -618,7 +623,6 @@ class ExecutableFunctionJSONTuple : public ExecutableFunctionJSONBase class FunctionBaseFunctionJSON : public IFunctionBase { @@ -939,7 +943,7 @@ class JSONHasImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator &) + static bool insertResultToColumn(IColumn & dest, Iterator &) { auto & col_vec = assert_cast &>(dest); col_vec.insertValue(1); @@ -967,7 +971,7 @@ class IsValidJSONImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName &) { return 0; } - static bool insertResultToColumn(IColumn & dest, const Iterator &) + static bool insertResultToColumn(IColumn & dest, Iterator &) { /// This function is called only if JSON is valid. /// If JSON isn't valid then `FunctionJSON::Executor::run()` adds default value (=zero) to `dest` without calling this function. @@ -990,7 +994,7 @@ class JSONLengthImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); @@ -1007,7 +1011,7 @@ class JSONLengthImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { auto & to_vec = assert_cast &>(dest); @@ -1049,7 +1053,7 @@ class JSONKeyImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) { auto last_key = iterator.getLastKey(); if (last_key.empty()) @@ -1083,7 +1087,7 @@ class JSONTypeImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); UInt8 type; @@ -1111,7 +1115,7 @@ class JSONTypeImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { WhichDataType which(iterator.getType()); UInt8 type; @@ -1149,7 +1153,7 @@ class JSONExtractNumericImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); NumberType value; @@ -1188,7 +1192,7 @@ class JSONExtractNumericImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { const auto & from = iterator.getColumn(); UInt64 row = iterator.getRow(); @@ -1230,7 +1234,7 @@ class JSONExtractBoolImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isBool()) @@ -1241,7 +1245,7 @@ class JSONExtractBoolImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { return JSONExtractUInt8Impl::insertResultToColumn(dest, iterator); } @@ -1258,20 +1262,24 @@ class JSONExtractRawImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator, DialectType dialect_type = DialectType::CLICKHOUSE) requires IsElementIterator { const auto & element = iterator.getElement(); auto & col_str = assert_cast(dest); auto & chars = col_str.getChars(); WriteBufferFromVector buf(chars, AppendModeTag()); - Traverse::traverse(element, buf); + if (dialect_type == DialectType::MYSQL) + Traverse::traverse(element, buf, true); + else + Traverse::traverse(element, buf); + buf.finalize(); chars.push_back(0); col_str.getOffsets().push_back(chars.size()); return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator, DialectType dialect_type = DialectType::CLICKHOUSE) requires IsObjectIterator { const auto & type = iterator.getType(); const auto & column = iterator.getColumn(); @@ -1284,10 +1292,13 @@ class JSONExtractRawImpl WriteBufferFromVector buf(to_chars, AppendModeTag{}); + const auto & format_setting + = dialect_type == DialectType::MYSQL ? Traverse::unquote_format_settings() : Traverse::format_settings(); + if (isDummyTuple(*type)) writeString("{}", buf); else - serialization->serializeTextJSON(*column, row, buf, Traverse::format_settings()); + serialization->serializeTextJSON(*column, row, buf, format_setting); writeChar(0, buf); buf.finalize(); @@ -1314,7 +1325,7 @@ class JSONUnquoteImpl col_str.insertData(json_data.data(), json_data.size()); } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); auto & col_str = assert_cast(dest); @@ -1322,7 +1333,7 @@ class JSONUnquoteImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { const auto & type = iterator.getType(); const auto & column = iterator.getColumn(); @@ -1359,7 +1370,7 @@ class JSONExtractStringImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isString()) @@ -1371,7 +1382,7 @@ class JSONExtractStringImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator) requires IsObjectIterator { const auto & column = iterator.getColumn(); UInt64 row = iterator.getRow(); @@ -1396,14 +1407,14 @@ struct JSONExtractTree public: Node() = default; virtual ~Node() = default; - virtual bool insertResultToColumn(IColumn &, const Iterator &) = 0; + virtual bool insertResultToColumn(IColumn &, Iterator &) = 0; }; template class NumericNode : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { return JSONExtractNumericImpl::insertResultToColumn(dest, iterator); } @@ -1417,7 +1428,7 @@ struct JSONExtractTree { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { auto from_col = dictionary_type->createColumn(); if (impl->insertResultToColumn(*from_col, iterator)) @@ -1436,7 +1447,7 @@ struct JSONExtractTree class UUIDNodeString : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); if (!element.isString()) @@ -1451,7 +1462,7 @@ struct JSONExtractTree class UUIDNodeObject : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { if (!isString(iterator.getType())) return false; @@ -1473,7 +1484,7 @@ struct JSONExtractTree public: explicit DecimalNodeString(DataTypePtr data_type_) : data_type(std::move(data_type_)) {} - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); const auto * type = assert_cast *>(data_type.get()); @@ -1503,7 +1514,7 @@ struct JSONExtractTree public: explicit DecimalNodeObject(DataTypePtr data_type_) : data_type(std::move(data_type_)) {} - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto * decimal_type = assert_cast *>(data_type.get()); const auto & from = iterator.getColumn(); @@ -1533,7 +1544,7 @@ struct JSONExtractTree class StringNodeString : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); if (element.isString()) @@ -1548,7 +1559,7 @@ struct JSONExtractTree class StringNodeObject : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { return JSONExtractStringImpl::insertResultToColumn(dest, iterator); } @@ -1559,7 +1570,7 @@ struct JSONExtractTree class FixedStringNodeString : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); if (!element.isString()) @@ -1576,7 +1587,7 @@ struct JSONExtractTree class FixedStringNodeObject : public Node { public: - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { if (!isString(iterator.getType())) return false; @@ -1622,7 +1633,7 @@ struct JSONExtractTree { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); auto & col_vec = assert_cast &>(dest); @@ -1667,7 +1678,7 @@ struct JSONExtractTree { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { auto & to_vec = assert_cast &>(dest); const auto & from = iterator.getColumn(); @@ -1711,7 +1722,7 @@ struct JSONExtractTree public: explicit NullableNode(std::unique_ptr nested_) : nested(std::move(nested_)) {} - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { auto & col_null = assert_cast(dest); if (!nested->insertResultToColumn(col_null.getNestedColumn(), iterator)) @@ -1729,7 +1740,7 @@ struct JSONExtractTree public: explicit ArrayNodeString(std::unique_ptr nested_) : nested(std::move(nested_)) {} - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); if (!element.isArray()) @@ -1744,7 +1755,8 @@ struct JSONExtractTree for (auto value : array) { - if (nested->insertResultToColumn(data, Iterator{value})) + auto temp_it = Iterator{value}; + if (nested->insertResultToColumn(data, temp_it)) were_valid_elements = true; else data.insertDefault(); @@ -1769,7 +1781,7 @@ struct JSONExtractTree public: explicit ArrayNodeObject(std::unique_ptr nested_) : nested(std::move(nested_)) {} - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto * from_array = typeid_cast(iterator.getColumn().get()); if (!from_array) @@ -1834,7 +1846,7 @@ struct JSONExtractTree { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { ColumnTuple & tuple = assert_cast(dest); size_t old_size = dest.size(); @@ -1865,7 +1877,8 @@ struct JSONExtractTree for (size_t index = 0; (index != this->nested.size()) && (it != array.end()); ++index) { - if (this->nested[index]->insertResultToColumn(tuple.getColumn(index), Iterator{*it++})) + auto temp_it = Iterator{*it++}; + if (this->nested[index]->insertResultToColumn(tuple.getColumn(index), temp_it)) were_valid_elements = true; else tuple.getColumn(index).insertDefault(); @@ -1882,7 +1895,8 @@ struct JSONExtractTree auto it = object.begin(); for (size_t index = 0; (index != this->nested.size()) && (it != object.end()); ++index) { - if (this->nested[index]->insertResultToColumn(tuple.getColumn(index), Iterator{(*it++).second})) + auto temp_it = Iterator{(*it++).second}; + if (this->nested[index]->insertResultToColumn(tuple.getColumn(index), temp_it)) were_valid_elements = true; else tuple.getColumn(index).insertDefault(); @@ -1895,7 +1909,8 @@ struct JSONExtractTree auto index = this->name_to_index_map.find(key); if (index != this->name_to_index_map.end()) { - if (this->nested[index->second]->insertResultToColumn(tuple.getColumn(index->second), Iterator{value})) + auto temp_it = Iterator{value}; + if (this->nested[index->second]->insertResultToColumn(tuple.getColumn(index->second), temp_it)) were_valid_elements = true; } } @@ -1916,7 +1931,7 @@ struct JSONExtractTree { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { auto & to_tuple = assert_cast(dest); size_t old_size = dest.size(); @@ -1984,7 +1999,7 @@ struct JSONExtractTree public: MapNodeString(std::unique_ptr key_, std::unique_ptr value_) : MapNodeBase(std::move(key_), std::move(value_)) { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto & element = iterator.getElement(); if (!element.isObject()) @@ -2006,7 +2021,8 @@ struct JSONExtractTree key_col.insertData(pair.first.data(), pair.first.size()); /// Insert value - if (!this->value->insertResultToColumn(value_col, Iterator{pair.second})) + auto temp_it = Iterator{pair.second}; + if (!this->value->insertResultToColumn(value_col, temp_it)) value_col.insertDefault(); } @@ -2020,7 +2036,7 @@ struct JSONExtractTree public: MapNodeObject(std::unique_ptr key_, std::unique_ptr value_) : MapNodeBase(std::move(key_), std::move(value_)) { } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) override + bool insertResultToColumn(IColumn & dest, Iterator & iterator) override { const auto * from_tuple_type = typeid_cast(iterator.getType().get()); if (!from_tuple_type) @@ -2152,7 +2168,7 @@ class JSONExtractImpl extract_tree = JSONExtractTree::build(function_name, result_type); } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) + bool insertResultToColumn(IColumn & dest, Iterator & iterator) { return extract_tree->insertResultToColumn(dest, iterator); } @@ -2195,7 +2211,7 @@ class JSONExtractKeysAndValuesImpl extract_tree = JSONExtractTree::build(function_name, value_type); } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isObject()) @@ -2223,7 +2239,7 @@ class JSONExtractKeysAndValuesImpl return true; } - bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { const auto * from_tuple_type = typeid_cast(iterator.getType().get()); if (!from_tuple_type || !from_tuple_type->haveExplicitNames()) @@ -2273,7 +2289,7 @@ class GetJsonObjectImpl static size_t getResolveArgumentIndex(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); ColumnString & col_str = assert_cast(dest); @@ -2286,7 +2302,7 @@ class GetJsonObjectImpl return true; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsObjectIterator { const auto & type = iterator.getType(); const auto & column = iterator.getColumn(); @@ -2324,7 +2340,7 @@ class JSONExtractArrayRawImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isArray()) @@ -2384,7 +2400,7 @@ class JSONExtractKeysAndValuesRawImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isObject()) @@ -2450,7 +2466,7 @@ class JSONExtractKeysImpl static size_t getNumberOfIndexArguments(const ColumnsWithTypeAndName & arguments) { return arguments.size() - 1; } - static bool insertResultToColumn(IColumn & dest, const Iterator & iterator) requires IsElementIterator + static bool insertResultToColumn(IColumn & dest, Iterator & iterator) requires IsElementIterator { const auto & element = iterator.getElement(); if (!element.isObject()) @@ -2470,7 +2486,7 @@ class JSONExtractKeysImpl return true; } - bool insertResultToColumn(IColumn & dest, const ObjectIterator & iterator) requires IsObjectIterator + static bool insertResultToColumn(IColumn & dest, ObjectIterator & iterator) requires IsObjectIterator { const auto * type_tuple = typeid_cast(iterator.getType().get()); if (!type_tuple || !type_tuple->haveExplicitNames()) diff --git a/src/Functions/HasTokenImpl.h b/src/Functions/HasTokenImpl.h index ac5303e361e..498cb73dcbd 100644 --- a/src/Functions/HasTokenImpl.h +++ b/src/Functions/HasTokenImpl.h @@ -18,7 +18,7 @@ namespace ErrorCodes /** Token search the string, means that needle must be surrounded by some separator chars, like whitespace or puctuation. */ -template +template struct HasTokenImpl { using ResultType = UInt8; @@ -47,7 +47,8 @@ struct HasTokenImpl const UInt8 * const end = haystack_data.data() + haystack_data.size(); const UInt8 * pos = begin; - if (const auto has_separator = std::any_of(pattern.cbegin(), pattern.cend(), isTokenSeparator); has_separator || pattern.empty()) + if (const auto has_separator = std::any_of(pattern.cbegin(), pattern.cend(), isTokenSeparator); + (has_separator && !enable_separator_inside) || pattern.empty()) { if (res_null) { @@ -74,8 +75,8 @@ struct HasTokenImpl while (pos < end && end != (pos = searcher.search(pos, end - pos))) { /// The found substring is a token - if ((pos == begin || isTokenSeparator(pos[-1])) - && (pos + pattern_size == end || isTokenSeparator(pos[pattern_size]))) + if (enable_separator_inside + || ((pos == begin || isTokenSeparator(pos[-1])) && (pos + pattern_size == end || isTokenSeparator(pos[pattern_size])))) { /// Let's determine which index it refers to. while (begin + haystack_offsets[i] <= pos) diff --git a/src/Functions/IFunction.h b/src/Functions/IFunction.h index 95a938d03fc..fd866d393e9 100644 --- a/src/Functions/IFunction.h +++ b/src/Functions/IFunction.h @@ -163,7 +163,16 @@ class IFunctionBase /// Do preparations and return executable. /// sample_columns should contain data types of arguments and values of constants, if relevant. - virtual ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName & arguments) const = 0; + virtual ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName & /*arguments*/) const + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "prepare is not implemented for function {}", getName()); + } + + /// Do preparations with extra parameters and return executable. + virtual ExecutableFunctionPtr prepareWithParameters(const ColumnsWithTypeAndName & /*arguments*/, const Array & /*parameters*/) const + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "prepare with extra parameters is not implemented for function {}", getName()); + } #if USE_EMBEDDED_COMPILER diff --git a/src/Functions/IFunctionCustomWeek.h b/src/Functions/IFunctionCustomWeek.h index f1408acf73e..9f81410fdfa 100644 --- a/src/Functions/IFunctionCustomWeek.h +++ b/src/Functions/IFunctionCustomWeek.h @@ -55,7 +55,7 @@ class IFunctionCustomWeek : public IFunction const IFunction::Monotonicity is_not_monotonic; /// This method is called only if the function has one argument. Therefore, we do not care about the non-local time zone. - const DateLUTImpl & date_lut = DateLUT::instance(); + const DateLUTImpl & date_lut = DateLUT::sessionInstance(); if (left.isNull() || right.isNull()) return {}; diff --git a/src/Functions/IFunctionDateOrDateTime.h b/src/Functions/IFunctionDateOrDateTime.h index 600faba068c..b40705cbfad 100644 --- a/src/Functions/IFunctionDateOrDateTime.h +++ b/src/Functions/IFunctionDateOrDateTime.h @@ -63,7 +63,7 @@ class IFunctionDateOrDateTime : public IFunction const IFunction::Monotonicity is_monotonic(/* is_monotonic */ true, /* is_positive */ true, /* is_always_monotonic */false); const IFunction::Monotonicity is_not_monotonic; - const DateLUTImpl * date_lut = &DateLUT::instance(); + const DateLUTImpl * date_lut = &DateLUT::sessionInstance(); if (const auto * timezone = dynamic_cast(&type)) date_lut = &timezone->getTimeZone(); diff --git a/src/Functions/JSONPath/Generator/ObjectJSONVisitorJSONPathMemberAccess.h b/src/Functions/JSONPath/Generator/ObjectJSONVisitorJSONPathMemberAccess.h index 0caa7f38c8b..0dc8e82930f 100644 --- a/src/Functions/JSONPath/Generator/ObjectJSONVisitorJSONPathMemberAccess.h +++ b/src/Functions/JSONPath/Generator/ObjectJSONVisitorJSONPathMemberAccess.h @@ -17,7 +17,11 @@ class ObjectJSONVisitorJSONPathMemberAccess : public IObjectJSONVisitor VisitorStatus apply(ObjectIterator & iterator) override { const auto * type_tuple = typeid_cast(iterator.getType().get()); + if (!type_tuple || !type_tuple->haveExplicitNames()) + return VisitorStatus::Error; auto pos = type_tuple->tryGetPositionByName(member_access_ptr->member_name); + if (!pos) + return VisitorStatus::Error; const auto& type = type_tuple->getElement(*pos); auto subcolumn = assert_cast(*iterator.getColumn()).getColumnPtr(*pos); iterator.setType(type); diff --git a/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.cpp b/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.cpp new file mode 100644 index 00000000000..2a533984d64 --- /dev/null +++ b/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +/** + * + * @param pos token iterator + * @param node node of ParserJSONPathArrayIndex + * @param expected stuff for logging + * @return was parse successful + * '$.a.1' -> is_start_with_dot = true + * '$.1' -> is_start_with_dot = false + */ +bool ParserJSONPathArrayIndex::parseImpl(Pos & pos, ASTPtr & node, Expected & /*expected*/) +{ + bool is_start_with_dot = false; + if (pos->type == TokenType::Dot) + { + is_start_with_dot = true; + ++pos; + } + + if (pos->type != TokenType::Number) + return false; + + auto range = std::make_shared(); + node = range; + + std::pair range_indices; + + std::string number_str; + number_str.assign(is_start_with_dot ? pos->begin : pos->begin + 1, pos->end); + UInt32 index; + if (!Poco::NumberParser::tryParseUnsigned(number_str, index)) + return false; + range_indices.first = index; + range_indices.second = range_indices.first + 1; + range->ranges.push_back(std::move(range_indices)); + + ++pos; + + return !range->ranges.empty(); +} + +} diff --git a/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.h b/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.h new file mode 100644 index 00000000000..7b0836f54d8 --- /dev/null +++ b/src/Functions/JSONPath/Parsers/ParserJSONPathArrayIndex.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace DB +{ +class ParserJSONPathArrayIndex : public IParserBase +{ + const char * getName() const override { return "ParserJSONPathArrayIndex"; } + + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; + +} diff --git a/src/Functions/JSONPath/Parsers/ParserJSONPathQuery.cpp b/src/Functions/JSONPath/Parsers/ParserJSONPathQuery.cpp index c18b2ad9b31..b43063ed1f4 100644 --- a/src/Functions/JSONPath/Parsers/ParserJSONPathQuery.cpp +++ b/src/Functions/JSONPath/Parsers/ParserJSONPathQuery.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB @@ -22,6 +23,7 @@ bool ParserJSONPathQuery::parseImpl(Pos & pos, ASTPtr & query, Expected & expect ParserJSONPathRange parser_jsonpath_range; ParserJSONPathStar parser_jsonpath_star; ParserJSONPathRoot parser_jsonpath_root; + ParserJSONPathArrayIndex parser_jsonpath_array_index; ASTPtr path_root; if (!parser_jsonpath_root.parse(pos, path_root, expected)) @@ -33,7 +35,8 @@ bool ParserJSONPathQuery::parseImpl(Pos & pos, ASTPtr & query, Expected & expect ASTPtr accessor; while (parser_jsonpath_member_access.parse(pos, accessor, expected) || parser_jsonpath_range.parse(pos, accessor, expected) - || parser_jsonpath_star.parse(pos, accessor, expected)) + || parser_jsonpath_star.parse(pos, accessor, expected) + || parser_jsonpath_array_index.parse(pos, accessor, expected)) { if (accessor) { diff --git a/src/Functions/JSONPath/Parsers/ParserJSONPathRange.cpp b/src/Functions/JSONPath/Parsers/ParserJSONPathRange.cpp index b51e59ac2f3..aeb7197747e 100644 --- a/src/Functions/JSONPath/Parsers/ParserJSONPathRange.cpp +++ b/src/Functions/JSONPath/Parsers/ParserJSONPathRange.cpp @@ -43,6 +43,10 @@ namespace ErrorCodes */ bool ParserJSONPathRange::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { + if (pos->type == TokenType::Dot) + { + ++pos; + } if (pos->type != TokenType::OpeningSquareBracket) { diff --git a/src/Functions/array/arrayIndex.h b/src/Functions/array/arrayIndex.h index 0eae4030f27..eb5c8b9e875 100644 --- a/src/Functions/array/arrayIndex.h +++ b/src/Functions/array/arrayIndex.h @@ -365,12 +365,13 @@ class FunctionArrayIndex : public IFunction static constexpr auto name = Name::name; static FunctionPtr create(ContextPtr context) { + bool return_nullable_array = !context || context->getSettingsRef().allow_return_nullable_array; if (context && context->getSettingsRef().enable_implicit_arg_type_convert) - return std::make_shared(std::make_unique(context)); - return std::make_shared(context); + return std::make_shared(std::make_unique(context, return_nullable_array)); + return std::make_shared(context, return_nullable_array); } - explicit FunctionArrayIndex(ContextPtr context_) : context(context_) {} + explicit FunctionArrayIndex(ContextPtr context_, bool return_nullable_array_ = true) : context(context_), allow_return_nullable_array(return_nullable_array_) {} ArgType getArgumentsType() const override { return ArgType::ARRAY_FIRST; } @@ -400,7 +401,7 @@ class FunctionArrayIndex : public IFunction "numeric types, or Enum and numeric type. Passed: {} and {}.", getName(), arguments[0]->getName(), arguments[1]->getName()); - if (nullable_type || (context && context->getSettingsRef().enable_implicit_arg_type_convert && arguments[1]->isNullable())) + if (allow_return_nullable_array && (nullable_type || (context && context->getSettingsRef().enable_implicit_arg_type_convert && arguments[1]->isNullable()))) return std::make_shared(std::make_shared>()); return std::make_shared>(); @@ -420,7 +421,8 @@ class FunctionArrayIndex : public IFunction { ColumnsWithTypeAndName tmp_args = {{nullable_col->getNestedColumnPtr(), removeNullable(arguments[0].type), arguments[0].name}, arguments[1]}; auto res = executeInternalImpl(tmp_args, removeNullable(result_type), input_rows_count); - return wrapInNullable(res, arguments, result_type, input_rows_count); + return allow_return_nullable_array ? wrapInNullable(res, arguments, result_type, input_rows_count) + : res; } return executeInternalImpl(arguments, result_type, input_rows_count); @@ -521,6 +523,7 @@ class FunctionArrayIndex : public IFunction private: ContextPtr context; + bool allow_return_nullable_array; using ResultType = typename ConcreteAction::ResultType; using ResultColumnType = ColumnVector; using ResultColumnPtr = decltype(ResultColumnType::create()); diff --git a/src/Functions/currentTime.cpp b/src/Functions/currentTime.cpp index ec60c3e3285..3871d04b343 100644 --- a/src/Functions/currentTime.cpp +++ b/src/Functions/currentTime.cpp @@ -130,7 +130,7 @@ class CurrentTimeOverloadResolver : public IFunctionOverloadResolver } DateTime64 dt64 = DB::nowSubsecondDt64(scale); ToTimeTransform transformer(scale, scale); - Decimal64::NativeType t = transformer.execute(dt64, intExp10(scale), DateLUT::instance()); + Decimal64::NativeType t = transformer.execute(dt64, intExp10(scale), DateLUT::sessionInstance()); return std::make_unique(t, scale, std::make_shared(scale)); } }; diff --git a/src/Functions/dateDiff.cpp b/src/Functions/dateDiff.cpp index 49a54658c70..e5c0ee12fdb 100644 --- a/src/Functions/dateDiff.cpp +++ b/src/Functions/dateDiff.cpp @@ -444,7 +444,7 @@ class DateDiffImpl */ // Note: It is impossible to take the civil-time diff of 2 different timezones to mysql's timestampdiff. // Because, mysql will convert both absolute times to the same timezone given by the session time_zone variable. - const DateLUTImpl & date_lut = DateLUT::instance(); + const DateLUTImpl & date_lut = DateLUT::sessionInstance(); bool should_swap = seconds_x > seconds_y; struct DateTimeComponents { const DateLUTImpl::Values & values; @@ -803,13 +803,13 @@ class FunctionTimeDiff : public IFunction { auto res = ColumnTime::create(rows, 0); - impl.dispatchForColumns(x, y, DateLUT::instance(), DateLUT::instance(), res->getData()); + impl.dispatchForColumns(x, y, DateLUT::sessionInstance(), DateLUT::sessionInstance(), res->getData()); return res; } else { auto res = ColumnInt64::create(rows); - impl.dispatchForColumns(x, y, DateLUT::instance(), DateLUT::instance(), res->getData()); + impl.dispatchForColumns(x, y, DateLUT::sessionInstance(), DateLUT::sessionInstance(), res->getData()); return res; } } diff --git a/src/Functions/dateName.cpp b/src/Functions/dateName.cpp index ae3a4edb32e..b55cc9e01b0 100644 --- a/src/Functions/dateName.cpp +++ b/src/Functions/dateName.cpp @@ -146,7 +146,7 @@ class FunctionDateNameImpl : public IFunction if (std::is_same_v || std::is_same_v) time_zone_tmp = &extractTimeZoneFromFunctionArguments(arguments, 2, 1); else - time_zone_tmp = &DateLUT::instance(); + time_zone_tmp = &DateLUT::sessionInstance(); const auto & times_data = times->getData(); const DateLUTImpl & time_zone = *time_zone_tmp; diff --git a/src/Functions/extractTimeZoneFromFunctionArguments.cpp b/src/Functions/extractTimeZoneFromFunctionArguments.cpp index b6646c6d252..fdc4c930f64 100644 --- a/src/Functions/extractTimeZoneFromFunctionArguments.cpp +++ b/src/Functions/extractTimeZoneFromFunctionArguments.cpp @@ -66,7 +66,7 @@ const DateLUTImpl & extractTimeZoneFromFunctionArguments(const ColumnsWithTypeAn else { if (arguments.size() <= datetime_arg_num) - return DateLUT::instance(); + return DateLUT::sessionInstance(); const auto & dt_arg = arguments[datetime_arg_num].type.get(); /// If time zone is attached to an argument of type DateTime. @@ -75,7 +75,7 @@ const DateLUTImpl & extractTimeZoneFromFunctionArguments(const ColumnsWithTypeAn if (const auto * type = checkAndGetDataType(dt_arg)) return type->getTimeZone(); - return DateLUT::instance(); + return DateLUT::sessionInstance(); } } diff --git a/src/Functions/formatDateTime.cpp b/src/Functions/formatDateTime.cpp index e1f33899e6d..676ee02f638 100644 --- a/src/Functions/formatDateTime.cpp +++ b/src/Functions/formatDateTime.cpp @@ -1320,7 +1320,7 @@ namespace else if (std::is_same_v || std::is_same_v) time_zone_tmp = &extractTimeZoneFromFunctionArguments(arguments, 2, 0); else - time_zone_tmp = &DateLUT::instance(); + time_zone_tmp = &DateLUT::sessionInstance(); const DateLUTImpl & time_zone = *time_zone_tmp; const auto & vec = times->getData(); diff --git a/src/Functions/fromDaysAndToDays.cpp b/src/Functions/fromDaysAndToDays.cpp index d125d7caebf..8676f8602e5 100644 --- a/src/Functions/fromDaysAndToDays.cpp +++ b/src/Functions/fromDaysAndToDays.cpp @@ -109,7 +109,7 @@ namespace { const auto * col_from = checkAndGetColumn(column); - static const Int32 daynum_min_offset = -static_cast(DateLUT::instance().getDayNumOffsetEpoch()); + static const Int32 daynum_min_offset = -static_cast(DateLUT::sessionInstance().getDayNumOffsetEpoch()); MutableColumnPtr res = DataTypeDate32().createColumn(); auto & res_data = dynamic_cast *>(res.get())->getData(); @@ -239,7 +239,7 @@ namespace if (col == nullptr) throw Exception("Column type does not match to the data type", ErrorCodes::ILLEGAL_COLUMN); - const auto & timezone = DateLUT::instance(); + const auto & timezone = DateLUT::sessionInstance(); auto & data = col->getData(); const auto row_size = data.size(); res_data.resize(row_size); diff --git a/src/Functions/hasToken.cpp b/src/Functions/hasToken.cpp index 2ef4fffb3a9..77fded692ae 100644 --- a/src/Functions/hasToken.cpp +++ b/src/Functions/hasToken.cpp @@ -12,6 +12,11 @@ struct NameHasToken static constexpr auto name = "hasToken"; }; +struct NameHasTokens +{ + static constexpr auto name = "hasTokens"; +}; + struct NameHasTokenOrNull { static constexpr auto name = "hasTokenOrNull"; @@ -19,13 +24,15 @@ struct NameHasTokenOrNull using FunctionHasToken = FunctionsStringSearch>; +using FunctionHasTokens + = FunctionsStringSearch>; using FunctionHasTokenOrNull = FunctionsStringSearch, ExecutionErrorPolicy::Null>; REGISTER_FUNCTION(HasToken) { factory.registerFunction(FunctionFactory::CaseSensitive); - + factory.registerFunction(FunctionFactory::CaseSensitive); factory.registerFunction(FunctionFactory::CaseSensitive); } diff --git a/src/Functions/makeDate.cpp b/src/Functions/makeDate.cpp index 5efeb6ad822..20925291efa 100644 --- a/src/Functions/makeDate.cpp +++ b/src/Functions/makeDate.cpp @@ -131,7 +131,7 @@ namespace auto res_column = Traits::ReturnDataType::ColumnType::create(input_rows_count); auto & result_data = res_column->getData(); - const auto & date_lut = DateLUT::instance(); + const auto & date_lut = DateLUT::sessionInstance(); const Int32 max_days_since_epoch = date_lut.makeDayNum(Traits::MAX_DATE[0], Traits::MAX_DATE[1], Traits::MAX_DATE[2]); if (is_year_month_day_variant) @@ -573,7 +573,7 @@ namespace const auto & minute_data = typeid_cast(*converted_arguments[1]).getData(); const auto & second_data = typeid_cast(*converted_arguments[2]).getData(); - const auto & date_lut = DateLUT::instance(); + const auto & date_lut = DateLUT::sessionInstance(); const auto max_fraction = pow(10, precision) - 1; const auto min_time = minTime(date_lut); const auto max_time = maxTime(date_lut); diff --git a/src/Functions/modulo.cpp b/src/Functions/modulo.cpp index a4b001d46a5..e26224e671a 100644 --- a/src/Functions/modulo.cpp +++ b/src/Functions/modulo.cpp @@ -197,13 +197,13 @@ REGISTER_FUNCTION(ModuloLegacy) factory.registerFunction(); } -struct NameBucket {static constexpr auto name = "bucket"; }; -using FunctionBucket = BinaryArithmeticOverloadResolver; +// struct NameBucket {static constexpr auto name = "bucket"; }; +// using FunctionBucket = BinaryArithmeticOverloadResolver; -REGISTER_FUNCTION(Bucket) -{ - factory.registerFunction(); -} +// REGISTER_FUNCTION(Bucket) +// { +// factory.registerFunction(); +// } struct NamePositiveModulo { diff --git a/src/Functions/multiIf.cpp b/src/Functions/multiIf.cpp index 1aee5dd15cf..e80f0884d18 100644 --- a/src/Functions/multiIf.cpp +++ b/src/Functions/multiIf.cpp @@ -124,8 +124,8 @@ class FunctionMultiIf final : public FunctionIfBase */ struct Instruction { - const IColumn * condition = nullptr; - const IColumn * source = nullptr; + IColumn::Ptr condition = nullptr; + IColumn::Ptr source = nullptr; bool condition_always_true = false; bool condition_is_nullable = false; @@ -160,15 +160,15 @@ class FunctionMultiIf final : public FunctionIfBase } else { - const ColumnWithTypeAndName & cond_col = arguments[i]; + IColumn::Ptr cond_col = arguments[i].column->convertToFullColumnIfLowCardinality(); /// We skip branches that are always false. /// If we encounter a branch that is always true, we can finish. - if (cond_col.column->onlyNull()) + if (cond_col->onlyNull()) continue; - if (const auto * column_const = checkAndGetColumn(*cond_col.column)) + if (const auto * column_const = checkAndGetColumn(*cond_col)) { Field value = column_const->getField(); @@ -181,26 +181,24 @@ class FunctionMultiIf final : public FunctionIfBase } else { - if (isColumnNullable(*cond_col.column)) - instruction.condition_is_nullable = true; - - instruction.condition = cond_col.column.get(); + instruction.condition = cond_col; + instruction.condition_is_nullable = instruction.condition->isNullable(); } - instruction.condition_is_short = cond_col.column->size() < arguments[0].column->size(); + instruction.condition_is_short = cond_col->size() < arguments[0].column->size(); } const ColumnWithTypeAndName & source_col = arguments[source_idx]; instruction.source_is_short = source_col.column->size() < arguments[0].column->size(); if (source_col.type->equals(*return_type)) { - instruction.source = source_col.column.get(); + instruction.source = source_col.column; } else { /// Cast all columns to result type. converted_columns_holder.emplace_back(castColumn(source_col, return_type)); - instruction.source = converted_columns_holder.back().get(); + instruction.source = converted_columns_holder.back(); } if (instruction.source && isColumnConst(*instruction.source)) diff --git a/src/Functions/now64.cpp b/src/Functions/now64.cpp index ba24ab256f2..4623ddf5d37 100644 --- a/src/Functions/now64.cpp +++ b/src/Functions/now64.cpp @@ -155,6 +155,7 @@ class Now64OverloadResolver : public IFunctionOverloadResolver REGISTER_FUNCTION(Now64) { factory.registerFunction(FunctionFactory::CaseInsensitive); + factory.registerAlias("CURRENT_TIMESTAMP", Now64OverloadResolver::name, FunctionFactory::CaseInsensitive); } } diff --git a/src/Functions/parseDateTime.cpp b/src/Functions/parseDateTime.cpp index 5beebe2a6ea..de23335dbfd 100644 --- a/src/Functions/parseDateTime.cpp +++ b/src/Functions/parseDateTime.cpp @@ -1970,7 +1970,7 @@ namespace const DateLUTImpl & getTimeZone(const ColumnsWithTypeAndName & arguments) const { if (arguments.size() < 3) - return DateLUT::instance(); + return DateLUT::sessionInstance(); const auto * col = checkAndGetColumnConst(arguments[2].column.get()); if (!col) diff --git a/src/Functions/partitionId.cpp b/src/Functions/partitionId.cpp index 3d4f588c1b0..9db64ded8ca 100644 --- a/src/Functions/partitionId.cpp +++ b/src/Functions/partitionId.cpp @@ -57,7 +57,8 @@ class FunctionPartitionId : public IFunction for (size_t i = 0; i < size; ++i) arguments[i].column->get(j, row[i]); MergeTreePartition partition(std::move(row)); - result_column->insert(partition.getID(sample_block)); + /// TODO: (zuochuang.zema) how to get extract_nullable_date_value + result_column->insert(partition.getID(sample_block, false)); } return result_column; } diff --git a/src/Functions/serverConstants.cpp b/src/Functions/serverConstants.cpp index 13a9a1cbab6..d03ed92242b 100644 --- a/src/Functions/serverConstants.cpp +++ b/src/Functions/serverConstants.cpp @@ -59,16 +59,26 @@ namespace explicit FunctionTcpPort(ContextPtr context) : FunctionConstantBase(context->getTCPPort(), context->isDistributed()) {} }; - - /// Returns the server time zone. + /// Returns timezone for current session. class FunctionTimezone : public FunctionConstantBase { public: static constexpr auto name = "timezone"; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } - explicit FunctionTimezone(ContextPtr context) : FunctionConstantBase(String{DateLUT::instance().getTimeZone()}, context->isDistributed()) {} + explicit FunctionTimezone(ContextPtr context) : FunctionConstantBase(DateLUT::sessionInstance().getTimeZone(), context->isDistributed()) {} }; + /// Returns the server time zone (timezone in which server runs). + class FunctionServerTimezone : public FunctionConstantBase + { + public: + static constexpr auto name = "serverTimezone"; + static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } + explicit FunctionServerTimezone(ContextPtr context) + : FunctionConstantBase(DateLUT::serverTimezoneInstance().getTimeZone(), context->isDistributed()) + { + } + }; /// Returns server uptime in seconds. class FunctionUptime : public FunctionConstantBase @@ -146,6 +156,12 @@ REGISTER_FUNCTION(Timezone) factory.registerAlias("timeZone", "timezone"); } +REGISTER_FUNCTION(ServerTimezone) +{ + factory.registerFunction(); + factory.registerAlias("serverTimeZone", "serverTimezone"); +} + REGISTER_FUNCTION(Uptime) { factory.registerFunction(); diff --git a/src/Functions/timestamp.cpp b/src/Functions/timestamp.cpp index a32356ca763..b7604b6ede4 100644 --- a/src/Functions/timestamp.cpp +++ b/src/Functions/timestamp.cpp @@ -60,7 +60,7 @@ class FunctionTimestamp : public IFunction ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { - const DateLUTImpl * local_time_zone = &DateLUT::instance(); + const DateLUTImpl * local_time_zone = &DateLUT::sessionInstance(); auto col_result = ColumnDateTime64::create(input_rows_count, DATETIME_SCALE); ColumnDateTime64::Container & vec_result = col_result->getData(); diff --git a/src/Functions/timezone.cpp b/src/Functions/timezone.cpp index 07f0d03a57b..40db95862dc 100644 --- a/src/Functions/timezone.cpp +++ b/src/Functions/timezone.cpp @@ -41,7 +41,7 @@ class FunctionTimezone : public IFunction ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override { - return DataTypeString().createColumnConst(input_rows_count, DateLUT::instance().getTimeZone()); + return DataTypeString().createColumnConst(input_rows_count, DateLUT::sessionInstance().getTimeZone()); } }; diff --git a/src/Functions/today.cpp b/src/Functions/today.cpp index fb73603fde6..f9f6940255f 100644 --- a/src/Functions/today.cpp +++ b/src/Functions/today.cpp @@ -101,7 +101,7 @@ class TodayOverloadResolver : public IFunctionOverloadResolver FunctionBasePtr buildImpl(const ColumnsWithTypeAndName &, const DataTypePtr &) const override { - return std::make_unique(DayNum(DateLUT::instance().toDayNum(time(nullptr)).toUnderType())); + return std::make_unique(DayNum(DateLUT::sessionInstance().toDayNum(time(nullptr)).toUnderType())); } }; diff --git a/src/Functions/yesterday.cpp b/src/Functions/yesterday.cpp index fd1701b3335..314745207e9 100644 --- a/src/Functions/yesterday.cpp +++ b/src/Functions/yesterday.cpp @@ -78,7 +78,7 @@ class YesterdayOverloadResolver : public IFunctionOverloadResolver FunctionBasePtr buildImpl(const ColumnsWithTypeAndName &, const DataTypePtr &) const override { - auto day_num = DateLUT::instance().toDayNum(time(nullptr)) - 1; + auto day_num = DateLUT::sessionInstance().toDayNum(time(nullptr)) - 1; return std::make_unique(static_cast(day_num)); } }; diff --git a/src/IO/CompressionMethod.cpp b/src/IO/CompressionMethod.cpp index ce2f4026353..e078d839ce0 100644 --- a/src/IO/CompressionMethod.cpp +++ b/src/IO/CompressionMethod.cpp @@ -49,6 +49,27 @@ std::string toContentEncodingName(CompressionMethod method) __builtin_unreachable(); } +std::string getFileSuffix(CompressionMethod method) +{ + switch (method) + { + case CompressionMethod::Gzip: + return "gz"; + case CompressionMethod::Zlib: + return "deflate"; + case CompressionMethod::Brotli: + return "br"; + case CompressionMethod::Xz: + return "xz"; + case CompressionMethod::Zstd: + return "zstd"; + case CompressionMethod::Snappy: + return "snappy"; + case CompressionMethod::None: + return ""; + } +} + CompressionMethod chooseCompressionMethod(const std::string & path, const std::string & hint) { std::string file_extension; diff --git a/src/IO/CompressionMethod.h b/src/IO/CompressionMethod.h index 105f8baae1c..a95e520b455 100644 --- a/src/IO/CompressionMethod.h +++ b/src/IO/CompressionMethod.h @@ -38,6 +38,8 @@ enum class CompressionMethod /// How the compression method is named in HTTP. std::string toContentEncodingName(CompressionMethod method); +std::string getFileSuffix(CompressionMethod method); + /** Choose compression method from path and hint. * if hint is "auto" or empty string, then path is analyzed, * otherwise path parameter is ignored and hint is used as compression method name. diff --git a/src/IO/OutfileCommon.cpp b/src/IO/OutfileCommon.cpp index ce193bb21bb..b963ebee188 100644 --- a/src/IO/OutfileCommon.cpp +++ b/src/IO/OutfileCommon.cpp @@ -69,7 +69,7 @@ String getFullOutPath(String & format_name, String & path, int serial_no, Compre } if (compression_method != CompressionMethod::None) - out_path += "." + toContentEncodingName(compression_method); + out_path += "." + getFileSuffix(compression_method); return out_path; } @@ -145,9 +145,10 @@ void OutfileTarget::getRawBuffer() else if (scheme == "tos") { if (out_uri.getQueryParameters().empty()) - { throw Exception("Missing access key, please check configuration.", ErrorCodes::BAD_ARGUMENTS); - } + if (compression_method != CompressionMethod::None) + throw Exception("Compression is not supported for tos outfile", ErrorCodes::BAD_ARGUMENTS); + out_buf_raw = std::make_unique(); } #if USE_HDFS @@ -291,7 +292,10 @@ void OutfileTarget::flushFile() ConnectionTimeouts timeouts(settings.http_connection_timeout, settings.http_send_timeout, settings.http_receive_timeout); HTTPSender http_sender(tos_uri, Poco::Net::HTTPRequest::HTTP_PUT, timeouts, http_headers); - http_sender.send((*out_tos_buf).str()); + String res = (*out_tos_buf).str(); + if (res.empty()) + res = "\n"; + http_sender.send(res); http_sender.handleResponse(); } catch (...) diff --git a/src/IO/ReadHelpers.h b/src/IO/ReadHelpers.h index ae7d78fa437..6984b705079 100644 --- a/src/IO/ReadHelpers.h +++ b/src/IO/ReadHelpers.h @@ -738,7 +738,7 @@ inline void convertToDayNum(DayNum & date, ExtendedDayNum & from) } template -inline ReturnType readDateTextImpl(DayNum & date, ReadBuffer & buf) +inline ReturnType readDateTextImpl(DayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut) { static constexpr bool throw_exception = std::is_same_v; @@ -749,13 +749,13 @@ inline ReturnType readDateTextImpl(DayNum & date, ReadBuffer & buf) else if (!readDateTextImpl(local_date, buf)) return false; - ExtendedDayNum ret = DateLUT::instance().makeDayNum(local_date.year(), local_date.month(), local_date.day()); - convertToDayNum(date,ret); + ExtendedDayNum ret = date_lut.makeDayNum(local_date.year(), local_date.month(), local_date.day()); + convertToDayNum(date, ret); return ReturnType(true); } template -inline ReturnType readDateTextImpl(ExtendedDayNum & date, ReadBuffer & buf) +inline ReturnType readDateTextImpl(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut) { static constexpr bool throw_exception = std::is_same_v; @@ -765,8 +765,8 @@ inline ReturnType readDateTextImpl(ExtendedDayNum & date, ReadBuffer & buf) readDateTextImpl(local_date, buf); else if (!readDateTextImpl(local_date, buf)) return false; - /// When the parameter is out of rule or out of range, Date32 uses 1925-01-01 as the default value (-DateLUT::instance().getDayNumOffsetEpoch(), -16436) and Date uses 1970-01-01. - date = DateLUT::instance().makeDayNum(local_date.year(), local_date.month(), local_date.day(), -static_cast(DateLUT::instance().getDayNumOffsetEpoch())); + /// When the parameter is out of rule or out of range, Date32 uses 1925-01-01 as the default value (-DateLUT::serverTimezoneInstance().getDayNumOffsetEpoch(), -16436) and Date uses 1970-01-01. + date = date_lut.makeDayNum(local_date.year(), local_date.month(), local_date.day(), -static_cast(date_lut.getDayNumOffsetEpoch())); return ReturnType(true); } @@ -776,14 +776,14 @@ inline void readDateText(LocalDate & date, ReadBuffer & buf) readDateTextImpl(date, buf); } -inline void readDateText(DayNum & date, ReadBuffer & buf) +inline void readDateText(DayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { - readDateTextImpl(date, buf); + readDateTextImpl(date, buf, date_lut); } -inline void readDateText(ExtendedDayNum & date, ReadBuffer & buf) +inline void readDateText(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { - readDateTextImpl(date, buf); + readDateTextImpl(date, buf, date_lut); } inline bool tryReadDateText(LocalDate & date, ReadBuffer & buf) @@ -791,14 +791,14 @@ inline bool tryReadDateText(LocalDate & date, ReadBuffer & buf) return readDateTextImpl(date, buf); } -inline bool tryReadDateText(DayNum & date, ReadBuffer & buf) +inline bool tryReadDateText(DayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { - return readDateTextImpl(date, buf); + return readDateTextImpl(date, buf, date_lut); } -inline bool tryReadDateText(ExtendedDayNum & date, ReadBuffer & buf) +inline bool tryReadDateText(ExtendedDayNum & date, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { - return readDateTextImpl(date, buf); + return readDateTextImpl(date, buf, date_lut); } template @@ -1277,12 +1277,13 @@ inline ReturnType readTimeTextImpl(Decimal64 & time, UInt32 scale, ReadBuffer & return ReturnType(true); } -inline void readDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +inline void readDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { readDateTimeTextImpl(datetime, buf, time_zone); } -inline void readDateTime64Text(DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +inline void readDateTime64Text( + DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { readDateTimeTextImpl(datetime64, scale, buf, date_lut); } @@ -1297,12 +1298,13 @@ inline bool tryReadTimeText(Decimal64 & time, UInt32 scale, ReadBuffer & buf) return readTimeTextImpl(time, scale, buf); } -inline bool tryReadDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +inline bool tryReadDateTimeText(time_t & datetime, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { return readDateTimeTextImpl(datetime, buf, time_zone); } -inline bool tryReadDateTime64Text(DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::instance()) +inline bool tryReadDateTime64Text( + DateTime64 & datetime64, UInt32 scale, ReadBuffer & buf, const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance()) { return readDateTimeTextImpl(datetime64, scale, buf, date_lut); } @@ -1505,8 +1507,14 @@ tryReadText(T & x, ReadBuffer & buf) { return tryReadFloatText(x, buf); } inline void readText(bool & x, ReadBuffer & buf) { readBoolText(x, buf); } inline void readText(String & x, ReadBuffer & buf) { readEscapedString(x, buf); } -inline void readText(DayNum & x, ReadBuffer & buf) { readDateText(x, buf); } -inline void readText(ExtendedDayNum & x, ReadBuffer & buf) { readDateText(x, buf); } +inline void readText(DayNum & x, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) +{ + readDateText(x, buf, time_zone); +} +inline void readText(ExtendedDayNum & x, ReadBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) +{ + readDateText(x, buf, time_zone); +} inline void readText(LocalDate & x, ReadBuffer & buf) { readDateText(x, buf); } inline void readText(LocalDateTime & x, ReadBuffer & buf) { readDateTimeText(x, buf); } inline void readText(UUID & x, ReadBuffer & buf) { readUUIDText(x, buf); } diff --git a/src/IO/ReadSettings.h b/src/IO/ReadSettings.h index b43c69b3ce1..150d6fc0640 100644 --- a/src/IO/ReadSettings.h +++ b/src/IO/ReadSettings.h @@ -18,6 +18,7 @@ #include #include #include +#include #include namespace DB @@ -121,6 +122,8 @@ struct ReadSettings /// Monitoring bool for_disk_s3 = false; // to choose which profile events should be incremented + Int64 remote_fs_read_failed_injection = 0; + void adjustBufferSize(size_t size) { local_fs_buffer_size = std::min(size, local_fs_buffer_size); diff --git a/src/IO/S3Common.cpp b/src/IO/S3Common.cpp index 0e75e5091c8..e8a23bd5a51 100644 --- a/src/IO/S3Common.cpp +++ b/src/IO/S3Common.cpp @@ -355,7 +355,10 @@ namespace S3 validateBucket(bucket, uri); if (uri.getPath().length() <= 1) throw Exception("Invalid S3 URI: no key: " + uri.toString(), ErrorCodes::BAD_ARGUMENTS); - key = uri.getPath().substr(1); + if (!uri.getQuery().empty()) + key = uri.getPathAndQuery().substr(1); + else + key = uri.getPath().substr(1); is_virtual_hosted_style = false; return; } @@ -693,6 +696,59 @@ namespace S3 } } + S3Util::S3ListResult S3Util::listObjectsWithDelimiter(const String & prefix, String delimiter, bool include_delimiter) const + { + ProfileEvents::increment(ProfileEvents::S3ListObjects); + Aws::S3::Model::ListObjectsV2Request request; + request.SetBucket(bucket); + request.SetPrefix(prefix); + request.SetDelimiter(delimiter); + + S3Util::S3ListResult result; + + while (result.has_more) + { + if (result.token) + request.SetContinuationToken(result.token.value()); + + Aws::S3::Model::ListObjectsV2Outcome outcome = client->ListObjectsV2(request); + + if (outcome.IsSuccess()) + { + const auto & list_result = outcome.GetResult(); + result.has_more = outcome.GetResult().GetIsTruncated(); + result.token = outcome.GetResult().GetNextContinuationToken(); + + size_t reserver_size = result.object_names.size() + list_result.GetContents().size() + list_result.GetCommonPrefixes().size(); + result.object_names.reserve(reserver_size); + result.object_sizes.reserve(reserver_size); + result.is_common_prefix.reserve(reserver_size); + for (const auto & content : list_result.GetContents()) + { + result.object_names.push_back(content.GetKey()); + result.object_sizes.push_back(content.GetSize()); + result.is_common_prefix.push_back(false); + } + for (const auto & common_prefix : list_result.GetCommonPrefixes()) + { + String prefix_path = common_prefix.GetPrefix(); + if (!include_delimiter) + prefix_path.erase(prefix_path.find_last_of(delimiter), delimiter.size()); + + result.object_names.push_back(prefix_path); + result.object_sizes.push_back(0); + result.is_common_prefix.push_back(true); + } + return result; + } + else + { + throw S3Exception(outcome.GetError(), fmt::format("Could not list objects in bucket {} with prefix {}", bucket, prefix)); + } + } + return result; + } + S3Util::S3ListResult S3Util::listObjectsWithPrefix(const String & prefix, const std::optional & token, int limit) const { ProfileEvents::increment(ProfileEvents::S3ListObjects); diff --git a/src/IO/S3Common.h b/src/IO/S3Common.h index de34aa26855..50ba91e8cc8 100644 --- a/src/IO/S3Common.h +++ b/src/IO/S3Common.h @@ -226,7 +226,9 @@ class S3Util std::optional token; Strings object_names; std::vector object_sizes; + std::vector is_common_prefix; }; + S3ListResult listObjectsWithDelimiter(const String & prefix, String delimiter = "/", bool include_delimiter = false) const; S3ListResult listObjectsWithPrefix(const String & prefix, const std::optional & token, int limit = 1000) const; // Write object diff --git a/src/IO/WriteBufferFromFile.cpp b/src/IO/WriteBufferFromFile.cpp index 67cd7ba27d6..a8937ffa0ac 100644 --- a/src/IO/WriteBufferFromFile.cpp +++ b/src/IO/WriteBufferFromFile.cpp @@ -31,8 +31,9 @@ WriteBufferFromFile::WriteBufferFromFile( int flags, mode_t mode, char * existing_memory, - size_t alignment) - : WriteBufferFromFileDescriptor(-1, buf_size, existing_memory, alignment), file_name(file_name_) + size_t alignment, + ThrottlerPtr throttler_) + : WriteBufferFromFileDescriptor(-1, buf_size, existing_memory, alignment, throttler_), file_name(file_name_) { ProfileEvents::increment(ProfileEvents::FileOpen); diff --git a/src/IO/WriteBufferFromFile.h b/src/IO/WriteBufferFromFile.h index 8c535e5461f..23fa1a3e900 100644 --- a/src/IO/WriteBufferFromFile.h +++ b/src/IO/WriteBufferFromFile.h @@ -2,6 +2,7 @@ #include +#include #include #include @@ -35,7 +36,8 @@ class WriteBufferFromFile : public WriteBufferFromFileDescriptor int flags = -1, mode_t mode = 0666, char * existing_memory = nullptr, - size_t alignment = 0); + size_t alignment = 0, + ThrottlerPtr throttler = nullptr); /// Use pre-opened file descriptor. WriteBufferFromFile( diff --git a/src/IO/WriteBufferFromFileDescriptor.cpp b/src/IO/WriteBufferFromFileDescriptor.cpp index 6b6f2034f6c..baef0522af1 100644 --- a/src/IO/WriteBufferFromFileDescriptor.cpp +++ b/src/IO/WriteBufferFromFileDescriptor.cpp @@ -68,6 +68,9 @@ void WriteBufferFromFileDescriptor::nextImpl() Stopwatch watch; + if (throttler) + throttler->add(offset()); + size_t bytes_written = 0; while (bytes_written != offset()) { @@ -106,8 +109,9 @@ WriteBufferFromFileDescriptor::WriteBufferFromFileDescriptor( int fd_, size_t buf_size, char * existing_memory, - size_t alignment) - : WriteBufferFromFileBase(buf_size, existing_memory, alignment), fd(fd_) {} + size_t alignment, + ThrottlerPtr throttler_) + : WriteBufferFromFileBase(buf_size, existing_memory, alignment), fd(fd_), throttler(throttler_) {} WriteBufferFromFileDescriptor::~WriteBufferFromFileDescriptor() diff --git a/src/IO/WriteBufferFromFileDescriptor.h b/src/IO/WriteBufferFromFileDescriptor.h index 1d341852ead..38af396a2b9 100644 --- a/src/IO/WriteBufferFromFileDescriptor.h +++ b/src/IO/WriteBufferFromFileDescriptor.h @@ -22,6 +22,7 @@ #pragma once #include +#include namespace DB @@ -33,6 +34,7 @@ class WriteBufferFromFileDescriptor : public WriteBufferFromFileBase { protected: int fd; + ThrottlerPtr throttler; void nextImpl() override; @@ -44,7 +46,8 @@ class WriteBufferFromFileDescriptor : public WriteBufferFromFileBase int fd_ = -1, size_t buf_size = DBMS_DEFAULT_BUFFER_SIZE, char * existing_memory = nullptr, - size_t alignment = 0); + size_t alignment = 0, + ThrottlerPtr throttler = nullptr); /** Could be used before initialization if needed 'fd' was not passed to constructor. * It's not possible to change 'fd' during work. diff --git a/src/IO/WriteHelpers.h b/src/IO/WriteHelpers.h index 8e62c3d8943..64d04a57903 100644 --- a/src/IO/WriteHelpers.h +++ b/src/IO/WriteHelpers.h @@ -757,15 +757,15 @@ inline void writeDateText(const LocalDate & date, WriteBuffer & buf) } template -inline void writeDateText(DayNum date, WriteBuffer & buf) +inline void writeDateText(DayNum date, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { - writeDateText(LocalDate(date), buf); + writeDateText(LocalDate(date, time_zone), buf); } template -inline void writeDateText(ExtendedDayNum date, WriteBuffer & buf) +inline void writeDateText(ExtendedDayNum date, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { - writeDateText(LocalDate(date), buf); + writeDateText(LocalDate(date, time_zone), buf); } /// In the format YYYY-MM-DD HH:MM:SS @@ -818,14 +818,19 @@ inline void writeDateTimeText(const LocalDateTime & datetime, WriteBuffer & buf) /// In the format YYYY-MM-DD HH:MM:SS, according to the specified time zone. template -inline void writeDateTimeText(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +inline void writeDateTimeText(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { writeDateTimeText(LocalDateTime(datetime, time_zone), buf); } /// In the format YYYY-MM-DD HH:MM:SS.NNNNNNNNN, according to the specified time zone. -template -inline void writeDateTimeText(DateTime64 datetime64, UInt32 scale, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +template < + char date_delimeter = '-', + char time_delimeter = ':', + char between_date_time_delimiter = ' ', + char fractional_time_delimiter = '.'> +inline void +writeDateTimeText(DateTime64 datetime64, UInt32 scale, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { static constexpr UInt32 MaxScale = DecimalUtils::max_precision; scale = scale > MaxScale ? MaxScale : scale; @@ -892,7 +897,7 @@ inline void writeTimeText(Decimal64 time, UInt32 scale, WriteBuffer & buf) /// In the RFC 1123 format: "Tue, 03 Dec 2019 00:11:50 GMT". You must provide GMT DateLUT. /// This is needed for HTTP requests. -inline void writeDateTimeTextRFC1123(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::instance()) +inline void writeDateTimeTextRFC1123(time_t datetime, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) { const auto & values = time_zone.getValues(datetime); @@ -996,8 +1001,10 @@ template <> inline void writeText(const bool & x, WriteBuffer & buf) { wri /// assumes here that `x` is a null-terminated string. inline void writeText(const char * x, WriteBuffer & buf) { writeCString(x, buf); } inline void writeText(const char * x, size_t size, WriteBuffer & buf) { writeString(x, size, buf); } - -inline void writeText(const DayNum & x, WriteBuffer & buf) { writeDateText(LocalDate(x), buf); } +inline void writeText(const DayNum & x, WriteBuffer & buf, const DateLUTImpl & time_zone = DateLUT::serverTimezoneInstance()) +{ + writeDateText(LocalDate(x, time_zone), buf); +} inline void writeText(const LocalDate & x, WriteBuffer & buf) { writeDateText(x, buf); } inline void writeText(const LocalDateTime & x, WriteBuffer & buf) { writeDateTimeText(x, buf); } inline void writeText(const UUID & x, WriteBuffer & buf) { writeUUIDText(x, buf); } diff --git a/src/IO/WriteSettings.h b/src/IO/WriteSettings.h index 4ebb71b7a2e..50f665756dd 100644 --- a/src/IO/WriteSettings.h +++ b/src/IO/WriteSettings.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace DB { @@ -39,6 +40,7 @@ struct WriteSettings WriteMode mode = WriteMode::Rewrite; std::map file_meta = {}; bool allow_overwrite_file = false; + Int64 remote_fs_write_failed_injection = 0; }; } diff --git a/src/IO/examples/parse_date_time_best_effort.cpp b/src/IO/examples/parse_date_time_best_effort.cpp index fc5755f1f95..f238125a142 100644 --- a/src/IO/examples/parse_date_time_best_effort.cpp +++ b/src/IO/examples/parse_date_time_best_effort.cpp @@ -12,7 +12,7 @@ using namespace DB; int main(int, char **) try { - const DateLUTImpl & local_time_zone = DateLUT::instance(); + const DateLUTImpl & local_time_zone = DateLUT::serverTimezoneInstance(); const DateLUTImpl & utc_time_zone = DateLUT::instance("UTC"); ReadBufferFromFileDescriptor in(STDIN_FILENO); diff --git a/src/Interpreters/ActionsVisitor.cpp b/src/Interpreters/ActionsVisitor.cpp index 06b962b4bf5..bff950e0bbb 100644 --- a/src/Interpreters/ActionsVisitor.cpp +++ b/src/Interpreters/ActionsVisitor.cpp @@ -1039,6 +1039,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data & if (make_set) { // check if current column type really has bitmap index + bool has_valid_identifier = false; for (size_t i = 0; i < arg_size; i += 2) { ASTPtr arg_col = node.arguments->children.at(i); @@ -1052,9 +1053,11 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data & should_update_bitmap_index_info = false; break; } + has_valid_identifier = true; } } } + should_update_bitmap_index_info &= has_valid_identifier; auto col_name = node.getColumnName(); if (should_update_bitmap_index_info) diff --git a/src/Interpreters/Aggregator.cpp b/src/Interpreters/Aggregator.cpp index c0f13bb836f..601e91bb1e2 100644 --- a/src/Interpreters/Aggregator.cpp +++ b/src/Interpreters/Aggregator.cpp @@ -2179,7 +2179,7 @@ void NO_INLINE Aggregator::mergeDataOnlyExistingKeysImpl( void NO_INLINE Aggregator::mergeWithoutKeyDataImpl( - ManyAggregatedDataVariants & non_empty_data) const + ManyAggregatedDataVariants & non_empty_data, std::atomic & is_cancelled) const { ThreadPool thread_pool{params.max_threads}; @@ -2197,6 +2197,7 @@ void NO_INLINE Aggregator::mergeWithoutKeyDataImpl( res_data + offsets_of_aggregate_states[i], current_data + offsets_of_aggregate_states[i], thread_pool, + is_cancelled, res->aggregates_pool); else aggregate_functions[i]->merge( @@ -2456,7 +2457,8 @@ void NO_INLINE Aggregator::mergeStreamsImpl( void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( Block & block, - AggregatedDataVariants & result) const + AggregatedDataVariants & result, + std::atomic & is_cancelled [[maybe_unused]]) const { AggregateColumnsConstData aggregate_columns(params.aggregates_size); @@ -2486,7 +2488,8 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( block.clear(); } -bool Aggregator::mergeOnBlock(Block block, AggregatedDataVariants & result, bool & no_more_keys) const + +bool Aggregator::mergeOnBlock(Block block, AggregatedDataVariants & result, bool & no_more_keys, std::atomic & is_cancelled) const { /// `result` will destroy the states of aggregate functions in the destructor result.aggregator = this; @@ -2501,7 +2504,7 @@ bool Aggregator::mergeOnBlock(Block block, AggregatedDataVariants & result, bool } if (result.type == AggregatedDataVariants::Type::without_key || block.info.is_overflows) - mergeWithoutKeyStreamsImpl(block, result); + mergeWithoutKeyStreamsImpl(block, result, is_cancelled); #define M(NAME, IS_TWO_LEVEL) \ else if (result.type == AggregatedDataVariants::Type::NAME) \ @@ -2589,7 +2592,7 @@ bool Aggregator::mergeOnBlock(Block block, AggregatedDataVariants & result, bool } -void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVariants & result, size_t max_threads) +void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVariants & result, size_t max_threads, std::atomic & is_cancelled) { if (bucket_to_blocks.empty()) return; @@ -2636,7 +2639,7 @@ void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVari LOG_TRACE(log, "Merging partially aggregated two-level data."); - auto merge_bucket = [&bucket_to_blocks, &result, this](Int32 bucket, Arena * aggregates_pool, ThreadGroupStatusPtr thread_group) + auto merge_bucket = [&bucket_to_blocks, &result, this, &is_cancelled](Int32 bucket, Arena * aggregates_pool, ThreadGroupStatusPtr thread_group) { if (thread_group) CurrentThread::attachToIfDetached(thread_group); @@ -2645,6 +2648,8 @@ void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVari { for (Block & block : bucket_to_blocks[bucket]) { + if (is_cancelled.load(std::memory_order_seq_cst)) + break; #define M(NAME, IS_TWO_LEVEL) \ else if (result.type == AggregatedDataVariants::Type::NAME) \ mergeStreamsImpl(block, aggregates_pool, *result.NAME, result.NAME->data, nullptr, false); @@ -2660,6 +2665,8 @@ void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVari { for (Block & block : bucket_to_blocks[bucket]) { + if (is_cancelled.load(std::memory_order_seq_cst)) + break; #define M(NAME) \ else if (result.type == AggregatedDataVariants::Type::NAME) \ mergeStreamsImpl(block, aggregates_pool, *result.NAME, result.NAME->data.impls[bucket], nullptr, false); @@ -2713,9 +2720,10 @@ void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVari { if (!checkLimits(result.sizeWithoutOverflowRow(), no_more_keys)) break; - + if (is_cancelled.load(std::memory_order_seq_cst)) + break; if (result.type == AggregatedDataVariants::Type::without_key || block.info.is_overflows) - mergeWithoutKeyStreamsImpl(block, result); + mergeWithoutKeyStreamsImpl(block, result, is_cancelled); #define M(NAME, IS_TWO_LEVEL) \ else if (result.type == AggregatedDataVariants::Type::NAME) \ @@ -2732,7 +2740,7 @@ void Aggregator::mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVari } -Block Aggregator::mergeBlocks(BlocksList & blocks, bool final) +Block Aggregator::mergeBlocks(BlocksList & blocks, bool final, std::atomic & is_cancelled) { if (blocks.empty()) return {}; @@ -2780,9 +2788,10 @@ Block Aggregator::mergeBlocks(BlocksList & blocks, bool final) { if (bucket_num >= 0 && block.info.bucket_num != bucket_num) bucket_num = -1; - + if (is_cancelled.load(std::memory_order_seq_cst)) + break; if (result.type == AggregatedDataVariants::Type::without_key || is_overflows) - mergeWithoutKeyStreamsImpl(block, result); + mergeWithoutKeyStreamsImpl(block, result, is_cancelled); #define M(NAME, IS_TWO_LEVEL) \ else if (result.type == AggregatedDataVariants::Type::NAME) \ diff --git a/src/Interpreters/Aggregator.h b/src/Interpreters/Aggregator.h index 752442edada..b3b792b50f6 100644 --- a/src/Interpreters/Aggregator.h +++ b/src/Interpreters/Aggregator.h @@ -1111,7 +1111,8 @@ class Aggregator final bool & no_more_keys) const; /// Used for aggregate projection. - bool mergeOnBlock(Block block, AggregatedDataVariants & result, bool & no_more_keys) const; + bool mergeOnBlock(Block block, AggregatedDataVariants & result, bool & no_more_keys, std::atomic & is_cancelled) const; + /** Convert the aggregation data structure into a block. * If overflow_row = true, then aggregates for rows that are not included in max_rows_to_group_by are put in the first block. @@ -1126,13 +1127,13 @@ class Aggregator final using BucketToBlocks = std::map; /// Merge partially aggregated blocks separated to buckets into one data structure. - void mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVariants & result, size_t max_threads); + void mergeBlocks(BucketToBlocks bucket_to_blocks, AggregatedDataVariants & result, size_t max_threads, std::atomic & is_cancelled); /// Merge several partially aggregated blocks into one. /// Precondition: for all blocks block.info.is_overflows flag must be the same. /// (either all blocks are from overflow data or none blocks are). /// The resulting block has the same value of is_overflows flag. - Block mergeBlocks(BlocksList & blocks, bool final); + Block mergeBlocks(BlocksList & blocks, bool final, std::atomic & is_cancelled); /** Split block with partially-aggregated data to many blocks, as if two-level method of aggregation was used. * This is needed to simplify merging of that data with other results, that are already two-level. @@ -1346,7 +1347,7 @@ class Aggregator final Arena * arena) const; void mergeWithoutKeyDataImpl( - ManyAggregatedDataVariants & non_empty_data) const; + ManyAggregatedDataVariants & non_empty_data, std::atomic & is_cancelled) const; template void mergeSingleLevelDataImpl( @@ -1435,7 +1436,8 @@ class Aggregator final void mergeWithoutKeyStreamsImpl( Block & block, - AggregatedDataVariants & result) const; + AggregatedDataVariants & result, + std::atomic & is_cancelled) const; template void mergeBucketImpl( diff --git a/src/Interpreters/AsynchronousMetricLog.cpp b/src/Interpreters/AsynchronousMetricLog.cpp index 09345ecca7c..e7110b8286b 100644 --- a/src/Interpreters/AsynchronousMetricLog.cpp +++ b/src/Interpreters/AsynchronousMetricLog.cpp @@ -44,7 +44,7 @@ void AsynchronousMetricLog::addValues(const AsynchronousMetricValues & values) const auto now = std::chrono::system_clock::now(); element.event_time = time_in_seconds(now); element.event_time_microseconds = time_in_microseconds(now); - element.event_date = DateLUT::instance().toDayNum(element.event_time); + element.event_date = DateLUT::serverTimezoneInstance().toDayNum(element.event_time); for (const auto & [key, value] : values) { diff --git a/src/Interpreters/AsynchronousMetrics.cpp b/src/Interpreters/AsynchronousMetrics.cpp index e98b8493eda..cab46bbb432 100644 --- a/src/Interpreters/AsynchronousMetrics.cpp +++ b/src/Interpreters/AsynchronousMetrics.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -574,6 +575,13 @@ void AsynchronousMetrics::update(std::chrono::system_clock::time_point update_ti /// This is also a good indicator of system responsiveness. new_values["Jitter"] = std::chrono::duration_cast(current_time - update_time).count() / 1e9; + { + if (auto cloud_table_definition_cache = getContext()->tryGetCloudTableDefinitionCache()) + { + new_values["CloudTableDefinitionCacheCells"] = cloud_table_definition_cache->count(); + } + } + { if (auto mark_cache = getContext()->getMarkCache()) { @@ -714,6 +722,14 @@ void AsynchronousMetrics::update(std::chrono::system_clock::time_point update_ti } } + { + if (auto gin_store_cache = getContext()->getGinIndexStoreFactory()) + { + new_values["GinStoreCacheCount"] = gin_store_cache->count(); + new_values["GinStoreCacheWeight"] = gin_store_cache->weight(); + } + } + #if USE_EMBEDDED_COMPILER { if (auto * compiled_expression_cache = CompiledExpressionCacheFactory::instance().tryGetCache()) diff --git a/src/Interpreters/ClusterProxy/executeQuery.cpp b/src/Interpreters/ClusterProxy/executeQuery.cpp index 3c9791da74b..3d8f2c60284 100644 --- a/src/Interpreters/ClusterProxy/executeQuery.cpp +++ b/src/Interpreters/ClusterProxy/executeQuery.cpp @@ -126,6 +126,9 @@ ContextMutablePtr updateSettingsForCluster(const Cluster & cluster, ContextPtr c if (!settings.access_table_names.toString().empty()) new_settings.setString("access_table_names", ""); + if (!settings.accessible_table_names.toString().empty()) + new_settings.setString("accessible_table_names", ""); + auto new_context = Context::createCopy(context); new_context->setSettings(new_settings); return new_context; @@ -166,28 +169,6 @@ ContextMutablePtr removeUserRestrictionsFromSettings(ContextPtr context, const S return new_context; } -// For distributed query, rewrite sample ast by dividing sample_size. -// We assume data is evenly distributed and it is reasonable to divided sample_size into several parts. -ASTPtr rewriteSampleForDistributedTable(const ASTPtr & query_ast, size_t shard_size) -{ - ASTPtr rewrite_ast = query_ast->clone(); - ASTSelectQuery * select = rewrite_ast->as(); - if (select && select->sampleSize()) - { - ASTSampleRatio * sample = select->sampleSize()->as(); - if (!sample) - return rewrite_ast; - - ASTSampleRatio::BigNum numerator = sample->ratio.numerator; - ASTSampleRatio::BigNum denominator = sample->ratio.denominator; - if (numerator <= 1 || denominator > 1) - return rewrite_ast; - - sample->ratio.numerator = (sample->ratio.numerator + 1) / shard_size; - } - return rewrite_ast; -} - void executeQuery( QueryPlan & query_plan, IStreamFactory & stream_factory, Poco::Logger * log, diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.cpp b/src/Interpreters/CollectJoinOnKeysVisitor.cpp index ee319fbac3f..5c70a4bc078 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.cpp +++ b/src/Interpreters/CollectJoinOnKeysVisitor.cpp @@ -92,7 +92,11 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as return; /// go into children if (func.name == "or") + { + if (!data.check_function_type_in_join_on_condition) + return; throw Exception("JOIN ON does not support OR. Unexpected '" + queryToString(ast) + "'", ErrorCodes::NOT_IMPLEMENTED); + } ASOF::Inequality inequality = ASOF::getInequality(func.name); if (func.name == "equals" || func.name == "bitEquals" || func.name == "notEquals" || func.name == "bitNotEquals" || inequality != ASOF::Inequality::None) @@ -101,6 +105,8 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as throw Exception("Function " + func.name + " takes two arguments, got '" + func.formatForErrorMessage() + "' instead", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); } + else if (!data.check_function_type_in_join_on_condition) + return; else throw Exception("Expected equality or inequality, got '" + queryToString(ast) + "'", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); @@ -149,7 +155,7 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as { data.inequal_conditions.push_back(ast); } - else + else if (data.check_function_type_in_join_on_condition) { throw Exception(fmt::format("JOIN ON condition {} is not support", queryToString(ast)), ErrorCodes::INVALID_JOIN_ON_EXPRESSION); } diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.h b/src/Interpreters/CollectJoinOnKeysVisitor.h index c39c1dff214..4db2dc8991b 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -60,6 +60,7 @@ class CollectJoinOnKeysMatcher bool ignore_array_join_check_in_join_on_condition{false}; ContextPtr context{nullptr}; ASTs inequal_conditions {}; + bool check_function_type_in_join_on_condition{true}; void addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no, bool null_safe_equal); void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no, diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index f2b3ea4a227..9a16d81a48d 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -126,6 +127,7 @@ #include #include #include +#include #include #include #include @@ -163,6 +165,7 @@ #include #include #include +#include #include #include #include @@ -193,7 +196,6 @@ #include #include #include -#include #include #include @@ -264,6 +266,7 @@ namespace ErrorCodes extern const int NOT_A_LEADER; extern const int INVALID_SETTING_VALUE; extern const int DATABASE_ACCESS_DENIED; + extern const int QUERY_WAS_CANCELLED; } /** Set of known objects (environment), that could be used in query. @@ -358,6 +361,10 @@ struct ContextSharedPart mutable IntermediateResultCachePtr intermediate_result_cache; /// part cache of queries' results. mutable MMappedFileCachePtr mmap_cache; /// Cache of mmapped files to avoid frequent open/map/unmap/close and to reuse from several threads. + mutable OnceFlag cloud_table_definition_cache_initialized; + /// Cache of CloudMergeTree objects to speed up table creation during query execution. + /// Used when send_cacheable_table_definitions is enabled + mutable CloudTableDefinitionCachePtr cloud_table_definition_cache; ProcessList process_list; /// Executing queries at the moment. SegmentSchedulerPtr segment_scheduler; ExchangeStatusTrackerPtr exchange_data_tracker; @@ -372,21 +379,31 @@ struct ContextSharedPart InterserverIOHandler interserver_io_handler; /// Handler for interserver communication. mutable std::optional buffer_flush_schedule_pool; /// A thread pool that can do background flush for Buffer tables. + mutable OnceFlag schedule_pool_initialized; mutable std::optional schedule_pool; /// A thread pool that can run different jobs in background (used in replicated tables) mutable std::optional distributed_schedule_pool; /// A thread pool that can run different jobs in background (used for distributed sends) mutable std::optional message_broker_schedule_pool; /// A thread pool that can run different jobs in background (used for message brokers, like RabbitMQ and Kafka) + mutable OnceFlag readers_initialized; mutable AsynchronousReaderPtr asynchronous_remote_fs_reader; - mutable ThrottlerPtr disk_cache_throttler; - - mutable std::array, SchedulePool::Size> extra_schedule_pools; + struct ExtraSchedulePool + { + OnceFlag is_initialized; + std::unique_ptr pool; + }; + mutable std::array extra_schedule_pools; + std::optional vector_index_loading_thread_pool; + mutable OnceFlag disk_cache_throttler_initialized; + mutable ThrottlerPtr disk_cache_throttler; + mutable OnceFlag preload_throttler_initialized; + mutable ThrottlerPtr preload_throttler; /// may be nullptr + mutable OnceFlag replicated_fetches_throttler_initialized; mutable ThrottlerPtr replicated_fetches_throttler; /// A server-wide throttler for replicated fetches + mutable OnceFlag replicated_sends_throttler_initialized; mutable ThrottlerPtr replicated_sends_throttler; /// A server-wide throttler for replicated sends - mutable ThrottlerPtr preload_throttler; - MultiVersion macros; /// Substitutions extracted from config. std::unique_ptr ddl_worker; /// Process ddl commands from zk. /// Rules for selecting the compression settings, depending on the size of the part. @@ -409,6 +426,7 @@ struct ContextSharedPart mutable CnchTopologyMasterPtr topology_master; mutable ResourceManagerClientPtr rm_client; mutable std::unique_ptr vw_pool; + mutable OnceFlag global_txn_committer_initialized; mutable GlobalTxnCommitterPtr global_txn_committer; mutable GlobalDataManagerPtr global_data_manager; @@ -553,7 +571,11 @@ struct ContextSharedPart cnch_bg_threads_array->shutdown(); if (cnch_txn_coordinator) + { cnch_txn_coordinator->shutdown(); + /// Need to reset cnch_txn_coordinator before schedule_pool reset, otherwise it may core. + cnch_txn_coordinator.reset(); + } if (server_manager) server_manager->shutDown(); @@ -573,6 +595,8 @@ struct ContextSharedPart if (worker_status_manager) worker_status_manager->shutdown(); + if (cnch_catalog) + cnch_catalog->shutDown(); std::unique_ptr delete_system_logs; std::unique_ptr delete_cnch_system_logs; @@ -618,7 +642,7 @@ struct ContextSharedPart distributed_schedule_pool.reset(); message_broker_schedule_pool.reset(); for (auto & p : extra_schedule_pools) - p.reset(); + p.pool.reset(); ddl_worker.reset(); /// Stop trace collector if any @@ -651,10 +675,11 @@ struct ContextSharedPart } }; +ContextData::ContextData() = default; +ContextData::ContextData(const ContextData &) = default; Context::Context() = default; -Context::Context(const Context &) = default; -Context & Context::operator=(const Context &) = default; +Context::Context(const Context & rhs) : ContextData(rhs), std::enable_shared_from_this(rhs) {} SharedContextHolder::SharedContextHolder(SharedContextHolder &&) noexcept = default; SharedContextHolder & SharedContextHolder::operator=(SharedContextHolder &&) = default; @@ -669,10 +694,10 @@ void SharedContextHolder::reset() shared.reset(); } -ContextMutablePtr Context::createGlobal(ContextSharedPart * shared) +ContextMutablePtr Context::createGlobal(ContextSharedPart * shared_part) { auto res = std::shared_ptr(new Context); - res->shared = shared; + res->shared = shared_part; return res; } @@ -690,7 +715,7 @@ SharedContextHolder Context::createShared() void Context::addSessionView(StorageID view_table_id, StoragePtr view_storage) { - auto lock = getLock(); + auto lock = getLocalLock(); if (session_views_cache.find(view_table_id) != session_views_cache.end()) return; session_views_cache.emplace(view_table_id, view_storage); @@ -698,21 +723,24 @@ void Context::addSessionView(StorageID view_table_id, StoragePtr view_storage) StoragePtr Context::getSessionView(StorageID view_table_id) { - auto lock = getLock(); - auto it = session_views_cache.find(view_table_id); - if (it != session_views_cache.end()) - return it->second; - else { - StoragePtr view_storage = DatabaseCatalog::instance().tryGetTable(view_table_id, shared_from_this()); - if (view_storage) - session_views_cache.emplace(view_table_id, view_storage); - return view_storage; + auto lock = getLocalSharedLock(); + auto it = session_views_cache.find(view_table_id); + if (it != session_views_cache.end()) + return it->second; } + + /// should be done outside the context lock, otherwise may deadlock + StoragePtr view_storage = DatabaseCatalog::instance().tryGetTable(view_table_id, shared_from_this()); + + if (view_storage) + addSessionView(view_table_id, view_storage); + return view_storage; } ContextMutablePtr Context::createCopy(const ContextPtr & other) { + auto lock = other->getLocalSharedLock(); return std::shared_ptr(new Context(*other)); } @@ -729,16 +757,11 @@ ContextMutablePtr Context::createCopy(const ContextMutablePtr & other) return createCopy(std::const_pointer_cast(other)); } -void Context::copyFrom(const ContextPtr & other) -{ - *this = *other; -} - Context::~Context() = default; WorkerStatusManagerPtr Context::getWorkerStatusManager() { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->worker_status_manager) shared->worker_status_manager = std::make_shared(global_context); return shared->worker_status_manager; @@ -751,7 +774,7 @@ void Context::updateAdaptiveSchdulerConfig() WorkerStatusManagerPtr Context::getWorkerStatusManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->worker_status_manager) shared->worker_status_manager = std::make_shared(global_context); return shared->worker_status_manager; @@ -797,6 +820,9 @@ InterserverIOHandler & Context::getInterserverIOHandler() ReadSettings Context::getReadSettings() const { ReadSettings res; + + res.remote_fs_read_failed_injection = settings.remote_fs_read_failed_injection; + res.remote_fs_prefetch = settings.remote_filesystem_read_prefetch; res.local_fs_prefetch = settings.local_filesystem_read_prefetch; res.remote_read_log = settings.enable_remote_read_log ? getRemoteReadLog().get() : nullptr; @@ -826,6 +852,22 @@ std::unique_lock Context::getLock() const return std::unique_lock(shared->mutex); } +/// NOTE: it's an non-recursive lock, caller should be aware of the deadlock risk +std::unique_lock Context::getLocalLock() const +{ + ProfileEvents::increment(ProfileEvents::ContextLock); + CurrentMetrics::Increment increment{CurrentMetrics::ContextLockWait}; + return std::unique_lock(mutex); +} + +/// NOTE: it's an non-recursive lock, caller should be aware of the deadlock risk +std::shared_lock Context::getLocalSharedLock() const +{ + ProfileEvents::increment(ProfileEvents::ContextLock); + CurrentMetrics::Increment increment{CurrentMetrics::ContextLockWait}; + return std::shared_lock(mutex); +} + ProcessList & Context::getProcessList() { return shared->process_list; @@ -869,7 +911,7 @@ const ReplicatedFetchList & Context::getReplicatedFetchList() const SegmentSchedulerPtr Context::getSegmentScheduler() { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->segment_scheduler) shared->segment_scheduler = std::make_shared(); return shared->segment_scheduler; @@ -877,7 +919,7 @@ SegmentSchedulerPtr Context::getSegmentScheduler() SegmentSchedulerPtr Context::getSegmentScheduler() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->segment_scheduler) shared->segment_scheduler = std::make_shared(); return shared->segment_scheduler; @@ -885,13 +927,13 @@ SegmentSchedulerPtr Context::getSegmentScheduler() const void Context::setMockExchangeDataTracker(ExchangeStatusTrackerPtr exchange_data_tracker) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->exchange_data_tracker = exchange_data_tracker; } ExchangeStatusTrackerPtr Context::getExchangeDataTracker() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->exchange_data_tracker) { if (shared->server_type == ServerType::cnch_server) @@ -913,7 +955,7 @@ void Context::initDiskExchangeDataManager() const DiskExchangeDataManagerPtr Context::getDiskExchangeDataManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->disk_exchange_data_manager) { const auto & bsp_conf = getRootConfig().bulk_synchronous_parallel; @@ -944,13 +986,13 @@ DiskExchangeDataManagerPtr Context::getDiskExchangeDataManager() const void Context::setMockDiskExchangeDataManager(DiskExchangeDataManagerPtr disk_exchange_data_manager) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->disk_exchange_data_manager = disk_exchange_data_manager; } BindingCacheManagerPtr Context::getGlobalBindingCacheManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (this->shared->global_binding_cache_manager) return this->shared->global_binding_cache_manager; return nullptr; @@ -958,7 +1000,7 @@ BindingCacheManagerPtr Context::getGlobalBindingCacheManager() const BindingCacheManagerPtr Context::getGlobalBindingCacheManager() { - auto lock = getLock(); + auto lock = getLock(); // checked if (this->shared->global_binding_cache_manager) return this->shared->global_binding_cache_manager; return nullptr; @@ -966,7 +1008,7 @@ BindingCacheManagerPtr Context::getGlobalBindingCacheManager() void Context::setGlobalBindingCacheManager(std::shared_ptr && manager) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->global_binding_cache_manager) throw Exception("Global binding cache has been already created.", ErrorCodes::LOGICAL_ERROR); shared->global_binding_cache_manager = std::move(manager); @@ -974,7 +1016,7 @@ void Context::setGlobalBindingCacheManager(std::shared_ptr std::shared_ptr Context::getSessionBindingCacheManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!this->session_binding_cache_manager) { this->session_binding_cache_manager = std::make_shared(); @@ -985,7 +1027,7 @@ std::shared_ptr Context::getSessionBindingCacheManager() co QueueManagerPtr Context::getQueueManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->queue_manager) shared->queue_manager = std::make_shared(global_context); return shared->queue_manager; @@ -993,7 +1035,7 @@ QueueManagerPtr Context::getQueueManager() const AsyncQueryManagerPtr Context::getAsyncQueryManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->async_query_manager) shared->async_query_manager = std::make_shared(global_context); return shared->async_query_manager; @@ -1071,18 +1113,20 @@ CnchWorkerResourcePtr Context::tryGetCnchWorkerResource() const void Context::initCnchWorkerResource() { - worker_resource = std::make_shared(); + auto lock = getLocalLock(); + if (!worker_resource) + worker_resource = std::make_shared(); } void Context::setExtendedProfileInfo(const ExtendedProfileInfo & source) const { - auto lock = getLock(); + auto lock = getLocalLock(); extended_profile_info = source; } ExtendedProfileInfo Context::getExtendedProfileInfo() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); return extended_profile_info; } @@ -1096,57 +1140,60 @@ String Context::resolveDatabase(const String & database_name) const String Context::getPath() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->path; } String Context::getFlagsPath() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->flags_path; } String Context::getUserFilesPath() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->user_files_path; } String Context::getDictionariesLibPath() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->dictionaries_lib_path; } String Context::getMetastorePath() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->metastore_path; } VolumePtr Context::getTemporaryVolume() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->tmp_volume; } TemporaryDataOnDiskScopePtr Context::getTempDataOnDisk() const { - auto lock = getLock(); - if (this->temp_data_on_disk) - return this->temp_data_on_disk; + { + auto lock = getLocalSharedLock(); + if (this->temp_data_on_disk) + return this->temp_data_on_disk; + } + auto lock = getLock(); // checked return shared->temp_data_on_disk; } void Context::setTempDataOnDisk(TemporaryDataOnDiskScopePtr temp_data_on_disk_) { - auto lock = getLock(); + auto lock = getLocalLock(); this->temp_data_on_disk = std::move(temp_data_on_disk_); } void Context::setPath(const String & path) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->path = path; @@ -1331,38 +1378,43 @@ void Context::setTemporaryStoragePolicy(const String & policy_name, size_t max_s void Context::setFlagsPath(const String & path) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->flags_path = path; } void Context::setUserFilesPath(const String & path) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->user_files_path = path; } void Context::setDictionariesLibPath(const String & path) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->dictionaries_lib_path = path; } void Context::setMetastorePath(const String & path) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->metastore_path = path; } void Context::setConfig(const ConfigurationPtr & config) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->config = config; shared->access_control_manager.setExternalAuthenticatorsConfig(*shared->config); } const Poco::Util::AbstractConfiguration & Context::getConfigRef() const { - auto lock = getLock(); + auto lock = getLock(); // checked + return shared->config ? *shared->config : Poco::Util::Application::instance().config(); +} + +const Poco::Util::AbstractConfiguration & Context::getConfigRefWithLock(const std::unique_lock &) const +{ return shared->config ? *shared->config : Poco::Util::Application::instance().config(); } @@ -1412,13 +1464,13 @@ const AccessControlManager & Context::getAccessControlManager() const void Context::setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->access_control_manager.setExternalAuthenticatorsConfig(config); } std::unique_ptr Context::makeGSSAcceptorContext() const { - auto lock = getLock(); + auto lock = getLock(); // checked return std::make_unique(shared->access_control_manager.getExternalAuthenticators().getKerberosParams()); } @@ -1442,7 +1494,7 @@ void Context::updateAdditionalServices(const Poco::Util::AbstractConfiguration & void Context::setUsersConfig(const ConfigurationPtr & config) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->users_config = config; shared->access_control_manager.setUsersConfig(*shared->users_config); if (getServerType() == ServerType::cnch_server || getServerType() == ServerType::cnch_worker) @@ -1457,7 +1509,7 @@ void Context::setUsersConfig(const ConfigurationPtr & config) ConfigurationPtr Context::getUsersConfig() { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->users_config; } @@ -1483,7 +1535,7 @@ void Context::setVWCustomizedSettings(VWCustomizedSettingsPtr vw_customized_sett } -void Context::initResourceGroupManager(const ConfigurationPtr & config) +void Context::initResourceGroupManager(const ConfigurationPtr & ) { LOG_DEBUG(shared->log, "Skip initialize resource group"); @@ -1514,10 +1566,14 @@ void Context::initResourceGroupManager(const ConfigurationPtr & config) void Context::setResourceGroup(const IAST * ast) { - if (auto lock = getLock(); shared->resource_group_manager && shared->resource_group_manager->isInUse()) - resource_group = shared->resource_group_manager->selectGroup(*this, ast); - else - resource_group = nullptr; + IResourceGroup * group = nullptr; + { + auto lock = getLock(); // checked + if (shared->resource_group_manager && shared->resource_group_manager->isInUse()) + group = shared->resource_group_manager->selectGroup(*this, ast); + } + auto lock = getLocalLock(); + resource_group = group; } IResourceGroup * Context::tryGetResourceGroup() const @@ -1550,26 +1606,35 @@ void Context::stopResourceGroup() void Context::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address) { - client_info.current_user = credentials.getUserName(); - client_info.current_address = address; - - //#if defined(ARCADIA_BUILD) - /// This is harmful field that is used only in foreign "Arcadia" build. - client_info.current_password.clear(); - if (const auto * basic_credentials = dynamic_cast(&credentials)) - client_info.current_password = basic_credentials->getPassword(); - //#endif - /// Find a user with such name and check the credentials. /// NOTE: getAccessControlManager().login and other AccessControl's functions may require some IO work, /// so Context::getLock() must be unlocked while we're doing this. auto new_user_id = getAccessControlManager().login(credentials, address.host()); - auto new_access = getAccessControlManager().getContextAccess( - new_user_id, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info, - has_tenant_id_in_username ? tenant_id : "", - getServerType() != ServerType::cnch_server); - auto lock = getLock(); + ContextAccessParams params; + { + auto lock = getLocalLock(); + client_info.current_user = credentials.getUserName(); + client_info.current_address = address; + + //#if defined(ARCADIA_BUILD) + /// This is harmful field that is used only in foreign "Arcadia" build. + client_info.current_password.clear(); + if (const auto * basic_credentials = dynamic_cast(&credentials)) + client_info.current_password = basic_credentials->getPassword(); + //#endif + + String tenant = getTenantId(); + params = getAccessControlManager().getContextAccessParams( + new_user_id, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info, + tenant, + has_tenant_id_in_username, + getServerType() != ServerType::cnch_server); + } + + auto new_access = getAccessControlManager().getContextAccess(params); + + auto lock = getLocalLock(); user_id = new_user_id; access = std::move(new_access); @@ -1578,7 +1643,7 @@ void Context::setUser(const Credentials & credentials, const Poco::Net::SocketAd current_roles.clear(); use_default_roles = true; - applySettingsChanges(default_profile_info->settings); + applySettingsChangesWithLock(default_profile_info->settings, /*internal*/ true, lock); } String Context::formatUserName(const String & name) @@ -1621,7 +1686,7 @@ std::shared_ptr Context::getUser() const void Context::setQuotaKey(String quota_key_) { - auto lock = getLock(); + auto lock = getLocalLock(); client_info.quota_key = std::move(quota_key_); } @@ -1632,29 +1697,28 @@ String Context::getUserName() const std::optional Context::getUserID() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); return user_id; } - void Context::setCurrentRoles(const std::vector & current_roles_) { - auto lock = getLock(); + auto lock = getLocalLock(); if (current_roles == current_roles_ && !use_default_roles) return; current_roles = current_roles_; use_default_roles = false; - calculateAccessRights(); + calculateAccessRightsWithLock(lock); } void Context::setCurrentRolesDefault() { - auto lock = getLock(); + auto lock = getLocalLock(); if (use_default_roles) return; current_roles.clear(); use_default_roles = true; - calculateAccessRights(); + calculateAccessRightsWithLock(lock); } boost::container::flat_set Context::getCurrentRoles() const @@ -1673,13 +1737,15 @@ std::shared_ptr Context::getRolesInfo() const } -void Context::calculateAccessRights() +void Context::calculateAccessRightsWithLock(const std::unique_lock &) { - auto lock = getLock(); if (user_id) - access = getAccessControlManager().getContextAccess( + { + auto params = getAccessControlManager().getContextAccessParams( *user_id, current_roles, use_default_roles, settings, current_database, client_info, - has_tenant_id_in_username ? tenant_id : "", false); + tenant_id, has_tenant_id_in_username, false); + access = getAccessControlManager().getContextAccess(params); + } } @@ -1747,33 +1813,29 @@ void Context::checkAccess(const AccessRightsElements & elements) const void Context::grantAllAccess() { - auto lock = getLock(); + auto lock = getLocalLock(); access = ContextAccess::getFullAccess(); } std::shared_ptr Context::getAccess() const { - auto lock = getLock(); // If its a worker node and prefer_cnch_catalog is false, this is a query from server // and access check has already been done in server. We can return full access. if (getServerType() == ServerType::cnch_worker && !getSettingsRef().prefer_cnch_catalog) return ContextAccess::getFullAccess(); + + auto lock = getLocalSharedLock(); return access ? access : ContextAccess::getFullAccess(); } void Context::checkAeolusTableAccess(const String & database_name, const String & table_name) const { - String table_names = this->getSettingsRef().access_table_names; - if (table_names.empty()) - return; - std::vector tables; - boost::split(tables, table_names, boost::is_any_of(" ,")); - /// avoid check temporary table. - if (database_name == DatabaseCatalog::TEMPORARY_DATABASE) + /// avoid check temporary and system table. + if (database_name == DatabaseCatalog::TEMPORARY_DATABASE || database_name == DatabaseCatalog::SYSTEM_DATABASE) return; String full_table_name = database_name.empty() ? table_name : database_name+"."+table_name; - if (std::find(tables.begin(), tables.end(), full_table_name) == tables.end()) + if (!aeolusCheck(*this, full_table_name)) { throw Exception("Access denied to " + full_table_name , ErrorCodes::DATABASE_ACCESS_DENIED); } @@ -1781,18 +1843,25 @@ void Context::checkAeolusTableAccess(const String & database_name, const String ASTPtr Context::getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType type) const { - auto lock = getLock(); - auto initial_condition = initial_row_policy ? initial_row_policy->getCondition(database, table_name, type) : nullptr; - return getAccess()->getRowPolicyCondition(database, table_name, type, initial_condition); + ASTPtr condition; + { + auto lock = getLocalSharedLock(); + condition = initial_row_policy ? initial_row_policy->getCondition(database, table_name, type) : nullptr; + } + return getAccess()->getRowPolicyCondition(database, table_name, type, condition); } void Context::setInitialRowPolicy() { - auto lock = getLock(); - auto initial_user_id = getAccessControlManager().find(client_info.initial_user); - initial_row_policy = nullptr; - if (initial_user_id) - initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}); + String initial_user_copy; + { + auto lock = getLocalLock(); + initial_user_copy = client_info.initial_user; + } + auto initial_user_id = getAccessControlManager().find(initial_user_copy); + auto initial_row_policy_local = initial_user_id ? getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}) : nullptr; + auto lock = getLocalLock(); + initial_row_policy = initial_row_policy_local; } @@ -1807,13 +1876,12 @@ std::optional Context::getQuotaUsage() const return getAccess()->getQuotaUsage(); } -void Context::setCurrentProfile(const String & profile_name) +void Context::setCurrentProfileWithLock(const String & profile_name, const std::unique_lock & lock) { - auto lock = getLock(); try { UUID profile_id = getAccessControlManager().getID(profile_name); - setCurrentProfile(profile_id); + setCurrentProfileWithLock(profile_id, lock); } catch (Exception & e) { @@ -1822,25 +1890,40 @@ void Context::setCurrentProfile(const String & profile_name) } } -void Context::setCurrentProfile(const UUID & profile_id) +void Context::setCurrentProfileWithLock(const UUID & profile_id, const std::unique_lock & lock) { - auto lock = getLock(); auto profile_info = getAccessControlManager().getSettingsProfileInfo(profile_id); - checkSettingsConstraints(profile_info->settings); - applySettingsChanges(profile_info->settings); - settings_constraints_and_current_profiles = profile_info->getConstraintsAndProfileIDs(settings_constraints_and_current_profiles); + setCurrentProfileWithLock(*profile_info, lock); +} + +void Context::setCurrentProfileWithLock(const SettingsProfilesInfo & profiles_info, const std::unique_lock & lock) +{ + checkSettingsConstraintsWithLock(profiles_info.settings); + applySettingsChangesWithLock(profiles_info.settings, true, lock); + settings_constraints_and_current_profiles = profiles_info.getConstraintsAndProfileIDs(settings_constraints_and_current_profiles); } +void Context::setCurrentProfile(const String & profile_name) +{ + auto lock = getLocalLock(); + setCurrentProfileWithLock(profile_name, lock); +} + +void Context::setCurrentProfile(const UUID & profile_id) +{ + auto lock = getLocalLock(); + setCurrentProfileWithLock(profile_id, lock); +} std::vector Context::getCurrentProfiles() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); return settings_constraints_and_current_profiles->current_profiles; } std::vector Context::getEnabledProfiles() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); return settings_constraints_and_current_profiles->enabled_profiles; } @@ -1866,7 +1949,7 @@ const Block & Context::getScalar(const String & name) const Tables Context::getExternalTables() const { assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); - auto lock = getLock(); + auto lock = getLocalSharedLock(); Tables res; for (const auto & table : external_tables_mapping) @@ -1891,7 +1974,7 @@ Tables Context::getExternalTables() const void Context::addExternalTable(const String & table_name, TemporaryTableHolder && temporary_table) { assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); - auto lock = getLock(); + auto lock = getLocalLock(); if (external_tables_mapping.end() != external_tables_mapping.find(table_name)) throw Exception("Temporary table " + backQuoteIfNeed(table_name) + " already exists.", ErrorCodes::TABLE_ALREADY_EXISTS); external_tables_mapping.emplace(table_name, std::make_shared(std::move(temporary_table))); @@ -1903,7 +1986,7 @@ std::shared_ptr Context::removeExternalTable(const String assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); std::shared_ptr holder; { - auto lock = getLock(); + auto lock = getLocalLock(); auto iter = external_tables_mapping.find(table_name); if (iter == external_tables_mapping.end()) return {}; @@ -1945,7 +2028,7 @@ void Context::addQueryAccessInfo( void Context::addQueryFactoriesInfo(QueryLogFactories factory_type, const String & created_object) const { assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); - auto lock = getLock(); + auto lock = getLocalLock(); switch (factory_type) { @@ -2014,16 +2097,41 @@ StoragePtr Context::getViewSource() const return view_source; } +void Context::setSettingWithLock(const StringRef & name, const String & value, const std::unique_lock & lock) +{ + if (name == "profile") + { + setCurrentProfileWithLock(value, lock); + return; + } + settings.set(std::string_view{name}, value); + + if (ContextAccessParams::dependsOnSettingName(name.toView())) + calculateAccessRightsWithLock(lock); +} + +void Context::setSettingWithLock(const StringRef & name, const Field & value, const std::unique_lock & lock) +{ + if (name == "profile") + { + setCurrentProfileWithLock(value.safeGet(), lock); + return; + } + settings.set(std::string_view{name}, value); + + if (ContextAccessParams::dependsOnSettingName(name.toView())) + calculateAccessRightsWithLock(lock); +} + Settings Context::getSettings() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); return settings; } - void Context::setSettings(const Settings & settings_) { - auto lock = getLock(); + auto lock = getLocalLock(); auto old_readonly = settings.readonly; auto old_allow_ddl = settings.allow_ddl; auto old_allow_introspection_functions = settings.allow_introspection_functions; @@ -2032,50 +2140,43 @@ void Context::setSettings(const Settings & settings_) if ((settings.readonly != old_readonly) || (settings.allow_ddl != old_allow_ddl) || (settings.allow_introspection_functions != old_allow_introspection_functions)) - calculateAccessRights(); + calculateAccessRightsWithLock(lock); } - void Context::setSetting(const StringRef & name, const String & value) { - auto lock = getLock(); - if (name == "profile") - { - setCurrentProfile(value); - return; - } - settings.set(std::string_view{name}, value); - - if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions") - calculateAccessRights(); + auto lock = getLocalLock(); + setSettingWithLock(name, value, lock); } - void Context::setSetting(const StringRef & name, const Field & value) { - auto lock = getLock(); - if (name == "profile") + auto lock = getLocalLock(); + setSettingWithLock(name, value, lock); +} + +void Context::applySettingChangeWithLock(const SettingChange & change, const std::unique_lock & lock) +{ + try { - setCurrentProfile(value.safeGet()); - return; + setSettingWithLock(change.name, change.value, lock); + } + catch (Exception & e) + { + e.addMessage(fmt::format( + "in attempt to set the value of setting '{}' to {}", change.name, applyVisitor(FieldVisitorToString(), change.value))); + throw; } - settings.set(std::string_view{name}, value); - - if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions") - calculateAccessRights(); } -void Context::applySettingsChanges(const JSON & changes) +void Context::applySettingsChangesWithLock(const SettingsChanges & changes, bool internal, const std::unique_lock & lock) { - auto lock = getLock(); - // set ansi related settings first, as they may be overwritten explicitly later std::optional dialect_type_opt; - std::function find_dialect_type_if_any = [&](const SettingsChanges & setting_changes) - { - for (const auto & change: setting_changes) + std::function find_dialect_type_if_any = [&](const SettingsChanges & setting_changes) { + for (const auto & change : setting_changes) { - if (change.name == "profile") + if (change.name == "profile" && getClientInfo().query_kind != ClientInfo::QueryKind::INITIAL_QUERY) { UUID profile_id = getAccessControlManager().getID(change.value.safeGet()); auto profile_info = getAccessControlManager().getSettingsProfileInfo(profile_id); @@ -2094,51 +2195,39 @@ void Context::applySettingsChanges(const JSON & changes) } } }; + find_dialect_type_if_any(changes); - for (JSON::iterator it = changes.begin(); it != changes.end(); ++it) + // NOTE: tenanted users connect to server using tenant id given in connection info. + // allow only whitelisted settings for tenanted users + if (is_tenant_user() && !internal && !isInternalQuery () && getIsRestrictSettingsToWhitelist() && (getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY || session_context.lock().get() == this) && !getCurrentTenantId().empty()) { - auto name = it.getRawName().toView(); - auto value = it.getValue().getRawString().toView(); - Field value_field(value); - auto value_str = value_field.safeGet(); - UUID profile_id = getAccessControlManager().getID(value_str); - auto profile_info = getAccessControlManager().getSettingsProfileInfo(profile_id); - checkSettingsConstraints(profile_info->settings); - if (name == "profile") - { - find_dialect_type_if_any(profile_info->settings); - } - - if (name == "dialect_type") - { - if (!dialect_type_opt) - dialect_type_opt = value; - else if (*dialect_type_opt != value) - throw Exception(ErrorCodes::INVALID_SETTING_VALUE, "Multiple dialect_type value found"); - } - - try - { - setSetting(StringRef(name), value_field); - } - catch (Exception & e) + for (const auto & change : changes) { - e.addMessage(fmt::format("in attempt to set the value of setting '{}' to {}", - name, applyVisitor(FieldVisitorToString(), value_field))); - throw; + if (!SettingsChanges::WHITELIST_SETTINGS.contains(change.name)) + throw Exception(ErrorCodes::UNKNOWN_SETTING, "Unknown or disabled setting " + change.name + + "for tenant user. Contact the admin about whether it is needed to add it to tenant_whitelist_settings" + " in configuration"); } } // skip if a previous setting change is in process - bool apply_ansi_related_settings = dialect_type_opt && !settings.dialect_type.pending; + // skip if current and target are same + bool apply_ansi_related_settings = dialect_type_opt && !settings.dialect_type.pending + && settings.dialect_type.value != SettingFieldDialectTypeTraits::fromString(*dialect_type_opt); if (apply_ansi_related_settings) { - setSetting("dialect_type", *dialect_type_opt); + setSettingWithLock("dialect_type", *dialect_type_opt, lock); ANSI::onSettingChanged(&settings); settings.dialect_type.pending = true; } + for (const SettingChange & change : changes) + { + if (change.name == "profile" && getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + continue; + applySettingChangeWithLock(change, lock); + } applySettingsQuirks(settings); if (apply_ansi_related_settings) @@ -2162,102 +2251,73 @@ void Context::applySettingChange(const SettingChange & change) void Context::applySettingsChanges(const SettingsChanges & changes, bool internal) { - auto lock = getLock(); - - // set ansi related settings first, as they may be overwritten explicitly later - std::optional dialect_type_opt; - std::function find_dialect_type_if_any = [&](const SettingsChanges & setting_changes) { - for (const auto & change : setting_changes) - { - if (change.name == "profile") - { - UUID profile_id = getAccessControlManager().getID(change.value.safeGet()); - auto profile_info = getAccessControlManager().getSettingsProfileInfo(profile_id); - - find_dialect_type_if_any(profile_info->settings); - } - - if (change.name == "dialect_type") - { - auto value_str = change.value.safeGet(); - - if (!dialect_type_opt) - dialect_type_opt = value_str; - else if (*dialect_type_opt != value_str) - throw Exception(ErrorCodes::INVALID_SETTING_VALUE, "Multiple dialect_type value found"); - } - } - }; - - // NOTE: tenanted users connect to server using tenant id given in connection info. - // allow only whitelisted settings for tenanted users - if (is_tenant_user() && !internal && !isInternalQuery () && getIsRestrictSettingsToWhitelist() && getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY && !getCurrentTenantId().empty()) - { - for (const auto & change : changes) - { - if (!SettingsChanges::WHITELIST_SETTINGS.contains(change.name)) - throw Exception(ErrorCodes::UNKNOWN_SETTING, "Unknown or disabled setting " + change.name + - "for tenant user. Contact the admin about whether it is needed to add it to tenant_whitelist_settings" - " in configuration"); - } - } - - find_dialect_type_if_any(changes); - - // skip if a previous setting change is in process - bool apply_ansi_related_settings = dialect_type_opt && !settings.dialect_type.pending; + auto lock = getLocalLock(); + applySettingsChangesWithLock(changes, internal, lock); +} - if (apply_ansi_related_settings) - { - setSetting("dialect_type", *dialect_type_opt); - ANSI::onSettingChanged(&settings); - settings.dialect_type.pending = true; - } +void Context::checkSettingsConstraintsWithLock(const SettingChange & change) const +{ + getSettingsConstraintsAndCurrentProfilesWithLock()->constraints.check(settings, change); +} - for (const SettingChange & change : changes) - applySettingChange(change); - applySettingsQuirks(settings); +void Context::checkSettingsConstraintsWithLock(const SettingsChanges & changes) const +{ + getSettingsConstraintsAndCurrentProfilesWithLock()->constraints.check(settings, changes); +} - if (apply_ansi_related_settings) - settings.dialect_type.pending = false; +void Context::checkSettingsConstraintsWithLock(SettingsChanges & changes) const +{ + getSettingsConstraintsAndCurrentProfilesWithLock()->constraints.check(settings, changes); } +void Context::clampToSettingsConstraintsWithLock(SettingsChanges & changes) const +{ + getSettingsConstraintsAndCurrentProfilesWithLock()->constraints.clamp(settings, changes); +} void Context::checkSettingsConstraints(const SettingChange & change) const { - getSettingsConstraintsAndCurrentProfiles()->constraints.check(settings, change); + auto lock = getLocalSharedLock(); + checkSettingsConstraintsWithLock(change); } void Context::checkSettingsConstraints(const SettingsChanges & changes) const { - getSettingsConstraintsAndCurrentProfiles()->constraints.check(settings, changes); + auto lock = getLocalSharedLock(); + checkSettingsConstraintsWithLock(changes); } void Context::checkSettingsConstraints(SettingsChanges & changes) const { - getSettingsConstraintsAndCurrentProfiles()->constraints.check(settings, changes); + auto lock = getLocalSharedLock(); + checkSettingsConstraintsWithLock(changes); } void Context::clampToSettingsConstraints(SettingsChanges & changes) const { - getSettingsConstraintsAndCurrentProfiles()->constraints.clamp(settings, changes); + auto lock = getLocalSharedLock(); + clampToSettingsConstraintsWithLock(changes); } -std::shared_ptr Context::getSettingsConstraintsAndCurrentProfiles() const +std::shared_ptr Context::getSettingsConstraintsAndCurrentProfilesWithLock() const { - auto lock = getLock(); if (settings_constraints_and_current_profiles) return settings_constraints_and_current_profiles; static auto no_constraints_or_profiles = std::make_shared(getAccessControlManager()); return no_constraints_or_profiles; } +std::shared_ptr Context::getSettingsConstraintsAndCurrentProfiles() const +{ + auto lock = getLocalSharedLock(); + return getSettingsConstraintsAndCurrentProfilesWithLock(); +} String Context::getCurrentDatabase() const { String tenant_db; { - auto lock = getLock(); + auto lock = getLocalLock(); tenant_db = current_database; } @@ -2312,7 +2372,7 @@ void Context::setCurrentDatabaseNameInGlobalContext(const String & name) throw Exception( "Cannot set current database for non global context, this method should be used during server initialization", ErrorCodes::LOGICAL_ERROR); - auto lock = getLock(); + auto lock = getLocalLock(); if (!current_database.empty()) throw Exception("Default database name cannot be changed in global context without server restart", ErrorCodes::LOGICAL_ERROR); @@ -2322,10 +2382,10 @@ void Context::setCurrentDatabaseNameInGlobalContext(const String & name) void Context::setCurrentDatabase(const String & name) { - DatabaseCatalog::instance().assertDatabaseExists(name, hasQueryContext() ? getQueryContext() : shared_from_this()); - auto lock = getLock(); + DatabaseCatalog::instance().assertDatabaseExists(name, hasQueryContext() ? getQueryContext(): shared_from_this()); + auto lock = getLocalLock(); current_database = name; - calculateAccessRights(); + calculateAccessRightsWithLock(lock); } void Context::setCurrentDatabase(const String & name, ContextPtr local_context) @@ -2351,7 +2411,7 @@ void Context::setCurrentDatabase(const String & name, ContextPtr local_context) } auto db_name_with_tenant_id = appendTenantIdOnly(database_opt.value()); - auto lock = getLock(); + auto lock = getLocalLock(); if(use_cnch_catalog){ current_catalog = ""; current_database = db_name_with_tenant_id; @@ -2361,14 +2421,14 @@ void Context::setCurrentDatabase(const String & name, ContextPtr local_context) current_database = database_opt.value(); LOG_TRACE(shared->log, "use external catalog, catalog_name: {}, db_name: {}", current_catalog, current_database); } - calculateAccessRights(); + calculateAccessRightsWithLock(lock); } void Context::setCurrentCatalog(const String & catalog_name) { if (catalog_name == "" || catalog_name == "cnch") { - auto lock = getLock(); + auto lock = getLocalLock(); current_catalog = ""; current_database = ""; return; @@ -2378,7 +2438,7 @@ void Context::setCurrentCatalog(const String & catalog_name) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "catalog {} does not exist", catalog_name); } - auto lock = getLock(); + auto lock = getLocalLock(); current_catalog = catalog_name; current_database = "default"; } @@ -2450,6 +2510,9 @@ void Context::killCurrentQuery() { process_list_elem->cancelQuery(true, false); } + getSegmentScheduler()->cancelPlanSegmentsFromCoordinator( + client_info.initial_query_id, ErrorCodes::QUERY_WAS_CANCELLED, "Cancelled by Client.", shared_from_this()); + getPlanSegmentProcessList().tryCancelPlanSegmentGroup(client_info.initial_query_id); }; String Context::getDefaultFormat() const @@ -2619,7 +2682,7 @@ void Context::loadDictionaries(const Poco::Util::AbstractConfiguration & config) SynonymsExtensions & Context::getSynonymsExtensions() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->synonyms_extensions) shared->synonyms_extensions.emplace(getConfigRef()); @@ -2629,7 +2692,7 @@ void Context::loadDictionaries(const Poco::Util::AbstractConfiguration & config) Lemmatizers & Context::getLemmatizers() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->lemmatizers) shared->lemmatizers.emplace(getConfigRef()); @@ -2649,6 +2712,16 @@ ProgressCallback Context::getProgressCallback() const return progress_callback; } +void Context::setSendTCPProgress(std::function callback) +{ + send_tcp_progress = callback; +} + +std::function Context::getSendTCPProgress() const +{ + return send_tcp_progress; +} + void Context::setProcessListEntry(std::shared_ptr process_list_entry_) { process_list_entry = process_list_entry_; @@ -2676,13 +2749,13 @@ std::weak_ptr Context::getPlanSegmentProcessListEnt void Context::setProcessorProfileElementConsumer( std::shared_ptr> processor_log_element_consumer_) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->processor_log_element_consumer = processor_log_element_consumer_; } std::shared_ptr> Context::getProcessorProfileElementConsumer() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->processor_log_element_consumer) return {}; @@ -2712,7 +2785,7 @@ QueryStatus * Context::getProcessListElement() const void Context::setNvmCache(const Poco::Util::AbstractConfiguration &config) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->nvm_cache) throw Exception("Nvmcache cache has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -2749,27 +2822,27 @@ void Context::setNvmCache(const Poco::Util::AbstractConfiguration &config) NvmCachePtr Context::getNvmCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->nvm_cache; } void Context::dropNvmCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->nvm_cache) shared->nvm_cache->reset(); } void Context::setFooterCache(size_t max_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (max_size_in_bytes) ArrowFooterCache::initialize(max_size_in_bytes); } void Context::setUncompressedCache(size_t max_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->uncompressed_cache) throw Exception("Uncompressed cache has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -2780,14 +2853,14 @@ void Context::setUncompressedCache(size_t max_size_in_bytes) UncompressedCachePtr Context::getUncompressedCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->uncompressed_cache; } void Context::dropUncompressedCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->uncompressed_cache) shared->uncompressed_cache->reset(); } @@ -2795,7 +2868,7 @@ void Context::dropUncompressedCache() const void Context::setMarkCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->mark_cache) throw Exception("Mark cache has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -2805,20 +2878,33 @@ void Context::setMarkCache(size_t cache_size_in_bytes) MarkCachePtr Context::getMarkCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->mark_cache; } void Context::dropMarkCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->mark_cache) shared->mark_cache->reset(); } +std::shared_ptr Context::tryGetCloudTableDefinitionCache() const +{ + if (hasSessionTimeZone()) + return nullptr; + callOnce(shared->cloud_table_definition_cache_initialized, [&] { + const Poco::Util::AbstractConfiguration & config = getConfigRef(); + auto cache_size = config.getUInt(".cloud_table_definition_cache_size", 50000); + if (getServerType() == ServerType::cnch_worker && cache_size) + shared->cloud_table_definition_cache = std::make_shared(cache_size); + }); + return shared->cloud_table_definition_cache; +} + void Context::setQueryCache(const Poco::Util::AbstractConfiguration & config) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->query_cache) throw Exception(ErrorCodes::LOGICAL_ERROR, "Query cache has been already created."); @@ -2829,27 +2915,27 @@ void Context::setQueryCache(const Poco::Util::AbstractConfiguration & config) void Context::updateQueryCacheConfiguration(const Poco::Util::AbstractConfiguration & config) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->query_cache) shared->query_cache->updateConfiguration(config); } QueryCachePtr Context::getQueryCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->query_cache; } void Context::dropQueryCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->query_cache) shared->query_cache->reset(); } void Context::setIntermediateResultCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->intermediate_result_cache) throw Exception("Intermediate result cache has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -2859,20 +2945,20 @@ void Context::setIntermediateResultCache(size_t cache_size_in_bytes) IntermediateResultCachePtr Context::getIntermediateResultCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->intermediate_result_cache; } void Context::dropIntermediateResultCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->intermediate_result_cache) shared->intermediate_result_cache->reset(); } void Context::setMMappedFileCache(size_t cache_size_in_num_entries) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->mmap_cache) throw Exception("Mapped file cache has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -2882,13 +2968,13 @@ void Context::setMMappedFileCache(size_t cache_size_in_num_entries) MMappedFileCachePtr Context::getMMappedFileCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->mmap_cache; } void Context::dropMMappedFileCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->mmap_cache) shared->mmap_cache->reset(); } @@ -2896,7 +2982,7 @@ void Context::dropMMappedFileCache() const void Context::dropCaches() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->uncompressed_cache) shared->uncompressed_cache->reset(); @@ -2923,7 +3009,7 @@ void Context::setMergeSchedulerSettings(const Poco::Util::AbstractConfiguration BackgroundSchedulePool & Context::getBufferFlushSchedulePool() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->buffer_flush_schedule_pool) shared->buffer_flush_schedule_pool.emplace( settings.background_buffer_flush_schedule_pool_size, CurrentMetrics::BackgroundBufferFlushSchedulePoolTask, "BgBufSchPool"); @@ -2974,15 +3060,18 @@ BackgroundTaskSchedulingSettings Context::getBackgroundMoveTaskSchedulingSetting BackgroundSchedulePool & Context::getSchedulePool() const { - auto lock = getLock(); - if (!shared->schedule_pool) - shared->schedule_pool.emplace(settings.background_schedule_pool_size, CurrentMetrics::BackgroundSchedulePoolTask, "BgSchPool"); + callOnce(shared->schedule_pool_initialized, [&]{ + shared->schedule_pool.emplace( + settings.background_schedule_pool_size, + CurrentMetrics::BackgroundSchedulePoolTask, + "BgSchPool"); + }); return *shared->schedule_pool; } BackgroundSchedulePool & Context::getDistributedSchedulePool() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->distributed_schedule_pool) shared->distributed_schedule_pool.emplace( settings.background_distributed_schedule_pool_size, CurrentMetrics::BackgroundDistributedSchedulePoolTask, "BgDistSchPool"); @@ -2991,7 +3080,7 @@ BackgroundSchedulePool & Context::getDistributedSchedulePool() const BackgroundSchedulePool & Context::getMessageBrokerSchedulePool() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->message_broker_schedule_pool) shared->message_broker_schedule_pool.emplace( settings.background_message_broker_schedule_pool_size, CurrentMetrics::BackgroundMessageBrokerSchedulePoolTask, "BgMBSchPool"); @@ -3000,159 +3089,121 @@ BackgroundSchedulePool & Context::getMessageBrokerSchedulePool() const BackgroundSchedulePool & Context::getConsumeSchedulePool() const { - auto lock = getLock(); - LOG_DEBUG(&Poco::Logger::get("BackgroundSchedulePool"), "getConsumeSchedulePool"); - if (!shared->extra_schedule_pools[SchedulePool::Consume]) - { + auto & item = shared->extra_schedule_pools[SchedulePool::Consume]; + callOnce(item.is_initialized, [&] { CpuSetPtr cpu_set; if (auto & cgroup_manager = CGroupManagerFactory::instance(); cgroup_manager.isInit()) { cpu_set = cgroup_manager.getCpuSet("hakafka"); } - shared->extra_schedule_pools[SchedulePool::Consume].emplace( + item.pool = std::make_unique( settings.background_consume_schedule_pool_size, CurrentMetrics::BackgroundConsumeSchedulePoolTask, "BgConsumePool", std::move(cpu_set)); - } - - return *shared->extra_schedule_pools[SchedulePool::Consume]; -} - -BackgroundSchedulePool & Context::getRestartSchedulePool() const -{ - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::Restart]) - shared->extra_schedule_pools[SchedulePool::Restart].emplace( - settings.background_schedule_pool_size, CurrentMetrics::BackgroundRestartSchedulePoolTask, "BgRestartPool"); - return *shared->extra_schedule_pools[SchedulePool::Restart]; -} -BackgroundSchedulePool & Context::getHaLogSchedulePool() const -{ - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::HaLog]) - shared->extra_schedule_pools[SchedulePool::HaLog].emplace( - settings.background_schedule_pool_size, CurrentMetrics::BackgroundHaLogSchedulePoolTask, "BgHaLogPool"); - return *shared->extra_schedule_pools[SchedulePool::HaLog]; -} - -BackgroundSchedulePool & Context::getMutationSchedulePool() const -{ - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::Mutation]) - shared->extra_schedule_pools[SchedulePool::Mutation].emplace( - settings.background_schedule_pool_size, CurrentMetrics::BackgroundMutationSchedulePoolTask, "BgMutatePool"); - return *shared->extra_schedule_pools[SchedulePool::Mutation]; + }); + return *item.pool; } BackgroundSchedulePool & Context::getLocalSchedulePool() const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::Local]) - shared->extra_schedule_pools[SchedulePool::Local].emplace( - settings.background_local_schedule_pool_size, CurrentMetrics::BackgroundLocalSchedulePoolTask, "BgLocalPool"); - return *shared->extra_schedule_pools[SchedulePool::Local]; + auto & item = shared->extra_schedule_pools[SchedulePool::Local]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( + settings.background_local_schedule_pool_size, + CurrentMetrics::BackgroundLocalSchedulePoolTask, + "BgLocalPool" + ); + }); + return *item.pool; } BackgroundSchedulePool & Context::getMergeSelectSchedulePool() const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::MergeSelect]) - shared->extra_schedule_pools[SchedulePool::MergeSelect].emplace( - settings.background_schedule_pool_size, CurrentMetrics::BackgroundMergeSelectSchedulePoolTask, "BgMSelectPool"); - return *shared->extra_schedule_pools[SchedulePool::MergeSelect]; + auto & item = shared->extra_schedule_pools[SchedulePool::MergeSelect]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( + settings.background_schedule_pool_size, + CurrentMetrics::BackgroundMergeSelectSchedulePoolTask, + "BgMSelectPool"); + }); + return *item.pool; } BackgroundSchedulePool & Context::getUniqueTableSchedulePool() const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::UniqueTable]) - shared->extra_schedule_pools[SchedulePool::UniqueTable].emplace( - settings.background_unique_table_schedule_pool_size, CurrentMetrics::BackgroundUniqueTableSchedulePoolTask, "BgUniqPool"); - return *shared->extra_schedule_pools[SchedulePool::UniqueTable]; -} - -BackgroundSchedulePool & Context::getMemoryTableSchedulePool() const -{ - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::MemoryTable]) - shared->extra_schedule_pools[SchedulePool::MemoryTable].emplace( - settings.background_memory_table_schedule_pool_size, CurrentMetrics::BackgroundMemoryTableSchedulePoolTask, "BgMemTblPool"); - return *shared->extra_schedule_pools[SchedulePool::MemoryTable]; + auto & item = shared->extra_schedule_pools[SchedulePool::UniqueTable]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( + settings.background_unique_table_schedule_pool_size, + CurrentMetrics::BackgroundUniqueTableSchedulePoolTask, + "BgUniqPool"); + }); + return *item.pool; } BackgroundSchedulePool & Context::getTopologySchedulePool() const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::CNCHTopology]) - shared->extra_schedule_pools[SchedulePool::CNCHTopology].emplace( - settings.background_topology_thread_pool_size, CurrentMetrics::BackgroundCNCHTopologySchedulePoolTask, "CNCHTopoPool"); - return *shared->extra_schedule_pools[SchedulePool::CNCHTopology]; + auto & item = shared->extra_schedule_pools[SchedulePool::CNCHTopology]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( + settings.background_topology_thread_pool_size, + CurrentMetrics::BackgroundCNCHTopologySchedulePoolTask, + "CNCHTopoPool"); + }); + return *item.pool; } BackgroundSchedulePool & Context::getMetricsRecalculationSchedulePool() const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[SchedulePool::PartsMetrics]) - shared->extra_schedule_pools[SchedulePool::PartsMetrics].emplace( + auto & item = shared->extra_schedule_pools[SchedulePool::PartsMetrics]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( settings.background_metrics_recalculation_schedule_pool_size, CurrentMetrics::BackgroundPartsMetricsSchedulePoolTask, "PtMetricsPool"); - return *shared->extra_schedule_pools[SchedulePool::PartsMetrics]; + }); + return *item.pool; } BackgroundSchedulePool & Context::getExtraSchedulePool( SchedulePool::Type pool_type, SettingFieldUInt64 pool_size, CurrentMetrics::Metric metric, const char * name) const { - auto lock = getLock(); - if (!shared->extra_schedule_pools[pool_type]) - shared->extra_schedule_pools[pool_type].emplace(pool_size, metric, name); - return *shared->extra_schedule_pools[pool_type]; + auto & item = shared->extra_schedule_pools[pool_type]; + callOnce(item.is_initialized, [&] { + item.pool = std::make_unique( pool_size, metric, name); + }); + return *item.pool; } ThrottlerPtr Context::getDiskCacheThrottler() const { - auto lock = getLock(); - if (!shared->disk_cache_throttler) - { + callOnce(shared->disk_cache_throttler_initialized, [&] { shared->disk_cache_throttler = std::make_shared(settings.max_bandwidth_for_disk_cache); - } - + }); return shared->disk_cache_throttler; } ThrottlerPtr Context::getReplicatedSendsThrottler() const { - auto lock = getLock(); - if (!shared->replicated_sends_throttler) - shared->replicated_sends_throttler = std::make_shared(settings.max_replicated_sends_network_bandwidth_for_server); - + callOnce(shared->replicated_sends_throttler_initialized, [&] { + shared->replicated_sends_throttler = std::make_shared( + settings.max_replicated_sends_network_bandwidth_for_server); + }); return shared->replicated_sends_throttler; } ThrottlerPtr Context::getReplicatedFetchesThrottler() const { - auto lock = getLock(); - if (!shared->replicated_fetches_throttler) - shared->replicated_fetches_throttler = std::make_shared(settings.max_replicated_fetches_network_bandwidth_for_server); - + callOnce(shared->replicated_fetches_throttler_initialized, [&] { + shared->replicated_fetches_throttler = std::make_shared( + settings.max_replicated_fetches_network_bandwidth_for_server); + }); return shared->replicated_fetches_throttler; } -void Context::initPreloadThrottler() -{ - auto lock = getLock(); - shared->preload_throttler = settings.parts_preload_throttler == 0 ? nullptr : std::make_shared(settings.parts_preload_throttler); -} - -ThrottlerPtr Context::tryGetPreloadThrottler() const -{ - auto lock = getLock(); - return shared->preload_throttler; -} - bool Context::hasDistributedDDL() const { return getConfigRef().has("distributed_ddl"); @@ -3160,7 +3211,7 @@ bool Context::hasDistributedDDL() const void Context::setDDLWorker(std::unique_ptr ddl_worker) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->ddl_worker) throw Exception("DDL background thread has already been initialized", ErrorCodes::LOGICAL_ERROR); ddl_worker->startup(); @@ -3169,7 +3220,7 @@ void Context::setDDLWorker(std::unique_ptr ddl_worker) DDLWorker & Context::getDDLWorker() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->ddl_worker) { if (!hasZooKeeper()) @@ -3468,8 +3519,8 @@ InterserverCredentialsPtr Context::getInterserverCredentials() std::pair Context::getCnchInterserverCredentials() const { - auto lock = getLock(); String user_name = getSettingsRef().username_for_internal_communication.toString(); + auto lock = getLock(); // checked auto password = shared->users_config->getString("users." + user_name + ".password", ""); return {user_name, password}; @@ -3576,8 +3627,6 @@ UInt16 Context::getTCPPort() const if (auto env_port = getPortFromEnvForConsul("PORT0")) return env_port; - auto lock = getLock(); - const auto & config = getConfigRef(); return config.getInt("tcp_port", DBMS_DEFAULT_PORT); } @@ -3599,8 +3648,6 @@ UInt16 Context::getTCPPort(const String & host, UInt16 rpc_port) const std::optional Context::getTCPPortSecure() const { - auto lock = getLock(); - const auto & config = getConfigRef(); if (config.has("tcp_port_secure")) return config.getInt("tcp_port_secure"); @@ -3623,7 +3670,6 @@ UInt16 Context::getServerPort(const String & port_name) const UInt16 Context::getHaTCPPort() const { - auto lock = getLock(); const auto & config = getConfigRef(); return config.getInt("ha_tcp_port"); } @@ -3720,7 +3766,7 @@ void Context::setCluster(const String & cluster_name, const std::shared_ptrsystem_logs = std::make_unique(getGlobalContext(), getConfigRef()); } @@ -3747,7 +3793,7 @@ PartitionSelectorPtr Context::getBGPartitionSelector() const std::shared_ptr Context::getQueryLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3758,7 +3804,7 @@ std::shared_ptr Context::getQueryLog() const std::shared_ptr Context::getQueryThreadLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3769,7 +3815,7 @@ std::shared_ptr Context::getQueryThreadLog() const std::shared_ptr Context::getQueryExchangeLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3780,7 +3826,7 @@ std::shared_ptr Context::getQueryExchangeLog() const std::shared_ptr Context::getPartLog(const String & part_database) const { - auto lock = getLock(); + auto lock = getLock(); // checked /// No part log or system logs are shutting down. if (!shared->system_logs) @@ -3798,7 +3844,7 @@ std::shared_ptr Context::getPartLog(const String & part_database) const std::shared_ptr Context::getPartMergeLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs || !shared->system_logs->part_merge_log) return {}; @@ -3809,7 +3855,7 @@ std::shared_ptr Context::getPartMergeLog() const std::shared_ptr Context::getServerPartLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs || !shared->system_logs->server_part_log) return {}; @@ -3821,7 +3867,7 @@ void Context::initializeCnchSystemLogs() { if ((shared->server_type != ServerType::cnch_server) && (shared->server_type != ServerType::cnch_worker)) return; - auto lock = getLock(); + auto lock = getLock(); // checked shared->cnch_system_logs = std::make_unique(getGlobalContext()); } @@ -3836,7 +3882,7 @@ void Context::insertViewRefreshTaskLog(const ViewRefreshTaskLogElement & element std::shared_ptr Context::getCnchQueryLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_system_logs) return {}; @@ -3846,7 +3892,7 @@ std::shared_ptr Context::getCnchQueryLog() const std::shared_ptr Context::getViewRefreshTaskLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_system_logs) return {}; @@ -3856,7 +3902,7 @@ std::shared_ptr Context::getViewRefreshTaskLog() const std::shared_ptr Context::getTraceLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3867,7 +3913,7 @@ std::shared_ptr Context::getTraceLog() const std::shared_ptr Context::getTextLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3878,7 +3924,7 @@ std::shared_ptr Context::getTextLog() const std::shared_ptr Context::getMetricLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3889,7 +3935,7 @@ std::shared_ptr Context::getMetricLog() const std::shared_ptr Context::getAsynchronousMetricLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3900,7 +3946,7 @@ std::shared_ptr Context::getAsynchronousMetricLog() const std::shared_ptr Context::getOpenTelemetrySpanLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3910,7 +3956,7 @@ std::shared_ptr Context::getOpenTelemetrySpanLog() const std::shared_ptr Context::getKafkaLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3920,7 +3966,7 @@ std::shared_ptr Context::getKafkaLog() const std::shared_ptr Context::getCloudKafkaLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_system_logs) return {}; @@ -3929,7 +3975,7 @@ std::shared_ptr Context::getCloudKafkaLog() const std::shared_ptr Context::getCloudMaterializedMySQLLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_system_logs) return {}; @@ -3938,7 +3984,7 @@ std::shared_ptr Context::getCloudMaterializedMySQLLog std::shared_ptr Context::getCloudUniqueTableLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_system_logs) return {}; @@ -3947,7 +3993,7 @@ std::shared_ptr Context::getCloudUniqueTableLog() const std::shared_ptr Context::getMutationLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3958,7 +4004,7 @@ std::shared_ptr Context::getMutationLog() const std::shared_ptr Context::getProcessorsProfileLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3968,7 +4014,7 @@ std::shared_ptr Context::getProcessorsProfileLog() const std::shared_ptr Context::getRemoteReadLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3978,7 +4024,7 @@ std::shared_ptr Context::getRemoteReadLog() const std::shared_ptr Context::getZooKeeperLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3988,7 +4034,7 @@ std::shared_ptr Context::getZooKeeperLog() const std::shared_ptr Context::getAutoStatsTaskLog() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->system_logs) return {}; @@ -3998,12 +4044,12 @@ std::shared_ptr Context::getAutoStatsTaskLog() const CompressionCodecPtr Context::chooseCompressionCodec(size_t part_size, double part_size_ratio) const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->compression_codec_selector) { constexpr auto config_name = "compression"; - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); if (config.has(config_name)) shared->compression_codec_selector = std::make_unique(config, "compression"); @@ -4104,11 +4150,11 @@ void Context::updateStorageConfiguration(Poco::Util::AbstractConfiguration & con const CnchHiveSettings & Context::getCnchHiveSettings() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnchhive_settings) { - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); CnchHiveSettings cnchhive_settings; cnchhive_settings.loadFromConfig("hive", config); shared->cnchhive_settings.emplace(cnchhive_settings); @@ -4119,11 +4165,11 @@ const CnchHiveSettings & Context::getCnchHiveSettings() const const CnchHiveSettings & Context::getCnchLasSettings() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->las_settings) { - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); CnchHiveSettings las_settings; las_settings.loadFromConfig("las", config); shared->las_settings.emplace(las_settings); @@ -4133,11 +4179,11 @@ const CnchHiveSettings & Context::getCnchLasSettings() const const MergeTreeSettings & Context::getMergeTreeSettings(bool skip_unknown_settings) const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->merge_tree_settings) { - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); MergeTreeSettings mt_settings; mt_settings.loadFromConfig("merge_tree", config, skip_unknown_settings); shared->merge_tree_settings.emplace(mt_settings); @@ -4148,11 +4194,11 @@ const MergeTreeSettings & Context::getMergeTreeSettings(bool skip_unknown_settin const CnchFileSettings & Context::getCnchFileSettings() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->cnch_file_settings) { - auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); shared->cnch_file_settings.emplace(); shared->cnch_file_settings->loadFromConfig("cnch_file", config); } @@ -4162,11 +4208,11 @@ const CnchFileSettings & Context::getCnchFileSettings() const const MergeTreeSettings & Context::getReplicatedMergeTreeSettings() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->replicated_merge_tree_settings) { - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); MergeTreeSettings mt_settings; mt_settings.loadFromConfig("merge_tree", config); mt_settings.loadFromConfig("replicated_merge_tree", config); @@ -4179,11 +4225,11 @@ const MergeTreeSettings & Context::getReplicatedMergeTreeSettings() const const StorageS3Settings & Context::getStorageS3Settings() const { #if !defined(ARCADIA_BUILD) - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->storage_s3_settings) { - const auto & config = getConfigRef(); + const auto & config = getConfigRefWithLock(lock); shared->storage_s3_settings.emplace().loadFromConfig("s3", config, getSettingsRef()); } @@ -4326,7 +4372,7 @@ OutputFormatPtr Context::getOutputFormat(const String & name, WriteBuffer & buf, time_t Context::getUptimeSeconds() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->uptime_watch.elapsedSeconds(); } @@ -4471,7 +4517,7 @@ void Context::setQueryParameter(const String & name, const String & value) void Context::addBridgeCommand(std::unique_ptr cmd) const { - auto lock = getLock(); + auto lock = getLock(); // checked shared->bridge_commands.emplace_back(std::move(cmd)); } @@ -4490,7 +4536,7 @@ const IHostContextPtr & Context::getHostContext() const std::shared_ptr Context::getActionLocksManager() { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->action_locks_manager) shared->action_locks_manager = std::make_shared(shared_from_this()); @@ -4583,7 +4629,7 @@ StorageID Context::resolveStorageID(StorageID storage_id, StorageNamespace where StorageID resolved = StorageID::createEmpty(); std::optional exc; { - auto lock = getLock(); + auto lock = getLock(); // checked resolved = resolveStorageIDImpl(std::move(storage_id), where, &exc); } if (exc) @@ -4604,7 +4650,7 @@ StorageID Context::tryResolveStorageID(StorageID storage_id, StorageNamespace wh StorageID resolved = StorageID::createEmpty(); { - auto lock = getLock(); + auto lock = getLock(); // checked resolved = resolveStorageIDImpl(std::move(storage_id), where, nullptr); } if (resolved && !resolved.hasUUID() && resolved.database_name != DatabaseCatalog::TEMPORARY_DATABASE) @@ -4721,7 +4767,7 @@ ZooKeeperMetadataTransactionPtr Context::getZooKeeperMetadataTransaction() const PartUUIDsPtr Context::getPartUUIDs() const { - auto lock = getLock(); + auto lock = getLocalLock(); // checked if (!part_uuids) /// For context itself, only this initialization is not const. /// We could have done in constructor. @@ -4747,7 +4793,7 @@ void Context::setReadTaskCallback(ReadTaskCallback && callback) PartUUIDsPtr Context::getIgnoredPartUUIDs() const { - auto lock = getLock(); + auto lock = getLocalLock(); // checked if (!ignored_part_uuids) const_cast(ignored_part_uuids) = std::make_shared(); @@ -4805,13 +4851,13 @@ void Context::setLasfsConnectionParams(const Poco::Util::AbstractConfiguration & void Context::setVETosConnectParams(const VETosConnectionParams & connect_params) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->vetos_connection_params = connect_params; } const VETosConnectionParams & Context::getVETosConnectParams() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->vetos_connection_params; } @@ -4827,7 +4873,7 @@ const OSSConnectionParams & Context::getOSSConnectParams() const void Context::setUniqueKeyIndexBlockCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->unique_key_index_block_cache) throw Exception("Unique key index block cache has been already created", ErrorCodes::LOGICAL_ERROR); shared->unique_key_index_block_cache = IndexFile::NewLRUCache(cache_size_in_bytes); @@ -4835,13 +4881,13 @@ void Context::setUniqueKeyIndexBlockCache(size_t cache_size_in_bytes) UniqueKeyIndexBlockCachePtr Context::getUniqueKeyIndexBlockCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->unique_key_index_block_cache; } void Context::setUniqueKeyIndexFileCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->unique_key_index_file_cache) throw Exception("Unique key index file cache has been already created", ErrorCodes::LOGICAL_ERROR); shared->unique_key_index_file_cache = std::make_shared(*this, cache_size_in_bytes); @@ -4849,13 +4895,13 @@ void Context::setUniqueKeyIndexFileCache(size_t cache_size_in_bytes) UniqueKeyIndexFileCachePtr Context::getUniqueKeyIndexFileCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->unique_key_index_file_cache; } void Context::setUniqueKeyIndexCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->unique_key_index_cache) throw Exception("Unique key index cache has been already created", ErrorCodes::LOGICAL_ERROR); shared->unique_key_index_cache = std::make_shared(cache_size_in_bytes); @@ -4863,13 +4909,13 @@ void Context::setUniqueKeyIndexCache(size_t cache_size_in_bytes) std::shared_ptr Context::getUniqueKeyIndexCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->unique_key_index_cache; } void Context::setDeleteBitmapCache(size_t cache_size_in_bytes) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->delete_bitmap_cache) throw Exception("Delete bitmap cache has been already created", ErrorCodes::LOGICAL_ERROR); shared->delete_bitmap_cache = std::make_shared(cache_size_in_bytes); @@ -4877,7 +4923,7 @@ void Context::setDeleteBitmapCache(size_t cache_size_in_bytes) DeleteBitmapCachePtr Context::getDeleteBitmapCache() const { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->delete_bitmap_cache; } @@ -4952,12 +4998,12 @@ void Context::setMetaCheckerStatus(bool stop) shared->stop_sync = stop; } -void Context::setChecksumsCache(const ChecksumsCacheSettings & settings) +void Context::setChecksumsCache(const ChecksumsCacheSettings & settings_) { if (shared->checksums_cache) throw Exception("Checksums cache has been already created.", ErrorCodes::LOGICAL_ERROR); - shared->checksums_cache = std::make_shared(settings); + shared->checksums_cache = std::make_shared(settings_); } std::shared_ptr Context::getChecksumsCache() const @@ -4965,12 +5011,12 @@ std::shared_ptr Context::getChecksumsCache() const return shared->checksums_cache; } -void Context::setGinIndexStoreFactory(const GinIndexStoreCacheSettings & settings) +void Context::setGinIndexStoreFactory(const GinIndexStoreCacheSettings & settings_) { if (shared->ginindex_store_factory) throw Exception("ginindex_store_factory has been already created.", ErrorCodes::LOGICAL_ERROR); - shared->ginindex_store_factory = std::make_shared(settings); + shared->ginindex_store_factory = std::make_shared(settings_); } std::shared_ptr Context::getGinIndexStoreFactory() const @@ -5091,7 +5137,7 @@ UInt64 Context::getPhysicalTimestamp() const void Context::setPartCacheManager() { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->cache_manager) throw Exception("Part cache manager has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -5101,7 +5147,8 @@ void Context::setPartCacheManager() PartCacheManagerPtr Context::getPartCacheManager() const { - auto lock = getLock(); + /// no need to lock because PartCacheManager is initialized during server start up, + /// there is no concurrent setPartCacheManager and getPartCacheManager usage. return shared->cache_manager; } @@ -5138,7 +5185,7 @@ DaemonManagerClientPtr Context::getDaemonManagerClient() const void Context::setCnchServerManager(const Poco::Util::AbstractConfiguration & config) { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->server_manager) throw Exception("Server manager has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -5147,7 +5194,7 @@ void Context::setCnchServerManager(const Poco::Util::AbstractConfiguration & con std::shared_ptr Context::getCnchServerManager() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->server_manager) throw Exception("Server manager is not initiailized.", ErrorCodes::LOGICAL_ERROR); @@ -5158,7 +5205,7 @@ void Context::updateServerVirtualWarehouses(const ConfigurationPtr & config) { std::shared_ptr server_manager; { - auto lock = getLock(); + auto lock = getLock(); // checked server_manager = shared->server_manager; } if (server_manager) @@ -5167,7 +5214,7 @@ void Context::updateServerVirtualWarehouses(const ConfigurationPtr & config) void Context::setCnchTopologyMaster() { - auto lock = getLock(); + auto lock = getLock(); // checked if (shared->topology_master) throw Exception("Topology master has been already created.", ErrorCodes::LOGICAL_ERROR); @@ -5176,7 +5223,7 @@ void Context::setCnchTopologyMaster() std::shared_ptr Context::getCnchTopologyMaster() const { - auto lock = getLock(); + auto lock = getLock(); // checked if (!shared->topology_master) throw Exception("Topology master is not initialized.", ErrorCodes::LOGICAL_ERROR); @@ -5185,9 +5232,9 @@ std::shared_ptr Context::getCnchTopologyMaster() const GlobalTxnCommitterPtr Context::getGlobalTxnCommitter() const { - auto lock = getLock(); - if (!shared->global_txn_committer) - shared->global_txn_committer = std::make_shared(shared_from_this()); + callOnce(shared->global_txn_committer_initialized, [&] { + shared->global_txn_committer = std::make_shared(getGlobalContext()); + }); return shared->global_txn_committer; } @@ -5427,7 +5474,7 @@ void Context::initResourceManagerClient() String host_port; try { - auto lock = getLock(); + auto lock = getLock(); // checked shared->rm_client = std::make_shared(getGlobalContext()); LOG_DEBUG(shared->log, "Initialised Resource Manager Client on try: {}", retry_count); return; @@ -5449,7 +5496,7 @@ ResourceManagerClientPtr Context::getResourceManagerClient() const void Context::initCnchBGThreads() { - auto lock = getLock(); + auto lock = getLock(); // checked shared->cnch_bg_threads_array = std::make_unique(shared_from_this()); } @@ -5606,14 +5653,13 @@ std::multimap Context::collectMutationStatus void Context::initCnchTransactionCoordinator() { - auto lock = getLock(); - + auto lock = getLock(); // checked shared->cnch_txn_coordinator = std::make_unique(shared_from_this()); } TransactionCoordinatorRcCnch & Context::getCnchTransactionCoordinator() const { - auto lock = getLock(); + auto lock = getLock(); // checked return *shared->cnch_txn_coordinator; } @@ -5621,7 +5667,7 @@ void Context::setCurrentTransaction(TransactionCnchPtr txn, bool finish_txn) { TransactionCnchPtr prev_txn; { - auto lock = getLock(); + auto lock = getLocalSharedLock(); prev_txn = current_cnch_txn; } @@ -5631,7 +5677,7 @@ void Context::setCurrentTransaction(TransactionCnchPtr txn, bool finish_txn) if (current_thread && txn) CurrentThread::get().setTransactionId(txn->getTransactionID()); - auto lock = getLock(); + auto lock = getLocalLock(); current_cnch_txn = std::move(txn); } @@ -5654,28 +5700,26 @@ TransactionCnchPtr Context::setTemporaryTransaction(const TxnTimestamp & txn_id, else cnch_txn = std::make_shared(getGlobalContext(), txn_id, primary_txn_id); - auto lock = getLock(); + auto lock = getLocalLock(); std::swap(current_cnch_txn, cnch_txn); return current_cnch_txn; } TransactionCnchPtr Context::getCurrentTransaction() const { - auto lock = getLock(); - + auto lock = getLocalSharedLock(); return current_cnch_txn; } TxnTimestamp Context::tryGetCurrentTransactionID() const { - auto lock = getLock(); - + auto lock = getLocalSharedLock(); return current_cnch_txn ? current_cnch_txn->getTransactionID() : TxnTimestamp{}; } TxnTimestamp Context::getCurrentTransactionID() const { - auto lock = getLock(); + auto lock = getLocalSharedLock(); if (!current_cnch_txn) throw Exception("Transaction is not set (empty)", ErrorCodes::LOGICAL_ERROR); @@ -5689,11 +5733,9 @@ TxnTimestamp Context::getCurrentTransactionID() const TxnTimestamp Context::getCurrentCnchStartTime() const { - auto lock = getLock(); - + auto lock = getLocalSharedLock(); if (!current_cnch_txn) throw Exception("Transaction is not set", ErrorCodes::LOGICAL_ERROR); - return current_cnch_txn->getStartTime(); } @@ -5738,32 +5780,12 @@ std::vector> Context::getAllWorkerResou return shared->named_cnch_sessions->getAllWorkerResources(); } -Context::PartAllocator Context::getPartAllocationAlgo() const +Context::PartAllocator Context::getPartAllocationAlgo(MergeTreeSettingsPtr table_settings) const { - /// we prefer the config setting first - if (getConfigRef().has("part_allocation_algorithm")) - { - LOG_DEBUG( - shared->log, - "Using part allocation algorithm from config: {}.", - getConfigRef().getInt("part_allocation_algorithm")); - switch (getConfigRef().getInt("part_allocation_algorithm")) - { - case 0: - return PartAllocator::JUMP_CONSISTENT_HASH; - case 1: - return PartAllocator::RING_CONSISTENT_HASH; - case 2: - return PartAllocator::STRICT_RING_CONSISTENT_HASH; - case 3: - return PartAllocator::SIMPLE_HASH; - default: - return PartAllocator::JUMP_CONSISTENT_HASH; - } - } + auto algorithm = table_settings->cnch_part_allocation_algorithm >= 0 ? table_settings->cnch_part_allocation_algorithm : settings.cnch_part_allocation_algorithm; + LOG_DEBUG(shared->log, "Send query with cnch_part_allocation_algorithm = {}, system setting = {}, table setting = {}", algorithm, settings.cnch_part_allocation_algorithm, table_settings->cnch_part_allocation_algorithm); - /// if not set, we use the query settings - switch (settings.cnch_part_allocation_algorithm) + switch (algorithm) { case 0: return PartAllocator::JUMP_CONSISTENT_HASH; @@ -5772,7 +5794,7 @@ Context::PartAllocator Context::getPartAllocationAlgo() const case 2: return PartAllocator::STRICT_RING_CONSISTENT_HASH; case 3: - return PartAllocator::SIMPLE_HASH; + return PartAllocator::DISK_CACHE_STEALING_DEBUG; default: return PartAllocator::JUMP_CONSISTENT_HASH; } @@ -5791,6 +5813,11 @@ Context::HybridPartAllocator Context::getHybridPartAllocationAlgo() const } } +bool Context::hasSessionTimeZone() const +{ + return !settings.session_timezone.value.empty(); +} + void Context::createPlanNodeIdAllocator(int max_id) { id_allocator = std::make_shared(max_id); @@ -5808,7 +5835,7 @@ void Context::createOptimizerMetrics() std::shared_ptr Context::getStatisticsMemoryStore() { - auto lock = getLock(); + auto lock = getLocalLock(); if (!this->stats_memory_store) { this->stats_memory_store = std::make_shared(); @@ -5880,25 +5907,25 @@ void Context::waitReadFromClientFinished() const void Context::setPlanCacheManager(std::unique_ptr && manager) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->plan_cache_manager = std::move(manager); } PlanCacheManager* Context::getPlanCacheManager() { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->plan_cache_manager ? shared->plan_cache_manager.get() : nullptr; } void Context::setPreparedStatementManager(std::unique_ptr && manager) { - auto lock = getLock(); + auto lock = getLock(); // checked shared->prepared_statement_manager = std::move(manager); } PreparedStatementManager * Context::getPreparedStatementManager() { - auto lock = getLock(); + auto lock = getLock(); // checked return shared->prepared_statement_manager ? shared->prepared_statement_manager.get() : nullptr; } @@ -5927,16 +5954,12 @@ void Context::setQueryExpirationTimeStamp() AsynchronousReaderPtr Context::getThreadPoolReader() const { - auto lock = getLock(); - - if (!shared->asynchronous_remote_fs_reader) - { + callOnce(shared->readers_initialized, [&] { const Poco::Util::AbstractConfiguration & config = getConfigRef(); auto pool_size = config.getUInt(".threadpool_remote_fs_reader_pool_size", 250); auto queue_size = config.getUInt(".threadpool_remote_fs_reader_queue_size", 1000000); shared->asynchronous_remote_fs_reader = std::make_shared(pool_size, queue_size); - } - + }); return shared->asynchronous_remote_fs_reader; } } diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 66901b9b3d0..3eeb35fdd98 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -47,13 +47,16 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include +#include #include #if !defined(ARCADIA_BUILD) # include @@ -79,6 +82,10 @@ namespace DB::Statistics { struct StatisticsMemoryStore; } +namespace DB::Statistics::AutoStats +{ +class AutoStatisticsManager; +} namespace zkutil { class ZooKeeper; @@ -121,9 +128,11 @@ class ManipulationList; class ReplicatedFetchList; class Cluster; class Compiler; +class CloudTableDefinitionCache; class MarkCache; class MMappedFileCache; class UncompressedCache; +class GinIdxFilterResultCache; class PrimaryIndexCache; class ProcessList; class ProcessListEntry; @@ -183,6 +192,7 @@ class VWResourceGroupManager; class Credentials; class GSSAcceptorContext; struct SettingsConstraintsAndProfileIDs; +struct SettingsProfilesInfo; class RemoteHostFilter; struct StorageID; class IDisk; @@ -202,6 +212,7 @@ class KeeperDispatcher; class SegmentScheduler; using SegmentSchedulerPtr = std::shared_ptr; class ChecksumsCache; +class CompressedDataIndexCache; class PrimaryIndexCache; struct ChecksumsCacheSettings; template @@ -421,15 +432,13 @@ class CopyableAtomic : public std::atomic } }; -/** A set of known objects that can be used in the query. - * Consists of a shared part (always common to all sessions and queries) - * and copied part (which can be its own for each session or query). - * - * Everything is encapsulated for all sorts of checks and locks. - */ -class Context : public std::enable_shared_from_this +class ContextData { -private: +protected: + /// Use copy constructor or createGlobal() instead + ContextData(); + ContextData(const ContextData &); + ContextSharedPart * shared; ClientInfo client_info; @@ -447,9 +456,11 @@ class Context : public std::enable_shared_from_this CopyableAtomic resource_group{nullptr}; /// Current resource group. String current_database; Settings settings; /// Setting for query execution. + SettingsChanges settings_changes; // query level or session level settings changes using ProgressCallback = std::function; ProgressCallback progress_callback; /// Callback for tracking progress of query execution. + std::function send_tcp_progress{nullptr}; using FileProgressCallback = std::function; FileProgressCallback file_progress_callback; /// Callback for tracking progress of file loading. @@ -579,13 +590,12 @@ class Context : public std::enable_shared_from_this bool enable_worker_fault_tolerance = false; timespec query_expiration_timestamp{}; + public: // Top-level OpenTelemetry trace context for the query. Makes sense only for a query context. OpenTelemetryTraceContext query_trace_context; -private: - friend struct NamedCnchSession; - +protected: using SampleBlockCache = std::unordered_map; mutable SampleBlockCache sample_block_cache; @@ -627,16 +637,38 @@ class Context : public std::enable_shared_from_this ExceptionHandlerPtr plan_segment_ex_handler = nullptr; bool read_from_client_finished = false; - bool is_explain_query = false; + int step_id = 2000; + int rule_id = 3000; + String graphviz_sub_query_path; + int sub_query_id = 0; + bool has_tenant_id_in_username = false; + String tenant_id; + String current_catalog; +}; + +/** A set of known objects that can be used in the query. + * Consists of a shared part (always common to all sessions and queries) + * and copied part (which can be its own for each session or query). + * + * Everything is encapsulated for all sorts of checks and locks. + */ +class Context : public ContextData, public std::enable_shared_from_this +{ +private: + /// ContextData mutex + mutable SharedMutex mutex; + + String query_plan; Context(); Context(const Context &); - Context & operator=(const Context &); public: + friend struct NamedCnchSession; + /// Create initial Context with ContextShared and etc. - static ContextMutablePtr createGlobal(ContextSharedPart * shared); + static ContextMutablePtr createGlobal(ContextSharedPart * shared_part); static ContextMutablePtr createCopy(const ContextWeakPtr & other); static ContextMutablePtr createCopy(const ContextMutablePtr & other); static ContextMutablePtr createCopy(const ContextPtr & other); @@ -645,8 +677,6 @@ class Context : public std::enable_shared_from_this void addSessionView(StorageID view_table_id, StoragePtr view_storage); StoragePtr getSessionView(StorageID view_table_id); - void copyFrom(const ContextPtr & other); - ~Context(); void setExtendedProfileInfo(const ExtendedProfileInfo & source) const; @@ -710,6 +740,7 @@ class Context : public std::enable_shared_from_this /// Global application configuration settings. void setConfig(const ConfigurationPtr & config); const Poco::Util::AbstractConfiguration & getConfigRef() const; + const Poco::Util::AbstractConfiguration & getConfigRefWithLock(const std::unique_lock &) const; void initRootConfig(const Poco::Util::AbstractConfiguration & poco_config); const RootConfiguration & getRootConfig() const; @@ -935,13 +966,15 @@ class Context : public std::enable_shared_from_this Settings getSettings() const; void setSettings(const Settings & settings_); + void setSessionSettingsChanges(const SettingsChanges & settings_changes_) const { getSessionContext()->settings_changes = settings_changes_; } + void applySessionSettingsChanges() { applySettingsChanges(getSessionContext()->settings_changes); } + void clearSessionSettingsChanges() const { getSessionContext()->settings_changes.clear(); } /// Set settings by name. void setSetting(const StringRef & name, const String & value); void setSetting(const StringRef & name, const Field & value); void applySettingChange(const SettingChange & change); void applySettingsChanges(const SettingsChanges & changes, bool internal = true); - void applySettingsChanges(const JSON & changes); /// Checks the constraints. void checkSettingsConstraints(const SettingChange & change) const; @@ -1083,6 +1116,9 @@ class Context : public std::enable_shared_from_this void setProgressCallback(ProgressCallback callback); /// Used in InterpreterSelectQuery to pass it to the IBlockInputStream. ProgressCallback getProgressCallback() const; + void setSendTCPProgress(std::function callback); + /// Used in InterpreterSelectQuery to pass it to the IBlockInputStream. + std::function getSendTCPProgress() const; void setFileProgressCallback(FileProgressCallback && callback) { file_progress_callback = callback; } FileProgressCallback getFileProgressCallback() const { return file_progress_callback; } @@ -1170,6 +1206,13 @@ class Context : public std::enable_shared_from_this UInt32 getZooKeeperSessionUptime() const; + void addQueryPlanInfo(String & query_plan_) + { + this->query_plan = query_plan_; + } + + String getQueryPlan() {return query_plan;} + #if USE_NURAFT std::shared_ptr & getKeeperDispatcher() const; #endif @@ -1208,6 +1251,9 @@ class Context : public std::enable_shared_from_this std::shared_ptr getMarkCache() const; void dropMarkCache() const; + /// result maybe nullptr + std::shared_ptr tryGetCloudTableDefinitionCache() const; + /// Create a cache of mapped files to avoid frequent open/map/unmap/close and to reuse from several threads. void setMMappedFileCache(size_t cache_size_in_num_entries); std::shared_ptr getMMappedFileCache() const; @@ -1245,13 +1291,9 @@ class Context : public std::enable_shared_from_this BackgroundSchedulePool & getDistributedSchedulePool() const; BackgroundSchedulePool & getConsumeSchedulePool() const; - BackgroundSchedulePool & getRestartSchedulePool() const; - BackgroundSchedulePool & getHaLogSchedulePool() const; - BackgroundSchedulePool & getMutationSchedulePool() const; BackgroundSchedulePool & getLocalSchedulePool() const; BackgroundSchedulePool & getMergeSelectSchedulePool() const; BackgroundSchedulePool & getUniqueTableSchedulePool() const; - BackgroundSchedulePool & getMemoryTableSchedulePool() const; BackgroundSchedulePool & getTopologySchedulePool() const; BackgroundSchedulePool & getMetricsRecalculationSchedulePool() const; /// no more get pool method, use getExtraSchedulePool @@ -1259,13 +1301,9 @@ class Context : public std::enable_shared_from_this SchedulePool::Type pool_type, SettingFieldUInt64 pool_size, CurrentMetrics::Metric metric, const char * name) const; ThrottlerPtr getDiskCacheThrottler() const; - ThrottlerPtr getReplicatedFetchesThrottler() const; ThrottlerPtr getReplicatedSendsThrottler() const; - void initPreloadThrottler(); - ThrottlerPtr tryGetPreloadThrottler() const; - /// Has distributed_ddl configuration or not. bool hasDistributedDDL() const; void setDDLWorker(std::unique_ptr ddl_worker); @@ -1440,17 +1478,14 @@ class Context : public std::enable_shared_from_this UInt32 nextNodeId() { return id_allocator->nextId(); } void createPlanNodeIdAllocator(int max_id = 1); - int step_id = 2000; int getStepId() const { return step_id; } void setStepId(int step_id_) { step_id = step_id_; } int getAndIncStepId() { return ++step_id; } - int rule_id = 3000; int getRuleId() const { return rule_id; } void setRuleId(int rule_id_) { rule_id = rule_id_; } void incRuleId() { ++rule_id; } - String graphviz_sub_query_path; void setExecuteSubQueryPath(String path) { graphviz_sub_query_path = std::move(path); } String getExecuteSubQueryPath() const { @@ -1461,7 +1496,6 @@ class Context : public std::enable_shared_from_this graphviz_sub_query_path = ""; } - int sub_query_id = 0; int incAndGetSubQueryId() { return ++sub_query_id; } const SymbolAllocatorPtr & getSymbolAllocator() { return symbol_allocator; } @@ -1525,10 +1559,10 @@ class Context : public std::enable_shared_from_this return settings.default_catalog.toString(); } - void setChecksumsCache(const ChecksumsCacheSettings & settings); + void setChecksumsCache(const ChecksumsCacheSettings & settings_); std::shared_ptr getChecksumsCache() const; - void setGinIndexStoreFactory(const GinIndexStoreCacheSettings & settings); + void setGinIndexStoreFactory(const GinIndexStoreCacheSettings & settings_); std::shared_ptr getGinIndexStoreFactory() const; void setPrimaryIndexCache(size_t cache_size_in_bytes); @@ -1652,10 +1686,10 @@ class Context : public std::enable_shared_from_this JUMP_CONSISTENT_HASH = 0, RING_CONSISTENT_HASH = 1, STRICT_RING_CONSISTENT_HASH = 2, - SIMPLE_HASH = 3,//Note: Now just used for test disk cache stealing so not used for online + DISK_CACHE_STEALING_DEBUG = 3,//Note: Now just used for test disk cache stealing so not used for online }; - PartAllocator getPartAllocationAlgo() const; + PartAllocator getPartAllocationAlgo(MergeTreeSettingsPtr settings) const; /// Consistent hash algorithm for hybrid part allocation enum HybridPartAllocator : int @@ -1668,6 +1702,10 @@ class Context : public std::enable_shared_from_this }; HybridPartAllocator getHybridPartAllocationAlgo() const; + // If session timezone is specified, some cache which involves creating table/storage can't be used. + // Because it may use wrong timezone for DateTime column, which leads to incorrect result. + bool hasSessionTimeZone() const; + String getDefaultCnchPolicyName() const; String getCnchAuxilityPolicyName() const; @@ -1704,16 +1742,28 @@ class Context : public std::enable_shared_from_this bool is_tenant_user() const { return has_tenant_id_in_username; } private: - bool has_tenant_id_in_username = false; - String tenant_id; - String current_catalog; std::unique_lock getLock() const; + std::unique_lock getLocalLock() const; + std::shared_lock getLocalSharedLock() const; void initGlobal(); /// Compute and set actual user settings, client_info.current_user should be set - void calculateAccessRights(); + void calculateAccessRightsWithLock(const std::unique_lock &); + + void setCurrentProfileWithLock(const String & profile_name, const std::unique_lock & lock); + void setCurrentProfileWithLock(const UUID & profile_id, const std::unique_lock & lock); + void setCurrentProfileWithLock(const SettingsProfilesInfo & profiles_info, const std::unique_lock & lock); + void setSettingWithLock(const StringRef & name, const String & value, const std::unique_lock & lock); + void setSettingWithLock(const StringRef & name, const Field & value, const std::unique_lock & lock); + void applySettingChangeWithLock(const SettingChange & change, const std::unique_lock & lock); + void applySettingsChangesWithLock(const SettingsChanges & changes, bool internal, const std::unique_lock & lock); + std::shared_ptr getSettingsConstraintsAndCurrentProfilesWithLock() const; + void checkSettingsConstraintsWithLock(const SettingChange & change) const; + void checkSettingsConstraintsWithLock(const SettingsChanges & changes) const; + void checkSettingsConstraintsWithLock(SettingsChanges & changes) const; + void clampToSettingsConstraintsWithLock(SettingsChanges & changes) const; template void checkAccessImpl(const Args &... args) const; diff --git a/src/Interpreters/Context_fwd.h b/src/Interpreters/Context_fwd.h index 4d1a56cd4fa..7674de6804c 100644 --- a/src/Interpreters/Context_fwd.h +++ b/src/Interpreters/Context_fwd.h @@ -13,13 +13,9 @@ namespace SchedulePool enum Type { Consume, - Restart, - HaLog, - Mutation, Local, MergeSelect, UniqueTable, - MemoryTable, CNCHTopology, PartsMetrics, BspGC, diff --git a/src/Interpreters/CrashLog.cpp b/src/Interpreters/CrashLog.cpp index a9da804f1d2..3b695547cd9 100644 --- a/src/Interpreters/CrashLog.cpp +++ b/src/Interpreters/CrashLog.cpp @@ -40,7 +40,7 @@ void CrashLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(timestamp_ns); columns[i++]->insert(signal); diff --git a/src/Interpreters/DAGGraph.cpp b/src/Interpreters/DAGGraph.cpp index a39e4a513e7..2702b5c3404 100644 --- a/src/Interpreters/DAGGraph.cpp +++ b/src/Interpreters/DAGGraph.cpp @@ -107,6 +107,13 @@ void SourcePruner::generateUnprunableSegments() unprunable_plan_segments.insert(segment_output->getPlanSegmentId()); } } + for (const auto & segment_input : node.getPlanSegment()->getPlanSegmentInputs()) + { + if (segment_input->isStable()) + { + unprunable_plan_segments.insert(node.getPlanSegment()->getPlanSegmentId()); + } + } } } diff --git a/src/Interpreters/DAGGraph.h b/src/Interpreters/DAGGraph.h index 54896169d49..fb8880a7392 100644 --- a/src/Interpreters/DAGGraph.h +++ b/src/Interpreters/DAGGraph.h @@ -30,7 +30,7 @@ struct PlanSegmentsStatus }; using PlanSegmentsStatusPtr = std::shared_ptr; -using Source = std::unordered_set; +using SegmentIds = std::unordered_set; using WorkerInfoSet = std::unordered_set, HostWithPorts::IsSameEndpoint>; using PlanSegmentId = size_t; using StorageUnions = std::vector>; @@ -77,10 +77,9 @@ struct DAGGraph return source_pruner; } - /// all segments containing only table scan - Source sources; - /// all segments containing at least one table scan - Source any_tables; + SegmentIds leaf_segments; + /// all segments contain at least table scan + SegmentIds segments_has_table_scan; size_t final = std::numeric_limits::max(); std::set scheduled_segments; std::unordered_map id_to_segment; diff --git a/src/Interpreters/DatabaseCatalog.cpp b/src/Interpreters/DatabaseCatalog.cpp index 4cc828c7f3f..137ca5248ec 100644 --- a/src/Interpreters/DatabaseCatalog.cpp +++ b/src/Interpreters/DatabaseCatalog.cpp @@ -302,9 +302,6 @@ DatabaseAndTable DatabaseCatalog::getTableImpl( } } - if (context_->getServerType() == ServerType::cnch_server) - context_->checkAeolusTableAccess(table_id.database_name, table_id.table_name); - if (table_id.hasUUID() && table_id.database_name == TEMPORARY_DATABASE) { /// Shortcut for tables which have persistent UUID @@ -333,6 +330,7 @@ DatabaseAndTable DatabaseCatalog::getTableImpl( db_and_table.second = std::make_shared(std::move(db_and_table.second), db_and_table.first.get()); } #endif + return db_and_table; } @@ -1204,7 +1202,7 @@ DatabasePtr DatabaseCatalog::tryGetDatabaseCnch(const String & database_name, Co return res; res = getDatabaseFromCnchOrHiveCatalog( database_name, - getContext(), + local_context, txn ? txn->getStartTime() : TxnTimestamp::maxTS(), local_context->getSettingsRef().enable_three_part_identifier); if (res && txn) diff --git a/src/Interpreters/DistributedStages/BSPScheduler.cpp b/src/Interpreters/DistributedStages/BSPScheduler.cpp index 8c5605bb84b..513761df670 100644 --- a/src/Interpreters/DistributedStages/BSPScheduler.cpp +++ b/src/Interpreters/DistributedStages/BSPScheduler.cpp @@ -52,13 +52,13 @@ void BSPScheduler::submitTasks(PlanSegment * plan_segment_ptr, const SegmentTask else { pending_task_instances.for_nodes[selector_info.worker_nodes[i].address].emplace(task.task_id, i); - if (task.is_source) + if (task.has_table_scan) { source_task_count_on_workers[selector_info.worker_nodes[i].address] += 1; } } } - if (task.is_source) + if (task.has_table_scan) { std::unordered_map source_task_index_on_workers; for (size_t i = 0; i < selector_info.worker_nodes.size(); i++) @@ -70,7 +70,7 @@ void BSPScheduler::submitTasks(PlanSegment * plan_segment_ptr, const SegmentTask source_task_index_on_workers[addr]++; } } - triggerDispatch(cluster_nodes.rank_workers); + triggerDispatch(cluster_nodes.all_workers); } void BSPScheduler::onSegmentFinished(const size_t & segment_id, bool is_succeed, bool /*is_canceled*/) @@ -153,19 +153,6 @@ void BSPScheduler::updateSegmentStatusCounter(size_t segment_id, UInt64 parallel std::unique_lock lk(nodes_alloc_mutex); auto failed_worker = segment_parallel_locations[segment_id][parallel_index]; failed_workers[segment_id].insert(failed_worker); - auto iter = pending_task_instances.for_nodes[failed_worker].begin(); - while (iter != pending_task_instances.for_nodes[failed_worker].end()) - { - if (iter->task_id == segment_id) - { - pending_task_instances.no_prefs.insert({iter->task_id, iter->parallel_index}); - iter = pending_task_instances.for_nodes[failed_worker].erase(iter); - } - else - { - iter++; - } - } } } @@ -194,13 +181,14 @@ bool BSPScheduler::retryTaskIfPossible(size_t segment_id, UInt64 parallel_index) { if (auto step = std::dynamic_pointer_cast(node.step)) { - if (auto cnch_table = step->getTarget()->getStorage()) + if (auto cnch_table = std::dynamic_pointer_cast(step->getTarget()->getStorage())) { - // unique table can't support retry - if (cnch_table->getInMemoryMetadataPtr()->hasUniqueKey()) + auto txn = query_context->getCurrentTransaction(); + /// Unique table with can't support retry in non-append write mode when dedup in write suffix stage + if (cnch_table->commitTxnInWriteSuffixStage(txn->getDedupImplVersion(query_context), query_context)) return false; - is_table_write = true; } + is_table_write = true; } else if (node.step->getType() == IQueryPlanStep::Type::TableFinish) return false; @@ -250,11 +238,11 @@ bool BSPScheduler::retryTaskIfPossible(size_t segment_id, UInt64 parallel_index) } { std::unique_lock lk(nodes_alloc_mutex); - if (dag_graph_ptr->any_tables.contains(segment_id) || + if (dag_graph_ptr->segments_has_table_scan.contains(segment_id) || // for local no repartion and local may no repartition, schedule to original node NodeSelector::tryGetLocalInput(dag_graph_ptr->getPlanSegmentPtr(segment_id)) || // in case all workers except servers are occupied, simply retry at last node - failed_workers[segment_id].size() == cluster_nodes.rank_workers.size()) + failed_workers[segment_id].size() == cluster_nodes.all_workers.size()) { auto available_worker = segment_parallel_locations[segment_id][parallel_index]; occupied_workers[segment_id].erase(available_worker); @@ -266,7 +254,7 @@ bool BSPScheduler::retryTaskIfPossible(size_t segment_id, UInt64 parallel_index) { pending_task_instances.no_prefs.insert({segment_id, parallel_index}); lk.unlock(); - triggerDispatch(cluster_nodes.rank_workers); + triggerDispatch(cluster_nodes.all_workers); } } return true; diff --git a/src/Interpreters/DistributedStages/MPPQueryCoordinator.cpp b/src/Interpreters/DistributedStages/MPPQueryCoordinator.cpp index 66c50ff06b0..0646c99ec59 100644 --- a/src/Interpreters/DistributedStages/MPPQueryCoordinator.cpp +++ b/src/Interpreters/DistributedStages/MPPQueryCoordinator.cpp @@ -6,17 +6,18 @@ #include #include #include -#include -#include -#include #include #include #include +#include +#include +#include +#include #include #include -#include -#include #include +#include +#include "Interpreters/DistributedStages/ProgressManager.h" #include @@ -208,7 +209,13 @@ BlockIO MPPQueryCoordinator::execute() process_list_elem_ptr->get().updateProgressIn(p); }); - scheduler_status = query_context->getSegmentScheduler()->insertPlanSegments(query_id, plan_segment_tree.get(), query_context); + { + /// only send progress before executing final plan segment, + /// working thread will join when this tcp progress sender is destroyed + auto sender = std::make_unique( + query_context->getSendTCPProgress(), query_context->getSettingsRef().interactive_delay / 1000); + scheduler_status = query_context->getSegmentScheduler()->insertPlanSegments(query_id, plan_segment_tree.get(), query_context); + } if (scheduler_status && !scheduler_status->exception.empty()) { diff --git a/src/Interpreters/DistributedStages/PlanSegment.cpp b/src/Interpreters/DistributedStages/PlanSegment.cpp index 09ef1c4f792..9ecbfb01c76 100644 --- a/src/Interpreters/DistributedStages/PlanSegment.cpp +++ b/src/Interpreters/DistributedStages/PlanSegment.cpp @@ -211,6 +211,7 @@ String PlanSegmentInput::toString(size_t indent) const ostr << indent_str << "keep_order: " << keep_order << "\n"; ostr << indent_str << "storage_id: " << (type == PlanSegmentType::SOURCE && storage_id.has_value() ? storage_id->getNameForLogs() : "") << "\n"; ostr << indent_str << "source_addresses: " << "\n"; + ostr << indent_str << "isStable: " << isStable() << "\n"; for (auto & address : source_addresses) ostr << indent_str << indent_str << address.toString() << "\n"; @@ -239,6 +240,8 @@ void PlanSegmentOutput::toProto(Protos::PlanSegmentOutput & proto) proto.set_shuffle_hash_function(shuffle_function_name); proto.set_parallel_size(parallel_size); proto.set_keep_order(keep_order); + if(!shuffle_func_params.empty()) + serializeFieldVectorToProto(shuffle_func_params, *proto.mutable_shuffle_function_parameters()); } void PlanSegmentOutput::fillFromProto(const Protos::PlanSegmentOutput & proto) @@ -247,6 +250,8 @@ void PlanSegmentOutput::fillFromProto(const Protos::PlanSegmentOutput & proto) shuffle_function_name = proto.shuffle_hash_function(); parallel_size = proto.parallel_size(); keep_order = proto.keep_order(); + if (proto.has_shuffle_function_parameters()) + shuffle_func_params = deserializeFieldVectorFromProto(proto.shuffle_function_parameters()); } String PlanSegmentOutput::toString(size_t indent) const @@ -256,6 +261,15 @@ String PlanSegmentOutput::toString(size_t indent) const ostr << IPlanSegment::toString(indent) << "\n"; ostr << indent_str << "shuffle_function_name: " << shuffle_function_name << "\n"; + if (!shuffle_func_params.empty()) + { + ostr << indent_str << "shuffle_parameters: "; + for (auto & field : shuffle_func_params) + { + ostr << field.toString() << " "; + } + ostr << "\n"; + } ostr << indent_str << "parallel_size: " << parallel_size << "\n"; ostr << indent_str << "keep_order: " << keep_order; @@ -386,6 +400,7 @@ void PlanSegment::toProto(Protos::PlanSegment & plan_segment_proto) for (const auto & id : runtime_filters) plan_segment_proto.add_runtime_filter_id(id); + plan_segment_proto.set_profile_type(ReportProfileTypeConverter::toProto(profile_type)); } void PlanSegment::fillFromProto(const Protos::PlanSegment & proto, ContextMutablePtr context_) @@ -411,7 +426,10 @@ void PlanSegment::fillFromProto(const Protos::PlanSegment & proto, ContextMutabl runtime_filters.emplace(runtime_filter_id); } + if (proto.has_profile_type()) + profile_type = ReportProfileTypeConverter::fromProto(proto.profile_type()); } + /** * update plansegemnt if * 1. a segment is deserialized diff --git a/src/Interpreters/DistributedStages/PlanSegment.h b/src/Interpreters/DistributedStages/PlanSegment.h index 76f6e2b24be..beb2f22beba 100644 --- a/src/Interpreters/DistributedStages/PlanSegment.h +++ b/src/Interpreters/DistributedStages/PlanSegment.h @@ -35,6 +35,8 @@ namespace DB { using RuntimeFilterId = UInt32; +ENUM_WITH_PROTO_CONVERTER(ReportProfileType, Protos::ReportProfileType, (Unspecified, 0), (QueryPlan, 1), (QueryPipeline, 2)); + /** * SOURCE means the plan is the leaf of a plan segment tree, i.g. TableScan Node. * EXCHANGE always marking the plan that need to repartiton the data. @@ -172,11 +174,15 @@ class PlanSegmentInput : public IPlanSegment void setStorageID(const StorageID & storage_id_) { storage_id = storage_id_;} + void setStable(bool stable_) { stable = stable_; } + bool isStable() const { return stable; } + private: size_t parallel_index = std::numeric_limits::max(); /// no longer used bool keep_order = false; AddressInfos source_addresses; std::optional storage_id; + bool stable = false; }; using PlanSegmentInputPtr = std::shared_ptr; @@ -210,10 +216,18 @@ class PlanSegmentOutput : public IPlanSegment String toString(size_t indent = 0) const override; + void setShuffleFunctionName(const String & shuffle_function_name_) { shuffle_function_name = shuffle_function_name_; } + + const String & getShuffleFunctionName() { return shuffle_function_name; } + + void setShuffleFunctionParams(const Array & shuffle_func_params_) { shuffle_func_params = shuffle_func_params_; } + const Array & getShuffleFunctionParams() { return shuffle_func_params; } + private: String shuffle_function_name = "cityHash64"; size_t parallel_size; bool keep_order = false; + Array shuffle_func_params; }; using PlanSegmentOutputPtr = std::shared_ptr; @@ -310,6 +324,10 @@ class PlanSegment const std::unordered_set & getRuntimeFilters() const { return runtime_filters; } static void getRemoteSegmentId(const QueryPlan::Node * node, std::unordered_map & exchange_to_segment); + + void setProfileType(const ReportProfileType & type) { profile_type = type; } + + ReportProfileType getProfileType() const { return profile_type; } private: size_t segment_id; String query_id; @@ -325,6 +343,8 @@ class PlanSegment size_t exchange_parallel_size; std::unordered_set runtime_filters; + + ReportProfileType profile_type = ReportProfileType::Unspecified; }; class PlanSegmentTree diff --git a/src/Interpreters/DistributedStages/PlanSegmentExecutor.cpp b/src/Interpreters/DistributedStages/PlanSegmentExecutor.cpp index 0dad062d1b0..435dc34db81 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentExecutor.cpp +++ b/src/Interpreters/DistributedStages/PlanSegmentExecutor.cpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include #include @@ -76,6 +77,7 @@ #include #include #include +#include "Interpreters/sendPlanSegment.h" namespace ProfileEvents { @@ -173,7 +175,7 @@ std::optional PlanSegmentExecutor::execute query_log_element->event_time_microseconds = time_in_microseconds(finish_time); return convertSuccessPlanSegmentStatusToResult( - context, plan_segment_instance->info.execution_address, final_progress, sender_metrics, plan_segment_outputs); + context, plan_segment_instance->info.execution_address, final_progress, sender_metrics, plan_segment_outputs, segment_profile); } catch (...) { @@ -229,6 +231,7 @@ std::optional PlanSegmentExecutor::execute BlockIO PlanSegmentExecutor::lazyExecute(bool /*add_output_processors*/) { + LOG_DEBUG(&Poco::Logger::get("PlanSegmentExecutor"), "lazyExecute: {}", plan_segment->getPlanSegmentId()); BlockIO res; // Will run as master query and already initialized if (!CurrentThread::get().getQueryContext() || CurrentThread::get().getQueryContext().get() != context.get()) @@ -273,17 +276,67 @@ void PlanSegmentExecutor::collectSegmentQueryRuntimeMetric(const QueryStatus * q query_log_element->query_tables = query_access_info.tables; } -StepAggregatedOperatorProfiles collectStepRuntimeProfiles(int segment_id, const QueryPipelinePtr & pipeline) +StepProfiles collectStepRuntimeProfiles(const QueryPipelinePtr & pipeline) { ProcessorProfiles profiles; for (const auto & processor : pipeline->getProcessors()) profiles.push_back(std::make_shared(processor.get())); GroupedProcessorProfilePtr grouped_profiles = GroupedProcessorProfile::getGroupedProfiles(profiles); + auto step_profile = GroupedProcessorProfile::aggregateOperatorProfileToStepLevel(grouped_profiles); + AddressToStepProfile addr_to_step_profile; + addr_to_step_profile["localhost"] = step_profile; + return ProfileMetric::aggregateStepProfileBetweenWorkers(addr_to_step_profile); +} + +void fillPlanSegmentProfile( + PlanSegmentProfilePtr & segment_profile, + const QueryPipelinePtr & pipeline, + ReportProfileType type, + const QueryStatus * query_status, + ContextPtr context, + PlanSegment * plan_segment) +{ + AddressInfo current_address = getLocalAddress(*context); + segment_profile->worker_address = extractExchangeHostPort(current_address); + if (query_status) + { + auto query_status_info = query_status->getInfo(true, context->getSettingsRef().log_profile_events); + segment_profile->read_bytes = query_status_info.read_bytes; + segment_profile->read_rows = query_status_info.read_rows; + segment_profile->query_duration_ms = query_status_info.elapsed_seconds * 1000; + segment_profile->io_wait_ms = query_status_info.max_io_time_thread_ms; + } - std::unordered_map> segment_grouped_profile; - segment_grouped_profile[segment_id].emplace_back(grouped_profiles); - auto step_profile = StepOperatorProfile::aggregateOperatorProfileToStepLevel(segment_grouped_profile); - return AggregatedStepOperatorProfile::aggregateStepOperatorProfileBetweenWorkers(step_profile); + if (type == ReportProfileType::Unspecified) + return; + ProcessorProfiles profiles; + for (const auto & processor : pipeline->getProcessors()) + profiles.push_back(std::make_shared(processor.get())); + GroupedProcessorProfilePtr grouped_profiles = GroupedProcessorProfile::getGroupedProfiles(profiles); + if (type == ReportProfileType::QueryPipeline) + { + auto output_root = GroupedProcessorProfile::getOutputRoot(grouped_profiles); + segment_profile->profile_root_id = output_root->id; + segment_profile->profiles = GroupedProcessorProfile::getProfileMetricsFromOutputRoot(output_root); + } + else if (type == ReportProfileType::QueryPlan) + { + auto step_profile = GroupedProcessorProfile::aggregateOperatorProfileToStepLevel(grouped_profiles); + for (auto & [step_id, profile] : step_profile) + segment_profile->profiles.emplace(step_id, profile); + auto & plan = plan_segment->getQueryPlan(); + for (auto & node : plan.getNodes()) + { + if (!node.step->getAttributeDescriptions().empty() && segment_profile->profiles.contains(node.id)) + { + for (auto & att : node.step->getAttributeDescriptions()) + { + auto attribute_ptr = std::make_shared(att.second); + segment_profile->profiles.at(node.id)->attributes.emplace(att.first, attribute_ptr); + } + } + } + } } void PlanSegmentExecutor::doExecute() @@ -436,7 +489,13 @@ void PlanSegmentExecutor::doExecute() query_log_element->segment_profiles = std::make_shared>(); query_log_element->segment_profiles->emplace_back( PlanSegmentDescription::getPlanSegmentDescription(plan_segment_instance->plan_segment, true) - ->jsonPlanSegmentDescriptionAsString(collectStepRuntimeProfiles(plan_segment->getPlanSegmentId(), pipeline))); + ->jsonPlanSegmentDescriptionAsString(collectStepRuntimeProfiles(pipeline))); + } + if (context->getSettingsRef().report_segment_profiles && plan_segment) + { + segment_profile = std::make_shared(query_log_element->client_info.initial_query_id, plan_segment->getPlanSegmentId()); + fillPlanSegmentProfile( + segment_profile, pipeline, plan_segment->getProfileType(), &process_plan_segment_entry->get(), context, plan_segment); } if (context->getSettingsRef().log_processors_profiles) @@ -829,7 +888,11 @@ Processors PlanSegmentExecutor::buildRepartitionExchangeSink( arguments.emplace_back(plan_segment_outputs[output_index]->getHeader().getByName(column_name)); argument_numbers.emplace_back(plan_segment_outputs[output_index]->getHeader().getPositionByName(column_name)); } - auto repartition_func = RepartitionTransform::getDefaultRepartitionFunction(arguments, context); + auto repartition_func = RepartitionTransform::getRepartitionHashFunction( + plan_segment_outputs[output_index]->getShuffleFunctionName(), + arguments, + context, + plan_segment_outputs[output_index]->getShuffleFunctionParams()); size_t partition_num = senders.size(); if (keep_order && context->getSettingsRef().exchange_enable_keep_order_parallel_shuffle && partition_num > 1) diff --git a/src/Interpreters/DistributedStages/PlanSegmentExecutor.h b/src/Interpreters/DistributedStages/PlanSegmentExecutor.h index d1ef116f7d3..0cb3b4f31a6 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentExecutor.h +++ b/src/Interpreters/DistributedStages/PlanSegmentExecutor.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +60,7 @@ class PlanSegmentExecutor : private boost::noncopyable AddressInfo coordinator_address; RuntimeSegmentsStatus runtime_segment_status; Protos::SenderMetrics sender_metrics; + PlanSegmentProfilePtr segment_profile; }; std::optional execute(); BlockIO lazyExecute(bool add_output_processors = false); @@ -86,6 +88,7 @@ class PlanSegmentExecutor : private boost::noncopyable SenderMetrics sender_metrics; Progress progress; Progress final_progress; + PlanSegmentProfilePtr segment_profile; Processors buildRepartitionExchangeSink(BroadcastSenderPtrs & senders, bool keep_order, size_t output_index, const Block &header, OutputPortRawPtrs &ports); diff --git a/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.cpp b/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.cpp index f0d2b327cf3..b1a3118b7c6 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.cpp +++ b/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.cpp @@ -13,6 +13,7 @@ * limitations under the License. */ +#include #include #include #include @@ -39,6 +40,7 @@ namespace ErrorCodes extern const int BRPC_PROTOCOL_VERSION_UNSUPPORT; extern const int QUERY_WAS_CANCELLED; extern const int QUERY_WAS_CANCELLED_INTERNAL; + extern const int TIMEOUT_EXCEEDED; } WorkerNodeResourceData ResourceMonitorTimer::getResourceData() const { @@ -359,7 +361,21 @@ void PlanSegmentManagerRpcService::submitPlanSegment( /// Create session context for worker if (context->getServerType() == ServerType::cnch_worker) { - auto named_session = context->acquireNamedCnchSession(txn_id, {}, query_common->check_session()); + size_t max_execution_time_ms = 0; + if (query_common->has_query_expiration_timestamp()) + { + auto duration_ms = duration_ms_from_now(query_common->query_expiration_timestamp()); + if (!duration_ms) + throw Exception( + ErrorCodes::TIMEOUT_EXCEEDED, + "Max execution time exceeded before submit plan segment, try increase max_execution_time, current timestamp:{} " + "expires at:{}", + time_in_milliseconds(std::chrono::system_clock::now()), + query_common->query_expiration_timestamp()); + max_execution_time_ms = duration_ms.value(); + } + auto named_session + = context->acquireNamedCnchSession(txn_id, (max_execution_time_ms / 1000) + 1, query_common->check_session()); query_context = Context::createCopy(named_session->context); query_context->setSessionContext(query_context); query_context->setTemporaryTransaction(txn_id, primary_txn_id); @@ -454,8 +470,8 @@ void PlanSegmentManagerRpcService::submitPlanSegment( if (!settings_io_buf->empty()) { ReadBufferFromBrpcBuf settings_read_buf(*settings_io_buf); - /// Sets an extra row policy based on `client_info.initial_user` - query_context->setInitialRowPolicy(); + /// Sets an extra row policy based on `client_info.initial_user`, problematic for now + // query_context->setInitialRowPolicy(); /// apply settings changed const size_t MIN_MINOR_VERSION_ENABLE_STRINGS_WITH_FLAGS = 4; if (query_common->brpc_protocol_minor_revision() >= MIN_MINOR_VERSION_ENABLE_STRINGS_WITH_FLAGS) @@ -523,4 +539,16 @@ void PlanSegmentManagerRpcService::submitPlanSegment( LOG_ERROR(log, "executeQuery failed: {}", error_msg); } } + +void PlanSegmentManagerRpcService::sendPlanSegmentProfile( + ::google::protobuf::RpcController * /*controller*/, + const ::DB::Protos::PlanSegmentProfileRequest * request, + ::DB::Protos::PlanSegmentProfileResponse * /*response*/, + ::google::protobuf::Closure * done) +{ + brpc::ClosureGuard done_guard(done); + PlanSegmentProfilePtr profile = PlanSegmentProfile::fromProto(*request); + const SegmentSchedulerPtr & scheduler = context->getSegmentScheduler(); + scheduler->updateSegmentProfile(profile); +} } diff --git a/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.h b/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.h index 6b2df21872a..583bd2ecc56 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.h +++ b/src/Interpreters/DistributedStages/PlanSegmentManagerRpcService.h @@ -125,8 +125,14 @@ class PlanSegmentManagerRpcService : public Protos::PlanSegmentManagerService ::DB::Protos::SendProgressResponse * response, ::google::protobuf::Closure * done) override; + void sendPlanSegmentProfile( + ::google::protobuf::RpcController * /*controller*/, + const ::DB::Protos::PlanSegmentProfileRequest * request, + ::DB::Protos::PlanSegmentProfileResponse * /*response*/, + ::google::protobuf::Closure * done) override; private: + ContextMutablePtr context; std::unique_ptr report_metrics_timer; Poco::Logger * log; diff --git a/src/Interpreters/DistributedStages/PlanSegmentProcessList.cpp b/src/Interpreters/DistributedStages/PlanSegmentProcessList.cpp index e64fd888ee6..a904d671043 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentProcessList.cpp +++ b/src/Interpreters/DistributedStages/PlanSegmentProcessList.cpp @@ -286,7 +286,16 @@ CancellationCode PlanSegmentProcessList::tryCancelPlanSegmentGroup(const String if (segment_group.get()) { if (coordinator_address.empty() || segment_group->coordinator_address == coordinator_address) + { found = segment_group->tryCancel(true); + LOG_DEBUG( + logger, + "Try cancel for distributed query[{}@{}@{}] from PlanSegmentProcessList, result is {}", + initial_query_id, + coordinator_address, + segment_group->initial_query_start_time_ms, + found); + } else { LOG_WARNING( diff --git a/src/Interpreters/DistributedStages/PlanSegmentReport.cpp b/src/Interpreters/DistributedStages/PlanSegmentReport.cpp index 5389d442cd9..5e4ca6f3b50 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentReport.cpp +++ b/src/Interpreters/DistributedStages/PlanSegmentReport.cpp @@ -16,6 +16,8 @@ void reportExecutionResult(const PlanSegmentExecutor::ExecutionResult & result) static auto * logger = &Poco::Logger::get("PlanSegmentExecutor"); try { + if (result.segment_profile) + reportSuccessPlanSegmentProfile(result); auto address = extractExchangeHostPort(result.coordinator_address); const auto & status = result.runtime_segment_status; @@ -119,7 +121,8 @@ PlanSegmentExecutor::ExecutionResult convertSuccessPlanSegmentStatusToResult( const AddressInfo & execution_address, Progress & final_progress, SenderMetrics & sender_metrics, - PlanSegmentOutputs & plan_segment_outputs) + PlanSegmentOutputs & plan_segment_outputs, + PlanSegmentProfilePtr & segment_profile) { PlanSegmentExecutor::ExecutionResult result; @@ -133,8 +136,34 @@ PlanSegmentExecutor::ExecutionResult convertSuccessPlanSegmentStatusToResult( result.runtime_segment_status.message = "execute success"; result.runtime_segment_status.metrics.final_progress = final_progress.toProto(); result.sender_metrics = senderMetricsToProto(plan_segment_outputs, sender_metrics, execution_address); + if (query_context->getSettingsRef().report_segment_profiles && segment_profile) + result.segment_profile = segment_profile; return result; } +void reportSuccessPlanSegmentProfile(const PlanSegmentExecutor::ExecutionResult & result) +{ + static auto * logger = &Poco::Logger::get("PlanSegmentExecutor"); + try + { + std::shared_ptr rpc_client = RpcChannelPool::getInstance().getClient( + extractExchangeHostPort(result.coordinator_address), BrpcChannelPoolOptions::DEFAULT_CONFIG_KEY); + Protos::PlanSegmentManagerService_Stub manager(&rpc_client->getChannel()); + brpc::Controller cntl; + Protos::PlanSegmentProfileRequest request; + Protos::PlanSegmentProfileResponse response; + result.segment_profile->is_succeed = true; + result.segment_profile->toProto(request); + + manager.sendPlanSegmentProfile(&cntl, &request, &response, nullptr); + rpc_client->assertController(cntl); + LOG_TRACE( + logger, "PlanSegment-{} send profile to coordinator successfully, query id-{}.", request.segment_id(), request.query_id()); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} } diff --git a/src/Interpreters/DistributedStages/PlanSegmentReport.h b/src/Interpreters/DistributedStages/PlanSegmentReport.h index 9480bb6aca2..e1f6a044539 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentReport.h +++ b/src/Interpreters/DistributedStages/PlanSegmentReport.h @@ -23,10 +23,14 @@ PlanSegmentExecutor::ExecutionResult convertFailurePlanSegmentStatusToResult( SenderMetrics sender_metrics = {}, PlanSegmentOutputs plan_segment_outputs = {}); + PlanSegmentExecutor::ExecutionResult convertSuccessPlanSegmentStatusToResult( ContextPtr query_context, const AddressInfo & execution_address, Progress & final_progress, SenderMetrics & sender_metrics, - PlanSegmentOutputs & plan_segment_outputs); + PlanSegmentOutputs & plan_segment_outputs, + PlanSegmentProfilePtr & segment_profile); + +[[maybe_unused]] static void reportSuccessPlanSegmentProfile(const PlanSegmentExecutor::ExecutionResult & result); } diff --git a/src/Interpreters/DistributedStages/PlanSegmentSplitter.cpp b/src/Interpreters/DistributedStages/PlanSegmentSplitter.cpp index f88872ea5c4..e68549f10c0 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentSplitter.cpp +++ b/src/Interpreters/DistributedStages/PlanSegmentSplitter.cpp @@ -13,6 +13,7 @@ * limitations under the License. */ +#include #include #include @@ -26,7 +27,9 @@ #include #include #include +#include #include +#include namespace DB { @@ -37,7 +40,8 @@ void PlanSegmentSplitter::split(QueryPlan & query_plan, PlanSegmentContext & pla if (plan_segment_context.context->getSettingsRef().distributed_max_parallel_size != 0) SetScalable::setScalable(query_plan.getRoot(), query_plan.getCTENodes(), *plan_segment_context.context); size_t exchange_id = 0; - PlanSegmentVisitorContext split_context{{}, {}, exchange_id}; + PlanSegmentVisitorContext split_context{ + {}, {}, exchange_id, plan_segment_context.context->getSettingsRef().exchange_shuffle_method_name}; visitor.createPlanSegment(query_plan.getRoot(), split_context); std::unordered_map plan_mapping; @@ -99,7 +103,6 @@ void PlanSegmentSplitter::split(QueryPlan & query_plan, PlanSegmentContext & pla } } } - } ParallelSizeChecker checker; @@ -119,8 +122,11 @@ void PlanSegmentSplitter::split(QueryPlan & query_plan, PlanSegmentContext & pla auto first = sizes[0]; for (auto size : sizes) { - if (size != first) - throw Exception("Segment parallel size not match", ErrorCodes::LOGICAL_ERROR); + // TODO(wangtao.vip): check with @JingPeng whether it is right to skip (error in tpcds 05/08). + if (size != first && !plan_segment_context.context->getSettingsRef().bsp_mode) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Segment parallel size not match {} and {}", size, first); + } } } } @@ -152,8 +158,14 @@ PlanSegmentResult PlanSegmentVisitor::visitExchangeNode(QueryPlan::Node * node, bool is_add_extremes = false; for (auto & child : node->children) { - PlanSegmentVisitorContext child_context{{}, {}, split_context.exchange_id}; - auto plan_segment = createPlanSegment(child, child_context); + String hash_func = plan_segment_context.context->getSettingsRef().exchange_shuffle_method_name; + PlanSegmentVisitorContext child_context{ + {}, + {}, + split_context.exchange_id, + step->getSchema().getHashFunc(hash_func), + step->getSchema().getParams()}; + PlanSegment * plan_segment = createPlanSegment(child, child_context); is_add_totals |= child_context.is_add_totals; is_add_extremes |= child_context.is_add_extremes; @@ -164,6 +176,7 @@ PlanSegmentResult PlanSegmentVisitor::visitExchangeNode(QueryPlan::Node * node, // TODO: Not support one ExchangeStep with multi children yet(multi children can't share one exchange id), we may need to support it later. input->setExchangeId(plan_segment->getPlanSegmentOutputs().back()->getExchangeId()); input->setKeepOrder(step->needKeepOrder()); + input->setStable(step->getSchema().getBucketExpr() != nullptr); inputs.push_back(input); if (auto * output = dynamic_cast(plan_segment->getPlanSegmentOutput().get())) @@ -196,7 +209,7 @@ PlanSegmentResult PlanSegmentVisitor::visitCTERefNode(QueryPlan::Node * node, Pl if (cte_node->step->getType() == IQueryPlanStep::Type::Exchange) { exchange_step = dynamic_cast(cte_node->step.get()); - PlanSegmentVisitorContext child_context{{}, {}, split_context.exchange_id, split_context.is_add_extremes, split_context.is_add_totals, exchange_step->isScalable()}; + PlanSegmentVisitorContext child_context{{}, {}, split_context.exchange_id, split_context.hash_func, split_context.params, split_context.is_add_extremes, split_context.is_add_totals, exchange_step->isScalable()}; plan_segment = createPlanSegment(cte_node->children[0], child_context); } else @@ -270,11 +283,15 @@ PlanSegment * PlanSegmentVisitor::createPlanSegment(QueryPlan::Node * node, size auto plan_segment = std::make_unique(segment_id, plan_segment_context.query_id, cluster_name); plan_segment->setQueryPlan(std::move(sub_plan)); - plan_segment->setExchangeParallelSize(plan_segment_context.context->getSettingsRef().exchange_parallel_size); + auto exchange_parallel_size = plan_segment_context.context->getSettingsRef().exchange_parallel_size; + plan_segment->setExchangeParallelSize(exchange_parallel_size); PlanSegmentType output_type = segment_id == 0 ? PlanSegmentType::OUTPUT : PlanSegmentType::EXCHANGE; auto output = std::make_shared(plan_segment->getQueryPlan().getRoot()->step->getOutputStream().header, output_type); + + output->setShuffleFunctionName(split_context.hash_func); + output->setShuffleFunctionParams(split_context.params); if (output_type == PlanSegmentType::OUTPUT) { plan_segment->setParallelSize(1); @@ -288,15 +305,36 @@ PlanSegment * PlanSegmentVisitor::createPlanSegment(QueryPlan::Node * node, size else output->setParallelSize(parallel); } - output->setExchangeParallelSize(plan_segment_context.context->getSettingsRef().exchange_parallel_size); + output->setExchangeParallelSize(exchange_parallel_size); output->setExchangeId(split_context.exchange_id++); plan_segment->appendPlanSegmentOutput(output); auto inputs = findInputs(plan_segment->getQueryPlan().getRoot()); if (inputs.empty()) inputs.push_back(std::make_shared(Block(), PlanSegmentType::UNKNOWN)); - for (auto & input : inputs) - input->setExchangeParallelSize(plan_segment_context.context->getSettingsRef().exchange_parallel_size); + if (unlikely(exchange_parallel_size > 1)) + { + for (auto & input : inputs) + { + if (input->isStable()) + { + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "exchange_parallel_size can't be {} when input is stable for segment {} ", + exchange_parallel_size, + plan_segment->getPlanSegmentId()); + } + input->setExchangeParallelSize(exchange_parallel_size); + } + } + else + { + for (auto & input : inputs) + { + input->setExchangeParallelSize(exchange_parallel_size); + } + } + if (inputs[0]->getExchangeMode() == ExchangeMode::GATHER) plan_segment->setParallelSize(1); diff --git a/src/Interpreters/DistributedStages/PlanSegmentSplitter.h b/src/Interpreters/DistributedStages/PlanSegmentSplitter.h index 4b86718b658..80212483e34 100644 --- a/src/Interpreters/DistributedStages/PlanSegmentSplitter.h +++ b/src/Interpreters/DistributedStages/PlanSegmentSplitter.h @@ -56,6 +56,8 @@ struct PlanSegmentVisitorContext PlanSegmentInputs inputs; std::vector children; size_t & exchange_id; + String hash_func; + Array params = Array(); bool is_add_totals = false; bool is_add_extremes = false; bool scalable = true; diff --git a/src/Interpreters/DistributedStages/ProgressManager.cpp b/src/Interpreters/DistributedStages/ProgressManager.cpp index e5f8ad336d5..f0ebb1eebbb 100644 --- a/src/Interpreters/DistributedStages/ProgressManager.cpp +++ b/src/Interpreters/DistributedStages/ProgressManager.cpp @@ -1,8 +1,45 @@ +#include +#include +#include #include +#include +#include #include +#include namespace DB { + +TCPProgressSender::TCPProgressSender(std::function send_tcp_progress_, size_t interval_) + : logger(&Poco::Logger::get("ProgressManager")), send_tcp_progress(send_tcp_progress_), interval(interval_) +{ + if (send_tcp_progress && interval) + { + LOG_TRACE(logger, "TCPProgressSender started"); + thread = std::make_unique([&]() { + while (true) + { + std::unique_lock lock(mu); + var.wait_for(lock, std::chrono::milliseconds(this->interval), [&]() { return this->shutdown.load(); }); + if (shutdown) + { + LOG_TRACE(logger, "TCPProgressSender shutdown"); + break; + } + this->send_tcp_progress(); + } + }); + } +} + +TCPProgressSender::~TCPProgressSender() +{ + shutdown = true; + var.notify_all(); + if (thread && thread->joinable()) + thread->join(); +} + void ProgressManager::onProgress(UInt32 segment_id, UInt32 parallel_index, const Progress & progress_) { std::unique_lock lock(segment_progress_mutex); diff --git a/src/Interpreters/DistributedStages/ProgressManager.h b/src/Interpreters/DistributedStages/ProgressManager.h index e07bd18c29f..9f97bcdcae8 100644 --- a/src/Interpreters/DistributedStages/ProgressManager.h +++ b/src/Interpreters/DistributedStages/ProgressManager.h @@ -1,12 +1,33 @@ #pragma once +#include +#include #include #include #include #include +#include namespace DB { + +// send progress repeatedly +class TCPProgressSender +{ +public: + TCPProgressSender(std::function send_tcp_progress_, size_t interval_); + ~TCPProgressSender(); + +private: + Poco::Logger * logger; + std::atomic_bool shutdown = {false}; + std::mutex mu; + std::condition_variable var; + std::function send_tcp_progress; + std::unique_ptr thread; + size_t interval; +}; + class ProgressManager { public: diff --git a/src/Interpreters/DistributedStages/Scheduler.cpp b/src/Interpreters/DistributedStages/Scheduler.cpp index 09f7fa4bb1b..69ce416c121 100644 --- a/src/Interpreters/DistributedStages/Scheduler.cpp +++ b/src/Interpreters/DistributedStages/Scheduler.cpp @@ -34,7 +34,11 @@ bool Scheduler::addBatchTask(BatchTaskPtr batch_task) bool Scheduler::getBatchTaskToSchedule(BatchTaskPtr & task) { - return queue.tryPop(task, query_expiration_ms); + auto now = time_in_milliseconds(std::chrono::system_clock::now()); + if (query_expiration_ms <= now) + return false; + else + return queue.tryPop(task, query_expiration_ms - now); } void Scheduler::dispatchTask(PlanSegment * plan_segment_ptr, const SegmentTask & task, const size_t idx) @@ -89,7 +93,7 @@ TaskResult Scheduler::scheduleTask(PlanSegment * plan_segment_ptr, const Segment NodeSelectorResult selector_info; { std::unique_lock lock(node_selector_result_mutex); - auto selector_result = node_selector_result.emplace(task.task_id, node_selector.select(plan_segment_ptr, task.is_source)); + auto selector_result = node_selector_result.emplace(task.task_id, node_selector.select(plan_segment_ptr, task.has_table_scan)); selector_info = selector_result.first->second; } prepareTask(plan_segment_ptr, selector_info.worker_nodes.size()); @@ -136,7 +140,7 @@ void Scheduler::schedule() { Stopwatch sw; genTopology(); - genSourceTasks(); + genLeafTasks(); /// Leave final segment alone. while (!dag_graph_ptr->plan_segment_status_ptr->is_final_stage_start) @@ -176,20 +180,20 @@ void Scheduler::schedule() LOG_DEBUG(log, "Scheduling takes {} ms", sw.elapsedMilliseconds()); } -void Scheduler::genSourceTasks() +void Scheduler::genLeafTasks() { - LOG_DEBUG(log, "Begin generate source tasks"); + LOG_DEBUG(log, "Begin generate leaf tasks"); auto batch_task = std::make_shared(); - batch_task->reserve(dag_graph_ptr->sources.size()); - for (auto source_id : dag_graph_ptr->sources) + batch_task->reserve(dag_graph_ptr->leaf_segments.size()); + for (auto leaf_id : dag_graph_ptr->leaf_segments) { - LOG_TRACE(log, "Generate task for source segment {}", source_id); - if (source_id == dag_graph_ptr->final) + LOG_TRACE(log, "Generate task for leaf segment {}", leaf_id); + if (leaf_id == dag_graph_ptr->final) continue; - batch_task->emplace_back(source_id, true); - plansegment_topology.erase(source_id); - LOG_TRACE(log, "Task for source segment {} generated", source_id); + batch_task->emplace_back(leaf_id, true); + plansegment_topology.erase(leaf_id); + LOG_TRACE(log, "Task for leaf segment {} generated", leaf_id); } addBatchTask(std::move(batch_task)); } @@ -256,7 +260,7 @@ void Scheduler::removeDepsAndEnqueueTask(const SegmentTask & task) LOG_INFO(log, "Erase dependency {} for segment {}", task_id, id); if (dependencies.empty()) { - batch_task->emplace_back(id); + batch_task->emplace_back(id, dag_graph_ptr->segments_has_table_scan.contains(id)); } } for (const auto & t : *batch_task) diff --git a/src/Interpreters/DistributedStages/Scheduler.h b/src/Interpreters/DistributedStages/Scheduler.h index a908dbc3e74..947382bcf55 100644 --- a/src/Interpreters/DistributedStages/Scheduler.h +++ b/src/Interpreters/DistributedStages/Scheduler.h @@ -33,12 +33,12 @@ enum class TaskStatus : uint8_t /// Indicates a plan segment. struct SegmentTask { - explicit SegmentTask(size_t task_id_, bool is_source_ = false) : task_id(task_id_), is_source(is_source_) + explicit SegmentTask(size_t task_id_, bool has_table_scan_ = false) : task_id(task_id_), has_table_scan(has_table_scan_) { } // plan segment id. size_t task_id; - bool is_source; + bool has_table_scan; }; /// Indicates a plan segment instance. @@ -110,7 +110,7 @@ class Scheduler , local_address(getLocalAddress(*query_context)) , log(&Poco::Logger::get("Scheduler")) { - cluster_nodes.rank_workers.emplace_back(local_address, NodeType::Local, ""); + cluster_nodes.all_workers.emplace_back(local_address, NodeType::Local, ""); timespec query_expiration_ts = query_context->getQueryExpirationTimeStamp(); query_expiration_ms = query_expiration_ts.tv_sec * 1000 + query_expiration_ts.tv_nsec / 1000000; } @@ -153,7 +153,7 @@ class Scheduler std::atomic stopped{false}; void genTopology(); - void genSourceTasks(); + void genLeafTasks(); bool getBatchTaskToSchedule(BatchTaskPtr & task); virtual void sendResources(PlanSegment * plan_segment_ptr) { diff --git a/src/Interpreters/DistributedStages/executePlanSegment.cpp b/src/Interpreters/DistributedStages/executePlanSegment.cpp index 15506bf7a20..f6a3bb9d553 100644 --- a/src/Interpreters/DistributedStages/executePlanSegment.cpp +++ b/src/Interpreters/DistributedStages/executePlanSegment.cpp @@ -114,7 +114,7 @@ static void OnSendPlanSegmentCallback( { LOG_ERROR( &Poco::Logger::get("executePlanSegment"), - "send plansegment to {} failed, error: {}, msg: {}", + "Send plansegment to {} failed, error: {}, msg: {}", butil::endpoint2str(cntl->remote_side()).c_str(), cntl->ErrorText(), response->message()); @@ -127,7 +127,7 @@ static void OnSendPlanSegmentCallback( else { LOG_TRACE( - &Poco::Logger::get("executePlanSegment"), "send plansegment to {} success", butil::endpoint2str(cntl->remote_side()).c_str()); + &Poco::Logger::get("executePlanSegment"), "Send plansegment to {} success", butil::endpoint2str(cntl->remote_side()).c_str()); async_context->asyncComplete(cntl->call_id(), result); } } @@ -163,6 +163,8 @@ void prepareQueryCommonBuf( query_common.set_check_session(!context->getSettingsRef().bsp_mode && !context->getSettingsRef().enable_prune_empty_resource); query_common.set_txn_id(context->getCurrentTransactionID().toUInt64()); query_common.set_primary_txn_id(context->getCurrentTransaction()->getPrimaryTransactionID().toUInt64()); + auto query_expiration_ts = context->getQueryExpirationTimeStamp(); + query_common.set_query_expiration_timestamp(query_expiration_ts.tv_sec * 1000 + query_expiration_ts.tv_nsec / 1000000); const String & quota_key = client_info.quota_key; if (!client_info.quota_key.empty()) query_common.set_quota(quota_key); diff --git a/src/Interpreters/ExplainSettings.h b/src/Interpreters/ExplainSettings.h index e2740eaec68..e9418e87cbf 100644 --- a/src/Interpreters/ExplainSettings.h +++ b/src/Interpreters/ExplainSettings.h @@ -1,4 +1,5 @@ #include +#include "common/types.h" namespace DB @@ -19,6 +20,7 @@ struct QueryMetadataSettings bool lineage = false; bool format_json = false; bool lineage_use_optimizer = false; + bool ignore_format = false; constexpr static char name[] = "METADATA"; @@ -27,8 +29,11 @@ struct QueryMetadataSettings {"json", json}, {"lineage", lineage}, {"lineage_use_optimizer", lineage_use_optimizer}, - {"format_json", format_json} + {"format_json", format_json}, + {"ignore_format", ignore_format} }; + + std::unordered_map> uint_settings = {}; }; struct QueryPlanSettings @@ -43,7 +48,12 @@ struct QueryPlanSettings bool pb_json = false; bool verbose = true; bool add_whitespace = true; // used to pretty print json - bool aggregate_profiles = true; + bool aggregate_profiles = true; + bool pretty_num = true; + bool selected_parts = false; + bool segment_profile = false; + + UInt64 segment_id = UINT64_MAX; constexpr static char name[] = "PLAN"; @@ -60,7 +70,11 @@ struct QueryPlanSettings {"add_whitespace", add_whitespace}, {"aggregate_profiles", aggregate_profiles}, {"verbose", verbose}, - }; + {"pretty_num", pretty_num}, + {"selected_parts", selected_parts}, + {"segment_profile", segment_profile}}; + + std::unordered_map> uint_settings = {{"segment_id", segment_id}}; }; struct QueryPipelineSettings @@ -79,25 +93,39 @@ struct QueryPipelineSettings {"compact", compact}, {"stats", stats} }; + + std::unordered_map> uint_settings = {}; }; template struct ExplainSettings : public Settings { using Settings::boolean_settings; + using Settings::uint_settings; bool has(const std::string & name_) const { - return boolean_settings.count(name_) > 0; + return boolean_settings.count(name_) > 0 || uint_settings.count(name_) > 0; } - void setBooleanSetting(const std::string & name_, bool value) + bool setBooleanSetting(const std::string & name_, bool value) { auto it = boolean_settings.find(name_); if (it == boolean_settings.end()) - throw Exception("Unknown setting for ExplainSettings: " + name_, ErrorCodes::LOGICAL_ERROR); + return false; + + it->second.get() = value; + return true; + } + + bool setUIntSetting(const std::string & name_, UInt64 value) + { + auto it = uint_settings.find(name_); + if (it == uint_settings.end()) + return false; it->second.get() = value; + return true; } std::string getSettingsList() const @@ -111,6 +139,13 @@ struct ExplainSettings : public Settings res += setting.first; } + for (const auto & setting : uint_settings) + { + if (!res.empty()) + res += ", "; + + res += setting.first; + } return res; } }; @@ -131,15 +166,24 @@ ExplainSettings checkAndGetSettings(const ASTPtr & ast_settings) "Supported settings: " + settings.getSettingsList(), ErrorCodes::UNKNOWN_SETTING); if (change.value.getType() != Field::Types::UInt64) - throw Exception("Invalid type " + std::string(change.value.getTypeName()) + " for setting \"" + change.name + - "\" only boolean settings are supported", ErrorCodes::INVALID_SETTING_VALUE); + throw Exception( + "Invalid type " + std::string(change.value.getTypeName()) + " for setting \"" + change.name + + "\" only boolean and UInt64 settings are supported", + ErrorCodes::INVALID_SETTING_VALUE); auto value = change.value.get(); - if (value > 1) - throw Exception("Invalid value " + std::to_string(value) + " for setting \"" + change.name + - "\". Only boolean settings are supported", ErrorCodes::INVALID_SETTING_VALUE); - settings.setBooleanSetting(change.name, value); + bool has_setting = false; + if (value < 2) + has_setting |= settings.setBooleanSetting(change.name, value); + + if (!has_setting) + has_setting |= settings.setUIntSetting(change.name, value); + + if (!has_setting) + throw Exception("Unknown setting for ExplainSettings: " + change.name, ErrorCodes::LOGICAL_ERROR); + // throw Exception("Invalid value " + std::to_string(value) + " for setting \"" + change.name + + // "\". Only boolean settings are supported", ErrorCodes::INVALID_SETTING_VALUE); } return settings; diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index a005c56d481..bb32567201b 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -789,7 +789,7 @@ void ExpressionAnalyzer::getRootActionsWithOwnBitmapInfo(const ASTPtr & ast, boo LogAST log; ActionsVisitor::Data visitor_data(getContext(), settings.size_limits_for_set, subquery_depth, sourceColumns(), std::move(actions), prepared_sets, subqueries_for_sets, - no_subqueries, false, only_consts, !isRemoteStorage(), getAggregationKeysInfo(), false, own_index_context); + no_subqueries, false, only_consts, !isRemoteStorage(), getAggregationKeysInfo(), false, own_index_context, metadata_snapshot); ActionsVisitor(visitor_data, log.stream()).visit(ast); actions = visitor_data.getActions(); } diff --git a/src/Interpreters/GinFilter.cpp b/src/Interpreters/GinFilter.cpp index 238f8b8c650..18b066504b0 100644 --- a/src/Interpreters/GinFilter.cpp +++ b/src/Interpreters/GinFilter.cpp @@ -200,4 +200,14 @@ bool GinFilter::match(const GinPostingsCache & postings_cache , roaring::Roaring return false; } +String GinFilter::getTermsInString() const +{ + String result; + for (const String & term : terms) + { + result += " " + term; + } + return result; +} + } diff --git a/src/Interpreters/GinFilter.h b/src/Interpreters/GinFilter.h index 128cddbb5fc..6a5261690fd 100644 --- a/src/Interpreters/GinFilter.h +++ b/src/Interpreters/GinFilter.h @@ -78,6 +78,9 @@ class GinFilter void filpWithRange(roaring::Roaring & result) const; size_t getAllRangeSize() const; + // for log trace + String getTermsInString() const; + private: /// Filter parameters [[__maybe_unused__]] const GinFilterParameters & params; diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index b6dc5384706..d90b4bd539e 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -935,20 +935,20 @@ class AddedColumns bool need_filter = false; IColumn::Filter filter; - void reserve(bool need_replicate) + void reserve(bool /*need_replicate*/) { - if (!max_joined_block_rows) - return; + // if (!max_joined_block_rows) + // return; - /// Do not allow big allocations when user set max_joined_block_rows to huge value - size_t reserve_size = std::min(max_joined_block_rows, kMaxAllowedJoinedBlockRows); + // /// Do not allow big allocations when user set max_joined_block_rows to huge value + // size_t reserve_size = std::min(max_joined_block_rows, kMaxAllowedJoinedBlockRows); - if (need_replicate) - /// Reserve 10% more space for columns, because some rows can be repeated - reserve_size = static_cast(1.1 * reserve_size); + // if (need_replicate) + // /// Reserve 10% more space for columns, because some rows can be repeated + // reserve_size = static_cast(1.1 * reserve_size); - for (auto & column : columns) - column->reserve(reserve_size); + // for (auto & column : columns) + // column->reserve(reserve_size); } std::vector> right_anti_index; diff --git a/src/Interpreters/ITokenExtractor.cpp b/src/Interpreters/ITokenExtractor.cpp index 26355996f37..a4635b801b6 100644 --- a/src/Interpreters/ITokenExtractor.cpp +++ b/src/Interpreters/ITokenExtractor.cpp @@ -200,14 +200,119 @@ bool SplitTokenExtractor::nextInStringPadded(const char * data, size_t length, s bool SplitTokenExtractor::nextInStringLike(const char * data, size_t length, size_t * pos, String & token) const { token.clear(); - bool bad_token = false; // % or _ before token + bool escaped = false; + while (*pos < length) + { + if (escaped && (data[*pos] == '%' || data[*pos] == '_' || data[*pos] == '\\')) + { + // Append escaped characters directly to the token + token += data[*pos]; + escaped = false; + ++*pos; + } + else if (!escaped && (data[*pos] == '%' || data[*pos] == '_')) + { + if (!token.empty()) + { + return true; // Return the valid token + } + // Wildcard: reset token, continue to next character + token.clear(); + ++*pos; + } + else if (!escaped && data[*pos] == '\\') + { + // Escape character, set escape mode + escaped = true; + ++*pos; + } + else if (!escaped && isASCII(data[*pos]) && !isAlphaNumericASCII(data[*pos])) + { + if (!token.empty()) + return true; + + token.clear(); + ++*pos; + } + else + { + const size_t sz = UTF8::seqLength(static_cast(data[*pos])); + for (size_t j = 0; j < sz; ++j) + { + token += data[*pos]; + ++*pos; + } + escaped = false; + } + } + + + return !token.empty(); +} + +bool StandardTokenExtractor::nextInString( + const char * data, size_t length, size_t * __restrict pos, size_t * __restrict token_start, size_t * __restrict token_length) const +{ + *token_start = *pos; + *token_length = 0; + + while (*pos < length) + { + if (isASCII(data[*pos])) + { + if (isAlphaNumericASCII(data[*pos])) + { + // if is Alpha or Numeric just continue + ++*pos; + ++*token_length; + continue; + } + + /// Finish current token if have any + if (*token_length > 0) + { + return true; + } + else + { + // skip current and split continue + *token_start = ++*pos; + } + } + else // UTF-8 case + { + // Finish current token if have any + if (*token_length > 0) + return true; + + // get length and return + const size_t sz = UTF8::seqLength(static_cast(data[*pos])); + (*pos)+=sz; + (*token_length)+=sz; + // submit utf-8 token + if(*token_length > 0) + { + return true; + } + + } + } + + return *token_length > 0; +} + + +bool StandardTokenExtractor::nextInStringLike(const char * data, size_t length, size_t * pos, std::string & token) const +{ + token.clear(); bool escaped = false; while (*pos < length) { if (!escaped && (data[*pos] == '%' || data[*pos] == '_')) { + if (!token.empty()) + return true; token.clear(); - bad_token = true; ++*pos; } else if (!escaped && data[*pos] == '\\') @@ -215,19 +320,18 @@ bool SplitTokenExtractor::nextInStringLike(const char * data, size_t length, siz escaped = true; ++*pos; } - else if (isASCII(data[*pos]) && !isAlphaNumericASCII(data[*pos])) + else if (!escaped && isASCII(data[*pos]) && !isAlphaNumericASCII(data[*pos])) { - if (!bad_token && !token.empty()) + if (!token.empty()) return true; token.clear(); - bad_token = false; escaped = false; ++*pos; } else { - const size_t sz = UTF8::seqLength(static_cast(data[*pos])); + const size_t sz = UTF8::seqLength(static_cast(data[*pos])); for (size_t j = 0; j < sz; ++j) { token += data[*pos]; @@ -237,7 +341,68 @@ bool SplitTokenExtractor::nextInStringLike(const char * data, size_t length, siz } } - return !bad_token && !token.empty(); + return !token.empty(); } +// bool StandardTokenExtractor::nextInStringLike( +// const char * data, size_t length, size_t * pos, String & out) const +// { +// out.clear(); +// bool escaped = false; + +// while (*pos < length) +// { +// if (!escaped && (data[*pos] == '%' || data[*pos] == '_')) +// { +// if (!out.empty()) +// return true; +// out.clear(); +// ++*pos; +// } +// else if (!escaped && data[*pos] == '\\') +// { +// escaped = true; +// ++*pos; +// } +// else if (isASCII(data[*pos])) +// { +// if(isAlphaNumericASCII(data[*pos])) +// { +// out += data[*pos]; +// ++*pos; +// escaped = false; +// continue; +// } + +// if (!bad_token && !out.empty()) +// return true; + +// out.clear(); +// bad_token = false; +// escaped = false; +// ++*pos; +// } +// else +// { +// // before cut utf-8 submit token if has any +// if (!bad_token && !out.empty()) +// return true; + +// out.clear(); +// bad_token = false; +// escaped = false; + +// const size_t sz = UTF8::seqLength(static_cast(data[*pos])); + +// out.append((data + *pos), sz); +// (*pos) += sz; + +// // submit token after cut utf-8 +// if (!out.empty()) +// return true; +// } +// } +// return !bad_token && !out.empty(); +// } + } diff --git a/src/Interpreters/ITokenExtractor.h b/src/Interpreters/ITokenExtractor.h index 7b335e02d16..3e12e85c946 100644 --- a/src/Interpreters/ITokenExtractor.h +++ b/src/Interpreters/ITokenExtractor.h @@ -104,7 +104,18 @@ struct SplitTokenExtractor final : public ITokenExtractor }; +class StandardTokenExtractor final : public ITokenExtractor +{ +public: + static const char * getName() { return "standard"; } + + bool nextInString(const char * data, size_t length, size_t * __restrict pos, size_t * __restrict token_start, size_t * __restrict token_length) const override; + + bool nextInStringLike(const char * data, size_t length, size_t * pos, String & out) const override; +}; + } + diff --git a/src/Interpreters/InterpreterAlterDiskCacheQuery.cpp b/src/Interpreters/InterpreterAlterDiskCacheQuery.cpp index d8b8b9fb24e..b20a6e6ccec 100644 --- a/src/Interpreters/InterpreterAlterDiskCacheQuery.cpp +++ b/src/Interpreters/InterpreterAlterDiskCacheQuery.cpp @@ -44,7 +44,7 @@ BlockIO InterpreterAlterDiskCacheQuery::execute() if (query.type == ASTAlterDiskCacheQuery::Type::PRELOAD) { - storage->sendPreloadTasks(getContext(), std::move(parts), query.sync, getContext()->getSettings().parts_preload_level); + storage->sendPreloadTasks(getContext(), std::move(parts), query.sync, getContext()->getSettings().parts_preload_level, time(nullptr)); } else if (query.type == ASTAlterDiskCacheQuery::Type::DROP) { diff --git a/src/Interpreters/InterpreterAlterQuery.cpp b/src/Interpreters/InterpreterAlterQuery.cpp index 6f17b7744d1..be1966b8bde 100644 --- a/src/Interpreters/InterpreterAlterQuery.cpp +++ b/src/Interpreters/InterpreterAlterQuery.cpp @@ -158,6 +158,19 @@ BlockIO InterpreterAlterQuery::executeToTable(const ASTAlterQuery & alter) if (!mutation_commands.empty()) { + /// TODO: zuochuang.zema, zhangsiqi.awesome we need a detailed description about alter commands and mutation commands. + /// Precheck: if some alter commands will be converted to mutation commands later, it means the txn will generate two mutation entries. + /// Just reject such queries. + if (!alter_commands.empty()) + { + auto metadata = table->getInMemoryMetadataPtr(); + auto mutation_commands_in_alter = alter_commands.getMutationCommands(*metadata, false, getContext()); + if (!mutation_commands_in_alter.empty()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot execute mutation commands and alter commands in single query"); + } + } + table->checkMutationIsPossible(mutation_commands, getContext()->getSettingsRef()); MutationsInterpreter(table, metadata_snapshot, mutation_commands, getContext(), false).validate(); table->mutate(mutation_commands, getContext()); @@ -167,7 +180,7 @@ BlockIO InterpreterAlterQuery::executeToTable(const ASTAlterQuery & alter) if (!partition_commands.empty()) { table->checkAlterPartitionIsPossible(partition_commands, metadata_snapshot, getContext()->getSettingsRef()); - auto partition_commands_pipe = table->alterPartition(metadata_snapshot, partition_commands, getContext()); + auto partition_commands_pipe = table->alterPartition(metadata_snapshot, partition_commands, getContext(), query_ptr); if (!partition_commands_pipe.empty()) res.pipeline.init(std::move(partition_commands_pipe)); table->setUpdateTimeNow(); @@ -408,7 +421,7 @@ AccessRightsElements InterpreterAlterQuery::getRequiredAccessForCommand(const AS } case ASTAlterCommand::ATTACH_PARTITION: case ASTAlterCommand::ATTACH_DETACHED_PARTITION: - { + case ASTAlterCommand::PREATTACH_PARTITION: { required_access.emplace_back(AccessType::INSERT, database, table); break; } diff --git a/src/Interpreters/InterpreterCreateQuery.cpp b/src/Interpreters/InterpreterCreateQuery.cpp index 9547ba35054..647f63a071e 100644 --- a/src/Interpreters/InterpreterCreateQuery.cpp +++ b/src/Interpreters/InterpreterCreateQuery.cpp @@ -99,6 +99,8 @@ #include #include #include +#include +#include #include #include #include @@ -1406,7 +1408,8 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create) return doCreateOrReplaceTable(create, properties); /// when create materialized view and tenant id is not empty add setting tenant_id to select query - if (create.is_materialized_view && !getCurrentTenantId().empty()) + if (create.is_materialized_view && create.refresh_strategy && (create.refresh_strategy->schedule_kind == RefreshScheduleKind::ASYNC || + create.refresh_strategy->schedule_kind == RefreshScheduleKind::MANUAL) && !getCurrentTenantId().empty()) { ASTPtr settings = std::make_shared(); settings->as().is_standalone = false; @@ -1572,23 +1575,6 @@ bool InterpreterCreateQuery::doCreateTable(ASTCreateQuery & create, try { res->checkColumnsValidity(properties.columns); - if (auto * view = dynamic_cast(res.get())) - { - // if (view->async() && getContext()->getSettingsRef().enable_non_partitioned_base_refresh_throw_exception) - // view->validatePartitionBased(getContext()); - - if (view->tryGetTargetTable() && !view->hasInnerTable()) - { - StoragePtr target_table = view->tryGetTargetTable(); - if (!target_table->getInMemoryMetadataPtr()->getColumns().getMaterialized().empty()) - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Cannot create materialized view {} to target table {} with materialized columns {}", - view->getStorageID().getNameForLogs(), - target_table->getStorageID().getNameForLogs(), - target_table->getInMemoryMetadataPtr()->getColumns().getMaterialized().toString()); - } - } } catch (...) { @@ -1687,7 +1673,7 @@ BlockIO InterpreterCreateQuery::doCreateOrReplaceTable(ASTCreateQuery & create, BlockIO InterpreterCreateQuery::fillTableIfNeeded(const ASTCreateQuery & create) { - /// If the query is a CREATE SELECT, insert the data into the table. + /// If the query is a CREATE SELECT, insert the data into the table via INSERT INTO ... SELECT if (create.select && !create.attach && !create.is_ordinary_view && !create.is_live_view && (!create.is_materialized_view || create.is_populate)) { @@ -1706,22 +1692,46 @@ BlockIO InterpreterCreateQuery::fillTableIfNeeded(const ASTCreateQuery & create) } else { - /// Just run it as new INSET INTO ... SELECT FROM - /// Cannot directly use InterpreterInsertQuery here, because Cnch requires some resouce initialization (txn, vw, session resource) - /// all done in executeQuery now. Directly initialization didn't work. - auto insert_context = Context::createCopy(getContext()->getSessionContext()); + String insert_query_id = fmt::format("{}_insert", getContext()->getCurrentQueryId()); + auto insert_context = Context::createCopy(getContext()); + insert_context->makeSessionContext(); insert_context->makeQueryContext(); - insert_context->setSettings(getContext()->getSettingsRef()); + insert_context->setCurrentTransaction(nullptr, false); + insert_context->setCurrentVW(nullptr); + insert_context->setCurrentWorkerGroup(nullptr); + insert_context->setCurrentQueryId(insert_query_id); + if (!insert_context->getSettingsRef().enable_optimizer_for_create_select) + insert_context->setSetting("enable_optimizer", false); + + /// Execute INSERT in a separate thread so that it has a clean ThreadStatus attached to its local context + /// Pros: a) supports running INSERT in both optimizer and non-optimizer mode (not all inserts are supported by optimizer) + /// b) ThreadStatus for the outer query won't get polutated + /// Cons: a) user facing query (CTAS) cannot get progress update + /// b) cancel CTAS won't affect INSERT + /// c) we'll have two process list items (CTAS & INSERT) for the query which might be confusing + std::exception_ptr exception; + auto thread = ThreadFromGlobalPool([insert_context = std::move(insert_context), &exception, &insert]() { + try + { + CurrentThread::QueryScope query_scope {insert_context}; + ReadBufferFromOwnString in(insert->formatForErrorMessage()); + NullWriteBuffer out; + executeQuery(in, out, /*allow_into_outfile=*/false, insert_context, /*set_result_details=*/{}); + } + catch (...) + { + exception = std::current_exception(); + } - // TODO @wangtao.2077: review this when internal queries are fully supported by optimizer - if (insert_context->getSettingsRef().enable_optimizer && insert_context->getSettingsRef().enable_optimizer_for_create_select) - { - insert_context->setCurrentQueryId(""); - CurrentThread::attachQueryContext(insert_context); - return executeQuery(insert->formatForErrorMessage(), insert_context, /*internal=*/false); - } + }); + thread.join(); - return executeQuery(insert->formatForErrorMessage(), insert_context, /*internal=*/true); + if (exception) + std::rethrow_exception(exception); + + /// NOTE: cannot return BlockIO from the inner query because fields like process_list_entry + /// and finish_callback will be overwrite by the outer executeQuery, which leads to subtle bugs + return {}; } } @@ -1785,7 +1795,21 @@ BlockIO InterpreterCreateQuery::execute() return executeDDLQueryOnCluster(query_ptr, getContext(), getRequiredAccess()); } - getContext()->checkAccess(getRequiredAccess()); + auto context = getContext(); + context->checkAccess(getRequiredAccess()); + if (create.is_dictionary && context->is_tenant_user()) + { + if (create.dictionary) + { + auto& refresh_query = create.dictionary->clickhouse_query; + auto& invalidate_query = create.dictionary->clickhouse_invalidate_query; + if (!refresh_query.empty()) + executeQuery("explain " + refresh_query, context, true); + if (!invalidate_query.empty()) + executeQuery(invalidate_query, context, true); + } + + } ASTQueryWithOutput::resetOutputASTIfExist(create); @@ -1817,6 +1841,17 @@ AccessRightsElements InterpreterCreateQuery::getRequiredAccess() const else if (create.is_dictionary) { required_access.emplace_back(AccessType::CREATE_DICTIONARY, create.database, create.table); + if (create.dictionary) + { + auto & db = create.dictionary->clickhouse_db; + auto & tb = create.dictionary->clickhouse_tb; + if (!db.empty() && !tb.empty()) + { + auto context = getContext(); + if (context->is_tenant_user()) + required_access.emplace_back(AccessType::SELECT, db, tb); + } + } } else if (create.isView()) { diff --git a/src/Interpreters/InterpreterDeleteQuery.cpp b/src/Interpreters/InterpreterDeleteQuery.cpp index 85eeb84289c..2da012c58eb 100644 --- a/src/Interpreters/InterpreterDeleteQuery.cpp +++ b/src/Interpreters/InterpreterDeleteQuery.cpp @@ -87,7 +87,14 @@ BlockIO InterpreterDeleteQuery::execute() 0, DBMS_DEFAULT_MAX_PARSER_DEPTH); - InterpreterInsertQuery insert_interpreter(insert_ast, getContext()); + InterpreterInsertQuery insert_interpreter( + insert_ast, + getContext(), + /*allow_materialized_*/false, + /*no_squash_*/false, + /*no_destination_*/false, + AccessType::ALTER_DELETE); + return insert_interpreter.execute(); } else diff --git a/src/Interpreters/InterpreterDropPreparedStatementQuery.cpp b/src/Interpreters/InterpreterDropPreparedStatementQuery.cpp index b47dfb17ae0..32026fb8417 100644 --- a/src/Interpreters/InterpreterDropPreparedStatementQuery.cpp +++ b/src/Interpreters/InterpreterDropPreparedStatementQuery.cpp @@ -14,11 +14,15 @@ namespace ErrorCodes BlockIO InterpreterDropPreparedStatementQuery::execute() { + auto current_context = getContext(); + AccessRightsElements access_rights_elements; + access_rights_elements.emplace_back(AccessType::DROP_PREPARED_STATEMENT); + current_context->checkAccess(access_rights_elements); + const auto * drop = query_ptr->as(); if (!drop || drop->name.empty()) throw Exception("Drop Prepare logical error", ErrorCodes::LOGICAL_ERROR); - auto current_context = getContext(); // if (!drop->cluster.empty()) // return executeDDLQueryOnCluster(query_ptr, current_context); @@ -28,7 +32,7 @@ BlockIO InterpreterDropPreparedStatementQuery::execute() if (!prepared_manager) throw Exception("Prepare cache has to be initialized", ErrorCodes::LOGICAL_ERROR); - prepared_manager->remove(drop->name, !drop->if_exists); + prepared_manager->remove(drop->name, !drop->if_exists, current_context); return {}; } } diff --git a/src/Interpreters/InterpreterExplainQuery.cpp b/src/Interpreters/InterpreterExplainQuery.cpp index 35dc831201e..c3993f7be8a 100644 --- a/src/Interpreters/InterpreterExplainQuery.cpp +++ b/src/Interpreters/InterpreterExplainQuery.cpp @@ -721,11 +721,8 @@ BlockIO InterpreterExplainQuery::explainAnalyze() { context_ptr->setSetting("log_processors_profiles", true); context_ptr->setSetting("report_processors_profiles", true); + context_ptr->setSetting("report_segment_profiles", true); } - std::shared_ptr> consumer - = std::make_shared(context_ptr->getCurrentQueryId()); - ProfileLogHub::getInstance().initLogChannel(context_ptr->getCurrentQueryId(), consumer); - context_ptr->setProcessorProfileElementConsumer(consumer); context_ptr->setIsExplainQuery(true); try { @@ -734,8 +731,6 @@ BlockIO InterpreterExplainQuery::explainAnalyze() } catch (...) { - if (context_ptr->getProcessorProfileElementConsumer()) - context_ptr->getProcessorProfileElementConsumer()->stop(); throw; } @@ -810,6 +805,7 @@ BlockInputStreamPtr InterpreterExplainQuery::explainMetaData() { InterpreterSelectQueryUseOptimizer interpreter(query_ptr, contxt, SelectQueryOptions()); interpreter.buildQueryPlan(query_plan, analysis, !metadata_settings.lineage_use_optimizer); + query_ptr = interpreter.getQuery(); } catch (...) { diff --git a/src/Interpreters/InterpreterGrantQuery.cpp b/src/Interpreters/InterpreterGrantQuery.cpp index c733cb0c992..36524902c7f 100644 --- a/src/Interpreters/InterpreterGrantQuery.cpp +++ b/src/Interpreters/InterpreterGrantQuery.cpp @@ -26,19 +26,34 @@ namespace void updateFromQueryTemplate( T & grantee, const ASTGrantQuery & query, - const std::vector & roles_to_grant_or_revoke) + const std::vector & roles_to_grant_or_revoke, + bool sensitive_tenant) { if (!query.access_rights_elements.empty()) { if (query.is_revoke) { - grantee.access.revoke(query.access_rights_elements); - grantee.sensitive_access.revoke(query.access_rights_elements); + if (query.if_exists) + { + if (!query.is_sensitive) + grantee.access.tryRevoke(query.access_rights_elements); + if (sensitive_tenant) + grantee.sensitive_access.tryRevoke(query.access_rights_elements); + } + else + { + if (!query.is_sensitive) + grantee.access.revoke(query.access_rights_elements); + if (sensitive_tenant) + grantee.sensitive_access.revoke(query.access_rights_elements); + } } else { - grantee.access.grant(query.access_rights_elements); - grantee.sensitive_access.grant(query.access_rights_elements); + if (!query.is_sensitive) + grantee.access.grant(query.access_rights_elements); + if (sensitive_tenant) + grantee.sensitive_access.grant(query.access_rights_elements); } } @@ -64,12 +79,13 @@ namespace void updateFromQueryImpl( IAccessEntity & grantee, const ASTGrantQuery & query, - const std::vector & roles_to_grant_or_revoke) + const std::vector & roles_to_grant_or_revoke, + bool sensitive_tenant) { if (auto * user = typeid_cast(&grantee)) - updateFromQueryTemplate(*user, query, roles_to_grant_or_revoke); + updateFromQueryTemplate(*user, query, roles_to_grant_or_revoke, sensitive_tenant); else if (auto * role = typeid_cast(&grantee)) - updateFromQueryTemplate(*role, query, roles_to_grant_or_revoke); + updateFromQueryTemplate(*role, query, roles_to_grant_or_revoke, sensitive_tenant); } void checkGranteeIsAllowed(const ContextAccess & access, const UUID & grantee_id, const IAccessEntity & grantee) @@ -99,24 +115,28 @@ namespace const AccessControlManager & access_control, const ContextAccess & access, const ASTGrantQuery & query, - const std::vector & grantees_from_query) + const std::vector & grantees_from_query, + bool & need_check_grantees_are_allowed) { const auto & elements = query.access_rights_elements; + need_check_grantees_are_allowed = true; if (elements.empty()) + { + /// No access rights to grant or revoke. + need_check_grantees_are_allowed = false; return; + } - /// To execute the command GRANT the current user needs to have the access granted - /// with GRANT OPTION. if (!query.is_revoke) { + /// To execute the command GRANT the current user needs to have the access granted with GRANT OPTION. access.checkGrantOption(elements); - checkGranteesAreAllowed(access_control, access, grantees_from_query); return; } if (access.hasGrantOption(elements)) { - checkGranteesAreAllowed(access_control, access, grantees_from_query); + /// Simple case: the current user has the grant option for all the access rights specified for REVOKE. return; } @@ -143,6 +163,7 @@ namespace all_granted_access.makeUnion(user->access); } } + need_check_grantees_are_allowed = false; /// already checked AccessRights required_access; if (elements[0].is_partial_revoke) @@ -164,21 +185,28 @@ namespace } } - std::vector getRoleIDsAndCheckAdminOption( const AccessControlManager & access_control, const ContextAccess & access, const ASTGrantQuery & query, const RolesOrUsersSet & roles_from_query, - const std::vector & grantees_from_query) + const std::vector & grantees_from_query, + bool & need_check_grantees_are_allowed) { - std::vector matching_ids; + need_check_grantees_are_allowed = true; + if (roles_from_query.empty()) + { + /// No roles to grant or revoke. + need_check_grantees_are_allowed = false; + return {}; + } + std::vector matching_ids; if (!query.is_revoke) { + /// To execute the command GRANT the current user needs to have the roles granted with ADMIN OPTION. matching_ids = roles_from_query.getMatchingIDs(access_control); access.checkAdminOption(matching_ids); - checkGranteesAreAllowed(access_control, access, grantees_from_query); return matching_ids; } @@ -187,7 +215,7 @@ namespace matching_ids = roles_from_query.getMatchingIDs(); if (access.hasAdminOption(matching_ids)) { - checkGranteesAreAllowed(access_control, access, grantees_from_query); + /// Simple case: the current user has the admin option for all the roles specified for REVOKE. return matching_ids; } } @@ -215,6 +243,7 @@ namespace all_granted_roles.makeUnion(user->granted_roles); } } + need_check_grantees_are_allowed = false; /// already checked const auto & all_granted_roles_set = query.admin_option ? all_granted_roles.getGrantedWithAdminOption() : all_granted_roles.getGranted(); if (roles_from_query.all) @@ -224,6 +253,33 @@ namespace access.checkAdminOption(matching_ids); return matching_ids; } + + void checkGrantOptionAndGrantees( + const AccessControlManager & access_control, + const ContextAccess & access, + const ASTGrantQuery & query, + const std::vector & grantees_from_query) + { + bool need_check_grantees_are_allowed = true; + checkGrantOption(access_control, access, query, grantees_from_query, need_check_grantees_are_allowed); + if (need_check_grantees_are_allowed) + checkGranteesAreAllowed(access_control, access, grantees_from_query); + } + + std::vector getRoleIDsAndCheckAdminOptionAndGrantees( + const AccessControlManager & access_control, + const ContextAccess & access, + const ASTGrantQuery & query, + const RolesOrUsersSet & roles_from_query, + const std::vector & grantees_from_query) + { + bool need_check_grantees_are_allowed = true; + auto role_ids = getRoleIDsAndCheckAdminOption( + access_control, access, query, roles_from_query, grantees_from_query, need_check_grantees_are_allowed); + if (need_check_grantees_are_allowed) + checkGranteesAreAllowed(access_control, access, grantees_from_query); + return role_ids; + } } @@ -249,7 +305,7 @@ BlockIO InterpreterGrantQuery::execute() /// Check if the current user has corresponding roles granted with admin option. std::vector roles; if (roles_set) - roles = getRoleIDsAndCheckAdminOption(access_control, *getContext()->getAccess(), query, *roles_set, grantees); + roles = getRoleIDsAndCheckAdminOptionAndGrantees(access_control, *getContext()->getAccess(), query, *roles_set, grantees); // if (!query.cluster.empty()) // { @@ -264,13 +320,14 @@ BlockIO InterpreterGrantQuery::execute() /// Check if the current user has corresponding access rights with grant option. if (!query.access_rights_elements.empty()) - checkGrantOption(access_control, *getContext()->getAccess(), query, grantees); + checkGrantOptionAndGrantees(access_control, *getContext()->getAccess(), query, grantees); /// Update roles and users listed in `grantees`. - auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr + auto update_func = [&, ctx = getContext()](const AccessEntityPtr & entity) -> AccessEntityPtr { auto clone = entity->clone(); - updateFromQueryImpl(*clone, query, roles); + bool sensitive_tenant = ctx->getAccessControlManager().isSensitiveGrantee(clone->getName()); + updateFromQueryImpl(*clone, query, roles, sensitive_tenant); return clone; }; @@ -280,21 +337,21 @@ BlockIO InterpreterGrantQuery::execute() } -void InterpreterGrantQuery::updateUserFromQuery(User & user, const ASTGrantQuery & query) +void InterpreterGrantQuery::updateUserFromQuery(User & user, const ASTGrantQuery & query, bool sensitive_tenant) { std::vector roles_to_grant_or_revoke; if (query.roles) roles_to_grant_or_revoke = RolesOrUsersSet{*query.roles}.getMatchingIDs(); - updateFromQueryImpl(user, query, roles_to_grant_or_revoke); + updateFromQueryImpl(user, query, roles_to_grant_or_revoke, sensitive_tenant); } -void InterpreterGrantQuery::updateRoleFromQuery(Role & role, const ASTGrantQuery & query) +void InterpreterGrantQuery::updateRoleFromQuery(Role & role, const ASTGrantQuery & query, bool sensitive_tenant) { std::vector roles_to_grant_or_revoke; if (query.roles) roles_to_grant_or_revoke = RolesOrUsersSet{*query.roles}.getMatchingIDs(); - updateFromQueryImpl(role, query, roles_to_grant_or_revoke); + updateFromQueryImpl(role, query, roles_to_grant_or_revoke, sensitive_tenant); } void InterpreterGrantQuery::extendQueryLogElemImpl(QueryLogElement & elem, const ASTPtr & /*ast*/, ContextPtr) const diff --git a/src/Interpreters/InterpreterGrantQuery.h b/src/Interpreters/InterpreterGrantQuery.h index abaddcc599b..fe24361a005 100644 --- a/src/Interpreters/InterpreterGrantQuery.h +++ b/src/Interpreters/InterpreterGrantQuery.h @@ -19,8 +19,8 @@ class InterpreterGrantQuery : public IInterpreter, WithMutableContext BlockIO execute() override; - static void updateUserFromQuery(User & user, const ASTGrantQuery & query); - static void updateRoleFromQuery(Role & role, const ASTGrantQuery & query); + static void updateUserFromQuery(User & user, const ASTGrantQuery & query, bool sensitive_tenant); + static void updateRoleFromQuery(Role & role, const ASTGrantQuery & query, bool sensitive_tenant); void extendQueryLogElemImpl(QueryLogElement &, const ASTPtr &, ContextPtr) const override; private: diff --git a/src/Interpreters/InterpreterInsertQuery.cpp b/src/Interpreters/InterpreterInsertQuery.cpp index c5ffe56af39..30ceab1ade8 100644 --- a/src/Interpreters/InterpreterInsertQuery.cpp +++ b/src/Interpreters/InterpreterInsertQuery.cpp @@ -19,11 +19,13 @@ * All Bytedance's Modifications are Copyright (2023) Bytedance Ltd. and/or its affiliates. */ +#include #include #include #include #include +#include #include #include #include @@ -60,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -68,6 +71,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include "Interpreters/Context_fwd.h" #include @@ -87,12 +95,13 @@ namespace ErrorCodes } InterpreterInsertQuery::InterpreterInsertQuery( - const ASTPtr & query_ptr_, ContextPtr context_, bool allow_materialized_, bool no_squash_, bool no_destination_) + const ASTPtr & query_ptr_, ContextPtr context_, bool allow_materialized_, bool no_squash_, bool no_destination_, AccessType access_type_) : WithContext(context_) , query_ptr(query_ptr_) , allow_materialized(allow_materialized_) , no_squash(no_squash_) , no_destination(no_destination_) + , access_type(access_type_) { checkStackSize(); } @@ -211,8 +220,37 @@ StoragePtr InterpreterInsertQuery::getTable(ASTInsertQuery & query) else { query.table_id = getContext()->resolveStorageID(query.table_id); - return DatabaseCatalog::instance().getTable(query.table_id, getContext()); + auto storage = DatabaseCatalog::instance().tryGetTable(query.table_id, getContext()); + if (storage) + return storage; + + storage = tryGetTableInWorkerResource(query.table_id); + if (storage) + return storage; + throw Exception(ErrorCodes::UNKNOWN_TABLE, "Cannot find table {} in server", query.table_id.getNameForLogs()); + } +} + +StoragePtr InterpreterInsertQuery::tryGetTableInWorkerResource(const StorageID & table_id) +{ + /// in some case of bitengine, server will write data into dict table and the targe table + /// can only be found in worker_resource + auto try_get_table_from_worker_resource = [&table_id](const auto & context) -> StoragePtr { + if (auto worker_resource = context->tryGetCnchWorkerResource(); worker_resource) + return worker_resource->getTable(table_id); + else + return nullptr; + }; + auto storage = try_get_table_from_worker_resource(getContext()); + if (storage) + return storage; + else if (auto query_context = getContext()->getQueryContext()) + { + storage = try_get_table_from_worker_resource(query_context); + if (storage) + return storage; } + return nullptr; } Block InterpreterInsertQuery::getSampleBlock( @@ -310,7 +348,7 @@ BlockIO InterpreterInsertQuery::execute() auto query_sample_block = getSampleBlock(insert_query, table, metadata_snapshot); if (!insert_query.table_function) - getContext()->checkAccess(AccessType::INSERT, insert_query.table_id, query_sample_block.getNames()); + getContext()->checkAccess(access_type, insert_query.table_id, query_sample_block.getNames()); bool is_distributed_insert_select = false; @@ -341,18 +379,18 @@ BlockIO InterpreterInsertQuery::execute() /// Handle the insert commit for insert select/infile case in cnch server. BlockInputStreamPtr in = cnch_merge_tree->writeInWorker(query_ptr, metadata_snapshot, getContext()); + auto txn = getContext()->getCurrentTransaction(); + txn->setMainTableUUID(table->getStorageUUID()); + if (const auto * cnch_table = dynamic_cast(table.get()); - cnch_table && cnch_table->commitTxnFromWorkerSide(metadata_snapshot, getContext())) + cnch_table && cnch_table->commitTxnInWriteSuffixStage(txn->getDedupImplVersion(getContext()), getContext())) { /// for unique table, insert select|infile is committed from worker side res.in = std::move(in); } else - { - auto txn = getContext()->getCurrentTransaction(); - txn->setMainTableUUID(table->getStorageUUID()); res.in = std::make_shared(in, std::move(txn)); - } + if (insert_query.is_overwrite && !lock_holders.empty()) { /// Make sure lock is release after txn commit @@ -438,7 +476,7 @@ BlockIO InterpreterInsertQuery::execute() res.pipeline.dropTotalsAndExtremes(); - if (table->supportsParallelInsert() && settings.max_insert_threads > 1) + if (table->supportsParallelInsert(getContext()) && settings.max_insert_threads > 1) out_streams_size = std::min(size_t(settings.max_insert_threads), res.pipeline.getNumStreams()); res.pipeline.resize(out_streams_size); @@ -571,15 +609,32 @@ BlockIO InterpreterInsertQuery::execute() res.pipeline.addSimpleTransform( [&](const Block & in_header) -> ProcessorPtr { return std::make_shared(in_header, actions); }); - res.pipeline.setSinks([&](const Block &, QueryPipeline::StreamType type) -> ProcessorPtr { - if (type != QueryPipeline::StreamType::Main) - return nullptr; + if (settings.insert_select_with_profiles) + { + res.pipeline.addSimpleTransform([&](const Block &, QueryPipeline::StreamType type) -> ProcessorPtr + { + if (type != QueryPipeline::StreamType::Main) + return nullptr; - auto stream = std::move(out_streams.back()); - out_streams.pop_back(); + auto stream = std::move(out_streams.back()); + out_streams.pop_back(); + + return std::make_shared(std::move(stream)); + }); + } + else + { + res.pipeline.setSinks([&](const Block &, QueryPipeline::StreamType type) -> ProcessorPtr + { + if (type != QueryPipeline::StreamType::Main) + return nullptr; - return std::make_shared(std::move(stream)); - }); + auto stream = std::move(out_streams.back()); + out_streams.pop_back(); + + return std::make_shared(std::move(stream)); + }); + } if (!allow_materialized) { @@ -633,6 +688,72 @@ void InterpreterInsertQuery::extendQueryLogElemImpl(QueryLogElement & elem, cons } } +void parseFuzzyName(const ContextPtr & context_ptr, std::vector & file_path_list, const String & source_uri, const String & scheme) +{ + // Assume no query and fragment in uri, todo, add sanity check + String fuzzy_file_name; + String uri_prefix = source_uri.substr(0, source_uri.find_last_of('/')); + if (uri_prefix.length() == source_uri.length()) + { + fuzzy_file_name = source_uri; + uri_prefix.clear(); + } + else + { + uri_prefix += "/"; + fuzzy_file_name = source_uri.substr(uri_prefix.length()); + } + + auto max_files = context_ptr->getSettingsRef().fuzzy_max_files; + std::vector parent_list = parseDescription(fuzzy_file_name, 0, fuzzy_file_name.length(), ',', max_files); + for (const auto & fuzzy_name : parent_list) + { + std::vector child_list = parseDescription(fuzzy_name, 0, fuzzy_name.length(), '|', max_files); + for (const auto & star_name : child_list) + { + String full_path = uri_prefix + star_name; + if (star_name.find_first_of("*?{") == std::string::npos) + { + file_path_list.emplace_back(full_path); + continue; + } + + std::shared_ptr matcher; + if (scheme.empty() || scheme == "file") + { + matcher = std::make_shared(); + } +#if USE_HDFS + else if (DB::isHdfsOrCfsScheme(scheme)) + { + matcher = std::make_shared(full_path, context_ptr); + } +#endif +#if USE_AWS_S3 + else if (isS3URIScheme(scheme)) + { + matcher = std::make_shared(full_path, context_ptr); + } +#endif + else + { + file_path_list.emplace_back(full_path); + } + + if (matcher) + { + // match files + String match_path = matcher->removeSchemeAndPrefix(full_path); + Strings match_file_list = matcher->regexMatchFiles("/", match_path); + file_path_list.insert(file_path_list.end(), match_file_list.begin(), match_file_list.end()); + } + + if (file_path_list.size() > max_files) + throw Exception(uri_prefix + fuzzy_file_name + " generates too many files, please modify the value of fuzzy_max_files.", ErrorCodes::BAD_ARGUMENTS); + } + } +} + BlockInputStreamPtr InterpreterInsertQuery::buildInputStreamFromSource( const ContextPtr context_ptr, const ColumnsDescription & columns, @@ -643,83 +764,71 @@ BlockInputStreamPtr InterpreterInsertQuery::buildInputStreamFromSource( bool is_enable_squash, const String & compression_method) { - // Assume no query and fragment in uri, todo, add sanity check - String fuzzyFileNames; - String uriPrefix = source_uri.substr(0, source_uri.find_last_of('/')); - if (uriPrefix.length() == source_uri.length()) - { - fuzzyFileNames = source_uri; - uriPrefix.clear(); - } - else - { - uriPrefix += "/"; - fuzzyFileNames = source_uri.substr(uriPrefix.length()); - } - - Poco::URI uri(uriPrefix); + Poco::URI uri(source_uri); const String & scheme = uri.getScheme(); BlockInputStreams inputs; { - auto max_files = context_ptr->getSettingsRef().fuzzy_max_files; - std::vector fuzzyNameList = parseDescription(fuzzyFileNames, 0, fuzzyFileNames.length(), ',' , max_files); - std::vector > fileNames; - for (auto fuzzyName : fuzzyNameList) - fileNames.push_back(parseDescription(fuzzyName, 0, fuzzyName.length(), '|', max_files)); + std::vector file_path_list; + parseFuzzyName(context_ptr, file_path_list, source_uri, scheme); - for (auto & vecNames : fileNames) + for (auto & file_path : file_path_list) { - for (auto & name : vecNames) - { - std::unique_ptr read_buf = nullptr; + std::unique_ptr read_buf = nullptr; - if (scheme.empty() || scheme == "file") - { - read_buf = std::make_unique(Poco::URI(uriPrefix + name).getPath()); - } + if (scheme.empty() || scheme == "file") + { + read_buf = std::make_unique(Poco::URI(file_path).getPath()); + } #if USE_HDFS - else if (DB::isHdfsOrCfsScheme(scheme)) - { - ReadSettings read_settings; - read_settings.remote_throttler = context_ptr->getProcessList().getHDFSDownloadThrottler(); - read_buf = std::make_unique(uriPrefix + name, context_ptr->getHdfsConnectionParams(), read_settings); - } + else if (DB::isHdfsOrCfsScheme(scheme)) + { + ReadSettings read_settings; + read_settings.remote_throttler = context_ptr->getProcessList().getHDFSDownloadThrottler(); + read_buf = std::make_unique(file_path, context_ptr->getHdfsConnectionParams(), read_settings); + } #endif #if USE_AWS_S3 - else if (isS3URIScheme(scheme)) - { - S3::URI s3_uri(Poco::URI(uriPrefix + name)); - String endpoint = s3_uri.endpoint.empty() ? context_ptr->getSettingsRef().s3_endpoint.toString() : s3_uri.endpoint; - String bucket = s3_uri.bucket; - String key = s3_uri.key; - S3::S3Config s3_cfg(endpoint, context_ptr->getSettingsRef().s3_region.toString(), bucket, - context_ptr->getSettingsRef().s3_ak_id.toString(), context_ptr->getSettingsRef().s3_ak_secret.toString(), - "", "", context_ptr->getSettingsRef().s3_use_virtual_hosted_style); - const std::shared_ptr client = s3_cfg.create(); - read_buf = std::make_unique(client, bucket, key, context_ptr->getReadSettings()); - } + else if (isS3URIScheme(scheme)) + { + S3::URI s3_uri(file_path); + String endpoint = s3_uri.endpoint.empty() ? context_ptr->getSettingsRef().s3_endpoint.toString() : s3_uri.endpoint; + String bucket = s3_uri.bucket; + String key = s3_uri.key; + S3::S3Config s3_cfg( + endpoint, + context_ptr->getSettingsRef().s3_region.toString(), + bucket, + context_ptr->getSettingsRef().s3_ak_id.toString(), + context_ptr->getSettingsRef().s3_ak_secret.toString(), + "", + "", + context_ptr->getSettingsRef().s3_use_virtual_hosted_style); + const std::shared_ptr client = s3_cfg.create(); + read_buf = std::make_unique(client, bucket, key, context_ptr->getReadSettings()); + } #endif - else - { - throw Exception("URI scheme " + scheme + " is not supported with insert statement yet", ErrorCodes::NOT_IMPLEMENTED); - } + else + { + throw Exception("URI scheme " + scheme + " is not supported with insert statement yet", ErrorCodes::NOT_IMPLEMENTED); + } - read_buf = wrapReadBufferWithCompressionMethod(std::move(read_buf), chooseCompressionMethod(name, compression_method), settings.snappy_format_blocked); + read_buf = wrapReadBufferWithCompressionMethod( + std::move(read_buf), chooseCompressionMethod(file_path, compression_method), settings.snappy_format_blocked); - inputs.emplace_back( - std::make_shared>( - context_ptr->getInputStreamByFormatNameAndBuffer(format, *read_buf, - sample, // sample_block - settings.max_insert_block_size, - columns), + inputs.emplace_back(std::make_shared>( + context_ptr->getInputStreamByFormatNameAndBuffer( + format, + *read_buf, + sample, // sample_block + settings.max_insert_block_size, + columns), std::move(read_buf))); - } } } - if (inputs.size() == 0) - throw Exception("Inputs interpreter error", ErrorCodes::LOGICAL_ERROR); + if (inputs.empty()) + throw Exception("Input files is empty.", ErrorCodes::LOGICAL_ERROR); auto stream = inputs[0]; if (inputs.size() > 1) diff --git a/src/Interpreters/InterpreterInsertQuery.h b/src/Interpreters/InterpreterInsertQuery.h index c5e9848acfd..2feb09e96cf 100644 --- a/src/Interpreters/InterpreterInsertQuery.h +++ b/src/Interpreters/InterpreterInsertQuery.h @@ -21,6 +21,7 @@ #pragma once +#include #include #include #include @@ -40,7 +41,8 @@ class InterpreterInsertQuery : public IInterpreter, WithContext ContextPtr context_, bool allow_materialized_ = false, bool no_squash_ = false, - bool no_destination_ = false); + bool no_destination_ = false, + AccessType access_type_ = AccessType::INSERT); /** Prepare a request for execution. Return block streams * - the stream into which you can write data to execute the query, if INSERT; @@ -65,12 +67,14 @@ class InterpreterInsertQuery : public IInterpreter, WithContext private: StoragePtr getTable(ASTInsertQuery & query); + StoragePtr tryGetTableInWorkerResource(const StorageID & table_id); Block getSampleBlock(const ASTInsertQuery & query, const StoragePtr & table, const StorageMetadataPtr & metadata_snapshot) const; ASTPtr query_ptr; const bool allow_materialized; const bool no_squash; const bool no_destination; + AccessType access_type{AccessType::INSERT}; }; diff --git a/src/Interpreters/InterpreterKillQueryQuery.cpp b/src/Interpreters/InterpreterKillQueryQuery.cpp index 106b7eae603..05065290968 100644 --- a/src/Interpreters/InterpreterKillQueryQuery.cpp +++ b/src/Interpreters/InterpreterKillQueryQuery.cpp @@ -45,6 +45,7 @@ #include #include #include +#include namespace DB @@ -121,24 +122,26 @@ static QueryDescriptors extractQueriesExceptMeAndCheckAccess(const Block & proce }; String query_user; + const auto & tenant_id = getCurrentTenantId(); + const auto & nontenant_client_user = getOriginalEntityName(my_client.current_user, tenant_id); for (size_t i = 0; i < num_processes; ++i) { if ((my_client.current_query_id == query_id_col.getDataAt(i).toString()) - && (my_client.current_user == user_col.getDataAt(i).toString())) + && (nontenant_client_user == user_col.getDataAt(i).toString())) continue; auto query_id = query_id_col.getDataAt(i).toString(); query_user = user_col.getDataAt(i).toString(); - if ((my_client.current_user != query_user) && !is_kill_query_granted()) + if ((nontenant_client_user != query_user) && !is_kill_query_granted()) continue; res.emplace_back(std::move(query_id), query_user, i, false); } if (res.empty() && access_denied) - throw Exception("User " + my_client.current_user + " attempts to kill query created by " + query_user, ErrorCodes::ACCESS_DENIED); + throw Exception("User " + nontenant_client_user + " attempts to kill query created by " + query_user, ErrorCodes::ACCESS_DENIED); return res; } @@ -225,6 +228,7 @@ BlockIO InterpreterKillQueryQuery::execute() { case ASTKillQueryQuery::Type::Query: { + auto context = getContext(); auto where_clause = DB::collectWhereORClausePredicate(query.where_expression, getContext()); String query_id; std::for_each(where_clause.begin(), where_clause.end(), [&query_id](const std::map & wheres) { @@ -233,13 +237,13 @@ BlockIO InterpreterKillQueryQuery::execute() query_id = iter->second.get(); }); if (!query_id.empty()) - getContext()->getQueueManager()->cancel(query_id); + context->getQueueManager()->cancel(query_id); Block processes_block = getSelectResult("query_id, user, query", "system.processes"); if (!processes_block) return res_io; - ProcessList & process_list = getContext()->getProcessList(); - QueryDescriptors queries_to_stop = extractQueriesExceptMeAndCheckAccess(processes_block, getContext()); + ProcessList & process_list = context->getProcessList(); + QueryDescriptors queries_to_stop = extractQueriesExceptMeAndCheckAccess(processes_block, context); auto header = processes_block.cloneEmpty(); header.insert(0, {ColumnString::create(), std::make_shared(), "kill_status"}); @@ -250,7 +254,7 @@ BlockIO InterpreterKillQueryQuery::execute() for (const auto & query_desc : queries_to_stop) { auto code = (query.test) ? CancellationCode::Unknown - : process_list.sendCancelToQuery(query_desc.query_id, query_desc.user, true); + : process_list.sendCancelToQuery(query_desc.query_id, (context->is_tenant_user() ? formatTenantEntityName(query_desc.user) : query_desc.user), true); insertResultRow(query_desc.source_num, code, processes_block, header, res_columns); } diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 2b88fcd8d62..42ddf0958c0 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -282,6 +282,10 @@ static void checkAccessRightsForSelect( const StorageMetadataPtr & table_metadata, const TreeRewriterResult & syntax_analyzer_result) { + // firstly, check aeolus access + if (context->getServerType() == ServerType::cnch_server) + context->checkAeolusTableAccess(table_id.database_name, table_id.table_name); + if (!syntax_analyzer_result.has_explicit_columns && table_metadata && !table_metadata->getColumns().empty()) { /// For a trivial query like "SELECT count() FROM table" access is granted if at least @@ -487,6 +491,11 @@ InterpreterSelectQuery::InterpreterSelectQuery( if (storage && query.where() && !query.prewhere()) { + if (auto * merge_tree_data = dynamic_cast(storage.get())) + { + merge_tree_data->prepareDataPartsForRead(); + } + /// PREWHERE optimization: transfer some condition from WHERE to PREWHERE if enabled and viable if (const auto & column_sizes = storage->getColumnSizes(); !column_sizes.empty()) { @@ -620,7 +629,8 @@ InterpreterSelectQuery::InterpreterSelectQuery( required_columns = syntax_analyzer_result->requiredSourceColumns(); // disable map column access if not explcit set to avoid "select *" query - if (storage && storage->supportsMapImplicitColumn() && !settings.allow_map_access_without_key && query_analyzer->hasByteMapColumn()) + if (storage && storage->supportsMapImplicitColumn() && (!settings.allow_map_access_without_key && !settings.enable_optimizer) + && query_analyzer->hasByteMapColumn()) throw Exception("Map column access without key is not allowed for ByteMap", ErrorCodes::NOT_IMPLEMENTED); if (storage) diff --git a/src/Interpreters/InterpreterSelectQueryUseOptimizer.cpp b/src/Interpreters/InterpreterSelectQueryUseOptimizer.cpp index ef8a6422f8b..95a315c92dc 100644 --- a/src/Interpreters/InterpreterSelectQueryUseOptimizer.cpp +++ b/src/Interpreters/InterpreterSelectQueryUseOptimizer.cpp @@ -16,6 +16,10 @@ #include #include +#include +#include +#include +#include #include #include #include @@ -38,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -52,11 +57,15 @@ #include #include #include +#include +#include +#include "common/defines.h" #include #include #include +#include "Interpreters/ClientInfo.h" +#include "Interpreters/Context_fwd.h" -#include namespace ProfileEvents { @@ -545,7 +554,7 @@ BlockIO InterpreterSelectQueryUseOptimizer::readFromQueryCache(ContextPtr local_ BlockIO InterpreterSelectQueryUseOptimizer::execute() { - if (auto * create_prepared = query_ptr->as()) + if (query_ptr && query_ptr->as()) { // if (!create_prepared->cluster.empty()) // return executeDDLQueryOnCluster(query_ptr, context); @@ -603,7 +612,6 @@ void InterpreterSelectQueryUseOptimizer::resetFinalSampleSize(PlanSegmentTreePtr size_t sample_size = (sample->getSampleSize() + 1) / plan_segment.getPlanSegment()->getParallelSize(); sample->setSampleSize(sample_size); } - } } } @@ -618,6 +626,9 @@ void InterpreterSelectQueryUseOptimizer::fillContextQueryAccessInfo(ContextPtr c { Names required_columns; auto storage_id = storage_analysis.storage->getStorageID(); + // check aeolus access + if (context->getServerType() == ServerType::cnch_server) + context->checkAeolusTableAccess(storage_id.database_name, storage_id.table_name); if (auto it = used_columns_map.find(storage_analysis.storage->getStorageID()); it != used_columns_map.end()) { for (const auto & column : it->second) @@ -650,11 +661,26 @@ void InterpreterSelectQueryUseOptimizer::setUnsupportedSettings(ContextMutablePt return; SettingsChanges setting_changes; - setting_changes.emplace_back("distributed_aggregation_memory_efficient", false); - context->applySettingsChanges(setting_changes); } +void InterpreterSelectQueryUseOptimizer::fillQueryPlan(ContextPtr context, QueryPlan & query_plan) +{ + WriteBufferFromOwnString buffer; + Protos::QueryPlan plan_pb; + query_plan.toProto(plan_pb); + String json_msg; + google::protobuf::util::JsonPrintOptions pb_options; + pb_options.preserve_proto_field_names = true; + pb_options.always_print_primitive_fields = true; + pb_options.add_whitespace = false; + + google::protobuf::util::MessageToJsonString(plan_pb, &json_msg, pb_options); + buffer << json_msg; + + context->getQueryContext()->addQueryPlanInfo(buffer.str()); +} + void InterpreterSelectQueryUseOptimizer::buildQueryPlan(QueryPlanPtr & query_plan, AnalysisPtr & analysis, bool skip_optimize) { context->createPlanNodeIdAllocator(); @@ -685,6 +711,12 @@ void InterpreterSelectQueryUseOptimizer::buildQueryPlan(QueryPlanPtr & query_pla { stage_watch.restart(); PlanOptimizer::optimize(*query_plan, context); + + if (context->getSettingsRef().log_query_plan) + { + fillQueryPlan(context, *query_plan); + } + context->logOptimizerProfile( log, "Optimizer stage run time: ", "Optimizer", std::to_string(stage_watch.elapsedMillisecondsAsDouble()) + "ms"); ProfileEvents::increment(ProfileEvents::QueryOptimizerTime, stage_watch.elapsedMilliseconds()); @@ -693,6 +725,14 @@ void InterpreterSelectQueryUseOptimizer::buildQueryPlan(QueryPlanPtr & query_pla BlockIO InterpreterSelectQueryUseOptimizer::executeCreatePreparedStatementQuery() { + const auto & prepare = query_ptr->as(); + AccessRightsElements access_rights_elements; + access_rights_elements.emplace_back(AccessType::CREATE_PREPARED_STATEMENT); + + if (prepare.or_replace) + access_rights_elements.emplace_back(AccessType::DROP_PREPARED_STATEMENT); + context->checkAccess(access_rights_elements); + auto * prep_stat_manager = context->getPreparedStatementManager(); if (!prep_stat_manager) throw Exception("Prepare cache has to be initialized", ErrorCodes::LOGICAL_ERROR); @@ -703,10 +743,9 @@ BlockIO InterpreterSelectQueryUseOptimizer::executeCreatePreparedStatementQuery( String name; String formatted_query; SettingsChanges settings_changes; - const auto & prepare = query_ptr->as(); + ASTPtr prepare_ast = query_ptr->clone(); { name = prepare.getName(); - formatted_query = prepare.formatForErrorMessage(); settings_changes = InterpreterSetQuery::extractSettingsFromQuery(query_ptr, context); } @@ -716,16 +755,7 @@ BlockIO InterpreterSelectQueryUseOptimizer::executeCreatePreparedStatementQuery( CollectPreparedParams prepared_params_collector; CollectPreparedParamsVisitor(prepared_params_collector).visit(query_ptr); prep_stat_manager->addPlanToCache( - name, - formatted_query, - settings_changes, - query_plan, - analysis, - std::move(prepared_params_collector.prepared_params), - context, - !prepare.if_not_exists, - prepare.or_replace, - prepare.is_permanent); + name, prepare_ast, settings_changes, query_plan, analysis, std::move(prepared_params_collector.prepared_params), context); return {}; } @@ -870,12 +900,22 @@ std::optional ClusterInfoFinder::visitCTERefNode(CTERefNode void ExplainAnalyzeVisitor::visitExplainAnalyzeNode(QueryPlan::Node * node, PlanSegmentTree::Nodes & nodes) { auto * explain = dynamic_cast(node->step.get()); - if (explain->getKind() != ASTExplainQuery::ExplainKind::DistributedAnalyze && explain->getKind() != ASTExplainQuery::ExplainKind::PipelineAnalyze) - return; PlanSegmentDescriptions plan_segment_descriptions; bool record_plan_detail = explain->getSetting().json && (explain->getKind() != ASTExplainQuery::ExplainKind::PipelineAnalyze); for (auto & segment_node : nodes) - plan_segment_descriptions.emplace_back(PlanSegmentDescription::getPlanSegmentDescription(segment_node.plan_segment, record_plan_detail)); + { + if (explain->getKind() == ASTExplainQuery::ExplainKind::DistributedAnalyze + || explain->getKind() == ASTExplainQuery::ExplainKind::LogicalAnalyze) + segment_node.plan_segment->setProfileType(ReportProfileType::QueryPlan); + else if (explain->getKind() == ASTExplainQuery::ExplainKind::PipelineAnalyze) + segment_node.plan_segment->setProfileType(ReportProfileType::QueryPipeline); + + if (explain->getKind() == ASTExplainQuery::ExplainKind::DistributedAnalyze + || explain->getKind() == ASTExplainQuery::ExplainKind::PipelineAnalyze) + plan_segment_descriptions.emplace_back( + PlanSegmentDescription::getPlanSegmentDescription(segment_node.plan_segment, record_plan_detail)); + } + explain->setPlanSegmentDescriptions(plan_segment_descriptions); } diff --git a/src/Interpreters/InterpreterSelectQueryUseOptimizer.h b/src/Interpreters/InterpreterSelectQueryUseOptimizer.h index f536ab6591a..ab68c8fe070 100644 --- a/src/Interpreters/InterpreterSelectQueryUseOptimizer.h +++ b/src/Interpreters/InterpreterSelectQueryUseOptimizer.h @@ -26,6 +26,7 @@ #include #include #include +#include "Parsers/IAST_fwd.h" namespace Poco { @@ -79,6 +80,8 @@ class InterpreterSelectQueryUseOptimizer : public IInterpreter static void fillContextQueryAccessInfo(ContextPtr context, AnalysisPtr & analysis); + static void fillQueryPlan(ContextPtr context, QueryPlan & query_plan); + Block getSampleBlock(); static void setUnsupportedSettings(ContextMutablePtr & context); @@ -87,6 +90,8 @@ class InterpreterSelectQueryUseOptimizer : public IInterpreter BlockIO executeCreatePreparedStatementQuery(); bool isCreatePreparedStatement(); + ASTPtr & getQuery() { return query_ptr; } + private: ASTPtr query_ptr; PlanNodePtr sub_plan_ptr; diff --git a/src/Interpreters/InterpreterSetSensitiveQuery.cpp b/src/Interpreters/InterpreterSetSensitiveQuery.cpp index 44a950b1e10..233d785701b 100644 --- a/src/Interpreters/InterpreterSetSensitiveQuery.cpp +++ b/src/Interpreters/InterpreterSetSensitiveQuery.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace DB { @@ -10,7 +11,16 @@ namespace DB BlockIO InterpreterSetSensitiveQuery::execute() { const auto & ast = query_ptr->as(); - getContext()->getCnchCatalog()->putSensitiveResource(ast.database, ast.table, ast.column, ast.target, ast.value); + const auto ctx = getContext(); + + if (ast.target == "DATABASE") + ctx->checkAccess(AccessType::SET_SENSITIVE, ast.database); + else if (ast.target == "TABLE") + ctx->checkAccess(AccessType::SET_SENSITIVE, ast.database, ast.table); + else if (ast.target == "COLUMN") + ctx->checkAccess(AccessType::SET_SENSITIVE, ast.database, ast.table, ast.column); + + ctx->getCnchCatalog()->putSensitiveResource(ast.database, ast.table, ast.column, ast.target, ast.value); return {}; } diff --git a/src/Interpreters/InterpreterShowAccessQuery.cpp b/src/Interpreters/InterpreterShowAccessQuery.cpp index 5533e07c415..c96e950c5f4 100644 --- a/src/Interpreters/InterpreterShowAccessQuery.cpp +++ b/src/Interpreters/InterpreterShowAccessQuery.cpp @@ -83,7 +83,10 @@ ASTs InterpreterShowAccessQuery::getCreateAndGrantQueries() const { create_queries.push_back(InterpreterShowCreateAccessEntityQuery::getCreateQuery(*entity, access_control)); if (entity->isTypeOf(EntityType::USER) || entity->isTypeOf(EntityType::ROLE)) - boost::range::push_back(grant_queries, InterpreterShowGrantsQuery::getGrantQueries(*entity, access_control)); + { + boost::range::push_back(grant_queries, InterpreterShowGrantsQuery::getGrantQueries(*entity, access_control, true)); + boost::range::push_back(grant_queries, InterpreterShowGrantsQuery::getGrantQueries(*entity, access_control, false)); + } } ASTs result = std::move(create_queries); diff --git a/src/Interpreters/InterpreterShowGrantsQuery.cpp b/src/Interpreters/InterpreterShowGrantsQuery.cpp index b274e2775fa..c23ced7157d 100644 --- a/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -56,6 +56,7 @@ namespace if (!current_query) { current_query = std::make_shared(); + current_query->is_sensitive = sensitive_mode; current_query->grantees = grantees; current_query->attach_mode = attach_mode; if (element.is_partial_revoke) @@ -66,6 +67,9 @@ namespace current_query->access_rights_elements.emplace_back(std::move(element)); } + if (sensitive_mode) + return res; + for (const auto & element : grantee.granted_roles.getElements()) { if (element.empty()) @@ -167,15 +171,18 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries() const ASTs grant_queries; for (const auto & entity : entities) - boost::range::push_back(grant_queries, getGrantQueries(*entity, access_control)); + { + boost::range::push_back(grant_queries, getGrantQueries(*entity, access_control, true)); + boost::range::push_back(grant_queries, getGrantQueries(*entity, access_control, false)); + } return grant_queries; } -ASTs InterpreterShowGrantsQuery::getGrantQueries(const IAccessEntity & user_or_role, const AccessControlManager & access_control) +ASTs InterpreterShowGrantsQuery::getGrantQueries(const IAccessEntity & user_or_role, const AccessControlManager & access_control, bool sensitive_mode) { - return getGrantQueriesImpl(user_or_role, &access_control, false); + return getGrantQueriesImpl(user_or_role, &access_control, false, sensitive_mode); } diff --git a/src/Interpreters/InterpreterShowGrantsQuery.h b/src/Interpreters/InterpreterShowGrantsQuery.h index 9550bbdf387..8113f15a3ea 100644 --- a/src/Interpreters/InterpreterShowGrantsQuery.h +++ b/src/Interpreters/InterpreterShowGrantsQuery.h @@ -20,8 +20,8 @@ class InterpreterShowGrantsQuery : public IInterpreter, WithContext BlockIO execute() override; - static ASTs getGrantQueries(const IAccessEntity & user_or_role, const AccessControlManager & access_control); - static ASTs getAttachGrantQueries(const IAccessEntity & user_or_role, bool sensitive_mode = false); + static ASTs getGrantQueries(const IAccessEntity & user_or_role, const AccessControlManager & access_control, bool sensitive_mode); + static ASTs getAttachGrantQueries(const IAccessEntity & user_or_role, bool sensitive_mode); bool ignoreQuota() const override { return true; } bool ignoreLimits() const override { return true; } diff --git a/src/Interpreters/InterpreterShowPreparedStatementQuery.cpp b/src/Interpreters/InterpreterShowPreparedStatementQuery.cpp index 387ad197656..dc8dfeea232 100644 --- a/src/Interpreters/InterpreterShowPreparedStatementQuery.cpp +++ b/src/Interpreters/InterpreterShowPreparedStatementQuery.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB { @@ -26,7 +27,22 @@ BlockIO InterpreterShowPreparedStatementQuery::execute() if (show_prepared->show_create) { auto prepared_object = prepared_manager->getObject(show_prepared->name); - out << prepared_object.query; + if (!prepared_object.query) + out << "Null"; + if (auto * create_prep_stat = prepared_object.query->as()) + { + + if (context->getTenantId().empty()) + out << create_prep_stat->formatForErrorMessage(); + else + { + ASTPtr ast = create_prep_stat->clone(); + auto * new_create = ast->as(); + new_create->rewriteNamesWithoutTenant(); + out << new_create->formatForErrorMessage(); + } + } + result_column_name = "Create Statement"; } else if (show_prepared->show_explain) @@ -58,7 +74,10 @@ BlockIO InterpreterShowPreparedStatementQuery::execute() { auto name_list = prepared_manager->getNames(); for (auto & name : name_list) - out << name << "\n"; + { + if (context->getTenantId().empty() || isTenantMatchedEntityName(name)) + out << getOriginalEntityName(name) << "\n"; + } result_column_name = "Prepared Statement List"; } diff --git a/src/Interpreters/InterpreterShowTablesQuery.cpp b/src/Interpreters/InterpreterShowTablesQuery.cpp index c1ad9774f4c..1b74507645d 100644 --- a/src/Interpreters/InterpreterShowTablesQuery.cpp +++ b/src/Interpreters/InterpreterShowTablesQuery.cpp @@ -207,7 +207,7 @@ String InterpreterShowTablesQuery::getRewrittenQueryImpl() rewritten_query << "is_temporary"; } else - rewritten_query << "database = " << DB::quote << ((query.dictionaries || query.snapshots) ? database : getOriginalDatabaseName(database)); + rewritten_query << "database = " << DB::quote << ((query.snapshots) ? database : getOriginalDatabaseName(database)); if (!query.like.empty()) rewritten_query << " AND name " << (query.not_like ? "NOT " : "") << (query.case_insensitive_like ? "ILIKE " : "LIKE ") << DB::quote diff --git a/src/Interpreters/InterpreterSystemQuery.cpp b/src/Interpreters/InterpreterSystemQuery.cpp index 46c3fde9d09..962ac2b04d9 100644 --- a/src/Interpreters/InterpreterSystemQuery.cpp +++ b/src/Interpreters/InterpreterSystemQuery.cpp @@ -1818,7 +1818,7 @@ void InterpreterSystemQuery::lockMemoryLock(const ASTSystemQuery & query, const Stopwatch lock_watch; - auto cnch_lock = transaction->createLockHolder({std::move(partition_lock)}); + auto cnch_lock = std::make_shared(local_context, std::move(partition_lock)); cnch_lock->lock(); LOG_DEBUG(log, "Acquired lock in {} ms", lock_watch.elapsedMilliseconds()); sleepForSeconds(query.seconds); diff --git a/src/Interpreters/InterpreterUpdateQuery.cpp b/src/Interpreters/InterpreterUpdateQuery.cpp index 9469ac4fec9..a71a93ca586 100644 --- a/src/Interpreters/InterpreterUpdateQuery.cpp +++ b/src/Interpreters/InterpreterUpdateQuery.cpp @@ -58,6 +58,34 @@ BlockIO InterpreterUpdateQuery::execute() } +static ASTTableExpression * getFirstTableExpression(const ASTUpdateQuery & update) +{ + if (!update.tables) + return {}; + + auto & tables_in_update_query = update.tables->as(); + if (tables_in_update_query.children.empty()) + return {}; + + auto & tables_element = tables_in_update_query.children[0]->as(); + if (!tables_element.table_expression) + return {}; + + return tables_element.table_expression->as(); +} + +static String getTableExpressionAlias(const ASTTableExpression * table_expression) +{ + if (table_expression->subquery) + return table_expression->subquery->tryGetAlias(); + else if (table_expression->table_function) + return table_expression->table_function->tryGetAlias(); + else if (table_expression->database_and_table_name) + return table_expression->database_and_table_name->tryGetAlias(); + + return String(); +} + ASTPtr InterpreterUpdateQuery::prepareInterpreterSelectQuery(const StoragePtr & storage) { auto res = std::make_shared(); @@ -78,20 +106,28 @@ ASTPtr InterpreterUpdateQuery::prepareInterpreterSelectQuery(const StoragePtr & //collect assignments std::unordered_map assignments; + String update_table_alias; for (const auto & child : ast_update.assignment_list->children) { - if (const ASTAssignment * assignment = child->as()) + const ASTAssignment * assignment = child->as(); + if (!assignment) + throw Exception("Syntax error in update statement. " + child->getID(), ErrorCodes::SYNTAX_ERROR); + + if (const auto & t = assignment->table_name; !t.empty()) { - if (immutable_columns.count(assignment->column_name)) - throw Exception("Updating partition/unique keys is not allowed.", ErrorCodes::BAD_ARGUMENTS); + if (update_table_alias.empty()) + update_table_alias = t; + else if (update_table_alias != t) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "UPDATE multi tables is not supported. Tables: {}, {}", update_table_alias, t); + } - if (!ordinary_columns.count(assignment->column_name)) - throw Exception("There is no column named " + assignment->column_name, ErrorCodes::BAD_ARGUMENTS); + if (immutable_columns.count(assignment->column_name)) + throw Exception("Updating partition/unique keys is not allowed.", ErrorCodes::BAD_ARGUMENTS); - assignments.emplace(assignment->column_name, assignment->expression()->clone()); - } - else - throw Exception("Syntax error in update statement. " + child->getID(), ErrorCodes::SYNTAX_ERROR); + if (!ordinary_columns.count(assignment->column_name)) + throw Exception("There is no column named " + assignment->column_name, ErrorCodes::BAD_ARGUMENTS); + + assignments.emplace(assignment->column_name, assignment->expression()->clone()); } auto select_list = std::make_shared(); @@ -119,6 +155,8 @@ ASTPtr InterpreterUpdateQuery::prepareInterpreterSelectQuery(const StoragePtr & if (ast_update.single_table) { + if (!update_table_alias.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "No table alias found: {}", update_table_alias); res->setExpression(ASTSelectQuery::Expression::TABLES, std::make_shared()); auto tables = res->tables(); auto tables_elem = std::make_shared(); @@ -131,6 +169,23 @@ ASTPtr InterpreterUpdateQuery::prepareInterpreterSelectQuery(const StoragePtr & } else { + const auto & first_table = getFirstTableExpression(ast_update); + auto first_table_alias = getTableExpressionAlias(first_table); + + /// Check that only the first table is updated. + if (update_table_alias.empty()) + { + /// By default, if update_table is empty, it means the first table is updated. + } + else + { + if (first_table_alias.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "SET with table alias but there is no table alias. {}, {}", update_table_alias); + else if (first_table_alias != update_table_alias) + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "It's only allowed to update the first table `{}`, but `{}` is given.", first_table_alias, update_table_alias); + } + res->setExpression(ASTSelectQuery::Expression::TABLES, ast_update.tables->clone()); } diff --git a/src/Interpreters/KafkaLog.cpp b/src/Interpreters/KafkaLog.cpp index 42e9c3c6246..356fefcfb87 100644 --- a/src/Interpreters/KafkaLog.cpp +++ b/src/Interpreters/KafkaLog.cpp @@ -67,7 +67,7 @@ void KafkaLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(UInt64(event_type)); - columns[i++]->insert(UInt64(DateLUT::instance().toDayNum(event_time))); + columns[i++]->insert(UInt64(DateLUT::serverTimezoneInstance().toDayNum(event_time))); columns[i++]->insert(UInt64(event_time)); columns[i++]->insert(UInt64(duration_ms)); diff --git a/src/Interpreters/LogicalExpressionsOptimizer.cpp b/src/Interpreters/LogicalExpressionsOptimizer.cpp index 936ed0149d2..50d24b4a7ae 100644 --- a/src/Interpreters/LogicalExpressionsOptimizer.cpp +++ b/src/Interpreters/LogicalExpressionsOptimizer.cpp @@ -112,7 +112,9 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() bool found_chain = false; auto * function = to_node->as(); - if (function && function->name == "or" && function->children.size() == 1) + /// Optimization does not respect aliases properly, which can lead to MULTIPLE_EXPRESSION_FOR_ALIAS error. + /// Disable it if an expression has an alias. Proper implementation is done with the new analyzer. + if (function && function->alias.empty() && function->name == "or" && function->children.size() == 1) { const auto * expression_list = function->children[0]->as(); if (expression_list) @@ -121,14 +123,14 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() for (const auto & child : expression_list->children) { auto * equals = child->as(); - if (equals && equals->name == "equals" && equals->children.size() == 1) + if (equals && equals->alias.empty() && equals->name == "equals" && equals->children.size() == 1) { const auto * equals_expression_list = equals->children[0]->as(); if (equals_expression_list && equals_expression_list->children.size() == 2) { /// Equality expr = xN. const auto * literal = equals_expression_list->children[1]->as(); - if (literal) + if (literal && literal->alias.empty()) { auto expr_lhs = equals_expression_list->children[0]->getTreeHash(); OrWithExpression or_with_expression{function, expr_lhs, function->tryGetAlias()}; @@ -199,6 +201,9 @@ bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const Disj const auto & equalities = chain.second; const auto & equality_functions = equalities.functions; + if (settings.optimize_min_equality_disjunction_chain_length == 0) + return false; + /// We eliminate too short chains. if (equality_functions.size() < settings.optimize_min_equality_disjunction_chain_length) return false; diff --git a/src/Interpreters/MaterializedMySQLLog.cpp b/src/Interpreters/MaterializedMySQLLog.cpp index 97054b34bf3..36145d11d81 100644 --- a/src/Interpreters/MaterializedMySQLLog.cpp +++ b/src/Interpreters/MaterializedMySQLLog.cpp @@ -62,7 +62,7 @@ void MaterializedMySQLLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(nameset_2_array(tables)); columns[i++]->insert(type); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(resync_table); diff --git a/src/Interpreters/MergeJoin.cpp b/src/Interpreters/MergeJoin.cpp index 5a06ccc982b..d0ab59a69d9 100644 --- a/src/Interpreters/MergeJoin.cpp +++ b/src/Interpreters/MergeJoin.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include @@ -983,7 +983,7 @@ std::shared_ptr MergeJoin::loadRightBlock(size_t pos) const { auto load_func = [&]() -> std::shared_ptr { - TemporaryFileStream input(flushed_right_blocks[pos]->path(), materializeBlock(right_sample_block)); + TemporaryFileStreamLegacy input(flushed_right_blocks[pos]->path(), materializeBlock(right_sample_block)); return std::make_shared(input.block_in->read()); }; diff --git a/src/Interpreters/MetricLog.cpp b/src/Interpreters/MetricLog.cpp index ae0b85d4d8d..c625a9c0731 100644 --- a/src/Interpreters/MetricLog.cpp +++ b/src/Interpreters/MetricLog.cpp @@ -42,7 +42,7 @@ void MetricLogElement::appendToBlock(MutableColumns & columns) const { size_t column_idx = 0; - columns[column_idx++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[column_idx++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[column_idx++]->insert(event_time); columns[column_idx++]->insert(event_time_microseconds); columns[column_idx++]->insert(milliseconds); diff --git a/src/Interpreters/MonotonicityCheckVisitor.h b/src/Interpreters/MonotonicityCheckVisitor.h index a26f62a0829..2d47a7d673d 100644 --- a/src/Interpreters/MonotonicityCheckVisitor.h +++ b/src/Interpreters/MonotonicityCheckVisitor.h @@ -69,6 +69,12 @@ class MonotonicityCheckMatcher if (!pos) return false; + /// It is possible that tables list is empty. + /// IdentifierSemantic get the position from AST, and it can be not valid to use it. + /// One example is `fetchPartitions` from 02147_order_by_optimizations.sql + if (*pos >= tables.size()) + return false; + if (auto data_type_and_name = tables[*pos].columns.tryGetByName(identifier->shortName())) { arg_data_type = data_type_and_name->type; diff --git a/src/Interpreters/MutationLog.cpp b/src/Interpreters/MutationLog.cpp index 87d071cc672..790999c06c2 100644 --- a/src/Interpreters/MutationLog.cpp +++ b/src/Interpreters/MutationLog.cpp @@ -56,7 +56,7 @@ void MutationLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(UInt64(event_type)); - columns[i++]->insert(UInt64(DateLUT::instance().toDayNum(event_time))); + columns[i++]->insert(UInt64(DateLUT::serverTimezoneInstance().toDayNum(event_time))); columns[i++]->insert(UInt64(event_time)); columns[i++]->insert(database_name); diff --git a/src/Interpreters/MySQL/InterpretersAnalyticalMySQLDDLQuery.cpp b/src/Interpreters/MySQL/InterpretersAnalyticalMySQLDDLQuery.cpp index c77edce2094..f84825c4c4a 100644 --- a/src/Interpreters/MySQL/InterpretersAnalyticalMySQLDDLQuery.cpp +++ b/src/Interpreters/MySQL/InterpretersAnalyticalMySQLDDLQuery.cpp @@ -56,7 +56,7 @@ namespace ErrorCodes namespace MySQLInterpreter { -static NamesAndTypesList getColumnsList(const ASTExpressionList * columns_definition) +static NamesAndTypesList getColumnsList(const ASTExpressionList * columns_definition, ContextPtr context) { NamesAndTypesList columns_name_and_type; for (const auto & declare_column_ast : columns_definition->children) @@ -77,19 +77,12 @@ static NamesAndTypesList getColumnsList(const ASTExpressionList * columns_defini if (type_name_upper.find("TIMESTAMP") != String::npos || type_name_upper.find("DATETIME") != String::npos) { - if (type_name_upper.find("DATETIME64") == String::npos) + if (!data_type_function->arguments || data_type_function->arguments->children.empty()) { data_type_function->name = "DATETIME64"; auto arguments = std::make_shared(); - arguments->children.push_back(std::make_shared(UInt8(DataTypeDateTime64::default_scale))); - if (data_type_function->arguments && !data_type_function->arguments->children.empty()) { - auto &args = data_type_function->arguments->children; - arguments->children.insert(arguments->children.end(), args.begin(), args.end()); - } - if (type_name_upper.find("DateTimeWithoutTz") != String::npos) - { - arguments->children.push_back(std::make_shared("UTC")); - } + UInt8 scale = context->getSettingsRef().datetime_format_mysql_definition ? 0 : DataTypeDateTime64::default_scale; + arguments->children.push_back(std::make_shared(scale)); data_type_function->arguments = arguments; } } @@ -283,61 +276,6 @@ static std::tuple ASTPtr - { - if (type_max_size <= 1000) - return std::make_shared(column_name); - - return makeASTFunction("intDiv", std::make_shared(column_name), - std::make_shared(UInt64(type_max_size / 1000))); - }; - - ASTPtr best_partition; - size_t best_size = 0; - for (const auto & primary_key : primary_keys) - { - DataTypePtr type = primary_key.type; - WhichDataType which(type); - - if (which.isNullable()) - throw Exception("LOGICAL ERROR: MySQL primary key must be not null, it is a bug.", ErrorCodes::LOGICAL_ERROR); - - if (which.isDate() || which.isDate32() || which.isDateTime() || which.isDateTime64()) - { - /// In any case, date or datetime is always the best partitioning key - return makeASTFunction("toYYYYMM", std::make_shared(primary_key.name)); - } - - if (type->haveMaximumSizeOfValue() && (!best_size || type->getSizeOfValueInMemory() < best_size)) - { - if (which.isInt8() || which.isUInt8()) - { - best_size = type->getSizeOfValueInMemory(); - best_partition = numbers_partition(primary_key.name, std::numeric_limits::max()); - } - else if (which.isInt16() || which.isUInt16()) - { - best_size = type->getSizeOfValueInMemory(); - best_partition = numbers_partition(primary_key.name, std::numeric_limits::max()); - } - else if (which.isInt32() || which.isUInt32()) - { - best_size = type->getSizeOfValueInMemory(); - best_partition = numbers_partition(primary_key.name, std::numeric_limits::max()); - } - else if (which.isInt64() || which.isUInt64()) - { - best_size = type->getSizeOfValueInMemory(); - best_partition = numbers_partition(primary_key.name, std::numeric_limits::max()); - } - } - } - - return best_partition; -} - static ASTPtr getOrderByPolicy( const NamesAndTypesList & primary_keys, const NamesAndTypesList & keys = NamesAndTypesList(), const NamesAndTypesList & cluster_keys = NamesAndTypesList()) { @@ -383,6 +321,23 @@ static ASTPtr getOrderByPolicy( return order_by_expression; } +static ASTPtr getPartitionPolicy(const NamesAndTypesList & partition_keys) +{ + for (const auto & partition_key : partition_keys) + { + DataTypePtr type = partition_key.type; + WhichDataType which(type); + + if (which.isDate() || which.isDate32() || which.isDateTime() || which.isDateTime64()) + { + /// In any case, date or datetime is always the best partitioning key + return makeASTFunction("toYYYYMM", std::make_shared(partition_key.name)); + } + } + + return nullptr; +} + namespace { @@ -551,7 +506,6 @@ void InterpreterCreateAnalyticMySQLImpl::validate(const InterpreterCreateAnalyti validateTTLExpression(mysql_storage->ttl_table->ptr()); } - if (mysql_storage->engine) { auto upper_name = Poco::toUpper(mysql_storage->engine->name); @@ -615,7 +569,7 @@ ASTPtr InterpreterCreateAnalyticMySQLImpl::getRewrittenQuery( const TQuery & cre engine_name = Poco::toUpper(mysql_storage->mysql_engine->as()->value.get()); if (engine_names.find(engine_name) == engine_names.end()) { - throw Exception ("Unsupported String Engine Name, please remove quotes", ErrorCodes::MYSQL_SYNTAX_ERROR); + throw Exception ("Unsupported Engine Name", ErrorCodes::MYSQL_SYNTAX_ERROR); } } @@ -625,23 +579,15 @@ ASTPtr InterpreterCreateAnalyticMySQLImpl::getRewrittenQuery( const TQuery & cre return query; } + // table if (has_table_definition) { - NamesAndTypesList columns_name_and_type = getColumnsList(create_defines->columns); + NamesAndTypesList columns_name_and_type = getColumnsList(create_defines->columns, context); const auto & [primary_keys, unique_keys, keys, cluster_keys] = getKeys(create_defines->columns, create_defines->mysql_indices, context, columns_name_and_type); setNotNullModifier(create_defines->columns, primary_keys); convertDecimal(create_defines->columns, primary_keys); - /// The `partition by` expression must use primary keys, otherwise the primary keys will not be merge. - if (mysql_storage->mysql_partition_by) - { - storage->set(storage->partition_by, mysql_storage->mysql_partition_by->clone()); - } - else if (ASTPtr partition_expression = getPartitionPolicy(primary_keys)) - storage->set(storage->partition_by, partition_expression); - - /// The `order by` expression must use primary keys, otherwise the primary keys will not be merge. if (ASTPtr order_by_expression = getOrderByPolicy(primary_keys, keys, cluster_keys)) { auto & list = order_by_expression->as()->arguments; @@ -656,50 +602,48 @@ ASTPtr InterpreterCreateAnalyticMySQLImpl::getRewrittenQuery( const TQuery & cre storage->set(storage->unique_key, unique_key_expression); } + if (ASTPtr partition_expression = getPartitionPolicy(primary_keys)) + storage->set(storage->partition_by, partition_expression); + rewritten_query->set(rewritten_query->columns_list, create_query.columns_list->clone()); rewritten_query->columns_list->mysql_indices = nullptr; } - if (!storage->engine || engine_names.find(Poco::toUpper(storage->engine->name)) != engine_names.end()) - storage->set(storage->engine, makeASTFunction("CnchMergeTree")); - if (!storage->order_by) - storage->set(storage->order_by, makeASTFunction("tuple")); - if (!storage->unique_key) + // storage { - if (storage->primary_key) + if (!storage->engine || engine_names.find(Poco::toUpper(storage->engine->name)) != engine_names.end()) + storage->set(storage->engine, makeASTFunction("CnchMergeTree")); + + if (!storage->order_by) + storage->set(storage->order_by, makeASTFunction("tuple")); + + if (!storage->unique_key) { - storage->set(storage->unique_key, storage->primary_key->clone()); - storage->primary_key = nullptr; + // clickhouse syntax for primary key + if (storage->primary_key) + { + storage->set(storage->unique_key, storage->primary_key->clone()); + storage->primary_key = nullptr; + } + else + { + storage->set(storage->unique_key, makeASTFunction("tuple")); + } } - else + + if (mysql_storage->mysql_partition_by) { - storage->set(storage->unique_key, makeASTFunction("tuple")); + storage->set(storage->partition_by, mysql_storage->mysql_partition_by->clone()); } - } - - if (mysql_storage->distributed_by) - { - // distributed by hash(col) -> cluster by col - const String vw_name = "vw_default"; - auto vw = context->getVirtualWarehousePool().get(vw_name); - // context->setCurrentVW(std::move(vw_handle)); - // auto vw = context->tryGetCurrentVW(); - int total_bucket_number = vw ? vw->getNumWorkers() : 1; - auto cluster_by_ast = std::make_shared(mysql_storage->distributed_by->clone(), std::make_shared(total_bucket_number), 0, false, false); - storage->set(storage->cluster_by, cluster_by_ast); - } - else if (mysql_storage->cluster_by) - { - storage->set(storage->cluster_by, mysql_storage->cluster_by->clone()); - } - { // settings ASTPtr settings = std::make_shared(); auto *settings_ast = settings->as(); settings_ast->is_standalone = false; bool has_index_granularity_setting = false; bool has_partition_level_unique_keys_setting = false; + bool has_enable_bucket_level_unique_keys = false; + bool has_enable_bucket_for_distribute = context->getSettingsRef().enable_bucket_for_distribute; if (auto *const mysql_settings = mysql_storage->settings->as()) { for (const auto & change: mysql_settings->changes) @@ -708,16 +652,39 @@ ASTPtr InterpreterCreateAnalyticMySQLImpl::getRewrittenQuery( const TQuery & cre has_index_granularity_setting = true; if (change.name == "partition_level_unique_keys") has_partition_level_unique_keys_setting = true; + if (change.name == "enable_bucket_level_unique_keys") + has_enable_bucket_level_unique_keys = true; } } - // It's not recommended to mix mysql and clickhosue dialects - // but we have to provide this in case of fall back + + // block_size -> index_granularity if (mysql_storage->block_size && !has_index_granularity_setting) - // block_size -> index_granularity settings_ast->changes.push_back({"index_granularity", mysql_storage->block_size->as()->value.get()}); + // distributed by hash(col) -> cluster by col + if (mysql_storage->distributed_by && has_enable_bucket_for_distribute) + { + const String vw_name = "vw_default"; + auto vw = context->getVirtualWarehousePool().get(vw_name); + + int total_bucket_number = vw ? vw->getNumWorkers() : 1; + auto cluster_by_ast = std::make_shared(mysql_storage->distributed_by->clone(), std::make_shared(total_bucket_number), 0, false, false); + storage->set(storage->cluster_by, cluster_by_ast); + + // distribute by must contain unique key + if (!has_enable_bucket_level_unique_keys) + settings_ast->changes.push_back({"enable_bucket_level_unique_keys", 1}); + } + else if (mysql_storage->cluster_by) + { + // clickhouse cluster by syntax + storage->set(storage->cluster_by, mysql_storage->cluster_by->clone()); + } + + // storage settings for mysql behavior if (!has_partition_level_unique_keys_setting) settings_ast->changes.push_back({"partition_level_unique_keys", 0}); + if (const auto mysql_settings = mysql_storage->settings->as()) settings_ast->changes.insert(settings_ast->changes.end(), mysql_settings->changes.begin(), mysql_settings->changes.end()); diff --git a/src/Interpreters/NestedLoopJoin.cpp b/src/Interpreters/NestedLoopJoin.cpp index 71119491827..1070969ef24 100644 --- a/src/Interpreters/NestedLoopJoin.cpp +++ b/src/Interpreters/NestedLoopJoin.cpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/Interpreters/NodeSelector.cpp b/src/Interpreters/NodeSelector.cpp index 3515d7cd8e2..126fc5c9c75 100644 --- a/src/Interpreters/NodeSelector.cpp +++ b/src/Interpreters/NodeSelector.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -119,10 +121,11 @@ NodeSelectorResult LocalNodeSelector::select(PlanSegment *, ContextPtr query_con NodeSelectorResult SourceNodeSelector::select(PlanSegment * plan_segment_ptr, ContextPtr query_context, DAGGraph * dag_graph_ptr) { checkClusterInfo(plan_segment_ptr); + bool need_stable_schedule = needStableSchedule(plan_segment_ptr); NodeSelectorResult result; // The one worker excluded is server itself. - const auto worker_number = cluster_nodes.rank_workers.size() - 1; - if (plan_segment_ptr->getParallelSize() > worker_number && !query_context->getSettingsRef().bsp_mode) + const auto worker_number = cluster_nodes.all_workers.size() - 1; + if (plan_segment_ptr->getParallelSize() > worker_number && (!query_context->getSettingsRef().bsp_mode || need_stable_schedule)) { throw Exception( ErrorCodes::BAD_QUERY_PARAMETER, @@ -130,6 +133,7 @@ NodeSelectorResult SourceNodeSelector::select(PlanSegment * plan_segment_ptr, Co plan_segment_ptr->getParallelSize(), worker_number); } + // If parallelism is greater than the worker number, we split the parts according to the input size. if (plan_segment_ptr->getParallelSize() > worker_number) { @@ -156,7 +160,7 @@ NodeSelectorResult SourceNodeSelector::select(PlanSegment * plan_segment_ptr, Co { sum += current_size; } - size_t avg = sum / plan_segment_ptr->getParallelSize(); + size_t avg = sum / plan_segment_ptr->getParallelSize() + 1; if (sum < plan_segment_ptr->getParallelSize()) sum = 0; if (sum > 0) @@ -218,14 +222,28 @@ NodeSelectorResult SourceNodeSelector::select(PlanSegment * plan_segment_ptr, Co } else { - size_t parallel_index = 0; - for (const auto & worker : cluster_nodes.rank_workers) + if (need_stable_schedule) { - parallel_index++; - if (parallel_index > plan_segment_ptr->getParallelSize()) - break; - if (worker.address != local_address) - result.worker_nodes.emplace_back(worker); + LOG_TRACE(log, "use stable schedule for segment:{} with {} nodes", plan_segment_ptr->getPlanSegmentId(), worker_number); + if (plan_segment_ptr->getParallelSize() != worker_number) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + " Source plan segment parallel size {} is not equal to worker number {}.", + plan_segment_ptr->getParallelSize(), + worker_number); + for (size_t parallel_index = 0; parallel_index < worker_number; parallel_index++) + { + result.worker_nodes.emplace_back(cluster_nodes.all_workers[parallel_index]); + } + } + else + { + for (size_t parallel_index = 0; parallel_index < plan_segment_ptr->getParallelSize(); parallel_index++) + { + if (parallel_index > plan_segment_ptr->getParallelSize()) + break; + result.worker_nodes.emplace_back(cluster_nodes.all_workers[cluster_nodes.rank_worker_ids[parallel_index]]); + } } } } @@ -245,17 +263,33 @@ NodeSelectorResult ComputeNodeSelector::select(PlanSegment * plan_segment_ptr, C } else { - size_t parallel_index = 0; - for (const auto & worker : cluster_nodes.rank_workers) + bool need_stable_schedule = needStableSchedule(plan_segment_ptr); + if (need_stable_schedule) { - parallel_index++; - if (parallel_index > plan_segment_ptr->getParallelSize()) - break; - if (worker.address != local_address) - result.worker_nodes.emplace_back(worker); + const auto worker_number = cluster_nodes.all_workers.size() - 1; + LOG_TRACE(log, "use stable schedule for segment:{} with {} nodes", plan_segment_ptr->getPlanSegmentId(), worker_number); + if (plan_segment_ptr->getParallelSize() != worker_number) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Plan segment {} need stable schedule, but parallel size {} is not equal to worker number {}.", + plan_segment_ptr->getPlanSegmentId(), + plan_segment_ptr->getParallelSize(), + worker_number); + for (size_t parallel_index = 0; parallel_index < worker_number; parallel_index++) + { + result.worker_nodes.emplace_back(cluster_nodes.all_workers[parallel_index]); + } + } + else + { + for (size_t parallel_index = 0; parallel_index < plan_segment_ptr->getParallelSize(); parallel_index++) + { + if (parallel_index > plan_segment_ptr->getParallelSize()) + break; + result.worker_nodes.emplace_back(cluster_nodes.all_workers[cluster_nodes.rank_worker_ids[parallel_index]]); + } } } - return result; } @@ -304,16 +338,16 @@ NodeSelectorResult LocalityNodeSelector::select(PlanSegment * plan_segment_ptr, return result; } -NodeSelectorResult NodeSelector::select(PlanSegment * plan_segment_ptr, bool is_source) +NodeSelectorResult NodeSelector::select(PlanSegment * plan_segment_ptr, bool has_table_scan) { NodeSelectorResult result; auto segment_id = plan_segment_ptr->getPlanSegmentId(); - LOG_TRACE(log, "Begin to select nodes for segment, id: {}, is_source: {}", segment_id, is_source); + LOG_TRACE(log, "Begin to select nodes for segment, id: {}, has table scan: {}", segment_id, has_table_scan); if (isLocal(plan_segment_ptr)) { result = local_node_selector.select(plan_segment_ptr, query_context); } - else if (is_source) + else if (has_table_scan) { result = source_node_selector.select(plan_segment_ptr, query_context, dag_graph_ptr); } diff --git a/src/Interpreters/NodeSelector.h b/src/Interpreters/NodeSelector.h index 6b1c55e8c48..5794a404b1b 100644 --- a/src/Interpreters/NodeSelector.h +++ b/src/Interpreters/NodeSelector.h @@ -65,18 +65,18 @@ struct ClusterNodes const auto & worker_group = query_context->tryGetCurrentWorkerGroup(); if (worker_group) { - for (auto i : rank_worker_ids) + for (size_t i = 0; i < rank_worker_ids.size(); i++) { const auto & worker_endpoint = worker_group->getHostWithPortsVec()[i]; auto worker_address = getRemoteAddress(worker_endpoint, query_context); - rank_workers.emplace_back(worker_address, NodeType::Remote, worker_endpoint.id); - rank_hosts.emplace_back(worker_endpoint); + all_workers.emplace_back(worker_address, NodeType::Remote, worker_endpoint.id); + all_hosts.emplace_back(worker_endpoint); } } } std::vector rank_worker_ids; - std::vector rank_workers; - HostWithPortsVec rank_hosts; + std::vector all_workers; + HostWithPortsVec all_hosts; }; struct NodeSelectorResult @@ -142,6 +142,12 @@ class CommonNodeSelector ErrorCodes::LOGICAL_ERROR); } } + + bool needStableSchedule(PlanSegment * plan_segment_ptr) + { + const auto & inputs = plan_segment_ptr->getPlanSegmentInputs(); + return std::any_of(inputs.begin(), inputs.end(), [](const auto & input) { return input->isStable(); }); + } void selectPrunedWorkers(DAGGraph * dag_graph_ptr, PlanSegment * plan_segment_ptr, NodeSelectorResult & result, AddressInfo & local_address) { @@ -149,7 +155,7 @@ class CommonNodeSelector if (target_hosts.empty()) { LOG_DEBUG(log, "SourcePrune plan segment {} select first worker.", plan_segment_ptr->getPlanSegmentId()); - for (const auto & worker : cluster_nodes.rank_workers) + for (const auto & worker : cluster_nodes.all_workers) { if (worker.address != local_address) { @@ -163,8 +169,8 @@ class CommonNodeSelector LOG_DEBUG(log, "SourcePrune plan segment {} select workers after source prune.", plan_segment_ptr->getPlanSegmentId()); for (size_t idx = 0; idx < cluster_nodes.rank_worker_ids.size(); idx++) { - if (target_hosts.contains(cluster_nodes.rank_hosts[idx])) - result.worker_nodes.emplace_back(cluster_nodes.rank_workers[idx]); + if (target_hosts.contains(cluster_nodes.all_hosts[idx])) + result.worker_nodes.emplace_back(cluster_nodes.all_workers[idx]); } } } @@ -219,7 +225,7 @@ class NodeSelector { } - NodeSelectorResult select(PlanSegment * plan_segment_ptr, bool is_source); + NodeSelectorResult select(PlanSegment * plan_segment_ptr, bool has_table_scan); void setParallelIndexAndSourceAddrs(PlanSegment * plan_segment_ptr, NodeSelectorResult * result); static PlanSegmentInputPtr tryGetLocalInput(PlanSegment * plan_segment_ptr); diff --git a/src/Interpreters/OpenTelemetrySpanLog.cpp b/src/Interpreters/OpenTelemetrySpanLog.cpp index 46c67a8e4e7..058558b21e7 100644 --- a/src/Interpreters/OpenTelemetrySpanLog.cpp +++ b/src/Interpreters/OpenTelemetrySpanLog.cpp @@ -58,7 +58,7 @@ void OpenTelemetrySpanLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(operation_name); columns[i++]->insert(start_time_us); columns[i++]->insert(finish_time_us); - columns[i++]->insert(DateLUT::instance().toDayNum(finish_time_us / 1000000).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(finish_time_us / 1000000).toUnderType()); // The user might add some ints values, and we will have Int Field, and the // insert will fail because the column requires Strings. Convert the fields // here, because it's hard to remember to convert them in all other places. diff --git a/src/Interpreters/PartLog.cpp b/src/Interpreters/PartLog.cpp index 55c6c6839b9..dfe58cb2f24 100644 --- a/src/Interpreters/PartLog.cpp +++ b/src/Interpreters/PartLog.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -11,12 +13,37 @@ #include #include +#include #include #include +#include namespace DB { +void dumpToMapColumn(const std::unordered_map & map, DB::IColumn * column) +{ + auto * column_map = column ? &typeid_cast(*column) : nullptr; + if (!column_map) + return; + + auto & offsets = column_map->getOffsets(); + auto & key_column = column_map->getKey(); + auto & value_column = column_map->getValue(); + + size_t size = 0; + for (auto & entry : map) + { + UInt64 value = entry.second; + + key_column.insertData(entry.first.c_str(), strlen(entry.first.c_str())); + value_column.insert(value); + size++; + } + + offsets.push_back((offsets.size() == 0 ? 0 : offsets.back()) + size); +} + NamesAndTypesList PartLogElement::getNamesAndTypes() { auto event_type_datatype = std::make_shared( @@ -55,6 +82,8 @@ NamesAndTypesList PartLogElement::getNamesAndTypes() {"rows", std::make_shared()}, {"segments", std::make_shared()}, + {"segments_map", std::make_shared(std::make_shared(), std::make_shared())}, + {"preload_level", std::make_shared()}, {"size_in_bytes", std::make_shared()}, // On disk /// Merge-specific info @@ -76,7 +105,7 @@ void PartLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(query_id); columns[i++]->insert(event_type); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(start_time); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); @@ -90,7 +119,15 @@ void PartLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(path_on_disk); columns[i++]->insert(rows); - columns[i++]->insert(segments); + columns[i++]->insert(segments_count); + + auto * column = columns[i++].get(); + if (segments.size() > 0) + dumpToMapColumn(segments, column); + else + column->insertDefault(); + + columns[i++]->insert(preload_level); columns[i++]->insert(bytes_compressed_on_disk); Array source_part_names_array; @@ -177,7 +214,7 @@ bool PartLog::addNewParts( return true; } -PartLogElement PartLog::createElement(PartLogElement::Type event_type, const IMergeTreeDataPartPtr & part, UInt64 elapsed_ns, const String & exception, UInt64 submit_ts, UInt64 segments) +PartLogElement PartLog::createElement(PartLogElement::Type event_type, const IMergeTreeDataPartPtr & part, UInt64 elapsed_ns, const String & exception, UInt64 submit_ts, UInt64 segments_count, std::unordered_map segments, UInt64 preload_level) { PartLogElement elem; @@ -192,7 +229,9 @@ PartLogElement PartLog::createElement(PartLogElement::Type event_type, const IMe elem.part_name = part->name; elem.rows = part->rows_count; + elem.segments_count = segments_count; elem.segments = segments; + elem.preload_level = preload_level; elem.bytes_compressed_on_disk = part->bytes_on_disk; elem.exception = exception; diff --git a/src/Interpreters/PartLog.h b/src/Interpreters/PartLog.h index 8cc7c98ae3f..dc1997604c1 100644 --- a/src/Interpreters/PartLog.h +++ b/src/Interpreters/PartLog.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "common/types.h" @@ -39,7 +40,9 @@ struct PartLogElement /// Size of the part UInt64 rows = 0; - UInt64 segments = 0; + UInt64 segments_count = 0; + std::unordered_map segments; + UInt64 preload_level = 0; /// Size of files in filesystem UInt64 bytes_compressed_on_disk = 0; @@ -81,7 +84,7 @@ class PartLog : public SystemLog static bool addNewParts(ContextPtr context, const MutableDataPartsVector & parts, UInt64 elapsed_ns, const ExecutionStatus & execution_status = {}); static PartLogElement createElement(PartLogElement::Type event_type, const IMergeTreeDataPartPtr & part, - UInt64 elapsed_ns = 0, const String & exception = "", UInt64 submit_ts = 0, UInt64 segments = 0); + UInt64 elapsed_ns = 0, const String & exception = "", UInt64 submit_ts = 0, UInt64 segments_count = 0, std::unordered_map segments = {}, UInt64 preload_level = 0); }; } diff --git a/src/Interpreters/PartMergeLog.cpp b/src/Interpreters/PartMergeLog.cpp index feb3e1721f3..c7e9dcaacb2 100644 --- a/src/Interpreters/PartMergeLog.cpp +++ b/src/Interpreters/PartMergeLog.cpp @@ -64,7 +64,7 @@ void PartMergeLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(UInt64(event_type)); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(UInt64(event_time)); columns[i++]->insert(database); diff --git a/src/Interpreters/PreparedStatement/PreparedStatementCatalog.cpp b/src/Interpreters/PreparedStatement/PreparedStatementCatalog.cpp new file mode 100644 index 00000000000..85a5e92fd0e --- /dev/null +++ b/src/Interpreters/PreparedStatement/PreparedStatementCatalog.cpp @@ -0,0 +1,27 @@ +#include +#include + +namespace DB +{ +void PreparedStatementCatalogManager::updatePreparedStatement(PreparedStatementItemPtr data) +{ + catalog->updatePreparedStatement(data); +} + +PreparedStatements PreparedStatementCatalogManager::getPreparedStatements() +{ + return catalog->getPreparedStatements(); +} + +PreparedStatementItemPtr PreparedStatementCatalogManager::getPreparedStatement(const String & name) +{ + return catalog->getPreparedStatement(name); +} + +void PreparedStatementCatalogManager::removePreparedStatement(const String & name) +{ + catalog->removePreparedStatement(name); +} + + +} diff --git a/src/Interpreters/PreparedStatement/PreparedStatementCatalog.h b/src/Interpreters/PreparedStatement/PreparedStatementCatalog.h new file mode 100644 index 00000000000..2711ffdb742 --- /dev/null +++ b/src/Interpreters/PreparedStatement/PreparedStatementCatalog.h @@ -0,0 +1,46 @@ +#pragma once +#include +#include +#include + +namespace DB +{ + +class PreparedStatementItem +{ + +public: + PreparedStatementItem(String name_, String create_statement_) + : name(std::move(name_)) + , create_statement(std::move(create_statement_)) + { + } + + String name; + String create_statement; +}; + +using PreparedStatementItemPtr = std::shared_ptr; +using PreparedStatements = std::vector; + +class PreparedStatementCatalogManager +{ +public: + explicit PreparedStatementCatalogManager(const ContextPtr & context) + { + catalog = context->getCnchCatalog(); + } + + void updatePreparedStatement(PreparedStatementItemPtr); + + PreparedStatements getPreparedStatements(); + + PreparedStatementItemPtr getPreparedStatement(const String & name); + + void removePreparedStatement(const String & name); + +private: + std::shared_ptr catalog; +}; + +} diff --git a/src/Interpreters/PreparedStatement/PreparedStatementManager.cpp b/src/Interpreters/PreparedStatement/PreparedStatementManager.cpp index f4cb8bfe945..31c4ecc365b 100644 --- a/src/Interpreters/PreparedStatement/PreparedStatementManager.cpp +++ b/src/Interpreters/PreparedStatement/PreparedStatementManager.cpp @@ -1,10 +1,14 @@ #include #include +#include #include #include #include #include #include +#include +#include "Interpreters/PreparedStatement/PreparedStatementCatalog.h" +#include "Parsers/IAST_fwd.h" #include @@ -18,7 +22,7 @@ namespace ErrorCodes } void PreparedObject::toProto(Protos::PreparedStatement & proto) const { - proto.set_query(query); + proto.set_query(query->formatForErrorMessage()); } void PreparedStatementManager::initialize(ContextMutablePtr context) @@ -26,17 +30,13 @@ void PreparedStatementManager::initialize(ContextMutablePtr context) if (!context->getPreparedStatementManager()) { auto manager_instance = std::make_unique(); - const auto & config = context->getConfigRef(); - String default_path = fs::path{context->getPath()} / "prepared_statement/"; - String path = config.getString("prepared_statement_path", default_path); - manager_instance->prepared_statement_loader = std::make_unique(path); context->setPreparedStatementManager(std::move(manager_instance)); - loadStatementsFromDisk(context); + loadStatementsFromCatalog(context); } } void PreparedStatementManager::set( - const String & name, PreparedObject prepared_object, bool throw_if_exists, bool or_replace, bool is_persistent) + const String & name, PreparedObject prepared_object, bool throw_if_exists, bool or_replace, bool is_persistent, ContextMutablePtr context) { std::unique_lock lock(mutex); @@ -44,10 +44,12 @@ void PreparedStatementManager::set( { if (is_persistent) { + if (!context) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Context is nullptr"); Protos::PreparedStatement proto; - prepared_object.toProto(proto); - throw_if_exists = throw_if_exists && !or_replace; - prepared_statement_loader->storeObject(name, proto, throw_if_exists, or_replace); + PreparedStatementCatalogManager catalog(context); + PreparedStatementItemPtr prepared = std::make_shared(name, prepared_object.query->formatForErrorMessage()); + catalog.updatePreparedStatement(prepared); } cache[name] = std::move(prepared_object); } @@ -67,11 +69,15 @@ SettingsChanges PreparedStatementManager::getSettings(const String & name) const return getUnsafe(name).settings_changes; } -void PreparedStatementManager::remove(const String & name, bool throw_if_not_exists) +void PreparedStatementManager::remove(const String & name, bool throw_if_not_exists, ContextMutablePtr context) { std::unique_lock lock(mutex); + if (context) + { + PreparedStatementCatalogManager catalog(context); + catalog.removePreparedStatement(name); + } - prepared_statement_loader->removeObject(name, false); if (hasUnsafe(name)) cache.erase(name); else if (throw_if_not_exists) @@ -94,7 +100,11 @@ PreparedStatementManager::CacheResultType PreparedStatementManager::getPlanFromC for (auto & [database, table_info] : prepared_object.query_detail->query_access_info) { for (auto & [table, columns] : table_info) - context->addQueryAccessInfo(database, table, columns); + { + auto storage_id = context->tryResolveStorageID(StorageID{database, table}); + context->checkAccess(AccessType::SELECT, storage_id, columns); + context->addQueryAccessInfo(backQuoteIfNeed(storage_id.getDatabaseName()), storage_id.getFullTableName(), columns); + } } } @@ -105,15 +115,12 @@ PreparedStatementManager::CacheResultType PreparedStatementManager::getPlanFromC void PreparedStatementManager::addPlanToCache( const String & name, - const String & query, + ASTPtr & query, SettingsChanges settings_changes, QueryPlanPtr & plan, AnalysisPtr analysis, PreparedParameterSet prepared_params, - ContextMutablePtr & context, - bool throw_if_exists, - bool or_replace, - bool is_persistent) + ContextMutablePtr & context) { PlanNodeId max_id; PreparedObject prepared_object{}; @@ -136,12 +143,12 @@ void PreparedStatementManager::addPlanToCache( { for (const auto & column : it->second) prepared_object.query_detail - ->query_access_info[backQuoteIfNeed(storage_id.getDatabaseName())][storage_id.getFullTableName()] + ->query_access_info[storage_id.getDatabaseName()][storage_id.getTableName()] .emplace_back(column); } } - - set(name, std::move(prepared_object), throw_if_exists, or_replace, is_persistent); + const auto & prepare = query->as(); + set(name, std::move(prepared_object), !prepare.if_not_exists, prepare.or_replace, prepare.is_permanent, context); } PlanNodePtr PreparedStatementManager::getNewPlanNode(PlanNodePtr node, ContextMutablePtr & context, bool cache_plan, PlanNodeId & max_id) @@ -205,39 +212,39 @@ void PreparedStatementManager::clearCache() cache.clear(); } -NamesAndPreparedStatements PreparedStatementManager::getAllStatementsFromDisk(ContextMutablePtr & context) -{ - return prepared_statement_loader->getAllObjects(context); -} - -void PreparedStatementManager::loadStatementsFromDisk(ContextMutablePtr & context) +void PreparedStatementManager::loadStatementsFromCatalog(ContextMutablePtr & context) { if (!context->getPreparedStatementManager()) throw Exception("PreparedStatement cache has to be initialized", ErrorCodes::LOGICAL_ERROR); auto * manager = context->getPreparedStatementManager(); manager->clearCache(); - auto statements = manager->getAllStatementsFromDisk(context); + PreparedStatementCatalogManager catalog(context); + auto statements = catalog.getPreparedStatements(); for (auto & statement : statements) { try { ParserCreatePreparedStatementQuery parser(ParserSettings::valueOf(context->getSettingsRef())); - auto ast = parseQuery(parser, statement.second.query(), "", 0, context->getSettings().max_parser_depth); + auto ast = parseQuery(parser, statement->create_statement, "", 0, context->getSettings().max_parser_depth); auto * create_prep_stat = ast->as(); if (!create_prep_stat) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid prepread statement query: {}", statement.second.query()); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid prepread statement query: {}", statement->create_statement); create_prep_stat->is_permanent = false; - InterpreterSelectQueryUseOptimizer interpreter{ast, context, {}}; + auto query_context = Context::createCopy(context); + query_context->setQueryContext(query_context); + SettingsChanges settings_changes = InterpreterSetQuery::extractSettingsFromQuery(ast, query_context); + query_context->applySettingsChanges(settings_changes); + InterpreterSelectQueryUseOptimizer interpreter{ast, query_context, {}}; interpreter.executeCreatePreparedStatementQuery(); } catch (...) { tryLogWarningCurrentException( &Poco::Logger::get("PreparedStatementManager"), - fmt::format("while build prepared statement {} plan", backQuote(statement.first))); + fmt::format("while build prepared statement {} plan", backQuote(statement->name))); continue; } } diff --git a/src/Interpreters/PreparedStatement/PreparedStatementManager.h b/src/Interpreters/PreparedStatement/PreparedStatementManager.h index 4f0e7f12df0..840a644d928 100644 --- a/src/Interpreters/PreparedStatement/PreparedStatementManager.h +++ b/src/Interpreters/PreparedStatement/PreparedStatementManager.h @@ -29,7 +29,7 @@ struct PreparedObject std::unordered_map>> query_access_info; }; - String query; + ASTPtr query; SettingsChanges settings_changes; PreparedParameterSet prepared_params; std::shared_ptr query_detail; @@ -52,15 +52,14 @@ class PreparedStatementManager PreparedObject prepared_object, bool throw_if_exists = true, bool or_replace = false, - bool is_persistent = true); + bool is_persistent = true, + ContextMutablePtr context = nullptr); PreparedObject getObject(const String & name) const; SettingsChanges getSettings(const String & name) const; - void remove(const String & name, bool throw_if_not_exists); + void remove(const String & name, bool throw_if_not_exists, ContextMutablePtr context = nullptr); void clearCache(); Strings getNames() const; bool has(const String & name) const; - NamesAndPreparedStatements getAllStatementsFromDisk(ContextMutablePtr & context); - struct CacheResultType { QueryPlanPtr plan; @@ -72,24 +71,19 @@ class PreparedStatementManager // TODO @wangtao: extract common logic with InterpreterSelectQueryUseOptimizer::addPlanToCache void addPlanToCache( const String & name, - const String & query, + ASTPtr & query, SettingsChanges settings_changes, QueryPlanPtr & plan, AnalysisPtr analysis, PreparedParameterSet prepared_params, - ContextMutablePtr & context, - bool throw_if_exists, - bool or_replace, - bool is_persistent); + ContextMutablePtr & context); - static void loadStatementsFromDisk(ContextMutablePtr & context); + static void loadStatementsFromCatalog(ContextMutablePtr & context); private: CacheType cache; mutable std::shared_mutex mutex; - std::unique_ptr prepared_statement_loader; - bool hasUnsafe(const String & name) const { return cache.contains(name); diff --git a/src/Interpreters/ProcessList.cpp b/src/Interpreters/ProcessList.cpp index e567160154e..e404ac2f9c7 100644 --- a/src/Interpreters/ProcessList.cpp +++ b/src/Interpreters/ProcessList.cpp @@ -168,7 +168,7 @@ static CurrentMetrics::Metric getQueryTypeMetric(ProcessListQueryType & query_ty static bool isMonitoredCnchQuery(const IAST * ast) { if (!ast) - return false; + return true; if (const auto * create_ast = ast->as(); create_ast && create_ast->select) { @@ -178,8 +178,7 @@ static bool isMonitoredCnchQuery(const IAST * ast) return isMonitoredCnchTable(ast_update->database); else if (const auto * ast_delete = ast->as(); ast_delete && !ast_delete->database.empty() && !ast_delete->table.empty()) return isMonitoredCnchTable(ast_delete->database); - else if (const auto * ast_insert = ast->as(); ast_insert && (ast_insert->select || ast_insert->in_file) - && !ast_insert->table_id.database_name.empty() && !ast_insert->table_id.database_name.empty()) + else if (const auto * ast_insert = ast->as(); ast_insert && !ast_insert->table_id.database_name.empty() && !ast_insert->table_id.database_name.empty()) return isMonitoredCnchTable(ast_insert->table_id.database_name); else if (const auto * ast_select = ast->as()) { @@ -378,7 +377,7 @@ void ProcessList::checkRunningQuery(ContextPtr query_context, bool is_unlimited_ * like SELECT count() FROM remote('127.0.0.{1,2}', system.numbers) * so they must have different query_ids. */ - + const ClientInfo & client_info = query_context->getClientInfo(); const Settings & settings = query_context->getSettingsRef(); std::unique_lock lock(mutex); diff --git a/src/Interpreters/ProcessorProfile.cpp b/src/Interpreters/ProcessorProfile.cpp index 7d9f4144ac7..813dd3da7c3 100644 --- a/src/Interpreters/ProcessorProfile.cpp +++ b/src/Interpreters/ProcessorProfile.cpp @@ -1,6 +1,11 @@ +#include +#include #include #include +#include #include +#include +#include "common/types.h" #include namespace DB @@ -124,21 +129,21 @@ void GroupedProcessorProfile::add(ProcessorId processor_id, const ProcessorProfi step_id = profile->step_id; processor_ids.emplace(processor_id); parallel_size += 1; - grouped_elapsed_us = std::max(grouped_elapsed_us, profile->elapsed_us); - grouped_input_wait_elapsed_us = std::max(grouped_input_wait_elapsed_us, profile->input_wait_elapsed_us); - grouped_output_wait_elapsed_us = std::max(grouped_output_wait_elapsed_us, profile->output_wait_elapsed_us); + sum_grouped_elapsed_us += profile->elapsed_us; + sum_grouped_input_wait_elapsed_us += profile->input_wait_elapsed_us; + sum_grouped_output_wait_elapsed_us += profile->output_wait_elapsed_us; grouped_input_rows += profile->input_rows; grouped_input_bytes += profile->input_bytes; grouped_output_rows += profile->output_rows; grouped_output_bytes += profile->output_bytes; worker_cnt = 1; - max_grouped_elapsed_us = grouped_elapsed_us; - min_grouped_elapsed_us = grouped_elapsed_us; - max_grouped_input_wait_elapsed_us = grouped_input_wait_elapsed_us; - min_grouped_input_wait_elapsed_us = grouped_input_wait_elapsed_us; - max_grouped_output_wait_elapsed_us = grouped_output_wait_elapsed_us; - min_grouped_output_wait_elapsed_us = grouped_output_wait_elapsed_us; + max_grouped_elapsed_us = std::max(max_grouped_elapsed_us, profile->elapsed_us); + min_grouped_elapsed_us = std::min(min_grouped_elapsed_us, profile->elapsed_us); + max_grouped_input_wait_elapsed_us = std::max(max_grouped_input_wait_elapsed_us, profile->input_wait_elapsed_us); + min_grouped_input_wait_elapsed_us = std::min(min_grouped_input_wait_elapsed_us, profile->input_wait_elapsed_us); + max_grouped_output_wait_elapsed_us = std::max(max_grouped_output_wait_elapsed_us, profile->output_wait_elapsed_us); + min_grouped_output_wait_elapsed_us = std::min(min_grouped_output_wait_elapsed_us, profile->output_wait_elapsed_us); } std::set GroupedProcessorProfile::fillChildren(GroupedProcessorProfilePtr & input_processor, std::set & visited) @@ -162,6 +167,8 @@ std::set GroupedProcessorProfile::fillChildren(Group GroupedProcessorProfilePtr GroupedProcessorProfile::getOutputRoot(GroupedProcessorProfilePtr & input_root) { + if (input_root->processor_name == "output_root") + return input_root; std::set visited; std::set outputs; outputs = fillChildren(input_root, visited); @@ -170,26 +177,61 @@ GroupedProcessorProfilePtr GroupedProcessorProfile::getOutputRoot(GroupedProcess return output_root; } -SegmentAndWorkerToGroupedProfile GroupedProcessorProfile::aggregateProfileBetweenWorkers(SegmentAndWorkerToGroupedProfile & worker_grouped_profiles) +UInt128 GroupedProcessorProfile::getPipelineProfilehash(GroupedProcessorProfilePtr & node) { - SegmentAndWorkerToGroupedProfile res; + SipHash hash; + UInt128 key{}; + GroupedProcessorProfiles profiles; + profiles.push_back(node); + while (!profiles.empty()) + { + auto profile = profiles.back(); + profiles.pop_back(); + if (profile->processor_name.starts_with("MergeTree")) + hash.update("MergeTree"); + else if (profile->processor_name.starts_with("MultiPathReceiver")) + hash.update("MultiPathReceiver"); + else + hash.update(profile->processor_name); + for (auto & child : profile->children) + profiles.push_back(child); + } + hash.get128(key); + return key; +} + +SegIdAndAddrToPipelineProfile +GroupedProcessorProfile::aggregatePipelineProfileBetweenWorkers(SegIdAndAddrToPipelineProfile & worker_grouped_profiles) +{ + SegIdAndAddrToPipelineProfile res; for (auto [segment, woker_profile_map] : worker_grouped_profiles) { - String workers_ip_list_str = "["; - GroupedProcessorProfilePtr aggregate_profile = nullptr; + std::unordered_map> hash_to_profiles; + for (auto & [worker_ip, profile] : woker_profile_map) { - if (!aggregate_profile) + auto pipeline_hash = getPipelineProfilehash(profile); + hash_to_profiles[pipeline_hash].emplace(worker_ip, profile); + } + + for (auto [hash, woker_profile] : hash_to_profiles) + { + String workers_ip_list_str = "["; + GroupedProcessorProfilePtr aggregate_profile = nullptr; + for (auto & [worker_ip, profile] : woker_profile) { - workers_ip_list_str = workers_ip_list_str + worker_ip; - aggregate_profile = profile; - continue; + if (!aggregate_profile) + { + workers_ip_list_str = workers_ip_list_str + worker_ip; + aggregate_profile = profile; + continue; + } + workers_ip_list_str = workers_ip_list_str + "," + worker_ip; + aggregate_profile->addProfileRecursively(profile); } - workers_ip_list_str = workers_ip_list_str + "," + worker_ip; - aggregate_profile->addProfileRecursively(profile); + workers_ip_list_str = workers_ip_list_str + "]"; + res[segment][workers_ip_list_str] = aggregate_profile; } - workers_ip_list_str = workers_ip_list_str + "]"; - res[segment][workers_ip_list_str] = aggregate_profile; } return res; } @@ -201,15 +243,15 @@ void GroupedProcessorProfile::addProfileRecursively(GroupedProcessorProfilePtr & parallel_size += profile->parallel_size; worker_cnt++; - grouped_elapsed_us += profile->grouped_elapsed_us; - max_grouped_elapsed_us = std::max(max_grouped_elapsed_us, profile->grouped_elapsed_us); - min_grouped_elapsed_us = std::min(min_grouped_elapsed_us, profile->grouped_elapsed_us); - grouped_input_wait_elapsed_us += profile->grouped_input_wait_elapsed_us; - max_grouped_input_wait_elapsed_us = std::max(max_grouped_input_wait_elapsed_us, profile->grouped_input_wait_elapsed_us); - min_grouped_input_wait_elapsed_us = std::min(min_grouped_input_wait_elapsed_us, profile->grouped_input_wait_elapsed_us); - grouped_output_wait_elapsed_us += profile->grouped_output_wait_elapsed_us; - max_grouped_output_wait_elapsed_us = std::max(max_grouped_output_wait_elapsed_us, profile->grouped_output_wait_elapsed_us); - min_grouped_output_wait_elapsed_us = std::min(min_grouped_output_wait_elapsed_us, profile->grouped_output_wait_elapsed_us); + sum_grouped_elapsed_us += profile->sum_grouped_elapsed_us; + max_grouped_elapsed_us = std::max(max_grouped_elapsed_us, profile->max_grouped_elapsed_us); + min_grouped_elapsed_us = std::min(min_grouped_elapsed_us, profile->min_grouped_elapsed_us); + sum_grouped_input_wait_elapsed_us += profile->sum_grouped_input_wait_elapsed_us; + max_grouped_input_wait_elapsed_us = std::max(max_grouped_input_wait_elapsed_us, profile->max_grouped_input_wait_elapsed_us); + min_grouped_input_wait_elapsed_us = std::min(min_grouped_input_wait_elapsed_us, profile->min_grouped_input_wait_elapsed_us); + sum_grouped_output_wait_elapsed_us += profile->sum_grouped_output_wait_elapsed_us; + max_grouped_output_wait_elapsed_us = std::max(max_grouped_output_wait_elapsed_us, profile->max_grouped_output_wait_elapsed_us); + min_grouped_output_wait_elapsed_us = std::min(min_grouped_output_wait_elapsed_us, profile->min_grouped_output_wait_elapsed_us); grouped_input_rows += profile->grouped_input_rows; grouped_input_bytes += profile->grouped_input_bytes; grouped_output_rows += profile->grouped_output_rows; @@ -228,13 +270,13 @@ Poco::JSON::Object::Ptr GroupedProcessorProfile::getJsonProfiles() json->set("ProcessorName", processor_name); json->set("StepId", step_id); json->set("ParallelSize", parallel_size); - json->set("ElapsedUs", UInt64(grouped_elapsed_us/worker_cnt)); + json->set("ElapsedUs", UInt64(sum_grouped_elapsed_us / parallel_size)); json->set("MaxElapsedUs", max_grouped_elapsed_us); json->set("MinElapsedUs", min_grouped_elapsed_us); - json->set("InputWaitElapsedUs", UInt64(grouped_input_wait_elapsed_us/worker_cnt)); + json->set("InputWaitElapsedUs", UInt64(sum_grouped_input_wait_elapsed_us / parallel_size)); json->set("MaxInputWaitElapsedUs", max_grouped_input_wait_elapsed_us); json->set("MinInputWaitElapsedUs", min_grouped_input_wait_elapsed_us); - json->set("OutputWaitElapsedUs", UInt64(grouped_output_wait_elapsed_us/worker_cnt)); + json->set("OutputWaitElapsedUs", UInt64(sum_grouped_output_wait_elapsed_us / parallel_size)); json->set("MaxOutputWaitElapsedUs", max_grouped_output_wait_elapsed_us); json->set("MinOutputWaitElapsedUs", min_grouped_output_wait_elapsed_us); json->set("InputRows", grouped_input_rows); @@ -250,9 +292,102 @@ Poco::JSON::Object::Ptr GroupedProcessorProfile::getJsonProfiles() return json; } -StepsOperatorProfiles StepOperatorProfile::aggregateOperatorProfileToStepLevel(std::unordered_map> & segment_profile_tree) +std::unordered_map +GroupedProcessorProfile::getProfileMetricsFromOutputRoot(GroupedProcessorProfilePtr & output_root) { - StepsOperatorProfiles res; + if (output_root->processor_name == "input_root") + output_root = GroupedProcessorProfile::getOutputRoot(output_root); + std::unordered_map res; + GroupedProcessorProfiles grouped_profiles; + grouped_profiles.push_back(output_root); + while (!grouped_profiles.empty()) + { + auto node = grouped_profiles.back(); + grouped_profiles.pop_back(); + ProfileMetricPtr profile = std::make_shared(); + profile->id = node->id; + profile->name = node->processor_name; + profile->parallel_size = node->parallel_size; + for (auto & child : node->children) + { + profile->children_ids.emplace_back(child->id); + grouped_profiles.push_back(child); + } + + profile->sum_elapsed_us = node->sum_grouped_elapsed_us; + profile->max_elapsed_us = node->max_grouped_elapsed_us; + profile->min_elapsed_us = node->min_grouped_elapsed_us; + profile->output_rows = node->grouped_output_rows; + profile->output_bytes = node->grouped_output_bytes; + profile->output_wait_sum_elapsed_us = node->sum_grouped_output_wait_elapsed_us; + profile->output_wait_max_elapsed_us = node->max_grouped_output_wait_elapsed_us; + profile->output_wait_min_elapsed_us = node->min_grouped_output_wait_elapsed_us; + InputProfileMetric input; + input.id = 0; + input.input_rows = node->grouped_input_rows; + input.input_bytes = node->grouped_input_bytes; + input.input_wait_sum_elapsed_us = node->sum_grouped_input_wait_elapsed_us; + input.input_wait_max_elapsed_us = node->max_grouped_input_wait_elapsed_us; + input.input_wait_min_elapsed_us = node->min_grouped_input_wait_elapsed_us; + profile->inputs.emplace(0, input); + + res[profile->id] = profile; + } + return res; +} + +GroupedProcessorProfilePtr +GroupedProcessorProfile::getGroupedProfileFromMetrics(std::unordered_map & profile_map, UInt64 root_id) +{ + if (!profile_map.contains(root_id)) + return nullptr; + + auto profile = profile_map.at(root_id); + GroupedProcessorProfilePtr node = std::make_shared(); + node->id = root_id; + node->processor_name = profile->name; + node->parallel_size = profile->parallel_size; + node->grouped_output_rows = profile->output_rows; + node->grouped_output_bytes = profile->output_bytes; + node->sum_grouped_elapsed_us = profile->sum_elapsed_us; + node->max_grouped_elapsed_us = profile->max_elapsed_us; + node->min_grouped_elapsed_us = profile->min_elapsed_us; + + node->sum_grouped_output_wait_elapsed_us = profile->output_wait_sum_elapsed_us; + node->max_grouped_output_wait_elapsed_us = profile->output_wait_max_elapsed_us; + node->min_grouped_output_wait_elapsed_us = profile->output_wait_min_elapsed_us; + if (!profile->inputs.empty()) + { + auto & input_profile = profile->inputs[0]; + node->id = input_profile.id; + node->grouped_input_rows = input_profile.input_rows; + node->grouped_input_bytes = input_profile.input_bytes; + node->sum_grouped_input_wait_elapsed_us = input_profile.input_wait_sum_elapsed_us; + node->max_grouped_input_wait_elapsed_us = input_profile.input_wait_max_elapsed_us; + node->min_grouped_input_wait_elapsed_us = input_profile.input_wait_min_elapsed_us; + } + + node->worker_cnt = 1; + + for (auto & child_id : profile->children_ids) + { + if (profile_map.contains(child_id)) + { + auto child = getGroupedProfileFromMetrics(profile_map, child_id); + child->parents.emplace(node->processor_name, node); + node->children.emplace_back(child); + } + } + return node; +} + + +StepProfiles GroupedProcessorProfile::aggregateOperatorProfileToStepLevel(GroupedProcessorProfilePtr & processor_profile_root) +{ + if (processor_profile_root->processor_name == "input_root") + processor_profile_root = GroupedProcessorProfile::getOutputRoot(processor_profile_root); + + StepProfiles res; struct ProfilesList { @@ -262,128 +397,89 @@ StepsOperatorProfiles StepOperatorProfile::aggregateOperatorProfileToStepLevel(s std::unordered_map profiles_at_each_level; }; - for (auto & [segment_id, processor_profile_roots] : segment_profile_tree) + /// step_id -> map + std::unordered_map step_processor_profiles_at_each_level; + + size_t level = 0; + std::queue q; + std::unordered_set id_set; + q.push(processor_profile_root); + id_set.emplace(processor_profile_root->id); + while (!q.empty()) { - for (auto & processor_profile_root : processor_profile_roots) + size_t size = q.size(); + for (size_t i = 0; i < size; i++) { - /// step_id -> map - std::unordered_map step_processor_profiles_at_each_level; - - size_t level = 0; - std::queue q; - std::unordered_set id_set; - q.push(processor_profile_root); - id_set.emplace(processor_profile_root->id); - while (!q.empty()) - { - size_t size = q.size(); - for (size_t i = 0; i < size; i++) - { - auto processor_profile = q.front(); - q.pop(); - auto & current_step_id = processor_profile->step_id; - auto & inputs = processor_profile->children; - auto & outputs = processor_profile->parents; + auto processor_profile = q.front(); + q.pop(); + auto & current_step_id = processor_profile->step_id; + auto & inputs = processor_profile->children; + auto & outputs = processor_profile->parents; - if (current_step_id == -1 && !outputs.empty() && processor_profile->processor_name != "output_root") - current_step_id = outputs.begin()->second->step_id; + if (current_step_id == -1 && !outputs.empty() && processor_profile->processor_name != "output_root") + current_step_id = outputs.begin()->second->step_id; - step_processor_profiles_at_each_level[current_step_id].profiles_at_each_level[level].push_back(processor_profile); + step_processor_profiles_at_each_level[current_step_id].profiles_at_each_level[level].push_back(processor_profile); - if (outputs.empty()) - step_processor_profiles_at_each_level[current_step_id].output_profiles.push_back(processor_profile); - - if (inputs.empty()) - step_processor_profiles_at_each_level[current_step_id].input_profiles[current_step_id] = processor_profile; + if (outputs.empty()) + step_processor_profiles_at_each_level[current_step_id].output_profiles.push_back(processor_profile); - for (auto & input_profile : inputs) - { - if (input_profile->step_id != -1 && current_step_id != input_profile->step_id) - { - step_processor_profiles_at_each_level[current_step_id].input_profiles[input_profile->step_id] = processor_profile; - step_processor_profiles_at_each_level[input_profile->step_id].output_profiles.push_back(input_profile); - } - if (!id_set.contains(input_profile->id)) - { - q.push(input_profile); - id_set.emplace(input_profile->id); - } - } - } - level++; - } + if (inputs.empty()) + step_processor_profiles_at_each_level[current_step_id].input_profiles[current_step_id] = processor_profile; - for (auto & [step_id, profiles_list] : step_processor_profiles_at_each_level) + for (auto & input_profile : inputs) { - auto step_profile = std::make_shared(); - - for (auto & output_profile : profiles_list.output_profiles) - { - step_profile->output_bytes += output_profile->grouped_output_bytes; - step_profile->output_rows += output_profile->grouped_output_rows; - step_profile->output_wait_elapsed_us = std::max(step_profile->output_wait_elapsed_us, output_profile->grouped_output_wait_elapsed_us); - } - - for (auto & [input_step_id, input_profile] : profiles_list.input_profiles) + if (input_profile->step_id != -1 && current_step_id != input_profile->step_id) { - step_profile->inputs_profile[input_step_id].input_rows = input_profile->grouped_input_rows; - step_profile->inputs_profile[input_step_id].input_bytes = input_profile->grouped_input_bytes; - step_profile->inputs_profile[input_step_id].input_wait_elapsed_us = input_profile->grouped_input_wait_elapsed_us; + step_processor_profiles_at_each_level[current_step_id].input_profiles[input_profile->step_id] = processor_profile; + step_processor_profiles_at_each_level[input_profile->step_id].output_profiles.push_back(input_profile); } - - for (auto & [_, level_profiles] : profiles_list.profiles_at_each_level) + if (!id_set.contains(input_profile->id)) { - UInt64 sum_elapsed_us = 0; - for (auto & profile : level_profiles) - sum_elapsed_us = std::max(sum_elapsed_us, profile->grouped_elapsed_us); - step_profile->sum_elapsed_us += sum_elapsed_us; + q.push(input_profile); + id_set.emplace(input_profile->id); } - res[step_id].push_back(step_profile); } } + level++; } - return res; -} - -StepAggregatedOperatorProfiles -AggregatedStepOperatorProfile::aggregateStepOperatorProfileBetweenWorkers(StepsOperatorProfiles & steps_operator_profiles) -{ - StepAggregatedOperatorProfiles res; - for (auto & [step_id, step_profiles] : steps_operator_profiles) + for (auto & [step_id, profiles_list] : step_processor_profiles_at_each_level) { - if (step_profiles.empty()) + if (step_id == -1) continue; - auto agg_profile_ptr = std::make_shared(); - agg_profile_ptr->step_id = step_id; - std::unordered_map inputs_profile; - for (auto & input : step_profiles[0]->inputs_profile) - inputs_profile[input.first] = {}; + auto step_profile = std::make_shared(); - for (auto & step_profile : step_profiles) + for (auto & output_profile : profiles_list.output_profiles) { - agg_profile_ptr->max_elapsed_us = std::max(agg_profile_ptr->max_elapsed_us, step_profile->sum_elapsed_us); - agg_profile_ptr->min_elapsed_us = std::min(agg_profile_ptr->min_elapsed_us, step_profile->sum_elapsed_us); - agg_profile_ptr->sum_elapsed_us += step_profile->sum_elapsed_us; - agg_profile_ptr->worker_cnt++; - agg_profile_ptr->max_output_wait_elapsed_us = std::max(agg_profile_ptr->max_output_wait_elapsed_us, step_profile->output_wait_elapsed_us); - agg_profile_ptr->min_output_wait_elapsed_us = std::min(agg_profile_ptr->min_output_wait_elapsed_us, step_profile->output_wait_elapsed_us); - agg_profile_ptr->sum_output_wait_elapsed_us += step_profile->output_wait_elapsed_us; - agg_profile_ptr->output_rows += step_profile->output_rows; - agg_profile_ptr->output_bytes += step_profile->output_bytes; - - for (auto & [id, input_profile] : step_profile->inputs_profile) - { - inputs_profile[id].input_wait_elapsed_us += input_profile.input_wait_elapsed_us; - inputs_profile[id].max_input_wait_elapsed_us = std::max(inputs_profile[id].max_input_wait_elapsed_us, input_profile.input_wait_elapsed_us); - inputs_profile[id].min_input_wait_elapsed_us = std::min(inputs_profile[id].min_input_wait_elapsed_us, input_profile.input_wait_elapsed_us); - inputs_profile[id].input_rows += input_profile.input_rows; - inputs_profile[id].input_bytes += input_profile.input_bytes; - } + step_profile->output_bytes += output_profile->grouped_output_bytes; + step_profile->output_rows += output_profile->grouped_output_rows; + step_profile->output_wait_sum_elapsed_us += output_profile->max_grouped_output_wait_elapsed_us; + step_profile->output_wait_max_elapsed_us + = std::max(step_profile->output_wait_max_elapsed_us, output_profile->max_grouped_output_wait_elapsed_us); + step_profile->output_wait_min_elapsed_us + = std::min(step_profile->output_wait_min_elapsed_us, output_profile->min_grouped_output_wait_elapsed_us); } - agg_profile_ptr->inputs_profile = std::move(inputs_profile); - res[step_id] = agg_profile_ptr; + for (auto & [input_step_id, input_profile] : profiles_list.input_profiles) + { + step_profile->inputs[input_step_id].id = input_step_id; + step_profile->inputs[input_step_id].input_rows = input_profile->grouped_input_rows; + step_profile->inputs[input_step_id].input_bytes = input_profile->grouped_input_bytes; + step_profile->inputs[input_step_id].input_wait_sum_elapsed_us += input_profile->max_grouped_output_wait_elapsed_us; + step_profile->inputs[input_step_id].input_wait_max_elapsed_us = std::max( + step_profile->inputs[input_step_id].input_wait_max_elapsed_us, input_profile->max_grouped_output_wait_elapsed_us); + step_profile->inputs[input_step_id].input_wait_min_elapsed_us = std::min( + step_profile->inputs[input_step_id].input_wait_min_elapsed_us, input_profile->min_grouped_output_wait_elapsed_us); + } + + for (auto & [_, level_profiles] : profiles_list.profiles_at_each_level) + { + for (auto & profile : level_profiles) + step_profile->sum_elapsed_us += profile->max_grouped_elapsed_us; + } + step_profile->id = step_id; + res[step_id] = step_profile; } return res; } diff --git a/src/Interpreters/ProcessorProfile.h b/src/Interpreters/ProcessorProfile.h index 8f2006035df..4133ef303f3 100644 --- a/src/Interpreters/ProcessorProfile.h +++ b/src/Interpreters/ProcessorProfile.h @@ -1,12 +1,17 @@ #pragma once -#include -#include #include +#include +#include #include +#include +#include +#include #include #include +#include +#include namespace DB { @@ -48,7 +53,13 @@ struct ProcessorProfile struct GroupedProcessorProfile; using GroupedProcessorProfilePtr = std::shared_ptr; using GroupedProcessorProfiles = std::vector; -using SegmentAndWorkerToGroupedProfile = std::unordered_map>; +using SegIdAndAddrToPipelineProfile = std::unordered_map>; +struct ProfileMetric; +using ProfileMetricPtr = std::shared_ptr; +using ProfileMetrics = std::vector; +using IdToProfileMetrics = std::unordered_map; +using StepProfiles = std::unordered_map; // step_id -> aggregated profile +using AddressToStepProfile = std::unordered_map>; // step_id -> aggregated profile struct GroupedProcessorProfile { @@ -56,9 +67,9 @@ struct GroupedProcessorProfile String processor_name; int64_t step_id = -1; - UInt64 grouped_elapsed_us{}; - UInt64 grouped_input_wait_elapsed_us{}; - UInt64 grouped_output_wait_elapsed_us{}; + UInt64 sum_grouped_elapsed_us{}; + UInt64 sum_grouped_input_wait_elapsed_us{}; + UInt64 sum_grouped_output_wait_elapsed_us{}; UInt64 grouped_input_rows{}; UInt64 grouped_input_bytes{}; @@ -86,64 +97,16 @@ struct GroupedProcessorProfile static GroupedProcessorProfilePtr getOutputRoot(GroupedProcessorProfilePtr & input_root); void add(ProcessorId processor_id, const ProcessorProfilePtr & profile); - static SegmentAndWorkerToGroupedProfile aggregateProfileBetweenWorkers(SegmentAndWorkerToGroupedProfile & worker_grouped_profiles); + static SegIdAndAddrToPipelineProfile aggregatePipelineProfileBetweenWorkers(SegIdAndAddrToPipelineProfile & worker_grouped_profiles); + static UInt128 getPipelineProfilehash(GroupedProcessorProfilePtr & node); void addProfileRecursively(GroupedProcessorProfilePtr & profile); - Poco::JSON::Object::Ptr getJsonProfiles(); -}; - - -struct StepOperatorProfile; -using StepOperatorProfilePtr = std::shared_ptr; -using StepOperatorProfiles = std::vector; -using StepsOperatorProfiles = std::unordered_map; // step_id -> step_operator_profile - -struct InputProfile -{ - UInt64 input_wait_elapsed_us; // sum_input_wait_elapsed_us in AggregatedStepOperatorProfile - UInt64 max_input_wait_elapsed_us{}; - UInt64 min_input_wait_elapsed_us{UINT64_MAX}; - UInt64 input_rows; - UInt64 input_bytes; -}; - -struct StepOperatorProfile -{ - int64_t step_id = -1; - UInt64 sum_elapsed_us{}; - - /// input step_id -> (input_wait_elapsed_us, input_rows, input_bytes) - std::unordered_map inputs_profile; - - UInt64 output_wait_elapsed_us{}; - UInt64 output_rows{}; - UInt64 output_bytes{}; + static std::unordered_map getProfileMetricsFromOutputRoot(GroupedProcessorProfilePtr & output_root); + static GroupedProcessorProfilePtr + getGroupedProfileFromMetrics(std::unordered_map & profile_map, UInt64 root_id); - static StepsOperatorProfiles aggregateOperatorProfileToStepLevel(std::unordered_map> & segment_profile_tree); -}; - -struct AggregatedStepOperatorProfile; -using AggregatedStepOperatorProfilePtr = std::shared_ptr; -using StepAggregatedOperatorProfiles = std::unordered_map; // step_id -> aggregated profile - -struct AggregatedStepOperatorProfile -{ - size_t step_id{}; - UInt64 max_elapsed_us{}; - UInt64 min_elapsed_us{UINT64_MAX}; - UInt64 sum_elapsed_us{}; - UInt64 worker_cnt = 0; - - std::unordered_map inputs_profile; - - UInt64 max_output_wait_elapsed_us{}; - UInt64 min_output_wait_elapsed_us{UINT64_MAX}; - UInt64 sum_output_wait_elapsed_us{}; - UInt64 output_rows{}; - UInt64 output_bytes{}; - - static StepAggregatedOperatorProfiles aggregateStepOperatorProfileBetweenWorkers(StepsOperatorProfiles & steps_operator_profiles); - String toJSONString(size_t indent = 0) const; + static StepProfiles aggregateOperatorProfileToStepLevel(GroupedProcessorProfilePtr & processor_profile_root); + Poco::JSON::Object::Ptr getJsonProfiles(); }; } diff --git a/src/Interpreters/ProcessorsProfileLog.cpp b/src/Interpreters/ProcessorsProfileLog.cpp index 5d44b61e064..5fa5b4a1047 100644 --- a/src/Interpreters/ProcessorsProfileLog.cpp +++ b/src/Interpreters/ProcessorsProfileLog.cpp @@ -60,7 +60,7 @@ void ProcessorProfileLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); diff --git a/src/Interpreters/QueryExchangeLog.cpp b/src/Interpreters/QueryExchangeLog.cpp index 97a4d636ed0..5281d88bf60 100644 --- a/src/Interpreters/QueryExchangeLog.cpp +++ b/src/Interpreters/QueryExchangeLog.cpp @@ -85,7 +85,7 @@ void QueryExchangeLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; columns[i++]->insert(initial_query_id); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insertData(type.data(), type.size()); columns[i++]->insert(exchange_id); diff --git a/src/Interpreters/QueryLog.cpp b/src/Interpreters/QueryLog.cpp index 7fc7ac8ce3d..40aa8b93f3e 100644 --- a/src/Interpreters/QueryLog.cpp +++ b/src/Interpreters/QueryLog.cpp @@ -150,7 +150,8 @@ NamesAndTypesList QueryLogElement::getNamesAndTypes() {"fallback_reason", std::make_shared()}, {"segment_profiles", std::make_shared(std::make_shared())}, {"virtual_warehouse", std::make_shared()}, - {"worker_group", std::make_shared()} + {"worker_group", std::make_shared()}, + {"query_plan", std::make_shared()} }; } @@ -173,7 +174,7 @@ void QueryLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(type); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); columns[i++]->insert(query_start_time); @@ -353,6 +354,7 @@ void QueryLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(virtual_warehouse); columns[i++]->insert(worker_group); + columns[i++]->insert(query_plan); } void QueryLogElement::appendClientInfo(const ClientInfo & client_info, MutableColumns & columns, size_t & i) diff --git a/src/Interpreters/QueryLog.h b/src/Interpreters/QueryLog.h index 457551a180b..f3abe3cab7e 100644 --- a/src/Interpreters/QueryLog.h +++ b/src/Interpreters/QueryLog.h @@ -104,6 +104,7 @@ struct QueryLogElement String virtual_warehouse; String worker_group; + String query_plan; static std::string name() { return "QueryLog"; } diff --git a/src/Interpreters/QueryThreadLog.cpp b/src/Interpreters/QueryThreadLog.cpp index 647a1d89c7a..55d369e4ebf 100644 --- a/src/Interpreters/QueryThreadLog.cpp +++ b/src/Interpreters/QueryThreadLog.cpp @@ -109,7 +109,7 @@ void QueryThreadLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); columns[i++]->insert(query_start_time); diff --git a/src/Interpreters/RemoteReadLog.cpp b/src/Interpreters/RemoteReadLog.cpp index 7b812b98358..21a7a025117 100644 --- a/src/Interpreters/RemoteReadLog.cpp +++ b/src/Interpreters/RemoteReadLog.cpp @@ -35,7 +35,7 @@ void RemoteReadLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::sessionInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(request_time_microseconds); columns[i++]->insert(context); diff --git a/src/Interpreters/RequiredSourceColumnsVisitor.cpp b/src/Interpreters/RequiredSourceColumnsVisitor.cpp index c0a9ed7b427..82626ded122 100644 --- a/src/Interpreters/RequiredSourceColumnsVisitor.cpp +++ b/src/Interpreters/RequiredSourceColumnsVisitor.cpp @@ -266,7 +266,7 @@ bool NoBitmapIndexRequiredSourceColumnsMatcher::needChildVisit(ASTPtr & node, co { /// "indexHint" is a special function for index analysis. Everything that is inside it is not calculated. @sa KeyCondition /// "lambda" visit children itself. - if (f->name == "indexHint" || f->name == "lambda" || BitmapIndexHelper::isArraySetFunctions(f->name)) + if (f->name == "indexHint" || f->name == "lambda" || BitmapIndexHelper::isValidBitMapFunctions(node)) return false; } diff --git a/src/Interpreters/SegmentScheduler.cpp b/src/Interpreters/SegmentScheduler.cpp index 87f1da5ae9d..5bf0a7e4f11 100644 --- a/src/Interpreters/SegmentScheduler.cpp +++ b/src/Interpreters/SegmentScheduler.cpp @@ -14,6 +14,7 @@ */ #include +#include #include #include #include @@ -31,7 +32,7 @@ #include #include #include -#include +#include "Interpreters/ProcessorProfile.h" namespace ProfileEvents { @@ -92,7 +93,14 @@ SegmentScheduler::insertPlanSegments(const String & query_id, PlanSegmentTree * { server_resource->sendResources(query_context); } - + + } + { + if (query_context->isExplainQuery() && query_context->getSettingsRef().report_segment_profiles) + { + std::unique_lock lock(segment_profile_mutex); + segment_profile_map[query_id]; + } } auto * final_segment = plan_segments_ptr->getRoot()->getPlanSegment(); @@ -132,6 +140,30 @@ SegmentScheduler::insertPlanSegments(const String & query_id, PlanSegmentTree * return dag_ptr->plan_segment_status_ptr; } +static void OnCancelQueryCallback( + Protos::CancelQueryResponse * response, brpc::Controller * cntl, std::shared_ptr rpc_client, String query_id) +{ + static auto * log = &Poco::Logger::get("SegmentScheduler"); + + std::unique_ptr cntl_guard(cntl); + std::unique_ptr response_guard(response); + + rpc_client->checkAliveWithController(*cntl); + if (cntl->Failed()) + { + LOG_TRACE( + log, + "Send cancel query with id {} to {} failed, error: {}, msg: {}", + query_id, + butil::endpoint2str(cntl->remote_side()).c_str(), + cntl->ErrorText(), + response->message()); + } + else + { + LOG_TRACE(log, "Send cancel query with id {} to {} success", query_id, butil::endpoint2str(cntl->remote_side()).c_str()); + } +} CancellationCode SegmentScheduler::cancelPlanSegmentsFromCoordinator( const String & query_id, const Int32 & code, const String & exception, ContextPtr query_context) @@ -189,13 +221,16 @@ void SegmentScheduler::cancelWorkerPlanSegments(const String & query_id, const D { String coordinator_addr = query_context->getHostWithPorts().getExchangeAddress(); std::vector call_ids; - call_ids.reserve(dag_ptr->plan_send_addresses.size()); - auto handler = std::make_shared(); + std::set plan_send_addresses; + { + std::unique_lock lock(dag_ptr->status_mutex); + plan_send_addresses = dag_ptr->plan_send_addresses; + } Protos::CancelQueryRequest request; request.set_query_id(query_id); request.set_coordinator_address(coordinator_addr); - for (const auto & addr : dag_ptr->plan_send_addresses) + for (const auto & addr : plan_send_addresses) { auto address = extractExchangeHostPort(addr); std::shared_ptr rpc_client = RpcChannelPool::getInstance().getClient(address, BrpcChannelPoolOptions::DEFAULT_CONFIG_KEY); @@ -205,29 +240,9 @@ void SegmentScheduler::cancelWorkerPlanSegments(const String & query_id, const D Protos::CancelQueryResponse * response = new Protos::CancelQueryResponse(); request.set_query_id(query_id); request.set_coordinator_address(coordinator_addr); - manager.cancelQuery(cntl, &request, response, brpc::NewCallback(RPCHelpers::onAsyncCallDone, response, cntl, handler)); - LOG_INFO( - log, - "Cancel plan segment query_id-{} on host-{}", - query_id, - extractExchangeHostPort(addr)); - } - - if (query_context->getSettingsRef().enable_wait_cancel_rpc) - { - for (auto & call_id : call_ids) - brpc::Join(call_id); - - try - { - handler->throwIfException(); - } - catch (...) - { - tryLogCurrentException(log, "cancelWorkerPlanSegments"); - } + manager.cancelQuery(cntl, &request, response, brpc::NewCallback(OnCancelQueryCallback, response, cntl, rpc_client, query_id)); + LOG_INFO(log, "Cancel plan segment query_id-{} on host-{}", query_id, extractExchangeHostPort(addr)); } - } bool SegmentScheduler::finishPlanSegments(const String & query_id) @@ -244,8 +259,15 @@ bool SegmentScheduler::finishPlanSegments(const String & query_id) } { - std::unique_lock lock(segment_status_mutex); + std::unique_lock lock(segment_profile_mutex); + auto seg_profile_map_ite = segment_profile_map.find(query_id); + if (seg_profile_map_ite != segment_profile_map.end()) + segment_profile_map.erase(seg_profile_map_ite); + } + + { + std::unique_lock lock(segment_status_mutex); auto seg_status_map_ite = segment_status_map.find(query_id); if (seg_status_map_ite != segment_status_map.end()) segment_status_map.erase(seg_status_map_ite); @@ -322,6 +344,31 @@ void SegmentScheduler::updateSegmentStatus(const RuntimeSegmentsStatus & segment status->code = segment_status.code; } + +void SegmentScheduler::updateSegmentProfile(PlanSegmentProfilePtr & segment_profile) +{ + std::unique_lock lock(segment_profile_mutex); + auto segment_profile_iter = segment_profile_map.find(segment_profile->query_id); + if (segment_profile_iter == segment_profile_map.end()) + return; + + PlanSegmentProfilePtr profile = segment_profile; + segment_profile_iter->second[segment_profile->segment_id].emplace_back(profile); +} + +std::unordered_map SegmentScheduler::getSegmentsProfile(const String & query_id) +{ + std::unordered_map res; + { + std::unique_lock lock(segment_profile_mutex); + auto segment_profile_iter = segment_profile_map.find(query_id); + if (segment_profile_iter == segment_profile_map.end()) + return res; + res = segment_profile_iter->second; + } + return res; +} + void SegmentScheduler::checkQueryCpuTime(const String & query_id) { UInt64 max_cpu_seconds = 0; @@ -435,6 +482,30 @@ void SegmentScheduler::updateReceivedSegmentStatusCounter( } } +bool SegmentScheduler::alreadyReceivedAllSegmentStatus(const String & query_id) +{ + std::unique_lock lock(segment_status_mutex); + auto all_segments_iterator = query_map.find(query_id); + auto received_status_segments_counter_iterator = query_status_received_counter_map.find(query_id); + if (received_status_segments_counter_iterator == query_status_received_counter_map.end() && all_segments_iterator == query_map.end()) + return true; + if (received_status_segments_counter_iterator == query_status_received_counter_map.end()) + return false; + if (all_segments_iterator == query_map.end()) + return true; + auto dag_ptr = all_segments_iterator->second; + auto received_status_segments_counter = received_status_segments_counter_iterator->second; + for (auto & parallel : dag_ptr->segment_parallel_size_map) + { + if (parallel.first == 0) + continue; + + if (query_status_received_counter_map[query_id][parallel.first].size() < parallel.second) + return false; + } + return true; +} + void SegmentScheduler::onSegmentFinished(const RuntimeSegmentsStatus & status) { std::unique_lock lock(bsp_scheduler_map_mutex); @@ -499,8 +570,8 @@ void SegmentScheduler::buildDAGGraph(PlanSegmentTree * plan_segments_ptr, std::s // value, readnothing, system table if (plan_segment_ptr->getPlanSegmentInputs().empty()) { - graph_ptr->sources.insert(plan_segment_ptr->getPlanSegmentId()); - graph_ptr->any_tables.insert(plan_segment_ptr->getPlanSegmentId()); + graph_ptr->leaf_segments.insert(plan_segment_ptr->getPlanSegmentId()); + // graph_ptr->segments_has_table_scan.insert(plan_segment_ptr->getPlanSegmentId()); } // source if (!plan_segment_ptr->getPlanSegmentInputs().empty()) @@ -515,9 +586,9 @@ void SegmentScheduler::buildDAGGraph(PlanSegmentTree * plan_segments_ptr, std::s any_tables = true; } if (all_tables) - graph_ptr->sources.insert(plan_segment_ptr->getPlanSegmentId()); + graph_ptr->leaf_segments.insert(plan_segment_ptr->getPlanSegmentId()); if (any_tables) - graph_ptr->any_tables.insert(plan_segment_ptr->getPlanSegmentId()); + graph_ptr->segments_has_table_scan.insert(plan_segment_ptr->getPlanSegmentId()); } // final stage if (plan_segment_ptr->getPlanSegmentOutput()->getPlanSegmentType() == PlanSegmentType::OUTPUT) @@ -549,9 +620,9 @@ void SegmentScheduler::buildDAGGraph(PlanSegmentTree * plan_segments_ptr, std::s } } // do some check - // 1. check source or final is empty - if (graph_ptr->sources.empty()) - throw Exception("Logical error: source is empty", ErrorCodes::LOGICAL_ERROR); + // 1. check if leaf segments or the final is empty + if (graph_ptr->leaf_segments.empty()) + throw Exception("Logical error: no leaf segment", ErrorCodes::LOGICAL_ERROR); if (graph_ptr->final == std::numeric_limits::max()) throw Exception("Logical error: final is empty", ErrorCodes::LOGICAL_ERROR); @@ -656,7 +727,7 @@ PlanSegmentSet SegmentScheduler::getIOPlanSegmentInstanceIDs(const String & quer throw Exception("query_id-" + query_id + " does not exist in scheduler query map", ErrorCodes::LOGICAL_ERROR); const auto & dag_ptr = iter->second; PlanSegmentSet res; - for (auto && segment_id : dag_ptr->any_tables) + for (auto && segment_id : dag_ptr->segments_has_table_scan) { /// wont wait for final segment, because it is already logged in progress_callback if (segment_id != dag_ptr->final) diff --git a/src/Interpreters/SegmentScheduler.h b/src/Interpreters/SegmentScheduler.h index fb75643a4c2..b77b953fbd2 100644 --- a/src/Interpreters/SegmentScheduler.h +++ b/src/Interpreters/SegmentScheduler.h @@ -54,14 +54,24 @@ struct ExceptionWithCode }; using RuntimeSegmentsStatusPtr = std::shared_ptr; +struct PlanSegmentProfile; +using PlanSegmentProfilePtr = std::shared_ptr; +using PlanSegmentProfiles = std::vector; +using PlanSegmentsStatusPtr = std::shared_ptr; using PlanSegmentsPtr = std::vector; // > using RuntimeSegmentsStatusCounter = std::unordered_map>; // > using SegmentStatusMap = std::unordered_map>; +using SegmentProfilesMap = std::unordered_map>; using BspSchedulerMap = std::unordered_map>; enum class OverflowMode; +struct SegmentSchedulerOptions +{ + std::function send_progress_callback; +}; + class SegmentScheduler { public: @@ -91,10 +101,14 @@ class SegmentScheduler void updateSegmentStatus(const RuntimeSegmentsStatus & segment_status); void updateQueryStatus(const RuntimeSegmentsStatus & segment_status); + void updateSegmentProfile(PlanSegmentProfilePtr & segment_profile); + std::unordered_map getSegmentsProfile(const String & query_id); + void updateReceivedSegmentStatusCounter( const String & query_id, const size_t & segment_id, const UInt64 & parallel_index, const RuntimeSegmentsStatus & status); // Return true if only the query runs in bsp mode and all statuses of specified segment has been received. bool bspQueryReceivedAllStatusOfSegment(const String & query_id, const size_t & segment_id) const; + bool alreadyReceivedAllSegmentStatus(const String & query_id); void onSegmentFinished(const RuntimeSegmentsStatus & status); std::shared_ptr getBSPScheduler(const String & query_id); @@ -107,7 +121,9 @@ class SegmentScheduler // Protect maps below. mutable bthread::Mutex segment_status_mutex; + mutable bthread::Mutex segment_profile_mutex; mutable SegmentStatusMap segment_status_map; + mutable SegmentProfilesMap segment_profile_map; mutable std::unordered_map query_status_map; // record exception when exception occurred ConcurrentShardMap query_to_exception_with_code; diff --git a/src/Interpreters/ServerPartLog.cpp b/src/Interpreters/ServerPartLog.cpp index de6d1ce6460..210ee797710 100644 --- a/src/Interpreters/ServerPartLog.cpp +++ b/src/Interpreters/ServerPartLog.cpp @@ -82,7 +82,7 @@ void ServerPartLogElement::appendToBlock(MutableColumns & columns) const size_t i = 0; columns[i++]->insert(static_cast(event_type)); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(txn_id); diff --git a/src/Interpreters/SortedBlocksWriter.cpp b/src/Interpreters/SortedBlocksWriter.cpp index b12616dba1e..eee92cf45ad 100644 --- a/src/Interpreters/SortedBlocksWriter.cpp +++ b/src/Interpreters/SortedBlocksWriter.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include @@ -22,7 +22,7 @@ std::unique_ptr flushToFile(const String & tmp_path, const Block auto tmp_file = createTemporaryFile(tmp_path); std::atomic is_cancelled{false}; - TemporaryFileStream::write(tmp_file->path(), header, stream, &is_cancelled, codec); + TemporaryFileStreamLegacy::write(tmp_file->path(), header, stream, &is_cancelled, codec); if (is_cancelled) throw Exception("Cannot flush MergeJoin data on disk. No space at " + tmp_path, ErrorCodes::NOT_ENOUGH_SPACE); diff --git a/src/Interpreters/SystemLog.h b/src/Interpreters/SystemLog.h index cc3836200e8..d7f59aae92c 100644 --- a/src/Interpreters/SystemLog.h +++ b/src/Interpreters/SystemLog.h @@ -199,6 +199,12 @@ class SystemLog : public ISystemLog, protected boost::noncopyable, protected Wit const String & storage_def_, size_t flush_interval_milliseconds_); + /// destructor is necessary to stop flush thread before deleting member variable `saving_thread` + ~SystemLog() override + { + shutdown(); + } + /** Append a record into log. * Writing to table will be done asynchronously and in case of failure, record could be lost. */ diff --git a/src/Interpreters/TextLog.cpp b/src/Interpreters/TextLog.cpp index baf98b6771d..b129c06b2c5 100644 --- a/src/Interpreters/TextLog.cpp +++ b/src/Interpreters/TextLog.cpp @@ -55,7 +55,7 @@ void TextLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); columns[i++]->insert(microseconds); diff --git a/src/Interpreters/TraceLog.cpp b/src/Interpreters/TraceLog.cpp index dac27aebe58..742c275acab 100644 --- a/src/Interpreters/TraceLog.cpp +++ b/src/Interpreters/TraceLog.cpp @@ -42,7 +42,7 @@ void TraceLogElement::appendToBlock(MutableColumns & columns) const { size_t i = 0; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); columns[i++]->insert(timestamp_ns); diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index 1ef57894e7b..f112641f82c 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -768,11 +768,11 @@ void collectJoinedColumns(TableJoin & analyzed_join, const ASTTableJoin & table_ { bool is_asof = (table_join.strictness == ASTTableJoin::Strictness::Asof); - CollectJoinOnKeysVisitor::Data data{analyzed_join, tables[0], tables[1], aliases, is_asof, false, enable_join_on_1_equals_1, {}, {}, false, ignore_array_join_check_in_join_on_condition, context}; + CollectJoinOnKeysVisitor::Data data{analyzed_join, tables[0], tables[1], aliases, is_asof, false, enable_join_on_1_equals_1, {}, {}, false, ignore_array_join_check_in_join_on_condition, context, {}, !context->getSettings().enable_optimizer}; CollectJoinOnKeysVisitor(data).visit(table_join.on_expression); CollectJoinOnKeysMatcher::analyzeJoinOnConditions(data, table_join.kind); - if (!data.has_some && !data.is_nest_loop_join) + if (!data.has_some && !data.is_nest_loop_join && !context->getSettings().enable_optimizer) throw Exception("Cannot get JOIN keys from JOIN ON section: " + queryToString(table_join.on_expression), ErrorCodes::INVALID_JOIN_ON_EXPRESSION); if (is_asof) diff --git a/src/Interpreters/UniqueTableLog.cpp b/src/Interpreters/UniqueTableLog.cpp index eec6e295083..b2b8eddd5d7 100644 --- a/src/Interpreters/UniqueTableLog.cpp +++ b/src/Interpreters/UniqueTableLog.cpp @@ -44,7 +44,7 @@ void UniqueTableLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(table); columns[i++]->insert(type); - columns[i++]->insert(DateLUT::instance().toDayNum(event_time).toUnderType()); + columns[i++]->insert(DateLUT::sessionInstance().toDayNum(event_time).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insert(txn_id); diff --git a/src/Interpreters/WorkerGroupHandle.cpp b/src/Interpreters/WorkerGroupHandle.cpp index 437d4dbfb19..7afe470527a 100644 --- a/src/Interpreters/WorkerGroupHandle.cpp +++ b/src/Interpreters/WorkerGroupHandle.cpp @@ -36,6 +36,7 @@ namespace ErrorCodes extern const int NO_SUCH_SERVICE; extern const int VIRTUAL_WAREHOUSE_NOT_FOUND; extern const int RESOURCE_MANAGER_NO_AVAILABLE_WORKER; + extern const int WORKER_NODE_NOT_FOUND; } WorkerGroupHandle WorkerGroupHandleImpl::mockWorkerGroupHandle(const String & worker_id_prefix_, UInt64 worker_number_, const ContextPtr & context_) @@ -80,6 +81,15 @@ WorkerGroupHandleImpl::WorkerGroupHandleImpl( , metrics(metrics_) , worker_num(hosts.size()) { + /// some allocation algorithm (such as jump consistent hash) work best when + /// 1) the index of existing workers don't change + /// 2) new workers are added to the end with larger index + /// we achieve this by + /// 1) let k8s assign sequential worker id "{WG_NAME}_{IDX}" to each worker + /// 2) make sure workers are sorted in worker id's order + /// TODO: sort in numeric order rather than lexicographic order + std::sort(hosts.begin(), hosts.end()); + auto current_context = getContext(); const auto & settings = current_context->getSettingsRef(); @@ -121,7 +131,7 @@ WorkerGroupHandleImpl::WorkerGroupHandleImpl( shards_info.emplace_back(std::move(info)); } - + ring = buildRing(this->shards_info, current_context); LOG_DEBUG(&Poco::Logger::get("WorkerGroupHandleImpl"), "Success built ring with {} nodes\n", ring->size()); @@ -246,6 +256,14 @@ CnchWorkerClientPtr WorkerGroupHandleImpl::doGetWorkerClient(const HostWithPorts return getContext()->getCnchWorkerClientPools().getWorker(host_ports); } +size_t WorkerGroupHandleImpl::getWorkerIndex(const String & worker_id) const +{ + for (size_t i = 0; i < hosts.size(); ++i) + if (worker_id == hosts[i].id) + return i; + throw Exception(ErrorCodes::WORKER_NODE_NOT_FOUND, "worker '{}' not found in worker group {}", worker_id, id); +} + bool WorkerGroupHandleImpl::isSame(const WorkerGroupData & data) const { if (id != data.id || vw_name != data.vw_name) diff --git a/src/Interpreters/WorkerGroupHandle.h b/src/Interpreters/WorkerGroupHandle.h index 806659a9176..f29703b5fcb 100644 --- a/src/Interpreters/WorkerGroupHandle.h +++ b/src/Interpreters/WorkerGroupHandle.h @@ -113,6 +113,8 @@ class WorkerGroupHandleImpl : private boost::noncopyable, public WithContext CnchWorkerClientPtr getWorkerClient(const HostWithPorts & host_ports) const; CnchWorkerClientPtr doGetWorkerClient(const HostWithPorts & host_ports) const; + size_t getWorkerIndex(const String & worker_id) const; + std::optional indexOf(const HostWithPorts & host_ports) const { for (size_t i = 0, size = hosts.size(); i < size; i++) diff --git a/src/Interpreters/WorkerStatusManager.h b/src/Interpreters/WorkerStatusManager.h index 495e1e8bae4..2175daf2fe4 100644 --- a/src/Interpreters/WorkerStatusManager.h +++ b/src/Interpreters/WorkerStatusManager.h @@ -218,9 +218,9 @@ struct WorkerStatus : public DB::WorkerNodeResourceData double scheduler_score{0}; }; -using WorkerNodeSet = std::unordered_set; +using WorkerNodeSet = std::unordered_set; using WorkerStatusPtr = std::shared_ptr; -using WorkerNodeStatusContainer = std::unordered_map; +using WorkerNodeStatusContainer = std::unordered_map; struct WorkerStatusExtra { @@ -283,7 +283,7 @@ class WorkerGroupStatus }; using WorkerGroupStatusPtr = std::shared_ptr; -using UnhealthWorkerStatusMap = std::unordered_map; +using UnhealthWorkerStatusMap = std::unordered_map; class WorkerStatusManager : public WithContext { @@ -365,8 +365,8 @@ class WorkerStatusManager : public WithContext void shutdown(); private: - ThreadSafeMap global_extra_workers_status; - ThreadSafeMap unhealth_workers_status; + ThreadSafeMap global_extra_workers_status; + ThreadSafeMap unhealth_workers_status; ThreadSafeMap vw_worker_list_map; AdaptiveSchedulerConfig adaptive_scheduler_config; diff --git a/src/Interpreters/ZooKeeperLog.cpp b/src/Interpreters/ZooKeeperLog.cpp index d24fa78ea17..f3d86d6453a 100644 --- a/src/Interpreters/ZooKeeperLog.cpp +++ b/src/Interpreters/ZooKeeperLog.cpp @@ -172,7 +172,7 @@ void ZooKeeperLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(type); auto event_time_seconds = event_time / 1000000; - columns[i++]->insert(DateLUT::instance().toDayNum(event_time_seconds).toUnderType()); + columns[i++]->insert(DateLUT::serverTimezoneInstance().toDayNum(event_time_seconds).toUnderType()); columns[i++]->insert(event_time); columns[i++]->insertData(IPv6ToBinary(address.host()).data(), 16); columns[i++]->insert(address.port()); diff --git a/src/Interpreters/convertFieldToType.cpp b/src/Interpreters/convertFieldToType.cpp index fd3e0c37973..61b06c160fe 100644 --- a/src/Interpreters/convertFieldToType.cpp +++ b/src/Interpreters/convertFieldToType.cpp @@ -399,29 +399,32 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID } else if (const DataTypeMap * type_map = typeid_cast(&type)) { - const auto & key_type = *type_map->getKeyType(); - const auto & value_type = *type_map->getValueType(); + if (src.getType() == Field::Types::Map) + { + const auto & key_type = *type_map->getKeyType(); + const auto & value_type = *type_map->getValueType(); - const auto & map = src.get(); - size_t map_size = map.size(); + const auto & map = src.get(); + size_t map_size = map.size(); - Map res(map_size); + Map res(map_size); - bool have_unconvertible_element = false; - for (size_t i = 0; i < map_size; ++i) - { - const auto & key = map[i].first; - const auto & value = map[i].second; + bool have_unconvertible_element = false; + for (size_t i = 0; i < map_size; ++i) + { + const auto & key = map[i].first; + const auto & value = map[i].second; - res[i] = {convertFieldToType(key, key_type), convertFieldToType(value, value_type)}; - if (res[i].first.isNull() && !key_type.isNullable()) - have_unconvertible_element = true; + res[i] = {convertFieldToType(key, key_type), convertFieldToType(value, value_type)}; + if (res[i].first.isNull() && !key_type.isNullable()) + have_unconvertible_element = true; - if (res[i].second.isNull() && !value_type.isNullable()) - have_unconvertible_element = true; - } + if (res[i].second.isNull() && !value_type.isNullable()) + have_unconvertible_element = true; + } - return have_unconvertible_element ? Field(Null()) : Field(res); + return have_unconvertible_element ? Field(Null()) : Field(res); + } } else if (const DataTypeAggregateFunction * agg_func_type = typeid_cast(&type)) { diff --git a/src/Interpreters/executeQuery.cpp b/src/Interpreters/executeQuery.cpp index f0f98bcc7bf..7fd8db382f4 100644 --- a/src/Interpreters/executeQuery.cpp +++ b/src/Interpreters/executeQuery.cpp @@ -747,8 +747,12 @@ static TransactionCnchPtr prepareCnchTransaction(ContextMutablePtr context, [[ma return {}; } -void interpretSettings(ASTPtr ast, ContextMutablePtr context) +void interpretSettings(ASTPtr query, ContextMutablePtr context) { + auto & ast = query; + if (auto * explain_select_query = ast->as()) + ast = explain_select_query->getExplainedQuery(); + if (const auto * select_query = ast->as()) { if (auto new_settings = select_query->settings()) @@ -768,6 +772,16 @@ void interpretSettings(ASTPtr ast, ContextMutablePtr context) } } } + else if (const auto * create_select_query = ast->as(); create_select_query && create_select_query->select) + { + const auto * select_in_query = create_select_query->select->as(); + if (select_in_query && !select_in_query->list_of_selects->children.empty()) + { + const auto * last_select = select_in_query->list_of_selects->children.back()->as(); + if (last_select && last_select->settings()) + InterpreterSetQuery(last_select->settings(), context).executeForCurrentContext(); + } + } else if (const auto * query_with_output = dynamic_cast(ast.get())) { if (query_with_output->settings_ast) @@ -846,6 +860,9 @@ static std::tuple executeQueryImpl( { context_ptr->getVWCustomizedSettings()->overwriteDefaultSettings(vw_name, context_ptr); } + + if (context_ptr->hasSessionContext()) + context_ptr->applySessionSettingsChanges(); } }; @@ -919,8 +936,7 @@ static std::tuple executeQueryImpl( /// Interpret SETTINGS clauses as early as possible (before invoking the corresponding interpreter), /// to allow settings to take effect. - if (input_ast == nullptr) - InterpreterSetQuery::applySettingsFromQuery(ast, context); + InterpreterSetQuery::applySettingsFromQuery(ast, context); if (context->getServerType() == ServerType::cnch_server && context->hasQueryContext()) { @@ -1026,6 +1042,9 @@ static std::tuple executeQueryImpl( } } + if (context->hasSessionContext()) + context->clearSessionSettingsChanges(); + /// Copy query into string. It will be written to log and presented in processlist. If an INSERT query, string will not include data to insertion. String query(begin, query_end); @@ -1440,6 +1459,11 @@ static std::tuple executeQueryImpl( } } + if (settings.log_query_plan) + { + elem.query_plan = context->getQueryContext()->getQueryPlan(); + } + interpreter->extendQueryLogElem(elem, ast, context, query_database, query_table); if (settings.log_query_settings) @@ -2120,7 +2144,7 @@ void executeQuery( if (set_result_details) set_result_details( - context->getClientInfo().current_query_id, out->getContentType(), format_name, DateLUT::instance().getTimeZone(), streams.coordinator); + context->getClientInfo().current_query_id, out->getContentType(), format_name, DateLUT::serverTimezoneInstance().getTimeZone(), streams.coordinator); copyData( *streams.in, *out, []() { return false; }, [&out](const Block &) { out->flush(); }); @@ -2177,7 +2201,7 @@ void executeQuery( if (set_result_details) set_result_details( - context->getClientInfo().current_query_id, out->getContentType(), format_name, DateLUT::instance().getTimeZone(), streams.coordinator); + context->getClientInfo().current_query_id, out->getContentType(), format_name, DateLUT::serverTimezoneInstance().getTimeZone(), streams.coordinator); pipeline.setOutputFormat(std::move(out)); } @@ -2324,7 +2348,7 @@ void executeHttpQueryInAsyncMode( query.data(), query.data() + query.size(), ast, context, false, QueryProcessingStage::Complete, has_query_tail, istr); auto & pipeline = streams.pipeline; if (set_result_details_cp) - set_result_details_cp(query_id, "text/plain; charset=UTF-8", format_name1_cp, DateLUT::instance().getTimeZone(), streams.coordinator); + set_result_details_cp(query_id, "text/plain; charset=UTF-8", format_name1_cp, DateLUT::serverTimezoneInstance().getTimeZone(), streams.coordinator); if (streams.in) { const auto * ast_query_with_output = dynamic_cast(ast.get()); @@ -2425,54 +2449,5 @@ void executeHttpQueryInAsyncMode( }); } -void adjustAccessTablesIfNeeded(ContextMutablePtr & context) -{ - // In case access_table_names is set, this query will be readonly and - // access right will be propagated to remote tables - String access_table_names = context->getSettingsRef().access_table_names; - if (!access_table_names.empty()) - { - auto add_access_table_name = [&](const String & db, const String & tbl) - { - access_table_names.append(",").append(db).append(".").append(tbl); - context->setSetting("access_table_names", access_table_names); - }; - std::vector tables; - boost::split(tables, access_table_names, boost::is_any_of(" ,")); - - for (auto & table : tables) - { - char * begin = table.data(); - char * end = begin + table.size(); - Tokens tokens(begin, end); - IParser::Pos token_iterator(tokens, context->getSettingsRef().max_parser_depth); - auto pos = token_iterator; - Expected expected; - String database_name, table_name; - if (!parseDatabaseAndTableName(pos, expected, database_name, table_name)) - continue; - - StorageID table_id{database_name, table_name}; - /// tryGetTable below requires resolved table id - StorageID resolved = context->tryResolveStorageID(table_id); - if (!resolved) - continue; - - // continue if current table is temporary table. - if (resolved.database_name == DatabaseCatalog::TEMPORARY_DATABASE) - continue; - - /// access_table_names need to have resolved name, otherwise tryGetTable below will fail - if (table_id.database_name.empty() && !resolved.database_name.empty()) - add_access_table_name(resolved.getDatabaseName(), resolved.getTableName()); - - // auto storage_ptr = DatabaseCatalog::instance().tryGetTable(resolved, context); - // auto * distributed = dynamic_cast(storage_ptr.get()); - // if (distributed && !distributed->getRemoteTableName().empty()) - // add_access_table_name(distributed->getRemoteDatabaseName(), distributed->getRemoteTableName()); - } - } -} - } diff --git a/src/Interpreters/executeQuery.h b/src/Interpreters/executeQuery.h index d69482346ba..542e8b85649 100644 --- a/src/Interpreters/executeQuery.h +++ b/src/Interpreters/executeQuery.h @@ -122,5 +122,4 @@ void updateAsyncQueryStatus( void interpretSettings(ASTPtr ast, ContextMutablePtr context); -void adjustAccessTablesIfNeeded(ContextMutablePtr & context); } diff --git a/src/Interpreters/executeQueryHelper.cpp b/src/Interpreters/executeQueryHelper.cpp index 8e8afae1467..e9608469eb1 100644 --- a/src/Interpreters/executeQueryHelper.cpp +++ b/src/Interpreters/executeQueryHelper.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -32,77 +33,161 @@ namespace DB HostWithPorts getTargetServer(ContextPtr context, ASTPtr & ast) { - /// Only get target server for main table - String database, table; - bool is_alter_database = false; + auto get_target_server_for_table = [&] (const String & database_name, const String & table_name) -> HostWithPorts + { + if (database_name == "system") + return {}; + + DatabaseAndTable db_and_tb = DatabaseCatalog::instance().tryGetDatabaseAndTable(StorageID(database_name, table_name), context); + DatabasePtr db_ptr = std::move(db_and_tb.first); + StoragePtr storage_ptr = std::move(db_and_tb.second); + if (!db_ptr || !storage_ptr || db_ptr->getEngineName() != "Cnch") + return {}; + + auto topology_master = context->getCnchTopologyMaster(); + + return topology_master->getTargetServer( + UUIDHelpers::UUIDToString(storage_ptr->getStorageUUID()), storage_ptr->getServerVwName(), context->getTimestamp(), true); + }; if (const auto * alter = ast->as()) { - database = alter->database; - table = alter->table; - is_alter_database = (alter->alter_object == ASTAlterQuery::AlterObjectType::DATABASE); + if (alter->alter_object == ASTAlterQuery::AlterObjectType::DATABASE) + return {}; + return get_target_server_for_table(alter->database.empty() ? context->getCurrentDatabase() : alter->database, alter->table); } else if (const auto * alter_mysql = ast->as()) { - database = alter_mysql->database; - table = alter_mysql->table; - is_alter_database = (alter_mysql->alter_object == ASTAlterQuery::AlterObjectType::DATABASE); + if(alter_mysql->alter_object == ASTAlterQuery::AlterObjectType::DATABASE) + return {}; + return get_target_server_for_table(alter_mysql->database.empty() ? context->getCurrentDatabase() : alter_mysql->database, alter_mysql->table); } else if (const auto * select = ast->as()) { - if (!context->getSettingsRef().enable_select_query_forwarding) + if (!context->getSettingsRef().enable_select_query_forwarding && !context->getSettingsRef().enable_multiple_table_select_query_forwarding) return {}; ASTs tables; bool has_table_func = false; ASTSelectQuery::collectAllTables(ast.get(), tables, has_table_func); - // when query inlcudes multiple tables, it is better to just keep existing host since cannot guarantee all tables are in the same host. - if (!has_table_func && !tables.empty() && tables.size() == 1) + + if (tables.empty() || has_table_func) + return {}; + + String current_db = context->getCurrentDatabase(); + if (tables.size() == 1) { - // simplily use the first table if there are multiple tables used - DatabaseAndTableWithAlias db_and_table(tables[0]); + DatabaseAndTableWithAlias db_and_table(tables[0], current_db); LOG_DEBUG( &Poco::Logger::get("executeQuery"), - "Extract db and table {}.{} from the query.", + "Get main table `{}.{}` for current select query.", db_and_table.database, db_and_table.table); - database = db_and_table.database; - table = db_and_table.table; + return get_target_server_for_table(db_and_table.database, db_and_table.table); } else + { + if (!context->getSettingsRef().enable_multiple_table_select_query_forwarding) + return {}; + + std::vector db_and_tables; + for (const auto & table_ast : tables) + db_and_tables.emplace_back(DatabaseAndTableWithAlias(table_ast, current_db)); + + /// For multiple table select, we forward the query with the following policy: + /// 1. If the main table is explicitly set, use the user defined main table. + /// 2. If one of the table is set to be main table in the table settings, use that table + /// 3. Pick up the first one table shows up in the query as main table. + String explicit_main_table = context->getSettingsRef().explicit_main_table; + if (!explicit_main_table.empty()) + { + char * begin = explicit_main_table.data(); + char * end = begin + explicit_main_table.size(); + Tokens tokens(begin, end); + IParser::Pos token_iterator(tokens, context->getSettingsRef().max_parser_depth); + auto pos = token_iterator; + Expected expected; + String database_name, table_name; + if (parseDatabaseAndTableName(pos, expected, database_name, table_name)) + { + if (database_name.empty()) + database_name = current_db; + + // Only if the specified main table shows up in select we can forward the query to its host server. + if (std::any_of(db_and_tables.begin(), db_and_tables.end(), [&](const auto & ele){return ele.database == database_name && ele.table == table_name;})) + { + LOG_DEBUG( + &Poco::Logger::get("executeQuery"), + "Get explicit main table `{}.{}` for current select query from settings.", + database_name, + table_name); + return get_target_server_for_table(database_name, table_name); + } + else + { + LOG_WARNING( + &Poco::Logger::get("executeQuery"), + "Ignore main table settings because `{}.{}` is not in the select query.", + database_name, + table_name); + } + } + } + + StoragePtr main_storage; + for (const auto & db_and_table : db_and_tables) + { + if (db_and_table.database == "system") + continue; + + DatabaseAndTable db_and_tb = DatabaseCatalog::instance().tryGetDatabaseAndTable(StorageID(db_and_table.database, db_and_table.table), context); + DatabasePtr db_ptr = std::move(db_and_tb.first); + StoragePtr storage_ptr = std::move(db_and_tb.second); + if (!db_ptr || !storage_ptr || db_ptr->getEngineName() != "Cnch") + continue; + + if (!main_storage) + main_storage = std::move(storage_ptr); + else + { + auto * cnch = dynamic_cast(storage_ptr.get()); + if (!cnch) + continue; + + // If multiple tables are set to be main table, pick up the first one. + if (cnch->getSettings()->as_main_table) + { + main_storage = std::move(storage_ptr); + break; + } + } + } + + if (main_storage) + { + LOG_DEBUG( + &Poco::Logger::get("executeQuery"), + "Get main table `{}.{}` for current select query.", + main_storage->getDatabaseName(), + main_storage->getTableName()); + auto topology_master = context->getCnchTopologyMaster(); + return topology_master->getTargetServer( + UUIDHelpers::UUIDToString(main_storage->getStorageUUID()), main_storage->getServerVwName(), context->getTimestamp(), true); + } + return {}; + } } else if (const auto * rename = ast->as()) { - if (!rename->database) - { - database = rename->elements.at(0).from.database; - table = rename->elements.at(0).from.table; - } - else + if (rename->database) return {}; + + return get_target_server_for_table(rename->elements.at(0).from.database.empty() ? context->getCurrentDatabase() : rename->elements.at(0).from.database, + rename->elements.at(0).from.table); } else return {}; - - if (database.empty()) - database = context->getCurrentDatabase(); - - if (database == "system" || is_alter_database) - return {}; - - DatabaseAndTable db_and_tb = DatabaseCatalog::instance().tryGetDatabaseAndTable(StorageID(database, table), context); - DatabasePtr db_ptr = std::move(db_and_tb.first); - StoragePtr storage_ptr = std::move(db_and_tb.second); - if (!db_ptr || !storage_ptr) - return {}; - if (db_ptr->getEngineName() != "Cnch") - return {}; - - auto topology_master = context->getCnchTopologyMaster(); - - return topology_master->getTargetServer( - UUIDHelpers::UUIDToString(storage_ptr->getStorageUUID()), storage_ptr->getServerVwName(), context->getTimestamp(), true); } void executeQueryByProxy(ContextMutablePtr context, const HostWithPorts & server, const ASTPtr & ast, BlockIO & res, bool in_interactive_txn, const String & query) diff --git a/src/Interpreters/loadMetadata.cpp b/src/Interpreters/loadMetadata.cpp index af095d2b89a..cfded4748f1 100644 --- a/src/Interpreters/loadMetadata.cpp +++ b/src/Interpreters/loadMetadata.cpp @@ -246,7 +246,7 @@ void reloadFormatSchema(ContextMutablePtr context, String remote_format_schema_p remote_format_schema_path += "/"; // add it by default // try download files from remote_format_schema_path to format_schema_path Poco::URI remote_uri(remote_format_schema_path); - if (remote_uri.getScheme() == "hdfs") + if (isHdfsOrCfsScheme(remote_uri.getScheme())) { HDFSBuilderPtr builder = context->getHdfsConnectionParams().createBuilder(remote_uri); HDFSFSPtr fs = createHDFSFS(builder.get()); @@ -289,7 +289,7 @@ void reloadFormatSchema(ContextMutablePtr context, String remote_format_schema_p } else { - if(log) {LOG_ERROR(log, "remote_format_schema_path only support hdfs");} + if(log) {LOG_ERROR(log, "remote_format_schema_path only support hdfs and cfs");} } } #endif diff --git a/src/Interpreters/profile/PlanSegmentProfile.cpp b/src/Interpreters/profile/PlanSegmentProfile.cpp new file mode 100644 index 00000000000..b3584f47a25 --- /dev/null +++ b/src/Interpreters/profile/PlanSegmentProfile.cpp @@ -0,0 +1,192 @@ +#include +#include +#include +#include + + +namespace DB +{ +void InputProfileMetric::fillFromProto(const Protos::InputProfileMetric & proto) +{ + id = proto.id(); + input_rows = proto.input_rows(); + input_bytes = proto.input_bytes(); + input_wait_sum_elapsed_us = proto.input_wait_sum_elapsed_us(); + input_wait_max_elapsed_us = proto.input_wait_max_elapsed_us(); + input_wait_min_elapsed_us = proto.input_wait_min_elapsed_us(); +} +void InputProfileMetric::toProto(Protos::InputProfileMetric & proto) const +{ + proto.set_id(id); + proto.set_input_rows(input_rows); + proto.set_input_bytes(input_bytes); + proto.set_input_wait_sum_elapsed_us(input_wait_sum_elapsed_us); + proto.set_input_wait_max_elapsed_us(input_wait_max_elapsed_us); + proto.set_input_wait_min_elapsed_us(input_wait_min_elapsed_us); +} + +ProfileMetricPtr ProfileMetric::fromProto(const Protos::ProfileMetric & proto) +{ + ProfileMetricPtr profile = std::make_shared(); + profile->id = proto.id(); + profile->name = proto.name(); + + if (!proto.children_ids().empty()) + { + for (const auto & id : proto.children_ids()) + profile->children_ids.emplace_back(id); + } + profile->parallel_size = proto.parallel_size(); + profile->min_elapsed_us = proto.min_elapsed_us(); + profile->sum_elapsed_us = proto.sum_elapsed_us(); + profile->max_elapsed_us = proto.max_elapsed_us(); + profile->output_rows = proto.output_rows(); + profile->output_bytes = proto.output_bytes(); + profile->output_wait_sum_elapsed_us = proto.output_wait_sum_elapsed_us(); + profile->output_wait_max_elapsed_us = proto.output_wait_max_elapsed_us(); + profile->output_wait_min_elapsed_us = proto.output_wait_min_elapsed_us(); + + for (const auto & proto_input : proto.inputs()) + { + InputProfileMetric input_profile; + input_profile.fillFromProto(proto_input); + profile->inputs.emplace(input_profile.id, input_profile); + } + + for (const auto & [attribute_type, attribute] : proto.attributes()) + { + AttributeInfoPtr info = std::make_shared(); + info->fillFromProto(attribute); + profile->attributes.emplace(attribute_type, info); + } + return profile; +} + +void ProfileMetric::toProto(Protos::ProfileMetric & proto) +{ + proto.set_id(id); + proto.set_name(name); + for (auto & child_id : children_ids) + proto.add_children_ids(child_id); + proto.set_parallel_size(parallel_size); + + proto.set_sum_elapsed_us(sum_elapsed_us); + proto.set_min_elapsed_us(min_elapsed_us); + proto.set_max_elapsed_us(max_elapsed_us); + + proto.set_output_rows(output_rows); + proto.set_output_bytes(output_bytes); + proto.set_output_wait_sum_elapsed_us(output_wait_sum_elapsed_us); + proto.set_output_wait_max_elapsed_us(output_wait_max_elapsed_us); + proto.set_output_wait_min_elapsed_us(output_wait_min_elapsed_us); + + for (auto & input : inputs) + input.second.toProto(*proto.add_inputs()); + + for (auto & att : attributes) + { + auto * att_proto = &(*proto.mutable_attributes())[att.first]; + att.second->toProto(*att_proto); + } +} + +StepProfiles ProfileMetric::aggregateStepProfileBetweenWorkers(AddressToStepProfile & addr_to_step_profile) +{ + StepProfiles res; + for (auto & [address, stepid_to_profile] : addr_to_step_profile) + { + for (auto & [step_id, step_profile] : stepid_to_profile) + { + if (!res.contains(step_id)) + { + step_profile->worker_cnt = 1; + if (!step_profile->attributes.empty()) + step_profile->address_to_attributes[address] = step_profile->attributes; + res[step_id] = step_profile; + } + else + { + auto & profile_ptr = res.at(step_id); + profile_ptr->max_elapsed_us = std::max(profile_ptr->max_elapsed_us, step_profile->sum_elapsed_us); + profile_ptr->min_elapsed_us = std::min(profile_ptr->min_elapsed_us, step_profile->sum_elapsed_us); + profile_ptr->sum_elapsed_us += step_profile->sum_elapsed_us; + profile_ptr->worker_cnt++; + profile_ptr->output_wait_max_elapsed_us + = std::max(profile_ptr->output_wait_max_elapsed_us, step_profile->output_wait_max_elapsed_us); + profile_ptr->output_wait_min_elapsed_us + = std::min(profile_ptr->output_wait_min_elapsed_us, step_profile->output_wait_max_elapsed_us); + profile_ptr->output_wait_sum_elapsed_us += step_profile->output_wait_max_elapsed_us; + profile_ptr->output_rows += step_profile->output_rows; + profile_ptr->output_bytes += step_profile->output_bytes; + + for (auto & [id, input_profile] : step_profile->inputs) + { + profile_ptr->inputs[id].input_wait_sum_elapsed_us += input_profile.input_wait_max_elapsed_us; + profile_ptr->inputs[id].input_wait_max_elapsed_us + = std::max(profile_ptr->inputs[id].input_wait_max_elapsed_us, input_profile.input_wait_max_elapsed_us); + profile_ptr->inputs[id].input_wait_min_elapsed_us + = std::min(profile_ptr->inputs[id].input_wait_min_elapsed_us, input_profile.input_wait_max_elapsed_us); + profile_ptr->inputs[id].input_rows += input_profile.input_rows; + profile_ptr->inputs[id].input_bytes += input_profile.input_bytes; + } + if (!step_profile->attributes.empty()) + profile_ptr->address_to_attributes[address] = step_profile->attributes; + } + } + } + return res; +} + +PlanSegmentProfilePtr PlanSegmentProfile::fromProto(const Protos::PlanSegmentProfileRequest & proto) +{ + PlanSegmentProfilePtr segment_profile = std::make_shared(); + segment_profile->query_id = proto.query_id(); + segment_profile->segment_id = proto.segment_id(); + segment_profile->is_succeed = proto.is_succeed(); + segment_profile->worker_address = proto.worker_address(); + if (proto.has_profile_root_id()) + segment_profile->profile_root_id = proto.profile_root_id(); + + for (const auto & [profile_id, profile_proto] : proto.profiles()) + { + auto profile = ProfileMetric::fromProto(profile_proto); + segment_profile->profiles.emplace(profile_id, profile); + } + if (proto.has_read_rows()) + segment_profile->read_rows = proto.read_rows(); + if (proto.has_read_bytes()) + segment_profile->read_bytes = proto.read_bytes(); + if (proto.has_total_cpu_ms()) + segment_profile->total_cpu_ms = proto.total_cpu_ms(); + if (proto.has_query_duration_ms()) + segment_profile->query_duration_ms = proto.query_duration_ms(); + if (proto.has_io_wait_ms()) + segment_profile->io_wait_ms = proto.io_wait_ms(); + if (proto.has_error_message()) + segment_profile->error_message = proto.error_message(); + return segment_profile; +} + +void PlanSegmentProfile::toProto(Protos::PlanSegmentProfileRequest & proto) +{ + proto.set_query_id(query_id); + proto.set_segment_id(segment_id); + proto.set_is_succeed(is_succeed); + proto.set_worker_address(worker_address); + + proto.set_profile_root_id(profile_root_id); + + for (auto & p : profiles) + { + auto * profile_proto = &(*proto.mutable_profiles())[p.first]; + p.second->toProto(*profile_proto); + } + proto.set_read_rows(read_rows); + proto.set_read_bytes(read_bytes); + proto.set_total_cpu_ms(total_cpu_ms); + proto.set_query_duration_ms(query_duration_ms); + proto.set_io_wait_ms(io_wait_ms); + proto.set_error_message(error_message); +} + +} diff --git a/src/Interpreters/profile/PlanSegmentProfile.h b/src/Interpreters/profile/PlanSegmentProfile.h new file mode 100644 index 00000000000..29d9bf65e45 --- /dev/null +++ b/src/Interpreters/profile/PlanSegmentProfile.h @@ -0,0 +1,92 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +struct RuntimeAttributeDescription; +using AttributeInfoPtr = std::shared_ptr; + +struct InputProfileMetric +{ + UInt64 id; + UInt64 input_rows; + UInt64 input_bytes; + UInt64 input_wait_sum_elapsed_us = 0; + UInt64 input_wait_max_elapsed_us = 0; + UInt64 input_wait_min_elapsed_us{UINT64_MAX}; + void fillFromProto(const Protos::InputProfileMetric & proto); + void toProto(Protos::InputProfileMetric & proto) const; +}; + +struct ProfileMetric; +using ProfileMetricPtr = std::shared_ptr; +using ProfileMetrics = std::vector; + +struct ProfileMetric +{ + UInt64 id; + String name; // only use for pipeline profile + std::vector children_ids; + UInt32 parallel_size; // only use for pipeline profile + UInt32 worker_cnt = 0; + + UInt64 sum_elapsed_us; + UInt64 max_elapsed_us; + UInt64 min_elapsed_us; + + UInt64 output_rows; + UInt64 output_bytes; + UInt64 output_wait_sum_elapsed_us = 0; + UInt64 output_wait_max_elapsed_us = 0; + UInt64 output_wait_min_elapsed_us{UINT64_MAX}; + std::unordered_map inputs; + + std::unordered_map attributes; // only use for plan profile + + std::unordered_map> address_to_attributes; + + static ProfileMetricPtr fromProto(const Protos::ProfileMetric & proto); + void toProto(Protos::ProfileMetric & proto); + static StepProfiles aggregateStepProfileBetweenWorkers(AddressToStepProfile & addr_to_step_profile); +}; + +struct PlanSegmentProfile; +using PlanSegmentProfilePtr = std::shared_ptr; +using PlanSegmentProfiles = std::vector; +struct PlanSegmentProfile +{ + String query_id; + UInt64 segment_id; + bool is_succeed; + String worker_address; + + UInt64 profile_root_id; + std::unordered_map profiles; + + std::unordered_map attributes; + + UInt64 read_rows; + UInt64 read_bytes; + UInt64 total_cpu_ms{}; + UInt64 query_duration_ms{}; + UInt64 io_wait_ms{}; + + String error_message; + +public: + PlanSegmentProfile() = default; + explicit PlanSegmentProfile(String query_id_, UInt64 segment_id_) : query_id(query_id_), segment_id(segment_id_) + { + } + static PlanSegmentProfilePtr fromProto(const Protos::PlanSegmentProfileRequest & proto); + void toProto(Protos::PlanSegmentProfileRequest & proto); +}; + +} diff --git a/src/Interpreters/tests/gtest_ansi_setting.cpp b/src/Interpreters/tests/gtest_ansi_setting.cpp new file mode 100644 index 00000000000..e6b2d0748f7 --- /dev/null +++ b/src/Interpreters/tests/gtest_ansi_setting.cpp @@ -0,0 +1,18 @@ +#include +#include + +#include + +using namespace DB; + +TEST(AnsiSettings, TestApplyOnlyIfDialectChange) +{ + auto context = Context::createCopy(getContext().context); + context->applySettingsChanges(SettingsChanges{{"dialect_type", "CLICKHOUSE"}}); + context->applySettingsChanges(SettingsChanges{{"cast_keep_nullable", "1"}}); + context->applySettingsChanges(SettingsChanges{{"dialect_type", "CLICKHOUSE"}}); + EXPECT_EQ(context->getSettingsRef().cast_keep_nullable.value, bool{1}); + context->applySettingsChanges(SettingsChanges{{"dialect_type", "ANSI"}}); + context->applySettingsChanges(SettingsChanges{{"dialect_type", "CLICKHOUSE"}}); + EXPECT_EQ(context->getSettingsRef().cast_keep_nullable.value, bool{0}); +} diff --git a/src/Interpreters/tests/gtest_standard_token.cpp b/src/Interpreters/tests/gtest_standard_token.cpp new file mode 100644 index 00000000000..81e578c9d2f --- /dev/null +++ b/src/Interpreters/tests/gtest_standard_token.cpp @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include +#include + +using namespace DB; + + +TEST(TokenExtractor, StandardToken) +{ + size_t pos = 0; + StandardTokenExtractor tokenizer; + + size_t token_start = 0; + size_t token_length = 0; + + + size_t index = 0; + std::string test_str_1 = "ByConity是分布式的云原生SQL数仓引擎"; + std::vector test_token_1 = {"ByConity","是","分","布","式","的","云","原","生","SQL","数","仓","引","擎"}; + while(tokenizer.nextInString( + test_str_1.data(), test_str_1.length(), &pos, &token_start, &token_length)) + { + auto tmp_str = std::string(test_str_1.data()+token_start, token_length); + ASSERT_EQ(tmp_str, test_token_1[index]); + ++index; + } + + pos = 0; + index = 0; + token_start = 0; + token_length = 0; + std::string test_str_2 = "StandardToken:分词器,可以跳过ASCII符号.,/!@#$@()-空格等并整块切分english token,123456789和单个切分中文"; + std::vector test_token_2 = { + "StandardToken","分","词","器","可","以","跳","过","ASCII", + "符","号","空","格","等","并","整","块","切","分", + "english","token","123456789","和","单","个","切","分","中","文" + }; + + while(tokenizer.nextInString( + test_str_2.data(), test_str_2.length(), &pos, &token_start, &token_length)) + { + auto tmp_str = std::string(test_str_2.data()+token_start, token_length); + ASSERT_EQ(tmp_str, test_token_2[index]); + ++index; + } +} + +TEST(TokenExtractor, StandardTokenLike) +{ + size_t pos = 0; + StandardTokenExtractor tokenizer; + std::string tmp_token; + size_t index = 0; + + pos = 0; + std::string test_str_1 = "%NOTOKEN%"; + while(tokenizer.nextInStringLike(test_str_1.data(), test_str_1.length(), &pos, tmp_token)) + { + if(!tmp_token.empty()) + { + throw std::runtime_error("should no token here"); + } + } + + pos = 0; + std::string test_str_2 = "%NOTOKEN"; + while(tokenizer.nextInStringLike(test_str_2.data(), test_str_2.length(), &pos, tmp_token)) + { + if(!tmp_token.empty()) + { + throw std::runtime_error("should no token here"); + } + } + + pos = 0; + std::string test_str_3 = "NOTOKEN%"; + while(tokenizer.nextInStringLike(test_str_3.data(), test_str_3.length(), &pos, tmp_token)) + { + if(!tmp_token.empty()) + { + throw std::runtime_error("should no token here"); + } + } + + + pos = 0; + std::string test_str_4 = "NO_TOKEN"; + while(tokenizer.nextInStringLike(test_str_4.data(), test_str_4.length(), &pos, tmp_token)) + { + if(!tmp_token.empty()) + { + throw std::runtime_error("should no token here"); + } + } + + index = 0; + pos = 0; + std::string test_str_5 = "%这里_会有中文token%"; + std::vector test_tokens_5 = {"这","里","会","有","中","文"}; + while(tokenizer.nextInStringLike(test_str_5.data(), test_str_5.length(), &pos, tmp_token)) + { + ASSERT_EQ(tmp_token, test_tokens_5[index]); + index++; + } + + index = 0; + pos = 0; + std::string test_str_6 = "%这里_,english %Token也有%"; + std::vector test_tokens_6 = {"这","里","english","也","有"}; + while(tokenizer.nextInStringLike(test_str_6.data(), test_str_6.length(), &pos, tmp_token)) + { + ASSERT_EQ(tmp_token, test_tokens_6[index]); + index++; + } +} diff --git a/src/Interpreters/trySetVirtualWarehouse.cpp b/src/Interpreters/trySetVirtualWarehouse.cpp index ee64dbd2330..31b10baa4a3 100644 --- a/src/Interpreters/trySetVirtualWarehouse.cpp +++ b/src/Interpreters/trySetVirtualWarehouse.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -68,8 +69,9 @@ static bool trySetVirtualWarehouseFromStorageID(const StorageID table_id, Contex LOG_DEBUG( &Poco::Logger::get("trySetVirtualWarehouse"), - "try get warehouse from {}, type is WRITE {}", + "set vw to {} from cnch table {}, type is WRITE {}", vw_name, + table_id.getNameForLogs(), VirtualWarehouseType::Write == vw_type); setVirtualWarehouseByName(vw_name, context); return true; @@ -111,8 +113,9 @@ static bool trySetVirtualWarehouseFromStorageID(const StorageID table_id, Contex LOG_DEBUG( &Poco::Logger::get("trySetVirtualWarehouse"), - "try get warehouse from {}, type is WRITE {}", + "set vw to {} from nested cnch table {}, type is WRITE {}", nested_vw_name, + nested_table->getStorageID().getNameForLogs(), VirtualWarehouseType::Write == vw_type); setVirtualWarehouseByName(nested_vw_name, context); return true; @@ -294,6 +297,12 @@ static bool trySetVirtualWarehouseFromAST(const ASTPtr & ast, ContextMutablePtr if (trySetVirtualWarehouseFromTable(database, refresh_mv->table, context)) return true; } + else if (auto * create = ast->as()) + { + /// No need to set vw for create query. + /// For CTAS, the data filling work is implemented as ASTInsertQuery (insert select) + return false; + } } while (false); diff --git a/src/MergeTreeCommon/CnchStorageCommon.cpp b/src/MergeTreeCommon/CnchStorageCommon.cpp index 49d97236cc8..e0800c45d3d 100644 --- a/src/MergeTreeCommon/CnchStorageCommon.cpp +++ b/src/MergeTreeCommon/CnchStorageCommon.cpp @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -303,7 +302,7 @@ void CnchStorageCommonHelper::filterCondition( String CnchStorageCommonHelper::getCreateQueryForCloudTable( const String & query, const String & local_table_name, - const ContextPtr & context, + const ContextPtr & /*context*/, bool enable_staging_area, const std::optional & cnch_storage_id, const Strings & engine_args, @@ -318,63 +317,19 @@ String CnchStorageCommonHelper::getCreateQueryForCloudTable( if (!local_database_name.empty()) create_query.database = local_database_name; - auto * storage = create_query.storage; + replaceCnchWithCloud( + create_query.storage, + cnch_storage_id.value_or(table_id).getDatabaseName(), + cnch_storage_id.value_or(table_id).getTableName(), + engine_type, + engine_args); - auto engine = std::make_shared(); - engine->name = storage->engine->name.replace(0, strlen("Cnch"), "Cloud"); - engine->arguments = std::make_shared(); - engine->arguments->children.emplace_back(std::make_shared(cnch_storage_id.value_or(table_id).getDatabaseName())); - engine->arguments->children.emplace_back(std::make_shared(cnch_storage_id.value_or(table_id).getTableName())); - if (!engine_args.empty()) - { - for (const auto & arg : engine_args) - { - engine->arguments->children.emplace_back(std::make_shared(arg)); - } - } - else if (storage->engine->arguments) - { - for (const auto & arg : storage->engine->arguments->children) - { - engine->arguments->children.push_back(arg); - } - } - - storage->set(storage->engine, engine); - - if (startsWith(engine->name, "Cloud")) /// table settings for *MergeTree engines - { - modifyOrAddSetting(create_query, "cnch_temporary_table", Field(UInt64(1))); - - if (enable_staging_area) - modifyOrAddSetting(create_query, "cloud_enable_staging_area", Field(UInt64(1))); - } - else if(engine->name == "CnchHive" || engine->name == "CnchHDFS" || engine->name == "CnchS3") - { - modifyOrAddSetting(create_query, "cnch_temporary_table", Field(UInt64(1))); - } - - /// query settings - auto query_settings = std::make_shared(); - query_settings->is_standalone = false; - - if (context) - query_settings->changes = context->getSettingsRef().getChangedSettings(); - - if (create_query.settings_ast) - { - auto & settings_ast = create_query.settings_ast->as(); - if (!query_settings->changes.empty()) - { - for (const auto & change: settings_ast.changes) - modifyOrAddSetting(*query_settings, change.name, std::move(change.value)); - } - else - query_settings->changes = std::move(settings_ast.changes); - } + // perhaps better to enable if_not_exists by default + if (engine_type == WorkerEngineType::DICT) + create_query.if_not_exists = true; - if (!query_settings->changes.empty()) - create_query.setOrReplaceAST(create_query.settings_ast, query_settings); + if (enable_staging_area) + modifyOrAddSetting(create_query, "cloud_enable_staging_area", Field(UInt64(1))); WriteBufferFromOwnString statement_buf; formatAST(create_query, statement_buf, false); diff --git a/src/MergeTreeCommon/CnchStorageCommon.h b/src/MergeTreeCommon/CnchStorageCommon.h index c610fde086d..59542bc252e 100644 --- a/src/MergeTreeCommon/CnchStorageCommon.h +++ b/src/MergeTreeCommon/CnchStorageCommon.h @@ -15,6 +15,7 @@ #pragma once +#include #include #include #include @@ -69,23 +70,6 @@ enum class CNCHStorageMediumType String toStr(CNCHStorageMediumType tp); CNCHStorageMediumType fromStr(const String & type_str); -enum class WorkerEngineType : uint8_t -{ - CLOUD, - DICT, -}; - -inline static String toString(WorkerEngineType type) -{ - switch (type) - { - case WorkerEngineType::CLOUD: - return "Cloud"; - case WorkerEngineType::DICT: - return "DictCloud"; - } -} - class CnchStorageCommonHelper { public: @@ -121,6 +105,8 @@ class CnchStorageCommonHelper // when move these conditions from where to implicit_where. static ASTs getConditions(const ASTPtr & ast); + // TODO: too many arguments, try remove `enable_staging_area', `cnch_storage_id', `engine_args', `local_database_name'. + // check StorageCnchMergeTree::genViewDependencyCreateQueries to see whether it's possible String getCreateQueryForCloudTable( const String & query, const String & local_table_name, diff --git a/src/MergeTreeCommon/GlobalDataManager.cpp b/src/MergeTreeCommon/GlobalDataManager.cpp index 574d47ca695..e73839f5b25 100644 --- a/src/MergeTreeCommon/GlobalDataManager.cpp +++ b/src/MergeTreeCommon/GlobalDataManager.cpp @@ -13,11 +13,11 @@ void GlobalDataManager::loadDataPartsWithDBM( const UUID & storage_uuid, const UInt64 table_version, const WGWorkerInfoPtr & runtime_worker_info, - ServerDataPartsWithDBM & server_parts) + std::unordered_map & server_parts, + std::vector> & partitions) { auto storage_manager = getStorageDataManager(storage_uuid, runtime_worker_info); - - return storage_manager->loadDataPartsWithDBM(storage, table_version, server_parts); + return storage_manager->loadDataPartsWithDBM(storage, table_version, server_parts, partitions); } StorageDataManagerPtr GlobalDataManager::getStorageDataManager(const UUID & storage_uuid, const WGWorkerInfoPtr & runtime_worker_info) diff --git a/src/MergeTreeCommon/GlobalDataManager.h b/src/MergeTreeCommon/GlobalDataManager.h index 00f9abf6c9a..e22aa2eb43d 100644 --- a/src/MergeTreeCommon/GlobalDataManager.h +++ b/src/MergeTreeCommon/GlobalDataManager.h @@ -15,7 +15,8 @@ class GlobalDataManager : public WithContext const UUID & storage_uuid, const UInt64 table_version, const WGWorkerInfoPtr & runtime_worker_info, - ServerDataPartsWithDBM & server_parts); + std::unordered_map & server_parts, + std::vector> & partitions); private: diff --git a/src/MergeTreeCommon/MergeTreeDataDeduper.cpp b/src/MergeTreeCommon/MergeTreeDataDeduper.cpp index ad4d16956be..4dd987cb7dd 100644 --- a/src/MergeTreeCommon/MergeTreeDataDeduper.cpp +++ b/src/MergeTreeCommon/MergeTreeDataDeduper.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace DB { @@ -37,8 +38,9 @@ namespace ErrorCodes using IndexFileIteratorPtr = std::unique_ptr; using IndexFileIterators = std::vector; -MergeTreeDataDeduper::MergeTreeDataDeduper(const MergeTreeMetaBase & data_, ContextPtr context_) - : data(data_), context(context_), log(&Poco::Logger::get(data_.getLogName() + " (Deduper)")) +MergeTreeDataDeduper::MergeTreeDataDeduper( + const MergeTreeMetaBase & data_, ContextPtr context_, const CnchDedupHelper::DedupMode & dedup_mode_) + : data(data_), context(context_), log(&Poco::Logger::get(data_.getLogName() + " (Deduper)")), dedup_mode(dedup_mode_) { if (data.merging_params.hasExplicitVersionColumn()) version_mode = VersionMode::ExplicitVersion; @@ -46,6 +48,12 @@ MergeTreeDataDeduper::MergeTreeDataDeduper(const MergeTreeMetaBase & data_, Cont version_mode = VersionMode::PartitionValueAsVersion; else version_mode = VersionMode::NoVersion; + + if (dedup_mode == CnchDedupHelper::DedupMode::APPEND) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Dedup mode in dedup process is APPEND for table {}, it's a bug!", + data.getCnchStorageID().getNameForLogs()); } namespace @@ -108,8 +116,7 @@ void MergeTreeDataDeduper::dedupKeysWithParts( const IMergeTreeDataPartsVector & parts, DeleteBitmapVector & delta_bitmaps, DedupTaskProgressReporter reporter, - DedupTaskPtr & dedup_task, - DedupKeyMode dedup_key_mode) + DedupTaskPtr & dedup_task) { const IndexFile::Comparator * comparator = IndexFile::BytewiseComparator(); @@ -209,7 +216,7 @@ void MergeTreeDataDeduper::dedupKeysWithParts( else { exact_match = true; - if (dedup_key_mode == DedupKeyMode::THROW) + if (dedup_mode == CnchDedupHelper::DedupMode::THROW) throw Exception("Found duplication when insert with setting dedup_key_mode=DedupKeyMode::THROW", ErrorCodes::INCORRECT_DATA); } @@ -217,7 +224,7 @@ void MergeTreeDataDeduper::dedupKeysWithParts( { RowPos lhs = ReplacingSortedKeysIterator::decodeCurrentRowPos(base_iter, version_mode, parts, base_implicit_versions); const RowPos & rhs = keys->CurrentRowPos(); - if (keys->IsCurrentLowPriority()) + if (keys->IsCurrentLowPriority() || dedup_mode == CnchDedupHelper::DedupMode::IGNORE) addRowIdToBitmap(delta_bitmaps[rhs.child + parts.size()], rhs.rowid); else { @@ -350,8 +357,7 @@ LocalDeleteBitmaps MergeTreeDataDeduper::dedupParts( txn_id.toUInt64()); if (base_bitmap) { - UInt64 bitmap_version = new_parts[i - visible_parts.size()]->getDeleteBitmapVersion(); - if (bitmap_version == txn_id.toUInt64()) + if (new_parts[i - visible_parts.size()]->delete_flag) { LOG_TRACE( log, @@ -442,20 +448,22 @@ LocalDeleteBitmaps MergeTreeDataDeduper::dedupParts( size_t num_bitmaps_to_dump = prepare_bitmaps_to_dump(visible_parts, new_parts, bitmaps); LOG_DEBUG( log, - "Dedup {} in {} ms, visible parts={}, new parts={}, result bitmaps={}, txn_id: {}", + "Dedup {} in {} ms, visible parts={}, new parts={}, result bitmaps={}, txn_id: {}, dedup mode: {}", dedup_task_local->getDedupLevelInfo(), sub_task_watch.elapsedMilliseconds(), visible_parts.size(), new_parts.size(), num_bitmaps_to_dump, - txn_id.toUInt64()); + txn_id.toUInt64(), + CnchDedupHelper::typeToString(dedup_mode)); }); } dedup_pool.wait(); LOG_DEBUG( log, - "Dedup {} tasks in {} ms, thread pool={}, visible parts={}, staged parts={}, uncommitted_parts = {}, result bitmaps={}, txn_id: {}", + "Dedup {} tasks in {} ms, thread pool={}, visible parts={}, staged parts={}, uncommitted_parts = {}, result bitmaps={}, txn_id: " + "{}, dedup mode: {}", dedup_tasks.size(), watch.elapsedMilliseconds(), dedup_pool_size, @@ -463,7 +471,8 @@ LocalDeleteBitmaps MergeTreeDataDeduper::dedupParts( all_staged_parts.size(), all_uncommitted_parts.size(), res.size(), - txn_id.toUInt64()); + txn_id.toUInt64(), + CnchDedupHelper::typeToString(dedup_mode)); return res; } @@ -722,7 +731,7 @@ MergeTreeDataDeduper::dedupImpl(const IMergeTreeDataPartsVector & visible_parts, return os.str(); }; - dedupKeysWithParts(dedup_task->iter, visible_parts, res, task_progress_reporter, dedup_task, context->getSettings().dedup_key_mode); + dedupKeysWithParts(dedup_task->iter, visible_parts, res, task_progress_reporter, dedup_task); return res; } diff --git a/src/MergeTreeCommon/MergeTreeDataDeduper.h b/src/MergeTreeCommon/MergeTreeDataDeduper.h index 857464180e8..d57ad9c8b01 100644 --- a/src/MergeTreeCommon/MergeTreeDataDeduper.h +++ b/src/MergeTreeCommon/MergeTreeDataDeduper.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace DB { @@ -34,7 +35,10 @@ class MergeTreeDataDeduper using RowPos = ReplacingSortedKeysIterator::RowPos; using DeleteCallback = ReplacingSortedKeysIterator::DeleteCallback; - MergeTreeDataDeduper(const MergeTreeMetaBase & data_, ContextPtr context_); + MergeTreeDataDeduper( + const MergeTreeMetaBase & data_, + ContextPtr context_, + const CnchDedupHelper::DedupMode & dedup_mode_); /// Remove duplicate keys among visible, staged, and uncommitted parts. /// Assumes that @@ -102,8 +106,7 @@ class MergeTreeDataDeduper const IMergeTreeDataPartsVector & parts, DeleteBitmapVector & delta_bitmaps, DedupTaskProgressReporter reporter, - DedupTaskPtr & dedup_task, - DedupKeyMode dedup_key_mode = DedupKeyMode::REPLACE); + DedupTaskPtr & dedup_task); /// Convert dedup task into multiple sub dedup tasks. If valid_bucket_table is true, it will split dedup task into bucket granule. DedupTasks convertIntoSubDedupTasks( @@ -124,6 +127,7 @@ class MergeTreeDataDeduper ContextPtr context; Poco::Logger * log; VersionMode version_mode; + CnchDedupHelper::DedupMode dedup_mode; }; } diff --git a/src/MergeTreeCommon/MergeTreeMetaBase.cpp b/src/MergeTreeCommon/MergeTreeMetaBase.cpp index 262fbd69d8f..51d33f22812 100644 --- a/src/MergeTreeCommon/MergeTreeMetaBase.cpp +++ b/src/MergeTreeCommon/MergeTreeMetaBase.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -44,10 +45,11 @@ #include #include #include +#include #include #include #include -#include +#include #include #include #include @@ -141,9 +143,7 @@ MergeTreeMetaBase::MergeTreeMetaBase( { try { - checkPartitionKeyAndInitMinMax(metadata_.partition_key); - setProperties(metadata_, metadata_, false); if (minmax_idx_date_column_pos == -1) throw Exception("Could not find Date column", ErrorCodes::BAD_TYPE_OF_FIELD); } @@ -158,39 +158,18 @@ MergeTreeMetaBase::MergeTreeMetaBase( { is_custom_partitioned = true; checkPartitionKeyAndInitMinMax(metadata_.partition_key); - setProperties(metadata_, metadata_, false); } - format_version = MERGE_TREE_DATA_MIN_FORMAT_VERSION_WITH_CUSTOM_PARTITIONING; + storage_address = fmt::format("{}", fmt::ptr(this)); /// NOTE: using the same columns list as is read when performing actual merges. - merging_params.check(metadata_, metadata_.hasUniqueKey()); - - if (merging_params.partitionValueAsVersion()) - { - if (metadata_.partition_key.sample_block.columns() == 0) - throw Exception("Table is not partitioned, can't use partition value as version", ErrorCodes::BAD_ARGUMENTS); - if (metadata_.partition_key.sample_block.columns() > 1) - throw Exception("Partition key contains more than one column, can't use it as version", ErrorCodes::BAD_ARGUMENTS); - auto partition_key_type = metadata_.partition_key.sample_block.getDataTypes()[0]; - if (!partition_key_type->canBeUsedAsVersion()) - throw Exception("Partition key has type " + partition_key_type->getName() + ", can't be used as version", ErrorCodes::BAD_ARGUMENTS); - } - - if (metadata_.hasUniqueKey() && !attach_) - checkVersionColumnConstraint(); + merging_params.check(metadata_, attach_); if (metadata_.sampling_key.definition_ast != nullptr) { /// This is for backward compatibility. checkSampleExpression(metadata_, getSettings()->compatibility_allow_sampling_expression_not_in_primary_key); } - - checkTTLExpressions(metadata_, metadata_); - - storage_address = fmt::format("{}", fmt::ptr(this)); - - setServerVwName(getSettings()->cnch_server_vw); } StoragePolicyPtr MergeTreeMetaBase::getStoragePolicy(StorageLocation location) const @@ -213,7 +192,7 @@ const String& MergeTreeMetaBase::getRelativeDataPath(StorageLocation location) c return relative_data_path; } -void MergeTreeMetaBase::setRelativeDataPath(StorageLocation location, const String& rel_path) +void MergeTreeMetaBase::setRelativeDataPath(StorageLocation location, const String & rel_path) { if (unlikely(location == StorageLocation::AUXILITY)) { @@ -692,6 +671,16 @@ String MergeTreeMetaBase::getFullPathOnDisk(StorageLocation location, const Disk return disk->getPath() + getRelativeDataPath(location); } +bool MergeTreeMetaBase::supportsParallelInsert(ContextPtr local_context) const +{ + if (!getInMemoryMetadataPtr()->hasUniqueKey()) + return true; + + if (!local_context->getSettingsRef().optimize_unique_table_write) + return false; + return getSettings()->dedup_impl_version.value == DedupImplVersion::DEDUP_IN_TXN_COMMIT; +} + NamesAndTypesList MergeTreeMetaBase::getVirtuals() const { /// Array(Tuple(String, String)) @@ -707,6 +696,7 @@ NamesAndTypesList MergeTreeMetaBase::getVirtuals() const NameAndTypePair("_partition_id", std::make_shared()), NameAndTypePair("_partition_value", getPartitionValueType()), NameAndTypePair("_sample_factor", std::make_shared()), + NameAndTypePair("_part_offset", std::make_shared()), NameAndTypePair("_part_row_number", std::make_shared()), NameAndTypePair("_bucket_number", std::make_shared()), RowExistsColumn::ROW_EXISTS_COLUMN, @@ -818,27 +808,34 @@ Block MergeTreeMetaBase::getBlockWithVirtualPartColumns(const DataPartsVector & return block; } -Block MergeTreeMetaBase::getBlockWithVirtualPartitionColumns( +Block MergeTreeMetaBase::getPartitionBlockWithVirtualColumns( const std::vector> & partition_list) const { + auto block = getInMemoryMetadataPtr()->partition_key.sample_block; DataTypePtr partition_value_type = getPartitionValueType(); - bool has_partition_value = typeid_cast(partition_value_type.get()); - Block block{ - ColumnWithTypeAndName(ColumnString::create(), std::make_shared(), "_partition_id"), - ColumnWithTypeAndName(partition_value_type->createColumn(), partition_value_type, "_partition_value")}; - + block.insert(ColumnWithTypeAndName(ColumnString::create(), std::make_shared(), "_partition_id")); + block.insert(ColumnWithTypeAndName(partition_value_type->createColumn(), partition_value_type, "_partition_value")); MutableColumns columns = block.mutateColumns(); - auto & partition_id_column = columns[0]; - auto & partition_value_column = columns[1]; + bool has_partition_value = typeid_cast(partition_value_type.get()); + auto block_size = block.columns(); + + auto & partition_id_column = columns[block_size-2]; + auto & partition_value_column = columns[block_size-1]; + + std::for_each(columns.begin(), columns.end(), [&](auto & column) { column->reserve(partition_list.size()); }); for (const auto & partition : partition_list) { partition_id_column->insert(partition->getID(*this)); - Tuple tuple(partition->value.begin(), partition->value.end()); if (has_partition_value) + { + for (size_t i = 0; i < partition->value.size(); i++) + columns[i]->insert(partition->value[i]); + Tuple tuple(partition->value.begin(), partition->value.end()); partition_value_column->insert(std::move(tuple)); + } } block.setColumns(std::move(columns)); if (!has_partition_value) @@ -1070,33 +1067,6 @@ MergeTreeMetaBase::getDataPartsVectorInPartition(MergeTreeMetaBase::DataPartStat data_parts_by_state_and_info.lower_bound(state_with_partition), data_parts_by_state_and_info.upper_bound(state_with_partition)); } -ServerDataPartsVector MergeTreeMetaBase::getServerDataPartsInPartitions(const Strings & required_partitions) -{ - ServerDataPartsVector server_parts; - DeleteBitmapMetaPtrVector delete_bitmaps; - { - auto lock = lockPartsRead(); - for (const String & partition_id : required_partitions) - { - const auto & parts_with_dbm = server_data_parts[partition_id]; - server_parts.insert(server_parts.end(), parts_with_dbm.first.begin(), parts_with_dbm.first.end()); - delete_bitmaps.insert(delete_bitmaps.end(), parts_with_dbm.second.begin(), parts_with_dbm.second.end()); - } - } - auto visible_server_parts = CnchPartsHelper::calcVisibleParts(server_parts, false, CnchPartsHelper::LoggingOption::DisableLogging, true); - - if (getInMemoryMetadataPtr()->hasUniqueKey() && !visible_server_parts.empty()) - getDeleteBitmapMetaForServerParts(visible_server_parts, delete_bitmaps); - - return visible_server_parts; -} - -MergeTreeMetaBase::MergeTreePartitions MergeTreeMetaBase::getAllPartitions() const -{ - auto lock = lockPartsRead(); - return data_partitions; -} - MergeTreeMetaBase::DataParts MergeTreeMetaBase::getDataParts() const { return getDataParts({DataPartState::Committed}); @@ -1432,22 +1402,9 @@ MergeTreeMetaBase::DataPartPtr MergeTreeMetaBase::getAnyPartInPartition( return nullptr; } -void MergeTreeMetaBase::checkVersionColumnConstraint() -{ - if (merging_params.partitionValueAsVersion()) - { - auto partition_types = getInMemoryMetadataPtr()->partition_key.sample_block.getDataTypes(); - if (partition_types.size() >= 1) - { - auto & type = partition_types[0]; - if (TypeIndex::UInt64 < type->getTypeId() && type->getTypeId() <= TypeIndex::Int256) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "The type of version column is {}, it is not compatible with UInt64", type->getName()); - } - } -} - -void MergeTreeMetaBase::MergingParams::check(const StorageInMemoryMetadata & metadata, bool has_unique_key) const +void MergeTreeMetaBase::MergingParams::check(const StorageInMemoryMetadata & metadata, bool attach) const { + const bool has_unique_key = metadata.hasUniqueKey(); const auto columns = metadata.getColumns().getAllPhysical(); if (!sign_column.empty() && mode != MergingParams::Collapsing && mode != MergingParams::VersionedCollapsing) @@ -1551,8 +1508,31 @@ void MergeTreeMetaBase::MergingParams::check(const StorageInMemoryMetadata & met } } - if (has_unique_key && !partitionValueAsVersion()) - check_version_column(true, "Unique Key"); + + if (has_unique_key) + { + if (partitionValueAsVersion()) + { + if (metadata.partition_key.sample_block.columns() == 0) + throw Exception("Table is not partitioned, can't use partition value as version", ErrorCodes::BAD_ARGUMENTS); + if (metadata.partition_key.sample_block.columns() > 1) + throw Exception("Partition key contains more than one column, can't use it as version", ErrorCodes::BAD_ARGUMENTS); + auto partition_key_type = metadata.partition_key.sample_block.getDataTypes()[0]; + if (!partition_key_type->canBeUsedAsVersion()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Partition key has type {}, can't be used as version", partition_key_type->getName()); + // singed integer and types larger than 64 bits are not supported currently + if (!attach && TypeIndex::UInt64 < partition_key_type->getTypeId() && partition_key_type->getTypeId() <= TypeIndex::Int256) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Partition key has type {}, can't be used as version", partition_key_type->getName()); + } + else + { + check_version_column(true, "Unique Key"); + } + } + else if (partitionValueAsVersion()) + { + throw Exception("Table doesn't have UNIQUE KEY, can't use partition value as version", ErrorCodes::BAD_ARGUMENTS); + } if (mode == MergingParams::Replacing) check_version_column(true, "ReplacingMergeTree"); @@ -1919,15 +1899,33 @@ void MergeTreeMetaBase::checkColumnsValidity(const ColumnsDescription & columns, } } -bool MergeTreeMetaBase::commitTxnFromWorkerSide(const StorageMetadataPtr & metadata_snapshot, ContextPtr query_context) const +bool MergeTreeMetaBase::commitTxnInWriteSuffixStage(const UInt32 & deup_impl_version, ContextPtr query_context) const { - if (!metadata_snapshot->hasUniqueKey()) + if (!getInMemoryMetadataPtr()->hasUniqueKey() || static_cast(deup_impl_version) == DedupImplVersion::DEDUP_IN_TXN_COMMIT) return false; bool enable_staging_area = query_context->getSettingsRef().enable_staging_area_for_write || getSettings()->cloud_enable_staging_area; bool enable_append_mode = query_context->getSettingsRef().dedup_key_mode == DedupKeyMode::APPEND; return !enable_append_mode && !enable_staging_area; } +bool MergeTreeMetaBase::supportsWriteInWorkers(const Context & query_context) const +{ + if (!getInMemoryMetadataPtr()->hasUniqueKey()) + return true; + const auto & query_settings = query_context.getSettingsRef(); + const auto & table_settings = getSettings(); + if (query_settings.optimize_unique_table_write) + { + if (table_settings->dedup_impl_version.value == DedupImplVersion::DEDUP_IN_TXN_COMMIT) + return true; + LOG_DEBUG(log, "Can not write in workers due to dedup impl version is {}", table_settings->dedup_impl_version.value); + } + bool enable_staging_area = query_context.getSettingsRef().enable_staging_area_for_write || getSettings()->cloud_enable_staging_area; + bool enable_append_mode = query_context.getSettingsRef().dedup_key_mode == DedupKeyMode::APPEND; + return enable_append_mode || enable_staging_area; +} + + ColumnSize MergeTreeMetaBase::getMapColumnSizes(const DataPartPtr & part, const String & map_implicit_column_name) const { auto part_checksums = part->getChecksums(); @@ -2100,12 +2098,18 @@ ASTPtr MergeTreeMetaBase::applyFilter( return PredicateUtils::combineConjuncts(conjuncts); } -void MergeTreeMetaBase::filterPartitionByTTL(std::vector> & partition_list, ContextPtr local_context) const +bool MergeTreeMetaBase::canFilterPartitionByTTL() const { auto metadata_snapshot = getInMemoryMetadataPtr(); TTLTableDescription table_ttl = metadata_snapshot->getTableTTLs(); - if (metadata_snapshot->hasPartitionLevelTTL() && table_ttl.definition_ast && local_context->getCurrentTransaction()) + return metadata_snapshot->hasPartitionLevelTTL() && table_ttl.definition_ast; +} + +void MergeTreeMetaBase::filterPartitionByTTL(std::vector> & partition_list, time_t query_time) const +{ + if (canFilterPartitionByTTL()) { + const auto & metadata_snapshot = getInMemoryMetadataPtr(); if (!metadata_snapshot->hasRowsTTL()) return; @@ -2142,8 +2146,9 @@ void MergeTreeMetaBase::filterPartitionByTTL(std::vectorsize() != partition_list.size()) throw Exception("Calculated TTL column size cannot match input partitions column size.", ErrorCodes::LOGICAL_ERROR); - TxnTimestamp start_ts = local_context->getCurrentTransactionID(); - time_t query_time = start_ts.toSecond(); + if (query_time == 0) + query_time = std::time(nullptr); + std::vector> filtered_result; if (column->isNullable()) @@ -2151,7 +2156,7 @@ void MergeTreeMetaBase::filterPartitionByTTL(std::vector(column)) { - const auto & date_lut = DateLUT::instance(); + const auto & date_lut = DateLUT::serverTimezoneInstance(); for (size_t index = 0; index < column->size(); index++) { auto ttl_value = date_lut.fromDayNum(DayNum(column_date->getElement(index))); @@ -2176,7 +2181,7 @@ void MergeTreeMetaBase::filterPartitionByTTL(std::vector(column)) // { - // const auto & date_lut = DateLUT::instance(); + // const auto & date_lut = DateLUT::serverTimezoneInstance(); // ttl_value = date_lut.fromDayNum(DayNum(column_date->getElement(index))); // } // else if (const ColumnUInt32 * column_date_time = typeid_cast(column)) @@ -2200,7 +2205,8 @@ Strings MergeTreeMetaBase::selectPartitionsByPredicate( const SelectQueryInfo & query_info, std::vector> & partition_list, const Names & column_names_to_return, - ContextPtr local_context) const + ContextPtr local_context, + const bool & ignore_ttl) const { /// Coarse grained partition pruner: filter out the partition which will definitely not satisfy the query predicate. The benefit /// is 2-folded: (1) we can prune data parts and (2) we can reduce numbers of calls to catalog to get parts 's metadata. @@ -2213,12 +2219,13 @@ Strings MergeTreeMetaBase::selectPartitionsByPredicate( /// (3) `_partition_id` or `_partition_value` if they're in predicate /// (1) Prune partition by partition level TTL - filterPartitionByTTL(partition_list, local_context); + if (!ignore_ttl) + filterPartitionByTTL(partition_list, local_context->tryGetCurrentTransactionID().toSecond()); const auto partition_key = MergeTreePartition::adjustPartitionKey(getInMemoryMetadataPtr(), local_context); const auto & partition_key_expr = partition_key.expression; const auto & partition_key_sample = partition_key.sample_block; - if (local_context->getSettingsRef().enable_partition_prune && partition_key_sample.columns() > 0) + if (partition_key_sample.columns() > 0) { /// (2) Prune partitions if there's a column in predicate that exactly match the partition key Names partition_key_columns; @@ -2258,7 +2265,7 @@ Strings MergeTreeMetaBase::selectPartitionsByPredicate( if (has_partition_column && !partition_list.empty()) { - Block partition_block = getBlockWithVirtualPartitionColumns(partition_list); + Block partition_block = getPartitionBlockWithVirtualColumns(partition_list); ASTPtr expression_ast; /// Generate valid expressions for filtering @@ -2268,6 +2275,7 @@ Strings MergeTreeMetaBase::selectPartitionsByPredicate( NameSet partition_ids; if (expression_ast) { + replace_func_with_known_column(expression_ast, NameSet{partition_key_columns.begin(), partition_key_columns.end()}); VirtualColumnUtils::filterBlockWithQuery(query_info.query, partition_block, local_context, expression_ast); partition_ids = VirtualColumnUtils::extractSingleValueFromBlock(partition_block, "_partition_id"); /// Prunning @@ -2290,7 +2298,7 @@ Strings MergeTreeMetaBase::selectPartitionsByPredicate( return res_partitions; } -void MergeTreeMetaBase::getDeleteBitmapMetaForServerParts(const ServerDataPartsVector & parts, DeleteBitmapMetaPtrVector & all_bitmaps) const +void MergeTreeMetaBase::getDeleteBitmapMetaForServerParts(const ServerDataPartsVector & parts, DeleteBitmapMetaPtrVector & all_bitmaps, bool force_found) const { DeleteBitmapMetaPtrVector bitmaps; CnchPartsHelper::calcVisibleDeleteBitmaps(all_bitmaps, bitmaps); @@ -2299,50 +2307,98 @@ void MergeTreeMetaBase::getDeleteBitmapMetaForServerParts(const ServerDataPartsV auto bitmap_it = bitmaps.begin(); for (const auto & part : parts) { - /// search for the first bitmap - while (bitmap_it != bitmaps.end() && !(*bitmap_it)->sameBlock(part->info())) - bitmap_it++; - - if (bitmap_it == bitmaps.end()) + if (force_found) { - if (auto unique_table_log = getContext()->getCloudUniqueTableLog()) + /// search for the first bitmap + while (bitmap_it != bitmaps.end() && !(*bitmap_it)->sameBlock(part->info())) + bitmap_it++; + + if (bitmap_it == bitmaps.end()) { - auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, getCnchStorageID()); - current_log.metric = ErrorCodes::LOGICAL_ERROR; - current_log.event_msg = "Delete bitmap metadata of " + part->name() + " is not found"; - unique_table_log->add(current_log); + if (auto unique_table_log = getContext()->getCloudUniqueTableLog()) + { + auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, getCnchStorageID()); + current_log.metric = ErrorCodes::LOGICAL_ERROR; + current_log.event_msg = "Delete bitmap metadata of " + part->name() + " is not found"; + unique_table_log->add(current_log); + } + throw Exception("Delete bitmap metadata of " + part->name() + " is not found", ErrorCodes::LOGICAL_ERROR); } - throw Exception("Delete bitmap metadata of " + part->name() + " is not found", ErrorCodes::LOGICAL_ERROR); - } - /// add all visible bitmaps (from new to old) part - bool found_base = false; - auto list_it = part->delete_bitmap_metas.before_begin(); - for (auto bitmap_meta = *bitmap_it; bitmap_meta; bitmap_meta = bitmap_meta->tryGetPrevious()) - { - list_it = part->delete_bitmap_metas.insert_after(list_it, bitmap_meta->getModel()); - if (bitmap_meta->getType() == DeleteBitmapMetaType::Base) + /// add all visible bitmaps (from new to old) part + bool found_base = false; + auto list_it = part->delete_bitmap_metas.before_begin(); + for (auto bitmap_meta = *bitmap_it; bitmap_meta; bitmap_meta = bitmap_meta->tryGetPrevious()) { - found_base = true; - break; + list_it = part->delete_bitmap_metas.insert_after(list_it, bitmap_meta->getModel()); + if (bitmap_meta->getType() == DeleteBitmapMetaType::Base) + { + found_base = true; + break; + } + } + if (!found_base) + { + if (auto unique_table_log = getContext()->getCloudUniqueTableLog()) + { + auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, getCnchStorageID()); + current_log.metric = ErrorCodes::LOGICAL_ERROR; + current_log.event_msg = "Base delete bitmap of " + part->name() + " is not found"; + unique_table_log->add(current_log); + } + throw Exception("Base delete bitmap of " + part->name() + " is not found", ErrorCodes::LOGICAL_ERROR); } + + bitmap_it++; } - if (!found_base) + else { - if (auto unique_table_log = getContext()->getCloudUniqueTableLog()) + while (bitmap_it != bitmaps.end() && (*(*bitmap_it)) <= part->info()) { - auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, getCnchStorageID()); - current_log.metric = ErrorCodes::LOGICAL_ERROR; - current_log.event_msg = "Base delete bitmap of " + part->name() + " is not found"; - unique_table_log->add(current_log); + if (!(*bitmap_it)->sameBlock(part->info())) + bitmap_it++; + else + { + /// add all visible bitmaps (from new to old) part + bool found_base = false; + auto list_it = part->delete_bitmap_metas.before_begin(); + for (auto bitmap_meta = *bitmap_it; bitmap_meta; bitmap_meta = bitmap_meta->tryGetPrevious()) + { + list_it = part->delete_bitmap_metas.insert_after(list_it, bitmap_meta->getModel()); + if (bitmap_meta->getType() == DeleteBitmapMetaType::Base) + { + found_base = true; + break; + } + } + if (!found_base) + { + if (auto unique_table_log = getContext()->getCloudUniqueTableLog()) + { + auto current_log = UniqueTable::createUniqueTableLog(UniqueTableLogElement::ERROR, getCnchStorageID()); + current_log.metric = ErrorCodes::LOGICAL_ERROR; + current_log.event_msg = "Base delete bitmap of " + part->name() + " is not found"; + unique_table_log->add(current_log); + } + throw Exception("Base delete bitmap of " + part->name() + " is not found", ErrorCodes::LOGICAL_ERROR); + } + bitmap_it++; + } } - throw Exception("Base delete bitmap of " + part->name() + " is not found", ErrorCodes::LOGICAL_ERROR); } - bitmap_it++; } } +void MergeTreeMetaBase::getDeleteBitmapMetaForCnchParts(MutableMergeTreeDataPartsCNCHVector & parts, DeleteBitmapMetaPtrVector & all_bitmaps, bool force_found) +{ + MergeTreeDataPartsCNCHVector cnch_parts; + cnch_parts.reserve(parts.size()); + for (auto & part : parts) + cnch_parts.emplace_back(const_pointer_cast(part)); + getDeleteBitmapMetaForCnchParts(cnch_parts, all_bitmaps, force_found); +} + void MergeTreeMetaBase::getDeleteBitmapMetaForCnchParts(const MergeTreeDataPartsCNCHVector & parts, DeleteBitmapMetaPtrVector & all_bitmaps, bool force_found) { DeleteBitmapMetaPtrVector bitmaps; diff --git a/src/MergeTreeCommon/MergeTreeMetaBase.h b/src/MergeTreeCommon/MergeTreeMetaBase.h index f946aecb0d1..106dbd295b4 100644 --- a/src/MergeTreeCommon/MergeTreeMetaBase.h +++ b/src/MergeTreeCommon/MergeTreeMetaBase.h @@ -54,9 +54,6 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer using MetaStorePtr = std::shared_ptr; - using MergeTreePartitions = std::vector>; - using ServerDataParts = std::unordered_map; - /// Alter conversions which should be applied on-fly for part. Build from of /// the most recent mutation commands for part. Now we have only rename_map /// here (from ALTER_RENAME) command, because for all other type of alters @@ -145,7 +142,7 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer Graphite::Params graphite_params; /// Check that needed columns are present and have correct types. - void check(const StorageInMemoryMetadata & metadata, bool has_unique_key) const; + void check(const StorageInMemoryMetadata & metadata, bool attach) const; String getModeName() const; @@ -172,7 +169,7 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer StoragePolicyPtr getStoragePolicy(StorageLocation location) const override; virtual const String& getRelativeDataPath(StorageLocation location) const; - virtual void setRelativeDataPath(StorageLocation location, const String& rel_path); + void setRelativeDataPath(StorageLocation location, const String & rel_path); bool supportsFinal() const override { @@ -189,6 +186,7 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer bool supportsSampling() const override { return true; } bool supportsIndexForIn() const override { return true; } bool supportsMapImplicitColumn() const override { return true; } + bool supportsParallelInsert(ContextPtr local_context) const override; NamesAndTypesList getVirtuals() const override; @@ -204,6 +202,8 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer /// If uuid is empty, throw exception UUID getCnchStorageUUID() const; + const MergingParams & getMergingParams() const { return merging_params; } + //// Data parts /// Returns a copy of the list so that the caller shouldn't worry about locks. DataParts getDataParts(const DataPartStates & affordable_states) const; @@ -221,14 +221,10 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer /// Returns all parts in specified partition DataPartsVector getDataPartsVectorInPartition(DataPartState /*state*/, const String & /*partition_id*/) const; - MergeTreePartitions getAllPartitions() const; - /// Returns Committed parts DataParts getDataParts() const; DataPartsVector getDataPartsVector() const; - ServerDataPartsVector getServerDataPartsInPartitions(const Strings & required_partitions); - /// Returns the part with the given name and state or nullptr if no such part. DataPartPtr getPartIfExists(const String & part_name, const DataPartStates & valid_states); DataPartPtr getPartIfExistsWithoutLock(const String & part_name, const DataPartStates & valid_states); @@ -284,6 +280,12 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer /// For ATTACH/DETACH/DROP PARTITION. String getPartitionIDFromQuery(const ASTPtr & ast, ContextPtr context) const; + bool extractNullableForPartitionID() const + { + const auto & settings = getSettings(); + return settings->allow_nullable_key && settings->extract_partition_nullable_date; + } + MutableDataPartPtr cloneAndLoadDataPartOnSameDisk(const DataPartPtr & src_part, const String & tmp_part_prefix, const MergeTreePartInfo & dst_part_info, const StorageMetadataPtr & metadata_snapshot); @@ -394,7 +396,7 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer Block getSampleBlockWithVirtualColumns() const; - Block getBlockWithVirtualPartitionColumns(const std::vector> & partition_list) const; + Block getPartitionBlockWithVirtualColumns(const std::vector> & partition_list) const; /// Construct a block consisting only of possible virtual columns for part pruning. /// If one_part is true, fill in at most one part. @@ -435,9 +437,12 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer virtual bool supportsOptimizer() const override { return true; } - bool commitTxnFromWorkerSide(const StorageMetadataPtr & metadata_snapshot, ContextPtr query_context) const; virtual bool supportIntermedicateResultCache() const override { return true; } + /// Just compatible with old impl for unique table + bool commitTxnInWriteSuffixStage(const UInt32 & deup_impl_version, ContextPtr query_context) const; + bool supportsWriteInWorkers(const Context & query_context) const; + ColumnSize calculateMapColumnSizesImpl(const String & map_implicit_column_name) const; void resetObjectColumns(const ColumnsDescription & object_columns_) { object_columns = object_columns_; } @@ -452,18 +457,21 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer /// partition filters /// TODO: make partition_list constant - void filterPartitionByTTL(std::vector> & partition_list, ContextPtr local_context) const; + void filterPartitionByTTL(std::vector> & partition_list, time_t query_time) const; + bool canFilterPartitionByTTL() const; Strings selectPartitionsByPredicate( const SelectQueryInfo & query_info, std::vector> & partition_list, const Names & column_names_to_return, - ContextPtr local_context) const; + ContextPtr local_context, + const bool & ignore_ttl = false) const; /** * @param parts input parts, must be sorted in PartComparator order */ - void getDeleteBitmapMetaForServerParts(const ServerDataPartsVector & parts, DeleteBitmapMetaPtrVector & delete_bitmap_metas) const; + void getDeleteBitmapMetaForServerParts(const ServerDataPartsVector & parts, DeleteBitmapMetaPtrVector & delete_bitmap_metas, bool force_found = true) const; + void getDeleteBitmapMetaForCnchParts(MutableMergeTreeDataPartsCNCHVector & parts, DeleteBitmapMetaPtrVector & delete_bitmap_metas, bool force_found = true); void getDeleteBitmapMetaForCnchParts(const MergeTreeDataPartsCNCHVector & parts, DeleteBitmapMetaPtrVector & delete_bitmap_metas, bool force_found = true); void getDeleteBitmapMetaForParts(IMergeTreeDataPartsVector & parts, DeleteBitmapMetaPtrVector & delete_bitmap_metas, bool force_found = true); void getDeleteBitmapMetaForStagedParts(const MergeTreeDataPartsCNCHVector & parts, ContextPtr context, TxnTimestamp start_time); @@ -594,9 +602,6 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer void checkProperties(const StorageInMemoryMetadata & new_metadata, const StorageInMemoryMetadata & old_metadata, bool attach = false) const; - /// Check version column constrains when create table - void checkVersionColumnConstraint(); - void setProperties(const StorageInMemoryMetadata & new_metadata, const StorageInMemoryMetadata & old_metadata, bool attach = false); void checkPartitionKeyAndInitMinMax(const KeyDescription & new_partition_key); @@ -622,14 +627,6 @@ class MergeTreeMetaBase : public IStorage, public WithMutableContext, public Mer /// Returns default settings for storage with possible changes from global config. virtual std::unique_ptr getDefaultSettings() const = 0; - /// track runtime server parts by partition id. Used when query by table version - MergeTreePartitions data_partitions; - // Server dataparts with delete bitmap. should be protected by data part lock - ServerDataParts server_data_parts; - - mutable std::mutex server_data_mutex; - mutable std::atomic has_server_part_to_load{false}; - private: // Record all query ids which access the table. It's guarded by `query_id_set_mutex` and is always mutable. mutable std::set query_id_set; diff --git a/src/MergeTreeCommon/StorageDataManager.cpp b/src/MergeTreeCommon/StorageDataManager.cpp index 7996a008b81..9723d36363f 100644 --- a/src/MergeTreeCommon/StorageDataManager.cpp +++ b/src/MergeTreeCommon/StorageDataManager.cpp @@ -2,9 +2,19 @@ #include +namespace ProfileEvents +{ + extern const Event LoadedServerParts; +} + namespace DB { +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + StorageDataManager::StorageDataManager(const ContextPtr context_,const UUID & uuid_, const WGWorkerInfoPtr & worker_info_ ) : WithContext(context_), storage_uuid(uuid_), @@ -12,20 +22,47 @@ StorageDataManager::StorageDataManager(const ContextPtr context_,const UUID & uu { } -void StorageDataManager::loadDataPartsWithDBM(const MergeTreeMetaBase & storage, const UInt64 & version, ServerDataPartsWithDBM & server_parts) +void StorageDataManager::loadDataPartsWithDBM( + const MergeTreeMetaBase & storage, + const UInt64 & version, + std::unordered_map & res_server_parts, + std::vector> & res_partitions) { auto table_versions_ptr = getRequiredTableVersions(version); - LOG_TRACE(&Poco::Logger::get("StorageDataManager"), "Get required table versions from {} to {}", + LOG_TRACE(log, "Get required table versions from {} to {}", table_versions_ptr.back()->getVersion(), table_versions_ptr.front()->getVersion()); + size_t loaded_parts_count = 0; for (auto it=table_versions_ptr.begin(); itgetAllPartsWithDBM(storage); - server_parts.first.insert(server_parts.first.end(), parts_with_dbm.first.begin(), parts_with_dbm.first.end()); - server_parts.second.insert(server_parts.second.end(), parts_with_dbm.second.begin(), parts_with_dbm.second.end()); + + for (auto & server_part : parts_with_dbm.first) + { + const String & partition_id = server_part->info().partition_id; + auto inner_it = res_server_parts.find(partition_id); + if (inner_it == res_server_parts.end()) + { + // add to result partition list + res_partitions.emplace_back(server_part->part_model_wrapper->partition); + } + res_server_parts[partition_id].first.emplace_back(std::move(server_part)); + loaded_parts_count++; + } + + for (auto & delete_bitmap : parts_with_dbm.second) + { + const String & partition_id = delete_bitmap->getModel()->partition_id(); + if (res_server_parts.find(partition_id) == res_server_parts.end()) + throw Exception("Load delete bitmap mismatch server data part. Its a logic error. ", ErrorCodes::LOGICAL_ERROR); + + res_server_parts[partition_id].second.emplace_back(std::move(delete_bitmap)); + } } + + ProfileEvents::increment(ProfileEvents::LoadedServerParts, loaded_parts_count); } UInt64 StorageDataManager::getLatestVersion() @@ -41,7 +78,7 @@ std::vector StorageDataManager::getRequiredTableVersions(const UInt64 latest_version = getLatestVersion(); if (latest_version < required_version) { - LOG_TRACE(&Poco::Logger::get("StorageDataManager"), "Latest version {} less than required version {}. Will reload table versions.", + LOG_TRACE(log, "Latest version {} less than required version {}. Will reload table versions.", latest_version, required_version); reloadTableVersions(); } diff --git a/src/MergeTreeCommon/StorageDataManager.h b/src/MergeTreeCommon/StorageDataManager.h index 08cc72e34c4..f9e8bc918bc 100644 --- a/src/MergeTreeCommon/StorageDataManager.h +++ b/src/MergeTreeCommon/StorageDataManager.h @@ -13,7 +13,11 @@ class StorageDataManager : public WithContext public: StorageDataManager(const ContextPtr context, const UUID & uuid_, const WGWorkerInfoPtr & worker_info_); - void loadDataPartsWithDBM(const MergeTreeMetaBase & storage, const UInt64 & version, ServerDataPartsWithDBM & server_parts); + void loadDataPartsWithDBM( + const MergeTreeMetaBase & storage, + const UInt64 & version, + std::unordered_map & server_parts, + std::vector> & partitions); WGWorkerInfoPtr getWorkerInfo() const { return worker_info; } @@ -33,6 +37,8 @@ class StorageDataManager : public WithContext WGWorkerInfoPtr worker_info; std::shared_mutex mutex; std::map versions; + + Poco::Logger * log = &Poco::Logger::get("StorageDataManager"); }; using StorageDataManagerPtr = std::shared_ptr; diff --git a/src/MergeTreeCommon/TableVersion.cpp b/src/MergeTreeCommon/TableVersion.cpp index 218c2ffea3b..66ce1330bf3 100644 --- a/src/MergeTreeCommon/TableVersion.cpp +++ b/src/MergeTreeCommon/TableVersion.cpp @@ -10,6 +10,13 @@ #include #include + +namespace ProfileEvents +{ + extern const Event LoadManifestPartsCacheHits; + extern const Event LoadManifestPartsCacheMisses; +} + namespace DB { @@ -22,7 +29,7 @@ class ManifestDiskCacheSegment : public IDiskCacheSegment { public: explicit ManifestDiskCacheSegment(TableVersionPtr version_) - : IDiskCacheSegment(0, 0), + : IDiskCacheSegment(0, 0, SegmentType::MANIFEST), version_ptr(version_) { } @@ -127,8 +134,8 @@ void TableVersion::fileterDataByWorkerInfo(const MergeTreeMetaBase & storage, st String worker_id_prefix = worker_id.substr(0, worker_id.find_last_of('-') + 1); WorkerGroupHandle mock_wg = WorkerGroupHandleImpl::mockWorkerGroupHandle(worker_id_prefix, worker_info->num_workers, getContext()); - // Use the same allocation algorithm as preaload. can work with parts as well as delete bitmap. - auto allocate_res = assignCnchParts(mock_wg, data_vector, getContext()); + // Use consistent hash to make sure the parts with the same basic name are always allocated to the same worker + auto allocate_res = assignCnchParts(mock_wg, data_vector, getContext(), storage.getSettings(), Context::PartAllocator::JUMP_CONSISTENT_HASH); // only get the allocated data which belongs to current worker worker_hold_data = std::move(allocate_res[worker_id]); @@ -179,15 +186,19 @@ void TableVersion::loadManifestData(const MergeTreeMetaBase & storage) { data_parts.swap(loaded_parts); delete_bitmaps.swap(loaded_dbm); - loaded_from_manifest = true; } - - LOG_TRACE(&Poco::Logger::get("TableVersion"), "Loaded {} data parts and {} delete bitmaps from manifest disk cache {}.", - data_parts.size(), - delete_bitmaps.size(), - manifest_seg->getSegmentName()); - return; } + + // Disk may be empty if no server parts assigned to this worker. Then, nothin will be loaded. + LOG_TRACE(log, "Loaded {} data parts and {} delete bitmaps from manifest disk cache {}. Path : {}", + data_parts.size(), + delete_bitmaps.size(), + manifest_seg->getSegmentName(), + segment_path); + + loaded_from_manifest = true; + ProfileEvents::increment(ProfileEvents::LoadManifestPartsCacheHits); + return; } } @@ -197,7 +208,7 @@ void TableVersion::loadManifestData(const MergeTreeMetaBase & storage) String checkpoint_file_path = joinPaths({getCheckpointRelativePath(storage), toString(version)}); if (!remote_disk->exists(checkpoint_file_path)) throw Exception("Cannot find checkpoint " + toString(version) + " for table " + storage.getStorageID().getFullTableName(), ErrorCodes::LOGICAL_ERROR); - + auto read_buffer = remote_disk->readFile(checkpoint_file_path); do { @@ -216,6 +227,7 @@ void TableVersion::loadManifestData(const MergeTreeMetaBase & storage) loaded_dbm = catalog->getDeleteBitmapsFromManifest(storage, txn_list); } + ProfileEvents::increment(ProfileEvents::LoadManifestPartsCacheMisses); // filter parts by worker info. if (worker_info) { @@ -231,7 +243,7 @@ void TableVersion::loadManifestData(const MergeTreeMetaBase & storage) loaded_from_manifest = true; } - LOG_TRACE(&Poco::Logger::get("TableVersion"), "Loaded {} parts and {} delete bitmap in table version {} from {}.", + LOG_TRACE(log, "Loaded {} parts and {} delete bitmap in table version {} from {}.", data_parts.size(), delete_bitmaps.size(), version, diff --git a/src/MergeTreeCommon/TableVersion.h b/src/MergeTreeCommon/TableVersion.h index a6b2c95590e..bb0da390876 100644 --- a/src/MergeTreeCommon/TableVersion.h +++ b/src/MergeTreeCommon/TableVersion.h @@ -59,6 +59,8 @@ class TableVersion : public std::enable_shared_from_this, public W std::shared_mutex mutex; DataModelPartWrapperVector data_parts; DeleteBitmapMetaPtrVector delete_bitmaps; + + Poco::Logger * log = &Poco::Logger::get("TableVersion"); }; using TableVersionPtr = std::shared_ptr; diff --git a/src/MergeTreeCommon/assignCnchParts.cpp b/src/MergeTreeCommon/assignCnchParts.cpp index 889dad7f0d9..6bfa728faf8 100644 --- a/src/MergeTreeCommon/assignCnchParts.cpp +++ b/src/MergeTreeCommon/assignCnchParts.cpp @@ -67,20 +67,16 @@ inline void reportStats(Poco::Logger * log, const M & map, const String & name, } /// explicit instantiation for server part and cnch data part. -template ServerAssignmentMap assignCnchParts(const WorkerGroupHandle & worker_group, const ServerDataPartsVector & parts, const ContextPtr & query_context); -template AssignmentMap assignCnchParts(const WorkerGroupHandle & worker_group, const MergeTreeDataPartsCNCHVector & parts, const ContextPtr & query_context); -template std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataModelPartWrapperVector &, const ContextPtr & query_context); -template std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DeleteBitmapMetaPtrVector &, const ContextPtr & query_context); +template ServerAssignmentMap assignCnchParts(const WorkerGroupHandle & worker_group, const ServerDataPartsVector & parts, const ContextPtr & query_context, MergeTreeSettingsPtr settings, std::optional allocator = std::nullopt); +template AssignmentMap assignCnchParts(const WorkerGroupHandle & worker_group, const MergeTreeDataPartsCNCHVector & parts, const ContextPtr & query_context, MergeTreeSettingsPtr settings, std::optional allocator = std::nullopt); +template std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataModelPartWrapperVector &, const ContextPtr & query_context, MergeTreeSettingsPtr settings, std::optional allocator = std::nullopt); +template std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DeleteBitmapMetaPtrVector &, const ContextPtr & query_context, MergeTreeSettingsPtr settings, std::optional allocator = std::nullopt); template -std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataPartsCnchVector & parts, const ContextPtr & query_context) +std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataPartsCnchVector & parts, const ContextPtr & query_context, MergeTreeSettingsPtr settings, std::optional allocator) { static auto * log = &Poco::Logger::get("assignCnchParts"); - Context::PartAllocator part_allocation_algorithm; - if (query_context->getSettingsRef().cnch_part_allocation_algorithm.changed) - part_allocation_algorithm = query_context->getPartAllocationAlgo(); - else - part_allocation_algorithm = worker_group->getContext()->getPartAllocationAlgo(); + Context::PartAllocator part_allocation_algorithm = allocator.value_or(query_context->getPartAllocationAlgo(settings)); switch (part_allocation_algorithm) { @@ -112,17 +108,17 @@ std::unordered_map assignCnchParts(const WorkerGrou reportStats(log, ret, "Strict Consistent Hash", worker_group->getRing().size()); return ret; } - case Context::PartAllocator::SIMPLE_HASH: //Note: Now just used for test disk cache stealing so not used for online + case Context::PartAllocator::DISK_CACHE_STEALING_DEBUG: //Note: Now just used for test disk cache stealing so not used for online { - auto ret = assignCnchPartsWithSimpleHash(worker_group->getWorkerIDVec(), worker_group->getIdHostPortsMap(), parts); - reportStats(log, ret, "Simple Hash", worker_group->getWorkerIDVec().size()); + auto ret = assignCnchPartsWithStealingCache(worker_group->getWorkerIDVec(), worker_group->getIdHostPortsMap(), parts); + reportStats(log, ret, "disk cache stealing debug", worker_group->getWorkerIDVec().size()); return ret; } } } template -std::unordered_map assignCnchPartsWithSimpleHash(WorkerList worker_ids, const std::unordered_map & worker_hosts, const DataPartsCnchVector & parts) +std::unordered_map assignCnchPartsWithStealingCache(WorkerList worker_ids, const std::unordered_map & worker_hosts, const DataPartsCnchVector & parts) { std::unordered_map ret; /// we don't know the order of workers returned from consul so sort then explicitly now @@ -148,12 +144,12 @@ std::unordered_map assignCnchPartsWithSimpleHash(Wo return ret; } +/// worker_ids should be sorted template -std::unordered_map assignCnchPartsWithJump(WorkerList worker_ids, const std::unordered_map & worker_hosts, const DataPartsCnchVector & parts) +std::unordered_map assignCnchPartsWithJump( + const WorkerList & worker_ids, const std::unordered_map & worker_hosts, const DataPartsCnchVector & parts) { std::unordered_map ret; - /// we don't know the order of workers returned from consul so sort then explicitly now - sort(worker_ids.begin(), worker_ids.end()); auto num_workers = worker_ids.size(); for (const auto & part : parts) @@ -331,6 +327,22 @@ void moveBucketTablePartsToAssignedParts( } } +BucketNumbersAssignmentMap assignBuckets(const std::set & required_bucket_numbers, const WorkerList & workers, bool replicated) +{ + BucketNumbersAssignmentMap assignment; + if (replicated) + { + for (const auto & worker : workers) + assignment[worker] = required_bucket_numbers; + } + else + { + for (auto bucket : required_bucket_numbers) + assignment[workers[bucket % workers.size()]].insert(bucket); + } + return assignment; +} + BucketNumberAndServerPartsAssignment assignCnchPartsForBucketTable( const ServerDataPartsVector & parts, WorkerList workers, std::set required_bucket_numbers, bool replicated) { diff --git a/src/MergeTreeCommon/assignCnchParts.h b/src/MergeTreeCommon/assignCnchParts.h index 3f410284fb7..8b143026a85 100644 --- a/src/MergeTreeCommon/assignCnchParts.h +++ b/src/MergeTreeCommon/assignCnchParts.h @@ -48,7 +48,9 @@ FilePartsAssignMap assignCnchFileParts(const WorkerGroupHandle & worker_group, c HivePartsAssignMap assignCnchHiveParts(const WorkerGroupHandle & worker_group, const HiveFiles & parts); template -std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataPartsCnchVector & parts, const ContextPtr & query_context); +std::unordered_map assignCnchParts(const WorkerGroupHandle & worker_group, const DataPartsCnchVector & parts, const ContextPtr & context, MergeTreeSettingsPtr settings, std::optional allocator = std::nullopt); + +BucketNumbersAssignmentMap assignBuckets(const std::set & required_bucket_numbers, const WorkerList & workers, bool replicated); /** * splitCnchParts will split server parts into bucketed parts and leftover server parts. diff --git a/src/MergeTreeCommon/tests/gtest_create_query_for_cloud_table.cpp b/src/MergeTreeCommon/tests/gtest_create_query_for_cloud_table.cpp index 0ae38b147f3..581ed395600 100644 --- a/src/MergeTreeCommon/tests/gtest_create_query_for_cloud_table.cpp +++ b/src/MergeTreeCommon/tests/gtest_create_query_for_cloud_table.cpp @@ -30,7 +30,6 @@ PARTITION BY toDate(event_time) PRIMARY KEY s ORDER BY (s, id) UNIQUE KEY id -SETTINGS cnch_temporary_table = 1 )#"; EXPECT_EQ(res, expected); } @@ -55,7 +54,6 @@ PARTITION BY toDate(event_time) PRIMARY KEY s ORDER BY (s, id) UNIQUE KEY id -SETTINGS cnch_temporary_table = 1 )#"; EXPECT_EQ(res, expected); } @@ -76,7 +74,6 @@ TEST(test_create_query_for_cloud_table, version_collapse) ) ENGINE = CloudVersionedCollapsingMergeTree(db1, tb1, Sign, Version) ORDER BY UserID -SETTINGS cnch_temporary_table = 1 )#"; EXPECT_EQ(res, expected); } @@ -94,7 +91,6 @@ TEST(test_create_query_for_cloud_table, s3) `age` String ) ENGINE = CloudS3(db1, tb1, `http://some_link/some_path/some_file.csv`, CSV, none, AKkkkkkkkkk, sKkkkkkkkkkkkkkkkkkkk) -SETTINGS cnch_temporary_table = 1 )#"; EXPECT_EQ(res, expected); } @@ -118,7 +114,6 @@ ENGINE = CloudMergeTree(db1, tb1) PARTITION BY toDate(event_time) PRIMARY KEY s ORDER BY (s, id) -SETTINGS cnch_temporary_table = 1 )#"; EXPECT_EQ(res, expected); } diff --git a/src/MergeTreeCommon/tests/gtest_topology.cpp b/src/MergeTreeCommon/tests/gtest_topology.cpp new file mode 100644 index 00000000000..2d2c4b48601 --- /dev/null +++ b/src/MergeTreeCommon/tests/gtest_topology.cpp @@ -0,0 +1,60 @@ +#include +#include +#include "Protos/DataModelHelpers.h" + + + +using namespace DB; + +namespace GTEST_TOPOLOGY +{ +TEST(CnchServerTopology, DiffTopology) +{ + // Create a Empty Topology + auto topo = CnchServerTopology(); + EXPECT_TRUE(topo.isSameTopologyWith(topo)); + + auto topo2 = CnchServerTopology(); + + HostWithPorts host1 = HostWithPorts("host1", 1, 2, 3, 4, 5, "id"); + HostWithPorts host2 = HostWithPorts("host2", 1, 2, 3, 4, 5, "id"); + HostWithPorts host1_v2 = HostWithPorts("host1", 1, 2, 3, 4, 5, "id"); + host1_v2.exchange_status_port = 333; + HostWithPorts host2_v2 = HostWithPorts("host2", 1, 2, 3, 4, 5, "id"); + host2_v2.exchange_port = 6; + + topo2.addServer(host1); + topo2.addServer(host2); + EXPECT_FALSE(topo.isSameTopologyWith(topo2)); + + topo.addServer(host1); + EXPECT_FALSE(topo.isSameTopologyWith(topo2)); + + topo.addServer(host2); + EXPECT_TRUE(topo.isSameTopologyWith(topo2)); + + auto topo3 = CnchServerTopology(); + topo3.addServer(host1_v2); + topo3.addServer(host2_v2); + /// exchange_port or exchange_status_port will be ignored. + EXPECT_TRUE(topo.isSameTopologyWith(topo3)); +} + +TEST(CnchServerTopology, Serialization) +{ + auto topo = CnchServerTopology(); + EXPECT_TRUE(topo.isSameTopologyWith(topo)); + HostWithPorts host1 = HostWithPorts("host1", 1, 2, 3, 4, 5, "id"); + HostWithPorts host2 = HostWithPorts("host2", 1, 2, 3, 4, 5, "id"); + topo.addServer(host1); + topo.addServer(host2); + + pb::RepeatedPtrField topology_versions; + + fillTopologyVersions({topo}, topology_versions); + auto new_topo = createTopologyVersionsFromModel(topology_versions); + + EXPECT_EQ(new_topo.size(), 1); + EXPECT_TRUE(topo.isSameTopologyWith(new_topo.front())); +} +} diff --git a/src/Optimizer/CardinalityEstimate/FilterEstimator.cpp b/src/Optimizer/CardinalityEstimate/FilterEstimator.cpp index 75f482a52e8..ebc4cd95898 100644 --- a/src/Optimizer/CardinalityEstimate/FilterEstimator.cpp +++ b/src/Optimizer/CardinalityEstimate/FilterEstimator.cpp @@ -53,14 +53,14 @@ PlanNodeStatisticsPtr FilterEstimator::estimate( if (!is_on_base_table) { // Prefer default selectivity when is_on_base_table flag is false. - UInt64 row_count = filter_stats->getRowCount() * default_selectivity; + UInt64 row_count = std::round(filter_stats->getRowCount() * default_selectivity); // make row count at least 1. row_count = row_count > 1 ? row_count : 1; filter_stats->updateRowCount(row_count); for (auto & symbol_stats : filter_stats->getSymbolStatistics()) { - symbol_stats.second = symbol_stats.second->applySelectivity(default_selectivity); + symbol_stats.second = symbol_stats.second->applySelectivity(default_selectivity, symbol_stats.second->getNdv() > opt_child_stats->getRowCount() * 0.8 ? default_selectivity : 1); // NDV must less or equals to row count symbol_stats.second->setNdv(std::min(filter_stats->getRowCount(), symbol_stats.second->getNdv())); } @@ -87,7 +87,7 @@ PlanNodeStatisticsPtr FilterEstimator::estimate( selectivity = 0; } - UInt64 filtered_row_count = filter_stats->getRowCount() * selectivity; + UInt64 filtered_row_count = std::round(filter_stats->getRowCount() * selectivity); // make row count at least 1. filter_stats->updateRowCount(filtered_row_count > 0 ? filtered_row_count : std::min(UInt64(1), opt_child_stats->getRowCount())); std::unordered_map & symbol_statistics_in_filter = result.second; @@ -100,7 +100,7 @@ PlanNodeStatisticsPtr FilterEstimator::estimate( } else { - symbol_statistics.second = symbol_statistics.second->applySelectivity(selectivity); + symbol_statistics.second = symbol_statistics.second->applySelectivity(selectivity, symbol_statistics.second->getNdv() > opt_child_stats->getRowCount() * 0.8 ? selectivity : 1); // NDV must less or equals to row count symbol_statistics.second->setNdv(std::min(filter_stats->getRowCount(), symbol_statistics.second->getNdv())); symbol_statistics.second->getHistogram().clear(); diff --git a/src/Optimizer/CardinalityEstimate/JoinEstimator.cpp b/src/Optimizer/CardinalityEstimate/JoinEstimator.cpp index 384a59e3d93..56977025932 100644 --- a/src/Optimizer/CardinalityEstimate/JoinEstimator.cpp +++ b/src/Optimizer/CardinalityEstimate/JoinEstimator.cpp @@ -95,7 +95,8 @@ PlanNodeStatisticsPtr JoinEstimator::estimate( } } - UInt64 filtered_row_count = res->getRowCount() * selectivity; + auto before_filter_row_count = res->getRowCount(); + UInt64 filtered_row_count = std::round(res->getRowCount() * selectivity); // make row count at least 1. res->updateRowCount(filtered_row_count > 1 ? filtered_row_count : 1); for (auto & symbol_statistics : res->getSymbolStatistics()) @@ -108,7 +109,8 @@ PlanNodeStatisticsPtr JoinEstimator::estimate( } else { - symbol_statistics.second = symbol_statistics.second->applySelectivity(selectivity); + symbol_statistics.second = symbol_statistics.second->applySelectivity( + selectivity, symbol_statistics.second->getNdv() > before_filter_row_count * 0.8 ? selectivity : 1); // NDV must less or equals to row count symbol_statistics.second->setNdv(std::min(res->getRowCount(), symbol_statistics.second->getNdv())); symbol_statistics.second->getHistogram().clear(); diff --git a/src/Optimizer/Cascades/Task.cpp b/src/Optimizer/Cascades/Task.cpp index e668f38ef98..b75c5c31a75 100644 --- a/src/Optimizer/Cascades/Task.cpp +++ b/src/Optimizer/Cascades/Task.cpp @@ -29,6 +29,7 @@ #include #include #include +#include "Interpreters/Context_fwd.h" #include "QueryPlan/IQueryPlanStep.h" #include @@ -282,7 +283,7 @@ void OptimizeInput::execute() // 1. We can init input cost using non-zero value for pruning // 2. We can calculate the current operator cost if we have maintain // logical properties in group (e.g. stats, schema, cardinality) - + // Compute the cost of the root operator // 1. Collect stats needed and cache them in the group // 2. Calculate cost based on children's stats and cache it in the group expression @@ -351,7 +352,7 @@ void OptimizeInput::execute() break; } auto & input_props = input_properties[cur_prop_pair_idx]; - + // initial total cost if (cur_child_idx == 0) { @@ -420,7 +421,8 @@ void OptimizeInput::execute() single_count++; } - if (group_expr->getStep()->getType() == IQueryPlanStep::Type::Union && single_count > 0 && single_count < group_expr->getChildrenGroups().size()) + if (group_expr->getStep()->getType() == IQueryPlanStep::Type::Union && single_count > 0 + && single_count < group_expr->getChildrenGroups().size()) { auto new_child_requires = input_props; for (auto & new_child : new_child_requires) @@ -576,6 +578,44 @@ void OptimizeInput::addInputPropertiesForCTE(CTEId cte_id, CTEDescription cte_de input_properties.insert(input_properties.end(), new_properties.begin(), new_properties.end()); } +static PropertySets makeHandleSame(const PropertySet & input_props, const PropertySet & actual_props, const ContextPtr & context) +{ + PropertySets result; + auto new_child_requires = input_props; + for (auto & new_child : new_child_requires) + { + new_child.getNodePartitioningRef().setRequireHandle(true); + } + result.emplace_back(new_child_requires); + + if (actual_props[0].getNodePartitioning().isExchangeSchema(context->getSettingsRef().enable_bucket_shuffle) + && actual_props[0].getNodePartitioning().getHandle() == Partitioning::Handle::BUCKET_TABLE) + { + auto other_new_child_requires = new_child_requires; + for (auto & new_child : other_new_child_requires) + { + new_child.getNodePartitioningRef().setHandle(Partitioning::Handle::BUCKET_TABLE); + new_child.getNodePartitioningRef().setBuckets(actual_props[0].getNodePartitioning().getBuckets()); + new_child.getNodePartitioningRef().setBucketExpr(actual_props[0].getNodePartitioning().getBucketExpr()); + } + result.emplace_back(other_new_child_requires); + } + + if (actual_props[1].getNodePartitioning().isExchangeSchema(context->getSettingsRef().enable_bucket_shuffle) + && actual_props[1].getNodePartitioning().getHandle() == Partitioning::Handle::BUCKET_TABLE) + { + auto other_new_child_requires = new_child_requires; + for (auto & new_child : other_new_child_requires) + { + new_child.getNodePartitioningRef().setHandle(Partitioning::Handle::BUCKET_TABLE); + new_child.getNodePartitioningRef().setBuckets(actual_props[1].getNodePartitioning().getBuckets()); + new_child.getNodePartitioningRef().setBucketExpr(actual_props[1].getNodePartitioning().getBucketExpr()); + } + result.emplace_back(other_new_child_requires); + } + return result; +} + bool OptimizeInput::checkJoinInputProperties(const PropertySet & requried_input_props, const PropertySet & actual_input_props) { bool all_fix_hash = std::all_of(requried_input_props.begin(), requried_input_props.end(), [](const auto & i_prop) { @@ -618,6 +658,7 @@ bool OptimizeInput::checkJoinInputProperties(const PropertySet & requried_input_ auto first_handle = first_props.getNodePartitioning().getHandle(); auto first_bucket_count = first_props.getNodePartitioning().getBuckets(); + auto first_sharding_expr = first_props.getNodePartitioning().getBucketExpr(); auto first_partition_column = first_props.getNodePartitioning().normalize(*left_equivalences).getColumns(); for (size_t actual_prop_index = 1; actual_prop_index < actual_input_props.size(); ++actual_prop_index) @@ -625,7 +666,8 @@ bool OptimizeInput::checkJoinInputProperties(const PropertySet & requried_input_ auto before_transformed_partition_cols = actual_input_props[actual_prop_index].getNodePartitioning().getColumns(); auto translated_prop = actual_input_props[actual_prop_index].normalize(*right_equivalences); if (translated_prop.getNodePartitioning().getHandle() != first_handle - || translated_prop.getNodePartitioning().getBuckets() != first_bucket_count) + || translated_prop.getNodePartitioning().getBuckets() != first_bucket_count + || !ASTEquality::compareTree(translated_prop.getNodePartitioning().getBucketExpr(), first_sharding_expr)) { match = false; break; @@ -649,14 +691,13 @@ bool OptimizeInput::checkJoinInputProperties(const PropertySet & requried_input_ if (!match) { - auto new_child_requires = requried_input_props; - for (auto & new_child : new_child_requires) + for (auto & new_child_requires : makeHandleSame(requried_input_props, actual_input_props, context->getOptimizerContext().getContext())) { - new_child.getNodePartitioningRef().setRequireHandle(true); + input_properties.emplace_back(new_child_requires); } - input_properties.emplace_back(new_child_requires); } + return match; } @@ -738,7 +779,8 @@ void OptimizeInput::enforcePropertyAndUpdateWinner( // increase cost if the cte exists both join side. disable q11 & q74 cte for tpcds. if (!it.second && group_expr->getStep()->getType() == IQueryPlanStep::Type::Join) { - auto coefficient = opt_context->getOptimizerContext().getContext()->getSettingsRef().cost_calculator_cte_weight_for_join_build_side; + auto coefficient + = opt_context->getOptimizerContext().getContext()->getSettingsRef().cost_calculator_cte_weight_for_join_build_side; it.first->second.second = std::max(it.first->second.second, cte_prop.second.second) * coefficient; } } @@ -848,7 +890,7 @@ void OptimizeCTE::execute() if (context->getOptimizerContext().isEnableTrace()) context->getOptimizerContext().trace("OptimizeCTE", group_expr->getGroupId(), group_expr->getProduceRule(), elapsed_ns); } - + OptimizerTask::OptimizerTask(OptContextPtr context_) : context(std::move(context_)), log(context->getOptimizerContext().getLog()) { } diff --git a/src/Optimizer/CostModel/ExchangeCost.cpp b/src/Optimizer/CostModel/ExchangeCost.cpp index 813ae024e9f..e927a1ca7e2 100644 --- a/src/Optimizer/CostModel/ExchangeCost.cpp +++ b/src/Optimizer/CostModel/ExchangeCost.cpp @@ -30,7 +30,7 @@ PlanNodeCost ExchangeCost::calculate(const ExchangeStep & step, CostContext & co if (!step.getSchema().getColumns().empty() && (step.getSchema().getHandle() == Partitioning::Handle::FIXED_HASH || step.getSchema().getHandle() == Partitioning::Handle::BUCKET_TABLE)) - base_cost += 1.0 / step.getSchema().getColumns().size(); + base_cost += 1.0 / (step.getSchema().getColumns().size() + 1); if (step.getSchema().getHandle() == Partitioning::Handle::BUCKET_TABLE) base_cost *= 1.1; diff --git a/src/Optimizer/Dump/DumpUtils.h b/src/Optimizer/Dump/DumpUtils.h index 229247846b5..0c827988fec 100644 --- a/src/Optimizer/Dump/DumpUtils.h +++ b/src/Optimizer/Dump/DumpUtils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -46,6 +47,7 @@ struct DumpSettings {"compress_directory", compress_directory}, {"version", version}, }; + std::unordered_map> uint_settings = {}; }; } diff --git a/src/Optimizer/ExpressionInterpreter.cpp b/src/Optimizer/ExpressionInterpreter.cpp index 47c6f2b3d5c..1db1a1929ec 100644 --- a/src/Optimizer/ExpressionInterpreter.cpp +++ b/src/Optimizer/ExpressionInterpreter.cpp @@ -402,10 +402,10 @@ std::pair ExpressionInterpreter::optimizeExpression(const C ASTPtr ExpressionInterpreter::optimizePredicate(const ConstASTPtr & expression) const { auto result = evaluate(expression); - Utils::checkState(isBoolCompatibleType(result.type)); + // other rules(e.g. CommonPredicateRewriteRule) may generate predicate with invalid types, cast them to UInt8 if (result.isAST()) - return result.ast; + return isBoolCompatibleType(result.type) ? result.ast : makeASTFunction("toBool", result.ast); const auto & field = result.value; UInt8 x = !field.isNull() && applyVisitor(FieldVisitorConvertToNumber(), field); @@ -710,6 +710,16 @@ InterpretIMResult ExpressionInterpreter::visitInFunction(const ASTFunction & fun if (left_arg_result.isAST() && !setting.enable_function_simplify) return {getType(rewritten_in_func), rewritten_in_func}; + if (const auto * ast_prepared_param = right_arg->as()) + { + auto riget_arg_result = visitASTPreparedParameter(*ast_prepared_param, right_arg); + ColumnsWithTypeAndName columns_with_types; + columns_with_types.emplace_back(left_arg_result.value, left_arg_result.type, ""); + columns_with_types.emplace_back(riget_arg_result.value, riget_arg_result.type, ""); + auto overload_resolver = FunctionFactory::instance().tryGet(function.name, context); + return {overload_resolver->getReturnType(columns_with_types), rewritten_in_func}; + } + // build set for IN statement(see also ActionsVisitor) SetPtr set; { diff --git a/src/Optimizer/ExpressionRewriter.cpp b/src/Optimizer/ExpressionRewriter.cpp index 57f481c4df0..efa8c8b4c06 100644 --- a/src/Optimizer/ExpressionRewriter.cpp +++ b/src/Optimizer/ExpressionRewriter.cpp @@ -44,7 +44,8 @@ bool FunctionIsInjective::isInjective(const ConstASTPtr & expr, ContextMutablePt } Scope scope(Scope::ScopeType::RELATION, nullptr, true, fields); ExprAnalyzerOptions options; - ExprAnalyzer::analyze(std::const_pointer_cast(expr), &scope, context, analysis, options); + ASTPtr tmp_expr = std::const_pointer_cast(expr); + ExprAnalyzer::analyze(tmp_expr, &scope, context, analysis, options); FunctionIsInjectiveVisitor visitor{context, analysis.getExpressionColumnWithTypes()}; NameSet remind_partition_cols = partition_cols; return ASTVisitorUtil::accept(expr, visitor, remind_partition_cols) && remind_partition_cols.empty(); @@ -78,7 +79,7 @@ bool FunctionIsInjectiveVisitor::visitASTFunction(const ConstASTPtr & node, Name processed_arguments.emplace_back(col_type.column, col_type.type, arg->getColumnName()); } auto function_base = function_builder->build(processed_arguments); - bool is_injective = function_base->isInjective(processed_arguments); + bool is_injective = function_base->isInjective(processed_arguments); if (is_injective) { visitNode(node, c); diff --git a/src/Optimizer/IntermediateResult/CacheParamBuilder.cpp b/src/Optimizer/IntermediateResult/CacheParamBuilder.cpp index 35d9e804947..ee4ec52e901 100644 --- a/src/Optimizer/IntermediateResult/CacheParamBuilder.cpp +++ b/src/Optimizer/IntermediateResult/CacheParamBuilder.cpp @@ -81,6 +81,7 @@ size_t CacheParamBuilder::computeJoinHash(std::shared_ptr join_step) join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), diff --git a/src/Optimizer/PlanCheck.cpp b/src/Optimizer/PlanCheck.cpp index cb13674bf6e..a6939fcced7 100644 --- a/src/Optimizer/PlanCheck.cpp +++ b/src/Optimizer/PlanCheck.cpp @@ -28,6 +28,7 @@ void PlanCheck::checkInitPlan(QueryPlan & plan, ContextMutablePtr context) void PlanCheck::checkFinalPlan(QueryPlan & plan, ContextMutablePtr context) { SymbolChecker::check(plan, context, true); + TableScanChecker::check(plan, context); } void ReadNothingChecker::check(PlanNodePtr plan) @@ -103,4 +104,34 @@ Void SymbolChecker::visitFilterNode(FilterNode & node, ContextMutablePtr & conte return {}; } +void TableScanChecker::check(QueryPlan & plan, ContextMutablePtr & context) +{ + TableScanChecker tablescan_check; + VisitorUtil::accept(plan.getPlanNode(), tablescan_check, context); +} + +Void TableScanChecker::visitPlanNode(PlanNodeBase & node, ContextMutablePtr & context) +{ + for (const auto & item : node.getChildren()) + { + VisitorUtil::accept(*item, *this, context); + } + return {}; +} + +Void TableScanChecker::visitTableScanNode(TableScanNode & node, ContextMutablePtr & context) +{ + auto & step = node.getStep(); + auto storage = step->getStorage(); + if (!context->getSettingsRef().allow_map_access_without_key && storage && storage->supportsMapImplicitColumn()) + { + Block header = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context)->getSampleBlockForColumns(step->getRequiredColumns()); + for (auto & col : header) + { + if (col.type->isByteMap()) + throw Exception("Map column access without key is not allowed for ByteMap", ErrorCodes::NOT_IMPLEMENTED); + } + } + return {}; +} } diff --git a/src/Optimizer/PlanCheck.h b/src/Optimizer/PlanCheck.h index a5b0ae6d125..52576251951 100644 --- a/src/Optimizer/PlanCheck.h +++ b/src/Optimizer/PlanCheck.h @@ -51,4 +51,12 @@ class SymbolChecker : public PlanNodeVisitor bool check_filter; }; +class TableScanChecker : public PlanNodeVisitor +{ +public: + static void check(QueryPlan & plan, ContextMutablePtr & context); + + Void visitPlanNode(PlanNodeBase &, ContextMutablePtr &) override; + Void visitTableScanNode(TableScanNode &, ContextMutablePtr &) override; +}; } diff --git a/src/Optimizer/PlanOptimizer.cpp b/src/Optimizer/PlanOptimizer.cpp index d655bdc7c11..71d9c38aae6 100644 --- a/src/Optimizer/PlanOptimizer.cpp +++ b/src/Optimizer/PlanOptimizer.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -121,6 +122,7 @@ const Rewriters & PlanOptimizer::getSimpleRewriters() // add exchange std::make_shared(false), + std::make_shared(), std::make_shared(Rules::pushPartialStepRules(), "PushPartialStep"), std::make_shared(Rules::optimizeAggregateRules(), "OptimizeAggregate"), std::make_shared(), @@ -138,7 +140,6 @@ const Rewriters & PlanOptimizer::getSimpleRewriters() std::make_shared(), std::make_shared(), /* some rules generates incorrect column ptr for DataStream, e.g. use a non-nullable column ptr for a nullable column */ - std::make_shared(), std::make_shared(Rules::pushTableScanEmbeddedStepRules(), "PushTableScanEmbeddedStepRules"), std::make_shared(), @@ -155,6 +156,7 @@ const Rewriters & PlanOptimizer::getFullRewriters() std::make_shared(), std::make_shared(), + std::make_shared(Rules::joinUsingToJoinOn(), "JoinUsingToJoinOn"), std::make_shared(Rules::sumIfToCountIf(), "SumIfToCountIf"), // remove subquery rely on specific pattern @@ -280,6 +282,10 @@ const Rewriters & PlanOptimizer::getFullRewriters() // Cost-based optimizer std::make_shared(), + // remove not inlined CTEs + std::make_shared(), + std::make_shared(), + // add runtime filters std::make_shared(), @@ -313,7 +319,6 @@ const Rewriters & PlanOptimizer::getFullRewriters() std::make_shared(), std::make_shared(), /* some rules generates incorrect column ptr for DataStream, e.g. use a non-nullable column ptr for a nullable column */ - std::make_shared(), std::make_shared(Rules::pushTableScanEmbeddedStepRules(), "PushTableScanEmbeddedStepRules"), std::make_shared(), std::make_shared(), @@ -324,6 +329,18 @@ const Rewriters & PlanOptimizer::getFullRewriters() return full_rewrites; } +const Rewriters & PlanOptimizer::getShortCircuitRewriters() +{ + static Rewriters short_circuit_rewriters = { + std::make_shared(), + std::make_shared(Rules::pushDownLimitRules(), "PushDownLimit"), + std::make_shared(Rules::removeRedundantRules(), "RemoveRedundant"), + std::make_shared(Rules::pushIntoTableScanRules(), "PushIntoTableScan"), + std::make_shared(Rules::explainAnalyzeRules(), "ExplainAnalyze"), + }; + return short_circuit_rewriters; +} + void PlanOptimizer::optimize(QueryPlan & plan, ContextMutablePtr context) { int i = GraphvizPrinter::PRINT_PLAN_OPTIMIZE_INDEX; @@ -335,7 +352,13 @@ void PlanOptimizer::optimize(QueryPlan & plan, ContextMutablePtr context) Stopwatch rule_watch, total_watch; total_watch.start(); - if (PlanPattern::isSimpleQuery(plan)) + if (ShortCircuitPlanner::isShortCircuitPlan(plan, context)) + { + plan.setShortCircuit(true); + optimize(plan, context, getShortCircuitRewriters()); + ShortCircuitPlanner::addExchangeIfNeeded(plan, context); + } + else if (PlanPattern::isSimpleQuery(plan)) { optimize(plan, context, getSimpleRewriters()); } diff --git a/src/Optimizer/PlanOptimizer.h b/src/Optimizer/PlanOptimizer.h index 3ea2401877a..e7b55fc485a 100644 --- a/src/Optimizer/PlanOptimizer.h +++ b/src/Optimizer/PlanOptimizer.h @@ -28,6 +28,7 @@ class PlanOptimizer static void optimize(QueryPlan & plan, ContextMutablePtr context, const Rewriters & rewriters); static const Rewriters & getSimpleRewriters(); static const Rewriters & getFullRewriters(); + static const Rewriters & getShortCircuitRewriters(); }; } diff --git a/src/Optimizer/Property/Constants.cpp b/src/Optimizer/Property/Constants.cpp index 13f6f8f4fb0..1ab851453af 100644 --- a/src/Optimizer/Property/Constants.cpp +++ b/src/Optimizer/Property/Constants.cpp @@ -27,4 +27,14 @@ Constants Constants::normalize(const SymbolEquivalences & symbol_equivalences) c return translate(mapping); } +String Constants::toString() const +{ + std::stringstream output; + output << "{"; + for (const auto & item : values) + output << " " << item.first << "=" << item.second.value.toString(); + output << "}"; + return output.str(); +} + } diff --git a/src/Optimizer/Property/Constants.h b/src/Optimizer/Property/Constants.h index cd588263ec1..4328a657ccb 100644 --- a/src/Optimizer/Property/Constants.h +++ b/src/Optimizer/Property/Constants.h @@ -26,6 +26,7 @@ class Constants Constants translate(const std::unordered_map & identities) const; Constants normalize(const SymbolEquivalences & symbol_equivalences) const; + String toString() const; private: std::map values{}; diff --git a/src/Optimizer/Property/Equivalences.h b/src/Optimizer/Property/Equivalences.h index b676f3c4f26..f681bd5d1b9 100644 --- a/src/Optimizer/Property/Equivalences.h +++ b/src/Optimizer/Property/Equivalences.h @@ -106,60 +106,6 @@ class Equivalences bool isEqual(T first, T second) const { return union_find.isConnected(first, second); } - Ptr translate(std::unordered_map & identities) const - { - auto result = std::make_shared(); - TMap> str_to_set; - for (auto & item : union_find.parent) - { - if (identities.contains(item.first)) - { - str_to_set[item.second].insert(identities[item.first]); - } - } - - for (auto & item : str_to_set) - { - auto & set = item.second; - if (set.size() > 1) - { - auto first = *set.begin(); - for (auto iter = set.begin()++; iter != set.end(); iter++) - { - result->add(first, *iter); - } - } - } - return result; - } - - Ptr translate(std::unordered_set & identities) const - { - auto result = std::make_shared(); - std::unordered_map> str_to_set; - for (auto & item : union_find.parent) - { - if (identities.contains(item.first)) - { - str_to_set[item.second].insert(item.first); - } - } - - for (auto & item : str_to_set) - { - auto & set = item.second; - if (set.size() > 1) - { - auto first = *set.begin(); - for (auto iter = set.begin()++; iter != set.end(); iter++) - { - result->add(first, *iter); - } - } - } - return result; - } - Map representMap() const { if (map) diff --git a/src/Optimizer/Property/Property.cpp b/src/Optimizer/Property/Property.cpp index a27dae81f64..92586a37d73 100644 --- a/src/Optimizer/Property/Property.cpp +++ b/src/Optimizer/Property/Property.cpp @@ -18,14 +18,17 @@ #include #include #include -#include -#include -#include #include #include +#include +#include #include +#include #include #include +#include +#include +#include "Core/Field.h" namespace DB { @@ -45,7 +48,7 @@ bool Partitioning::satisfy(const Partitioning & requirement, const Constants & c { if (requirement.require_handle) return getHandle() == requirement.getHandle() && getBuckets() == requirement.getBuckets() - && getColumns() == requirement.getColumns(); + && getColumns() == requirement.getColumns() && ASTEquality::compareTree(bucket_expr, requirement.bucket_expr); switch (requirement.component) { @@ -74,7 +77,7 @@ bool Partitioning::satisfy(const Partitioning & requirement, const Constants & c || (!requirement.isExactlyMatch() && this->isPartitionOn(requirement, constants)); default: return getHandle() == requirement.getHandle() && getBuckets() == requirement.getBuckets() - && getColumns() == requirement.getColumns(); + && getColumns() == requirement.getColumns() && ASTEquality::compareTree(bucket_expr, requirement.bucket_expr); } } @@ -106,6 +109,103 @@ bool Partitioning::isPartitionOn(const Partitioning & requirement, const Constan return true; } +bool Partitioning::isExchangeSchema(bool support_bucket_shuffle) const +{ + if (handle == Handle::BUCKET_TABLE) + { + if (support_bucket_shuffle && bucket_expr) + { + if (auto * cluster_by_ast_element = bucket_expr->as()) + { + if (cluster_by_ast_element->is_user_defined_expression) + { + if (!cluster_by_ast_element->getColumns()->as()) + return false; + } + + auto expression = extractKeyExpressionList(cluster_by_ast_element->getColumns()); + + if (auto * expr_list = expression->as()) + { + if (expr_list->children.size() != columns.size()) + return false; + for (const auto & col : expr_list->children) + { + if (auto * id = col->as()) + { + if (!id->name().starts_with("$")) + return false; + } + else + { + return false; + } + } + } + else + { + return false; + } + } + } + else + { + return false; + } + } + + return true; +} + +String Partitioning::getHashFunc(String default_func) const +{ + if (handle == Handle::BUCKET_TABLE) + { + if (bucket_expr) + { + if (auto * cluster_by_ast_element = bucket_expr->as()) + { + if (cluster_by_ast_element->is_user_defined_expression) + return "toUInt64"; + return "bucket"; + } + } + } + + return default_func; +} + + +// bucket(function_name,bucket_num,with_range,split_number)(bucket_column) +Array Partitioning::getParams() const +{ + Array result; + if (handle == Handle::BUCKET_TABLE) + { + if (bucket_expr) + { + if (auto * cluster_by_ast_element = bucket_expr->as()) + { + if (cluster_by_ast_element->is_user_defined_expression) + return result; + if (cluster_by_ast_element->split_number > 0 && columns.size() == 1) + { + result.emplace_back(Field("dtspartition")); + } + else + { + result.emplace_back(Field("sipHashBuitin")); + } + result.emplace_back(buckets); + result.emplace_back(Field(cluster_by_ast_element->is_with_range)); + result.emplace_back(Field(static_cast(cluster_by_ast_element->split_number))); + } + } + } + + return result; +} + Partitioning Partitioning::normalize(const SymbolEquivalences & symbol_equivalences) const { auto mapping = symbol_equivalences.representMap(); @@ -134,8 +234,8 @@ Partitioning Partitioning::translate(const std::unordered_map & else // note: don't discard column translate_columns.emplace_back(column); } - auto result - = Partitioning{handle, translate_columns, require_handle, buckets, enforce_round_robin, component, exactly_match, satisfy_worker}; + auto result = Partitioning{ + handle, translate_columns, require_handle, buckets, bucket_expr, enforce_round_robin, component, exactly_match, satisfy_worker}; result.setPreferred(preferred); return result; } @@ -151,6 +251,7 @@ void Partitioning::toProto(Protos::Partitioning & proto) const proto.set_enforce_round_robin(enforce_round_robin); proto.set_component(Partitioning::ComponentConverter::toProto(component)); proto.set_exactly_match(exactly_match); + serializeASTToProto(bucket_expr, *proto.mutable_bucket_expr()); } Partitioning Partitioning::fromProto(const Protos::Partitioning & proto) @@ -164,7 +265,10 @@ Partitioning Partitioning::fromProto(const Protos::Partitioning & proto) auto enforce_round_robin = proto.enforce_round_robin(); auto component = Partitioning::ComponentConverter::fromProto(proto.component()); auto exactly_match = proto.exactly_match(); - return Partitioning(handle, columns, require_handle, buckets, enforce_round_robin, component, exactly_match); + ASTPtr bucket_expr = nullptr; + if (proto.has_bucket_expr()) + bucket_expr = deserializeASTFromProto(proto.bucket_expr()); + return Partitioning(handle, columns, require_handle, buckets, bucket_expr, enforce_round_robin, component, exactly_match); } String Partitioning::toString() const @@ -213,7 +317,7 @@ String Partitioning::toString() const columns[0], [](String a, const String & b) { return std::move(a) + ", " + b; }) + "]"; - result += " BUCKETS " + std::to_string(getBuckets()); + result += " " + queryToString(bucket_expr); if (require_handle) result += " H"; if (preferred) @@ -229,6 +333,30 @@ String Partitioning::toString() const } } +SortOrder SortColumn::toReverseOrder(SortOrder sort_order) +{ + switch (sort_order) + { + case SortOrder::ASC_NULLS_FIRST: + return SortOrder::DESC_NULLS_LAST; + case SortOrder::ASC_NULLS_LAST: + return SortOrder::DESC_NULLS_FIRST; + case SortOrder::ASC_ANY: + return SortOrder::DESC_ANY; + case SortOrder::DESC_NULLS_FIRST: + return SortOrder::ASC_NULLS_LAST; + case SortOrder::DESC_NULLS_LAST: + return SortOrder::ASC_NULLS_FIRST; + case SortOrder::DESC_ANY: + return SortOrder::ASC_ANY; + case SortOrder::ANY: + return SortOrder::ANY; + case SortOrder::UNKNOWN: + return SortOrder::UNKNOWN; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "unknown sort order"); +} + size_t SortColumn::hash() const { size_t hash = MurmurHash3Impl64::apply(name.c_str(), name.size()); @@ -260,6 +388,15 @@ String SortColumn::toString() const return "unknown"; } +Sorting Sorting::toReverseOrder() const +{ + Sorting ret; + ret.reserve(size()); + for (const SortColumn & sort_column : *this) + ret.emplace_back(sort_column.toReverseOrder()); + return ret; +} + size_t Sorting::hash() const { size_t hash = IntHash64Impl::apply(this->size()); @@ -299,8 +436,9 @@ Sorting Sorting::normalize(const SymbolEquivalences & symbol_equivalences) const String Sorting::toString() const { - return empty() ? "" : std::accumulate( - std::next(begin()), end(), front().toString(), [](std::string a, const auto & b) { return std::move(a) + '-' + b.toString(); }); + return empty() ? "" : std::accumulate(std::next(begin()), end(), front().toString(), [](std::string a, const auto & b) { + return std::move(a) + '-' + b.toString(); + }); } size_t CTEDescriptions::hash() const @@ -309,7 +447,7 @@ size_t CTEDescriptions::hash() const for (const auto & item : *this) { hash = MurmurHash3Impl64::combineHashes(hash, IntHash64Impl::apply(item.first)); - hash = MurmurHash3Impl64::combineHashes(hash, item.second.hash()); + hash = MurmurHash3Impl64::combineHashes(hash, item.second.hash()); } return hash; } @@ -377,7 +515,7 @@ Property Property::normalize(const SymbolEquivalences & symbol_equivalences) con node_partitioning.normalize(symbol_equivalences), stream_partitioning.normalize(symbol_equivalences), sorting.normalize(symbol_equivalences)}; - result.setCTEDescriptions(cte_descriptions); + result.setCTEDescriptions(cte_descriptions); return result; } @@ -428,7 +566,7 @@ CTEDescription CTEDescription::inlined() } CTEDescription CTEDescription::from(const Property & property) - { +{ return CTEDescription(false, property.getNodePartitioning()); } diff --git a/src/Optimizer/Property/Property.h b/src/Optimizer/Property/Property.h index ec9b77aca0e..f28233899c2 100644 --- a/src/Optimizer/Property/Property.h +++ b/src/Optimizer/Property/Property.h @@ -84,6 +84,7 @@ class Partitioning Names columns_ = {}, bool require_handle_ = false, UInt64 buckets_ = 0, + ASTPtr bucket_expr_ = nullptr, bool enforce_round_robin_ = true, Component component_ = Component::ANY, bool exactly_match_ = false, @@ -92,6 +93,7 @@ class Partitioning , columns(std::move(columns_)) , require_handle(require_handle_) , buckets(buckets_) + , bucket_expr(bucket_expr_) , enforce_round_robin(enforce_round_robin_) , component(component_) , exactly_match(exactly_match_) @@ -115,6 +117,25 @@ class Partitioning void setComponent(Component component_) { component = component_; } bool isExactlyMatch() const { return exactly_match; } + bool isPartitionHandle() const { return handle == Handle::BUCKET_TABLE || handle == Handle::FIXED_HASH; } + + bool isExchangeSchema(bool support_bucket_shuffle) const; + + String getHashFunc(String default_func) const; + Array getParams() const; + + void resetIfPartitionHandle() + { + if (!isPartitionHandle()) + { + return; + } + this->columns = {}; + this->handle = Handle::UNKNOWN; + this->bucket_expr = nullptr; + this->buckets = 0; + } + bool isSatisfyWorker() const { return satisfy_worker; @@ -137,8 +158,12 @@ class Partitioning bool operator==(const Partitioning & other) const { return preferred == other.preferred && handle == other.handle && columns == other.columns && require_handle == other.require_handle && buckets == other.buckets - && enforce_round_robin == other.enforce_round_robin; + && enforce_round_robin == other.enforce_round_robin && ASTEquality::compareTree(bucket_expr, other.bucket_expr); } + + ASTPtr getBucketExpr() const { return bucket_expr; } + void setBucketExpr(const ASTPtr & bucket_expr_) { bucket_expr = bucket_expr_; } + String toString() const; void toProto(Protos::Partitioning & proto) const; @@ -149,6 +174,7 @@ class Partitioning Names columns; bool require_handle; UInt64 buckets; + ASTPtr bucket_expr; bool enforce_round_robin; Component component; bool exactly_match; @@ -179,7 +205,8 @@ class SortColumn return SortOrder::ASC_NULLS_LAST; else if (nulls_direction == -1) return SortOrder::ASC_NULLS_FIRST; - // else if (nulls_direction == 0) // no need, this case should return ASC_NULLS_LAST. + else if (nulls_direction == 0) + return SortOrder::ASC_ANY; } else if (direction == -1) { @@ -187,9 +214,14 @@ class SortColumn return SortOrder::DESC_NULLS_LAST; else if (nulls_direction == -1) return SortOrder::DESC_NULLS_FIRST; + else if (nulls_direction == 0) + return SortOrder::DESC_ANY; } + else if (direction == 0 && nulls_direction == 0) + return SortOrder::ANY; return SortOrder::UNKNOWN; } + static SortOrder toReverseOrder(SortOrder sort_order); SortColumn(String name_, SortOrder order_) : name(std::move(name_)), order(order_) { } explicit SortColumn(const SortColumnDescription & sort_column_description) : name(sort_column_description.column_name) @@ -199,6 +231,7 @@ class SortColumn const String & getName() const { return name; } SortOrder getOrder() const { return order; } + SortColumn toReverseOrder() const { return SortColumn{name, toReverseOrder(order)}; } SortColumnDescription toSortColumnDesc() const { @@ -281,6 +314,8 @@ class Sorting : public std::vector return res; } + Sorting toReverseOrder() const; + size_t hash() const; String toString() const; }; @@ -383,6 +418,7 @@ class Property const Partitioning & getNodePartitioning() const { return node_partitioning; } Partitioning & getNodePartitioningRef() { return node_partitioning; } const Partitioning & getStreamPartitioning() const { return stream_partitioning; } + Partitioning & getStreamPartitioningRef() { return stream_partitioning; } const Sorting & getSorting() const { return sorting; } const CTEDescriptions & getCTEDescriptions() const { return cte_descriptions; } CTEDescriptions & getCTEDescriptions() { return cte_descriptions; } diff --git a/src/Optimizer/Property/PropertyDeriver.cpp b/src/Optimizer/Property/PropertyDeriver.cpp index 9726b2177f4..a6e242ea7b3 100644 --- a/src/Optimizer/Property/PropertyDeriver.cpp +++ b/src/Optimizer/Property/PropertyDeriver.cpp @@ -19,11 +19,14 @@ #include #include +#include #include #include +#include #include #include #include +#include #include #include #include @@ -31,12 +34,9 @@ #include #include #include -#include -#include namespace DB { - namespace ErrorCodes { extern const int OPTIMIZER_NONSUPPORT; @@ -89,7 +89,7 @@ static String getClusterByHint(const StoragePtr & storage) return ""; } -Property PropertyDeriver::deriveStorageProperty(const StoragePtr & storage, const Property &, ContextMutablePtr & context) +Property PropertyDeriver::deriveStorageProperty(const StoragePtr & storage, const Property & required, ContextMutablePtr & context) { if (storage->getDatabaseName() == "system") { @@ -108,9 +108,40 @@ Property PropertyDeriver::deriveStorageProperty(const StoragePtr & storage, cons sorting.emplace_back(SortColumn(descs.column_names[i], SortOrder::ASC_NULLS_FIRST)); } + bool use_reverse_sorting = !required.getSorting().empty() + && (required.getSorting()[0].getOrder() == SortOrder::DESC_ANY || required.getSorting()[0].getOrder() == SortOrder::DESC_NULLS_FIRST + || required.getSorting()[0].getOrder() == SortOrder::DESC_NULLS_LAST); + if (use_reverse_sorting) + sorting = sorting.toReverseOrder(); + auto metadata = storage->getInMemoryMetadataPtr(); Names cluster_by; UInt64 buckets = 0; + + auto normalize_ast = [&](ASTPtr sharding_key) -> std::pair { + SymbolVisitor visitor; + Names partition_keys; + SymbolVisitorContext symbol_context; + ASTVisitorUtil::accept(sharding_key, visitor, symbol_context); + + ConstASTMap expression_map; + size_t index = 0; + for (auto symbol : symbol_context.result) + { + ASTPtr name = std::make_shared(symbol); + ASTPtr id = std::make_shared("$" + std::to_string(index)); + if (!expression_map.contains(name)) + { + expression_map[name] = ConstHashAST::make(id); + partition_keys.emplace_back(symbol); + index++; + } + } + + return {partition_keys, ExpressionRewriter::rewrite(sharding_key, expression_map)}; + }; + + ASTPtr ast; if (storage->isBucketTable()) { bool clustered = storage->isTableClustered(context); @@ -125,7 +156,9 @@ Property PropertyDeriver::deriveStorageProperty(const StoragePtr & storage, cons } else { - cluster_by = metadata->cluster_by_key.column_names; + auto [columns, rewritten] = normalize_ast(metadata->cluster_by_key.definition_ast); + cluster_by = columns; + ast = rewritten; } buckets = metadata->getBucketNumberFromClusterByKey(); } @@ -143,7 +176,16 @@ Property PropertyDeriver::deriveStorageProperty(const StoragePtr & storage, cons } #endif return Property{ - Partitioning{Partitioning::Handle::BUCKET_TABLE, cluster_by, true, buckets, true, Partitioning::Component::ANY, false, satisfyBucketWorkerRelation(storage, *context)}, + Partitioning{ + Partitioning::Handle::BUCKET_TABLE, + cluster_by, + true, + buckets, + ast, + true, + Partitioning::Component::ANY, + false, + satisfyBucketWorkerRelation(storage, *context)}, Partitioning{}, sorting}; } @@ -169,9 +211,11 @@ Property PropertyDeriver::deriveStoragePropertyWhatIfMode( Names cluster_by{what_if_table_partitioning.getPartitionKey().column}; // the bucket number is only used for matching, can be set to anything UInt64 buckets = (actual_storage_property.getNodePartitioning().getHandle() == Partitioning::Handle::BUCKET_TABLE) - ? actual_storage_property.getNodePartitioning().getBuckets() : context->getSettingsRef().memory_catalog_worker_size; + ? actual_storage_property.getNodePartitioning().getBuckets() + : context->getSettingsRef().memory_catalog_worker_size; - Partitioning new_partitioning{Partitioning::Handle::BUCKET_TABLE, cluster_by, true, buckets, true, Partitioning::Component::ANY}; + Partitioning new_partitioning{ + Partitioning::Handle::BUCKET_TABLE, cluster_by, true, buckets, nullptr, true, Partitioning::Component::ANY}; actual_storage_property.setNodePartitioning(new_partitioning); return actual_storage_property; @@ -368,7 +412,7 @@ Property DeriverVisitor::visitAggregatingStep(const AggregatingStep &, DeriverCo Property DeriverVisitor::visitMarkDistinctStep(const MarkDistinctStep &, DeriverContext & context) { - return context.getInput()[0].clearSorting(); + return context.getInput()[0].clearSorting(); } Property DeriverVisitor::visitMergingAggregatedStep(const MergingAggregatedStep &, DeriverContext & context) @@ -453,6 +497,7 @@ Property DeriverVisitor::visitUnionStep(const UnionStep & step, DeriverContext & output_keys, true, first_child_property.getNodePartitioning().getBuckets(), + first_child_property.getNodePartitioning().getBucketExpr(), first_child_property.getNodePartitioning().isEnforceRoundRobin(), first_child_property.getNodePartitioning().getComponent(), false, @@ -466,6 +511,7 @@ Property DeriverVisitor::visitUnionStep(const UnionStep & step, DeriverContext & output_keys, true, first_child_property.getNodePartitioning().getBuckets(), + first_child_property.getNodePartitioning().getBucketExpr(), first_child_property.getNodePartitioning().isEnforceRoundRobin(), first_child_property.getNodePartitioning().getComponent(), false, @@ -536,7 +582,8 @@ Property DeriverVisitor::visitTableScanStep(const TableScanStep & step, DeriverC translation.emplace(item.first, item.second); if (!context.getRequire().getTableLayout().empty()) - return PropertyDeriver::deriveStoragePropertyWhatIfMode(step.getStorage(), context.getContext(), context.getRequire()).translate(translation); + return PropertyDeriver::deriveStoragePropertyWhatIfMode(step.getStorage(), context.getContext(), context.getRequire()) + .translate(translation); return PropertyDeriver::deriveStorageProperty(step.getStorage(), context.getRequire(), context.getContext()).translate(translation); } @@ -691,9 +738,12 @@ Property DeriverVisitor::visitMultiJoinStep(const MultiJoinStep &, DeriverContex return context.getInput()[0]; } -Property DeriverVisitor::visitExpandStep(const ExpandStep&, DeriverContext & context) +Property DeriverVisitor::visitExpandStep(const ExpandStep &, DeriverContext & context) { - return context.getInput()[0]; + auto prop = context.getInput()[0].clearSorting(); + prop.getNodePartitioningRef().resetIfPartitionHandle(); + prop.getStreamPartitioningRef().resetIfPartitionHandle(); + return prop; } } diff --git a/src/Optimizer/Property/PropertyDeterminer.cpp b/src/Optimizer/Property/PropertyDeterminer.cpp index 11730f5486d..e617551c67f 100644 --- a/src/Optimizer/Property/PropertyDeterminer.cpp +++ b/src/Optimizer/Property/PropertyDeterminer.cpp @@ -138,8 +138,8 @@ PropertySets DeterminerVisitor::visitJoinStep(const JoinStep & step, DeterminerC Partitioning left_stream{Partitioning::Handle::FIXED_HASH, left_keys_asof}; Partitioning right_stream{Partitioning::Handle::FIXED_HASH, right_keys_asof}; - Property left{Partitioning{Partitioning::Handle::FIXED_HASH, left_keys_asof, false, 0, enforce_round_robine}, left_stream}; - Property right{Partitioning{Partitioning::Handle::FIXED_HASH, right_keys_asof, false, 0, false}, right_stream}; + Property left{Partitioning{Partitioning::Handle::FIXED_HASH, left_keys_asof, false, 0, nullptr, enforce_round_robine}, left_stream}; + Property right{Partitioning{Partitioning::Handle::FIXED_HASH, right_keys_asof, false, 0, nullptr, false}, right_stream}; PropertySet set; set.emplace_back(left); set.emplace_back(right); @@ -184,8 +184,8 @@ PropertySets DeterminerVisitor::visitJoinStep(const JoinStep & step, DeterminerC Partitioning left_stream{Partitioning::Handle::FIXED_HASH, sub_left_keys}; Partitioning right_stream{Partitioning::Handle::FIXED_HASH, sub_right_keys}; - Property left{Partitioning{Partitioning::Handle::FIXED_HASH, sub_left_keys, false, 0, enforce_round_robine}, left_stream}; - Property right{Partitioning{Partitioning::Handle::FIXED_HASH, sub_right_keys, false, 0, false}, right_stream}; + Property left{Partitioning{Partitioning::Handle::FIXED_HASH, sub_left_keys, false, 0, nullptr, enforce_round_robine}, left_stream}; + Property right{Partitioning{Partitioning::Handle::FIXED_HASH, sub_right_keys, false, 0, nullptr, false}, right_stream}; PropertySet prop_set; prop_set.emplace_back(left); prop_set.emplace_back(right); @@ -196,8 +196,8 @@ PropertySets DeterminerVisitor::visitJoinStep(const JoinStep & step, DeterminerC { Partitioning left_stream{Partitioning::Handle::FIXED_HASH, left_keys}; Partitioning right_stream{Partitioning::Handle::FIXED_HASH, right_keys}; - Property left{Partitioning{Partitioning::Handle::FIXED_HASH, left_keys, false, 0, enforce_round_robine}, left_stream}; - Property right{Partitioning{Partitioning::Handle::FIXED_HASH, right_keys, false, 0, false}, right_stream}; + Property left{Partitioning{Partitioning::Handle::FIXED_HASH, left_keys, false, 0, nullptr, enforce_round_robine}, left_stream}; + Property right{Partitioning{Partitioning::Handle::FIXED_HASH, right_keys, false, 0, nullptr, false}, right_stream}; PropertySet prop_set; prop_set.emplace_back(left); prop_set.emplace_back(right); @@ -284,7 +284,7 @@ PropertySets DeterminerVisitor::visitAggregatingStep(const AggregatingStep & ste { keys.emplace_back("__grouping_set"); return {PropertySet{ - Property{Partitioning{Partitioning::Handle::FIXED_HASH, keys, false, 0, true, Partitioning::Component::ANY, true}}}}; + Property{Partitioning{Partitioning::Handle::FIXED_HASH, keys, false, 0, nullptr, true, Partitioning::Component::ANY, true}}}}; } return sets; @@ -549,13 +549,11 @@ PropertySets DeterminerVisitor::visitFillingStep(const FillingStep &, Determiner PropertySets DeterminerVisitor::visitTableWriteStep(const TableWriteStep & step, DeterminerContext & context) { auto node = Partitioning{Partitioning::Handle::FIXED_ARBITRARY}; - if (const auto * cnch_table = dynamic_cast(step.getTarget()->getStorage().get())) + const auto * cnch_table = dynamic_cast(step.getTarget()->getStorage().get()); + if (cnch_table && !cnch_table->supportsWriteInWorkers(context.getContext())) { // unique table can't support do TableWrite in many workers. - if (cnch_table->getInMemoryMetadataPtr()->hasUniqueKey() && !context.getContext().getSettingsRef().enable_staging_area_for_write) - { - node = Partitioning{Partitioning::Handle::SINGLE}; - } + node = Partitioning{Partitioning::Handle::SINGLE}; } node.setComponent(Partitioning::Component::WORKER); return {{Property{node}}}; diff --git a/src/Optimizer/Property/PropertyEnforcer.cpp b/src/Optimizer/Property/PropertyEnforcer.cpp index ee2a39f3ffb..d96de3c5548 100644 --- a/src/Optimizer/Property/PropertyEnforcer.cpp +++ b/src/Optimizer/Property/PropertyEnforcer.cpp @@ -86,8 +86,6 @@ QueryPlanStepPtr PropertyEnforcer::enforceNodePartitioning( { case Partitioning::Handle::SINGLE: return std::make_unique(streams, ExchangeMode::GATHER, partitioning, keep_order); - case Partitioning::Handle::FIXED_HASH: - return std::make_unique(streams, ExchangeMode::REPARTITION, partitioning, keep_order); case Partitioning::Handle::FIXED_BROADCAST: return std::make_unique(streams, ExchangeMode::BROADCAST, partitioning, keep_order); case Partitioning::Handle::FIXED_ARBITRARY: @@ -99,7 +97,9 @@ QueryPlanStepPtr PropertyEnforcer::enforceNodePartitioning( return std::make_unique(streams, ExchangeMode::LOCAL_NO_NEED_REPARTITION, partitioning, keep_order); case Partitioning::Handle::ARBITRARY: return nullptr; + case Partitioning::Handle::FIXED_HASH: case Partitioning::Handle::BUCKET_TABLE: + return std::make_unique(streams, ExchangeMode::REPARTITION, partitioning, keep_order); default: throw Exception("Property Enforce error", ErrorCodes::ILLEGAL_ENFORCE); } diff --git a/src/Optimizer/Property/PropertyMatcher.cpp b/src/Optimizer/Property/PropertyMatcher.cpp index 38d0d81d2cb..c6c7bfbadf3 100644 --- a/src/Optimizer/Property/PropertyMatcher.cpp +++ b/src/Optimizer/Property/PropertyMatcher.cpp @@ -68,9 +68,9 @@ bool PropertyMatcher::matchStreamPartitioning( } Sorting PropertyMatcher::matchSorting( - const Context & context, const Sorting & required, const Sorting & actual, const SymbolEquivalences & equivalences) + const Context & context, const Sorting & required, const Sorting & actual, const SymbolEquivalences & equivalences, const Constants & constants) { - return matchSorting(context, required.toSortDesc(), actual, equivalences); + return matchSorting(context, required.toSortDesc(), actual, equivalences, constants); } /// Optimize in case of exact match with order key element @@ -129,15 +129,10 @@ SortOrder matchSortDescription(const SortColumnDescription & require, const Sort return SortOrder::UNKNOWN; } -Sorting PropertyMatcher::matchSorting(const Context &, const SortDescription & required, const Sorting & actual, const SymbolEquivalences &) +Sorting PropertyMatcher::matchSorting(const Context &, const SortDescription & required, const Sorting & actual, const SymbolEquivalences &, const Constants & constants) { if (!actual.empty()) { - SortOrder read_direction = SortOrder::UNKNOWN; - - // todo@jingpeng.mt constant - // auto fixed_sorting_columns = getFixedSortingColumns(query, sorting_key_columns, context); - SortDescription sort_description_for_merging; sort_description_for_merging.reserve(required.size()); @@ -147,24 +142,26 @@ Sorting PropertyMatcher::matchSorting(const Context &, const SortDescription & r while (desc_pos < required.size() && key_pos < actual.size()) { auto match = matchSortDescription(required[desc_pos], actual[key_pos].toSortColumnDesc()); - bool is_matched = match != SortOrder::UNKNOWN && (desc_pos == 0 || match == read_direction); - + bool is_matched = match != SortOrder::UNKNOWN; if (!is_matched) { /// If one of the sorting columns is constant after filtering, /// skip it, because it won't affect order anymore. - // if (fixed_sorting_columns.contains(sorting_key_columns[key_pos])) - // { - // ++key_pos; - // continue; - // } + if (constants.contains(actual[key_pos].getName())) + { + ++key_pos; + continue; + } + else if (constants.contains(required[desc_pos].column_name)) + { + sort_description_for_merging.push_back(required[desc_pos]); + ++desc_pos; + continue; + } break; } - if (desc_pos == 0) - read_direction = match; - sort_description_for_merging.push_back(required[desc_pos]); ++desc_pos; diff --git a/src/Optimizer/Property/PropertyMatcher.h b/src/Optimizer/Property/PropertyMatcher.h index c76dded39e8..2908a746b53 100644 --- a/src/Optimizer/Property/PropertyMatcher.h +++ b/src/Optimizer/Property/PropertyMatcher.h @@ -32,10 +32,10 @@ class PropertyMatcher const Context & context, const Partitioning & required, const Partitioning & actual, const SymbolEquivalences & equivalences = {}, const Constants & constants = {}, bool match_local_exchange = true); static Sorting - matchSorting(const Context & context, const Sorting & required, const Sorting & actual, const SymbolEquivalences & equivalences = {}); + matchSorting(const Context & context, const Sorting & required, const Sorting & actual, const SymbolEquivalences & equivalences = {}, const Constants & constants = {}); static Sorting matchSorting( - const Context & context, const SortDescription & required, const Sorting & actual, const SymbolEquivalences & equivalences = {}); + const Context & context, const SortDescription & required, const Sorting & actual, const SymbolEquivalences & equivalences = {}, const Constants & constants = {}); static Property compatibleCommonRequiredProperty(const std::unordered_set & properties); }; diff --git a/src/Optimizer/Property/SymbolEquivalencesDeriver.cpp b/src/Optimizer/Property/SymbolEquivalencesDeriver.cpp index 0870d09d4ab..d12c557c56c 100644 --- a/src/Optimizer/Property/SymbolEquivalencesDeriver.cpp +++ b/src/Optimizer/Property/SymbolEquivalencesDeriver.cpp @@ -69,26 +69,15 @@ SymbolEquivalencesDeriverVisitor::visitProjectionStep(const ProjectionStep & ste { const auto & assignments = step.getAssignments(); std::unordered_map identities = Utils::computeIdentityTranslations(assignments); - std::unordered_map revert_identifies; - - for (auto & item : identities) - { - revert_identifies[item.second] = item.first; - } - - auto equivalences = context[0]->translate(revert_identifies); for (auto & item : identities) - { - equivalences->add(item.second, item.first); - } - return equivalences; + context[0]->add(item.second, item.first); + return context[0]; } SymbolEquivalencesPtr -SymbolEquivalencesDeriverVisitor::visitAggregatingStep(const AggregatingStep & step, std::vector & context) +SymbolEquivalencesDeriverVisitor::visitAggregatingStep(const AggregatingStep &, std::vector & context) { - NameSet set{step.getKeys().begin(), step.getKeys().end()}; - return context[0]->translate(set); + return context[0]; } SymbolEquivalencesPtr SymbolEquivalencesDeriverVisitor::visitExchangeStep(const ExchangeStep &, std::vector & context) @@ -99,10 +88,12 @@ SymbolEquivalencesDeriverVisitor::visitExchangeStep(const ExchangeStep &, std::v SymbolEquivalencesPtr SymbolEquivalencesDeriverVisitor::visitCTERefStep(const CTERefStep & step, std::vector & context) { - auto mapping = step.getReverseOutputColumns(); if (!context.empty() && context[0]) { - context[0]->translate(mapping); + auto mappings = step.getOutputColumns(); + for (const auto & mapping : mappings) + context[0]->add(mapping.first, mapping.second); + return context[0]; } return std::make_shared(); } diff --git a/src/Optimizer/PushProjectionThroughJoin.cpp b/src/Optimizer/PushProjectionThroughJoin.cpp index 25ce43fffa7..00a391532cd 100644 --- a/src/Optimizer/PushProjectionThroughJoin.cpp +++ b/src/Optimizer/PushProjectionThroughJoin.cpp @@ -216,6 +216,7 @@ std::optional PushProjectionThroughJoin::pushProjectionThroughJoin( join_step.getKeepLeftReadInOrder(), join_step.getLeftKeys(), join_step.getRightKeys(), + join_step.getKeyIdsNullSafe(), join_step.getFilter(), join_step.isHasUsing(), join_step.getRequireRightKeys(), diff --git a/src/Optimizer/QueryUseOptimizerChecker.cpp b/src/Optimizer/QueryUseOptimizerChecker.cpp index 8acf59fa875..e703c42b21f 100644 --- a/src/Optimizer/QueryUseOptimizerChecker.cpp +++ b/src/Optimizer/QueryUseOptimizerChecker.cpp @@ -192,7 +192,7 @@ bool QueryUseOptimizerChecker::check(ASTPtr node, ContextMutablePtr context, boo if (!checkDatabaseAndTable(database, insert_query->table_id.getTableName(), context, {})) { - reason = "unsupported storage"; + reason = "unsupported storage, database: " + database + ", table: " + insert_query->table_id.getTableName(); support = false; } } @@ -247,6 +247,12 @@ bool QueryUseOptimizerVisitor::visitASTSelectQuery(ASTPtr & node, QueryUseOptimi return false; } + if (context.disallow_subquery) + { + reason = "nullIn/globalNullIn/notNullIn/globalNotNullIn function with subquery not implemented"; + return false; + } + if (select->group_by_with_totals && context.disallow_with_totals) { reason = "group by with totals only supports with totals at outmost select"; @@ -261,7 +267,7 @@ bool QueryUseOptimizerVisitor::visitASTSelectQuery(ASTPtr & node, QueryUseOptimi { if (!checkDatabaseAndTable(*table_expression, child_context.context, child_context.ctes)) { - reason = "unsupported storage"; + reason = "unsupported storage: " + table_expression->formatForErrorMessage(); return false; } if (table_expression->table_function) @@ -297,7 +303,7 @@ bool QueryUseOptimizerVisitor::visitASTFunction(ASTPtr & node, QueryUseOptimizer auto & fun = node->as(); if (fun.name == "untuple") { - reason = "unsupported function"; + reason = "unsupported untuple function"; return false; } @@ -311,13 +317,17 @@ bool QueryUseOptimizerVisitor::visitASTFunction(ASTPtr & node, QueryUseOptimizer table_expression.database_and_table_name = table; if (!checkDatabaseAndTable(table_expression, context.context, context.ctes)) { - reason = "unsupported storage"; + reason = "unsupported storage: " + table_expression.formatForErrorMessage(); return false; } } } } - return visitNode(node, context); + bool disallow_subquery = context.disallow_subquery; + context.disallow_subquery = disallow_subquery || (fun.name == "nullIn" || fun.name == "globalNullIn" || fun.name == "notNullIn" || fun.name == "globalNotNullIn"); + bool support = visitNode(node, context); + context.disallow_subquery = disallow_subquery; + return support; } bool QueryUseOptimizerVisitor::visitASTQuantifiedComparison(ASTPtr & node, QueryUseOptimizerContext & context) diff --git a/src/Optimizer/QueryUseOptimizerChecker.h b/src/Optimizer/QueryUseOptimizerChecker.h index 141c7d0e41d..9ecd3ee1e5f 100644 --- a/src/Optimizer/QueryUseOptimizerChecker.h +++ b/src/Optimizer/QueryUseOptimizerChecker.h @@ -39,6 +39,7 @@ struct QueryUseOptimizerContext NameSet ctes; Tables external_tables; bool disallow_with_totals = false; + bool disallow_subquery = false; }; class QueryUseOptimizerVisitor : public ASTVisitor diff --git a/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.cpp b/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.cpp index 360efef1030..961d03a2fbf 100644 --- a/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.cpp +++ b/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.cpp @@ -1,11 +1,18 @@ +#include #include #include +#include +#include +#include #include +#include +#include #include #include #include #include #include +#include #include #include #include @@ -24,6 +31,10 @@ namespace class UpdateParentExecuteOrderVisitor; // visitor to update execute_order from a join node to left descendant node class UpdateLeftExecuteOrderVisitor; + + // visitor to find cte ref exists both sides of join build and probe to add buffer + class FindAllCTEIfExistsOnJoinBuildVisitor; + // visitor to add BufferStep for deadlock CTEs class AddBufferVisitor; @@ -35,7 +46,14 @@ namespace using VisitEntry = std::pair; using VisitPath = std::vector; - using ExecuteOrderMap = std::unordered_map>; + struct CTEExecuteOrder + { + PlanNodeId node_id; + CTEId cte_id; + int execute_order; + }; + + using ExecuteOrders = std::vector; class FindDirectRightVisitor : public PlanNodeVisitor { @@ -50,7 +68,7 @@ namespace CTEInfo & cte_info; Poco::Logger * logger; - std::unordered_set deadlock_ctes; + std::unordered_set deadlock_ctes; VisitPath visit_path; }; @@ -67,14 +85,14 @@ namespace CTEInfo & cte_info; VisitPath visit_path; - ExecuteOrderMap execute_orders; + ExecuteOrders execute_orders; int cur_execute_order = 0; }; class UpdateLeftExecuteOrderVisitor : public PlanNodeVisitor { public: - UpdateLeftExecuteOrderVisitor(CTEInfo & cte_info_, ExecuteOrderMap & execute_orders_, int execute_order_) + UpdateLeftExecuteOrderVisitor(CTEInfo & cte_info_, ExecuteOrders & execute_orders_, int execute_order_) : cte_info(cte_info_), execute_orders(execute_orders_), execute_order(execute_order_) { } @@ -84,21 +102,23 @@ namespace void visitJoinNode(JoinNode & node, const Void &) override; CTEInfo & cte_info; - ExecuteOrderMap & execute_orders; + ExecuteOrders & execute_orders; const int execute_order = 0; }; class AddBufferVisitor : public SimplePlanRewriter { public: - AddBufferVisitor(const std::unordered_set & deadlock_ctes_, ContextMutablePtr context_, CTEInfo & cte_info_) - : SimplePlanRewriter(std::move(context_), cte_info_), deadlock_ctes(deadlock_ctes_) + AddBufferVisitor( + const std::unordered_set & deadlock_ctes_, ContextMutablePtr context_, CTEInfo & cte_info_, Poco::Logger * logger_) + : SimplePlanRewriter(std::move(context_), cte_info_), deadlock_ctes(deadlock_ctes_), logger(logger_) { } PlanNodePtr visitCTERefNode(CTERefNode & node, const Void & c) override; - const std::unordered_set & deadlock_ctes; + const std::unordered_set & deadlock_ctes; + Poco::Logger * logger; }; void FindDirectRightVisitor::visitPlanNode(PlanNodeBase & node, const JoinPath & join_path) @@ -111,19 +131,31 @@ namespace VisitorUtil::accept(node, update_parent_order_visitor, {}); const auto & execute_orders = update_parent_order_visitor.execute_orders; - if (logger && logger->is(Poco::Message::PRIO_TRACE)) + LOG_TRACE(logger, "FindDirectRightVisitor visit on node {}", node.getId()); + + std::unordered_map cte_min_execute_orders; + for (const auto & execute_order : execute_orders) { - std::ostringstream os; - for (const auto & [node_id, node_orders] : execute_orders) - os << node_id << "->" << fmt::format("({})", fmt::join(node_orders, ",")) << " "; - LOG_TRACE(logger, "Direct right node id: {}, calculated execute order: {}", node.getId(), os.str()); + auto it = cte_min_execute_orders.find(execute_order.cte_id); + if (it == cte_min_execute_orders.end()) + cte_min_execute_orders.emplace(execute_order.cte_id, execute_order.execute_order); + else + it->second = std::min(it->second, execute_order.execute_order); } - for (const auto & [cte_id, cte_def_node] : cte_info.getCTEs()) + // find deadlock ctes + for (const auto & execute_order : execute_orders) { - auto cte_def_node_id = cte_def_node->getId(); - if (execute_orders.count(cte_def_node_id) && execute_orders.at(cte_def_node_id).size() > 1) - deadlock_ctes.emplace(cte_id); + LOG_TRACE( + logger, + "Direct right node id: {}, cte_id: {}, execute order: {}, cte min execute order: {}", + execute_order.node_id, + execute_order.cte_id, + execute_order.execute_order, + cte_min_execute_orders[execute_order.cte_id]); + + if (execute_order.execute_order > cte_min_execute_orders[execute_order.cte_id]) + deadlock_ctes.emplace(execute_order.node_id); } } @@ -158,7 +190,6 @@ namespace { assert(visit_path.back().second == JoinPath::RIGHT); visit_path.pop_back(); - execute_orders[node.getId()].emplace(cur_execute_order); if (!visit_path.empty()) { @@ -179,7 +210,8 @@ namespace { auto join_path = visit_path.back().second; visit_path.pop_back(); - execute_orders[node.getId()].emplace(++cur_execute_order); + + ++cur_execute_order; // update execute_order for left tree of join node UpdateLeftExecuteOrderVisitor update_left_order_visitor{cte_info, execute_orders, cur_execute_order}; @@ -194,7 +226,6 @@ namespace void UpdateLeftExecuteOrderVisitor::visitPlanNode(PlanNodeBase & node, const Void & ctx) { - execute_orders[node.getId()].emplace(execute_order); for (auto & child : node.getChildren()) VisitorUtil::accept(*child, *this, ctx); @@ -202,16 +233,14 @@ namespace void UpdateLeftExecuteOrderVisitor::visitCTERefNode(CTERefNode & node, const Void & ctx) { - execute_orders[node.getId()].emplace(execute_order); auto cte_id = node.getStep()->getId(); + execute_orders.emplace_back(CTEExecuteOrder{node.getId(), cte_id, execute_order}); VisitorUtil::accept(*cte_info.getCTEDef(cte_id), *this, ctx); } void UpdateLeftExecuteOrderVisitor::visitJoinNode(JoinNode & node, const Void & ctx) { - execute_orders[node.getId()].emplace(execute_order); - VisitorUtil::accept(*node.getChildren().at(0), *this, ctx); } @@ -220,41 +249,160 @@ namespace SimplePlanRewriter::visitCTERefNode(node, c); auto cte_id = node.getStep()->getId(); - if (!deadlock_ctes.count(cte_id)) - { + if (!deadlock_ctes.count(node.getId())) return node.shared_from_this(); + + /** + * if buffer size exceed max_buffer_size_for_deadlock_cte, we inline cte instead of add buffer. + * + * note: max buffer size for tpcds 1t is 7994883314, so we set max_buffer_size_for_deadlock_cte + * 8'000'000'000 bytes (8Gb) by default for tpcds 1T + */ + Int64 max_buffer_size = context->getSettingsRef().max_buffer_size_for_deadlock_cte; + if (max_buffer_size == 0) + { + LOG_TRACE(logger, "Inline CTE {} because max_buffer_size_for_deadlock_cte=0", cte_id); + return node.getStep()->toInlinedPlanNode(cte_helper.getCTEInfo(), context); } - else + + if (max_buffer_size > 0) { - QueryPlanStepPtr buffer_step = std::make_shared(node.getCurrentDataStream()); - PlanNodePtr buffer_node = PlanNodeBase::createPlanNode( - context->nextNodeId(), std::move(buffer_step), {node.shared_from_this()}, node.getStatistics()); - return buffer_node; + auto stats = CardinalityEstimator::estimate(node, cte_helper.getCTEInfo(), context); + if (!stats) + { + LOG_TRACE(logger, "Inline CTE {} because estimates stats failed", cte_id); + return node.getStep()->toInlinedPlanNode(cte_helper.getCTEInfo(), context); + } + + Int64 buffer_size = (*stats)->getOutputSizeInBytes(); + LOG_TRACE(logger, "CTE {} estimated buffer size {}", cte_id, (*stats)->getOutputSizeInBytes()); + if (buffer_size > max_buffer_size) + { + LOG_TRACE( + logger, + "Inline CTE {} because estimates buffer size {} is bigger than max_buffer_size_for_deadlock_cte({})", + cte_id, + buffer_size, + max_buffer_size); + return node.getStep()->toInlinedPlanNode(cte_helper.getCTEInfo(), context); + } } + + QueryPlanStepPtr buffer_step = std::make_shared(node.getCurrentDataStream()); + PlanNodePtr buffer_node + = PlanNodeBase::createPlanNode(context->nextNodeId(), std::move(buffer_step), {node.shared_from_this()}, node.getStatistics()); + return buffer_node; } + + class FindAllCTEIfExistsOnJoinBuildVisitor : public PlanNodeVisitor, Void> + { + public: + explicit FindAllCTEIfExistsOnJoinBuildVisitor(CTEInfo & cte_info) : cte_helper(cte_info) + { + } + + std::unordered_set visitPlanNode(PlanNodeBase & node, Void & c) override + { + std::unordered_set ctes; + for (const auto & child : node.getChildren()) + { + auto child_ctes = VisitorUtil::accept(*child, *this, c); + ctes.insert(child_ctes.begin(), child_ctes.end()); + } + return ctes; + } + + std::unordered_set visitCTERefNode(CTERefNode & node, Void & c) override + { + const auto * cte_step = dynamic_cast(node.getStep().get()); + auto cte_id = cte_step->getId(); + cte_refs[cte_id].emplace_back(node.getId()); + + auto ctes = cte_helper.accept(cte_id, *this, c); + ctes.emplace(cte_id); + + return ctes; + } + + std::unordered_set visitJoinNode(JoinNode & node, Void & c) override + { + auto left_ctes = VisitorUtil::accept(*node.getChildren()[0], *this, c); + auto right_ctes = VisitorUtil::accept(*node.getChildren()[1], *this, c); + for (const auto & cte_id : right_ctes) + { + for (const auto & node_id : cte_refs[cte_id]) + deadlock_ctes.emplace(node_id); + } + + left_ctes.insert(right_ctes.begin(), left_ctes.end()); + return left_ctes; + } + + SimpleCTEVisitHelper> cte_helper; + std::unordered_map> cte_refs; + + std::unordered_set deadlock_ctes; + }; } void AddBufferForDeadlockCTE::rewrite(QueryPlan & plan, ContextMutablePtr context) const { static auto * logger = &Poco::Logger::get("AddBufferForDeadlockCTE"); - FindDirectRightVisitor find_deadlock_cte_visitor{plan.getCTEInfo(), logger}; - VisitorUtil::accept(plan.getPlanNode(), find_deadlock_cte_visitor, JoinPath::RIGHT); + if (plan.getCTEInfo().empty()) + return; + + std::unordered_set deadlock_ctes; - if (logger && logger->is(Poco::Message::PRIO_DEBUG) && !find_deadlock_cte_visitor.deadlock_ctes.empty()) + // fixme: fix deadlock algorithm to enable this settings + if (context->getSettings().enable_remove_remove_unnecessary_buffer) + { + FindDirectRightVisitor find_deadlock_cte_visitor{plan.getCTEInfo(), logger}; + VisitorUtil::accept(plan.getPlanNode(), find_deadlock_cte_visitor, JoinPath::RIGHT); + deadlock_ctes = std::move(find_deadlock_cte_visitor.deadlock_ctes); + } + else + { + FindAllCTEIfExistsOnJoinBuildVisitor find_all_cte_ref_visitor{plan.getCTEInfo()}; + Void c; + VisitorUtil::accept(plan.getPlanNode(), find_all_cte_ref_visitor, c); + deadlock_ctes = std::move(find_all_cte_ref_visitor.deadlock_ctes); + } + + if (deadlock_ctes.empty()) + return; + + if (logger->debug()) { std::ostringstream os; - for (const auto & cte_id : find_deadlock_cte_visitor.deadlock_ctes) - os << cte_id << '#' << plan.getCTEInfo().getCTEDef(cte_id)->getId() << ", "; - LOG_DEBUG(logger, "Detected deadlock ctes(cte_id#plan_node_id): {}", os.str()); + for (const auto & cte_ref_id : deadlock_ctes) + os << cte_ref_id << ", "; + LOG_DEBUG(logger, "Detected deadlock ctes(cte_ref_id): {}", os.str()); } - AddBufferVisitor add_buffer_visitor{find_deadlock_cte_visitor.deadlock_ctes, context, plan.getCTEInfo()}; - VisitorUtil::accept(plan.getPlanNode(), add_buffer_visitor, {}); + AddBufferVisitor add_buffer_visitor{deadlock_ctes, context, plan.getCTEInfo(), logger}; + plan.update(VisitorUtil::accept(plan.getPlanNode(), add_buffer_visitor, {})); RewriterPtr push_limit_through_buffer = std::make_shared( std::vector{std::make_shared()}, "PushDownLimitThroughBuffer"); push_limit_through_buffer->rewritePlan(plan, context); + + if (context->getSettingsRef().max_buffer_size_for_deadlock_cte >= 0) + { + static Rewriters rewriters + = {std::make_shared(), + std::make_shared(), + std::make_shared(false, true), + std::make_shared(Rules::inlineProjectionRules(), "InlineProjection"), + std::make_shared(), + std::make_shared(Rules::normalizeExpressionRules(), "NormalizeExpression"), + std::make_shared(Rules::swapPredicateRules(), "SwapPredicate"), + std::make_shared(Rules::simplifyExpressionRules(), "SimplifyExpression"), + std::make_shared(Rules::removeRedundantRules(), "RemoveRedundant")}; + + for (auto & rewriter : rewriters) + rewriter->rewritePlan(plan, context); + } } } diff --git a/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.h b/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.h index d4adb0cbd54..5638ae3ccba 100644 --- a/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.h +++ b/src/Optimizer/Rewriter/AddBufferForDeadlockCTE.h @@ -43,10 +43,9 @@ namespace DB /// the rewriter will output: /// Join /// / \ -/// Buffer Buffer -/// | | -/// CTERef[0] CTERef[0] -/// TODO: add buffer step only on left table side +/// Buffer CTERef[0] +/// | +/// CTERef[0] /// /// /// Currently the algorithm will add buffer step aggresively to solve cyclic deadlock ctes. diff --git a/src/Optimizer/Rewriter/AddRuntimeFilters.cpp b/src/Optimizer/Rewriter/AddRuntimeFilters.cpp index e787a88277d..362b43341e7 100644 --- a/src/Optimizer/Rewriter/AddRuntimeFilters.cpp +++ b/src/Optimizer/Rewriter/AddRuntimeFilters.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -161,6 +162,19 @@ PlanPropEquivalences AddRuntimeFilters::AddRuntimeFilterRewriter::visitJoinNode( if (!is_broadcast && isFixedHashShuffleOrBucketTableShuffle(left.property)) { partition_columns = left.property.getNodePartitioning().getColumns(); + bool all_contains = std::all_of( + partition_columns.begin(), partition_columns.end(), [&](const auto & column) { + return left.plan->getStep()->getOutputStream().header.has(column); + }); + if (!all_contains) + { + LOG_WARNING( + logger, + "partition columns not found in AddRuntimeFilteres, required: {}, left output: {}", + fmt::join(partition_columns, ", "), + fmt::join(left.plan->getOutputNames(), ", ")); + break; + } } double selectivity; @@ -196,6 +210,7 @@ PlanPropEquivalences AddRuntimeFilters::AddRuntimeFilterRewriter::visitJoinNode( join.getKeepLeftReadInOrder(), join.getLeftKeys(), join.getRightKeys(), + join.getKeyIdsNullSafe(), join.getFilter(), join.isHasUsing(), join.getRequireRightKeys(), @@ -530,6 +545,7 @@ PlanNodePtr AddRuntimeFilters::RemoveUnusedRuntimeFilterProbRewriter::visitJoinN join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), filters, join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -615,6 +631,7 @@ PlanNodePtr AddRuntimeFilters::RemoveUnusedRuntimeFilterBuildRewriter::visitJoin join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -669,20 +686,27 @@ PlanNodePtr AddRuntimeFilters::AddExchange::visitJoinNode(JoinNode & node, std:: } // fixme: fix buffer step to remove this method -PlanNodePtr AddRuntimeFilters::AddExchange::visitCTERefNode(CTERefNode & node, std::unordered_set & need_exchange) +PlanNodePtr AddRuntimeFilters::AddExchange::visitBufferNode(BufferNode & node, std::unordered_set & need_exchange) { + auto res = SimplePlanRewriter::visitBufferNode(node, need_exchange); if (need_exchange.empty()) - return SimplePlanRewriter::visitPlanNode(node, need_exchange); + return res; return PlanNodeBase::createPlanNode( context->nextNodeId(), std::make_unique( - DataStreams{node.getCurrentDataStream()}, + DataStreams{res->getCurrentDataStream()}, ExchangeMode::LOCAL_NO_NEED_REPARTITION, Partitioning{Partitioning::Handle::FIXED_ARBITRARY}, context->getSettingsRef().enable_shuffle_with_order), - PlanNodes{node.shared_from_this()}, - node.getStatistics()); + PlanNodes{res}, + res->getStatistics()); +} + +PlanNodePtr AddRuntimeFilters::AddExchange::visitCTERefNode(CTERefNode & node, std::unordered_set &) +{ + std::unordered_set need_exchange; + return SimplePlanRewriter::visitCTERefNode(node, need_exchange); } PlanNodePtr AddRuntimeFilters::AddExchange::visitFilterNode(FilterNode & node, std::unordered_set & need_exchange) diff --git a/src/Optimizer/Rewriter/AddRuntimeFilters.h b/src/Optimizer/Rewriter/AddRuntimeFilters.h index d405cca689b..671fb02dfa5 100644 --- a/src/Optimizer/Rewriter/AddRuntimeFilters.h +++ b/src/Optimizer/Rewriter/AddRuntimeFilters.h @@ -70,6 +70,7 @@ class AddRuntimeFilters::AddRuntimeFilterRewriter : public PlanNodeVisitor cte_helper; + Poco::Logger * logger = &Poco::Logger::get("AddRuntimeFilters"); }; struct RuntimeFilterContext @@ -164,6 +165,7 @@ class AddRuntimeFilters::AddExchange : public SimplePlanRewriter &) override; PlanNodePtr visitJoinNode(JoinNode & node, std::unordered_set &) override; PlanNodePtr visitCTERefNode(CTERefNode & node, std::unordered_set &) override; + PlanNodePtr visitBufferNode(BufferNode & node, std::unordered_set &) override; }; } diff --git a/src/Optimizer/Rewriter/ColumnPruning.cpp b/src/Optimizer/Rewriter/ColumnPruning.cpp index 67f3d4ab52c..0802eed9f2f 100644 --- a/src/Optimizer/Rewriter/ColumnPruning.cpp +++ b/src/Optimizer/Rewriter/ColumnPruning.cpp @@ -161,9 +161,16 @@ PlanNodePtr ColumnPruningVisitor::visitOffsetNode(OffsetNode & node, ColumnPruni return visitDefault(node, column_pruning_context); } -PlanNodePtr ColumnPruningVisitor::visitTableFinishNode(TableFinishNode & node, ColumnPruningContext & column_pruning_context) +PlanNodePtr ColumnPruningVisitor::visitTableFinishNode(TableFinishNode & node, ColumnPruningContext &) { - return visitPlanNode(node, column_pruning_context); + NameSet require; + PlanNodePtr child = node.getChildren()[0]; + for (const auto & item : child->getCurrentDataStream().header) + require.insert(item.name); + ColumnPruningContext child_column_pruning_context{.name_set = require}; + PlanNodePtr new_child = VisitorUtil::accept(*child, *this, child_column_pruning_context); + node.replaceChildren({new_child}); + return node.shared_from_this(); } PlanNodePtr ColumnPruningVisitor::visitOutfileFinishNode(OutfileFinishNode & node, ColumnPruningContext & column_pruning_context) @@ -462,10 +469,19 @@ PlanNodePtr ColumnPruningVisitor::visitExpandNode(ExpandNode & node, ColumnPruni ColumnPruningContext child_column_pruning_context{.name_set = child_require}; auto child = VisitorUtil::accept(node.getChildren()[0], *this, child_column_pruning_context); + Assignments assignments; + for (const auto & assignment : step->getAssignments()) + if (child_require.contains(assignment.first)) + assignments.emplace_back(assignment.first, assignment.second); + NameToType name_to_type; + for (const auto & item : step->getNameToType()) + if (child_require.contains(item.first)) + name_to_type.emplace(item.first, item.second); + auto expr_step = std::make_shared( child->getStep()->getOutputStream(), - step->getAssignments(), - step->getNameToType(), + assignments, + name_to_type, step->getGroupIdSymbol(), step->getGroupIdValue(), step->getGroupIdNonNullSymbol()); @@ -833,6 +849,7 @@ PlanNodePtr ColumnPruningVisitor::visitJoinNode(JoinNode & node, ColumnPruningCo step->getKeepLeftReadInOrder(), step->getLeftKeys(), step->getRightKeys(), + step->getKeyIdsNullSafe(), step->getFilter(), step->isHasUsing(), step->getRequireRightKeys(), @@ -1507,6 +1524,20 @@ String ColumnPruningVisitor::selectColumnWithMinSize(NamesAndTypesList source_co { source_columns.remove(column); } + + // tmp fix for 40113_lowcard_nullable_subcolumn + auto metadata_snapshot = storage->getInMemoryMetadataPtr(); + const auto & columns_desc = metadata_snapshot->getColumns(); + source_columns.erase( + std::remove_if( + source_columns.begin(), + source_columns.end(), + [&](const auto & type_and_name) { + auto column_opt = columns_desc.tryGetColumnOrSubcolumn(GetColumnsOptions::Ordinary, type_and_name.name); + return column_opt && column_opt->isSubcolumn() + && !!(typeid_cast(column_opt->getTypeInStorage().get())); + }), + source_columns.end()); } /// If we have no information about columns sizes, choose a column of minimum size of its data type. return ExpressionActions::getSmallestColumn(source_columns); diff --git a/src/Optimizer/Rewriter/EliminateJoinByForeignKey.cpp b/src/Optimizer/Rewriter/EliminateJoinByForeignKey.cpp index aa28d3f9048..a43e06df113 100644 --- a/src/Optimizer/Rewriter/EliminateJoinByForeignKey.cpp +++ b/src/Optimizer/Rewriter/EliminateJoinByForeignKey.cpp @@ -151,6 +151,9 @@ FPKeysAndOrdinaryKeys EliminateJoinByFK::Rewriter::visitPlanNode(PlanNodeBase & FPKeysAndOrdinaryKeys EliminateJoinByFK::Rewriter::visitJoinNode(JoinNode & node, JoinInfo & join_info) { + if (node.getStep()->hasKeyIdNullSafe()) + return {}; + std::vector input_keys; ForeignKeyOrPrimaryKeys old_common_fp_keys; // only for bottom join. @@ -1031,6 +1034,7 @@ PlanNodePtr EliminateJoinByFK::Eliminator::visitJoinNode(JoinNode & node, JoinEl step->getKeepLeftReadInOrder(), left_keys, right_keys, + step->getKeyIdsNullSafe(), step->getFilter(), step->isHasUsing(), step->getRequireRightKeys(), diff --git a/src/Optimizer/Rewriter/EliminateJoinByForeignKey.h b/src/Optimizer/Rewriter/EliminateJoinByForeignKey.h index 87170a88c4c..2b32f5c3851 100644 --- a/src/Optimizer/Rewriter/EliminateJoinByForeignKey.h +++ b/src/Optimizer/Rewriter/EliminateJoinByForeignKey.h @@ -68,7 +68,7 @@ class EliminateJoinByFK : public Rewriter private: bool isEnabled(ContextMutablePtr context) const override { - return context->getSettingsRef().enable_eliminate_join_by_fk; + return context->getSettingsRef().enable_eliminate_join_by_fk && !context->getSettingsRef().join_using_null_safe; } void rewrite(QueryPlan & plan, ContextMutablePtr context) const override; diff --git a/src/Optimizer/Rewriter/GroupByKeysPruning.cpp b/src/Optimizer/Rewriter/GroupByKeysPruning.cpp index fc10bc8cebe..790acab5df4 100644 --- a/src/Optimizer/Rewriter/GroupByKeysPruning.cpp +++ b/src/Optimizer/Rewriter/GroupByKeysPruning.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -12,6 +13,7 @@ #include #include #include +#include namespace DB { @@ -178,7 +180,10 @@ PlanAndDataDependencyWithConstants GroupByKeysPruning::Rewriter::visitAggregatin } for (const auto & [name, literal] : constants_values) { - new_assignments.emplace(name, std::make_shared(literal.value)); + // date/datetime should make a cast function + // but nullable(UInt64) shouldn't make a cast function + auto literal_ast = LiteralEncoder::encodeForComparisonExpr(literal.value, literal.type, context); + new_assignments.emplace(name, std::move(literal_ast)); new_name_to_type[name] = literal.type; } diff --git a/src/Optimizer/Rewriter/PredicatePushdown.cpp b/src/Optimizer/Rewriter/PredicatePushdown.cpp index 7138cb4cdf8..7f6138cdce3 100644 --- a/src/Optimizer/Rewriter/PredicatePushdown.cpp +++ b/src/Optimizer/Rewriter/PredicatePushdown.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -35,8 +36,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -154,10 +157,14 @@ PlanNodePtr PredicateVisitor::visitProjectionNode(ProjectionNode & node, Predica auto pushdown_predicate = PredicateUtils::combineConjuncts(inlined_deterministic_conjuncts); LOG_DEBUG( - &Poco::Logger::get("Debugger"), "node {}, pushdown_predicate : {}", node.getId(), pushdown_predicate->formatForErrorMessage()); + &Poco::Logger::get("PredicateVisitor"), + "project node {}, pushdown_predicate : {}", + node.getId(), + pushdown_predicate->formatForErrorMessage()); if (!pushdown_predicate->as()) - pushdown_predicate = ExpressionInterpreter::optimizePredicate(pushdown_predicate, step.getInputStreams()[0].getNamesToTypes(), context); + pushdown_predicate + = ExpressionInterpreter::optimizePredicate(pushdown_predicate, step.getInputStreams()[0].getNamesToTypes(), context); PredicateContext expression_context{ .predicate = pushdown_predicate, .extra_predicate_for_simplify_outer_join @@ -184,12 +191,29 @@ PlanNodePtr PredicateVisitor::visitProjectionNode(ProjectionNode & node, Predica PlanNodePtr PredicateVisitor::visitFilterNode(FilterNode & node, PredicateContext & predicate_context) { const auto & step = *node.getStep(); - auto predicates = std::vector{step.getFilter(), predicate_context.predicate}; + + // handle in function has large value list + UInt64 limit = predicate_context.context->getSettingsRef().max_in_value_list_to_pushdown; + std::pair split_in_filter = FilterStep::splitLargeInValueList(step.getFilter(), limit); + + LOG_DEBUG( + &Poco::Logger::get("PredicateVisitor"), + "filter node {}, split_in_filter.first : {}, split_in_filter.second : {}", + node.getId(), + split_in_filter.first->formatForErrorMessage(), + split_in_filter.second->formatForErrorMessage() + ); + + auto predicates = std::vector{split_in_filter.first, predicate_context.predicate}; ConstASTPtr predicate = PredicateUtils::combineConjuncts(predicates); + if (simplify_common_filter) { predicate = CommonPredicatesRewriter::rewrite(predicate, context); } + + LOG_DEBUG(&Poco::Logger::get("PredicateVisitor"), "filter node {}, pushdown_predicate : {}", node.getId(), predicate->formatForErrorMessage()); + PredicateContext filter_context{ .predicate = predicate, .extra_predicate_for_simplify_outer_join = predicate_context.extra_predicate_for_simplify_outer_join, @@ -198,6 +222,11 @@ PlanNodePtr PredicateVisitor::visitFilterNode(FilterNode & node, PredicateContex if (rewritten->getStep()->getType() != IQueryPlanStep::Type::Filter) { + if (!PredicateUtils::isTruePredicate(split_in_filter.second)) + { + auto filter_step = std::make_shared(rewritten->getStep()->getOutputStream(), split_in_filter.second); + return std::make_shared(context->nextNodeId(), std::move(filter_step), PlanNodes{rewritten}); + } return rewritten; } @@ -205,6 +234,11 @@ PlanNodePtr PredicateVisitor::visitFilterNode(FilterNode & node, PredicateContex { if (rewritten->getChildren()[0] != node.getChildren()[0]) { + if (!PredicateUtils::isTruePredicate(split_in_filter.second)) + { + auto filter_step = std::make_shared(rewritten->getStep()->getOutputStream(), split_in_filter.second); + return std::make_shared(context->nextNodeId(), std::move(filter_step), PlanNodes{rewritten}); + } return rewritten; } auto rewritten_step_ptr = rewritten->getStep(); @@ -214,6 +248,11 @@ PlanNodePtr PredicateVisitor::visitFilterNode(FilterNode & node, PredicateContex // see ExpressionEquivalence if (step.getFilter() != rewritten_step.getFilter()) { + if (!PredicateUtils::isTruePredicate(split_in_filter.second)) + { + auto filter_step = std::make_shared(rewritten->getStep()->getOutputStream(), split_in_filter.second); + return std::make_shared(context->nextNodeId(), std::move(filter_step), PlanNodes{rewritten}); + } return rewritten; } } @@ -225,13 +264,17 @@ PlanNodePtr PredicateVisitor::visitAggregatingNode(AggregatingNode & node, Predi const auto & step = *node.getStep(); const auto & keys = step.getKeys(); - // TODO: in case of grouping sets, we should be able to push the filters over grouping keys below the aggregation - // and also preserve the filter above the aggregation if it has an empty grouping set if (keys.empty()) { return visitPlanNode(node, predicate_context); } + // never push predicate through grouping sets agg + if (step.isGroupingSet()) + { + return visitPlanNode(node, predicate_context); + } + ConstASTPtr inherited_predicate = predicate_context.predicate; EqualityInference equality_inference = EqualityInference::newInstance(inherited_predicate, context); @@ -310,7 +353,7 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & auto step = node.getStep(); // RequireRightKeys is clickhouse sql only, we don't process this kind of join. - if (step->getRequireRightKeys().has_value()) + if (step->getRequireRightKeys().has_value() || step->hasKeyIdNullSafe()) { return visitPlanNode(node, predicate_context); } @@ -325,6 +368,7 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & PlanNodePtr & right = node.getChildren()[1]; ConstASTPtr left_effective_predicate = EffectivePredicateExtractor::extract(left, context); ConstASTPtr right_effective_predicate = EffectivePredicateExtractor::extract(right, context); + ConstASTPtr join_predicate = PredicateUtils::extractJoinPredicate(node); std::set left_symbols; @@ -351,8 +395,8 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & ASTTableJoin::Kind kind = step->getKind(); - LOG_TRACE( - logger, + LOG_DEBUG( + &Poco::Logger::get("PredicateVisitor"), "join node {}, inherited_predicate : {}, left effective predicate: {} , right effective predicate: {}, join_predicate : {}", node.getId(), inherited_predicate->formatForErrorMessage(), @@ -566,6 +610,7 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & ProjectionPlanner left_planner(left_source_expression_node, context); ProjectionPlanner right_planner(right_source_expression_node, context); const bool allow_extended_type_conversion = context->getSettingsRef().allow_extended_type_conversion; + const bool enable_implicit_arg_type_convert = context->getSettingsRef().enable_implicit_arg_type_convert; bool need_project = false; for (const auto & clause : join_clauses) @@ -577,7 +622,19 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & if (!JoinCommon::isJoinCompatibleTypes(left_type, right_type)) { - auto common_type = getLeastSupertype(DataTypes{left_type, right_type}, allow_extended_type_conversion); + DataTypePtr common_type; + try + { + common_type + = getCommonType(DataTypes{left_type, right_type}, allow_extended_type_conversion, enable_implicit_arg_type_convert); + } + catch (DB::Exception & ex) + { + throw Exception( + "Type mismatch of columns to JOIN by: " + left_type->getName() + " at left, " + right_type->getName() + + " at right. " + "Can't get supertype: " + ex.message(), + ErrorCodes::TYPE_MISMATCH); + } left_key = left_planner.addColumn(makeCastFunction(std::make_shared(left_key), common_type)).first; right_key = right_planner.addColumn(makeCastFunction(std::make_shared(right_key), common_type)).first; need_project = true; @@ -620,8 +677,9 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & ASTTableJoin::Strictness::All, step->getMaxStreams(), step->getKeepLeftReadInOrder(), - std::move(left_keys), - std::move(right_keys), + left_keys, + right_keys, + std::vector{}, new_join_filter, step->isHasUsing(), step->getRequireRightKeys(), @@ -643,8 +701,9 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & step->getStrictness(), step->getMaxStreams(), step->getKeepLeftReadInOrder(), - std::move(left_keys), - std::move(right_keys), + left_keys, + right_keys, + std::vector{}, new_join_filter, step->isHasUsing(), step->getRequireRightKeys(), @@ -676,6 +735,7 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & join_step->getKeepLeftReadInOrder(), join_step->getRightKeys(), join_step->getLeftKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -723,6 +783,14 @@ PlanNodePtr PredicateVisitor::visitJoinNode(JoinNode & node, PredicateContext & return output_node; } +DataTypePtr PredicateVisitor::getCommonType(const DataTypes & types, bool allow_extended_type_conversion, bool enable_implicit_arg_type_convert) +{ + if (enable_implicit_arg_type_convert) + return getLeastSupertype(types, true); + else + return getLeastSupertype(types, allow_extended_type_conversion); +} + PlanNodePtr PredicateVisitor::visitArrayJoinNode(ArrayJoinNode & node, PredicateContext & predicate_context) { const auto & step = *node.getStep(); @@ -1362,7 +1430,8 @@ void PredicateVisitor::tryNormalizeOuterToInnerJoin(JoinNode & node, const Const return; // TODO: ANTI JOINs also can be optimized - if (strictness != Strictness::All && strictness != Strictness::Any) + // left any join CANNOT be converted to inner any join + if (strictness != Strictness::All) return; auto column_types = step.getOutputStream().header.getNamesToTypes(); @@ -1564,10 +1633,10 @@ ASTPtr EffectivePredicateVisitor::visitFilterNode(FilterNode & node, ContextMuta } /** - * Disable extract predicate with inconsistent type. - * + * Disable extract predicate with inconsistent type. + * * for predicate : expr#toDate('2023-06-29') = expr#'2023-06-29', - * left argument and right argument both ASTIdentifier, but they + * left argument and right argument both ASTIdentifier, but they * have different type, left type is Date, right type is String. */ const NameToType & name_types = step.getOutputStream().getNamesToTypes(); @@ -1598,9 +1667,12 @@ ASTPtr EffectivePredicateVisitor::visitFilterNode(FilterNode & node, ContextMuta removed_inconsistent_type_filters.emplace_back(ptr); } + std::vector removed_large_in_value_list + = FilterStep::removeLargeInValueList(removed_inconsistent_type_filters, context->getSettingsRef().max_in_value_list_to_pushdown); + // Adds on underlying_predicate - removed_inconsistent_type_filters.emplace_back(underlying_predicate); - return PredicateUtils::combineConjuncts(removed_inconsistent_type_filters); + removed_large_in_value_list.emplace_back(underlying_predicate); + return PredicateUtils::combineConjuncts(removed_large_in_value_list); } ASTPtr EffectivePredicateVisitor::visitAggregatingNode(AggregatingNode & node, ContextMutablePtr & context) diff --git a/src/Optimizer/Rewriter/PredicatePushdown.h b/src/Optimizer/Rewriter/PredicatePushdown.h index 0136abc6e20..2fc8c8f2e90 100644 --- a/src/Optimizer/Rewriter/PredicatePushdown.h +++ b/src/Optimizer/Rewriter/PredicatePushdown.h @@ -120,6 +120,7 @@ class PredicateVisitor : public PlanNodeVisitor static ASTTableJoin::Kind useInnerForLeftSide(ASTTableJoin::Kind kind); static ASTTableJoin::Kind useInnerForRightSide(ASTTableJoin::Kind kind); static bool isRegularJoin(const JoinStep & step); + static DataTypePtr getCommonType(const DataTypes & types, bool allow_extended_type_conversion, bool enable_implicit_arg_type_convert); }; struct InnerJoinResult diff --git a/src/Optimizer/Rewriter/RemoveApply.cpp b/src/Optimizer/Rewriter/RemoveApply.cpp index 58f67217e26..fbcd9043dbc 100644 --- a/src/Optimizer/Rewriter/RemoveApply.cpp +++ b/src/Optimizer/Rewriter/RemoveApply.cpp @@ -163,6 +163,7 @@ PlanNodePtr CorrelatedScalarSubqueryVisitor::visitApplyNode(ApplyNode & node, Vo context->getSettingsRef().optimize_read_in_order, key_pairs.first, key_pairs.second, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -293,6 +294,7 @@ PlanNodePtr CorrelatedScalarSubqueryVisitor::visitApplyNode(ApplyNode & node, Vo context->getSettingsRef().optimize_read_in_order, key_pairs.first, key_pairs.second, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -516,6 +518,7 @@ PlanNodePtr UnCorrelatedScalarSubqueryVisitor::visitApplyNode(ApplyNode & node, context->getSettingsRef().optimize_read_in_order, Names{}, Names{}, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -646,6 +649,7 @@ PlanNodePtr CorrelatedInSubqueryVisitor::visitApplyNode(ApplyNode & node, Void & context->getSettingsRef().optimize_read_in_order, correlation_predicate.first, correlation_predicate.second, + std::vector{}, combine_filter, false, std::nullopt, @@ -931,6 +935,7 @@ PlanNodePtr UnCorrelatedInSubqueryVisitor::visitApplyNode(ApplyNode & node, Void context->getSettingsRef().optimize_read_in_order, in_left, in_right, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -1138,6 +1143,7 @@ PlanNodePtr CorrelatedExistsSubqueryVisitor::visitApplyNode(ApplyNode & node, Vo context->getSettingsRef().optimize_read_in_order, key_pairs.first, key_pairs.second, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -1241,6 +1247,7 @@ PlanNodePtr CorrelatedExistsSubqueryVisitor::visitApplyNode(ApplyNode & node, Vo context->getSettingsRef().optimize_read_in_order, key_pairs.first, key_pairs.second, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -1511,6 +1518,7 @@ PlanNodePtr UnCorrelatedExistsSubqueryVisitor::visitApplyNode(ApplyNode & node, context->getSettingsRef().optimize_read_in_order, Names{}, Names{}, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -1633,6 +1641,7 @@ PlanNodePtr UnCorrelatedQuantifiedComparisonSubqueryVisitor::visitApplyNode(Appl context->getSettingsRef().optimize_read_in_order, Names{}, Names{}, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, @@ -1806,6 +1815,7 @@ PlanNodePtr CorrelatedQuantifiedComparisonSubqueryVisitor::visitApplyNode(ApplyN context->getSettingsRef().optimize_read_in_order, correlation_predicate.first, correlation_predicate.second, + std::vector{}, join_filter, false, std::nullopt, @@ -2476,6 +2486,7 @@ TransformResult ExistsToSemiJoin::transformImpl(PlanNodePtr node, const Captures context->getSettingsRef().optimize_read_in_order, key_pairs.first, key_pairs.second, + std::vector{}, PredicateUtils::combineConjuncts(filter), false, std::nullopt, @@ -2592,6 +2603,7 @@ TransformResult InToSemiJoin::transformImpl(PlanNodePtr node, const Captures &, context->getSettingsRef().optimize_read_in_order, in_left, in_right, + std::vector{}, PredicateConst::TRUE_VALUE, false, std::nullopt, diff --git a/src/Optimizer/Rewriter/ShareCommonPlanNode.cpp b/src/Optimizer/Rewriter/ShareCommonPlanNode.cpp index 8809a2d9cf7..a9e204a395b 100644 --- a/src/Optimizer/Rewriter/ShareCommonPlanNode.cpp +++ b/src/Optimizer/Rewriter/ShareCommonPlanNode.cpp @@ -57,12 +57,12 @@ class ShareCommonPlanNode::Rewriter : public SimplePlanRewriter auto cte_id = cte.first; auto forward_order = plan_signature_output_orders.at(node_ptr); - auto reserve_order = plan_signature_output_orders.at(cte.second); + auto reverse_order = plan_signature_output_orders.at(cte.second); std::unordered_map output_columns; for (const auto & output : node.getOutputNames()) { - auto input_column = reserve_order.getByPosition(forward_order.getPositionByName(output)).name; + auto input_column = reverse_order.getByPosition(forward_order.getPositionByName(output)).name; output_columns.emplace(output, input_column); } return PlanNodeBase::createPlanNode( diff --git a/src/Optimizer/Rewriter/UnifyJoinOutputs.cpp b/src/Optimizer/Rewriter/UnifyJoinOutputs.cpp index 26d3ef51a7c..c3d22d2ebd8 100644 --- a/src/Optimizer/Rewriter/UnifyJoinOutputs.cpp +++ b/src/Optimizer/Rewriter/UnifyJoinOutputs.cpp @@ -160,6 +160,7 @@ PlanNodePtr UnifyJoinOutputs::Rewriter::visitJoinNode(JoinNode & node, std::set< step->getKeepLeftReadInOrder(), std::move(left_keys), std::move(right_keys), + std::vector{}, step->getFilter(), step->isHasUsing(), step->getRequireRightKeys(), diff --git a/src/Optimizer/Rewriter/UnifyNullableType.cpp b/src/Optimizer/Rewriter/UnifyNullableType.cpp index cdabe23c27b..5774269c22f 100644 --- a/src/Optimizer/Rewriter/UnifyNullableType.cpp +++ b/src/Optimizer/Rewriter/UnifyNullableType.cpp @@ -357,6 +357,7 @@ PlanNodePtr UnifyNullableVisitor::visitJoinNode(JoinNode & node, ContextMutableP join_step.getKeepLeftReadInOrder(), join_step.getLeftKeys(), join_step.getRightKeys(), + join_step.getKeyIdsNullSafe(), join_step.getFilter(), join_step.isHasUsing(), join_step.getRequireRightKeys(), diff --git a/src/Optimizer/Rewriter/UseSortingProperty.cpp b/src/Optimizer/Rewriter/UseSortingProperty.cpp index ee82915d820..4c5c1160ec6 100644 --- a/src/Optimizer/Rewriter/UseSortingProperty.cpp +++ b/src/Optimizer/Rewriter/UseSortingProperty.cpp @@ -1,62 +1,73 @@ +#include +#include #include +#include +#include #include #include #include +#include #include +#include #include +#include #include -#include +#include +#include namespace DB { void SortingOrderedSource::rewrite(QueryPlan & plan, ContextMutablePtr context) const { SortingOrderedSource::Rewriter rewriter{context, plan.getCTEInfo()}; - Void require; - auto result = VisitorUtil::accept(plan.getPlanNode(), rewriter, require); + SortDescription required; + auto result = VisitorUtil::accept(plan.getPlanNode(), rewriter, required); - PushSortingInfoRewriter push_rewriter{context, plan.getCTEInfo()}; + PruneSortingInfoRewriter push_rewriter{context, plan.getCTEInfo()}; SortInfo sort_info; auto plan_node = VisitorUtil::accept(result.plan, push_rewriter, sort_info); plan.update(plan_node); } -PlanAndProp SortingOrderedSource::Rewriter::visitPlanNode(PlanNodeBase & node, Void &) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitPlanNode(PlanNodeBase & node, SortDescription &) { PlanNodes children; - Void require; + SortDescription required; PropertySet input_properties; + ConstantsSet input_constants; for (const auto & child : node.getChildren()) { - auto result = VisitorUtil::accept(child, *this, require); + auto result = VisitorUtil::accept(child, *this, required); children.emplace_back(result.plan); input_properties.emplace_back(result.property); + input_constants.emplace_back(result.constants); } node.replaceChildren(children); Property any_prop; Property prop = PropertyDeriver::deriveProperty(node.getStep(), input_properties, any_prop, context); - return {node.shared_from_this(), prop}; + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), input_constants, cte_helper.getCTEInfo(), context); + return {node.shared_from_this(), prop, constants}; } -PlanAndProp SortingOrderedSource::Rewriter::visitSortingNode(SortingNode & node, Void & v) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitSortingNode(SortingNode & node, SortDescription &) { - auto result = VisitorUtil::accept(node.getChildren()[0], *this, v); - auto step = node.getStep(); - auto prefix_sorting = PropertyMatcher::matchSorting(*context, step->getSortDescription(), result.property.getSorting()); + auto required = step->getSortDescription(); + auto result = VisitorUtil::accept(node.getChildren()[0], *this, required); + + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + auto prefix_sorting = PropertyMatcher::matchSorting(*context, step->getSortDescription(), result.property.getSorting(), {}, constants); step->setPrefixDescription(prefix_sorting.toSortDesc()); Property any_prop; Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); - return {node.shared_from_this(), prop}; + return {node.shared_from_this(), prop, constants}; } -PlanAndProp SortingOrderedSource::Rewriter::visitAggregatingNode(AggregatingNode & node, Void & v) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitAggregatingNode(AggregatingNode & node, SortDescription & required) { - auto result = VisitorUtil::accept(node.getChildren()[0], *this, v); const auto & settings = context->getSettingsRef(); - if (settings.optimize_aggregation_in_order /* && !settings.optimize_aggregate_function_type */) { auto step = node.getStep(); @@ -68,19 +79,21 @@ PlanAndProp SortingOrderedSource::Rewriter::visitAggregatingNode(AggregatingNode order_descr.emplace_back(name, 1, 1); } + PlanAndPropConstants result = VisitorUtil::accept(node.getChildren()[0], *this, order_descr); - auto prefix_sorting = PropertyMatcher::matchSorting(*context, order_descr, result.property.getSorting()); + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + auto prefix_sorting = PropertyMatcher::matchSorting(*context, order_descr, result.property.getSorting(), {}, constants); step->setGroupBySortDescription(prefix_sorting.toSortDesc()); + + Property any_prop; + Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); + return {node.shared_from_this(), prop, constants}; } - Property any_prop; - Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); - return {node.shared_from_this(), prop}; + return visitPlanNode(node, required); } -PlanAndProp SortingOrderedSource::Rewriter::visitWindowNode(WindowNode & node, Void & v) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitWindowNode(WindowNode & node, SortDescription & required) { - auto result = VisitorUtil::accept(node.getChildren()[0], *this, v); - #if 0 if (context->getSettingsRef().optimize_read_in_window_order) { @@ -93,65 +106,178 @@ PlanAndProp SortingOrderedSource::Rewriter::visitWindowNode(WindowNode & node, V const auto & order_by = step->getWindowDescription().order_by; order_descr.insert(order_descr.end(), order_by.begin(), order_by.end()); - auto prefix_sorting = PropertyMatcher::matchSorting(*context, order_descr, result.property.getSorting()); + auto result = VisitorUtil::accept(node.getChildren()[0], *this, order_descr); + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + auto prefix_sorting = PropertyMatcher::matchSorting(*context, order_descr, result.property.getSorting(), {}, constants); step->setPrefixDescription(prefix_sorting.toSortDesc()); + + Property any_prop; + Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); + return {node.shared_from_this(), prop, constants}; } #endif + + return visitPlanNode(node, required); +} + +PlanAndPropConstants SortingOrderedSource::Rewriter::visitCTERefNode(CTERefNode & node, SortDescription &) +{ + const auto * step = node.getStep().get(); + SortDescription required; + auto cte_plan = cte_helper.acceptAndUpdate(step->getId(), *this, required, [](auto & result) { return result.plan; }); + return {node.shared_from_this(), Property{}, cte_plan.constants}; +} + +PlanAndPropConstants SortingOrderedSource::Rewriter::visitTopNFilteringNode(TopNFilteringNode & node, SortDescription &) +{ + auto & topn_filtering = node.getStep(); + auto required_sorting = topn_filtering->getSortDescription(); + + auto result = VisitorUtil::accept(node.getChildren()[0], *this, required_sorting); + + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + auto prefix_sorting = PropertyMatcher::matchSorting(*context, required_sorting, result.property.getSorting(), {}, constants); + if (prefix_sorting.size() == required_sorting.size()) + topn_filtering->setAlgorithm(TopNFilteringAlgorithm::Limit); + Property any_prop; Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); - return {node.shared_from_this(), prop}; + return {node.shared_from_this(), prop, constants}; } -PlanAndProp SortingOrderedSource::Rewriter::visitCTERefNode(CTERefNode & node, Void & v) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitTableScanNode(TableScanNode & node, SortDescription & required) { - const auto * step = node.getStep().get(); + auto & step = node.getStep(); - auto cte_plan = cte_helper.acceptAndUpdate(step->getId(), *this, v, [](auto & result) { return result.plan; }); - return {node.shared_from_this(), Property{}}; + Property any_prop; + any_prop.setSorting(Sorting{required}); + Property prop = PropertyDeriver::deriveProperty(step, context, any_prop); + step->setReadOrder(prop.getSorting().translate(node.getStep()->getAliasToColumnMap()).toSortDesc()); + Constants constants = ConstantsDeriver::deriveConstants(step, cte_helper.getCTEInfo(), context); + return {node.shared_from_this(), prop, constants}; } -PlanAndProp SortingOrderedSource::Rewriter::visitTopNFilteringNode(TopNFilteringNode & node, Void & ctx) +PlanAndPropConstants SortingOrderedSource::Rewriter::visitFilterNode(FilterNode & node, SortDescription & required) { - auto result = VisitorUtil::accept(node.getChildren()[0], *this, ctx); - auto actual_sorting = result.property.getSorting().toSortDesc(); + auto result = VisitorUtil::accept(node.getChildren()[0], *this, required); + Property any_prop; + Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + return {node.shared_from_this(), prop, constants}; +} - auto & topn_filtering = dynamic_cast(*node.getStep()); - const auto & required_sorting = topn_filtering.getSortDescription(); +PlanAndPropConstants SortingOrderedSource::Rewriter::visitProjectionNode(ProjectionNode & node, SortDescription & require) +{ + auto mappings = Utils::computeIdentityTranslations(node.getStep()->getAssignments()); - if (actual_sorting.hasPrefix(required_sorting)) - topn_filtering.setAlgorithm(TopNFilteringAlgorithm::Limit); + SortDescription push_down_sort_description; + for (const auto & column : require) + { + if (!mappings.contains(column.column_name)) + break; + push_down_sort_description.emplace_back( + SortColumnDescription{mappings.at(column.column_name), column.direction, column.nulls_direction}); + } + auto result = VisitorUtil::accept(node.getChildren()[0], *this, push_down_sort_description); + Property any_prop; Property prop = PropertyDeriver::deriveProperty(node.getStep(), {result.property}, any_prop, context); - return {node.shared_from_this(), prop}; + Constants constants = ConstantsDeriver::deriveConstants(node.getStep(), {result.constants}, cte_helper.getCTEInfo(), context); + return {node.shared_from_this(), prop, constants}; +} + +PlanNodePtr PruneSortingInfoRewriter::visitPlanNode(PlanNodeBase & node, SortInfo & required) +{ + SortInfo s{required.sort_desc, size_t{0}}; + return SimplePlanRewriter::visitPlanNode(node, s); } -PlanNodePtr PushSortingInfoRewriter::visitSortingNode(SortingNode & node, SortInfo &) +PlanNodePtr PruneSortingInfoRewriter::visitSortingNode(SortingNode & node, SortInfo &) { auto prefix_desc = node.getStep()->getPrefixDescription(); SortInfo s{prefix_desc, node.getStep()->getLimit()}; - return visitPlanNode(node, s); + return SimplePlanRewriter::visitPlanNode(node, s); } -PlanNodePtr PushSortingInfoRewriter::visitAggregatingNode(AggregatingNode & node, SortInfo &) +PlanNodePtr PruneSortingInfoRewriter::visitAggregatingNode(AggregatingNode & node, SortInfo &) { auto prefix_desc = node.getStep()->getGroupBySortDescription(); SortInfo s{prefix_desc, size_t{0}}; - return visitPlanNode(node, s); + return SimplePlanRewriter::visitPlanNode(node, s); } -PlanNodePtr PushSortingInfoRewriter::visitWindowNode(WindowNode & node, SortInfo &) +PlanNodePtr PruneSortingInfoRewriter::visitWindowNode(WindowNode & node, SortInfo &) { auto prefix_desc = node.getStep()->getPrefixDescription(); SortInfo s{prefix_desc, size_t{0}}; - return visitPlanNode(node, s); + return SimplePlanRewriter::visitPlanNode(node, s); +} + +PlanNodePtr PruneSortingInfoRewriter::visitTopNFilteringNode(TopNFilteringNode & node, SortInfo &) +{ + auto prefix_desc = node.getStep()->getSortDescription(); + SortInfo s{prefix_desc, size_t{0}}; + return SimplePlanRewriter::visitPlanNode(node, s); } -PlanNodePtr PushSortingInfoRewriter::visitTableScanNode(TableScanNode & node, SortInfo & s) +PlanNodePtr PruneSortingInfoRewriter::visitTableScanNode(TableScanNode & node, SortInfo & required) { - node.getStep()->setReadOrder(s.sort_desc); + auto & step = node.getStep(); + + NameSet required_columns; + auto mappings = step->getAliasToColumnMap(); + for (const auto & column : required.sort_desc) + { + auto column_name = mappings.contains(column.column_name) ? mappings.at(column.column_name) : column.column_name; + required_columns.emplace(column_name); + } + + // prune unused read order columns + // eg, select * from table(order by a,b,c) where a = 'x' and d = 'y' order by b,d + // required sort columns may be: b,d; read order columns should be a,b + auto read_order = step->getReadOrder(); + auto it = std::find_if(read_order.rbegin(), read_order.rend(), [&](const SortColumnDescription & sort_column) { + return required_columns.contains(sort_column.column_name); + }); + + SortDescription pruned_read_order(read_order.begin(), read_order.begin() + std::distance(it, read_order.rend())); + + if (!required.sort_desc.empty() && pruned_read_order.empty()) + { + // do nothing if all columns in required don't exist in table + if (logger->error()) + { + Names names; + for (const auto & desc : required.sort_desc) + names.emplace_back(desc.column_name); + LOG_WARNING(logger, "unkown required sorting: {}", fmt::format("{}", fmt::join(names, ", "))); + } + } + else + { + node.getStep()->setReadOrder(pruned_read_order); + } + return node.shared_from_this(); } +PlanNodePtr PruneSortingInfoRewriter::visitProjectionNode(ProjectionNode & node, SortInfo & required) +{ + auto mappings = Utils::computeIdentityTranslations(node.getStep()->getAssignments()); + + SortDescription push_down_sort_description; + for (const auto & column : required.sort_desc) + { + if (!mappings.contains(column.column_name)) + break; + push_down_sort_description.emplace_back( + SortColumnDescription{mappings.at(column.column_name), column.direction, column.nulls_direction}); + } + + SortInfo child_required{push_down_sort_description, Utils::canChangeOutputRows(*node.getStep(), context) ? required.limit : 0ul}; + return SimplePlanRewriter::visitPlanNode(node, child_required); +} + } diff --git a/src/Optimizer/Rewriter/UseSortingProperty.h b/src/Optimizer/Rewriter/UseSortingProperty.h index c6143119acd..2156728c87e 100644 --- a/src/Optimizer/Rewriter/UseSortingProperty.h +++ b/src/Optimizer/Rewriter/UseSortingProperty.h @@ -5,13 +5,23 @@ #include #include #include +#include #include #include #include #include +#include namespace DB { + +struct PlanAndPropConstants +{ + PlanNodePtr plan; + Property property; + Constants constants; +}; + class SortingOrderedSource : public Rewriter { public: @@ -20,44 +30,56 @@ class SortingOrderedSource : public Rewriter void rewrite(QueryPlan & plan, ContextMutablePtr context) const override; bool isEnabled(ContextMutablePtr context) const override { - return context->getSettingsRef().enable_sorting_property; + return context->getSettingsRef().enable_sorting_property && context->getSettingsRef().optimize_read_in_order; } class Rewriter; }; -class SortingOrderedSource::Rewriter : public PlanNodeVisitor +class SortingOrderedSource::Rewriter : public PlanNodeVisitor { public: Rewriter(ContextMutablePtr context_, CTEInfo & cte_info_) : context(context_), cte_helper(cte_info_) { } - PlanAndProp visitPlanNode(PlanNodeBase &, Void &) override; - PlanAndProp visitSortingNode(SortingNode &, Void &) override; - PlanAndProp visitAggregatingNode(AggregatingNode &, Void &) override; - PlanAndProp visitWindowNode(WindowNode &, Void &) override; - PlanAndProp visitCTERefNode(CTERefNode & node, Void &) override; - PlanAndProp visitTopNFilteringNode(TopNFilteringNode & node, Void &) override; + + PlanAndPropConstants visitPlanNode(PlanNodeBase &, SortDescription & required) override; + PlanAndPropConstants visitSortingNode(SortingNode &, SortDescription & required) override; + PlanAndPropConstants visitAggregatingNode(AggregatingNode &, SortDescription & required) override; + PlanAndPropConstants visitWindowNode(WindowNode &, SortDescription & required) override; + PlanAndPropConstants visitTopNFilteringNode(TopNFilteringNode & node, SortDescription & required) override; + + PlanAndPropConstants visitCTERefNode(CTERefNode & node, SortDescription & required) override; + PlanAndPropConstants visitProjectionNode(ProjectionNode & node, SortDescription & required) override; + PlanAndPropConstants visitFilterNode(FilterNode & node, SortDescription & required) override; + PlanAndPropConstants visitTableScanNode(TableScanNode & node, SortDescription & required) override; private: ContextMutablePtr context; - SimpleCTEVisitHelper cte_helper; + SimpleCTEVisitHelper cte_helper; }; - struct SortInfo { SortDescription sort_desc; - SizeOrVariable limit; + SizeOrVariable limit = 0ul; }; -class PushSortingInfoRewriter : public SimplePlanRewriter +class PruneSortingInfoRewriter : public SimplePlanRewriter { public: - PushSortingInfoRewriter(ContextMutablePtr context_, CTEInfo & cte_info_) : SimplePlanRewriter(context_, cte_info_) + PruneSortingInfoRewriter(ContextMutablePtr context_, CTEInfo & cte_info_) + : SimplePlanRewriter(context_, cte_info_), logger(&Poco::Logger::get("PruneSortingInfoRewriter")) { } + + PlanNodePtr visitPlanNode(PlanNodeBase & node, SortInfo & required) override; PlanNodePtr visitSortingNode(SortingNode &, SortInfo &) override; PlanNodePtr visitAggregatingNode(AggregatingNode &, SortInfo &) override; PlanNodePtr visitWindowNode(WindowNode &, SortInfo &) override; - PlanNodePtr visitTableScanNode(TableScanNode &, SortInfo &) override; + PlanNodePtr visitTopNFilteringNode(TopNFilteringNode & node, SortInfo &) override; + PlanNodePtr visitProjectionNode(ProjectionNode & node, SortInfo & required) override; + PlanNodePtr visitTableScanNode(TableScanNode &, SortInfo & required) override; + +private: + Poco::Logger * logger; }; } diff --git a/src/Optimizer/Rule/Rewrite/CrossJoinToUnion.cpp b/src/Optimizer/Rule/Rewrite/CrossJoinToUnion.cpp index a809521218a..967f70a173b 100644 --- a/src/Optimizer/Rule/Rewrite/CrossJoinToUnion.cpp +++ b/src/Optimizer/Rule/Rewrite/CrossJoinToUnion.cpp @@ -93,6 +93,7 @@ TransformResult CrossJoinToUnion::transformImpl(PlanNodePtr node, const Captures context->getSettingsRef().optimize_read_in_order, child_step.getLeftKeys(), child_step.getRightKeys(), + child_step.getKeyIdsNullSafe(), PredicateConst::TRUE_VALUE, child_step.isHasUsing(), child_step.getRequireRightKeys(), @@ -133,6 +134,7 @@ TransformResult CrossJoinToUnion::transformImpl(PlanNodePtr node, const Captures context->getSettingsRef().optimize_read_in_order, child_step.getLeftKeys(), child_step.getRightKeys(), + child_step.getKeyIdsNullSafe(), PredicateConst::TRUE_VALUE, child_step.isHasUsing(), child_step.getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/EagerAggregation.cpp b/src/Optimizer/Rule/Rewrite/EagerAggregation.cpp index 060ff1cfc92..45e7da9fbb5 100644 --- a/src/Optimizer/Rule/Rewrite/EagerAggregation.cpp +++ b/src/Optimizer/Rule/Rewrite/EagerAggregation.cpp @@ -586,6 +586,7 @@ PlanNodePtr insertLocalAggregate( join_step.getKeepLeftReadInOrder(), join_step.getLeftKeys(), join_step.getRightKeys(), + join_step.getKeyIdsNullSafe(), join_step.getFilter(), join_step.isHasUsing(), join_step.getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/ExtractBitmapImplicitFilter.cpp b/src/Optimizer/Rule/Rewrite/ExtractBitmapImplicitFilter.cpp index b3184f3693a..04706d0afe2 100644 --- a/src/Optimizer/Rule/Rewrite/ExtractBitmapImplicitFilter.cpp +++ b/src/Optimizer/Rule/Rewrite/ExtractBitmapImplicitFilter.cpp @@ -102,7 +102,8 @@ TransformResult ExtractBitmapImplicitFilter::transformImpl(PlanNodePtr node, con for (const auto & parameter : parameters_map) { auto [in_ast, elem_size] = createInFunctionForBitMapParameter(parameter.first, parameter.second); - functions.push_back(in_ast); + if (in_ast) + functions.push_back(in_ast); total_in_elems += elem_size; } diff --git a/src/Optimizer/Rule/Rewrite/InlineProjections.cpp b/src/Optimizer/Rule/Rewrite/InlineProjections.cpp index cca3a2b1e55..057687b1ab9 100644 --- a/src/Optimizer/Rule/Rewrite/InlineProjections.cpp +++ b/src/Optimizer/Rule/Rewrite/InlineProjections.cpp @@ -323,6 +323,7 @@ TransformResult InlineProjectionIntoJoin::transformImpl(PlanNodePtr node, const join_step.getKeepLeftReadInOrder(), join_step.getLeftKeys(), join_step.getRightKeys(), + join_step.getKeyIdsNullSafe(), join_step.getFilter(), join_step.isHasUsing(), join_step.getRequireRightKeys(), @@ -361,6 +362,7 @@ TransformResult InlineProjectionOnJoinIntoJoin::transformImpl(PlanNodePtr node, join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.cpp b/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.cpp new file mode 100644 index 00000000000..915c68411c0 --- /dev/null +++ b/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.cpp @@ -0,0 +1,45 @@ +#include + +#include +#include "QueryPlan/JoinStep.h" +#include "QueryPlan/PlanNode.h" +#include "QueryPlan/ProjectionStep.h" + +namespace DB +{ +ConstRefPatternPtr JoinUsingToJoinOn::getPattern() const +{ + static auto pattern = Patterns::join().matchingStep([](const JoinStep & s) { return s.isHasUsing(); }).result(); + return pattern; +} + +TransformResult JoinUsingToJoinOn::transformImpl(PlanNodePtr node, const Captures &, RuleContext & rule_context) +{ + // currently optimizer not support 'join on null-safe equals' + // due to performance issues, we don't rewrite null-safe equals to join on + if (rule_context.context->getSettingsRef().join_using_null_safe) + return {}; + + auto * join_node = dynamic_cast(node.get()); + if (!join_node) + return {}; + + auto join_step = join_node->getStep(); + auto join_kind = join_step->getKind(); + + // when inner/left join, no matter what data in require_right_keys, + // the using column is ALWAYS same as the left column + // so we just reset using flag and clear require_right_keys + if (join_kind == ASTTableJoin::Kind::Inner || join_kind == ASTTableJoin::Kind::Left) + { + auto new_join_step = std::static_pointer_cast(join_step->copy(rule_context.context)); + new_join_step->resetUsing(); + auto new_join_node + = PlanNodeBase::createPlanNode(node->getId(), std::move(new_join_step), node->getChildren(), node->getStatistics()); + return new_join_node; + } + + // TODO: rewrite Outer and Right Join + return {}; +} +} diff --git a/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.h b/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.h new file mode 100644 index 00000000000..47ccc37ad79 --- /dev/null +++ b/src/Optimizer/Rule/Rewrite/JoinUsingToJoinOn.h @@ -0,0 +1,22 @@ +#pragma once +#include + +namespace DB +{ +class JoinUsingToJoinOn : public Rule +{ +public: + RuleType getType() const override { return RuleType::DISTINCT_TO_AGGREGATE; } + String getName() const override { return "JOIN_USING_TO_JOIN_ON"; } + // TODO add context settings + bool isEnabled(ContextPtr context) const override + { + return context->getSettingsRef().enable_join_using_to_join_on; + } + ConstRefPatternPtr getPattern() const override; + +protected: + TransformResult transformImpl(PlanNodePtr node, const Captures & captures, RuleContext & context) override; +}; + +} diff --git a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.cpp b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.cpp index 2711b30f800..ae7ac49b013 100644 --- a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.cpp +++ b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include #include #include @@ -17,7 +19,6 @@ #include #include #include -#include "Interpreters/join_common.h" namespace DB { @@ -28,20 +29,20 @@ const std::set MultipleDistinctAggregationToExpandAggregate::distinct_fu const std::set MultipleDistinctAggregationToExpandAggregate::distinct_func_with_if{ "uniqexactif", "countdistinctif", "avgdistinctif", "maxdistinctif", "mindistinctif", "sumdistinctif"}; +const std::set MultipleDistinctAggregationToExpandAggregate::non_distinct_func_with_if{ + "sumif", "countif", "avgif", "maxif", "minif"}; + const std::unordered_map MultipleDistinctAggregationToExpandAggregate::distinct_func_normal_func{ {"uniqexact", "countIf"}, {"countdistinct", "countIf"}, {"avgdistinct", "avgIf"}, {"maxdistinct", "maxIf"}, {"mindistinct", "minIf"}, - {"sumdistinct", "sumIf"}, - {"count", "anyIf"}, - {"max", "anyIf"}, - {"min", "anyIf"}, - {"avg", "anyIf"}, - {"sum", "anyIf"}}; - -bool MultipleDistinctAggregationToExpandAggregate::hasNoDistinctWithFilterOrMask(const AggregatingStep & step) + {"sumdistinct", "sumIf"}}; + +const std::set MultipleDistinctAggregationToExpandAggregate::un_supported_func{"hllsketchestimate"}; + +bool MultipleDistinctAggregationToExpandAggregate::hasNoFilterOrMask(const AggregatingStep & step) { const AggregateDescriptions & agg_descs = step.getAggregates(); for (const auto & agg_desc : agg_descs) @@ -51,6 +52,9 @@ bool MultipleDistinctAggregationToExpandAggregate::hasNoDistinctWithFilterOrMask if (distinct_func_with_if.contains(Poco::toLower(agg_desc.function->getName()))) return false; + + if (non_distinct_func_with_if.contains(Poco::toLower(agg_desc.function->getName()))) + return false; } return true; } @@ -111,15 +115,15 @@ bool MultipleDistinctAggregationToExpandAggregate::hasUniqueArgument(const Aggre return true; } -bool MultipleDistinctAggregationToExpandAggregate::allCountHasAtMostOneArguments(const AggregatingStep & s) +bool MultipleDistinctAggregationToExpandAggregate::hasNoUnSupportedFunc(const AggregatingStep & step) { - for (const auto & agg : s.getAggregates()) + const AggregateDescriptions & agg_descs = step.getAggregates(); + for (const auto & agg_desc : agg_descs) { - if (Poco::toLower(agg.function->getName()) == "uniqexact" || Poco::toLower(agg.function->getName()) == "countdistinct") - { - if (agg.argument_names.size() > 1) - return false; - } + if (un_supported_func.contains(Poco::toLower(agg_desc.function->getName()))) + return false; + if (!distinct_func.contains(Poco::toLower(agg_desc.function->getName())) && agg_desc.arguments.size() > 1) + return false; } return true; } @@ -127,10 +131,11 @@ bool MultipleDistinctAggregationToExpandAggregate::allCountHasAtMostOneArguments ConstRefPatternPtr MultipleDistinctAggregationToExpandAggregate::getPattern() const { static auto pattern = Patterns::aggregating() - .matchingStep([](const AggregatingStep & s) { - return hasNoDistinctWithFilterOrMask(s) && (hasMultipleDistincts(s) || hasMixedDistinctAndNonDistincts(s)) && hasUniqueArgument(s) && allCountHasAtMostOneArguments(s); - }) - .result(); + .matchingStep([](const AggregatingStep & s) { + return hasNoFilterOrMask(s) && (hasMultipleDistincts(s) || hasMixedDistinctAndNonDistincts(s)) + && hasUniqueArgument(s) && hasNoUnSupportedFunc(s); + }) + .result(); return pattern; } @@ -149,11 +154,11 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan for (const auto & input_column : input) { DataTypePtr type = input_column.type; - type = JoinCommon::tryConvertTypeToNullable(type); + // type = JoinCommon::tryConvertTypeToNullable(type); name_type[input_column.name] = type; assignments.emplace( input_column.name, - makeASTFunction("cast", std::make_shared(Field()), std::make_shared(type->getName()))); + makeASTFunction("cast", std::make_shared(type->getDefault()), std::make_shared(type->getName()))); } /// append a extra mark field : group_id. @@ -184,6 +189,9 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan AggregateDescriptions aggs_with_mask; String non_distinct_agg_group_id_mask; + + Assignments new_argument_assignments; + for (const auto & agg_desc : agg_descs) { String group_id_mask; @@ -207,7 +215,7 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan makeASTFunction( "equals", std::make_shared(group_id_symbol), std::make_shared(distinct_group_id))); - aggs_with_mask.emplace_back(distinctAggWithMask(agg_desc, group_id_mask)); + aggs_with_mask.emplace_back(distinctAggWithMask(agg_desc, group_id_mask, new_argument_assignments, rule_context.context)); distinct_group_id++; } else @@ -268,10 +276,20 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan } // step 2 : add pre-compute aggregate + std::set keyset; + for (const String & key : step.getKeys()) + { + keyset.insert(key); + } + keyset.insert(group_id_symbol); + for (const String & distinct : distinct_arguments) + { + keyset.insert(distinct); + } + + // make sure keys remove duplicated value. Names keys; - keys.insert(keys.end(), step.getKeys().begin(), step.getKeys().end()); - keys.emplace_back(group_id_symbol); - keys.insert(keys.end(), distinct_arguments.begin(), distinct_arguments.end()); + keys.insert(keys.end(), keyset.begin(), keyset.end()); auto pre_agg_step = std::make_shared( expand_node->getStep()->getOutputStream(), @@ -304,11 +322,27 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan } auto mask_step = std::make_shared(pre_agg_node->getStep()->getOutputStream(), mask_assignments, mask_null_name_to_type); - auto mask_node = std::make_shared(rule_context.context->nextNodeId(), std::move(mask_step), PlanNodes{pre_agg_node}); + child = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(mask_step), PlanNodes{pre_agg_node}); + + if (!new_argument_assignments.empty()) + { + NameToType name_to_type; + for (const auto & assignment : new_argument_assignments) + name_to_type.emplace(assignment.first, std::make_shared()); + + for (const auto & input_column : child->getStep()->getOutputStream().header) + { + new_argument_assignments.emplace(input_column.name, makeASTIdentifier(input_column.name)); + name_to_type.emplace(input_column.name, input_column.type); + } + auto new_argument_projection_step + = std::make_shared(child->getStep()->getOutputStream(), new_argument_assignments, name_to_type); + child = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(new_argument_projection_step), {child}); + } // step 4 : final aggregate auto count_agg_step = std::make_shared( - mask_node->getStep()->getOutputStream(), + child->getStep()->getOutputStream(), step.getKeys(), step.getKeysNotHashed(), aggs_with_mask, @@ -321,25 +355,46 @@ TransformResult MultipleDistinctAggregationToExpandAggregate::transformImpl(Plan step.isNoShuffle(), step.isStreamingForCache(), step.getHints()); - auto count_agg_node = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(count_agg_step), {mask_node}); + auto count_agg_node = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(count_agg_step), {child}); return count_agg_node; } -AggregateDescription -MultipleDistinctAggregationToExpandAggregate::distinctAggWithMask(const AggregateDescription & agg_desc, String & mask_column) +AggregateDescription MultipleDistinctAggregationToExpandAggregate::distinctAggWithMask( + const AggregateDescription & agg_desc, String & mask_column, Assignments & new_argument_assignments, ContextMutablePtr context) { - DataTypes data_types = agg_desc.function->getArgumentTypes(); + String fun_remove_distinct = distinct_func_normal_func.at(Poco::toLower(agg_desc.function->getName())); + Names argument_names; + DataTypes data_types; + if (fun_remove_distinct == "countIf" && agg_desc.argument_names.size() > 1) + { + // countDistinct(arg1, arg2) cannot convert to count(arg1, arg2), because clickhousedon't support count multi arguments. + // As an alternative we can rewrite it to count(IF(arg1 is null, null, arg2 is null, null, 1)), + // or sum((arg1 is not null) AND (arg2 is not null)) + fun_remove_distinct = "sumIf"; + + ASTs argument_functions; + for (const auto & argument : agg_desc.argument_names) + argument_functions.emplace_back(makeASTFunction("isNotNull", makeASTIdentifier(argument))); + auto new_argument = PredicateUtils::combineConjuncts(argument_functions); + auto new_argument_name = context->getSymbolAllocator()->newSymbol(new_argument); + new_argument_assignments.emplace_back(new_argument_name, new_argument); + + argument_names.emplace_back(new_argument_name); + data_types.emplace_back(std::make_shared()); + } + else + { + argument_names = agg_desc.argument_names; + data_types = agg_desc.function->getArgumentTypes(); + } + + argument_names.emplace_back(mask_column); data_types.emplace_back(std::make_shared()); Array parameters = agg_desc.function->getParameters(); AggregateFunctionProperties properties; - - String fun_remove_distinct = distinct_func_normal_func.at(Poco::toLower(agg_desc.function->getName())); AggregateFunctionPtr new_agg_fun = AggregateFunctionFactory::instance().get(fun_remove_distinct, data_types, parameters, properties); - Names argument_names = agg_desc.argument_names; - - argument_names.emplace_back(mask_column); AggregateDescription agg_with_mask; @@ -357,13 +412,14 @@ MultipleDistinctAggregationToExpandAggregate::distinctAggWithMask(const Aggregat AggregateDescription MultipleDistinctAggregationToExpandAggregate::nonDistinctAggWithMask(const AggregateDescription & agg_desc, String & mask_column) { - DataTypes data_types = agg_desc.function->getArgumentTypes(); + DataTypes data_types; + data_types.emplace_back(agg_desc.function->getReturnType()); data_types.emplace_back(std::make_shared()); - Array parameters = agg_desc.function->getParameters(); + Array parameters; AggregateFunctionProperties properties; - String fun_remove_distinct = distinct_func_normal_func.at(Poco::toLower(agg_desc.function->getName())); + String fun = "anyIf"; /// in case count(*), agg_desc.function->getArgumentTypes() returns empty. /// anyIf requires 2 arguments @@ -372,7 +428,7 @@ MultipleDistinctAggregationToExpandAggregate::nonDistinctAggWithMask(const Aggre data_types.emplace_back(std::make_shared()); } - AggregateFunctionPtr new_agg_fun = AggregateFunctionFactory::instance().get(fun_remove_distinct, data_types, parameters, properties); + AggregateFunctionPtr new_agg_fun = AggregateFunctionFactory::instance().get(fun, data_types, parameters, properties); Names argument_names; argument_names.emplace_back(agg_desc.column_name); argument_names.emplace_back(mask_column); @@ -381,11 +437,9 @@ MultipleDistinctAggregationToExpandAggregate::nonDistinctAggWithMask(const Aggre agg_with_mask.mask_column = mask_column; agg_with_mask.function = new_agg_fun; - agg_with_mask.parameters = agg_desc.parameters; + agg_with_mask.parameters = parameters; agg_with_mask.column_name = agg_desc.column_name; agg_with_mask.argument_names = argument_names; - agg_with_mask.parameters = agg_desc.parameters; - agg_with_mask.arguments = agg_desc.arguments; return agg_with_mask; } diff --git a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.h b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.h index f8a67ef0349..48e8cfb87e9 100644 --- a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.h +++ b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToExpandAggregate.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include "Interpreters/Context_fwd.h" #include "Optimizer/Rule/Pattern.h" namespace DB @@ -93,10 +94,13 @@ class MultipleDistinctAggregationToExpandAggregate : public Rule private: static const std::set distinct_func; static const std::set distinct_func_with_if; + static const std::set non_distinct_func_with_if; + static const std::set un_supported_func; static const std::unordered_map distinct_func_normal_func; - static bool hasNoDistinctWithFilterOrMask(const AggregatingStep & step); + static bool hasNoFilterOrMask(const AggregatingStep & step); static bool hasMultipleDistincts(const AggregatingStep & step); static bool hasMixedDistinctAndNonDistincts(const AggregatingStep & step); + static bool hasNoUnSupportedFunc(const AggregatingStep & step); /** * Distinct/Non-distinct aggregate function's arguments must unique. @@ -107,10 +111,9 @@ class MultipleDistinctAggregationToExpandAggregate : public Rule */ static bool hasUniqueArgument(const AggregatingStep & step); - // All Count Aggregate Functions must have at most one argument. - static bool allCountHasAtMostOneArguments(const AggregatingStep & step); + static AggregateDescription distinctAggWithMask( + const AggregateDescription & agg_desc, String & mask_column, Assignments & new_argument_assignments, ContextMutablePtr context); - static AggregateDescription distinctAggWithMask(const AggregateDescription & agg_desc, String & mask_column); static AggregateDescription nonDistinctAggWithMask(const AggregateDescription & agg_desc, String & mask_column); static PlanNodePtr makeUnionNode( diff --git a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToMarkDistinct.cpp b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToMarkDistinct.cpp index 9b75bc1bc56..5ea703b644a 100644 --- a/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToMarkDistinct.cpp +++ b/src/Optimizer/Rule/Rewrite/MultipleDistinctAggregationToMarkDistinct.cpp @@ -1,6 +1,13 @@ +#include #include +#include +#include +#include +#include #include #include +#include +#include #include #include @@ -75,6 +82,8 @@ TransformResult MultipleDistinctAggregationToMarkDistinct::transformImpl(PlanNod AggregateDescriptions new_agg_descs; PlanNodePtr child = node->getChildren()[0]; + Assignments new_argument_assignments; + for (const auto & agg_desc : agg_descs) { if (distinct_func.contains(Poco::toLower(agg_desc.function->getName())) && agg_desc.mask_column.empty()) @@ -114,18 +123,40 @@ TransformResult MultipleDistinctAggregationToMarkDistinct::transformImpl(PlanNod child = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(mark_distinct_step), PlanNodes{child}); } - DataTypes data_types = agg_desc.function->getArgumentTypes(); + // remove the distinct flag and set the distinct marker + String fun_remove_distinct = distinct_func_normal_func.at(Poco::toLower(agg_desc.function->getName())); + Names argument_names; + DataTypes data_types; + if (fun_remove_distinct == "countIf" && agg_desc.argument_names.size() > 1) + { + // countDistinct(arg1, arg2) cannot convert to count(arg1, arg2), because clickhousedon't support count multi arguments. + // As an alternative we can rewrite it to count(IF(arg1 is null, null, arg2 is null, null, 1)), + // or sum((arg1 is not null) AND (arg2 is not null)) + fun_remove_distinct = "sumIf"; + + ASTs argument_functions; + for (const auto & argument : agg_desc.argument_names) + argument_functions.emplace_back(makeASTFunction("isNotNull", makeASTIdentifier(argument))); + auto new_argument = PredicateUtils::combineConjuncts(argument_functions); + auto new_argument_name = rule_context.context->getSymbolAllocator()->newSymbol(new_argument); + new_argument_assignments.emplace_back(new_argument_name, new_argument); + + argument_names.emplace_back(new_argument_name); + data_types.emplace_back(std::make_shared()); + } + else + { + argument_names = agg_desc.argument_names; + data_types = agg_desc.function->getArgumentTypes(); + } + + argument_names.emplace_back(marker); data_types.emplace_back(std::make_shared()); Array parameters = agg_desc.function->getParameters(); AggregateFunctionProperties properties; - - // remove the distinct flag and set the distinct marker - String fun_remove_distinct = distinct_func_normal_func.at(Poco::toLower(agg_desc.function->getName())); AggregateFunctionPtr new_agg_fun = AggregateFunctionFactory::instance().get(fun_remove_distinct, data_types, parameters, properties); - Names argument_names = agg_desc.argument_names; - argument_names.emplace_back(marker); AggregateDescription new_agg_desc; @@ -145,6 +176,22 @@ TransformResult MultipleDistinctAggregationToMarkDistinct::transformImpl(PlanNod } } + if (!new_argument_assignments.empty()) + { + NameToType name_to_type; + for (const auto & assignment : new_argument_assignments) + name_to_type.emplace(assignment.first, std::make_shared()); + + for (const auto & input : child->getStep()->getOutputStream().header) + { + new_argument_assignments.emplace(input.name, makeASTIdentifier(input.name)); + name_to_type.emplace(input.name, input.type); + } + auto new_argument_projection_step + = std::make_shared(child->getStep()->getOutputStream(), new_argument_assignments, name_to_type); + child = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(new_argument_projection_step), {child}); + } + auto count_agg_step = std::make_shared( child->getStep()->getOutputStream(), step.getKeys(), diff --git a/src/Optimizer/Rule/Rewrite/PushAggThroughJoinRules.cpp b/src/Optimizer/Rule/Rewrite/PushAggThroughJoinRules.cpp index b47b09efcaf..75a2b69ce41 100644 --- a/src/Optimizer/Rule/Rewrite/PushAggThroughJoinRules.cpp +++ b/src/Optimizer/Rule/Rewrite/PushAggThroughJoinRules.cpp @@ -347,6 +347,7 @@ TransformResult PushAggThroughOuterJoin::transformImpl(PlanNodePtr aggregation, join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -378,6 +379,7 @@ TransformResult PushAggThroughOuterJoin::transformImpl(PlanNodePtr aggregation, join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -522,6 +524,7 @@ TransformResult PushAggThroughInnerJoin::transformImpl(PlanNodePtr aggregation, join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/PushDownApplyRules.cpp b/src/Optimizer/Rule/Rewrite/PushDownApplyRules.cpp index 3aed2891119..90ab9a2c387 100644 --- a/src/Optimizer/Rule/Rewrite/PushDownApplyRules.cpp +++ b/src/Optimizer/Rule/Rewrite/PushDownApplyRules.cpp @@ -71,6 +71,7 @@ TransformResult PushDownApplyThroughJoin::transformImpl(PlanNodePtr node, const step->getKeepLeftReadInOrder(), step->getLeftKeys(), step->getRightKeys(), + step->getKeyIdsNullSafe(), step->getFilter(), step->isHasUsing(), step->getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.cpp b/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.cpp index 5cf70e7c222..a43e344b730 100644 --- a/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.cpp +++ b/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.cpp @@ -31,17 +31,26 @@ #include #include #include +#include namespace DB { NameSet PushPartialAggThroughExchange::BLOCK_AGGS{ - "pathCount", - "attributionAnalysis", - "attributionCorrelationFuse", + "pathcount", + "attributionanalysis", + "attributioncorrelationfuse", "attribution", - "attributionCorrelation", -}; + "attributioncorrelation", + "bitmapjoinandcard", + "bitmapjoinandcard2", + "bitmapjoin", + "bitmapcount", + "bitmapextract", + "bitmapmulticount", + "bitmapmulticountwithdate", + "bitmapmaxlevel", + "bitmapcolumndiff"}; static std::pair canPushPartialWithHint(const AggregatingStep * step) { @@ -200,13 +209,7 @@ TransformResult PushPartialAggThroughExchange::transformImpl(PlanNodePtr node, c for (const auto & agg : step->getAggregates()) { - if (BLOCK_AGGS.count(agg.function->getName())) - { - return {}; - } - - // fixme: remove bitmap* if correctness problem fixed - if (Poco::toLower(agg.function->getName()).starts_with("bitmap")) + if (BLOCK_AGGS.count(Poco::toLower(agg.function->getName()))) { return {}; } @@ -295,8 +298,12 @@ TransformResult PushPartialAggThroughUnion::transformImpl(PlanNodePtr node, cons ConstRefPatternPtr PushPartialSortingThroughExchange::getPattern() const { - static auto pattern = Patterns::sorting().withSingle(Patterns::exchange().matchingStep( - [](const ExchangeStep & step) { return step.getExchangeMode() == ExchangeMode::GATHER; })).result(); + static auto pattern + = Patterns::sorting() + .matchingStep([](const SortingStep & step) { return step.getStage() == SortingStep::Stage::FULL; }) + .withSingle(Patterns::exchange().matchingStep( + [](const ExchangeStep & step) { return step.getExchangeMode() == ExchangeMode::GATHER; })) + .result(); return pattern; } @@ -347,6 +354,58 @@ TransformResult PushPartialSortingThroughExchange::transformImpl(PlanNodePtr nod return final_sort_node; } +ConstRefPatternPtr PushPartialSortingThroughUnion::getPattern() const +{ + static auto pattern + = Patterns::sorting() + .matchingStep([](const SortingStep & step) { return step.getStage() == SortingStep::Stage::PARTIAL; }) + .withSingle(Patterns::unionn()) + .result(); + return pattern; +} + +TransformResult PushPartialSortingThroughUnion::transformImpl(PlanNodePtr node, const Captures &, RuleContext & context) +{ + const auto * step = dynamic_cast(node->getStep().get()); + auto union_node = node->getChildren()[0]; + const auto * union_step = dynamic_cast(union_node->getStep().get()); + + PlanNodes union_inputs; + for (size_t index = 0; index < union_node->getChildren().size(); index++) + { + auto exchange_child = union_node->getChildren()[index]; + if (dynamic_cast(exchange_child.get())) + return {}; + + SortDescription new_sort_desc; + for (const auto & desc : step->getSortDescription()) + { + auto new_desc = desc; + const auto & out_to_inputs = union_step->getOutToInputs(); + if (!out_to_inputs.contains(desc.column_name) || out_to_inputs.at(desc.column_name).size() <= index) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "PushPartialSortingThroughUnion: Can not find {} in out_to_inputs.", desc.column_name); + new_desc.column_name = union_step->getOutToInputs().at(desc.column_name).at(index); + new_sort_desc.emplace_back(new_desc); + } + + auto partial_sorting = std::make_unique( + exchange_child->getStep()->getOutputStream(), new_sort_desc, step->getLimit(), SortingStep::Stage::PARTIAL_NO_MERGE, SortDescription{}); + PlanNodes children{exchange_child}; + auto before_exchange_sort_node + = PlanNodeBase::createPlanNode(context.context->nextNodeId(), std::move(partial_sorting), children, node->getStatistics()); + union_inputs.emplace_back(before_exchange_sort_node); + } + + auto merging_sorted = std::make_unique( + step->getOutputStream(), step->getSortDescription(), step->getLimit(), SortingStep::Stage::MERGE, SortDescription{}); + + return PlanNodeBase::createPlanNode( + context.context->nextNodeId(), + std::move(merging_sorted), + {PlanNodeBase::createPlanNode(context.context->nextNodeId(), union_node->getStep(), union_inputs)}); +} + static bool isLimitNeeded(const LimitStep & limit, const PlanNodePtr & node) { auto range = PlanNodeCardinality::extractCardinality(*node); diff --git a/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.h b/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.h index f2973558e9f..648cdf381ab 100644 --- a/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.h +++ b/src/Optimizer/Rule/Rewrite/PushPartialStepThroughExchangeRules.h @@ -62,6 +62,20 @@ class PushPartialSortingThroughExchange : public Rule TransformResult transformImpl(PlanNodePtr node, const Captures & captures, RuleContext & context) override; }; +class PushPartialSortingThroughUnion : public Rule +{ +public: + RuleType getType() const override { return RuleType::PUSH_PARTIAL_SORTING_THROUGH_UNION; } + String getName() const override { return "PUSH_PARTIAL_SORTING_THROUGH_UNION"; } + bool isEnabled(ContextPtr context) const override + { + return context->getSettingsRef().enable_push_partial_sorting_through_union; + } + ConstRefPatternPtr getPattern() const override; + + TransformResult transformImpl(PlanNodePtr node, const Captures & captures, RuleContext & context) override; +}; + class PushPartialLimitThroughExchange : public Rule { public: diff --git a/src/Optimizer/Rule/Rewrite/SimplifyExpressionRules.cpp b/src/Optimizer/Rule/Rewrite/SimplifyExpressionRules.cpp index e0314b72010..17d0cb37cfc 100644 --- a/src/Optimizer/Rule/Rewrite/SimplifyExpressionRules.cpp +++ b/src/Optimizer/Rule/Rewrite/SimplifyExpressionRules.cpp @@ -79,7 +79,7 @@ TransformResult CommonJoinFilterRewriteRule::transformImpl(PlanNodePtr node, con } QueryPlanStepPtr join_step = std::make_shared( - step.getInputStreams(), + step.getInputStreams(), step.getOutputStream(), step.getKind(), step.getStrictness(), @@ -87,6 +87,7 @@ TransformResult CommonJoinFilterRewriteRule::transformImpl(PlanNodePtr node, con step.getKeepLeftReadInOrder(), step.getLeftKeys(), step.getRightKeys(), + step.getKeyIdsNullSafe(), rewritten, step.isHasUsing(), step.getRequireRightKeys(), @@ -272,6 +273,7 @@ TransformResult SimplifyJoinFilterRewriteRule::transformImpl(PlanNodePtr node, c step.getKeepLeftReadInOrder(), step.getLeftKeys(), step.getRightKeys(), + step.getKeyIdsNullSafe(), rewritten, step.isHasUsing(), step.getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Rewrite/SingleDistinctAggregationToGroupBy.cpp b/src/Optimizer/Rule/Rewrite/SingleDistinctAggregationToGroupBy.cpp index 657ea828568..deb86faa6d5 100644 --- a/src/Optimizer/Rule/Rewrite/SingleDistinctAggregationToGroupBy.cpp +++ b/src/Optimizer/Rule/Rewrite/SingleDistinctAggregationToGroupBy.cpp @@ -82,14 +82,19 @@ TransformResult SingleDistinctAggregationToGroupBy::transformImpl(PlanNodePtr no const auto & step = dynamic_cast(*step_ptr); // insert a extra Group-by Aggregate, perform distinct operation - auto symbols = step.getAggregates()[0].argument_names; - auto group_by = step.getKeys(); - symbols.insert(symbols.begin(), group_by.begin(), group_by.end()); + NameSet distinct_keys; + Names keys; + for (const auto & symbol : step.getKeys()) + if (distinct_keys.emplace(symbol).second) + keys.emplace_back(symbol); + for (const auto & symbol : step.getAggregates()[0].argument_names) + if (distinct_keys.emplace(symbol).second) + keys.emplace_back(symbol); AggregateDescriptions aggregate_descriptions; auto group_by_step = std::make_shared( node->getChildren()[0]->getStep()->getOutputStream(), - symbols, + keys, step.getKeysNotHashed(), aggregate_descriptions, GroupingSetsParamsList{}, @@ -110,7 +115,7 @@ TransformResult SingleDistinctAggregationToGroupBy::transformImpl(PlanNodePtr no auto remove_distinct_agg_step = std::make_shared( group_by_node->getStep()->getOutputStream(), - group_by, + step.getKeys(), step.getKeysNotHashed(), remove_distinct_agg_descs, GroupingSetsParamsList{}, diff --git a/src/Optimizer/Rule/Rule.h b/src/Optimizer/Rule/Rule.h index 8045f40cf2d..0cfef52c85a 100644 --- a/src/Optimizer/Rule/Rule.h +++ b/src/Optimizer/Rule/Rule.h @@ -65,6 +65,7 @@ enum class RuleType : UInt32 PUSH_PARTIAL_AGG_THROUGH_EXCHANGE, PUSH_PARTIAL_AGG_THROUGH_UNION, PUSH_PARTIAL_SORTING_THROUGH_EXCHANGE, + PUSH_PARTIAL_SORTING_THROUGH_UNION, PUSH_PARTIAL_LIMIT_THROUGH_EXCHANGE, PUSH_PARTIAL_DISTINCT_THROUGH_EXCHANGE, diff --git a/src/Optimizer/Rule/Rules.cpp b/src/Optimizer/Rule/Rules.cpp index 9e5ccb588cc..c50f580e7af 100644 --- a/src/Optimizer/Rule/Rules.cpp +++ b/src/Optimizer/Rule/Rules.cpp @@ -43,6 +43,7 @@ #include #include #include +#include namespace DB { @@ -97,6 +98,7 @@ std::vector Rules::pushPartialStepRules() std::make_shared(), std::make_shared(), std::make_shared(), + std::make_shared(), std::make_shared(), std::make_shared(), std::make_shared()}; @@ -237,4 +239,10 @@ std::vector Rules::extractBitmapImplicitFilterRules() { return {std::make_shared()}; } + +std::vector Rules::joinUsingToJoinOn() +{ + return {std::make_shared()}; +} + } diff --git a/src/Optimizer/Rule/Rules.h b/src/Optimizer/Rule/Rules.h index 2410e3083bb..129d3066a90 100644 --- a/src/Optimizer/Rule/Rules.h +++ b/src/Optimizer/Rule/Rules.h @@ -50,6 +50,7 @@ class Rules static std::vector crossJoinToUnion(); static std::vector sumIfToCountIf(); static std::vector extractBitmapImplicitFilterRules(); + static std::vector joinUsingToJoinOn(); }; } diff --git a/src/Optimizer/Rule/Transformation/InnerJoinAssociate.cpp b/src/Optimizer/Rule/Transformation/InnerJoinAssociate.cpp index 24fc2bf57b5..2af60350773 100644 --- a/src/Optimizer/Rule/Transformation/InnerJoinAssociate.cpp +++ b/src/Optimizer/Rule/Transformation/InnerJoinAssociate.cpp @@ -57,6 +57,10 @@ TransformResult InnerJoinAssociate::transformImpl(PlanNodePtr node, const Captur auto * left_join_node = dynamic_cast(node->getChildren()[0].get()); auto left_join_step = left_join_node->getStep(); + // TODO + if (join_step->hasKeyIdNullSafe() || left_join_step->hasKeyIdNullSafe()) + return {}; + auto a = left_join_node->getChildren()[0]; auto b = left_join_node->getChildren()[1]; auto c = node->getChildren()[1]; @@ -200,6 +204,7 @@ TransformResult InnerJoinAssociate::transformImpl(PlanNodePtr node, const Captur rule_context.context->getSettingsRef().optimize_read_in_order, bc_left_keys, bc_right_keys, + std::vector{}, bc_join_filter); auto bc_node = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), bc_join_step, {b, c}); @@ -213,6 +218,7 @@ TransformResult InnerJoinAssociate::transformImpl(PlanNodePtr node, const Captur rule_context.context->getSettingsRef().optimize_read_in_order, top_left_keys, top_right_keys, + std::vector{}, top_join_filter); auto top_node = PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), top_join_step, {a, bc_node}); diff --git a/src/Optimizer/Rule/Transformation/InnerJoinCommutation.cpp b/src/Optimizer/Rule/Transformation/InnerJoinCommutation.cpp index 3f399ac4808..ee7a5242a46 100644 --- a/src/Optimizer/Rule/Transformation/InnerJoinCommutation.cpp +++ b/src/Optimizer/Rule/Transformation/InnerJoinCommutation.cpp @@ -50,6 +50,7 @@ PlanNodePtr InnerJoinCommutation::swap(JoinNode & node, RuleContext & rule_conte step.getKeepLeftReadInOrder(), step.getRightKeys(), step.getLeftKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Transformation/JoinEnumOnGraph.cpp b/src/Optimizer/Rule/Transformation/JoinEnumOnGraph.cpp index 78568f28ca4..18c9ce1e803 100644 --- a/src/Optimizer/Rule/Transformation/JoinEnumOnGraph.cpp +++ b/src/Optimizer/Rule/Transformation/JoinEnumOnGraph.cpp @@ -138,6 +138,7 @@ static PlanNodePtr createJoinNode( context->getOptimizerContext().getContext()->getSettingsRef().optimize_read_in_order, join_keys.first, join_keys.second, + std::vector{}, filter, false, std::nullopt, diff --git a/src/Optimizer/Rule/Transformation/JoinReorderUtils.cpp b/src/Optimizer/Rule/Transformation/JoinReorderUtils.cpp index 1323994fe92..4fe325775d8 100644 --- a/src/Optimizer/Rule/Transformation/JoinReorderUtils.cpp +++ b/src/Optimizer/Rule/Transformation/JoinReorderUtils.cpp @@ -133,6 +133,7 @@ namespace JoinReorderUtils rule_context.context->getSettingsRef().optimize_read_in_order, join_keys.first, join_keys.second, + std::vector{}, join_filter); return PlanNodeBase::createPlanNode(rule_context.context->nextNodeId(), std::move(join_step), {left, right}); diff --git a/src/Optimizer/Rule/Transformation/LeftJoinToRightJoin.cpp b/src/Optimizer/Rule/Transformation/LeftJoinToRightJoin.cpp index ba73933ada9..3edfeef7c37 100644 --- a/src/Optimizer/Rule/Transformation/LeftJoinToRightJoin.cpp +++ b/src/Optimizer/Rule/Transformation/LeftJoinToRightJoin.cpp @@ -45,6 +45,7 @@ TransformResult LeftJoinToRightJoin::transformImpl(PlanNodePtr node, const Captu step.getKeepLeftReadInOrder(), step.getRightKeys(), step.getLeftKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Transformation/PullOuterJoin.cpp b/src/Optimizer/Rule/Transformation/PullOuterJoin.cpp index 5d83cd984a3..cda211edc34 100644 --- a/src/Optimizer/Rule/Transformation/PullOuterJoin.cpp +++ b/src/Optimizer/Rule/Transformation/PullOuterJoin.cpp @@ -99,7 +99,8 @@ static std::optional createNewJoin( context.getSettingsRef().max_threads, context.getSettingsRef().optimize_read_in_order, inner_join->getLeftKeys(), - inner_join->getRightKeys()); + inner_join->getRightKeys(), + inner_join->getKeyIdsNullSafe()); new_left->setOrdered(inner_join->isOrdered()); new_left->setSimpleReordered(inner_join->isSimpleReordered()); new_left->setHints(inner_join->getHints()); @@ -127,6 +128,7 @@ static std::optional createNewJoin( left_join->getKeepLeftReadInOrder(), left_join->getLeftKeys(), left_join->getRightKeys(), + left_join->getKeyIdsNullSafe(), PredicateConst::TRUE_VALUE, left_join->isHasUsing(), left_join->getRequireRightKeys(), diff --git a/src/Optimizer/Rule/Transformation/SemiJoinPushDown.cpp b/src/Optimizer/Rule/Transformation/SemiJoinPushDown.cpp index 9ea6f8a1709..0a16a09f3d1 100644 --- a/src/Optimizer/Rule/Transformation/SemiJoinPushDown.cpp +++ b/src/Optimizer/Rule/Transformation/SemiJoinPushDown.cpp @@ -86,6 +86,7 @@ TransformResult SemiJoinPushDown::transformImpl(PlanNodePtr node, const Captures step.getKeepLeftReadInOrder(), step.getLeftKeys(), step.getRightKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), @@ -113,6 +114,7 @@ TransformResult SemiJoinPushDown::transformImpl(PlanNodePtr node, const Captures step.getKeepLeftReadInOrder(), step.getLeftKeys(), step.getRightKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), @@ -181,6 +183,7 @@ TransformResult SemiJoinPushDownProjection::transformImpl(PlanNodePtr node, cons join_step->getKeepLeftReadInOrder(), mapper.map(join_step->getLeftKeys()), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), mapper.map(join_step->getFilter()), join_step->isHasUsing(), join_step->getRequireRightKeys(), @@ -193,7 +196,7 @@ TransformResult SemiJoinPushDownProjection::transformImpl(PlanNodePtr node, cons join_step->getRuntimeFilterBuilders(), join_step->getHints()); - + // create new projection (add remaining symbols form right side if join is any join) Assignments assignments; NameToType name_to_type; @@ -271,6 +274,7 @@ TransformResult SemiJoinPushDownAggregate::transformImpl(PlanNodePtr node, const join_step->getKeepLeftReadInOrder(), join_step->getLeftKeys(), join_step->getRightKeys(), + join_step->getKeyIdsNullSafe(), join_step->getFilter(), join_step->isHasUsing(), join_step->getRequireRightKeys(), diff --git a/src/Optimizer/ShortCircuitPlanner.cpp b/src/Optimizer/ShortCircuitPlanner.cpp new file mode 100644 index 00000000000..af68a4e1ad1 --- /dev/null +++ b/src/Optimizer/ShortCircuitPlanner.cpp @@ -0,0 +1,120 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +class ShortCircuitPlanner::ShortCircuitPlanVisitor : public DB::PlanNodeVisitor +{ +public: + explicit ShortCircuitPlanVisitor(ContextPtr context_) : context(context_) + { + } + + bool visitPlanNode(PlanNodeBase & plan, Void & c) override + { + if (plan.getChildren().size() == 1) + return VisitorUtil::accept(plan.getChildren()[0], *this, c); + return false; + } + + bool visitLimitNode(LimitNode & plan, Void & v) override + { + return VisitorUtil::accept(plan.getChildren()[0], *this, v); + } + + bool visitProjectionNode(ProjectionNode & plan, Void & v) override + { + return VisitorUtil::accept(plan.getChildren()[0], *this, v); + } + + bool visitFilterNode(FilterNode & plan, Void &) override + { + if (plan.getChildren()[0]->getType() == IQueryPlanStep::Type::TableScan) + return checkTableScan(dynamic_cast(*plan.getChildren()[0]->getStep()), plan.getStep()->getFilter()); + return false; + } + + static bool checkTableScan(TableScanStep & table_scan, ConstASTPtr filter) + { + auto constraints = extractConstraints(filter); + auto metadata = table_scan.getStorage()->getInMemoryMetadataPtr(); + return isPointScan(metadata->getUniqueKey(), constraints); + } + + static std::unordered_set extractConstraints(ConstASTPtr filter) + { + std::unordered_set constraints; + for (const auto & conjunct : PredicateUtils::extractConjuncts(filter)) + { + const auto * func = conjunct->as(); + if (!func || func->name != "equals") + continue; + const auto * column = func->arguments->children[0]->as(); + if (!column) + continue; + if (func->arguments->children[1]->getType() != ASTType::ASTLiteral + && func->arguments->children[1]->getType() != ASTType::ASTPreparedParameter) + continue; + constraints.emplace(column->name()); + } + return constraints; + } + + /** + * Check filter constains all unique. + */ + static bool isPointScan(const KeyDescription & primary_key, const std::unordered_set & constraints) + { + return std::all_of( + primary_key.column_names.begin(), primary_key.column_names.end(), [&](const auto & key) { return constraints.contains(key); }); + } + +private: + ContextPtr context; +}; + +bool ShortCircuitPlanner::isShortCircuitPlan(QueryPlan & query_plan, ContextPtr context) +{ + if (!context->getSettingsRef().enable_short_circuit) + return false; + + ShortCircuitPlanVisitor visitor{context}; + Void v; + return query_plan.getCTEInfo().empty() && VisitorUtil::accept(query_plan.getPlanNode(), visitor, v); +} + +void ShortCircuitPlanner::addExchangeIfNeeded(QueryPlan & query_plan, ContextMutablePtr context) +{ + // todo: analyze optimized cluster + auto output = query_plan.getPlanNode(); + if (output->getType() != IQueryPlanStep::Type::Projection) + throw Exception(ErrorCodes::LOGICAL_ERROR, "output node is expected a project"); + auto child = output->getChildren()[0]; + auto gather = PlanNodeBase::createPlanNode( + query_plan.getIdAllocator()->nextId(), + std::make_unique( + DataStreams{child->getStep()->getOutputStream()}, + ExchangeMode::GATHER, + Partitioning(Names{}), + context->getSettingsRef().enable_shuffle_with_order), + {child}); + output->replaceChildren({gather}); +} +} diff --git a/src/Optimizer/ShortCircuitPlanner.h b/src/Optimizer/ShortCircuitPlanner.h new file mode 100644 index 00000000000..cc45f5702e5 --- /dev/null +++ b/src/Optimizer/ShortCircuitPlanner.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace DB +{ +class ShortCircuitPlanner +{ +public: + static bool isShortCircuitPlan(QueryPlan & query_plan, ContextPtr context); + static void addExchangeIfNeeded(QueryPlan & query_plan, ContextMutablePtr context); + +private: + class ShortCircuitPlanVisitor; +}; +} diff --git a/src/Optimizer/Signature/PlanSignature.cpp b/src/Optimizer/Signature/PlanSignature.cpp index 5359ffdeafb..d0e3702f3fa 100644 --- a/src/Optimizer/Signature/PlanSignature.cpp +++ b/src/Optimizer/Signature/PlanSignature.cpp @@ -10,7 +10,10 @@ namespace DB { // for test assert, see more 10102_intermediate_result_cache -static const std::unordered_set IGNORED_SETTINGS{"max_bytes_to_read", "max_rows_to_read"}; +static const std::unordered_set IGNORED_SETTINGS{ + "max_bytes_to_read", "max_rows_to_read", "load_balancing", "prefer_localhost_replica", "send_logs_level", "max_execution_time"}; + +#define CHECK_IGNORED_SETTINGS(name) ((name).find("timeout") != std::string::npos || IGNORED_SETTINGS.contains(name)) size_t PlanSignatureProvider::combine(const std::vector & hashes) { @@ -35,7 +38,7 @@ PlanSignature PlanSignatureProvider::combineSettings(PlanSignature signature, co size_t size = 0; for (const auto & item : settings) { - if (IGNORED_SETTINGS.contains(item.name)) + if (CHECK_IGNORED_SETTINGS(item.name)) continue; size += 1; } @@ -43,7 +46,7 @@ PlanSignature PlanSignatureProvider::combineSettings(PlanSignature signature, co hash.update(size); for (const auto & item : settings) { - if (IGNORED_SETTINGS.contains(item.name)) + if (CHECK_IGNORED_SETTINGS(item.name)) continue; hash.update(sipHash64(item.name)); applyVisitor(FieldVisitorHash(hash), item.value); diff --git a/src/Optimizer/Signature/PlanSignature.h b/src/Optimizer/Signature/PlanSignature.h index a19e99e3c33..518fd633ffd 100644 --- a/src/Optimizer/Signature/PlanSignature.h +++ b/src/Optimizer/Signature/PlanSignature.h @@ -56,10 +56,7 @@ class PlanSignatureProvider } protected: - virtual PlanSignature computeStepHash(PlanNodePtr node) - { - return normalizer.computeNormalStep(node)->hash(); - } + virtual PlanSignature computeStepHash(PlanNodePtr node) { return normalizer.computeNormalStep(node)->hash(false); } static size_t combine(const std::vector & hashes); diff --git a/src/Optimizer/Signature/StepNormalizer.cpp b/src/Optimizer/Signature/StepNormalizer.cpp index 624b0b59326..d6f863b45e8 100644 --- a/src/Optimizer/Signature/StepNormalizer.cpp +++ b/src/Optimizer/Signature/StepNormalizer.cpp @@ -418,33 +418,13 @@ StepAndOutputOrder StepNormalizer::visitJoinStep(const JoinStep & step, StepsAnd DataStreams normal_input_streams = processInputStreams(step.getInputStreams(), inputs, symbol_mapping, cumulative_pos); createOutputSymbolMapping(step.getOutputStream().header, symbol_mapping, cumulative_pos); SymbolMapper symbol_mapper = SymbolMapper::simpleMapper(symbol_mapping); - QueryPlanStepPtr normal_step = symbol_mapper.map(step); + auto normal_step = symbol_mapper.map(step); auto output_header = normal_step->getOutputStream().header.getColumnsWithTypeAndName(); ExpressionReorderNormalizer::reorder(output_header); // replace the input_stream & output_stream because of reordering - normal_step = std::make_shared( - normal_input_streams, - DataStream{output_header}, - step.getKind(), - step.getStrictness(), - step.getMaxStreams(), - step.getKeepLeftReadInOrder(), - step.getLeftKeys(), - step.getRightKeys(), - step.getFilter(), - step.isHasUsing(), - step.getRequireRightKeys(), - step.getAsofInequality(), - step.getDistributionType(), - step.getJoinAlgorithm(), - step.isMagic(), - step.isOrdered(), - step.isSimpleReordered(), - // step.isParallel(), - // step.isBucket(), - step.getRuntimeFilterBuilders(), - step.getHints()); + normal_step->setInputStreams(normal_input_streams); + normal_step->setOutputStream(DataStream{output_header}); Block output_order = getOutputOrder(step, *normal_step, symbol_mapper); return StepAndOutputOrder{normal_step, std::move(output_order)}; diff --git a/src/Optimizer/Signature/StepNormalizer.h b/src/Optimizer/Signature/StepNormalizer.h index fc39886ef1c..04bef58bbe3 100644 --- a/src/Optimizer/Signature/StepNormalizer.h +++ b/src/Optimizer/Signature/StepNormalizer.h @@ -39,9 +39,9 @@ class StepNormalizer; /** * @class StepAndOutputOrder is the outcome of normalizing a step * - * @param normal_step is the normalized step for each original step, which is in index_ref and can be used to calculate hash etc. + * normal_step is the normalized step for each original step, which is in index_ref and can be used to calculate hash etc. * All normal_steps implicitly formulate a tree, whose structure is implied by the original plan. Handling this is left to PlanNormalizer - * @param reordered_header is an reordering of the original output header. + * reordered_header is an reordering of the original output header. * It is the same as the header of original step if no reordering take place. * It is different from normal_step->getOutputStream().header, as the symbols in reordered_header are still the original symbols. * The parent must use this information to normalize. diff --git a/src/Optimizer/SymbolTransformMap.cpp b/src/Optimizer/SymbolTransformMap.cpp index 7b274719253..50c5b27cbfb 100644 --- a/src/Optimizer/SymbolTransformMap.cpp +++ b/src/Optimizer/SymbolTransformMap.cpp @@ -181,7 +181,7 @@ String SymbolTransformMap::toString() const bool SymbolTransformMap::addSymbolMapping(const String & symbol, ConstASTPtr expr) { for (const auto & symbol_in_expr : SymbolsExtractor::extract(expr)) - if (symbol_to_expressions.contains(symbol_in_expr)) + if (symbol == symbol_in_expr || symbol_to_expressions.contains(symbol_in_expr)) return false; return symbol_to_expressions.emplace(symbol, std::move(expr)).second; } diff --git a/src/Optimizer/tests/gtest_base_plan_test.cpp b/src/Optimizer/tests/gtest_base_plan_test.cpp index f6f2a250995..d1d0a09b56b 100644 --- a/src/Optimizer/tests/gtest_base_plan_test.cpp +++ b/src/Optimizer/tests/gtest_base_plan_test.cpp @@ -76,7 +76,7 @@ BasePlanTest::BasePlanTest(const String & database_name_, const std::unordered_m setting_changes.emplace_back("enable_optimizer", true); setting_changes.emplace_back("enable_memory_catalog", true); - setting_changes.emplace_back("dialect_type", "ANSI"s); + setting_changes.emplace_back("dialect_type", "ANSI"); setting_changes.emplace_back("data_type_default_nullable", false); for (const auto & item : session_settings) diff --git a/src/Optimizer/tests/gtest_cascades.cpp b/src/Optimizer/tests/gtest_cascades.cpp index 48818d1328b..15b1a115d6b 100644 --- a/src/Optimizer/tests/gtest_cascades.cpp +++ b/src/Optimizer/tests/gtest_cascades.cpp @@ -71,6 +71,7 @@ PlanNodePtr join(const PlanNodePtr & left, const PlanNodePtr & right, const Name false, left_keys, right_keys, + std::vector{}, PredicateConst::TRUE_VALUE); return PlanNodeBase::createPlanNode(1, std::move(join_step), {left, right}); } diff --git a/src/Optimizer/tests/gtest_plan_signature.cpp b/src/Optimizer/tests/gtest_plan_signature.cpp index 03317a35745..569e77f9dd4 100644 --- a/src/Optimizer/tests/gtest_plan_signature.cpp +++ b/src/Optimizer/tests/gtest_plan_signature.cpp @@ -216,7 +216,7 @@ TEST_F(PlanSignatureTest, testTpcdsAllSignaturesWithoutRuntimeFilter) } std::sort( sorted_by_freq.begin(), sorted_by_freq.end(), [](const auto & left, const auto & right) { return left.size() > right.size(); }); - EXPECT_EQ(sorted_by_freq.size(), 11); + EXPECT_EQ(sorted_by_freq.size(), 12); // all binary mappings EXPECT_EQ(sorted_by_freq[0].size(), 2); // std::unordered_map query_mapping; diff --git a/src/Parsers/ASTAssignment.h b/src/Parsers/ASTAssignment.h index 3da54d717b4..389a56e008e 100644 --- a/src/Parsers/ASTAssignment.h +++ b/src/Parsers/ASTAssignment.h @@ -26,10 +26,11 @@ namespace DB { -/// Part of the ALTER UPDATE statement of the form: column = expr +/// Part of the ALTER UPDATE statement of the form: column = expr or tbl_alias.col = expr class ASTAssignment : public IAST { public: + String table_name; String column_name; ASTPtr expression() const @@ -37,7 +38,9 @@ class ASTAssignment : public IAST return children.at(0); } - String getID(char delim) const override { return "Assignment" + (delim + column_name); } + String tablePrefix() const { return table_name.empty() ? "" : table_name + "."; } + + String getID(char delim) const override { return "Assignment" + (delim + tablePrefix() + column_name); } ASTType getType() const override { return ASTType::ASTAssignment; } @@ -50,11 +53,13 @@ class ASTAssignment : public IAST void toLowerCase() override { + boost::to_lower(table_name); boost::to_lower(column_name); } void toUpperCase() override { + boost::to_upper(table_name); boost::to_upper(column_name); } @@ -62,6 +67,11 @@ class ASTAssignment : public IAST void formatImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override { settings.ostr << (settings.hilite ? hilite_identifier : ""); + if (!table_name.empty()) + { + settings.writeIdentifier(table_name); + settings.ostr << "."; + } settings.writeIdentifier(column_name); settings.ostr << (settings.hilite ? hilite_none : ""); diff --git a/src/Parsers/ASTClusterByElement.cpp b/src/Parsers/ASTClusterByElement.cpp index 4793b380efb..cfe739ec637 100644 --- a/src/Parsers/ASTClusterByElement.cpp +++ b/src/Parsers/ASTClusterByElement.cpp @@ -15,6 +15,7 @@ #include #include +#include #include @@ -56,5 +57,28 @@ ASTPtr ASTClusterByElement::clone() const return clone; } +void ASTClusterByElement::serialize(WriteBuffer & buf) const +{ + writeBinary(split_number, buf); + writeBinary(is_with_range, buf); + writeBinary(is_user_defined_expression, buf); + serializeASTs(children, buf); +} + +void ASTClusterByElement::deserializeImpl(ReadBuffer & buf) +{ + readBinary(split_number, buf); + readBinary(is_with_range, buf); + readBinary(is_user_defined_expression, buf); + children = deserializeASTs(buf); +} + +ASTPtr ASTClusterByElement::deserialize(ReadBuffer & buf) +{ + auto element = std::make_shared(); + element->deserializeImpl(buf); + return element; +} + } diff --git a/src/Parsers/ASTClusterByElement.h b/src/Parsers/ASTClusterByElement.h index 45766f2247f..80b06d99fab 100644 --- a/src/Parsers/ASTClusterByElement.h +++ b/src/Parsers/ASTClusterByElement.h @@ -42,9 +42,15 @@ class ASTClusterByElement : public IAST const ASTPtr & getColumns() const { return children.front(); } const ASTPtr & getTotalBucketNumber() const { return children.back(); } + ASTType getType() const override { return ASTType::ASTClusterByElement; } + String getID(char) const override { return "ClusterByElement"; } ASTPtr clone() const override; + void serialize(WriteBuffer & buf) const override; + void deserializeImpl(ReadBuffer & buf) override; + static ASTPtr deserialize(ReadBuffer & buf); + protected: void formatImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; }; diff --git a/src/Parsers/ASTDictionary.h b/src/Parsers/ASTDictionary.h index 0936a96ec48..05623afd2d8 100644 --- a/src/Parsers/ASTDictionary.h +++ b/src/Parsers/ASTDictionary.h @@ -133,6 +133,12 @@ class ASTDictionary : public IAST ASTPtr clone() const override; void formatImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; + + String clickhouse_db; + String clickhouse_tb; + String clickhouse_query; + String clickhouse_invalidate_query; + }; } diff --git a/src/Parsers/ASTGrantQuery.cpp b/src/Parsers/ASTGrantQuery.cpp index baa5045704e..cd547e29a7b 100644 --- a/src/Parsers/ASTGrantQuery.cpp +++ b/src/Parsers/ASTGrantQuery.cpp @@ -108,6 +108,12 @@ void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, F settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << (attach_mode ? "ATTACH " : "") << (is_revoke ? "REVOKE" : "GRANT") << (settings.hilite ? IAST::hilite_none : ""); + if (is_sensitive) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << (" SENSITIVE") + << (settings.hilite ? IAST::hilite_none : ""); + } + if (!access_rights_elements.sameOptions()) throw Exception("Elements of an ASTGrantQuery are expected to have the same options", ErrorCodes::LOGICAL_ERROR); if (!access_rights_elements.empty() && access_rights_elements[0].is_partial_revoke && !is_revoke) @@ -118,6 +124,8 @@ void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, F if (is_revoke) { + if (if_exists) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : ""); if (grant_option) settings.ostr << (settings.hilite ? hilite_keyword : "") << " GRANT OPTION FOR" << (settings.hilite ? hilite_none : ""); else if (admin_option) @@ -180,7 +188,7 @@ void ASTGrantQuery::rewriteNamesWithTenant(const Context *) } tenant_rewritten = true; - } + } } void ASTGrantQuery::rewriteNamesWithoutTenant(const Context *) diff --git a/src/Parsers/ASTGrantQuery.h b/src/Parsers/ASTGrantQuery.h index 4964e12074b..a36dbaf77d9 100644 --- a/src/Parsers/ASTGrantQuery.h +++ b/src/Parsers/ASTGrantQuery.h @@ -41,7 +41,9 @@ class ASTGrantQuery : public IAST { public: bool attach_mode = false; + bool is_sensitive = false; bool is_revoke = false; + bool if_exists = false; AccessRightsElements access_rights_elements; std::shared_ptr roles; bool admin_option = false; diff --git a/src/Parsers/ASTIdentifier.cpp b/src/Parsers/ASTIdentifier.cpp index d4cb34eba5f..591283307d3 100644 --- a/src/Parsers/ASTIdentifier.cpp +++ b/src/Parsers/ASTIdentifier.cpp @@ -307,7 +307,7 @@ void ASTIdentifier::appendCatalogName(const std::string & catalog_name) cnch_append_catalog = true; } -void ASTIdentifier::appendTenantId(const Context * context) +void ASTIdentifier::appendTenantId(const Context* context, bool is_datbase_name) { if (!context) return; @@ -315,7 +315,7 @@ void ASTIdentifier::appendTenantId(const Context * context) { /// Only catalogname case 1: - name_parts[0] = appendTenantIdOnly(name_parts[0]); + name_parts[0] = appendTenantIdOnly(name_parts[0], is_datbase_name); resetFullName(); break; default: @@ -463,7 +463,7 @@ void ASTTableIdentifier::appendCatalogName(const std::string & catalog_name) cnch_append_catalog = true; } -void ASTTableIdentifier::appendTenantId([[maybe_unused]] const Context * context) +void ASTTableIdentifier::appendTenantId([[maybe_unused]]const Context* context, bool /*is_datbase_name*/) { // this function shall not be called on TableIdentifier. throw Exception(ErrorCodes::LOGICAL_ERROR, "this function shall not be called on TableIdentifier."); @@ -535,9 +535,8 @@ void tryRewriteHiveCatalogName(ASTPtr & ast_catalog, const Context * context) return; if (auto * c = dynamic_cast(ast_catalog.get())) { - if (c->name() == "cnch") - return; - c->appendTenantId(context); + if(c->name() == "cnch") return; + c->appendTenantId(context, true); } } diff --git a/src/Parsers/ASTIdentifier.h b/src/Parsers/ASTIdentifier.h index eb85ca1b3e6..f47a9ba875d 100644 --- a/src/Parsers/ASTIdentifier.h +++ b/src/Parsers/ASTIdentifier.h @@ -91,7 +91,7 @@ class ASTIdentifier : public ASTWithAlias virtual void appendCatalogName(const std::string& catalog_name); - virtual void appendTenantId(const Context * context); + virtual void appendTenantId(const Context * context, bool is_datbase_name); String full_name; std::vector name_parts; @@ -143,7 +143,7 @@ class ASTTableIdentifier : public ASTIdentifier // void rewriteCnchDatabaseOrCatalog(const Context *context) override; void rewriteCnchDatabaseName(const Context * context = nullptr) override; virtual void appendCatalogName(const std::string& catalog_name) override; - virtual void appendTenantId(const Context * context) override; + virtual void appendTenantId(const Context * context, bool is_datbase_name) override; }; diff --git a/src/Parsers/ASTPreparedStatement.cpp b/src/Parsers/ASTPreparedStatement.cpp index 3df791a866b..dceb694c1c0 100644 --- a/src/Parsers/ASTPreparedStatement.cpp +++ b/src/Parsers/ASTPreparedStatement.cpp @@ -1,4 +1,6 @@ #include +#include +#include namespace DB { @@ -6,8 +8,10 @@ namespace DB ASTPtr ASTCreatePreparedStatementQuery::clone() const { auto res = std::make_shared(*this); + res->name_ast = name_ast->clone(); res->query = query->clone(); res->children.clear(); + res->children.push_back(res->name_ast); res->children.push_back(res->query); return res; } @@ -24,7 +28,7 @@ void ASTCreatePreparedStatementQuery::formatImpl(const FormatSettings & settings else if (or_replace) settings.ostr << (settings.hilite ? hilite_keyword : "") << "OR REPLACE " << (settings.hilite ? hilite_none : ""); - settings.ostr << (settings.hilite ? hilite_identifier : "") << name << (settings.hilite ? hilite_none : ""); + name_ast->formatImpl(settings, state, frame); formatOnCluster(settings); settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS" << (settings.hilite ? hilite_none : ""); @@ -35,6 +39,33 @@ void ASTCreatePreparedStatementQuery::formatImpl(const FormatSettings & settings } } +void ASTCreatePreparedStatementQuery::rewriteNamesWithTenant(const Context* context) +{ + if (!context) + { + String new_name = formatTenantName(name); + if (new_name != name) + { + auto tenant_id = getCurrentTenantId(); + std::vector name_part = {tenant_id, name}; + name_ast = std::make_shared(std::move(name_part), false); + name = new_name; + } + } + + if (auto * identifier = name_ast->as()) + { + identifier->appendTenantId(context, false); + name = identifier->name(); + } +} + +void ASTCreatePreparedStatementQuery::rewriteNamesWithoutTenant() +{ + name = getOriginalEntityName(name); + name_ast = std::make_shared(name); +} + ASTPtr ASTExecutePreparedStatementQuery::clone() const { auto res = std::make_shared(*this); @@ -46,7 +77,6 @@ ASTPtr ASTExecutePreparedStatementQuery::clone() const return res; } - void ASTExecutePreparedStatementQuery::formatQueryImpl(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const { settings.ostr << (settings.hilite ? hilite_keyword : "") << "EXECUTE PREPARED STATEMENT " << (settings.hilite ? hilite_identifier : "") @@ -61,6 +91,16 @@ void ASTExecutePreparedStatementQuery::formatQueryImpl(const FormatSettings & se } } +void ASTExecutePreparedStatementQuery::rewriteNamesWithTenant(const Context* /*context*/) +{ + name = formatTenantName(name); +} + +void ASTExecutePreparedStatementQuery::rewriteNamesWithoutTenant() +{ + name = getOriginalEntityName(name); +} + ASTPtr ASTShowPreparedStatementQuery::clone() const { auto res = std::make_shared(*this); @@ -86,6 +126,20 @@ void ASTShowPreparedStatementQuery::formatQueryImpl(const FormatSettings & setti settings.ostr << (settings.hilite ? hilite_keyword : "") << "PREPARED STATEMENTS" << (settings.hilite ? hilite_none : ""); } +void ASTShowPreparedStatementQuery::rewriteNamesWithTenant(const Context* /*context*/) +{ + if (name.empty()) + return; + name = formatTenantName(name); +} + +void ASTShowPreparedStatementQuery::rewriteNamesWithoutTenant() +{ + if (name.empty()) + return; + name = getOriginalEntityName(name); +} + ASTPtr ASTDropPreparedStatementQuery::clone() const { auto res = std::make_shared(*this); @@ -102,5 +156,15 @@ void ASTDropPreparedStatementQuery::formatImpl(const FormatSettings & settings, formatOnCluster(settings); } +void ASTDropPreparedStatementQuery::rewriteNamesWithTenant(const Context* /*context*/) +{ + name = formatTenantName(name); +} + +void ASTDropPreparedStatementQuery::rewriteNamesWithoutTenant() +{ + name = getOriginalEntityName(name); +} + } diff --git a/src/Parsers/ASTPreparedStatement.h b/src/Parsers/ASTPreparedStatement.h index 6d815e1149c..067666fcf49 100644 --- a/src/Parsers/ASTPreparedStatement.h +++ b/src/Parsers/ASTPreparedStatement.h @@ -12,6 +12,7 @@ class ASTCreatePreparedStatementQuery : public IAST, public ASTQueryWithOnCluste { public: String name; + ASTPtr name_ast; ASTPtr query; bool if_not_exists = false; @@ -45,6 +46,9 @@ class ASTCreatePreparedStatementQuery : public IAST, public ASTQueryWithOnCluste return removeOnCluster(clone()); } + void rewriteNamesWithTenant(const Context* context = nullptr); + void rewriteNamesWithoutTenant(); + protected: void formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; @@ -78,6 +82,9 @@ class ASTExecutePreparedStatementQuery : public ASTQueryWithOutput return values; } + void rewriteNamesWithTenant(const Context* context = nullptr); + void rewriteNamesWithoutTenant(); + protected: void formatQueryImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; @@ -103,6 +110,9 @@ class ASTShowPreparedStatementQuery : public ASTQueryWithOutput ASTPtr clone() const override; + void rewriteNamesWithTenant(const Context* context = nullptr); + void rewriteNamesWithoutTenant(); + protected: void formatQueryImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; @@ -132,6 +142,9 @@ class ASTDropPreparedStatementQuery : public IAST, public ASTQueryWithOnCluster return removeOnCluster(clone()); } + void rewriteNamesWithTenant(const Context* context = nullptr); + void rewriteNamesWithoutTenant(); + protected: void formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; diff --git a/src/Parsers/ASTQueryWithOutput.h b/src/Parsers/ASTQueryWithOutput.h index c60ec79c7e9..f4c33ac8e21 100644 --- a/src/Parsers/ASTQueryWithOutput.h +++ b/src/Parsers/ASTQueryWithOutput.h @@ -39,6 +39,7 @@ class ASTQueryWithOutput : public IAST ASTPtr settings_ast; ASTPtr compression_method; ASTPtr compression_level; + bool ignore_format = false; void formatOutput(const FormatSettings & s, FormatState & state, FormatStateStacked frame) const; diff --git a/src/Parsers/ASTSerDerHelper.cpp b/src/Parsers/ASTSerDerHelper.cpp index 839c324a7ef..074c98a0cda 100644 --- a/src/Parsers/ASTSerDerHelper.cpp +++ b/src/Parsers/ASTSerDerHelper.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -31,8 +32,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -112,7 +113,6 @@ namespace DB { - ASTPtr createWithASTType(ASTType type, ReadBuffer & buf) { switch (type) diff --git a/src/Parsers/ASTUtils.h b/src/Parsers/ASTUtils.h new file mode 100644 index 00000000000..d9101e6eb19 --- /dev/null +++ b/src/Parsers/ASTUtils.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace DB +{ + +inline void replace_func_with_known_column(ASTPtr & definition_ast, const NameSet & columns) +{ + if (!definition_ast) + return; + + if (ASTFunction * func = definition_ast->as()) + { + String column_name = func->getColumnName(); + if (columns.count(column_name)) + { + auto identifier = std::make_shared(column_name); + definition_ast = identifier; + } + else + { + for (auto & child : func->arguments->children) + replace_func_with_known_column(child, columns); + } + } + else + { + for (auto & child : definition_ast->children) + replace_func_with_known_column(child, columns); + } +} + +} diff --git a/src/Parsers/ExpressionElementParsers.cpp b/src/Parsers/ExpressionElementParsers.cpp index c25f7d1ed8f..830f8ac25d3 100644 --- a/src/Parsers/ExpressionElementParsers.cpp +++ b/src/Parsers/ExpressionElementParsers.cpp @@ -68,6 +68,7 @@ #include #include #include "ASTColumnsMatcher.h" +#include "Parsers/parseDatabaseAndTableName.h" #include #include @@ -181,15 +182,18 @@ bool ParserSubquery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) bool ParserIdentifier::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { - /// Identifier in backquotes or in double quotes - if (pos->type == TokenType::BackQuotedIdentifier || pos->type == TokenType::DoubleQuotedIdentifier) + /// Identifier in backquotes or in double quotes or single quotes + if (pos->type == TokenType::BackQuotedIdentifier || pos->type == TokenType::DoubleQuotedIdentifier + || (allow_single_quoted_identifier && pos->type == TokenType::StringLiteral)) { ReadBufferFromMemory buf(pos->begin, pos->size()); String s; if (*pos->begin == '`') readBackQuotedStringWithSQLStyle(s, buf); - else + else if (*pos->begin == '\'') + readQuotedStringWithSQLStyle(s, buf); + else if (*pos->begin == '"') readDoubleQuotedStringWithSQLStyle(s, buf); if (s.empty()) /// Identifiers "empty string" are not allowed. @@ -1079,7 +1083,7 @@ bool ParserCastOperator::parseImpl(Pos & pos, ASTPtr & node, Expected & expected if (!isOneOf(last_token)) return false; } - else if (isOneOf(pos->type)) + else if (isOneOf(pos->type)) { if (!isOneOf(last_token)) return false; @@ -2272,7 +2276,7 @@ const char * ParserAlias::restricted_keywords[] = bool ParserAlias::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { ParserKeyword s_as("AS"); - ParserIdentifier id_p; + ParserIdentifier id_p(false, allow_single_quoted_identifier); bool has_as_word = s_as.ignore(pos, expected); if (!allow_alias_without_as_keyword && !has_as_word) @@ -2707,7 +2711,7 @@ bool ParserWithOptionalAlias::parseImpl(Pos & pos, ASTPtr & node, Expected & exp allow_alias_without_as_keyword_now = false; ASTPtr alias_node; - if (ParserAlias(allow_alias_without_as_keyword_now).parse(pos, alias_node, expected)) + if (ParserAlias(allow_alias_without_as_keyword_now, dt.parse_mysql_ddl).parse(pos, alias_node, expected)) { /// FIXME: try to prettify this cast using `as<>()` if (auto * ast_with_alias = dynamic_cast(node.get())) @@ -3130,6 +3134,32 @@ bool ParserAssignment::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) return true; } +/// a.col = _expression_ or col = _expression_ +/// Reuse `parseDatabaseAndTableName` for extracting table alias and column name. +bool ParserAssignmentWithAlias::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + auto assignment = std::make_shared(); + node = assignment; + + ParserToken s_equals(TokenType::Equals); + ParserExpression p_expression(dt); + + /// Reuse `parseDatabaseAndTableName` for extracting table alias and column name. Need to ignore tenant_id. + parseDatabaseAndTableName(pos, expected, assignment->table_name, assignment->column_name, /*rewrite_db*/false); + + if (!s_equals.ignore(pos, expected)) + return false; + + ASTPtr expression; + if (!p_expression.parse(pos, expression, expected)) + return false; + + if (expression) + assignment->children.push_back(expression); + + return true; +} + bool ParserEscapeExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { ParserStringLiteral escape_exp(dt); @@ -3149,4 +3179,11 @@ bool ParserEscapeExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expe return true; } +bool ParserExecuteValue::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + return ParserTupleOfLiterals(dt).parse(pos, node, expected) + || ParserArrayOfLiterals(dt).parse(pos, node, expected) + || ParserLiteral(dt).parse(pos, node, expected); +} + } diff --git a/src/Parsers/ExpressionElementParsers.h b/src/Parsers/ExpressionElementParsers.h index 248882225fb..63b5590e0ef 100644 --- a/src/Parsers/ExpressionElementParsers.h +++ b/src/Parsers/ExpressionElementParsers.h @@ -72,12 +72,13 @@ class ParserSubquery : public IParserDialectBase class ParserIdentifier : public IParserBase { public: - explicit ParserIdentifier(bool allow_query_parameter_ = false) : allow_query_parameter(allow_query_parameter_) {} + explicit ParserIdentifier(bool allow_query_parameter_ = false, bool allow_single_quoted_identifier_ = false) : allow_query_parameter(allow_query_parameter_), allow_single_quoted_identifier(allow_single_quoted_identifier_) {} protected: const char * getName() const override { return "identifier"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; bool allow_query_parameter; + bool allow_single_quoted_identifier; }; @@ -480,12 +481,14 @@ class ParserLiteral : public IParserDialectBase class ParserAlias : public IParserBase { public: - explicit ParserAlias(bool allow_alias_without_as_keyword_) : allow_alias_without_as_keyword(allow_alias_without_as_keyword_) { } + explicit ParserAlias(bool allow_alias_without_as_keyword_, bool allow_single_quoted_identifier_ = false) : allow_alias_without_as_keyword(allow_alias_without_as_keyword_), allow_single_quoted_identifier(allow_single_quoted_identifier_) { } private: static const char * restricted_keywords[]; bool allow_alias_without_as_keyword; + /// default false; set to true for mysql, which allows: select 123 as 'offset' + bool allow_single_quoted_identifier; const char * getName() const override { return "alias"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; @@ -631,6 +634,15 @@ class ParserAssignment : public IParserDialectBase using IParserDialectBase::IParserDialectBase; }; +class ParserAssignmentWithAlias : public IParserDialectBase +{ +protected: + const char * getName() const override{ return "column assignment with alias"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +public: + using IParserDialectBase::IParserDialectBase; +}; + class ParserEscapeExpression : public IParserDialectBase { const char * getName() const override { return "ESCAPE clause"; } @@ -639,4 +651,15 @@ class ParserEscapeExpression : public IParserDialectBase using IParserDialectBase::IParserDialectBase; }; +/** The Execute Value is one of: an expression in parentheses, an array of literals, a literal, a function. + */ +class ParserExecuteValue : public IParserDialectBase +{ +protected: + const char * getName() const override { return "element of execute value"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +public: + using IParserDialectBase::IParserDialectBase; +}; + } diff --git a/src/Parsers/IAST.h b/src/Parsers/IAST.h index a6545971562..6336fcbd8c2 100644 --- a/src/Parsers/IAST.h +++ b/src/Parsers/IAST.h @@ -162,7 +162,8 @@ class ReadBuffer; M(ASTDropPreparedStatementQuery) \ M(ASTBitEngineConstraintDeclaration) \ M(ASTStorageAnalyticalMySQL) \ - M(ASTCreateQueryAnalyticalMySQL) + M(ASTCreateQueryAnalyticalMySQL) \ + M(ASTClusterByElement) #define ENUM_TYPE(ITEM) ITEM, enum class ASTType : UInt8 diff --git a/src/Parsers/ParserDictionary.cpp b/src/Parsers/ParserDictionary.cpp index 2095876c964..25b219757f1 100644 --- a/src/Parsers/ParserDictionary.cpp +++ b/src/Parsers/ParserDictionary.cpp @@ -208,11 +208,30 @@ bool ParserDictionary::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ASTPtr ast_range; ASTPtr ast_settings; + String db_name; + String tb_name; + String refresh_query; + String invalidate_query; + /// Primary is required to be the first in dictionary definition if (primary_key_keyword.ignore(pos) && !expression_list_p.parse(pos, primary_key, expected)) return false; /// Loop is used to avoid strict order of dictionary properties + auto get_value = [](ASTPair* kv_pair) + { + auto tb = kv_pair->second->as(); + auto tb_identifier = kv_pair->second->as(); + if (!tb) + { + if (tb_identifier) + return tb_identifier->name(); + } + else + return tb->value.get(); + return String(); + }; + while (true) { if (!ast_source && source_keyword.ignore(pos, expected)) @@ -232,7 +251,7 @@ bool ParserDictionary::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) for (auto &kv : ele->children) { auto kv_pair = kv->as(); - if (kv_pair->first == "user") + if (kv_pair->first == "user" || kv_pair->first == "USER") { String user_name; auto user = kv_pair->second->as(); @@ -256,6 +275,24 @@ bool ParserDictionary::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) else if (user_identifier) user_identifier->setShortName(user_name); } + else if (kv_pair->first == "db" || kv_pair->first == "DB") + { + auto db = get_value(kv_pair); + if (!db.empty()) + db_name = formatTenantEntityName(db); + } + else if (kv_pair->first == "table" || kv_pair->first == "TABLE") + { + tb_name = get_value(kv_pair); + } + else if (kv_pair->first == "query" || kv_pair->first == "QUERY") + { + refresh_query = get_value(kv_pair); + } + else if (kv_pair->first == "invalidate_query" || kv_pair->first == "INVALIDATE_QUERY") + { + invalidate_query = get_value(kv_pair); + } } } @@ -326,6 +363,10 @@ bool ParserDictionary::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) auto query = std::make_shared(); node = query; + query->clickhouse_db = db_name; + query->clickhouse_tb = tb_name; + query->clickhouse_query = refresh_query; + query->clickhouse_invalidate_query = invalidate_query; if (primary_key) query->set(query->primary_key, primary_key); diff --git a/src/Parsers/ParserDumpQuery.cpp b/src/Parsers/ParserDumpQuery.cpp index 90f39181a44..b1265bb75fc 100644 --- a/src/Parsers/ParserDumpQuery.cpp +++ b/src/Parsers/ParserDumpQuery.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace DB { @@ -48,6 +49,7 @@ bool ParserDumpQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ASTPtr where_expression; ASTPtr dump_path; ASTPtr settings; + bool ignore_format = false; /// DUMP if (!s_dump.ignore(pos, expected)) @@ -58,6 +60,14 @@ bool ParserDumpQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) if (!parser_settings.parse(pos, settings, expected)) pos = begin; + if (settings) + { + auto & settings_ast = settings->as(); + auto * ignore_format_setting = settings_ast.changes.tryGet("ignore_format"); + if (ignore_format_setting && ignore_format_setting->toString() == "1") + ignore_format = true; + } + /// DDL if (s_ddl.ignore(pos, expected)) { @@ -149,6 +159,7 @@ bool ParserDumpQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) dump_query->setExpression(ASTDumpQuery::Expression::CLUSTER, std::move(cluster_name)); dump_query->setExpression(ASTDumpQuery::Expression::DUMP_PATH, std::move(dump_path)); dump_query->setExpression(ASTDumpQuery::Expression::SETTING, std::move(settings)); + dump_query->ignore_format = ignore_format; if (output_client) { diff --git a/src/Parsers/ParserExplainQuery.cpp b/src/Parsers/ParserExplainQuery.cpp index ee5984082ab..033d8a52893 100644 --- a/src/Parsers/ParserExplainQuery.cpp +++ b/src/Parsers/ParserExplainQuery.cpp @@ -110,6 +110,10 @@ bool ParserExplainQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected explain_query->format = std::make_shared("JSON"); setIdentifierSpecial(explain_query->format); } + + auto * ignore_format = settings_ast.changes.tryGet("ignore_format"); + if (ignore_format && ignore_format->toString() == "1") + explain_query->ignore_format = true; explain_query->setSettings(std::move(settings)); } else diff --git a/src/Parsers/ParserGrantQuery.cpp b/src/Parsers/ParserGrantQuery.cpp index e6fb940ca04..51eb4bcca15 100644 --- a/src/Parsers/ParserGrantQuery.cpp +++ b/src/Parsers/ParserGrantQuery.cpp @@ -243,11 +243,17 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) // String cluster; // parseOnCluster(pos, expected, cluster); + bool is_sensitive = false; + if (ParserKeyword{"SENSITIVE"}.ignore(pos, expected)) + is_sensitive = true; + bool if_exists = false; bool grant_option = false; bool admin_option = false; if (is_revoke) { + if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected)) + if_exists = true; if (ParserKeyword{"GRANT OPTION FOR"}.ignore(pos, expected)) grant_option = true; else if (ParserKeyword{"ADMIN OPTION FOR"}.ignore(pos, expected)) @@ -297,6 +303,8 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) auto query = std::make_shared(); node = query; + query->is_sensitive = is_sensitive; + query->if_exists = if_exists; query->is_revoke = is_revoke; query->attach_mode = attach_mode; // query->cluster = std::move(cluster); diff --git a/src/Parsers/ParserPreparedParameter.cpp b/src/Parsers/ParserPreparedParameter.cpp index 4746007b55c..bfac5dc8636 100644 --- a/src/Parsers/ParserPreparedParameter.cpp +++ b/src/Parsers/ParserPreparedParameter.cpp @@ -7,6 +7,7 @@ #include #include #include +#include "Parsers/queryToString.h" namespace DB { @@ -27,14 +28,14 @@ bool ParserPreparedParameter::parseImpl(Pos & pos, ASTPtr & node, Expected & exp if (!ParserToken(TokenType::Colon).ignore(pos, expected)) return false; - if (!name_p.parse(pos, type_node, expected)) - return false; + ParserDataType type_parser(dt); + type_parser.parse(pos, type_node, expected); if (!ParserToken(TokenType::ClosingSquareBracket).ignore(pos, expected)) return false; tryGetIdentifierNameInto(identifier, prepared_parameter->name); - tryGetIdentifierNameInto(type_node, prepared_parameter->type); + prepared_parameter->type = queryToString(type_node); node = std::move(prepared_parameter); return true; } diff --git a/src/Parsers/ParserPreparedStatement.cpp b/src/Parsers/ParserPreparedStatement.cpp index 7919a6f8366..7daf0370eab 100644 --- a/src/Parsers/ParserPreparedStatement.cpp +++ b/src/Parsers/ParserPreparedStatement.cpp @@ -7,9 +7,11 @@ #include #include #include +#include +#include namespace DB -{ +{ bool ParserCreatePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { ParserKeyword s_create("CREATE"); @@ -39,7 +41,14 @@ bool ParserCreatePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Exp ParserCompoundIdentifier name_p; ASTPtr identifier; - if (!name_p.parse(pos, identifier, expected)) + if (name_p.parse(pos, identifier, expected)) + { + auto * name_node = identifier->as(); + if (name_node->nameParts().size() > 2 + || (name_node->nameParts().size() == 2 && (!getCurrentTenantId().empty() || getCurrentTenantId() == name_node->nameParts()[0]))) + return false; + } + else return false; String cluster_str; @@ -49,7 +58,6 @@ bool ParserCreatePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Exp return false; } - if (!s_as.parse(pos, identifier, expected)) return false; @@ -66,7 +74,10 @@ bool ParserCreatePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Exp prepare->if_not_exists = if_not_exists; prepare->or_replace = or_replace; prepare->query = query; + prepare->name_ast = identifier; + prepare->children.push_back(prepare->name_ast); prepare->children.push_back(prepare->query); + prepare->rewriteNamesWithTenant(pos.getContext()); node = prepare; return true; @@ -76,6 +87,7 @@ bool ParserExecutePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Ex { ParserKeyword s_execute("EXECUTE PREPARED STATEMENT"); ParserKeyword s_using("USING"); + ParserToken s_comma(TokenType::Comma); if (!s_execute.ignore(pos, expected)) return false; @@ -83,16 +95,50 @@ bool ParserExecutePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Ex ParserCompoundIdentifier name_p; ASTPtr identifier; - if (!name_p.parse(pos, identifier, expected)) + if (name_p.parse(pos, identifier, expected)) + { + auto * name_node = identifier->as(); + if (name_node->nameParts().size() > 2 + || (name_node->nameParts().size() == 2 && (!getCurrentTenantId().empty() || getCurrentTenantId() == name_node->nameParts()[0]))) + return false; + } + else return false; ASTPtr settings; if (s_using.ignore(pos, expected)) { - ParserSetQuery parser_settings(true); - if (!parser_settings.parse(pos, settings, expected)) - return false; + SettingsChanges changes; + ParserExecuteValue value_p(ParserSettings::CLICKHOUSE); + ParserToken s_eq(TokenType::Equals); + while (true) + { + if (!changes.empty() && !s_comma.ignore(pos)) + break; + + changes.push_back(SettingChange{}); + ASTPtr name; + ASTPtr value; + + if (!name_p.parse(pos, name, expected)) + return false; + + if (!s_eq.ignore(pos, expected)) + return false; + + if (!value_p.parse(pos, value, expected)) + return false; + + if (!value->as()) + return false; + + tryGetIdentifierNameInto(name, changes.back().name); + changes.back().value = value->as().value; + } + auto set_ast = std::make_shared(); + settings = set_ast; + set_ast->changes = std::move(changes); } else { @@ -104,6 +150,7 @@ bool ParserExecutePreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Ex tryGetIdentifierNameInto(identifier, execute->name); execute->values = settings; + execute->rewriteNamesWithTenant(); node = execute; return true; } @@ -127,9 +174,16 @@ bool ParserShowPreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expec create = true; if (!s_prepared_statement.ignore(pos, expected)) return false; - if (!name_p.parse(pos, identifier, expected)) + if (name_p.parse(pos, identifier, expected)) + { + auto * name_node = identifier->as(); + if (name_node->nameParts().size() > 2 + || (name_node->nameParts().size() == 2 && (!getCurrentTenantId().empty() || getCurrentTenantId() == name_node->nameParts()[0]))) return false; } + else + return false; + } else if (s_prepared_statements.ignore(pos, expected)) { } @@ -141,7 +195,14 @@ bool ParserShowPreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expec explain = true; if (!s_prepared_statement.ignore(pos, expected)) return false; - if (!name_p.parse(pos, identifier, expected)) + if (name_p.parse(pos, identifier, expected)) + { + auto * name_node = identifier->as(); + if (name_node->nameParts().size() > 2 + || (name_node->nameParts().size() == 2 && (!getCurrentTenantId().empty() || getCurrentTenantId() == name_node->nameParts()[0]))) + return false; + } + else return false; } else @@ -152,6 +213,7 @@ bool ParserShowPreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expec tryGetIdentifierNameInto(identifier, show_prepare->name); show_prepare->show_create = create; show_prepare->show_explain = explain; + show_prepare->rewriteNamesWithTenant(); node = show_prepare; return true; } @@ -172,7 +234,14 @@ bool ParserDropPreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expec ParserCompoundIdentifier name_p; ASTPtr identifier; - if (!name_p.parse(pos, identifier, expected)) + if (name_p.parse(pos, identifier, expected)) + { + auto * name_node = identifier->as(); + if (name_node->nameParts().size() > 2 + || (name_node->nameParts().size() == 2 && (!getCurrentTenantId().empty() || getCurrentTenantId() == name_node->nameParts()[0]))) + return false; + } + else return false; String cluster_str; @@ -186,6 +255,7 @@ bool ParserDropPreparedStatementQuery::parseImpl(Pos & pos, ASTPtr & node, Expec tryGetIdentifierNameInto(identifier, drop->name); drop->cluster = std::move(cluster_str); drop->if_exists = if_exists; + drop->rewriteNamesWithTenant(); node = drop; return true; } diff --git a/src/Parsers/ParserQueryWithOutput.cpp b/src/Parsers/ParserQueryWithOutput.cpp index afb0fd2d75a..055c4599ad3 100644 --- a/src/Parsers/ParserQueryWithOutput.cpp +++ b/src/Parsers/ParserQueryWithOutput.cpp @@ -180,12 +180,16 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (s_format.ignore(pos, expected)) { ParserIdentifier format_p; + ASTPtr format; - if (!format_p.parse(pos, query_with_output.format, expected)) + if (!format_p.parse(pos, format, expected)) return false; - setIdentifierSpecial(query_with_output.format); - - query_with_output.children.push_back(query_with_output.format); + if (!query_with_output.ignore_format) + { + query_with_output.format = format; + setIdentifierSpecial(query_with_output.format); + query_with_output.children.push_back(query_with_output.format); + } } return true; diff --git a/src/Parsers/ParserUpdateQuery.cpp b/src/Parsers/ParserUpdateQuery.cpp index 7b8960fd870..012d92f3fea 100644 --- a/src/Parsers/ParserUpdateQuery.cpp +++ b/src/Parsers/ParserUpdateQuery.cpp @@ -26,7 +26,7 @@ bool ParserUpdateQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ParserKeyword s_limit("LIMIT"); ParserKeyword s_settings("SETTINGS"); - ParserList parser_assignment_list(std::make_unique(dt), std::make_unique(TokenType::Comma), false); + ParserList parser_assignment_list(std::make_unique(dt), std::make_unique(TokenType::Comma), false); ParserExpression exp_list(dt); ParserOrderByExpressionList order_list(dt); ParserExpressionWithOptionalAlias exp_elem(false, dt); diff --git a/src/Parsers/formatTenantDatabaseName.cpp b/src/Parsers/formatTenantDatabaseName.cpp index 5fc06f9e0fe..d6d9cb8afca 100644 --- a/src/Parsers/formatTenantDatabaseName.cpp +++ b/src/Parsers/formatTenantDatabaseName.cpp @@ -61,12 +61,27 @@ static bool isInternalDatabaseName(const String & database_name) return false; } +//Format pattern {tenant_id}.{name} +String formatTenantName(const String & name, char separator) +{ + auto tenant_id = getCurrentTenantId(); + if (!tenant_id.empty() && + (!name.starts_with(tenant_id) || name.size() == tenant_id.size() || name[tenant_id.size()] != separator)) + { + String result = tenant_id; + result += separator; + result += name; + return result; + } + return name; +} + //Format pattern {tenant_id}.{database_name} static String formatTenantDatabaseNameImpl(const String & database_name, char separator = '.') { auto tenant_id = getCurrentTenantId(); if (!tenant_id.empty() && !isInternalDatabaseName(database_name) && - (database_name.find(tenant_id) != 0 || database_name.size() == tenant_id.size() || database_name[tenant_id.size()] != separator)) + (!database_name.starts_with(tenant_id) || database_name.size() == tenant_id.size() || database_name[tenant_id.size()] != separator)) { String result = tenant_id; result += separator; @@ -81,7 +96,7 @@ static String formatTenantUserNameImpl(const String & user_name, char separator { auto tenant_id = getCurrentTenantId(); if (!tenant_id.empty() && - (user_name.find(tenant_id) != 0 || user_name.size() == tenant_id.size() || user_name[tenant_id.size()] != separator)) + (!user_name.starts_with(tenant_id) || user_name.size() == tenant_id.size() || user_name[tenant_id.size()] != separator)) { String result = tenant_id; result += separator; @@ -106,14 +121,16 @@ String formatTenantDatabaseName(const String & database_name) } } -String appendTenantIdOnly(const String & name) +String appendTenantIdOnly(const String& name, bool is_datbase_name) { + if (!is_datbase_name) + return formatTenantName(name); return formatTenantDatabaseNameImpl(name); } String formatTenantDatabaseNameWithTenantId(const String & database_name, const String & tenant_id, char separator) { - if (!tenant_id.empty() && !isInternalDatabaseName(database_name) && database_name.find(tenant_id) != 0) + if (!tenant_id.empty() && !isInternalDatabaseName(database_name) && !database_name.starts_with(tenant_id)) { String result = tenant_id; result += separator; diff --git a/src/Parsers/formatTenantDatabaseName.h b/src/Parsers/formatTenantDatabaseName.h index 3cbe9a115d9..21c0bccedfb 100644 --- a/src/Parsers/formatTenantDatabaseName.h +++ b/src/Parsers/formatTenantDatabaseName.h @@ -14,7 +14,7 @@ String formatTenantDatabaseName(const String & database_name); // name -> tenant_id.name // no catalog information will be attached. -String appendTenantIdOnly(const String & name); +String appendTenantIdOnly(const String & name, bool is_datbase_name = true); String formatTenantConnectDefaultDatabaseName(const String & database_name); @@ -32,6 +32,8 @@ String getOriginalDatabaseName(const String & tenant_database_name); String getOriginalDatabaseName(const String & tenant_database_name, const String & tenant_id); +String formatTenantName(const String & name, char separator = '.'); + void pushTenantId(const String &tenant_id); void popTenantId(); diff --git a/src/Parsers/parseDatabaseAndTableName.cpp b/src/Parsers/parseDatabaseAndTableName.cpp index 4dc68089f81..bab970b2527 100644 --- a/src/Parsers/parseDatabaseAndTableName.cpp +++ b/src/Parsers/parseDatabaseAndTableName.cpp @@ -7,7 +7,7 @@ namespace DB { -bool parseDatabaseAndTableName(IParser::Pos & pos, Expected & expected, String & database_str, String & table_str) +bool parseDatabaseAndTableName(IParser::Pos & pos, Expected & expected, String & database_str, String & table_str, bool rewrite_db) { ParserToken s_dot(TokenType::Dot); ParserIdentifier table_parser; @@ -28,7 +28,9 @@ bool parseDatabaseAndTableName(IParser::Pos & pos, Expected & expected, String & database_str = ""; return false; } - tryRewriteCnchDatabaseName(database, pos.getContext()); + if (rewrite_db) + tryRewriteCnchDatabaseName(database, pos.getContext()); + tryGetIdentifierNameInto(database, database_str); tryGetIdentifierNameInto(table, table_str); } diff --git a/src/Parsers/parseDatabaseAndTableName.h b/src/Parsers/parseDatabaseAndTableName.h index dc435ca047e..24aca3aa8bc 100644 --- a/src/Parsers/parseDatabaseAndTableName.h +++ b/src/Parsers/parseDatabaseAndTableName.h @@ -5,7 +5,7 @@ namespace DB { /// Parses [db.]name -bool parseDatabaseAndTableName(IParser::Pos & pos, Expected & expected, String & database_str, String & table_str); +bool parseDatabaseAndTableName(IParser::Pos & pos, Expected & expected, String & database_str, String & table_str, bool rewrite_db = true); /// Parses [db.]name or [db.]* or [*.]* bool parseDatabaseAndTableNameOrAsterisks(IParser::Pos & pos, Expected & expected, String & database, bool & any_database, String & table, bool & any_table); diff --git a/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.cpp b/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.cpp index ad0d89b7090..0b0cbb56f79 100644 --- a/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.cpp +++ b/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.cpp @@ -184,20 +184,22 @@ void BrpcRemoteBroadcastReceiver::pushReceiveQueue(MultiPathDataPacket packet) return; } throw Exception( - "Push exchange data to receiver for " + getName() + " timeout from " + DateLUT::instance().timeToString(context->getClientInfo().initial_query_start_time) + - " to " + DateLUT::instance().timeToString(context->getQueryExpirationTimeStamp().tv_sec), + "Push exchange data to receiver for " + getName() + " timeout from " + + DateLUT::serverTimezoneInstance().timeToString(context->getClientInfo().initial_query_start_time) + " to " + + DateLUT::serverTimezoneInstance().timeToString(context->getQueryExpirationTimeStamp().tv_sec), ErrorCodes::DISTRIBUTE_STAGE_QUERY_EXCEPTION); } } -RecvDataPacket BrpcRemoteBroadcastReceiver::recv(timespec timeout_ts) noexcept +RecvDataPacket BrpcRemoteBroadcastReceiver::recv(timespec timeout_ts) { Stopwatch s; MultiPathDataPacket data_packet; if (!queue->tryPopUntil(data_packet, timeout_ts)) { - const auto error_msg = "Try pop receive queue for " + getName() + " timeout, from " + - DateLUT::instance().timeToString(context->getClientInfo().initial_query_start_time) + " to " + DateLUT::instance().timeToString(timeout_ts.tv_sec); + const auto error_msg = "Try pop receive queue for " + getName() + " timeout, from " + + DateLUT::serverTimezoneInstance().timeToString(context->getClientInfo().initial_query_start_time) + " to " + + DateLUT::serverTimezoneInstance().timeToString(timeout_ts.tv_sec); BroadcastStatus current_status = finish(BroadcastStatusCode::RECV_TIMEOUT, error_msg); return std::move(current_status); } diff --git a/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.h b/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.h index ac99f136b77..165fb3af97a 100644 --- a/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.h +++ b/src/Processors/Exchange/DataTrans/Brpc/BrpcRemoteBroadcastReceiver.h @@ -62,7 +62,7 @@ class BrpcRemoteBroadcastReceiver : public std::enable_shared_from_thisclosed(); String error_msg = "Try pop receive collector for " + name; - error_msg.append(collector_closed ? " interrupted" : " timeout at " + DateLUT::instance().timeToString(timeout_ts.tv_sec)); + error_msg.append( + collector_closed ? " interrupted" : " timeout at " + DateLUT::serverTimezoneInstance().timeToString(timeout_ts.tv_sec)); BroadcastStatus current_status = finish(collector_closed ? BroadcastStatusCode::RECV_UNKNOWN_ERROR : BroadcastStatusCode::RECV_TIMEOUT, error_msg); diff --git a/src/Processors/Exchange/DataTrans/RpcClient.cpp b/src/Processors/Exchange/DataTrans/RpcClient.cpp index 9eddf1674c9..c0445077248 100644 --- a/src/Processors/Exchange/DataTrans/RpcClient.cpp +++ b/src/Processors/Exchange/DataTrans/RpcClient.cpp @@ -55,9 +55,9 @@ void RpcClient::assertController(const brpc::Controller & cntl, int error_code) if (cntl.Failed()) { auto err = cntl.ErrorCode(); - if (err == ECONNREFUSED || err == ECONNRESET || err == ENOTCONN) + if (err == ECONNREFUSED || err == ECONNRESET) setOk(false); - else if (err == EHOSTDOWN || err == ENETUNREACH) + else if (err == EHOSTDOWN || err == ENETUNREACH || err == ENOTCONN) reportError(); throw Exception( fmt::format("Fail to call {}, error code: {}, msg: {}", cntl.method()->full_name(), err, cntl.ErrorText()), error_code); diff --git a/src/Processors/Exchange/RepartitionTransform.cpp b/src/Processors/Exchange/RepartitionTransform.cpp index 86c5cf5a539..8ed5c77e06f 100644 --- a/src/Processors/Exchange/RepartitionTransform.cpp +++ b/src/Processors/Exchange/RepartitionTransform.cpp @@ -118,6 +118,13 @@ ExecutableFunctionPtr RepartitionTransform::getDefaultRepartitionFunction(const return function_base->prepare(arguments); } +ExecutableFunctionPtr RepartitionTransform::getRepartitionHashFunction(const String & func_name, const ColumnsWithTypeAndName & arguments, ContextPtr context, const Array & params) +{ + FunctionOverloadResolverPtr func_builder = FunctionFactory::instance().get(func_name, context); + FunctionBasePtr function_base = func_builder->build(arguments); + return params.empty() ? function_base->prepare(arguments) : function_base->prepareWithParameters(arguments, params); +} + const DataTypePtr RepartitionTransform::REPARTITION_FUNC_RESULT_TYPE = std::make_shared(); const DataTypePtr RepartitionTransform::REPARTITION_FUNC_NULLABLE_RESULT_TYPE = std::make_shared(RepartitionTransform::REPARTITION_FUNC_RESULT_TYPE); } diff --git a/src/Processors/Exchange/RepartitionTransform.h b/src/Processors/Exchange/RepartitionTransform.h index 02e2c46197b..6072948590f 100644 --- a/src/Processors/Exchange/RepartitionTransform.h +++ b/src/Processors/Exchange/RepartitionTransform.h @@ -25,6 +25,7 @@ #include #include #include +#include #include namespace DB @@ -65,6 +66,8 @@ class RepartitionTransform : public ISimpleTransform static ExecutableFunctionPtr getDefaultRepartitionFunction(const ColumnsWithTypeAndName & arguments, ContextPtr context); + static ExecutableFunctionPtr getRepartitionHashFunction(const String & func_name, const ColumnsWithTypeAndName & arguments, ContextPtr context, const Array & params = {}); + protected: void transform(Chunk & chunk) override; diff --git a/src/Processors/Formats/Impl/ArrowBlockInputFormat.cpp b/src/Processors/Formats/Impl/ArrowBlockInputFormat.cpp index 28b136c6526..f9e2cb9717f 100644 --- a/src/Processors/Formats/Impl/ArrowBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ArrowBlockInputFormat.cpp @@ -108,6 +108,7 @@ void ArrowBlockInputFormat::prepareReader() format_settings.arrow.import_nested, format_settings.arrow.allow_missing_columns, format_settings.null_as_default, + format_settings.date_time_overflow_behavior, format_settings.arrow.case_insensitive_column_matching); if (stream) diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp index 3d652fc06ff..9827cf9ae8a 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp @@ -102,7 +102,7 @@ namespace ErrorCodes /// Inserts numeric data right into internal column data to reduce an overhead template > -static ColumnWithTypeAndName readColumnWithNumericData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithNumericData(const std::shared_ptr & arrow_column, const String & column_name) { auto internal_type = std::make_shared>(); auto internal_column = internal_type->createColumn(); @@ -127,7 +127,7 @@ static ColumnWithTypeAndName readColumnWithNumericData(std::shared_ptr -static ColumnWithTypeAndName readColumnWithStringData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithStringData(const std::shared_ptr & arrow_column, const String & column_name) { auto internal_type = std::make_shared(); auto internal_column = internal_type->createColumn(); @@ -171,7 +171,7 @@ static ColumnWithTypeAndName readColumnWithStringData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithFixedStringData(const std::shared_ptr & arrow_column, const String & column_name) { const auto * fixed_type = assert_cast(arrow_column->type().get()); size_t fixed_len = fixed_type->byte_width(); @@ -190,7 +190,7 @@ static ColumnWithTypeAndName readColumnWithFixedStringData(std::shared_ptr -static ColumnWithTypeAndName readColumnWithBigIntegerFromFixedBinaryData(std::shared_ptr & arrow_column, const String & column_name, const DataTypePtr & column_type) +static ColumnWithTypeAndName readColumnWithBigIntegerFromFixedBinaryData(const std::shared_ptr & arrow_column, const String & column_name, const DataTypePtr & column_type) { const auto * fixed_type = assert_cast(arrow_column->type().get()); size_t fixed_len = fixed_type->byte_width(); @@ -218,7 +218,7 @@ static ColumnWithTypeAndName readColumnWithBigIntegerFromFixedBinaryData(std::sh } template -static ColumnWithTypeAndName readColumnWithBigNumberFromBinaryData(std::shared_ptr & arrow_column, const String & column_name, const DataTypePtr & column_type) +static ColumnWithTypeAndName readColumnWithBigNumberFromBinaryData(const std::shared_ptr & arrow_column, const String & column_name, const DataTypePtr & column_type) { size_t total_size = 0; for (int chunk_i = 0, num_chunks = arrow_column->num_chunks(); chunk_i < num_chunks; ++chunk_i) @@ -259,7 +259,7 @@ static ColumnWithTypeAndName readColumnWithBigNumberFromBinaryData(std::shared_p return {std::move(internal_column), column_type, column_name}; } -static ColumnWithTypeAndName readColumnWithBooleanData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithBooleanData(const std::shared_ptr & arrow_column, const String & column_name) { auto internal_type = DataTypeFactory::instance().get("Bool"); auto internal_column = internal_type->createColumn(); @@ -278,7 +278,8 @@ static ColumnWithTypeAndName readColumnWithBooleanData(std::shared_ptr & arrow_column, const String & column_name, const DataTypePtr & type_hint) +static ColumnWithTypeAndName readColumnWithDate32Data(const std::shared_ptr & arrow_column, const String & column_name, + const DataTypePtr & type_hint, FormatSettings::DateTimeOverflowBehavior date_time_overflow_behavior) { DataTypePtr internal_type; bool check_date_range = false; @@ -310,11 +311,21 @@ static ColumnWithTypeAndName readColumnWithDate32Data(std::shared_ptr(chunk.Value(value_i)); if (days_num > DATE_LUT_MAX_EXTEND_DAY_NUM || days_num < -DAYNUM_OFFSET_EPOCH) { - throw Exception{ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, - "Input value {} of a column \"{}\" is out of allowed Date32 range, which is [{}, {}]", days_num, column_name, DAYNUM_OFFSET_EPOCH, DATE_LUT_MAX_EXTEND_DAY_NUM}; + switch (date_time_overflow_behavior) + { + case FormatSettings::DateTimeOverflowBehavior::Saturate: + days_num = (days_num < -DAYNUM_OFFSET_EPOCH) ? -DAYNUM_OFFSET_EPOCH : DATE_LUT_MAX_EXTEND_DAY_NUM; + break; + default: + /// Prior to introducing `date_time_overflow_behavior`, this function threw an error in case value was out of range. + /// In order to leave this behavior as default, we also throw when `date_time_overflow_mode == ignore`, as it is the setting's default value + /// (As we want to make this backwards compatible, not break any workflows.) + throw Exception{ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, + "Input value {} of a column \"{}\" is out of allowed Date32 range, which is [{}, {}]", + days_num,column_name, -DAYNUM_OFFSET_EPOCH, DATE_LUT_MAX_EXTEND_DAY_NUM}; + } } - else - column_data.emplace_back(days_num); + column_data.emplace_back(days_num); } } else @@ -328,7 +339,7 @@ static ColumnWithTypeAndName readColumnWithDate32Data(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithDate64Data(const std::shared_ptr & arrow_column, const String & column_name) { auto internal_type = std::make_shared(); auto internal_column = internal_type->createColumn(); @@ -347,7 +358,7 @@ static ColumnWithTypeAndName readColumnWithDate64Data(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithTimestampData(const std::shared_ptr & arrow_column, const String & column_name) { const auto & arrow_type = static_cast(*(arrow_column->type())); const UInt8 scale = arrow_type.unit() * 3; @@ -368,7 +379,7 @@ static ColumnWithTypeAndName readColumnWithTimestampData(std::shared_ptr -static ColumnWithTypeAndName readColumnWithTimeData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithTimeData(const std::shared_ptr & arrow_column, const String & column_name) { const auto & arrow_type = static_cast(*(arrow_column->type())); const UInt8 scale = arrow_type.unit() * 3; @@ -391,18 +402,18 @@ static ColumnWithTypeAndName readColumnWithTimeData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithTime32Data(const std::shared_ptr & arrow_column, const String & column_name) { return readColumnWithTimeData(arrow_column, column_name); } -static ColumnWithTypeAndName readColumnWithTime64Data(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithTime64Data(const std::shared_ptr & arrow_column, const String & column_name) { return readColumnWithTimeData(arrow_column, column_name); } template -static ColumnWithTypeAndName readColumnWithDecimalDataImpl(std::shared_ptr & arrow_column, const String & column_name, DataTypePtr internal_type) +static ColumnWithTypeAndName readColumnWithDecimalDataImpl(const std::shared_ptr & arrow_column, const String & column_name, DataTypePtr internal_type) { auto internal_column = internal_type->createColumn(); auto & column = assert_cast &>(*internal_column); @@ -421,7 +432,7 @@ static ColumnWithTypeAndName readColumnWithDecimalDataImpl(std::shared_ptr -static ColumnWithTypeAndName readColumnWithDecimalData(std::shared_ptr & arrow_column, const String & column_name) +static ColumnWithTypeAndName readColumnWithDecimalData(const std::shared_ptr & arrow_column, const String & column_name) { const auto * arrow_decimal_type = static_cast(arrow_column->type().get()); size_t precision = arrow_decimal_type->precision(); @@ -436,7 +447,7 @@ static ColumnWithTypeAndName readColumnWithDecimalData(std::shared_ptr & arrow_column) +static ColumnPtr readByteMapFromArrowColumn(const std::shared_ptr & arrow_column) { if (!arrow_column->null_count()) return ColumnUInt8::create(arrow_column->length(), 0); @@ -455,7 +466,7 @@ static ColumnPtr readByteMapFromArrowColumn(std::shared_ptr return nullmap_column; } -static ColumnPtr readOffsetsFromArrowListColumn(std::shared_ptr & arrow_column) +static ColumnPtr readOffsetsFromArrowListColumn(const std::shared_ptr & arrow_column) { auto offsets_column = ColumnUInt64::create(); ColumnArray::Offsets & offsets_data = assert_cast &>(*offsets_column).getData(); @@ -502,7 +513,8 @@ static ColumnPtr readOffsetsFromArrowListColumn(std::shared_ptr> -static ColumnWithTypeAndName readColumnWithIndexesDataImpl(std::shared_ptr & arrow_column, const String & column_name, Int64 default_value_index, NumericType dict_size, bool is_nullable) +static ColumnWithTypeAndName readColumnWithIndexesDataImpl(const std::shared_ptr & arrow_column, + const String & column_name, Int64 default_value_index, NumericType dict_size, bool is_nullable) { auto internal_type = std::make_shared>(); auto internal_column = internal_type->createColumn(); @@ -600,7 +612,7 @@ static ColumnWithTypeAndName readColumnWithIndexesDataImpl(std::shared_ptr & arrow_column, Int64 default_value_index, UInt64 dict_size, bool is_nullable) +static ColumnPtr readColumnWithIndexesData(const std::shared_ptr & arrow_column, Int64 default_value_index, UInt64 dict_size, bool is_nullable) { switch (arrow_column->type()->id()) { @@ -617,7 +629,7 @@ static ColumnPtr readColumnWithIndexesData(std::shared_ptr } } -static std::shared_ptr getNestedArrowColumn(std::shared_ptr & arrow_column) +static std::shared_ptr getNestedArrowColumn(const std::shared_ptr & arrow_column) { arrow::ArrayVector array_vector; array_vector.reserve(arrow_column->num_chunks()); @@ -702,34 +714,84 @@ static std::shared_ptr getNestedArrowColumn(std::shared_ptr // return {std::move(internal_column), std::move(internal_type), column_name}; // } +struct ReadColumnFromArrowColumnSettings +{ + std::string format_name; + FormatSettings::DateTimeOverflowBehavior date_time_overflow_behavior; + bool allow_arrow_null_type; + bool skip_columns_with_unsupported_types; +}; + +static ColumnWithTypeAndName readNonNullableColumnFromArrowColumn( + const std::shared_ptr & arrow_column, + std::string column_name, + std::unordered_map dictionary_infos, + DataTypePtr type_hint, + bool is_map_nested_column, + const ReadColumnFromArrowColumnSettings & settings); + static ColumnWithTypeAndName readColumnFromArrowColumn( - std::shared_ptr & arrow_column, - const std::string & column_name, - const std::string & format_name, - bool is_nullable, - std::unordered_map & dictionary_infos, - bool allow_null_type, - bool skip_columns_with_unsupported_types, - bool & skipped, - DataTypePtr type_hint = nullptr, - bool is_map_nested = false) + const std::shared_ptr & arrow_column, + std::string column_name, + std::unordered_map dictionary_infos, + DataTypePtr type_hint, + bool is_nullable_column, + bool is_map_nested_column, + const ReadColumnFromArrowColumnSettings & settings) { - if (!is_nullable && (arrow_column->null_count() || (type_hint && type_hint->isNullable())) && arrow_column->type()->id() != arrow::Type::LIST - && arrow_column->type()->id() != arrow::Type::MAP && arrow_column->type()->id() != arrow::Type::STRUCT && + /// read as Nullable (only in basic data type): + /// case 1: arrow column has null vaules, but clickhouse schema is not nullable + /// step 1: read column as Nullable(xxx) + /// step 2: clickhouse column is Bitmap column + /// castBitmapColumn: Array(Nullable(int)) / Array(Nullable(String)) -> Array(int) / Array(String) -> BitMap + /// castColumn: Nullable(xxx) -> xxx + /// case 2: arrow column has null values, clickhouse schema is Nullable(xxx) + /// step 1: read column as Nullable(xxx) + bool read_as_nullable_column = arrow_column->null_count() || is_nullable_column || (type_hint && type_hint->isNullable()); + if (read_as_nullable_column && + arrow_column->type()->id() != arrow::Type::LIST && + arrow_column->type()->id() != arrow::Type::LARGE_LIST && + arrow_column->type()->id() != arrow::Type::MAP && + arrow_column->type()->id() != arrow::Type::STRUCT && arrow_column->type()->id() != arrow::Type::DICTIONARY) { DataTypePtr nested_type_hint; if (type_hint) nested_type_hint = removeNullable(type_hint); - auto nested_column = readColumnFromArrowColumn(arrow_column, column_name, format_name, true, dictionary_infos, allow_null_type, skip_columns_with_unsupported_types, skipped, nested_type_hint); - if (skipped) + + auto nested_column = readNonNullableColumnFromArrowColumn(arrow_column, + column_name, + dictionary_infos, + nested_type_hint, + is_map_nested_column, + settings); + + if (!nested_column.column) return {}; + auto nullmap_column = readByteMapFromArrowColumn(arrow_column); auto nullable_type = std::make_shared(std::move(nested_column.type)); auto nullable_column = ColumnNullable::create(nested_column.column, nullmap_column); + return {std::move(nullable_column), std::move(nullable_type), column_name}; } + return readNonNullableColumnFromArrowColumn(arrow_column, + column_name, + dictionary_infos, + type_hint, + is_map_nested_column, + settings); +} + +static ColumnWithTypeAndName readNonNullableColumnFromArrowColumn( + const std::shared_ptr & arrow_column, + std::string column_name, + std::unordered_map dictionary_infos, + DataTypePtr type_hint, + bool is_map_nested_column, + const ReadColumnFromArrowColumnSettings & settings) +{ switch (arrow_column->type()->id()) { case arrow::Type::STRING: @@ -784,7 +846,7 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( case arrow::Type::BOOL: return readColumnWithBooleanData(arrow_column, column_name); case arrow::Type::DATE32: - return readColumnWithDate32Data(arrow_column, column_name, type_hint); + return readColumnWithDate32Data(arrow_column, column_name, type_hint, settings.date_time_overflow_behavior); case arrow::Type::DATE64: return readColumnWithDate64Data(arrow_column, column_name); // ClickHouse writes Date as arrow UINT16 and DateTime as arrow UINT32, @@ -832,8 +894,15 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( } } auto arrow_nested_column = getNestedArrowColumn(arrow_column); - auto nested_column = readColumnFromArrowColumn(arrow_nested_column, column_name, format_name, false, dictionary_infos, allow_null_type, skip_columns_with_unsupported_types, skipped, nested_type_hint, true); - if (skipped) + auto nested_column = readColumnFromArrowColumn( + arrow_nested_column, + column_name, + dictionary_infos, + nested_type_hint, + /*is_nullable_column*/ false, + /*is_map_nested_column*/ true, + settings); + if (!nested_column.column) return {}; auto offsets_column = readOffsetsFromArrowListColumn(arrow_column); @@ -866,10 +935,20 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( if (array_type_hint) nested_type_hint = array_type_hint->getNestedType(); } + auto * arrow_list_type = assert_cast(arrow_column->type().get()); + bool is_nested_nullable_column = arrow_list_type->value_field()->nullable(); auto arrow_nested_column = getNestedArrowColumn(arrow_column); - auto nested_column = readColumnFromArrowColumn(arrow_nested_column, column_name, format_name, false, dictionary_infos, allow_null_type, skip_columns_with_unsupported_types, skipped, nested_type_hint); - if (skipped) + auto nested_column = readColumnFromArrowColumn( + arrow_nested_column, + column_name, + dictionary_infos, + nested_type_hint, + is_nested_nullable_column, + false /*is_map_nested_column*/, + settings); + if (!nested_column.column) return {}; + auto offsets_column = readOffsetsFromArrowListColumn(arrow_column); auto array_column = ColumnArray::create(nested_column.column, offsets_column); auto array_type = std::make_shared(nested_column.type); @@ -894,11 +973,12 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( for (int i = 0; i != arrow_struct_type->num_fields(); ++i) { - auto field_name = arrow_struct_type->field(i)->name(); + auto field = arrow_struct_type->field(i); + auto field_name = field->name(); DataTypePtr nested_type_hint; if (tuple_type_hint) { - if (tuple_type_hint->haveExplicitNames() && !is_map_nested) + if (tuple_type_hint->haveExplicitNames() && !is_map_nested_column) { auto pos = tuple_type_hint->tryGetPositionByName(field_name); if (pos) @@ -908,9 +988,17 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( nested_type_hint = tuple_type_hint->getElement(i); } auto nested_arrow_column = std::make_shared(nested_arrow_columns[i]); - auto element = readColumnFromArrowColumn(nested_arrow_column, field_name, format_name, false, dictionary_infos, allow_null_type, skip_columns_with_unsupported_types, skipped, nested_type_hint); - if (skipped) + auto element = readColumnFromArrowColumn( + nested_arrow_column, + field_name, + dictionary_infos, + nested_type_hint, + field->nullable(), + false /*is_map_nested_column*/, + settings); + if (!element.column) return {}; + tuple_elements.emplace_back(std::move(element.column)); tuple_types.emplace_back(std::move(element.type)); tuple_names.emplace_back(std::move(element.name)); @@ -935,7 +1023,17 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( dict_array.emplace_back(dict_chunk.dictionary()); } auto arrow_dict_column = std::make_shared(dict_array); - auto dict_column = readColumnFromArrowColumn(arrow_dict_column, column_name, format_name, false, dictionary_infos, allow_null_type, skip_columns_with_unsupported_types, skipped); + auto dict_column = readColumnFromArrowColumn(arrow_dict_column, + column_name, + dictionary_infos, + nullptr /*nested_type_hint*/, + false /*is_nullable_column*/, + false /*is_map_nested_column*/, + settings); + + if (!dict_column.column) + return {}; + for (size_t i = 0; i != dict_column.column->size(); ++i) { if (dict_column.column->isDefaultAt(i)) @@ -983,7 +1081,7 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( // TODO: read UUID as a string? case arrow::Type::NA: { - if (allow_null_type) + if (settings.allow_arrow_null_type) { auto type = std::make_shared(); auto column = ColumnNothing::create(arrow_column->length()); @@ -993,11 +1091,8 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( } default: { - if (skip_columns_with_unsupported_types) - { - skipped = true; + if (settings.skip_columns_with_unsupported_types) return {}; - } throw Exception( ErrorCodes::UNKNOWN_TYPE, @@ -1005,10 +1100,10 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( "If it happens during schema inference and you want to skip columns with " "unsupported types, you can enable setting input_format_{}" "_skip_columns_with_unsupported_types_in_schema_inference", - format_name, + settings.format_name, arrow_column->type()->name(), column_name, - boost::algorithm::to_lower_copy(format_name)); + boost::algorithm::to_lower_copy(settings.format_name)); } } } @@ -1026,6 +1121,14 @@ static void checkStatus(const arrow::Status & status, const String & column_name Block ArrowColumnToCHColumn::arrowSchemaToCHHeader( const arrow::Schema & schema, const std::string & format_name, bool skip_columns_with_unsupported_types, const Block * hint_header, bool ignore_case) { + ReadColumnFromArrowColumnSettings settings + { + .format_name = format_name, + .date_time_overflow_behavior = FormatSettings::DateTimeOverflowBehavior::Ignore, + .allow_arrow_null_type = false, + .skip_columns_with_unsupported_types = skip_columns_with_unsupported_types + }; + ColumnsWithTypeAndName sample_columns; std::unordered_set nested_table_names; if (hint_header) @@ -1050,13 +1153,19 @@ Block ArrowColumnToCHColumn::arrowSchemaToCHHeader( arrow::ArrayVector array_vector = {arrow_array}; auto arrow_column = std::make_shared(array_vector); std::unordered_map dict_infos; - bool skipped = false; - bool allow_null_type = false; if (hint_header && hint_header->has(field->name()) && hint_header->getByName(field->name()).type->isNullable()) - allow_null_type = true; - ColumnWithTypeAndName sample_column = readColumnFromArrowColumn( - arrow_column, field->name(), format_name, false, dict_infos, allow_null_type, skip_columns_with_unsupported_types, skipped); - if (!skipped) + settings.allow_arrow_null_type = true; + + auto sample_column = readColumnFromArrowColumn( + arrow_column, + field->name(), + dict_infos, + nullptr /*nested_type_hint*/, + field->nullable() /*is_nullable_column*/, + false /*is_map_nested_column*/, + settings); + + if (sample_column.column) sample_columns.emplace_back(std::move(sample_column)); } return Block(std::move(sample_columns)); @@ -1068,40 +1177,51 @@ ArrowColumnToCHColumn::ArrowColumnToCHColumn( bool import_nested_, bool allow_missing_columns_, bool null_as_default_, + FormatSettings::DateTimeOverflowBehavior date_time_overflow_behavior_, bool case_insensitive_matching_) : header(header_) , format_name(format_name_) , import_nested(import_nested_) , allow_missing_columns(allow_missing_columns_) , null_as_default(null_as_default_) + , date_time_overflow_behavior(date_time_overflow_behavior_) , case_insensitive_matching(case_insensitive_matching_) { } void ArrowColumnToCHColumn::arrowTableToCHChunk(Chunk & res, std::shared_ptr & table, size_t num_rows, BlockMissingValues * block_missing_values) { - NameToColumnPtr name_to_column_ptr; + NameToArrowColumn name_to_arrow_column; for (auto column_name : table->ColumnNames()) { std::shared_ptr arrow_column = table->GetColumnByName(column_name); if (!arrow_column) throw Exception(ErrorCodes::DUPLICATE_COLUMN, "Column '{}' is duplicated", column_name); + auto arrow_field = table->schema()->GetFieldByName(column_name); + if (case_insensitive_matching) boost::to_lower(column_name); - name_to_column_ptr[std::move(column_name)] = arrow_column; + name_to_arrow_column[std::move(column_name)] = {std::move(arrow_column), std::move(arrow_field)}; } - arrowColumnsToCHChunk(res, name_to_column_ptr, num_rows, block_missing_values); + arrowColumnsToCHChunk(res, name_to_arrow_column, num_rows, block_missing_values); } -void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & name_to_column_ptr, size_t num_rows, BlockMissingValues * block_missing_values) +void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, const NameToArrowColumn & name_to_arrow_column, size_t num_rows, BlockMissingValues * block_missing_values) { + ReadColumnFromArrowColumnSettings settings + { + .format_name = format_name, + .date_time_overflow_behavior = date_time_overflow_behavior, + .allow_arrow_null_type = true, + .skip_columns_with_unsupported_types = false + }; + Columns columns_list; columns_list.reserve(header.columns()); std::unordered_map>> nested_tables; - bool skipped = false; for (size_t column_i = 0, columns = header.columns(); column_i < columns; ++column_i) { const ColumnWithTypeAndName & header_column = header.getByPosition(column_i); @@ -1111,7 +1231,7 @@ void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & boost::to_lower(search_column_name); ColumnWithTypeAndName column; - if (!name_to_column_ptr.contains(search_column_name)) + if (!name_to_arrow_column.contains(search_column_name)) { bool read_from_nested = false; /// Check if it's a column from nested table. @@ -1121,7 +1241,7 @@ void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & String search_nested_table_name = nested_table_name; if (case_insensitive_matching) boost::to_lower(search_nested_table_name); - if (name_to_column_ptr.contains(search_nested_table_name)) + if (name_to_arrow_column.contains(search_nested_table_name)) { if (!nested_tables.contains(search_nested_table_name)) { @@ -1133,9 +1253,17 @@ void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & } auto nested_table_type = Nested::collect(nested_columns).front().type; - std::shared_ptr arrow_column = name_to_column_ptr[search_nested_table_name]; - ColumnsWithTypeAndName cols = {readColumnFromArrowColumn( - arrow_column, nested_table_name, format_name, false, dictionary_infos, true, false, skipped, nested_table_type)}; + const auto & arrow_column = name_to_arrow_column.find(search_nested_table_name)->second; + ColumnsWithTypeAndName cols = + { + readColumnFromArrowColumn(arrow_column.column, + nested_table_name, + dictionary_infos, + nested_table_type, + arrow_column.field->nullable() /*is_nullable_column*/, + false /*is_map_nested_column*/, + settings) + }; BlockPtr block_ptr = std::make_shared(cols); auto column_extractor = std::make_shared(*block_ptr, case_insensitive_matching); nested_tables[search_nested_table_name] = {block_ptr, column_extractor}; @@ -1170,9 +1298,15 @@ void ArrowColumnToCHColumn::arrowColumnsToCHChunk(Chunk & res, NameToColumnPtr & } else { - auto arrow_column = name_to_column_ptr[search_column_name]; + const auto & arrow_column = name_to_arrow_column.find(search_column_name)->second; column = readColumnFromArrowColumn( - arrow_column, header_column.name, format_name, false, dictionary_infos, true, false, skipped, header_column.type); + arrow_column.column, + header_column.name, + dictionary_infos, + header_column.type, + arrow_column.field->nullable(), + /*is_map_nested_column*/ false, + settings); } if (null_as_default) @@ -1214,7 +1348,16 @@ ColumnPtr ArrowColumnToCHColumn::castArrayColumnToBitmapColumn(ColumnWithTypeAnd "ClickHouse BitMap64 can only be converted from Array, but column {} is {}", column.name, column.type->getName()); - DataTypePtr internal_nested = array->getNestedType(); + + if (array->getNestedType()->isNullable()) + { + DataTypePtr adapter_type = std::make_shared(removeNullable(array->getNestedType())); + column.column = castColumn(column, adapter_type); + column.type = adapter_type; + } + + DataTypePtr internal_nested = checkAndGetDataType(column.type.get())->getNestedType(); + if (isString(internal_nested)) { throw Exception("String list to Bitmap is not support In cnch", ErrorCodes::NOT_IMPLEMENTED); @@ -1227,6 +1370,7 @@ ColumnPtr ArrowColumnToCHColumn::castArrayColumnToBitmapColumn(ColumnWithTypeAnd column.column = castColumn(column, adapter_type); column.type = adapter_type; } + return castToBitmap64Column(column, target_type); } } diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.h b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.h index 0d5b23b8583..30dadbc4774 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.h +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.h @@ -53,6 +53,7 @@ class ArrowColumnToCHColumn bool import_nested_, bool allow_missing_columns_, bool null_as_default_, + FormatSettings::DateTimeOverflowBehavior date_time_overflow_behavior_, bool case_insensitive_matching_ = false); void arrowTableToCHChunk(Chunk & res, std::shared_ptr & table, size_t num_rows, BlockMissingValues * block_missing_values = nullptr); @@ -76,6 +77,15 @@ class ArrowColumnToCHColumn }; private: + struct ArrowColumn + { + std::shared_ptr column; + std::shared_ptr field; + }; + + using NameToArrowColumn = std::unordered_map; + void arrowColumnsToCHChunk(Chunk & res, const NameToArrowColumn & name_to_arrow_column, size_t num_rows, BlockMissingValues * block_missing_values = nullptr); + static ColumnPtr castArrayColumnToBitmapColumn(ColumnWithTypeAndName & column, const DataTypePtr & target_type); const Block & header; @@ -84,6 +94,7 @@ class ArrowColumnToCHColumn /// If false, throw exception if some columns in header not exists in arrow table. bool allow_missing_columns; bool null_as_default; + FormatSettings::DateTimeOverflowBehavior date_time_overflow_behavior; bool case_insensitive_matching; /// Map {column name : dictionary column}. diff --git a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp index 657a624c1be..ce1de06a9de 100644 --- a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp @@ -169,6 +169,7 @@ void ORCBlockInputFormat::prepareReader() format_settings.orc.import_nested, format_settings.orc.allow_missing_columns, format_settings.null_as_default, + format_settings.date_time_overflow_behavior, format_settings.orc.case_insensitive_column_matching); include_indices = getColumnIndices(schema, getPort().getHeader(), format_settings.orc.case_insensitive_column_matching, format_settings.orc.import_nested); diff --git a/src/Processors/Formats/Impl/OrcChunkReader.cpp b/src/Processors/Formats/Impl/OrcChunkReader.cpp index 61622fdf15d..ae5355bd4d4 100644 --- a/src/Processors/Formats/Impl/OrcChunkReader.cpp +++ b/src/Processors/Formats/Impl/OrcChunkReader.cpp @@ -514,14 +514,15 @@ Status OrcChunkReader::initBlock() bool allow_missing_columns = format_settings.orc.allow_missing_columns; bool null_as_default = format_settings.null_as_default; bool case_insenstive = format_settings.orc.case_insensitive_column_matching; + bool allow_out_of_range = format_settings.date_time_overflow_behavior == FormatSettings::DateTimeOverflowBehavior::Saturate ? true : false; //TODO fix this. active_orc_column_to_ch_column - = std::make_unique(active_block, allow_missing_columns, null_as_default, case_insenstive); + = std::make_unique(active_block, allow_missing_columns, null_as_default, case_insenstive, allow_out_of_range); lazy_orc_column_to_ch_column - = std::make_unique(lazy_block, allow_missing_columns, null_as_default, case_insenstive); + = std::make_unique(lazy_block, allow_missing_columns, null_as_default, case_insenstive, allow_out_of_range); orc_column_to_ch_column - = std::make_unique(chunk_reader_params.header, allow_missing_columns, null_as_default, case_insenstive); + = std::make_unique(chunk_reader_params.header, allow_missing_columns, null_as_default, case_insenstive, allow_out_of_range); return Status::OK(); } diff --git a/src/Processors/Formats/Impl/OrcCommon.cpp b/src/Processors/Formats/Impl/OrcCommon.cpp index 301a177b3a5..df76ff0693c 100644 --- a/src/Processors/Formats/Impl/OrcCommon.cpp +++ b/src/Processors/Formats/Impl/OrcCommon.cpp @@ -647,9 +647,6 @@ static void buildORCSearchArgumentImpl( break; } -#if USE_GIS == 1 - case KeyCondition::RPNElement::FUNCTION_GEOMETRY: -#endif case KeyCondition::RPNElement::FUNCTION_UNKNOWN: { builder.literal(orc::TruthValue::YES_NO_NULL); rpn_stack.pop_back(); diff --git a/src/Processors/Formats/Impl/Parquet/ParquetArrowColReader.cpp b/src/Processors/Formats/Impl/Parquet/ParquetArrowColReader.cpp index 6597c67b952..4474896ef37 100644 --- a/src/Processors/Formats/Impl/Parquet/ParquetArrowColReader.cpp +++ b/src/Processors/Formats/Impl/Parquet/ParquetArrowColReader.cpp @@ -52,6 +52,7 @@ ParquetArrowColReader::ParquetArrowColReader( format_settings.parquet.import_nested, format_settings.parquet.allow_missing_columns, format_settings.null_as_default, + format_settings.date_time_overflow_behavior, format_settings.parquet.case_insensitive_column_matching); } diff --git a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp index 927d12ded51..887a7249815 100644 --- a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp @@ -568,6 +568,7 @@ void ParquetBlockInputFormat::initializeRowGroupReaderIfNeeded(size_t row_group_ format_settings.parquet.import_nested, format_settings.parquet.allow_missing_columns, format_settings.null_as_default, + format_settings.date_time_overflow_behavior, format_settings.parquet.case_insensitive_column_matching); // if (auto context = getContext()) diff --git a/src/Processors/IProcessor.h b/src/Processors/IProcessor.h index 7854c6dda48..577fd922ab1 100644 --- a/src/Processors/IProcessor.h +++ b/src/Processors/IProcessor.h @@ -357,6 +357,7 @@ class IProcessor protected: virtual void onCancel() {} + std::atomic is_cancelled{false}; private: /// For: @@ -364,7 +365,7 @@ class IProcessor /// - input_wait_elapsed_us /// - output_wait_elapsed_us friend class PipelineExecutor; - std::atomic is_cancelled{false}; + friend class ExecutingGraph; std::string processor_description; diff --git a/src/Processors/IntermediateResult/CacheManager.cpp b/src/Processors/IntermediateResult/CacheManager.cpp index babf58d351b..00acd49c001 100644 --- a/src/Processors/IntermediateResult/CacheManager.cpp +++ b/src/Processors/IntermediateResult/CacheManager.cpp @@ -169,9 +169,7 @@ void CacheManager::setComplete(const CacheKey & key) { auto empty_key = key.cloneWithoutOwnerInfo(); value = tryGetUncompletedCache(empty_key); - if (value) - eraseUncompletedCache(empty_key); - else + if (!value) modifyKeyStateToRefused(key); } if (value) diff --git a/src/Processors/IntermediateResult/CacheManager.h b/src/Processors/IntermediateResult/CacheManager.h index fe6eb52f122..0f6484fea3e 100644 --- a/src/Processors/IntermediateResult/CacheManager.h +++ b/src/Processors/IntermediateResult/CacheManager.h @@ -155,6 +155,8 @@ struct CacheHolder bool all_part_in_cache = false; // if ture, the original pipeline will be generated bool all_part_in_storage = false; + // ensure cache is fully written + bool early_finish = false; }; using CacheHolderPtr = std::shared_ptr; diff --git a/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.cpp b/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.cpp index eaffe394ee4..81daec057f2 100644 --- a/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.cpp +++ b/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.cpp @@ -119,7 +119,7 @@ IMergingAlgorithm::Status FinishAggregatingInOrderAlgorithm::merge() Chunk FinishAggregatingInOrderAlgorithm::aggregate() { - auto aggregated = params->aggregator.mergeBlocks(blocks, false); + auto aggregated = params->aggregator.mergeBlocks(blocks, false, is_cancelled); blocks.clear(); accumulated_rows = 0; return {aggregated.getColumns(), aggregated.rows()}; diff --git a/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.h b/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.h index 119aefb0ab0..02aff0e0784 100644 --- a/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.h +++ b/src/Processors/Merges/Algorithms/FinishAggregatingInOrderAlgorithm.h @@ -75,6 +75,7 @@ class FinishAggregatingInOrderAlgorithm final : public IMergingAlgorithm std::vector inputs_to_update; BlocksList blocks; size_t accumulated_rows = 0; + std::atomic is_cancelled{false}; }; } diff --git a/src/Processors/Merges/Algorithms/GraphiteRollupSortedAlgorithm.cpp b/src/Processors/Merges/Algorithms/GraphiteRollupSortedAlgorithm.cpp index c987b26b6be..d115480ea87 100644 --- a/src/Processors/Merges/Algorithms/GraphiteRollupSortedAlgorithm.cpp +++ b/src/Processors/Merges/Algorithms/GraphiteRollupSortedAlgorithm.cpp @@ -145,7 +145,7 @@ static time_t roundTimeToPrecision(const DateLUTImpl & date_lut, time_t time, UI IMergingAlgorithm::Status GraphiteRollupSortedAlgorithm::merge() { - const DateLUTImpl & date_lut = DateLUT::instance(); + const DateLUTImpl & date_lut = DateLUT::serverTimezoneInstance(); /// Take rows in needed order and put them into `merged_data` until we get `max_block_size` rows. /// diff --git a/src/Processors/QueryPipeline.cpp b/src/Processors/QueryPipeline.cpp index 4f1af41d9c0..3ca004b0fba 100644 --- a/src/Processors/QueryPipeline.cpp +++ b/src/Processors/QueryPipeline.cpp @@ -49,6 +49,7 @@ #include #include #include +#include "Processors/Transforms/TableFinishTransform.h" #include @@ -578,6 +579,8 @@ void QueryPipeline::setProgressCallback(const ProgressCallback & callback) { if (auto * source = dynamic_cast(processor.get())) source->setProgressCallback(callback); + if (auto * finish_transform = dynamic_cast(processor.get())) + finish_transform->setProgressCallback(callback); } } @@ -603,6 +606,8 @@ void QueryPipeline::setProcessListElement(QueryStatus * elem) { if (auto * source = dynamic_cast(processor.get())) source->setProcessListElement(elem); + if (auto * finish_transform = dynamic_cast(processor.get())) + finish_transform->setProcessListElement(elem); } } @@ -779,9 +784,12 @@ void QueryPipeline::setWriteCacheComplete(const ContextPtr & context) if (!pipe.holder.cache_holder) return; + if (pipe.holder.cache_holder->early_finish) + return; + auto cache = context->getIntermediateResultCache(); auto & write_cache = pipe.holder.cache_holder->write_cache; - for (auto cache_key : write_cache) + for (const auto & cache_key : write_cache) cache->setComplete(cache_key); write_cache.clear(); } @@ -793,7 +801,7 @@ void QueryPipeline::clearUncompletedCache(const ContextPtr & context) auto cache = context->getIntermediateResultCache(); auto & write_cache = pipe.holder.cache_holder->write_cache; - for (auto cache_key : write_cache) + for (const auto & cache_key : write_cache) cache->eraseUncompletedCache(cache_key); } diff --git a/src/Processors/ResizeProcessor.cpp b/src/Processors/ResizeProcessor.cpp index d652a342150..02394998001 100644 --- a/src/Processors/ResizeProcessor.cpp +++ b/src/Processors/ResizeProcessor.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace DB @@ -263,6 +264,8 @@ IProcessor::Status ResizeProcessor::prepare(const PortNumbers & updated_inputs, IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_inputs, const PortNumbers & updated_outputs) { + static auto * logger = &Poco::Logger::get("StrictResizeProcessor"); + if (!initialized) { initialized = true; @@ -320,8 +323,15 @@ IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_in { input.status = InputStatus::Finished; ++num_finished_inputs; - - waiting_outputs.push(input.waiting_output); + /// Avoid pushing data to outputs which are already hasDate or finished + auto & output = output_ports[input.waiting_output]; + if (!output.port->isFinished() && output.port->canPush()) + { + /// reset status to avoid error: Invalid status NotActive for associated output + /// for example, if output with NotActive status is pushed to waiting_outputs and then assign to another input. + output.status = OutputStatus::NeedData; + waiting_outputs.push(input.waiting_output); + } } continue; } @@ -347,10 +357,8 @@ IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_in auto & waiting_output = output_ports[input_with_data.waiting_output]; - if (waiting_output.status == OutputStatus::NotActive) - throw Exception("Invalid status NotActive for associated output.", ErrorCodes::LOGICAL_ERROR); - - if (waiting_output.status != OutputStatus::Finished) + /// Output status could be NotActive when abandoned_chunks are pushed to it. + if (waiting_output.status == OutputStatus::NeedData) { waiting_output.port->pushData(input_with_data.port->pullData(/* set_not_needed = */ true)); waiting_output.status = OutputStatus::NotActive; @@ -367,7 +375,8 @@ IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_in disabled_input_ports.push(input_number); } - if (num_finished_inputs == inputs.size()) + /// Losing abandoned chunks if not judge empty. + if (num_finished_inputs == inputs.size() && abandoned_chunks.empty()) { for (auto & output : outputs) output.finish(); @@ -380,11 +389,17 @@ IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_in { auto & waiting_output = output_ports[waiting_outputs.front()]; waiting_outputs.pop(); - - waiting_output.port->pushData(std::move(abandoned_chunks.back())); - abandoned_chunks.pop_back(); - - waiting_output.status = OutputStatus::NotActive; + // push chunk to finished port will lose it + if (waiting_output.status == OutputStatus::NeedData) + { + waiting_output.port->pushData(std::move(abandoned_chunks.back())); + abandoned_chunks.pop_back(); + waiting_output.status = OutputStatus::NotActive; + } + else + { + LOG_WARNING(logger, "One output in waiting_outputs is finished"); + } } /// Enable more inputs if needed. @@ -406,9 +421,19 @@ IProcessor::Status StrictResizeProcessor::prepare(const PortNumbers & updated_in auto & output = output_ports[waiting_outputs.front()]; waiting_outputs.pop(); + if (output.status != OutputStatus::Finished) + ++num_finished_outputs; + output.status = OutputStatus::Finished; output.port->finish(); - ++num_finished_outputs; + } + + if (num_finished_outputs == outputs.size()) + { + for (auto & input : inputs) + input.close(); + + return Status::Finished; } if (disabled_input_ports.empty()) diff --git a/src/Processors/Sources/SourceFromIntermediateResultCache.h b/src/Processors/Sources/SourceFromIntermediateResultCache.h index 19f36628d5d..c0f8f6b8860 100644 --- a/src/Processors/Sources/SourceFromIntermediateResultCache.h +++ b/src/Processors/Sources/SourceFromIntermediateResultCache.h @@ -30,9 +30,12 @@ class SourceFromIntermediateResultCache : public ISource if (!chunk.empty()) { size_t num_columns = chunk.getNumColumns(); - auto columns = chunk.detachColumns(); + size_t num_rows = chunk.getNumRows(); + auto cache_columns = chunk.detachColumns(); + Columns output_columns(num_columns); for (size_t i = 0; i < num_columns; ++i) - chunk.addColumn(std::move(columns[cache_pos_to_output_pos[i]])); + output_columns[cache_pos_to_output_pos[i]] = std::move(cache_columns[i]); + chunk.setColumns(std::move(output_columns), num_rows); return chunk; } else diff --git a/src/Processors/Sources/SourceWithProgress.cpp b/src/Processors/Sources/SourceWithProgress.cpp index 60219109ecb..81d6b70bcad 100644 --- a/src/Processors/Sources/SourceWithProgress.cpp +++ b/src/Processors/Sources/SourceWithProgress.cpp @@ -86,6 +86,18 @@ void SourceWithProgress::work() } } +void SourceWithProgress::updateProgress(const Progress & value) +{ + if (progress_callback) + progress_callback(value); + + if (process_list_elem) + { + if (!process_list_elem->updateProgressIn(value)) + cancel(); + } +} + /// Aggregated copy-paste from IBlockInputStream::progressImpl. /// Most of this must be done in PipelineExecutor outside. Now it's done for compatibility with IBlockInputStream. void SourceWithProgress::progress(const Progress & value) diff --git a/src/Processors/Sources/SourceWithProgress.h b/src/Processors/Sources/SourceWithProgress.h index 8bbf571c459..b6d9b100a82 100644 --- a/src/Processors/Sources/SourceWithProgress.h +++ b/src/Processors/Sources/SourceWithProgress.h @@ -55,7 +55,7 @@ class SourceWithProgress : public ISourceWithProgress void setProcessListElement(QueryStatus * elem) final; void setProgressCallback(const ProgressCallback & callback) final { progress_callback = callback; } void addTotalRowsApprox(size_t value) final { total_rows_approx += value; } - + void updateProgress(const Progress & value); protected: /// Call this method to provide information about progress. void progress(const Progress & value); @@ -66,8 +66,8 @@ class SourceWithProgress : public ISourceWithProgress StreamLocalLimits limits; SizeLimits leaf_limits; std::shared_ptr quota; - ProgressCallback progress_callback; QueryStatus * process_list_elem = nullptr; + ProgressCallback progress_callback; /// The approximate total number of rows to read. For progress bar. size_t total_rows_approx = 0; diff --git a/src/Processors/Transforms/AggregatingStreamingTransform.cpp b/src/Processors/Transforms/AggregatingStreamingTransform.cpp index 129d2ac9897..66ebaaf19f1 100644 --- a/src/Processors/Transforms/AggregatingStreamingTransform.cpp +++ b/src/Processors/Transforms/AggregatingStreamingTransform.cpp @@ -68,7 +68,7 @@ ISimpleTransform::Status AggregatingStreamingTransform::prepare() if (has_left && start_generated) { - output_data.chunk = std::move(chunks[chunk_idx++]); + output_data.chunk = std::move(fetchNewChunk()); output.pushData(std::move(output_data)); has_left = chunk_idx != chunks.size(); return Status::PortFull; @@ -110,7 +110,8 @@ ISimpleTransform::Status AggregatingStreamingTransform::prepare() /// To do this, we pass a block with zero rows to aggregate. if (params->params.keys_size == 0 && !params->params.empty_result_for_aggregation_by_empty_set) { - params->aggregator.executeOnBlock(getInputs().front().getHeader(), variants, key_columns, aggregate_columns, no_more_keys); + params->aggregator.executeOnBlock( + getInputs().front().getHeader(), variants, key_columns, aggregate_columns, no_more_keys); has_left = true; start_generated = true; return Status::Ready; @@ -145,18 +146,8 @@ void AggregatingStreamingTransform::work() return; } - try - { - transform(input_data.chunk); - output_data.chunk.swap(input_data.chunk); - } - catch (DB::Exception &) - { - output_data.exception = std::current_exception(); - has_output = false; - has_input = false; - return; - } + transform(input_data.chunk); + output_data.chunk.swap(input_data.chunk); if (output_data.chunk.hasRows()) has_output = true; @@ -172,12 +163,12 @@ void AggregatingStreamingTransform::transform(DB::Chunk & chunk) if (!is_generated) { generate(chunk); - is_generated = true; + if (!is_without_key) + is_generated = true; return; } - output_data.chunk = std::move(chunks[chunk_idx++]); - output.pushData(std::move(output_data)); + chunk = std::move(fetchNewChunk()); has_left = chunk_idx != chunks.size(); } return; @@ -195,7 +186,7 @@ void AggregatingStreamingTransform::transform(DB::Chunk & chunk) { input_rows += num_rows; auto block = inputs.front().getHeader().cloneWithColumns(chunk.getColumns()); - if (!params->aggregator.mergeOnBlock(block, variants, no_more_keys)) + if (!params->aggregator.mergeOnBlock(block, variants, no_more_keys, is_cancelled)) { start_generated = true; no_more_data_needed = true; @@ -275,8 +266,10 @@ void AggregatingStreamingTransform::generate(DB::Chunk & chunk) chunks.emplace_back(std::move(tmp_chunk)); } - chunk = std::move(chunks[chunk_idx++]); + chunk = std::move(fetchNewChunk()); rows_returned += chunk.getNumRows(); has_left = chunk_idx != chunks.size(); + LOG_TRACE( + log, "{} blocks generate, {} chunks remain, {} rows return", blocks_list.size(), chunks.size() - chunk_idx, chunk.getNumRows()); } } diff --git a/src/Processors/Transforms/AggregatingStreamingTransform.h b/src/Processors/Transforms/AggregatingStreamingTransform.h index 166c8765fb4..9c166f70c27 100644 --- a/src/Processors/Transforms/AggregatingStreamingTransform.h +++ b/src/Processors/Transforms/AggregatingStreamingTransform.h @@ -11,6 +11,11 @@ namespace DB { +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + /** Aggregates streaming transform */ class AggregatingStreamingTransform : public ISimpleTransform @@ -45,6 +50,21 @@ class AggregatingStreamingTransform : public ISimpleTransform return (aggregation_ratio > 0) && variants.size() / (input_rows * 1.0) < aggregation_ratio; } + ALWAYS_INLINE Chunk & fetchNewChunk() + { + if (chunk_idx >= chunks.size()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "No chunk can be fetch, chunk_idx:{}, chunks_size:{}, has_left:{}, is_generated:{}, start_generated:{}, is_two_level:{}", + chunk_idx, + chunks.size(), + has_left, + is_generated, + start_generated, + is_two_level); + return chunks[chunk_idx++]; + } + // bool continueLocalAgg(size_t chunk_rows); /// To read the data that was flushed into the temporary data file. Processors processors; diff --git a/src/Processors/Transforms/AggregatingTransform.cpp b/src/Processors/Transforms/AggregatingTransform.cpp index f21bf93e4f2..492a57e455b 100644 --- a/src/Processors/Transforms/AggregatingTransform.cpp +++ b/src/Processors/Transforms/AggregatingTransform.cpp @@ -180,7 +180,9 @@ class ConvertingAggregatedToChunksTransform : public IProcessor public: ConvertingAggregatedToChunksTransform(AggregatingTransformParamsPtr params_, ManyAggregatedDataVariantsPtr data_, size_t num_threads_) : IProcessor({}, {params_->getHeader()}) - , params(std::move(params_)), data(std::move(data_)), num_threads(num_threads_) {} + , params(std::move(params_)), data(std::move(data_)) + , shared_data(std::make_shared()) + , num_threads(num_threads_) {} String getName() const override { return "ConvertingAggregatedToChunksTransform"; } @@ -239,8 +241,7 @@ class ConvertingAggregatedToChunksTransform : public IProcessor for (auto & input : inputs) input.close(); - if (shared_data) - shared_data->is_cancelled.store(true); + shared_data->is_cancelled.store(true); return Status::Finished; } @@ -265,6 +266,11 @@ class ConvertingAggregatedToChunksTransform : public IProcessor return prepareTwoLevel(); } + void onCancel() override + { + shared_data->is_cancelled.store(true, std::memory_order_seq_cst); + } + private: IProcessor::Status preparePushToOutput() { @@ -359,7 +365,7 @@ class ConvertingAggregatedToChunksTransform : public IProcessor if (first->type == AggregatedDataVariants::Type::without_key || params->params.overflow_row) { - params->aggregator.mergeWithoutKeyDataImpl(*data); + params->aggregator.mergeWithoutKeyDataImpl(*data, shared_data->is_cancelled); auto block = params->aggregator.prepareBlockAndFillWithoutKey( *first, params->final, first->type != AggregatedDataVariants::Type::without_key); @@ -399,7 +405,7 @@ class ConvertingAggregatedToChunksTransform : public IProcessor if (num_threads == 0) throw Exception("num_threads can't be zero when use two-level aggregation", ErrorCodes::LOGICAL_ERROR); AggregatedDataVariantsPtr & first = data->at(0); - shared_data = std::make_shared(); + processors.reserve(num_threads); for (size_t thread = 0; thread < num_threads; ++thread) { @@ -570,7 +576,7 @@ void AggregatingTransform::consume(Chunk chunk) { auto block = getInputs().front().getHeader().cloneWithColumns(chunk.detachColumns()); block = materializeBlock(block); - if (!params->aggregator.mergeOnBlock(block, variants, no_more_keys)) + if (!params->aggregator.mergeOnBlock(block, variants, no_more_keys, is_cancelled)) is_consume_finished = true; } else @@ -592,7 +598,7 @@ void AggregatingTransform::initGenerate() if (variants.empty() && params->params.keys_size == 0 && !params->params.empty_result_for_aggregation_by_empty_set) { if (params->only_merge) - params->aggregator.mergeOnBlock(getInputs().front().getHeader(), variants, no_more_keys); + params->aggregator.mergeOnBlock(getInputs().front().getHeader(), variants, no_more_keys, is_cancelled); else params->aggregator.executeOnBlock(getInputs().front().getHeader(), variants, key_columns, aggregate_columns, no_more_keys); } diff --git a/src/Processors/Transforms/CubeTransform.cpp b/src/Processors/Transforms/CubeTransform.cpp index 0bd0e5a5eca..9278a1bf7d7 100644 --- a/src/Processors/Transforms/CubeTransform.cpp +++ b/src/Processors/Transforms/CubeTransform.cpp @@ -45,7 +45,7 @@ Chunk CubeTransform::merge(Chunks && chunks, bool final) for (auto & chunk : chunks) rollup_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); - auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final); + auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final, is_cancelled); auto num_rows = rollup_block.rows(); return Chunk(rollup_block.getColumns(), num_rows); } diff --git a/src/Processors/Transforms/ExplainAnalyzeTransform.cpp b/src/Processors/Transforms/ExplainAnalyzeTransform.cpp index 0c462512134..a0c35ee4e5f 100644 --- a/src/Processors/Transforms/ExplainAnalyzeTransform.cpp +++ b/src/Processors/Transforms/ExplainAnalyzeTransform.cpp @@ -1,12 +1,13 @@ #include -#include #include #include #include -#include #include #include +#include #include +#include +#include "Interpreters/ProcessorProfile.h" namespace DB { @@ -24,7 +25,9 @@ ExplainAnalyzeTransform::ExplainAnalyzeTransform( , query_plan_ptr(std::move(query_plan_ptr_)) , segment_descriptions(segment_descriptions_) , settings(settings_) -{} +{ + coordinator_address = extractExchangeHostPort(context->getQueryContext()->getCoordinatorAddress()); +} void ExplainAnalyzeTransform::transform(Chunk & chunk) { @@ -32,69 +35,91 @@ void ExplainAnalyzeTransform::transform(Chunk & chunk) if (!input.isFinished()) return; - ///segment_id, worker_address -> profiles - std::unordered_map> segment_profiles; - // If the information of segment0 cannot be accepted ProcessorsSet processors_set; ProcessorProfiles profiles; getProcessorProfiles(processors_set, profiles, this); - for (auto & profile : profiles) - segment_profiles[profile->segment_id][profile->worker_address].push_back(profile); + auto segment0_profile = GroupedProcessorProfile::getGroupedProfiles(profiles); - getRemoteProcessorProfiles(segment_profiles); - - ///segment_id -> grouped_profile_tree - std::unordered_map> segment_grouped_profile; - SegmentAndWorkerToGroupedProfile worker_grouped_profiles; - for (auto & [segment_id, segment_profile_in_worker] : segment_profiles) + auto scheduler = context->getSegmentScheduler(); + UInt64 time_out = context->getSettingsRef().operator_profile_receive_timeout; + auto time_start = std::chrono::system_clock::now(); + while (!scheduler->alreadyReceivedAllSegmentStatus(context->getCurrentQueryId())) { - for (auto & [address, segment_profile] : segment_profile_in_worker) - { - auto input_profile_root = GroupedProcessorProfile::getGroupedProfiles(segment_profile); - auto output = GroupedProcessorProfile::getOutputRoot(input_profile_root); - if (kind == ASTExplainQuery::ExplainKind::PipelineAnalyze && !output->children.empty()) - worker_grouped_profiles[segment_id][address] = output->children[0]; - else - segment_grouped_profile[segment_id].emplace_back(output); - } + auto now = std::chrono::system_clock::now(); + UInt64 elapsed = std::chrono::duration_cast(now - time_start).count(); + if (elapsed >= time_out) + break; } + auto profiles_map = scheduler->getSegmentsProfile(context->getCurrentQueryId()); String explain; - if ((kind == ASTExplainQuery::ExplainKind::LogicalAnalyze || kind == ASTExplainQuery::ExplainKind::DistributedAnalyze) && !segment_grouped_profile.empty()) + if ((kind == ASTExplainQuery::ExplainKind::LogicalAnalyze || kind == ASTExplainQuery::ExplainKind::DistributedAnalyze)) { - auto steps_profiles = StepOperatorProfile::aggregateOperatorProfileToStepLevel(segment_grouped_profile); - auto step_agg_operator_profiles = AggregatedStepOperatorProfile::aggregateStepOperatorProfileBetweenWorkers(steps_profiles); + AddressToStepProfile addr_to_step_profile; + for (auto & [segment_id, segment_profiles] : profiles_map) + { + for (auto & segment_profile : segment_profiles) + { + for (auto & [step_id, profile] : segment_profile->profiles) + addr_to_step_profile[segment_profile->worker_address][step_id] = profile; + } + } + + auto segment0_steps_profiles = GroupedProcessorProfile::aggregateOperatorProfileToStepLevel(segment0_profile); + for (auto & [step_id, profile] : segment0_steps_profiles) + addr_to_step_profile[coordinator_address][step_id] = profile; CardinalityEstimator::estimate(*query_plan_ptr, context); std::unordered_map costs = CostCalculator::calculate(*query_plan_ptr, *context); - if (settings.json) + auto step_agg_operator_profiles = ProfileMetric::aggregateStepProfileBetweenWorkers(addr_to_step_profile); + if (kind == ASTExplainQuery::ExplainKind::LogicalAnalyze) { - if (kind == ASTExplainQuery::ExplainKind::LogicalAnalyze) + if (settings.json) { auto plan_cost = CostCalculator::calculatePlanCost(*query_plan_ptr, *context); explain = PlanPrinter::jsonLogicalPlan(*query_plan_ptr, plan_cost, step_agg_operator_profiles, costs, settings); } - else if (kind == ASTExplainQuery::ExplainKind::DistributedAnalyze && !segment_descriptions.empty()) - explain = PlanPrinter::jsonDistributedPlan(segment_descriptions, step_agg_operator_profiles); + else + explain = PlanPrinter::textLogicalPlan(*query_plan_ptr, context, costs, step_agg_operator_profiles, settings); } - else + else if (kind == ASTExplainQuery::ExplainKind::DistributedAnalyze && !segment_descriptions.empty()) { - if (kind == ASTExplainQuery::ExplainKind::LogicalAnalyze) - explain = PlanPrinter::textLogicalPlan(*query_plan_ptr, context, costs, step_agg_operator_profiles, settings); - else if (kind == ASTExplainQuery::ExplainKind::DistributedAnalyze && !segment_descriptions.empty()) - explain = PlanPrinter::textDistributedPlan(segment_descriptions, context, costs, step_agg_operator_profiles, *query_plan_ptr, settings); + if (settings.json) + explain = PlanPrinter::jsonDistributedPlan(segment_descriptions, step_agg_operator_profiles); + else + explain = PlanPrinter::textDistributedPlan( + segment_descriptions, context, costs, step_agg_operator_profiles, *query_plan_ptr, settings, profiles_map); } GraphvizPrinter::printLogicalPlan(*query_plan_ptr, context, "5999_explain_analyze", step_agg_operator_profiles); } - else if (kind == ASTExplainQuery::ExplainKind::PipelineAnalyze && !worker_grouped_profiles.empty()) + else if (kind == ASTExplainQuery::ExplainKind::PipelineAnalyze) { + SegIdAndAddrToPipelineProfile worker_grouped_profiles; + segment0_profile = GroupedProcessorProfile::getOutputRoot(segment0_profile); + if (segment0_profile->processor_name == "output_root" && !segment0_profile->children.empty()) + segment0_profile = segment0_profile->children[0]; + worker_grouped_profiles[0][coordinator_address] = segment0_profile; + for (auto & [segment_id, segment_profiles] : profiles_map) + { + for (auto & segment_profile : segment_profiles) + { + if (segment_profile->profiles.empty()) + continue; + auto profile + = GroupedProcessorProfile::getGroupedProfileFromMetrics(segment_profile->profiles, segment_profile->profile_root_id); + if (profile->processor_name == "output_root" && !profile->children.empty()) + profile = profile->children[0]; + worker_grouped_profiles[segment_profile->segment_id][segment_profile->worker_address] = std::move(profile); + } + } + if (settings.aggregate_profiles) - worker_grouped_profiles = GroupedProcessorProfile::aggregateProfileBetweenWorkers(worker_grouped_profiles); + worker_grouped_profiles = GroupedProcessorProfile::aggregatePipelineProfileBetweenWorkers(worker_grouped_profiles); if (settings.json) explain = PlanPrinter::jsonPipelineProfile(segment_descriptions, worker_grouped_profiles); else - explain = PlanPrinter::textPipelineProfile(segment_descriptions, worker_grouped_profiles); + explain = PlanPrinter::textPipelineProfile(segment_descriptions, worker_grouped_profiles, settings, profiles_map); } MutableColumns cols(1); @@ -244,7 +269,7 @@ void ExplainAnalyzeTransform::getProcessorProfiles(ProcessorsSet & processors_se child->input_bytes = from->getProcessorDataStats().input_bytes; child->output_rows = from->getProcessorDataStats().output_rows; child->output_bytes = from->getProcessorDataStats().output_bytes; - child->worker_address = "localhost:0"; + child->worker_address = coordinator_address; processors_set.insert(from); profiles.emplace_back(child); getProcessorProfiles(processors_set, profiles, from); diff --git a/src/Processors/Transforms/ExplainAnalyzeTransform.h b/src/Processors/Transforms/ExplainAnalyzeTransform.h index d2b2d1f52df..3fd077d0b5f 100644 --- a/src/Processors/Transforms/ExplainAnalyzeTransform.h +++ b/src/Processors/Transforms/ExplainAnalyzeTransform.h @@ -26,7 +26,7 @@ class ExplainAnalyzeTransform : public ISimpleTransform void transform(Chunk & chunk) override; ISimpleTransform::Status prepare() override; - static void getProcessorProfiles(ProcessorsSet & processors_set, ProcessorProfiles & profiles, const IProcessor * processor); + void getProcessorProfiles(ProcessorsSet & processors_set, ProcessorProfiles & profiles, const IProcessor * processor); void getRemoteProcessorProfiles(std::unordered_map> & segment_profiles); private: ASTExplainQuery::ExplainKind kind; @@ -35,5 +35,6 @@ class ExplainAnalyzeTransform : public ISimpleTransform PlanSegmentDescriptions segment_descriptions; bool has_final_transform = true; QueryPlanSettings settings; + String coordinator_address; }; } diff --git a/src/Processors/Transforms/IntermediateResultCacheTransform.cpp b/src/Processors/Transforms/IntermediateResultCacheTransform.cpp index ef1b11db845..9f46e744c5b 100644 --- a/src/Processors/Transforms/IntermediateResultCacheTransform.cpp +++ b/src/Processors/Transforms/IntermediateResultCacheTransform.cpp @@ -18,22 +18,28 @@ IntermediateResultCacheTransform::IntermediateResultCacheTransform( CacheParam & cache_param_, UInt64 cache_max_bytes_, UInt64 cache_max_rows_, - bool all_part_in_cache_) + CacheHolderPtr cache_holder_) : ISimpleTransform(header_, header_, false) , cache(std::move(cache_)) , cache_param(cache_param_) , cache_max_bytes(cache_max_bytes_) , cache_max_rows(cache_max_rows_) - , all_part_in_cache(all_part_in_cache_) + , cache_holder(std::move(cache_holder_)) , log(&Poco::Logger::get("IntermediateResultCacheTransform")) { } IProcessor::Status IntermediateResultCacheTransform::prepare() { - if (all_part_in_cache) + if (cache_holder->all_part_in_cache) stopReading(); + if (!cache_holder->early_finish && output.isFinished() && !input.isFinished()) + { + cache_holder->early_finish = true; + LOG_DEBUG(log, "Cache {} generate was early finish", cache_param.digest); + } + return ISimpleTransform::prepare(); } @@ -52,6 +58,9 @@ void IntermediateResultCacheTransform::transform(DB::Chunk & chunk) if (!owner_info.empty()) { CacheKey key{cache_param.digest, cache_param.cached_table.getFullTableName(), owner_info}; + if (!cache_holder->write_cache.contains(key)) + return; + auto it = uncompleted_cache.find(key); if (it != uncompleted_cache.end()) value = it->second; @@ -79,9 +88,12 @@ void IntermediateResultCacheTransform::transform(DB::Chunk & chunk) auto cache_chunk = chunk.clone(); size_t num_columns = cache_chunk.getNumColumns(); - auto columns = cache_chunk.detachColumns(); + size_t num_rows = cache_chunk.getNumRows(); + auto output_columns = cache_chunk.detachColumns(); + Columns cache_columns(num_columns); for (size_t i = 0; i < num_columns; ++i) - cache_chunk.addColumn(std::move(columns[cache_param.output_pos_to_cache_pos[i]])); + cache_columns[cache_param.output_pos_to_cache_pos[i]] = std::move(output_columns[i]); + cache_chunk.setColumns(std::move(cache_columns), num_rows); if (value) value->addChunk(cache_chunk); diff --git a/src/Processors/Transforms/IntermediateResultCacheTransform.h b/src/Processors/Transforms/IntermediateResultCacheTransform.h index d639b4b8557..82c47310214 100644 --- a/src/Processors/Transforms/IntermediateResultCacheTransform.h +++ b/src/Processors/Transforms/IntermediateResultCacheTransform.h @@ -31,7 +31,7 @@ class IntermediateResultCacheTransform : public ISimpleTransform CacheParam & cache_param_, UInt64 cache_max_bytes_, UInt64 cache_max_rows_, - bool all_part_in_cache_); + CacheHolderPtr cache_holder_); String getName() const override { @@ -47,7 +47,7 @@ class IntermediateResultCacheTransform : public ISimpleTransform CacheParam cache_param; UInt64 cache_max_bytes = 0; UInt64 cache_max_rows = 0; - bool all_part_in_cache = false; + CacheHolderPtr cache_holder; std::unordered_map uncompleted_cache; Poco::Logger * log; }; diff --git a/src/Processors/Transforms/MergingAggregatedMemoryEfficientTransform.cpp b/src/Processors/Transforms/MergingAggregatedMemoryEfficientTransform.cpp index df2ea4b03f0..1dcf5de4a4a 100644 --- a/src/Processors/Transforms/MergingAggregatedMemoryEfficientTransform.cpp +++ b/src/Processors/Transforms/MergingAggregatedMemoryEfficientTransform.cpp @@ -342,7 +342,7 @@ void MergingAggregatedBucketTransform::transform(Chunk & chunk) res_info->bucket_num = chunks_to_merge->bucket_num; chunk.setChunkInfo(std::move(res_info)); - auto block = params->aggregator.mergeBlocks(blocks_list, params->final); + auto block = params->aggregator.mergeBlocks(blocks_list, params->final, is_cancelled); size_t num_rows = block.rows(); chunk.setColumns(block.getColumns(), num_rows); } diff --git a/src/Processors/Transforms/MergingAggregatedTransform.cpp b/src/Processors/Transforms/MergingAggregatedTransform.cpp index ddc58d830da..fc49e2b03c4 100644 --- a/src/Processors/Transforms/MergingAggregatedTransform.cpp +++ b/src/Processors/Transforms/MergingAggregatedTransform.cpp @@ -58,7 +58,7 @@ Chunk MergingAggregatedTransform::generate() next_block = blocks.begin(); /// TODO: this operation can be made async. Add async for IAccumulatingTransform. - params->aggregator.mergeBlocks(std::move(bucket_to_blocks), data_variants, max_threads); + params->aggregator.mergeBlocks(std::move(bucket_to_blocks), data_variants, max_threads, is_cancelled); blocks = params->aggregator.convertToBlocks(data_variants, params->final, max_threads); next_block = blocks.begin(); } diff --git a/src/Processors/Transforms/ProcessorToOutputStream.cpp b/src/Processors/Transforms/ProcessorToOutputStream.cpp new file mode 100644 index 00000000000..d1fc64bc479 --- /dev/null +++ b/src/Processors/Transforms/ProcessorToOutputStream.cpp @@ -0,0 +1,72 @@ +#include +#include +#include +#include + +namespace DB +{ + +Block ProcessorToOutputStream::newHeader() +{ + return {ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared(), "inserted_rows")}; +} + +ProcessorToOutputStream::ProcessorToOutputStream(BlockOutputStreamPtr stream_) + : IProcessor({stream_->getHeader()}, {newHeader()}) + , input(inputs.front()) + , output(outputs.front()) + , stream(std::move(stream_)) +{ + total_rows = 0; + stream->writePrefix(); +} + +void ProcessorToOutputStream::consume(Chunk chunk) +{ + total_rows += chunk.getNumRows(); + stream->write(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); +} + +Chunk ProcessorToOutputStream::getReturnChunk() +{ + auto total_rows_column = DataTypeUInt64().createColumnConst(1, total_rows); + return Chunk({total_rows_column}, 1); +} + +void ProcessorToOutputStream::onFinish() +{ + stream->writeSuffix(); + + auto return_chunk = getReturnChunk(); + output_data.chunk = std::move(return_chunk); + output.pushData(std::move(output_data)); +} + +ProcessorToOutputStream::Status ProcessorToOutputStream::prepare() +{ + if (has_input) + return Status::Ready; + + if (input.isFinished()) + { + onFinish(); + output.finish(); + return Status::Finished; + } + + input.setNeeded(); + if (!input.hasData()) + return Status::NeedData; + + current_chunk = input.pull(true); + has_input = true; + return Status::Ready; +} + +void ProcessorToOutputStream::work() +{ + consume(std::move(current_chunk)); + has_input = false; +} + +} diff --git a/src/Processors/Transforms/ProcessorToOutputStream.h b/src/Processors/Transforms/ProcessorToOutputStream.h new file mode 100644 index 00000000000..c72f11555cd --- /dev/null +++ b/src/Processors/Transforms/ProcessorToOutputStream.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include + +namespace DB +{ + +class ProcessorToOutputStream : public IProcessor +{ +public: + explicit ProcessorToOutputStream(BlockOutputStreamPtr stream_); + + String getName() const override { return "ProcessorToOutputStream"; } + + static Block newHeader(); + Chunk getReturnChunk(); + + Status prepare() override; + void work() override; + + InputPort & getInputPort() { return input; } + OutputPort & getOutputPort() { return output; } + +protected: + InputPort & input; + OutputPort & output; + + Chunk current_chunk; + Port::Data output_data; + bool has_input = false; + + void consume(Chunk chunk); + void onFinish(); + +private: + BlockOutputStreamPtr stream; + size_t total_rows; + +}; + +} diff --git a/src/Processors/Transforms/ReverseTransform.cpp b/src/Processors/Transforms/ReverseTransform.cpp index 98f2bf54aa5..e6b2404b9b1 100644 --- a/src/Processors/Transforms/ReverseTransform.cpp +++ b/src/Processors/Transforms/ReverseTransform.cpp @@ -17,6 +17,15 @@ void ReverseTransform::transform(Chunk & chunk) for (auto & column : columns) column = column->permute(permutation, 0); + if (auto * side_block = chunk.getSideBlock()) + { + for (size_t i = 0; i < side_block->columns(); ++i) + { + auto & side_column = side_block->getByPosition(i).column; + side_column = side_column->permute(permutation, 0); + } + } + chunk.setColumns(std::move(columns), num_rows); } diff --git a/src/Processors/Transforms/RollupTransform.cpp b/src/Processors/Transforms/RollupTransform.cpp index d36d1737aeb..b269efd0766 100644 --- a/src/Processors/Transforms/RollupTransform.cpp +++ b/src/Processors/Transforms/RollupTransform.cpp @@ -46,7 +46,7 @@ Chunk RollupTransform::merge(Chunks && chunks, bool final) for (auto & chunk : chunks) rollup_blocks.emplace_back(getInputPort().getHeader().cloneWithColumns(chunk.detachColumns())); - auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final); + auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final, is_cancelled); auto num_rows = rollup_block.rows(); return Chunk(rollup_block.getColumns(), num_rows); } diff --git a/src/Processors/Transforms/RollupWithGroupingTransform.cpp b/src/Processors/Transforms/RollupWithGroupingTransform.cpp index 0b12e0efc5b..462383ec76d 100644 --- a/src/Processors/Transforms/RollupWithGroupingTransform.cpp +++ b/src/Processors/Transforms/RollupWithGroupingTransform.cpp @@ -53,7 +53,7 @@ Chunk RollupWithGroupingTransform::merge(Chunks && chunks, bool final) for (auto & chunk : chunks) rollup_blocks.emplace_back(getOutputPort().getHeader().cloneWithColumns(chunk.detachColumns())); - auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final); + auto rollup_block = params->aggregator.mergeBlocks(rollup_blocks, final, is_cancelled); auto num_rows = rollup_block.rows(); return Chunk(rollup_block.getColumns(), num_rows); } diff --git a/src/Processors/Transforms/TableFinishTransform.cpp b/src/Processors/Transforms/TableFinishTransform.cpp index e7d1b688eb1..54b4442b65f 100644 --- a/src/Processors/Transforms/TableFinishTransform.cpp +++ b/src/Processors/Transforms/TableFinishTransform.cpp @@ -1,24 +1,40 @@ #include #include +#include #include #include #include #include #include -#include +#include "Common/StackTrace.h" #include +#include "Interpreters/ActionsVisitor.h" +#include "Interpreters/ProcessList.h" + +namespace ProfileEvents +{ +extern const Event InsertedRows; +extern const Event InsertedBytes; +} namespace DB { -TableFinishTransform::TableFinishTransform(const Block & header_, const StoragePtr & storage_, - const ContextPtr & context_, ASTPtr & query_) - : IProcessor({header_}, {header_}), input(inputs.front()) +TableFinishTransform::TableFinishTransform( + const Block & header_, const StoragePtr & storage_, const ContextPtr & context_, ASTPtr & query_, bool insert_select_with_profiles_) + : IProcessor({header_}, {header_}) + , input(inputs.front()) , output(outputs.front()) , storage(storage_) , context(context_) , query(query_) + , insert_select_with_profiles(insert_select_with_profiles_) +{ +} + +void TableFinishTransform::setProcessListElement(QueryStatus * elem) { + process_list_elem = elem; } Block TableFinishTransform::getHeader() @@ -80,37 +96,57 @@ TableFinishTransform::Status TableFinishTransform::prepare() if (!input.hasData()) return Status::NeedData; - current_chunk = input.pull(true); + current_output_chunk = input.pull(true); has_input = true; return Status::Ready; } void TableFinishTransform::work() { - consume(std::move(current_chunk)); + consume(std::move(current_output_chunk)); has_input = false; } void TableFinishTransform::consume(Chunk chunk) { output_chunk = std::move(chunk); + + if (insert_select_with_profiles && !output_chunk.empty()) + { + auto & column = output_chunk.getColumns()[0]; + + ReadProgress local_progress(column->get64(0), 0); + + ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.read_rows); + ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.read_bytes); + + if (process_list_elem) + { + process_list_elem->updateProgressOut(Progress(local_progress)); + } + + if (progress_callback) + { + progress_callback(Progress(local_progress)); + } + } + has_output = true; } void TableFinishTransform::onFinish() { + TransactionCnchPtr txn = context->getCurrentTransaction(); + txn->setMainTableUUID(storage->getStorageUUID()); + if (const auto * cnch_table = dynamic_cast(storage.get()); - cnch_table && cnch_table->commitTxnFromWorkerSide(cnch_table->getInMemoryMetadataPtr(), context)) + cnch_table && cnch_table->commitTxnInWriteSuffixStage(txn->getDedupImplVersion(context), context)) { /// for unique table, insert select|infile is committed from worker side - /// TODO: should also commit in server side } else - { - TransactionCnchPtr txn = context->getCurrentTransaction(); - txn->setMainTableUUID(storage->getStorageUUID()); txn->commitV2(); - } + /// Make sure locks are release after transaction commit if (!lock_holders.empty()) lock_holders.clear(); diff --git a/src/Processors/Transforms/TableFinishTransform.h b/src/Processors/Transforms/TableFinishTransform.h index c0bcfb44052..a4da03b286e 100644 --- a/src/Processors/Transforms/TableFinishTransform.h +++ b/src/Processors/Transforms/TableFinishTransform.h @@ -4,6 +4,7 @@ #include #include #include +#include "Processors/Sources/SourceWithProgress.h" namespace DB { @@ -11,7 +12,12 @@ namespace DB class TableFinishTransform : public IProcessor { public: - TableFinishTransform(const Block & header_, const StoragePtr & storage_, const ContextPtr & context_, ASTPtr & query_); + TableFinishTransform( + const Block & header_, + const StoragePtr & storage_, + const ContextPtr & context_, + ASTPtr & query_, + bool insert_select_with_profiles_ = false); String getName() const override { @@ -30,6 +36,12 @@ class TableFinishTransform : public IProcessor return output; } + void setProcessListElement(QueryStatus * elem); + void setProgressCallback(const ProgressCallback & callback) + { + progress_callback = callback; + } + private: void consume(Chunk block); void onFinish(); @@ -40,14 +52,18 @@ class TableFinishTransform : public IProcessor Block header; - Chunk current_chunk; + Chunk current_output_chunk; Chunk output_chunk; bool has_input = false; bool has_output = false; + ProgressCallback progress_callback; + QueryStatus * process_list_elem = nullptr; + StoragePtr storage; ContextPtr context; ASTPtr query; + bool insert_select_with_profiles; CnchLockHolderPtrs lock_holders; }; diff --git a/src/Processors/tests/gtest_bucket_shuffle.cpp b/src/Processors/tests/gtest_bucket_shuffle.cpp new file mode 100644 index 00000000000..d11777fc018 --- /dev/null +++ b/src/Processors/tests/gtest_bucket_shuffle.cpp @@ -0,0 +1,125 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +using namespace DB; + + +Block generateBlockWithTwoColumns(size_t total_rows) +{ + auto col_uint64 = ColumnUInt64::create(total_rows, 0); + auto & col_uint64_data = col_uint64->getData(); + auto col_string = ColumnString::create(); + for (size_t i = 0; i < total_rows; i++) + { + col_uint64_data[i] = i; + String str = "bucket_" + std::to_string(i) ; + col_string->insertData(str.data(), str.size()); + } + ColumnWithTypeAndName column_1{std::move(col_uint64), std::make_shared(), "column_1"}; + ColumnWithTypeAndName column_2{std::move(col_string), std::make_shared(), "column_2"}; + + + ColumnsWithTypeAndName columns; + columns.emplace_back(std::move(column_1)); + columns.emplace_back(std::move(column_2)); + return Block(columns); +} + + +bool comparePrepareBucketColumnWithBucketFunction(Block & expected, ColumnPtr result) +{ + auto expected_col = expected.getByName(COLUMN_BUCKET_NUMBER).column; + + if (expected_col->size() != result->size()) + return false; + for(size_t i = 0; i < expected_col->size(); i++) + { + if(expected_col->getUInt(i) != result->getUInt(i)) + return false; + } + return true; +} + +ColumnPtr executeBucketFunction(Block & block, const Names & bucket_columns, const Int64 & split_number, const bool is_with_range, const Int64 total_shard_num, ContextPtr context) +{ + String func_name = "sipHashBuitin"; + if(split_number && bucket_columns.size() == 1) + func_name = "dtspartition"; + Array params; + params.emplace_back(Field(func_name)); + params.emplace_back(Field(static_cast(total_shard_num))); + params.emplace_back(Field(is_with_range)); + params.emplace_back(Field(static_cast(split_number))); + + + ColumnsWithTypeAndName arguments; + for (const auto & name: bucket_columns) + { + arguments.push_back(block.getByName(name)); + } + + auto func = RepartitionTransform::getRepartitionHashFunction("bucket", arguments, context, params); + return func->execute(arguments, RepartitionTransform::REPARTITION_FUNC_RESULT_TYPE, block.rows(), false); +} + + +bool executeAndComparePrepareBucketColumnWithBucketFunction( + Block & block, + Names bucket_columns, + const Int64 & split_number, + const bool is_with_range, + const Int64 total_shard_num, + ContextPtr context) +{ + auto expected = block; + prepareBucketColumn(expected, bucket_columns, split_number, is_with_range, total_shard_num, context, false); + auto result = executeBucketFunction(block, bucket_columns, split_number, is_with_range, total_shard_num, context); + return comparePrepareBucketColumnWithBucketFunction(expected, result); +} + +TEST(BucketShuffleTest, BucketFunctionTest) +{ + tryRegisterFunctions(); + auto block = generateBlockWithTwoColumns(5); + auto local_context = Context::createCopy(getContext().context); + + + // sipHashBuitin + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1"}, 0, false, 33, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_2"}, 0, false, 100, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1", "column_2"}, 300, false, 100, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1", "column_2"}, 50, true, 100, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1", "column_2"}, 50, true, 300, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1", "column_2"}, 50, true, 1, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1", "column_2"}, 1, true, 50, local_context)); + + + // dtspartition + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1"}, 300, false, 100, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_1"}, 400, true, 99, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_2"}, 400, true, 99, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_2"}, 99, true, 101, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_2"}, 99, true, 1, local_context)); + ASSERT_TRUE(executeAndComparePrepareBucketColumnWithBucketFunction(block, {"column_2"}, 1, true, 102, local_context)); + + // fail test + auto expected = block; + prepareBucketColumn(expected, {"column_1"}, 0, false, 33, local_context, false); + auto result = executeBucketFunction(block, {"column_1"}, 0, false, 37, local_context); + ASSERT_FALSE(comparePrepareBucketColumnWithBucketFunction(expected, result)); +} + + diff --git a/src/Protos/DataModelHelpers.cpp b/src/Protos/DataModelHelpers.cpp index 8b96c9d6ef2..7f4fcdba5b4 100644 --- a/src/Protos/DataModelHelpers.cpp +++ b/src/Protos/DataModelHelpers.cpp @@ -35,6 +35,7 @@ #include "common/logger_useful.h" #include #include +#include #include #include #include @@ -533,6 +534,15 @@ ServerDataPartsVector createServerPartsFromDataParts(const MergeTreeMetaBase & s return res; } +ServerDataPartsVector createServerPartsFromDataParts(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts) +{ + ServerDataPartsVector res; + res.reserve(parts.size()); + for (const auto & part : parts) + res.push_back(createServerPartFromDataPart(storage, part)); + return res; +} + IMergeTreeDataPartsVector createPartVectorFromServerParts( const MergeTreeMetaBase & storage, const ServerDataPartsVector & parts) { diff --git a/src/Protos/DataModelHelpers.h b/src/Protos/DataModelHelpers.h index 1b8ff6a1f90..ca205411abf 100644 --- a/src/Protos/DataModelHelpers.h +++ b/src/Protos/DataModelHelpers.h @@ -349,6 +349,7 @@ ServerDataPartsVector createServerPartsFromModels(const MergeTreeMetaBase & storage, const pb::RepeatedPtrField & parts_model); ServerDataPartsVector createServerPartsFromDataParts(const MergeTreeMetaBase & storage, const MergeTreeDataPartsCNCHVector & parts); +ServerDataPartsVector createServerPartsFromDataParts(const MergeTreeMetaBase & storage, const MutableMergeTreeDataPartsCNCHVector & parts); IMergeTreeDataPartsVector createPartVectorFromServerParts( const MergeTreeMetaBase & storage, diff --git a/src/Protos/RPCHelpers.h b/src/Protos/RPCHelpers.h index 44b0c9ac8a2..b274572cc57 100644 --- a/src/Protos/RPCHelpers.h +++ b/src/Protos/RPCHelpers.h @@ -243,20 +243,4 @@ namespace DB::RPCHelpers { return std::make_shared(worekr_info_data.worker_id(), worekr_info_data.num_workers(), worekr_info_data.index()); } - - inline void fillWorkerInfo(Protos::WorkerInfo & worekr_info_data, const String & worker_id, UInt64 num_workers) - { - /// TODO: Since worker IDs have the same format {commonprefix}-{index}, we can have a specific function to resolve worker index - if (auto pos = worker_id.find_last_of('-'); pos != String::npos) - { - worekr_info_data.set_index(std::stoul(worker_id.substr(pos + 1))); - } - else - { - // set an invalid index if cannot parse index from workerID - worekr_info_data.set_index(num_workers); - } - worekr_info_data.set_worker_id(worker_id); - worekr_info_data.set_num_workers(num_workers); - } } diff --git a/src/Protos/cnch_server_rpc.proto b/src/Protos/cnch_server_rpc.proto index 51ba12d1b0b..20acba0b13f 100644 --- a/src/Protos/cnch_server_rpc.proto +++ b/src/Protos/cnch_server_rpc.proto @@ -114,11 +114,13 @@ message CommitPartsReq { // Binlog will be committed with parts while sync MaterializedMySQL optional MySQLBinlogModel binlog = 14; optional uint64 peak_memory_usage = 15; + optional uint32 dedup_mode = 16; }; message CommitPartsResp { optional string exception = 1; optional uint64 commit_timestamp = 2; + optional uint32 dedup_impl_version = 3; } message SubmitPreloadTaskReq { @@ -180,6 +182,7 @@ message FetchPartitionsReq { optional string predicate = 4; repeated string column_name_filter = 5; optional uint64 txnid = 6; + optional bool ignore_ttl = 7; } message FetchPartitionsResp { @@ -281,6 +284,13 @@ message CleanTransactionReq { message CleanTransactionResp { optional string exception = 1; } +/// Called by CleanTransaction. +message CleanUndoBuffersReq { + required DataModelTransactionRecord txn_record = 1; +}; + +message CleanUndoBuffersResp { optional string exception = 1; }; + message GetServerStartTimeReq {} message GetServerStartTimeResp @@ -515,6 +525,18 @@ message handleRefreshTaskOnFinishResp optional string exception = 1; } +message GetDedupImplVersionReq +{ + required uint64 txn_id = 1; + required UUID uuid = 2; +} + +message GetDedupImplVersionResp +{ + required uint32 version = 1; + optional string exception = 2; +} + service CnchServerService { rpc reportTaskHeartbeat(ReportTaskHeartbeatReq) returns(ReportTaskHeartbeatResp); @@ -579,6 +601,7 @@ service CnchServerService { returns(CommitWorkerRPCByKeyResp); rpc cleanTransaction(CleanTransactionReq) returns(CleanTransactionResp); + rpc cleanUndoBuffers(CleanUndoBuffersReq) returns(CleanUndoBuffersResp); rpc acquireLock(AcquireLockReq) returns(AcquireLockResp); @@ -654,4 +677,7 @@ service CnchServerService { rpc handleRefreshTaskOnFinish(handleRefreshTaskOnFinishReq) returns (handleRefreshTaskOnFinishResp); + + rpc getDedupImplVersion(GetDedupImplVersionReq) + returns (GetDedupImplVersionResp); }; diff --git a/src/Protos/cnch_worker_rpc.proto b/src/Protos/cnch_worker_rpc.proto index 99cd8a50906..71da188d569 100644 --- a/src/Protos/cnch_worker_rpc.proto +++ b/src/Protos/cnch_worker_rpc.proto @@ -217,6 +217,7 @@ message PreloadDataPartsReq optional bool sync = 4; optional uint64 preload_level = 5; optional uint64 submit_ts = 6; + optional int64 read_injection = 7; } message PreloadDataPartsResp @@ -416,28 +417,6 @@ message SendDataPartsResp optional string exception = 1; } -message SendCnchHiveDataPartsReq -{ - -} - -message SendCnchHiveDataPartsResp -{ -} - -message SendCnchFileDataPartsReq -{ - required uint64 txn_id = 1; - required string database_name = 2; - required string table_name = 3; - repeated CnchFilePartModel parts = 4; -} - -message SendCnchFileDataPartsResp -{ - optional string exception = 1; -} - message CheckDataPartsReq { required uint64 txn_id = 1; @@ -490,19 +469,39 @@ message TableDataParts optional uint64 table_version = 11; } +// Send original (cnch) table definition and override to worker, in order to +// 1. remove server's parsing & formatting overhead +// before +// server: parse(create query) -> rewrite(ast) -> format(ast) -> send(new create query) +// worker: parse(new create query) -> create table(ast) +// after +// server: send(create query, override) +// worker: parse(create query) -> rewrite(ast) -> create table(ast, override) +// 2. be able to cache table template at worker +message CacheableTableDefinition +{ + required StorageID storage_id = 1; + required string definition = 2; + optional string dynamic_object_column_schema = 3; // present if not empty + required uint32 local_engine_type = 4; // WorkerEngineType + required string local_table_name = 5; + optional string local_underlying_dictionary_tables = 6; // for bitengine +} + message SendResourcesReq { required uint64 txn_id = 1; required uint64 primary_txn_id = 2; required uint64 timeout = 3; - /// create queries repeated string create_queries = 4; - /// data parts repeated TableDataParts data_parts = 5; optional string disk_cache_mode = 6; repeated UDFInfo udf_infos = 7; repeated string dynamic_object_column_schema = 8; optional WorkerInfo worker_info = 9; + // can coexist with `create_queries' + repeated CacheableTableDefinition cacheable_create_queries = 10; + optional string session_timezone = 11; } message SendResourcesResp @@ -551,6 +550,28 @@ message CheckMySQLSyncThreadStatusResp optional bool is_running = 2; } +message ExecuteDedupTaskReq +{ + required uint64 txn_id = 1; + required uint32 rpc_port = 2; + required UUID table_uuid = 3; + repeated DataModelPart new_parts = 4; + repeated string new_parts_paths = 5; + repeated DataModelDeleteBitmap delete_bitmaps_for_new_parts = 6; + repeated DataModelPart staged_parts = 7; + repeated string staged_parts_paths = 8; + repeated DataModelDeleteBitmap delete_bitmaps_for_staged_parts = 9; + repeated DataModelPart visible_parts = 10; + repeated string visible_parts_paths = 11; + repeated DataModelDeleteBitmap delete_bitmaps_for_visible_parts = 12; + required uint32 dedup_mode = 13; +} + +message ExecuteDedupTaskResp +{ + optional string exception = 1; +} + service CnchWorkerService { rpc executeSimpleQuery(ExecuteSimpleQueryReq) returns (ExecuteSimpleQueryResp); @@ -587,13 +608,12 @@ service CnchWorkerService rpc getDedupWorkerStatus(GetDedupWorkerStatusReq) returns (GetDedupWorkerStatusResp); rpc sendCreateQuery(SendCreateQueryReq) returns (SendCreateQueryResp); - rpc sendQueryDataParts(SendDataPartsReq) returns (SendDataPartsResp); - rpc sendCnchHiveDataParts(SendCnchHiveDataPartsReq) returns (SendCnchHiveDataPartsResp); - rpc sendCnchFileDataParts(SendCnchFileDataPartsReq) returns (SendCnchFileDataPartsResp); rpc checkDataParts(CheckDataPartsReq) returns (CheckDataPartsResp); rpc sendOffloading(SendOffloadingReq) returns (SendOffloadingResp); rpc sendResources(SendResourcesReq) returns (SendResourcesResp); rpc removeWorkerResource(RemoveWorkerResourceReq) returns (RemoveWorkerResourceResp); rpc preloadDataParts(PreloadDataPartsReq) returns (PreloadDataPartsResp); rpc dropPartDiskCache(DropPartDiskCacheReq) returns (DropPartDiskCacheResp); + + rpc executeDedupTask(ExecuteDedupTaskReq) returns (ExecuteDedupTaskResp); } diff --git a/src/Protos/data_models.proto b/src/Protos/data_models.proto index a22533755e7..25291a3337f 100644 --- a/src/Protos/data_models.proto +++ b/src/Protos/data_models.proto @@ -393,6 +393,11 @@ message SQLBinding { required string tenant_id = 6; } +message PreparedStatementItem { + required string name = 1; + required string create_statement = 2; +} + message VirtualWarehouseSettings { /// basic information /// // {READ, WRITE, TASK} @@ -766,6 +771,7 @@ message DataModelAccessEntity optional string create_sql = 2; // contains CREATE and GRANT queries of the entity optional uint64 commit_time = 3; optional UUID uuid = 4; + optional string sensitive_sql = 5; // contains GRANT queries that explicitly state the resource } message DataModelSensitiveDatabase @@ -811,3 +817,10 @@ message ManifestListModel repeated uint64 txn_ids = 2; optional bool checkpoint = 3; } + +message DataModelLargeKVMeta +{ + required bytes uuid = 1; //uuid of the large KV + required uint64 subkv_number = 2; + optional uint64 value_size = 3; // record the value size of the large KV +} diff --git a/src/Protos/enum.proto b/src/Protos/enum.proto index fc290264c4b..9567e54320d 100644 --- a/src/Protos/enum.proto +++ b/src/Protos/enum.proto @@ -182,3 +182,11 @@ message TopNFilteringAlgorithm { Heap = 3; } } + +message ReportProfileType { + enum Enum { + Unspecified = 0; + QueryPlan = 1; + QueryPipeline = 2; + } +} diff --git a/src/Protos/plan_node.proto b/src/Protos/plan_node.proto index 21e14a875b7..3b233aca89e 100644 --- a/src/Protos/plan_node.proto +++ b/src/Protos/plan_node.proto @@ -89,6 +89,7 @@ message SortingStep { FULL = 0; MERGE = 1; PARTIAL = 2; + PARTIAL_NO_MERGE = 3; } } @@ -207,6 +208,7 @@ message TableWriteStep { required ITransformingStep query_plan_base = 1; required Target target = 2; + optional bool insert_select_with_profiles = 3; } message TableFinishStep { @@ -269,6 +271,7 @@ message JoinStep { required bool is_magic = 17; required bool is_ordered = 18; repeated RuntimeFilterBuilders runtime_filter_builders = 19; + repeated bool key_ids_null_safe = 25; } message MergeSortingStep { diff --git a/src/Protos/plan_node_utils.proto b/src/Protos/plan_node_utils.proto index 5bcfde77164..b7209f7a300 100644 --- a/src/Protos/plan_node_utils.proto +++ b/src/Protos/plan_node_utils.proto @@ -109,6 +109,7 @@ message Partitioning { required bool enforce_round_robin = 5; required Component.Enum component = 6; required bool exactly_match = 7; + optional AST bucket_expr = 8; } // possibly nullptr @@ -264,6 +265,7 @@ message PlanSegmentOutput { optional string shuffle_hash_function = 2; optional uint32 parallel_size = 3; optional bool keep_order = 4; + optional FieldVector shuffle_function_parameters = 5; } message WindowFrame { @@ -346,3 +348,20 @@ message SettingChange { message SettingsChanges { repeated SettingChange settings_changes = 1; } + +message RuntimeAttributeDescription { + required string description = 1; + repeated NameWithAliasPair details = 2; + + // If the attribute information is complex, can use json + optional string additional = 3; +} + +message InputProfileMetric{ + required uint64 id = 1; + required uint64 input_rows = 2; + required uint64 input_bytes = 3; + required uint32 input_wait_sum_elapsed_us = 4; + required uint32 input_wait_max_elapsed_us = 5; + required uint32 input_wait_min_elapsed_us = 6; +} diff --git a/src/Protos/plan_segment_manager.proto b/src/Protos/plan_segment_manager.proto index da78625b4c3..dcc6f979acb 100644 --- a/src/Protos/plan_segment_manager.proto +++ b/src/Protos/plan_segment_manager.proto @@ -1,9 +1,11 @@ syntax = "proto2"; + package DB.Protos; import "data_models.proto"; import "plan_node_utils.proto"; import "plan_node.proto"; +import "enum.proto"; option cc_generic_services = true; @@ -17,6 +19,7 @@ message PlanSegment { optional uint32 segment_id = 7; optional string query_id = 8; optional AddressInfo coordinator_address = 9; + optional ReportProfileType.Enum profile_type = 10; } message CancelQueryRequest { @@ -116,6 +119,7 @@ message QueryCommon { optional TraceMeta trace_meta = 12; optional bool check_session = 13; + optional uint64 query_expiration_timestamp = 14; } message SubmitPlanSegmentRequest { @@ -192,6 +196,48 @@ message SendProgressRequest { optional Progress progress = 4; } +message ProfileMetric { + required uint64 id = 1; + required string name = 2; + repeated uint64 children_ids = 3; + required uint32 parallel_size = 4; + + required uint32 sum_elapsed_us = 5; + required uint32 max_elapsed_us = 6; + required uint32 min_elapsed_us = 7; + + required uint64 output_rows = 8; + required uint64 output_bytes = 9; + required uint32 output_wait_sum_elapsed_us = 10; + required uint32 output_wait_max_elapsed_us = 11; + required uint32 output_wait_min_elapsed_us = 12; + + repeated InputProfileMetric inputs = 13; + + map attributes = 14; +} + +message PlanSegmentProfileRequest{ + required string query_id = 1; + required uint32 segment_id = 2; + required bool is_succeed = 3; + required string worker_address = 4; + + optional uint64 profile_root_id = 6; + map profiles = 7; + + optional uint64 read_rows = 8; + optional uint64 read_bytes = 9; + optional uint64 total_cpu_ms = 10; + optional uint64 query_duration_ms = 11; + optional uint64 io_wait_ms = 12; + optional string error_message = 13; +} + +message PlanSegmentProfileResponse { + optional string message = 1; +} + message SendProgressResponse { optional string message = 1; } @@ -211,4 +257,6 @@ service PlanSegmentManagerService { rpc batchReportProcessorProfileMetrics(BatchReportProcessorProfileMetricRequest) returns (ReportProcessorProfileMetricResponse); rpc sendProgress(SendProgressRequest) returns (SendProgressResponse); + + rpc sendPlanSegmentProfile(PlanSegmentProfileRequest) returns (PlanSegmentProfileResponse); }; diff --git a/src/QueryPlan/AggregatingStep.cpp b/src/QueryPlan/AggregatingStep.cpp index d58c3760444..4ba584d78e2 100644 --- a/src/QueryPlan/AggregatingStep.cpp +++ b/src/QueryPlan/AggregatingStep.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -41,11 +42,12 @@ #include #include #include -#include -#include "Core/SettingsEnums.h" -#include +#include #include -#include +#include + +#include +#include namespace DB { @@ -232,11 +234,11 @@ AggregatingStep::createParams(Block header_before_aggregation, AggregateDescript return Aggregator::Params( - header_before_aggregation, keys, aggregates, overflow_row, 0, OverflowMode::THROW, - 0, - 0, - 0, - false, + header_before_aggregation, keys, aggregates, overflow_row, 0, OverflowMode::THROW, + 0, + 0, + 0, + false, 10485760, false, nullptr, 0, 0, false, 0); } @@ -322,6 +324,15 @@ AggregatingStep::AggregatingStep( , no_shuffle(no_shuffle_) { + NameSet output_names; + for (const auto & key : keys) + if (!output_names.emplace(key).second) + throw Exception(ErrorCodes::LOGICAL_ERROR, "duplicate group by key: {}", key); + + for (const auto & aggregate : params.aggregates) + if (!output_names.emplace(aggregate.column_name).second) + throw Exception(ErrorCodes::LOGICAL_ERROR, "duplicate aggreagte function output name: {}", aggregate.column_name); + // final = final && !totals && !cube & !rollup; setInputStreams(input_streams); } diff --git a/src/QueryPlan/ExpandStep.cpp b/src/QueryPlan/ExpandStep.cpp index 5e43cc6f009..d3969849507 100644 --- a/src/QueryPlan/ExpandStep.cpp +++ b/src/QueryPlan/ExpandStep.cpp @@ -42,6 +42,8 @@ ExpandStep::ExpandStep( { if (unlikely(!name_to_type[item.first])) throw Exception(ErrorCodes::LOGICAL_ERROR, "ExpandStep miss type info for column " + item.first); + if (unlikely(!input_stream_.header.has(item.first))) + throw Exception(ErrorCodes::LOGICAL_ERROR, "ExpandStep miss input column " + item.first); output_stream->header.insert(ColumnWithTypeAndName{name_to_type[item.first], item.first}); } auto group_id_symbol_type = std::make_shared(); @@ -54,9 +56,7 @@ void ExpandStep::setInputStreams(const DataStreams & input_streams_) Block block; for (auto & input : input_streams[0].header) - { - block.insert(ColumnWithTypeAndName{JoinCommon::tryConvertTypeToNullable(input.type), input.name}); - } + block.insert(ColumnWithTypeAndName{input.type, input.name}); output_stream->header = block; auto group_id_symbol_type = std::make_shared(); output_stream->header.insert(ColumnWithTypeAndName{group_id_symbol_type, group_id_symbol}); diff --git a/src/QueryPlan/FilterStep.cpp b/src/QueryPlan/FilterStep.cpp index f410b489d05..dea55faf738 100644 --- a/src/QueryPlan/FilterStep.cpp +++ b/src/QueryPlan/FilterStep.cpp @@ -13,6 +13,7 @@ * limitations under the License. */ +#include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include namespace DB { @@ -210,4 +212,66 @@ void FilterStep::prepare(const PreparedStatementContext & prepared_context) { prepared_context.prepare(filter); } + +std::pair FilterStep::splitLargeInValueList(const ConstASTPtr & filter, UInt64 limit) +{ + std::vector removed_large_in_value_list; + std::vector large_in_value_list; + for (auto & predicate : PredicateUtils::extractConjuncts(filter)) + { + LOG_DEBUG(&Poco::Logger::get("FilterStep"), " predicate : {}", predicate->formatForErrorMessage()); + + if (predicate->as() && + (predicate->as().name == "in" || + predicate->as().name == "globalIn" || + predicate->as().name == "notIn" || + predicate->as().name == "globalNotIn")) + { + const auto & function = predicate->as(); + if (function.arguments->getChildren()[1]->as()) + { + ASTFunction & tuple = function.arguments->getChildren()[1]->as(); + size_t size = tuple.arguments->getChildren().size(); + if (size > limit) + { + large_in_value_list.emplace_back(predicate); + continue; + } + } + } + removed_large_in_value_list.emplace_back(predicate); + } + + return std::make_pair( + PredicateUtils::combineConjuncts(removed_large_in_value_list), PredicateUtils::combineConjuncts(large_in_value_list)); +} + +std::vector FilterStep::removeLargeInValueList(const std::vector & filters, UInt64 limit) +{ + std::vector removed_large_in_value_list; + for (const auto & predicate : filters) + { + if (predicate->as() && + (predicate->as().name == "in" || + predicate->as().name == "globalIn" || + predicate->as().name == "notIn" || + predicate->as().name == "globalNotIn") + ) + { + const auto & function = predicate->as(); + if (function.arguments->getChildren()[1]->as()) + { + ASTFunction & tuple = function.arguments->getChildren()[1]->as(); + size_t size = tuple.arguments->getChildren().size(); + if (size > limit) + { + continue; + } + } + } + removed_large_in_value_list.emplace_back(predicate); + } + return removed_large_in_value_list; +} + } diff --git a/src/QueryPlan/FilterStep.h b/src/QueryPlan/FilterStep.h index bca2a1fada1..8fc9a5a8cf0 100644 --- a/src/QueryPlan/FilterStep.h +++ b/src/QueryPlan/FilterStep.h @@ -15,6 +15,7 @@ #pragma once #include +#include namespace DB { @@ -61,6 +62,8 @@ class FilterStep : public ITransformingStep void prepare(const PreparedStatementContext & prepared_context) override; + static std::pair splitLargeInValueList(const ConstASTPtr & filter, UInt64 limit); + static std::vector removeLargeInValueList(const std::vector & filters, UInt64 limit); private: ActionsDAGPtr actions_dag; ConstASTPtr filter; diff --git a/src/QueryPlan/GraphvizPrinter.cpp b/src/QueryPlan/GraphvizPrinter.cpp index 0fae1de6f5c..d49045e1e8c 100644 --- a/src/QueryPlan/GraphvizPrinter.cpp +++ b/src/QueryPlan/GraphvizPrinter.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,7 @@ #include #include #include +#include #include #include @@ -579,15 +581,18 @@ void PlanNodePrinter::printNode( out << "Actual Stats \\n"; out << "Output: " << PlanPrinter::TextPrinter::prettyNum(profile->output_rows) << " rows(" << PlanPrinter::TextPrinter::prettyBytes(profile->output_bytes) << "). " - << " Wait Time: " << PlanPrinter::TextPrinter::prettySeconds(profile->max_output_wait_elapsed_us) + << " Wait Time: " << PlanPrinter::TextPrinter::prettySeconds(profile->output_wait_max_elapsed_us) << " Wall Time: " << PlanPrinter::TextPrinter::prettySeconds(profile->max_elapsed_us) << " \\n"; - if (!node.getChildren().empty() && profile->inputs_profile.contains(node.getChildren()[0]->getId())) + if (!node.getChildren().empty() && profile->inputs.contains(node.getChildren()[0]->getId())) { if (node.getChildren().size() == 1) { out << "Input: "; - out << PlanPrinter::TextPrinter::prettyNum(profile->inputs_profile[node.getChildren()[0]->getId()].input_rows) - << " rows \\n"; + out << PlanPrinter::TextPrinter::prettyNum(profile->inputs[node.getChildren()[0]->getId()].input_rows) << " rows(" + << PlanPrinter::TextPrinter::prettyBytes(profile->inputs[node.getChildren()[0]->getId()].input_bytes) << "). " + << "Wait Time: " + << PlanPrinter::TextPrinter::prettySeconds(profile->inputs[node.getChildren()[0]->getId()].input_wait_max_elapsed_us) + << " \\n"; } else { @@ -595,9 +600,10 @@ void PlanNodePrinter::printNode( out << "Input: \\n"; for (const auto & child : node.getChildren()) { - auto input_profile = profile->inputs_profile[child->getId()]; - out << "source [" << num << "] : "; - out << PlanPrinter::TextPrinter::prettyNum(input_profile.input_rows) << " rows \\n"; + auto input_profile = profile->inputs[child->getId()]; + out << "source [" << num << "] : " << PlanPrinter::TextPrinter::prettyNum(input_profile.input_rows) << " rows(" + << PlanPrinter::TextPrinter::prettyBytes(input_profile.input_bytes) << "). " + << "Wait Time: " << PlanPrinter::TextPrinter::prettySeconds(input_profile.input_wait_max_elapsed_us) << " \\n"; ++num; } } @@ -1234,7 +1240,8 @@ String StepPrinter::printStep(const IQueryPlanStep & step, bool include_output) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } } return details.str(); @@ -1314,7 +1321,7 @@ String StepPrinter::printExpandStep(const ExpandStep & step, bool) } std::string result = ss.str(); - details << step.getGroupIdSymbol() << "[" << result << "]"; + details << step.getGroupIdSymbol() << "[" << result << "]"; details << "|"; details << "Groups"; details << "|"; @@ -1350,7 +1357,8 @@ String StepPrinter::printFilterStep(const FilterStep & step, bool include_output for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } } @@ -1437,7 +1445,7 @@ String StepPrinter::printJoinStep(const JoinStep & step) details << "JoinKeys\\n"; for (int i = 0; i < static_cast(left.size()); ++i) { - details << left.at(i) << "=" << right.at(i) << "\\n"; + details << left.at(i) << "=" << right.at(i) << (step.getKeyIdNullSafe(i) ? "(null aware)" : "") << "\\n"; } details << "|"; if (!PredicateUtils::isTruePredicate(step.getFilter())) @@ -1469,6 +1477,13 @@ String StepPrinter::printJoinStep(const JoinStep & step) details << "isOrdered:" << step.isOrdered() << "|"; } + if (step.isHasUsing()) + { + auto require_right_keys = step.getRequireRightKeys(); + auto using_str = require_right_keys ? fmt::format("{}", fmt::join(*require_right_keys, ",")) : "nullopt"; + details << "hasUsing:" << using_str << "|"; + } + if (!step.getRuntimeFilterBuilders().empty()) { details << "Runtime Filters \\n"; @@ -1482,7 +1497,8 @@ String StepPrinter::printJoinStep(const JoinStep & step) for (const auto & item : step.getOutputStream().header) { details << item.name << ":"; - details << item.type->getName() << "\\n"; + details << item.type->getName() << " "; + details << (item.column ? item.column->getName() : "") << "\\n"; } return details.str(); } @@ -1500,7 +1516,8 @@ String StepPrinter::printArrayJoinStep(const ArrayJoinStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1591,7 +1608,8 @@ String StepPrinter::printAggregatingStep(const AggregatingStep & step, bool incl for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } } @@ -1635,7 +1653,8 @@ String StepPrinter::printMarkDistinctStep(const MarkDistinctStep & step, bool /* for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1741,7 +1760,8 @@ String StepPrinter::printUnionStep(const UnionStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1755,7 +1775,8 @@ String StepPrinter::printIntersectOrExceptStep(const IntersectOrExceptStep & ste for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1771,7 +1792,8 @@ String StepPrinter::printIntersectStep(const IntersectStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1785,7 +1807,8 @@ String StepPrinter::printExceptStep(const ExceptStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1813,6 +1836,8 @@ String StepPrinter::printExchangeStep(const ExchangeStep & step) } }; details << f(step.getExchangeMode()); + details << "|"; + details << step.getSchema().toString(); if (step.needKeepOrder()) { @@ -1830,7 +1855,8 @@ String StepPrinter::printExchangeStep(const ExchangeStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1857,7 +1883,8 @@ String StepPrinter::printRemoteExchangeSourceStep(const RemoteExchangeSourceStep for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1900,7 +1927,8 @@ String StepPrinter::printTableFinishStep(const TableFinishStep & step) for (auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -1989,11 +2017,11 @@ String StepPrinter::printTableScanStep(const TableScanStep & step) { ASTSampleRatio * sample = query->sampleSize()->as(); details << "Sample : \\n"; - details << "Sample Size : " << ASTSampleRatio::toString(sample->ratio)<< "\\n"; + details << "Sample Size : " << ASTSampleRatio::toString(sample->ratio) << "\\n"; if (query->sampleOffset()) { ASTSampleRatio * offset = query->sampleOffset()->as(); - details << "Sample Offset : " << ASTSampleRatio::toString(offset->ratio)<< "\\n"; + details << "Sample Offset : " << ASTSampleRatio::toString(offset->ratio) << "\\n"; } details << "|"; } @@ -2046,7 +2074,8 @@ String StepPrinter::printTableScanStep(const TableScanStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); @@ -2138,7 +2167,8 @@ String StepPrinter::printLimitStep(const LimitStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } if (step.isPartial()) details << "|" @@ -2156,7 +2186,8 @@ String StepPrinter::printOffsetStep(const OffsetStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2174,7 +2205,8 @@ String StepPrinter::printLimitByStep(const LimitByStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2195,7 +2227,8 @@ String StepPrinter::printMergeSortingStep(const MergeSortingStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2236,12 +2269,18 @@ String StepPrinter::printSortingStep(const SortingStep & step) details << "|"; details << "partial"; } + if (step.getStage() == SortingStep::Stage::PARTIAL_NO_MERGE) + { + details << "|"; + details << "partial no merge"; + } details << "|"; details << "Output |"; for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2263,7 +2302,8 @@ String StepPrinter::printPartialSortingStep(const PartialSortingStep & step) for (auto & column : step_ptr->getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } */ return details.str(); @@ -2287,7 +2327,8 @@ String StepPrinter::printMergingSortedStep(const MergingSortedStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2308,7 +2349,8 @@ String StepPrinter::printDistinctStep(const DistinctStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2372,7 +2414,8 @@ String StepPrinter::printApplyStep(const ApplyStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2384,7 +2427,8 @@ String StepPrinter::printEnforceSingleRowStep(const EnforceSingleRowStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2395,7 +2439,8 @@ String StepPrinter::printAssignUniqueIdStep(const AssignUniqueIdStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2414,7 +2459,8 @@ String StepPrinter::printCTERefStep(const CTERefStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); @@ -2447,7 +2493,8 @@ String StepPrinter::printPartitionTopNStep(const PartitionTopNStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2504,7 +2551,8 @@ String StepPrinter::printWindowStep(const WindowStep & step) for (auto & column : step_ptr->getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } */ return details.str(); @@ -2547,7 +2595,8 @@ String StepPrinter::printExplainAnalyzeStep(const ExplainAnalyzeStep & step) for (auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2571,7 +2620,8 @@ String StepPrinter::printTopNFilteringStep(const TopNFilteringStep & step) for (auto & column : stepPtr->getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } */ return details.str(); @@ -2598,7 +2648,8 @@ String StepPrinter::printFillingStep(const FillingStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2613,7 +2664,8 @@ String StepPrinter::printTotalsHavingStep(const TotalsHavingStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2625,7 +2677,8 @@ String StepPrinter::printExtremesStep(const ExtremesStep & step) for (const auto & column : step.getOutputStream().header) { details << column.name << ":"; - details << column.type->getName() << "\\n"; + details << column.type->getName() << " "; + details << (column.column ? column.column->getName() : "") << "\\n"; } return details.str(); } @@ -2821,12 +2874,11 @@ void GraphvizPrinter::printLogicalPlan(PlanNodeBase & root, ContextMutablePtr & } } -void GraphvizPrinter::printLogicalPlan( - QueryPlan & plan, ContextMutablePtr & context, const String & name, StepAggregatedOperatorProfiles /*profiles*/) +void GraphvizPrinter::printLogicalPlan(QueryPlan & plan, ContextMutablePtr & context, const String & name, StepProfiles profiles) { if (context->getSettingsRef().print_graphviz) { - auto const graphviz = GraphvizPrinter::printLogicalPlan(*plan.getPlanNode(), &plan.getCTEInfo()); + auto const graphviz = GraphvizPrinter::printLogicalPlan(*plan.getPlanNode(), &plan.getCTEInfo(), profiles); cleanDotFiles(context); std::stringstream path; @@ -2909,7 +2961,12 @@ String GraphvizPrinter::printPipeline(const Processors & processors, const Execu buffer << "\\n" << "input: " << input_rows << " rows"; } - + buffer << "\\n" + << "input wait time: " << node->processor->getInputWaitElapsedUs() << " us."; + buffer << "\\n" + << "output wait time: " << node->processor->getOutputWaitElapsedUs() << " us."; + buffer << "\\n" + << "elapsed time: " << node->processor->getElapsedUs() << " us."; buffer << "\\n" << "execution time: " << node->execution_time_ns / 1e9 << " sec."; buffer << "\\n" @@ -3430,7 +3487,7 @@ void GraphvizPrinter::addID(ASTPtr & ast, std::unordered_map & a } } -String GraphvizPrinter::printLogicalPlan(PlanNodeBase & node, CTEInfo * cte_info, StepAggregatedOperatorProfiles profiles) +String GraphvizPrinter::printLogicalPlan(PlanNodeBase & node, CTEInfo * cte_info, StepProfiles profiles) { std::stringstream out; out << "digraph logical_plan {\n rankdir=\"BT\" \n"; @@ -3562,6 +3619,11 @@ void GraphvizPrinter::appendPlanSegmentNode(std::stringstream & out, const PlanS out << "keeporder "; } + if (input->isStable()) + { + out << "stable "; + } + out << "\n"; } out << "\n"; @@ -3578,6 +3640,14 @@ void GraphvizPrinter::appendPlanSegmentNode(std::stringstream & out, const PlanS { out << "keeporder "; } + out << "hash_func:" << input->getShuffleFunctionName(); + + auto visitor = FieldVisitorToString(); + out << " params:"; + for (auto item : input->getShuffleFunctionParams()) + { + out << " " << applyVisitor(visitor, item); + } out << "\n"; } out << "\n"; @@ -3706,7 +3776,6 @@ String GraphvizPrinter::printGroup(const Group & group, const std::unordered_map head_step = group.getLogicalExpressions()[0]->getStep().get(); auto fold = [](std::string a, GroupId b) { return std::move(a) + ", " + std::to_string(b); }; - auto fold_string = [](String a, const String & b) { return std::move(a) + ", " + b; }; auto expr_to_str = [&](const GroupExprPtr & expr) { if (!expr) @@ -3848,55 +3917,9 @@ String GraphvizPrinter::printGroup(const Group & group, const std::unordered_map // winners - auto partition_str = [&](const Partitioning & partitioning) { - auto component_str = " ANY"; - if (partitioning.getComponent() == Partitioning::Component::COORDINATOR) - component_str = " COORDINATOR"; - else if (partitioning.getComponent() == Partitioning::Component::WORKER) - component_str = " WORKER"; - - if (partitioning.getHandle() == Partitioning::Handle::SINGLE) - return String("SINGLE") + component_str; - else if (partitioning.getHandle() == Partitioning::Handle::FIXED_BROADCAST) - return String("BROADCAST") + component_str; - else if (partitioning.getHandle() == Partitioning::Handle::ARBITRARY) - return String("ARBITRARY") + component_str; - else if (partitioning.getHandle() == Partitioning::Handle::BUCKET_TABLE) - return String("BUCKET_TABLE") + component_str; - else if (partitioning.getHandle() == Partitioning::Handle::FIXED_ARBITRARY) - return String("FIXED_ARBITRARY") + component_str; - else if (partitioning.getHandle() == Partitioning::Handle::FIXED_HASH) - { - if (partitioning.getColumns().empty()) - { - return String("FIXED_HASH[]") + component_str; - } - else - { - auto result = String("FIXED_HASH[") - + std::accumulate( - std::next(partitioning.getColumns().begin()), - partitioning.getColumns().end(), - partitioning.getColumns()[0], - fold_string) - + "]"; - if (partitioning.isEnforceRoundRobin()) - { - result += " RoundR"; - } - if (partitioning.isRequireHandle()) - { - result += " handle"; - } - return result + component_str; - } - } - else - return String("UNKNOWN") + component_str; - }; auto property_str = [&](const Property & property) { std::stringstream ss; - ss << partition_str(property.getNodePartitioning()); + ss << property.getNodePartitioning().toString(); ss << " "; ss << property.getCTEDescriptions().toString(); return ss.str(); @@ -3936,7 +3959,7 @@ String GraphvizPrinter::printGroup(const Group & group, const std::unordered_map if (auto exchange_step = dynamic_cast(winner->getRemoteExchange()->getStep().get())) { out << "enforce: "; - out << partition_str(exchange_step->getSchema()); + out << exchange_step->getSchema().toString(); out << "
"; } } diff --git a/src/QueryPlan/GraphvizPrinter.h b/src/QueryPlan/GraphvizPrinter.h index df12702e602..9fcdb1b9493 100644 --- a/src/QueryPlan/GraphvizPrinter.h +++ b/src/QueryPlan/GraphvizPrinter.h @@ -42,7 +42,7 @@ class PlanNodePrinter : public PlanNodeVisitor bool with_id_ = false, CTEInfo * cte_info = nullptr, PlanCostMap plan_cost_map_ = {}, - StepAggregatedOperatorProfiles profiles_ = {}) + StepProfiles profiles_ = {}) : out(out_) , cte_helper(cte_info ? std::make_optional>(*cte_info) : std::nullopt) , with_id(with_id_) @@ -103,7 +103,7 @@ class PlanNodePrinter : public PlanNodeVisitor std::optional> cte_helper; bool with_id; PlanCostMap plan_cost_map; - StepAggregatedOperatorProfiles profiles; + StepProfiles profiles; void printNode(const PlanNodeBase & node, const String & label, const String & details, const String & color, PrinterContext & context); Void visitChildren(PlanNodeBase &, PrinterContext &); void printHints(const PlanNodeBase & node); @@ -256,7 +256,7 @@ class GraphvizPrinter static void printAST(const ASTPtr &, ContextMutablePtr & context, const String & visitor); static void printLogicalPlan(PlanNodeBase &, ContextMutablePtr &, const String & name); - static void printLogicalPlan(QueryPlan &, ContextMutablePtr &, const String & name, StepAggregatedOperatorProfiles profiles = {}); + static void printLogicalPlan(QueryPlan &, ContextMutablePtr &, const String & name, StepProfiles profiles = {}); static void printMemo(const Memo & memo, const ContextMutablePtr & context, const String & name); static void printMemo(const Memo & memo, GroupId root_id, const ContextMutablePtr & context, const String & name); static void printPlanSegment(const PlanSegmentTreePtr &, const ContextMutablePtr &); @@ -269,7 +269,7 @@ class GraphvizPrinter static String printAST(ASTPtr); static void addID(ASTPtr & ast, std::unordered_map & asts, std::shared_ptr> & max_node_id); - static String printLogicalPlan(PlanNodeBase &, CTEInfo * cte_info = nullptr, StepAggregatedOperatorProfiles profiles = {}); + static String printLogicalPlan(PlanNodeBase &, CTEInfo * cte_info = nullptr, StepProfiles profiles = {}); static String printPlanSegmentNodes(const PlanSegmentTreePtr &, const ContextMutablePtr &); static void appendPlanSegmentNodes( std::stringstream & out, diff --git a/src/QueryPlan/Hints/ImplementJoinOperationHints.cpp b/src/QueryPlan/Hints/ImplementJoinOperationHints.cpp index ad3820709c2..4225b6f0410 100644 --- a/src/QueryPlan/Hints/ImplementJoinOperationHints.cpp +++ b/src/QueryPlan/Hints/ImplementJoinOperationHints.cpp @@ -93,6 +93,7 @@ void JoinOperationHintsVisitor::visitJoinNode(JoinNode & node, Void & v) step.getKeepLeftReadInOrder(), step.getRightKeys(), step.getLeftKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), @@ -142,6 +143,7 @@ void JoinOperationHintsVisitor::visitJoinNode(JoinNode & node, Void & v) step.getKeepLeftReadInOrder(), step.getRightKeys(), step.getLeftKeys(), + step.getKeyIdsNullSafe(), step.getFilter(), step.isHasUsing(), step.getRequireRightKeys(), diff --git a/src/QueryPlan/Hints/ImplementJoinOrderHints.cpp b/src/QueryPlan/Hints/ImplementJoinOrderHints.cpp index a8ff4809ea6..69bd0185f2b 100644 --- a/src/QueryPlan/Hints/ImplementJoinOrderHints.cpp +++ b/src/QueryPlan/Hints/ImplementJoinOrderHints.cpp @@ -93,7 +93,7 @@ PlanNodePtr JoinOrderHintsVisitor::swapJoinOrder(PlanNodePtr node, SwapOrderPtr kind = ASTTableJoin::Kind::Right; else if (step->getKind() == ASTTableJoin::Kind::Right) kind = ASTTableJoin::Kind::Left; - + DataStreams streams = {step->getInputStreams()[1], step->getInputStreams()[0]}; auto join_step = std::make_shared( streams, @@ -104,6 +104,7 @@ PlanNodePtr JoinOrderHintsVisitor::swapJoinOrder(PlanNodePtr node, SwapOrderPtr step->getKeepLeftReadInOrder(), step->getRightKeys(), step->getLeftKeys(), + step->getKeyIdsNullSafe(), step->getFilter(), step->isHasUsing(), step->getRequireRightKeys(), diff --git a/src/QueryPlan/IQueryPlanStep.cpp b/src/QueryPlan/IQueryPlanStep.cpp index ea815d6c3d3..0905bce891b 100644 --- a/src/QueryPlan/IQueryPlanStep.cpp +++ b/src/QueryPlan/IQueryPlanStep.cpp @@ -64,6 +64,31 @@ void DataStream::fillFromProto(const Protos::DataStream & proto) sort_mode = DataStream::SortModeConverter::fromProto(proto.sort_mode()); } +void RuntimeAttributeDescription::fillFromProto(const Protos::RuntimeAttributeDescription & proto) +{ + description = proto.description(); + for (const auto & proto_element : proto.details()) + { + auto name = proto_element.name(); + auto alias = proto_element.alias(); + name_and_detail.emplace_back(name, alias); + } + if (proto.has_additional()) + additional = proto.additional(); +} + +void RuntimeAttributeDescription::toProto(Protos::RuntimeAttributeDescription & proto) const +{ + proto.set_description(description); + for (const auto & [name, detail] : name_and_detail) + { + auto * proto_element = proto.add_details(); + proto_element->set_name(name); + proto_element->set_alias(detail); + } + proto.set_additional(additional); +} + const DataStream & IQueryPlanStep::getOutputStream() const { if (!hasOutputStream()) @@ -282,9 +307,9 @@ String IQueryPlanStep::toString(Type type) return "Unknown"; } -size_t IQueryPlanStep::hash() const +size_t IQueryPlanStep::hash(bool ignore_output_stream) const { - return hashPlanStep(*this); + return hashPlanStep(*this, ignore_output_stream); } } diff --git a/src/QueryPlan/IQueryPlanStep.h b/src/QueryPlan/IQueryPlanStep.h index 65b4b5f2e9a..1425adaafc0 100644 --- a/src/QueryPlan/IQueryPlanStep.h +++ b/src/QueryPlan/IQueryPlanStep.h @@ -14,6 +14,7 @@ */ #pragma once +#include #include #include #include @@ -117,6 +118,19 @@ using ContextPtr = std::shared_ptr; using PlanHints = std::vector; +struct RuntimeAttributeDescription +{ + String description; + std::vector> name_and_detail; + + // If the attribute information is complex, can use json + String additional; + + void fillFromProto(const Protos::RuntimeAttributeDescription & proto); + void toProto(Protos::RuntimeAttributeDescription & proto) const; +}; + + /// Single step of query plan. class IQueryPlanStep { @@ -291,7 +305,7 @@ class IQueryPlanStep virtual std::shared_ptr copy(ContextPtr) const = 0; - size_t hash() const; + size_t hash(bool ignore_output_stream = true) const; bool operator==(const IQueryPlanStep & r) const { @@ -302,6 +316,12 @@ class IQueryPlanStep virtual void prepare(const PreparedStatementContext &) { } + + std::unordered_map & getAttributeDescriptions() + { + return attribute_descriptions; + } + protected: DataStreams input_streams; std::optional output_stream; @@ -310,6 +330,9 @@ class IQueryPlanStep std::string step_description; PlanHints hints; + /// Text description of runtime attributes + std::unordered_map attribute_descriptions; + static void describePipeline(const Processors & processors, FormatSettings & settings); }; diff --git a/src/QueryPlan/IntermediateResultCacheStep.cpp b/src/QueryPlan/IntermediateResultCacheStep.cpp index 3d2c97033f6..7710a1c29d1 100644 --- a/src/QueryPlan/IntermediateResultCacheStep.cpp +++ b/src/QueryPlan/IntermediateResultCacheStep.cpp @@ -42,6 +42,14 @@ QueryPipelinePtr IntermediateResultCacheStep::processCacheTransform( if (!cache) return std::move(pipelines[0]); + LOG_DEBUG( + log, + "process cache transform for digest:{}, write:{}, read:{}, all_part_in_cache:{}", + cache_param.digest, + cache_holder->write_cache.size(), + cache_holder->read_cache.size(), + cache_holder->all_part_in_cache); + const auto & settings = build_settings.context->getSettingsRef(); // write cache or skip pipeline if (!cache_holder->write_cache.empty() || cache_holder->all_part_in_cache) @@ -50,7 +58,7 @@ QueryPipelinePtr IntermediateResultCacheStep::processCacheTransform( auto cache_max_bytes = settings.intermediate_result_cache_max_bytes; auto cache_max_rows = settings.intermediate_result_cache_max_rows; return std::make_shared( - header, cache, cache_param, cache_max_bytes, cache_max_rows, cache_holder->all_part_in_cache); + header, cache, cache_param, cache_max_bytes, cache_max_rows, cache_holder); }); } diff --git a/src/QueryPlan/JoinStep.cpp b/src/QueryPlan/JoinStep.cpp index 8ec65b4f996..0687d2cb157 100644 --- a/src/QueryPlan/JoinStep.cpp +++ b/src/QueryPlan/JoinStep.cpp @@ -36,6 +36,7 @@ #include #include #include +#include namespace DB { @@ -89,9 +90,9 @@ JoinPtr JoinStep::makeJoin( } else { - table_join->addOnKeys(left, right, false); - // const String fn = null_safe_columns && (*null_safe_columns)[i] ? "bitEquals" : "equals"; - const String fn = "equals"; + bool null_safe = getKeyIdNullSafe(index); + table_join->addOnKeys(left, right, null_safe); + const String fn = null_safe ? "bitEquals" : "equals"; on_ast_terms.emplace_back(makeASTFunction(fn, left, right)); } } @@ -181,7 +182,6 @@ JoinPtr JoinStep::makeJoin( if (consumer) consumer->fixParallel(ConcurrentHashJoin::toPowerOfTwo(std::min(num_streams, 256))); return std::make_shared(table_join, num_streams, context->getSettings().parallel_join_rows_batch_threshold, r_sample_block); - } else if (join_algorithm == JoinAlgorithm::GRACE_HASH && GraceHashJoin::isSupported(table_join) && allow_grace_hash_join) { @@ -251,6 +251,7 @@ JoinStep::JoinStep( bool keep_left_read_in_order_, Names left_keys_, Names right_keys_, + std::vector key_ids_null_safe_, ConstASTPtr filter_, bool has_using_, std::optional> require_right_keys_, @@ -268,6 +269,7 @@ JoinStep::JoinStep( , keep_left_read_in_order(keep_left_read_in_order_) , left_keys(std::move(left_keys_)) , right_keys(std::move(right_keys_)) + , key_ids_null_safe(std::move(key_ids_null_safe_)) , filter(std::move(filter_)) , has_using(has_using_) , require_right_keys(std::move(require_right_keys_)) @@ -289,6 +291,18 @@ JoinStep::JoinStep( hints = std::move(hints_); } +bool JoinStep::hasKeyIdNullSafe() const +{ + return std::any_of(key_ids_null_safe.begin(), key_ids_null_safe.end(), [](auto x) { return x; }); +} + +bool JoinStep::getKeyIdNullSafe(size_t key_index) const +{ + if (key_index >= key_ids_null_safe.size()) + return false; + return key_ids_null_safe.at(key_index); +} + void JoinStep::setInputStreams(const DataStreams & input_streams_) { input_streams = input_streams_; @@ -312,9 +326,35 @@ QueryPipelinePtr JoinStep::updatePipeline(QueryPipelines pipelines, const BuildQ if (filter && !PredicateUtils::isTruePredicate(filter)) { Names output; - auto header = input_streams[0].header; + + bool has_outer_join_semantic = settings.context->getSettingsRef().join_use_nulls && + (isAny(getStrictness()) || isAll(getStrictness()) || getStrictness() == ASTTableJoin::Strictness::RightAny || isAsof(getStrictness())); + bool make_nullable_for_left = has_outer_join_semantic && isRightOrFull(getKind()); + bool make_nullable_for_right = has_outer_join_semantic && isLeftOrFull(getKind()); + + Block header; + for (const auto & col : input_streams[0].header) + { + if (make_nullable_for_left && JoinCommon::canBecomeNullable(col.type)) + { + header.insert(ColumnWithTypeAndName{col.column, JoinCommon::convertTypeToNullable(col.type), col.name}); + } + else + { + header.insert(col); + } + } for (const auto & col : input_streams[1].header) - header.insert(col); + { + if (make_nullable_for_right && JoinCommon::canBecomeNullable(col.type)) + { + header.insert(ColumnWithTypeAndName{col.column, JoinCommon::convertTypeToNullable(col.type), col.name}); + } + else + { + header.insert(col); + } + } for (const auto & item : header) output.emplace_back(item.name); output.emplace_back(filter->getColumnName()); @@ -399,6 +439,9 @@ bool JoinStep::supportReorder(bool support_filter, bool support_cross) const if (require_right_keys || has_using) return false; + if (hasKeyIdNullSafe()) + return false; + if (strictness != ASTTableJoin::Strictness::Unspecified && strictness != ASTTableJoin::Strictness::All) return false; @@ -441,6 +484,8 @@ void JoinStep::toProto(Protos::JoinStep & proto, bool for_hash_equals) const proto.add_left_keys(element); for (const auto & element : right_keys) proto.add_right_keys(element); + for (bool element : key_ids_null_safe) + proto.add_key_ids_null_safe(element); serializeASTToProto(filter, *proto.mutable_filter()); proto.set_has_using(has_using); proto.set_flag_require_right_keys(require_right_keys.has_value()); @@ -485,6 +530,9 @@ std::shared_ptr JoinStep::fromProto(const Protos::JoinStep & proto, Co std::vector right_keys; for (const auto & element : proto.right_keys()) right_keys.emplace_back(element); + std::vector key_ids_null_safe; + for (const auto & null_safe : proto.key_ids_null_safe()) + key_ids_null_safe.emplace_back(null_safe); auto filter = deserializeASTFromProto(proto.filter()); auto has_using = proto.has_using(); std::optional> require_right_keys; @@ -512,6 +560,7 @@ std::shared_ptr JoinStep::fromProto(const Protos::JoinStep & proto, Co keep_left_read_in_order, left_keys, right_keys, + key_ids_null_safe, filter, has_using, require_right_keys, @@ -585,6 +634,7 @@ std::shared_ptr JoinStep::copy(ContextPtr) const keep_left_read_in_order, left_keys, right_keys, + std::move(key_ids_null_safe), filter, has_using, require_right_keys, diff --git a/src/QueryPlan/JoinStep.h b/src/QueryPlan/JoinStep.h index 8161728a488..4a531a44351 100644 --- a/src/QueryPlan/JoinStep.h +++ b/src/QueryPlan/JoinStep.h @@ -59,6 +59,7 @@ class JoinStep : public IQueryPlanStep bool keep_left_read_in_order_ = false, Names left_keys_ = {}, Names right_keys_ = {}, + std::vector key_ids_null_safe_ = {}, ConstASTPtr filter_ = PredicateConst::TRUE_VALUE, bool has_using_ = false, std::optional> require_right_keys_ = std::nullopt, @@ -91,8 +92,16 @@ class JoinStep : public IQueryPlanStep const Names & getLeftKeys() const { return left_keys; } const Names & getRightKeys() const { return right_keys; } + const std::vector & getKeyIdsNullSafe() const { return key_ids_null_safe; } + bool hasKeyIdNullSafe() const; + bool getKeyIdNullSafe(size_t key_index) const; const ConstASTPtr & getFilter() const { return filter; } bool isHasUsing() const { return has_using; } + void resetUsing() + { + has_using = false; + require_right_keys = std::nullopt; + } std::optional> getRequireRightKeys() const { return require_right_keys; } ASOF::Inequality getAsofInequality() const { return asof_inequality; } DistributionType getDistributionType() const { return distribution_type; } @@ -100,7 +109,7 @@ class JoinStep : public IQueryPlanStep bool isCrossJoin() const { return kind == ASTTableJoin::Kind::Cross || (kind == ASTTableJoin::Kind::Inner && left_keys.empty()); } - bool isInnerJoin() const {return kind == ASTTableJoin::Kind::Inner; } + bool isInnerJoin() const { return kind == ASTTableJoin::Kind::Inner; } bool isOuterJoin() const { @@ -205,6 +214,7 @@ class JoinStep : public IQueryPlanStep Names left_keys; Names right_keys; + std::vector key_ids_null_safe; /** * Non-equals predicate diff --git a/src/QueryPlan/MergingAggregatedStep.cpp b/src/QueryPlan/MergingAggregatedStep.cpp index 7154e3b2550..f1c9fee281d 100644 --- a/src/QueryPlan/MergingAggregatedStep.cpp +++ b/src/QueryPlan/MergingAggregatedStep.cpp @@ -86,6 +86,15 @@ MergingAggregatedStep::MergingAggregatedStep( , memory_efficient_merge_threads(memory_efficient_merge_threads_) , should_produce_results_in_order_of_bucket_number(!(params->final) && memory_efficient_aggregation) { + NameSet output_names; + for (const auto & key : keys) + if (!output_names.emplace(key).second) + throw Exception(ErrorCodes::LOGICAL_ERROR, "duplicate group by key: {}", key); + + for (const auto & aggregate : params->params.aggregates) + if (!output_names.emplace(aggregate.column_name).second) + throw Exception(ErrorCodes::LOGICAL_ERROR, "duplicate aggreagte function output name: {}", aggregate.column_name); + /// Aggregation keys are distinct for (auto key : params->params.keys) output_stream->distinct_columns.insert(params->params.intermediate_header.getByPosition(key).name); diff --git a/src/QueryPlan/PlanPrinter.cpp b/src/QueryPlan/PlanPrinter.cpp index 429a490feda..9bda33ccf4c 100644 --- a/src/QueryPlan/PlanPrinter.cpp +++ b/src/QueryPlan/PlanPrinter.cpp @@ -18,25 +18,29 @@ #include #include #include -#include #include #include +#include #include #include #include #include #include +#include #include #include +#include #include #include #include #include +#include #include -#include + #include #include +#include namespace DB { @@ -84,18 +88,14 @@ namespace String PlanPrinter::textPlanNode(PlanNodeBase & node) { PlanCostMap costs; - StepAggregatedOperatorProfiles profiles; + StepProfiles profiles; TextPrinter printer{costs}; bool has_children = node.getChildren().empty(); return printer.printLogicalPlan(node, TextPrinterIntent{0, has_children}, profiles); } String PlanPrinter::textLogicalPlan( - QueryPlan & plan, - ContextMutablePtr context, - PlanCostMap costs, - const StepAggregatedOperatorProfiles & profiles, - const QueryPlanSettings & settings) + QueryPlan & plan, ContextMutablePtr context, PlanCostMap costs, const StepProfiles & profiles, const QueryPlanSettings & settings) { TextPrinter printer{costs, context, false, {}, settings, context->getSettingsRef().max_predicate_text_length}; bool has_children = !plan.getPlanNode()->getChildren().empty(); @@ -143,20 +143,31 @@ String PlanPrinter::textLogicalPlan( auto & optimizer_metrics = context->getOptimizerMetrics(); if (optimizer_metrics && !optimizer_metrics->getUsedMaterializedViews().empty()) { + auto tenant_id = context->getTenantId(); output += "note: Materialized Views is applied for " + std::to_string(optimizer_metrics->getUsedMaterializedViews().size()) + " times: "; const auto & views = optimizer_metrics->getUsedMaterializedViews(); auto it = views.begin(); - output += it->getDatabaseName() + "." + it->getTableName(); + output += getOriginalDatabaseName(it->getDatabaseName(), tenant_id) + "." + it->getTableName(); for (++it; it != views.end(); ++it) - output += ", " + it->getDatabaseName() + "." + it->getTableName(); + output += ", " + getOriginalDatabaseName(it->getDatabaseName() + "." + it->getTableName(), tenant_id); output += "."; } + if (plan.isShortCircuit()) + { + output += "note: Short Circuit is applied.\n"; + } + return output; } -String PlanPrinter::jsonLogicalPlan(QueryPlan & plan, std::optional plan_cost, const StepAggregatedOperatorProfiles & profiles, const PlanCostMap & costs, const QueryPlanSettings & settings) +String PlanPrinter::jsonLogicalPlan( + QueryPlan & plan, + std::optional plan_cost, + const StepProfiles & profiles, + const PlanCostMap & costs, + const QueryPlanSettings & settings) { std::ostringstream os; Poco::JSON::Object::Ptr json = new Poco::JSON::Object(true); @@ -187,7 +198,8 @@ String PlanPrinter::jsonLogicalPlan(QueryPlan & plan, std::optional & segment_profile) { auto f = [](ExchangeMode mode) { switch (mode) @@ -231,7 +243,7 @@ String PlanPrinter::getPlanSegmentHeaderText(PlanSegmentDescriptionPtr & segment { if (!first) os << "\n "; - os << "(SegmentId:" << output->segment_id + os << "(SegmentId:" << output->segment_id << " ExchangeId:" << output->exchange_id << " ExchangeMode:" << magic_enum::enum_name(output->mode) << " ParallelSize:" << output->parallel_size @@ -249,15 +261,24 @@ String PlanPrinter::getPlanSegmentHeaderText(PlanSegmentDescriptionPtr & segment { if (!first) os << "\n "; - os << "(SegmentId:" << input->segment_id + os << "(SegmentId:" << input->segment_id << " ExchangeId:" << input->exchange_id << " ExchangeMode:" << magic_enum::enum_name(input->mode) << " ExchangeParallelSize:" << input->exchange_parallel_size - << " KeepOrder:" << input->keep_order << ")"; + << " KeepOrder:" << input->keep_order + << (input->stable ? " Stable" : "") << ")"; first = false; } os << "]\n"; } + if (print_profile && !segment_profile.empty() && segment_profile.contains(segment_id)) + { + const auto & profiles = segment_profile.at(segment_id); + for (const auto & profile : profiles) + os << " " << profile->worker_address << " ReadRows: " << profile->read_rows + << " QueryDurationTime: " << profile->query_duration_ms << "ms." + << " IOWaitTime: " << profile->io_wait_ms << "ms.\n"; + } return os.str(); } @@ -265,9 +286,10 @@ String PlanPrinter::textDistributedPlan( PlanSegmentDescriptions & segments_desc, ContextMutablePtr context, const std::unordered_map & costs, - const StepAggregatedOperatorProfiles & profiles, + const StepProfiles & profiles, const QueryPlan & query_plan, - const QueryPlanSettings & settings) + const QueryPlanSettings & settings, + const std::unordered_map & segment_profile) { auto id_to_node = getPlanNodeMap(query_plan); for (auto & segment_desc : segments_desc) @@ -299,8 +321,12 @@ String PlanPrinter::textDistributedPlan( for (auto & segment_ptr : segments_desc) { - os << getPlanSegmentHeaderText(segment_ptr); - if (!segment_ptr->plan_node) + if (settings.segment_id != UINT64_MAX && segment_ptr->segment_id != settings.segment_id) + continue; + + os << getPlanSegmentHeaderText(segment_ptr, settings.segment_profile, segment_profile); + + if (!segment_ptr->plan_node) continue; auto analyze_node = PlanNodeSearcher::searchFrom(segment_ptr->plan_node) @@ -339,7 +365,11 @@ String PlanPrinter::textDistributedPlan( } -String PlanPrinter::textPipelineProfile(PlanSegmentDescriptions & segment_descs, SegmentAndWorkerToGroupedProfile & worker_grouped_profiles) +String PlanPrinter::textPipelineProfile( + PlanSegmentDescriptions & segment_descs, + SegIdAndAddrToPipelineProfile & worker_grouped_profiles, + const QueryPlanSettings & settings, + const std::unordered_map & segment_profile) { std::ostringstream os; @@ -349,19 +379,21 @@ String PlanPrinter::textPipelineProfile(PlanSegmentDescriptions & segment_descs, for (auto & segment_ptr : segment_descs) { size_t segment_id = segment_ptr->segment_id; - os << getPlanSegmentHeaderText(segment_ptr); + if (settings.segment_id != UINT64_MAX && segment_id != settings.segment_id) + continue; + os << getPlanSegmentHeaderText(segment_ptr, settings.segment_profile, segment_profile); if (!worker_grouped_profiles.contains(segment_id) || worker_grouped_profiles.at(segment_id).empty()) continue; - for (auto & [address, segment_profile] : worker_grouped_profiles.at(segment_id)) + for (auto & [address, profile] : worker_grouped_profiles.at(segment_id)) { - if (!segment_profile) + if (!profile) continue; TextPrinterIntent print{3, false}; os << print.print() << address << "\n"; - TextPrinter printer{{}}; - bool has_children = !segment_profile->children.empty(); - auto output = printer.printPipelineProfile(segment_profile, TextPrinterIntent{3, has_children}); + TextPrinter printer{{}, nullptr, true, {}, settings}; + bool has_children = !profile->children.empty(); + auto output = printer.printPipelineProfile(profile, TextPrinterIntent{3, has_children}); os << output; } os << "\n"; @@ -370,7 +402,7 @@ String PlanPrinter::textPipelineProfile(PlanSegmentDescriptions & segment_descs, return os.str(); } -String PlanPrinter::jsonPipelineProfile(PlanSegmentDescriptions & segment_descs, SegmentAndWorkerToGroupedProfile & worker_grouped_profiles) +String PlanPrinter::jsonPipelineProfile(PlanSegmentDescriptions & segment_descs, SegIdAndAddrToPipelineProfile & worker_grouped_profiles) { Poco::JSON::Object::Ptr distributed_plan = new Poco::JSON::Object(true); Poco::JSON::Array segments; @@ -490,7 +522,7 @@ String TextPrinterIntent::detailIntent() const } String PlanPrinter::TextPrinter::printLogicalPlan( - PlanNodeBase & plan, const TextPrinterIntent & intent, const StepAggregatedOperatorProfiles & profiles) // NOLINT(misc-no-recursion) + PlanNodeBase & plan, const TextPrinterIntent & intent, const StepProfiles & profiles) // NOLINT(misc-no-recursion) { std::stringstream out; @@ -512,9 +544,8 @@ String PlanPrinter::TextPrinter::printLogicalPlan( if (settings.stats) out << intent.detailIntent() << printStatistics(plan, intent); if (settings.profile && profiles.count(plan.getId())) - out << printOperatorProfiles(plan, intent, profiles) << intent.detailIntent() << printQError(plan, profiles); - out << printDetail(plan.getStep(), intent) << "\n"; - + out << printStepProfiles(plan, intent, profiles) << intent.detailIntent() << printQError(plan, profiles); + out << printDetail(plan.getStep(), intent) << printAttributes(plan, intent, profiles) << "\n"; } if ((step->getType() == IQueryPlanStep::Type::CTERef || step->getType() == IQueryPlanStep::Type::Exchange) && is_distributed) @@ -538,7 +569,7 @@ String PlanPrinter::TextPrinter::printLogicalPlan( String PlanPrinter::TextPrinter::printPipelineProfile(GroupedProcessorProfilePtr & input_root, const TextPrinterIntent & intent) { std::stringstream out; - out << intent.print() << printProcessorDetail(input_root, intent) << "\n"; + out << intent.print() << printPipelineProfileDetail(input_root, intent) << "\n"; for (auto it = input_root->children.begin(); it != input_root->children.end();) { @@ -550,20 +581,26 @@ String PlanPrinter::TextPrinter::printPipelineProfile(GroupedProcessorProfilePtr return out.str(); } -String PlanPrinter::TextPrinter::printProcessorDetail(GroupedProcessorProfilePtr profile, const TextPrinterIntent & intent) +String PlanPrinter::TextPrinter::printPipelineProfileDetail(GroupedProcessorProfilePtr profile, const TextPrinterIntent & intent) { std::stringstream out; - out << profile->processor_name <<" x" << profile->parallel_size << " ElapsedTime:" << prettySeconds(profile->grouped_elapsed_us/profile->worker_cnt); - if (profile->worker_cnt > 1) + out << profile->processor_name << " x" << profile->parallel_size + << " ElapsedTime:" << prettySeconds(profile->sum_grouped_elapsed_us / profile->parallel_size); + if (profile->parallel_size > 1) out<< "[max=" << prettySeconds(profile->max_grouped_elapsed_us) << ", min=" << prettySeconds(profile->min_grouped_elapsed_us) << "]"; - out << intent.detailIntent() << "Input: WaitTime:" << prettySeconds(profile->grouped_input_wait_elapsed_us/profile->worker_cnt); - if (profile->worker_cnt > 1) - out<< "[max=" << prettySeconds(profile->max_grouped_input_wait_elapsed_us) << ", min=" << prettySeconds(profile->min_grouped_input_wait_elapsed_us) << "]"; - out << " Rows:" << prettyNum(profile->grouped_input_rows) << " (" << prettyBytes(profile->grouped_input_bytes) << ")"; - out << intent.detailIntent() << " Output: WaitTime:" << prettySeconds(profile->grouped_output_wait_elapsed_us/profile->worker_cnt); - if (profile->worker_cnt > 1) + out << intent.detailIntent() << "Output: Rows:" << prettyNum(profile->grouped_output_rows, settings.pretty_num) << " (" + << prettyBytes(profile->grouped_output_bytes) << ")"; + out << " WaitTime:" << prettySeconds(profile->sum_grouped_output_wait_elapsed_us / profile->parallel_size); + if (profile->parallel_size > 1) out<< "[max=" << prettySeconds(profile->max_grouped_output_wait_elapsed_us) << ", min=" << prettySeconds(profile->min_grouped_output_wait_elapsed_us) << "]"; - out << " Rows:" << prettyNum(profile->grouped_output_rows) << " (" << prettyBytes(profile->grouped_output_bytes) << ")"; + + out << intent.detailIntent() << "Input: Rows:" << prettyNum(profile->grouped_input_rows, settings.pretty_num) << " (" + << prettyBytes(profile->grouped_input_bytes) << ")"; + out << " WaitTime:" << prettySeconds(profile->sum_grouped_input_wait_elapsed_us / profile->parallel_size); + if (profile->parallel_size > 1) + out << "[max=" << prettySeconds(profile->max_grouped_input_wait_elapsed_us) + << ", min=" << prettySeconds(profile->min_grouped_input_wait_elapsed_us) << "]"; + return out.str(); } @@ -579,8 +616,7 @@ String PlanPrinter::TextPrinter::printStatistics(const PlanNodeBase & plan, cons return out.str(); } -String PlanPrinter::TextPrinter::printOperatorProfiles( - PlanNodeBase & plan, const TextPrinterIntent & intent, const StepAggregatedOperatorProfiles & profiles) +String PlanPrinter::TextPrinter::printStepProfiles(PlanNodeBase & plan, const TextPrinterIntent & intent, const StepProfiles & profiles) { size_t step_id = plan.getId(); if (profiles.count(step_id)) @@ -590,17 +626,19 @@ String PlanPrinter::TextPrinter::printOperatorProfiles( out << intent.detailIntent() << "Act. WallTime: " << prettySeconds(profile->sum_elapsed_us/profile->worker_cnt); if (profile->worker_cnt > 1) out << "[max= " << prettySeconds(profile->max_elapsed_us) << ", min=" << prettySeconds(profile->min_elapsed_us) << "]"; - out << intent.detailIntent() << " Output: " << prettyNum(profile->output_rows) << " rows(" << prettyBytes(profile->output_bytes) << ")"; - out << ", WaitTime: " << prettySeconds(profile->sum_output_wait_elapsed_us/profile->worker_cnt); + out << intent.detailIntent() << " Output: " << prettyNum(profile->output_rows, settings.pretty_num) << " rows(" + << prettyBytes(profile->output_bytes) << ")"; + out << ", WaitTime: " << prettySeconds(profile->output_wait_sum_elapsed_us / profile->worker_cnt); if (profile->worker_cnt > 1) - out << "[max=" << prettySeconds(profile->max_output_wait_elapsed_us) << ", min=" << prettySeconds(profile->min_output_wait_elapsed_us) << "]"; + out << "[max=" << prettySeconds(profile->output_wait_max_elapsed_us) + << ", min=" << prettySeconds(profile->output_wait_min_elapsed_us) << "]"; int num = 1; - if (!plan.getChildren().empty() && profile->inputs_profile.contains(plan.getChildren()[0]->getId())) + if (!plan.getChildren().empty() && profile->inputs.contains(plan.getChildren()[0]->getId())) { for (auto & child : plan.getChildren()) { - auto input_profile = profile->inputs_profile[child->getId()]; + auto input_profile = profile->inputs[child->getId()]; if (num == 1) out << intent.detailIntent() << " Input: "; else @@ -609,16 +647,18 @@ String PlanPrinter::TextPrinter::printOperatorProfiles( if (plan.getChildren().size() > 1) out << "source[" << num << "] : "; - out << prettyNum(input_profile.input_rows) << " rows(" << prettyBytes(input_profile.input_bytes) << ")"; - out << ", WaitTime: " << prettySeconds(input_profile.input_wait_elapsed_us/profile->worker_cnt); + out << prettyNum(input_profile.input_rows, settings.pretty_num) << " rows(" << prettyBytes(input_profile.input_bytes) + << ")"; + out << ", WaitTime: " << prettySeconds(input_profile.input_wait_sum_elapsed_us / profile->worker_cnt); if (profile->worker_cnt > 1) - out << "[max=" << prettySeconds(input_profile.max_input_wait_elapsed_us) << ", min=" << prettySeconds(input_profile.min_input_wait_elapsed_us) << "]"; + out << "[max=" << prettySeconds(input_profile.input_wait_max_elapsed_us) + << ", min=" << prettySeconds(input_profile.input_wait_min_elapsed_us) << "]"; ++num; } } else { - for (auto & [id, input_metrics] : profile->inputs_profile) + for (auto & [id, input_metrics] : profile->inputs) { if (num == 1) out << intent.detailIntent() << " Input: "; @@ -628,9 +668,10 @@ String PlanPrinter::TextPrinter::printOperatorProfiles( if (plan.getChildren().size() > 1) out << "source [" << num << "] : "; - out << "WaitTime: " << prettySeconds(input_metrics.input_wait_elapsed_us/profile->worker_cnt); + out << "WaitTime: " << prettySeconds(input_metrics.input_wait_sum_elapsed_us / profile->worker_cnt); if (profile->worker_cnt > 1) - out << "[max=" << prettySeconds(input_metrics.max_input_wait_elapsed_us) << ", min=" << prettySeconds(input_metrics.min_input_wait_elapsed_us) << "]"; + out << "[max=" << prettySeconds(input_metrics.input_wait_max_elapsed_us) + << ", min=" << prettySeconds(input_metrics.input_wait_min_elapsed_us) << "]"; ++num; } } @@ -640,15 +681,57 @@ String PlanPrinter::TextPrinter::printOperatorProfiles( return ""; } -String PlanPrinter::TextPrinter::prettyNum(size_t num) +String PlanPrinter::TextPrinter::printAttributes(PlanNodeBase & plan, const TextPrinterIntent & intent, const StepProfiles & profiles) const +{ + size_t step_id = plan.getId(); + if (!profiles.contains(step_id) || profiles.at(step_id)->address_to_attributes.empty()) + return ""; + if (!settings.query_plan_options.indexes && !settings.selected_parts) + return ""; + std::stringstream out; + const auto & address_to_attributes = profiles.at(step_id)->address_to_attributes; + if (plan.getStep()->getType() == IQueryPlanStep::Type::TableScan) + { + String space; + for (const auto & [address, attribute] : address_to_attributes) + { + if (address_to_attributes.size() > 1) + { + out << intent.detailIntent() << address; + space = " "; + } + if (settings.query_plan_options.indexes && attribute.contains("Indexes")) + { + out << intent.detailIntent() << space << "Indexes:"; + auto index_desc = attribute.at("Indexes"); + for (const auto & desc : index_desc->name_and_detail) + out << intent.detailIntent() << space << " " << desc.second; + } + if (settings.selected_parts) + { + if (attribute.contains("SelectParts")) + out << intent.detailIntent() << space << attribute.at("SelectParts")->description; + if (attribute.contains("TableScanDescription")) + out << intent.detailIntent() << space << attribute.at("TableScanDescription")->description; + } + } + return out.str(); + } + return ""; +} + +String PlanPrinter::TextPrinter::prettyNum(size_t num, bool pretty_num) { std::vector suffixes{"", "K", "M", "B", "T"}; size_t idx = 0; auto count = static_cast(num); - while (count >= 1000 && idx < suffixes.size() - 1) + if (pretty_num) { - idx++; - count /= static_cast(1000); + while (count >= 1000 && idx < suffixes.size() - 1) + { + idx++; + count /= static_cast(1000); + } } std::stringstream out; @@ -691,7 +774,7 @@ String PlanPrinter::TextPrinter::prettyBytes(size_t bytes) return out.str(); } -String PlanPrinter::TextPrinter::printQError(const PlanNodeBase & plan, const StepAggregatedOperatorProfiles & profiles) +String PlanPrinter::TextPrinter::printQError(const PlanNodeBase & plan, const StepProfiles & profiles) { const auto & stats = plan.getStatistics(); std::stringstream out; @@ -829,6 +912,7 @@ String PlanPrinter::TextPrinter::printPrefix(PlanNodeBase & plan) return ""; } + String PlanPrinter::TextPrinter::printSuffix(PlanNodeBase & plan) { std::stringstream out; @@ -838,8 +922,9 @@ String PlanPrinter::TextPrinter::printSuffix(PlanNodeBase & plan) if (plan.getStep()->getType() == IQueryPlanStep::Type::TableScan) { + auto tenant_id = context->getTenantId(); const auto * table_scan = dynamic_cast(plan.getStep().get()); - out << " " << table_scan->getDatabase() << "." << table_scan->getOriginalTable(); + out << " " << getOriginalDatabaseName(table_scan->getDatabase(), tenant_id) << "." << table_scan->getOriginalTable(); } else if (plan.getStep()->getType() == IQueryPlanStep::Type::Exchange && segment_id != -1) { @@ -882,9 +967,11 @@ String PlanPrinter::TextPrinter::printDetail(QueryPlanStepPtr plan, const TextPr const auto * join_step = dynamic_cast(plan.get()); out << intent.detailIntent() << "Condition: "; if (!join_step->getLeftKeys().empty()) - out << join_step->getLeftKeys()[0] << " == " << join_step->getRightKeys()[0]; + out << join_step->getLeftKeys()[0] << " == " << join_step->getRightKeys()[0] + << (join_step->getKeyIdNullSafe(0) ? "(null aware)" : ""); for (size_t i = 1; i < join_step->getLeftKeys().size(); i++) - out << ", " << join_step->getLeftKeys()[i] << " == " << join_step->getRightKeys()[i]; + out << ", " << join_step->getLeftKeys()[i] << " == " << join_step->getRightKeys()[i] + << (join_step->getKeyIdNullSafe(i) ? "(null aware)" : ""); if (!ASTEquality::compareTree(join_step->getFilter(), PredicateConst::TRUE_VALUE)) { @@ -908,6 +995,14 @@ String PlanPrinter::TextPrinter::printDetail(QueryPlanStepPtr plan, const TextPr sort_columns.emplace_back(desc.format()); out << intent.detailIntent() << "Order by: " << join(sort_columns, ", ", "{", "}"); + if (!sort->getPrefixDescription().empty()) + { + std::vector prefix_sort_columns; + for (const auto & desc : sort->getPrefixDescription()) + prefix_sort_columns.emplace_back(desc.column_name); + out << intent.detailIntent() << "Prefix Order: " << join(prefix_sort_columns, ", ", "{", "}"); + } + std::visit( overloaded{ [&](size_t x) { @@ -1043,7 +1138,7 @@ String PlanPrinter::TextPrinter::printDetail(QueryPlanStepPtr plan, const TextPr { out << intent.detailIntent() << "Partition filter: " << printFilter(query_info.partition_filter, max_predicate_text_length); } - + if (query_info.input_order_info) { out << intent.detailIntent(); @@ -1121,7 +1216,7 @@ String PlanPrinter::TextPrinter::printDetail(QueryPlanStepPtr plan, const TextPr { const auto * table_write = dynamic_cast(plan.get()); if (table_write->getTarget()) - out << intent.detailIntent() << table_write->getTarget()->toString(); + out << intent.detailIntent() << table_write->getTarget()->toString(context->getTenantId()); } if (plan->getType() == IQueryPlanStep::Type::TotalsHaving) @@ -1511,7 +1606,8 @@ void NodeDescription::setStepStatistic(PlanNodePtr node) } } -Poco::JSON::Object::Ptr NodeDescription::jsonNodeDescription(const StepAggregatedOperatorProfiles & node_profiles, bool print_stats, const PlanCostMap & costs) +Poco::JSON::Object::Ptr +NodeDescription::jsonNodeDescription(const StepProfiles & node_profiles, bool print_stats, const PlanCostMap & costs) { Poco::JSON::Object::Ptr json = new Poco::JSON::Object(true); json->set("NodeId", node_id); @@ -1537,41 +1633,41 @@ Poco::JSON::Object::Ptr NodeDescription::jsonNodeDescription(const StepAggregate { const auto & profile_detail = node_profiles.at(node_id); Poco::JSON::Object::Ptr profiles = new Poco::JSON::Object(true); - profiles->set("WallTimeMs", profile_detail->sum_elapsed_us/profile_detail->worker_cnt/1000); - profiles->set("MaxWallTimeMs", profile_detail->max_elapsed_us/1000); - profiles->set("MinWallTimeMs", profile_detail->min_elapsed_us/1000); + profiles->set("WallTimeMs", float(profile_detail->sum_elapsed_us)/profile_detail->worker_cnt/1000); + profiles->set("MaxWallTimeMs", float(profile_detail->max_elapsed_us)/1000); + profiles->set("MinWallTimeMs", float(profile_detail->min_elapsed_us)/1000); profiles->set("OutputRows", profile_detail->output_rows); profiles->set("OutputBytes", profile_detail->output_bytes); - profiles->set("OutputWaitTimeMs", profile_detail->sum_output_wait_elapsed_us/profile_detail->worker_cnt/1000); - profiles->set("MaxOutputWaitTimeMs", profile_detail->max_output_wait_elapsed_us/1000); - profiles->set("MinOutputWaitTimeMs", profile_detail->min_output_wait_elapsed_us/1000); + profiles->set("OutputWaitTimeMs", float(profile_detail->output_wait_sum_elapsed_us)/profile_detail->worker_cnt/1000); + profiles->set("MaxOutputWaitTimeMs", float(profile_detail->output_wait_max_elapsed_us)/1000); + profiles->set("MinOutputWaitTimeMs", float(profile_detail->output_wait_min_elapsed_us)/1000); Poco::JSON::Array inputs_profile; - if (!children.empty() && profile_detail->inputs_profile.contains(children[0]->node_id)) + if (!children.empty() && profile_detail->inputs.contains(children[0]->node_id)) { for (auto & child : children) { - auto input_profile = profile_detail->inputs_profile[child->node_id]; + auto input_profile = profile_detail->inputs[child->node_id]; Poco::JSON::Object::Ptr input = new Poco::JSON::Object(true); input->set("InputNodeId", child->node_id); input->set("InputRows", input_profile.input_rows); input->set("InputBytes", input_profile.input_bytes); - input->set("InputWaitTimeMs", input_profile.input_wait_elapsed_us/profile_detail->worker_cnt/1000); - input->set("MaxInputWaitTimeMs", input_profile.max_input_wait_elapsed_us/1000); - input->set("MinInputWaitTimeMs", input_profile.min_input_wait_elapsed_us/1000); + input->set("InputWaitTimeMs", float(input_profile.input_wait_sum_elapsed_us)/profile_detail->worker_cnt/1000); + input->set("MaxInputWaitTimeMs", float(input_profile.input_wait_max_elapsed_us)/1000); + input->set("MinInputWaitTimeMs", float(input_profile.input_wait_min_elapsed_us)/1000); inputs_profile.add(input); } } else { - for (auto input_profile : profile_detail->inputs_profile) + for (auto input_profile : profile_detail->inputs) { Poco::JSON::Object::Ptr input = new Poco::JSON::Object(true); input->set("InputNodeId", input_profile.first); input->set("InputRows", input_profile.second.input_rows); input->set("InputBytes", input_profile.second.input_bytes); - input->set("InputWaitTimeMs", input_profile.second.input_wait_elapsed_us/profile_detail->worker_cnt/1000); - input->set("MaxInputWaitTimeMs", input_profile.second.max_input_wait_elapsed_us/1000); - input->set("MinInputWaitTimeMs", input_profile.second.min_input_wait_elapsed_us/1000); + input->set("InputWaitTimeMs", float(input_profile.second.input_wait_sum_elapsed_us)/profile_detail->worker_cnt/1000); + input->set("MaxInputWaitTimeMs", float(input_profile.second.input_wait_max_elapsed_us)/1000); + input->set("MinInputWaitTimeMs", float(input_profile.second.input_wait_min_elapsed_us)/1000); inputs_profile.add(input); } } @@ -1649,7 +1745,7 @@ NodeDescriptionPtr NodeDescription::getPlanDescription(PlanNodePtr node) return description; } -String PlanSegmentDescription::jsonPlanSegmentDescriptionAsString(const StepAggregatedOperatorProfiles & profiles) +String PlanSegmentDescription::jsonPlanSegmentDescriptionAsString(const StepProfiles & profiles) { auto json = jsonPlanSegmentDescription(profiles); std::ostringstream os; @@ -1657,7 +1753,7 @@ String PlanSegmentDescription::jsonPlanSegmentDescriptionAsString(const StepAggr return os.str(); } -Poco::JSON::Object::Ptr PlanSegmentDescription::jsonPlanSegmentDescription(const StepAggregatedOperatorProfiles & profiles, bool is_pipeline) +Poco::JSON::Object::Ptr PlanSegmentDescription::jsonPlanSegmentDescription(const StepProfiles & profiles, bool is_pipeline) { Poco::JSON::Object::Ptr json = new Poco::JSON::Object(true); @@ -1791,6 +1887,7 @@ PlanSegmentDescriptionPtr PlanSegmentDescription::getPlanSegmentDescription(Plan input_desc.mode = input->getExchangeMode(); input_desc.exchange_parallel_size = input->getExchangeParallelSize(); input_desc.keep_order = input->needKeepOrder(); + input_desc.stable = input->isStable(); auto input_desc_ptr = std::make_shared(input_desc); plan_segment_desc->inputs_desc.emplace_back(input_desc_ptr); } @@ -1808,7 +1905,7 @@ PlanSegmentDescriptionPtr PlanSegmentDescription::getPlanSegmentDescription(Plan return plan_segment_desc; } -String PlanPrinter::jsonDistributedPlan(PlanSegmentDescriptions & segment_descs, const StepAggregatedOperatorProfiles & profiles) +String PlanPrinter::jsonDistributedPlan(PlanSegmentDescriptions & segment_descs, const StepProfiles & profiles) { Poco::JSON::Object::Ptr distributed_plan = new Poco::JSON::Object(true); Poco::JSON::Array segments; @@ -1862,6 +1959,18 @@ String PlanPrinter::jsonMetaData( query_used_settings->set(setting.name, setting.value.toString()); metadata_json->set("UsedSettings", query_used_settings); + Poco::JSON::Array output_descs; + ASTPtr & select_ast = query; + if (auto * insert_query = query->as()) + select_ast = insert_query->select; + + if (analysis->hasOutputDescription(*select_ast)) + { + for (const auto & desc : analysis->getOutputDescription(*select_ast)) + output_descs.add(desc.name); + } + metadata_json->set("OutputDescriptions", output_descs); + // get InsertInfo Poco::JSON::Object::Ptr insert_table_info = new Poco::JSON::Object(true); if (analysis->getInsert()) diff --git a/src/QueryPlan/PlanPrinter.h b/src/QueryPlan/PlanPrinter.h index 593754ae5ca..8b12e2ff103 100644 --- a/src/QueryPlan/PlanPrinter.h +++ b/src/QueryPlan/PlanPrinter.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -29,6 +30,9 @@ using PlanCostMap = std::unordered_map; struct PlanSegmentDescription; using PlanSegmentDescriptionPtr = std::shared_ptr; using PlanSegmentDescriptions = std::vector; +struct PlanSegmentProfile; +using PlanSegmentProfilePtr = std::shared_ptr; +using PlanSegmentProfiles = std::vector; struct Analysis; using AnalysisPtr = std::shared_ptr; @@ -47,28 +51,36 @@ class PlanPrinter QueryPlan & plan, ContextMutablePtr context, PlanCostMap costs = {}, - const StepAggregatedOperatorProfiles & profiles = {}, + const StepProfiles & profiles = {}, const QueryPlanSettings & settings = {}); static String jsonLogicalPlan( QueryPlan & plan, std::optional plan_cost, - const StepAggregatedOperatorProfiles & profiles = {}, + const StepProfiles & profiles = {}, const PlanCostMap & costs = {}, const QueryPlanSettings & settings = {}); - static String jsonDistributedPlan(PlanSegmentDescriptions & segment_descs, const StepAggregatedOperatorProfiles & profiles); + static String jsonDistributedPlan(PlanSegmentDescriptions & segment_descs, const StepProfiles & profiles); static String textDistributedPlan( PlanSegmentDescriptions & segments_desc, ContextMutablePtr context, const std::unordered_map & costs = {}, - const StepAggregatedOperatorProfiles & profiles = {}, + const StepProfiles & profiles = {}, const QueryPlan & query_plan = {}, - const QueryPlanSettings & settings = {}); - static String textPipelineProfile(PlanSegmentDescriptions & segment_descs, SegmentAndWorkerToGroupedProfile & worker_grouped_profiles); - static String jsonPipelineProfile(PlanSegmentDescriptions & segment_descs, SegmentAndWorkerToGroupedProfile & worker_grouped_profiles); + const QueryPlanSettings & settings = {}, + const std::unordered_map & segment_profile = {}); + static String textPipelineProfile( + PlanSegmentDescriptions & segment_descs, + SegIdAndAddrToPipelineProfile & worker_grouped_profiles, + const QueryPlanSettings & settings = {}, + const std::unordered_map & segment_profile = {}); + static String jsonPipelineProfile(PlanSegmentDescriptions & segment_descs, SegIdAndAddrToPipelineProfile & worker_grouped_profiles); static void getPlanNodes(const PlanNodePtr & parent, std::unordered_map & id_to_node); static std::unordered_map getPlanNodeMap(const QueryPlan & query_plan); static void getRemoteSegmentId(const QueryPlan::Node * node, std::unordered_map & exchange_to_segment); - static String getPlanSegmentHeaderText(PlanSegmentDescriptionPtr & segment_desc); + static String getPlanSegmentHeaderText( + PlanSegmentDescriptionPtr & segment_desc, + bool print_profile = false, + const std::unordered_map & segment_profile = {}); static String jsonMetaData( ASTPtr & query, AnalysisPtr analysis, ContextMutablePtr context, QueryPlanPtr & plan, const QueryMetadataSettings & settings = {}); @@ -122,21 +134,22 @@ class PlanPrinter::TextPrinter , max_predicate_text_length(max_predicate_text_length_) {} static String printOutputColumns(PlanNodeBase & plan_node, const TextPrinterIntent & intent = {}); - String printLogicalPlan(PlanNodeBase & plan, const TextPrinterIntent & intent = {}, const StepAggregatedOperatorProfiles & profiles = {}); + String printLogicalPlan(PlanNodeBase & plan, const TextPrinterIntent & intent = {}, const StepProfiles & profiles = {}); String printPipelineProfile(GroupedProcessorProfilePtr & input_root, const TextPrinterIntent & intent = {}); - static String prettyNum(size_t num); + static String prettyNum(size_t num, bool pretty_num = true); static String prettyBytes(size_t bytes); static String prettySeconds(size_t seconds); static String printPrefix(PlanNodeBase & plan); String printSuffix(PlanNodeBase & plan); - static String printQError(const PlanNodeBase & plan, const StepAggregatedOperatorProfiles & profiles); + static String printQError(const PlanNodeBase & plan, const StepProfiles & profiles); static String printFilter(ConstASTPtr filter, size_t max_text_length = 10000); private: String printDetail(QueryPlanStepPtr plan, const TextPrinterIntent & intent) const; - static String printProcessorDetail(GroupedProcessorProfilePtr profile, const TextPrinterIntent & intent); + String printPipelineProfileDetail(GroupedProcessorProfilePtr profile, const TextPrinterIntent & intent); String printStatistics(const PlanNodeBase & plan, const TextPrinterIntent & intent = {}) const; - static String printOperatorProfiles(PlanNodeBase & plan, const TextPrinterIntent & intent = {}, const StepAggregatedOperatorProfiles & profiles = {}) ; + String printStepProfiles(PlanNodeBase & plan, const TextPrinterIntent & intent = {}, const StepProfiles & profiles = {}); + String printAttributes(PlanNodeBase & plan, const TextPrinterIntent & intent, const StepProfiles & profiles = {}) const; const std::unordered_map & costs; bool is_distributed; @@ -170,7 +183,7 @@ class NodeDescription void setStepStatistic(PlanNodePtr node); void setStepDetail(QueryPlanStepPtr step); - Poco::JSON::Object::Ptr jsonNodeDescription(const StepAggregatedOperatorProfiles & profiles, bool print_stats, const PlanCostMap & costs = {}); + Poco::JSON::Object::Ptr jsonNodeDescription(const StepProfiles & node_profiles, bool print_stats, const PlanCostMap & costs = {}); static NodeDescriptionPtr getPlanDescription(QueryPlan::Node * node); static NodeDescriptionPtr getPlanDescription(PlanNodePtr node); }; @@ -193,6 +206,7 @@ struct PlanSegmentDescription size_t exchange_id; size_t exchange_parallel_size; bool keep_order; + bool stable; }; size_t segment_id; String segment_type; @@ -217,8 +231,8 @@ struct PlanSegmentDescription NodeDescriptionPtr node_description; - Poco::JSON::Object::Ptr jsonPlanSegmentDescription(const StepAggregatedOperatorProfiles & profiles, bool is_pipeline = false); - String jsonPlanSegmentDescriptionAsString(const StepAggregatedOperatorProfiles & profiles); + Poco::JSON::Object::Ptr jsonPlanSegmentDescription(const StepProfiles & profiles, bool is_pipeline = false); + String jsonPlanSegmentDescriptionAsString(const StepProfiles & profiles); static PlanSegmentDescriptionPtr getPlanSegmentDescription(PlanSegmentPtr & segment, bool record_plan_detail = false); }; diff --git a/src/QueryPlan/PlanSerDerHelper.cpp b/src/QueryPlan/PlanSerDerHelper.cpp index ed8724109fb..dc48ad1e53f 100644 --- a/src/QueryPlan/PlanSerDerHelper.cpp +++ b/src/QueryPlan/PlanSerDerHelper.cpp @@ -367,23 +367,23 @@ bool isPlanStepEqual(const IQueryPlanStep & a, const IQueryPlanStep & b) } template -UInt64 hashPlanStepImpl(const IQueryPlanStep & raw_step) +UInt64 hashPlanStepImpl(const IQueryPlanStep & raw_step, bool ignore_output_stream) { const auto & step = reinterpret_cast(raw_step); ProtoType proto; - step.toProto(proto, true); + step.toProto(proto, ignore_output_stream); auto res = sipHash64Protobuf(proto); return res; } -UInt64 hashPlanStep(const IQueryPlanStep & step) +UInt64 hashPlanStep(const IQueryPlanStep & step, bool ignore_output_stream) { switch (step.getType()) { #define CASE_DEF(TYPE, VAR_NAME) \ case IQueryPlanStep::Type::TYPE: { \ - return hashPlanStepImpl(step); \ + return hashPlanStepImpl(step, ignore_output_stream); \ } APPLY_STEP_PROTOBUF_TYPES_AND_NAMES(CASE_DEF) diff --git a/src/QueryPlan/PlanSerDerHelper.h b/src/QueryPlan/PlanSerDerHelper.h index 0eb5d890ac1..15c033a809e 100644 --- a/src/QueryPlan/PlanSerDerHelper.h +++ b/src/QueryPlan/PlanSerDerHelper.h @@ -270,5 +270,5 @@ void serializeQueryPlanStepToProto(const QueryPlanStepPtr & step, Protos::QueryP QueryPlanStepPtr deserializeQueryPlanStepFromProto(const Protos::QueryPlanStep & proto, ContextPtr context); bool isPlanStepEqual(const IQueryPlanStep & a, const IQueryPlanStep & b); -UInt64 hashPlanStep(const IQueryPlanStep & step); +UInt64 hashPlanStep(const IQueryPlanStep & step, bool ignore_output_stream); } diff --git a/src/QueryPlan/QueryPlan.h b/src/QueryPlan/QueryPlan.h index 74967235f0b..9f2643e3a4c 100644 --- a/src/QueryPlan/QueryPlan.h +++ b/src/QueryPlan/QueryPlan.h @@ -132,6 +132,15 @@ class QueryPlan void setMaxThreads(size_t max_threads_) { max_threads = max_threads_; } size_t getMaxThreads() const { return max_threads; } + void setShortCircuit(bool short_circuit_) + { + short_circuit = short_circuit_; + } + bool isShortCircuit() const + { + return short_circuit; + } + void addInterpreterContext(std::shared_ptr context); void serialize(WriteBuffer & buffer) const; @@ -217,6 +226,8 @@ class QueryPlan std::shared_ptr max_node_id; //Whether reset step id in serialize(),use for explain analyze. bool reset_step_id = true; + + bool short_circuit = false; }; std::string debugExplainStep(const IQueryPlanStep & step); diff --git a/src/QueryPlan/QueryPlanner.cpp b/src/QueryPlan/QueryPlanner.cpp index 051c696ef88..9a63e244c53 100644 --- a/src/QueryPlan/QueryPlanner.cpp +++ b/src/QueryPlan/QueryPlanner.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -144,7 +145,8 @@ class QueryPlannerVisitor : public ASTVisitor void planJoinUsing(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); void planJoinOn(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); std::pair prepareJoinUsingKeys(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); - std::pair prepareJoinOnKeys(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); + std::tuple> + prepareJoinOnKeys(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); static DataStream getJoinOutputStream(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder); RelationPlan planReadFromStorage(IAST & table_ast, ScopePtr table_scope, ASTSelectQuery & origin_query, SqlHints & hints, bool is_table_function = false); @@ -319,12 +321,13 @@ RelationPlan QueryPlannerVisitor::visitASTInsertQuery(ASTPtr & node, const Void auto & insert = *analysis.getInsert(); auto select_plan = process(insert_query.select); select_plan.withNewRoot(planOutput(select_plan, insert_query.select, analysis, context)); + auto insert_select_with_profiles = context->getSettingsRef().insert_select_with_profiles; auto target = std::make_shared(insert.storage, insert.storage_id, insert.columns, node); auto insert_node = select_plan.getRoot()->addStep( context->nextNodeId(), - std::make_shared(select_plan.getRoot()->getCurrentDataStream(), target), + std::make_shared(select_plan.getRoot()->getCurrentDataStream(), target, insert_select_with_profiles), {select_plan.getRoot()}); auto total_affected_row_count_symbol = context->getSymbolAllocator()->newSymbol("rows"); @@ -335,7 +338,7 @@ RelationPlan QueryPlannerVisitor::visitASTInsertQuery(ASTPtr & node, const Void auto return_node = PlanNodeBase::createPlanNode( context->nextNodeId(), - std::make_shared(insert_node->getCurrentDataStream(), target, total_affected_row_count_symbol, node), + std::make_shared(insert_node->getCurrentDataStream(), target, total_affected_row_count_symbol, node, insert_select_with_profiles), {insert_node}); PRINT_PLAN(return_node, plan_insert); @@ -657,6 +660,7 @@ void QueryPlannerVisitor::planJoinUsing(ASTTableJoin & table_join, PlanBuilder & context->getSettingsRef().optimize_read_in_order, left_keys, right_keys, + std::vector{}, PredicateConst::TRUE_VALUE, true, use_ansi_semantic ? std::nullopt : std::make_optional(join_analysis.require_right_keys)); @@ -741,7 +745,7 @@ void QueryPlannerVisitor::planJoinOn(ASTTableJoin & table_join, PlanBuilder & le right_builder.withScope(joined_scope, joined_field_symbols); // 2. prepare join keys - auto [left_keys, right_keys] = prepareJoinOnKeys(table_join, left_builder, right_builder); + auto [left_keys, right_keys, key_ids_null_safe] = prepareJoinOnKeys(table_join, left_builder, right_builder); // 3. build join filter ASTPtr join_filter = PredicateConst::TRUE_VALUE; @@ -782,6 +786,7 @@ void QueryPlannerVisitor::planJoinOn(ASTTableJoin & table_join, PlanBuilder & le context->getSettingsRef().optimize_read_in_order, left_keys, right_keys, + key_ids_null_safe, isNormalInnerJoin(table_join) ? PredicateConst::TRUE_VALUE : join_filter, false, std::nullopt, @@ -855,13 +860,14 @@ QueryPlannerVisitor::prepareJoinUsingKeys(ASTTableJoin & table_join, PlanBuilder return {left_keys, right_keys}; } -std::pair +std::tuple> QueryPlannerVisitor::prepareJoinOnKeys(ASTTableJoin & table_join, PlanBuilder & left_builder, PlanBuilder & right_builder) { auto & join_analysis = analysis.getJoinOnAnalysis(table_join); ExpressionsAndTypes left_conditions; ExpressionsAndTypes right_conditions; + std::vector key_ids_null_safe; // for asof join, equality exprs & inequality exprs forms the join keys // for other joins, equality exprs forms the join keys, inequality exprs & complex exprs forms the join filter @@ -869,6 +875,7 @@ QueryPlannerVisitor::prepareJoinOnKeys(ASTTableJoin & table_join, PlanBuilder & { left_conditions.emplace_back(condition.left_ast, condition.left_coercion); right_conditions.emplace_back(condition.right_ast, condition.right_coercion); + key_ids_null_safe.emplace_back(condition.null_safe); } if (isAsofJoin(table_join)) @@ -884,7 +891,7 @@ QueryPlannerVisitor::prepareJoinOnKeys(ASTTableJoin & table_join, PlanBuilder & Names left_symbols = left_builder.projectExpressionsWithCoercion(left_conditions); Names right_symbols = right_builder.projectExpressionsWithCoercion(right_conditions); - return {left_symbols, right_symbols}; + return {left_symbols, right_symbols, key_ids_null_safe}; } DataStream QueryPlannerVisitor::getJoinOutputStream(ASTTableJoin &, PlanBuilder & left_builder, PlanBuilder & right_builder) @@ -2250,7 +2257,13 @@ RelationPlan QueryPlannerVisitor::planSetOperation(ASTs & selects, ASTSelectWith if (sub_plans.size() == 1) return sub_plans.front(); + std::vector sub_plan_types; + for (const auto & sub_plan : sub_plans) + sub_plan_types.emplace_back(sub_plan.getRoot()->getOutputNamesToTypes()); + FieldSubColumnIDs sub_column_positions; + DataTypes sub_column_type_coercions; + // compute common subcolumns for each field if (enable_subcolumn_optimization_through_union) { @@ -2273,6 +2286,24 @@ RelationPlan QueryPlannerVisitor::planSetOperation(ASTs & selects, ASTSelectWith for (const auto & sub_col_id : common_sub_col_ids) sub_column_positions.emplace_back(field_id, sub_col_id); } + + if (enable_implicit_type_conversion) + { + for (const auto & field_sub_col_id : sub_column_positions) + { + DataTypes sub_col_types; + for (size_t select_id = 0; select_id < selects.size(); ++select_id) + { + auto sub_col_symbol = sub_plans[select_id] + .getFieldSymbolInfos() + .at(field_sub_col_id.first) + .sub_column_symbols.at(field_sub_col_id.second); + sub_col_types.push_back(sub_plan_types[select_id].at(sub_col_symbol)); + } + sub_column_type_coercions.emplace_back( + getLeastSupertype(sub_col_types, context->getSettingsRef().allow_extended_type_conversion)); + } + } } // 2. prepare sub plan & collect input info @@ -2283,6 +2314,7 @@ RelationPlan QueryPlannerVisitor::planSetOperation(ASTs & selects, ASTSelectWith { auto & select = selects[select_id]; auto & sub_plan = sub_plans[select_id]; + const auto & name_to_type = sub_plan_types[select_id]; // prune invisible columns, copy duplicated columns, sort columns by a specific order(primary columns + sub columns) sub_plan = projectFieldSymbols(sub_plan, sub_column_positions); @@ -2290,20 +2322,37 @@ RelationPlan QueryPlannerVisitor::planSetOperation(ASTs & selects, ASTSelectWith auto column_names1 = sub_plan.getRoot()->getOutputNames(); #endif // coerce to common type - if (enable_implicit_type_conversion && analysis.hasRelationTypeCoercion(*select)) + if (enable_implicit_type_conversion) { - auto field_symbol_infos = sub_plan.getFieldSymbolInfos(); - const auto & target_types = analysis.getRelationTypeCoercion(*select); - assert(target_types.size() == field_symbol_infos.size()); NameToType symbols_and_types; + auto field_symbol_infos = sub_plan.getFieldSymbolInfos(); - for (size_t i = 0; i < target_types.size(); ++i) + if (analysis.hasRelationTypeCoercion(*select)) { - auto target_type = target_types[i]; - if (target_type) - symbols_and_types.emplace(field_symbol_infos[i].getPrimarySymbol(), target_type); + const auto & target_types = analysis.getRelationTypeCoercion(*select); + assert(target_types.size() == field_symbol_infos.size()); + + for (size_t i = 0; i < target_types.size(); ++i) + { + auto target_type = target_types[i]; + if (target_type) + symbols_and_types.emplace(field_symbol_infos[i].getPrimarySymbol(), target_type); + } } + if (!sub_column_type_coercions.empty()) + { + for (size_t pos = 0; pos < sub_column_type_coercions.size(); ++pos) + { + const auto & field_sub_col_id = sub_column_positions.at(pos); + const auto & sub_col_symbol + = field_symbol_infos.at(field_sub_col_id.first).sub_column_symbols.at(field_sub_col_id.second); + auto sub_col_type = name_to_type.at(sub_col_symbol); + auto target_type = sub_column_type_coercions[pos]; + if (!target_type->equals(*sub_col_type)) + symbols_and_types.emplace(sub_col_symbol, target_type); + } + } auto coercion_result = coerceTypesForSymbols(sub_plan.getRoot(), symbols_and_types, true); mapFieldSymbolInfos(field_symbol_infos, coercion_result.mappings, false); sub_plan = RelationPlan{coercion_result.plan, field_symbol_infos}; diff --git a/src/QueryPlan/ReadFromMergeTree.cpp b/src/QueryPlan/ReadFromMergeTree.cpp index b3f236742a8..cd3934c832d 100644 --- a/src/QueryPlan/ReadFromMergeTree.cpp +++ b/src/QueryPlan/ReadFromMergeTree.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -163,6 +164,98 @@ static Array extractMapColumnKeys(const MergeTreeMetaBase & data, const MergeTre return res; } +static bool isSamePartition(const RangesInDataPart & lhs, const RangesInDataPart & rhs) +{ + return lhs.data_part->partition.value == rhs.data_part->partition.value; +} + +static bool canReadInPartitionOrder( + const StorageInMemoryMetadata & metadata, + const InputOrderInfo & input_order_info, + const ASTSelectQuery & select) +{ + if (!metadata.isPartitionKeyDefined() || !metadata.isSortingKeyDefined()) + return false; + + const auto & partition_key = metadata.getPartitionKey(); + Names minmax_columns = partition_key.expression->getRequiredColumns(); + /// for simplicity, only support table with one partition key + if (partition_key.column_names.size() != 1 || minmax_columns.size() != 1) + return false; + + String partition_column = minmax_columns[0]; + Names sorting_columns = metadata.getSortingKeyColumns(); + chassert(sorting_columns.size() >= input_order_info.order_key_prefix_descr.size()); + /// optimizer guarantees that order_key_prefix is a prefix of sorting columns + sorting_columns.resize(input_order_info.order_key_prefix_descr.size()); + + /// sorting columns should contain partition column + auto partition_column_it = std::find(sorting_columns.begin(), sorting_columns.end(), partition_column); + if (partition_column_it == sorting_columns.end()) + return false; + + /// Allow table "partition by c order by (a, b, c)" for query "where a={} and b={} order by c", + /// where all sorting columns before partition column match single value, + /// note that in this case, input order is (a, b, c) + if (partition_column_it != sorting_columns.begin()) + { + NameSet single_value_columns; + auto collect = [&](const ASTPtr & filter) + { + if (!filter) + return; + + for (const auto & conjunct : PredicateUtils::extractConjuncts(filter->clone())) + { + const auto * func = conjunct->as(); + if (!func || func->name != "equals") + continue; + const auto * column = func->arguments->children[0]->as(); + const auto * literal = func->arguments->children[1]->as(); + if (column && literal) + single_value_columns.insert(column->name()); + } + }; + collect(select.where()); + collect(select.prewhere()); + auto match_single_value = [&](const String & name) { return single_value_columns.count(name); }; + if (!std::all_of(sorting_columns.begin(), partition_column_it, match_single_value)) + return false; + } + + /// fast path for: order by sort_column partition by sort_column + if (partition_key.column_names.front() == *partition_column_it) + return true; + + /// Allow "partition by func(x) order by (x)" where func is monotonic nondecreasing + IFunction::Monotonicity monotonicity; + for (const auto & action : partition_key.expression->getActions()) + { + if (action.node->type != ActionsDAG::ActionType::FUNCTION) + { + continue; + } + + /// Allow only one simple monotonic functions with one argument + if (monotonicity.is_monotonic) + { + monotonicity.is_monotonic = false; + break; + } + + if (action.node->children.size() != 1 || action.node->children.at(0)->result_name != *partition_column_it) + break; + + const auto & func = *action.node->function_base; + if (!func.hasInformationAboutMonotonicity()) + break; + + monotonicity = func.getMonotonicityForRange(*func.getArgumentTypes().at(0), {}, {}); + } + + return monotonicity.is_monotonic && monotonicity.is_positive; +} + ReadFromMergeTree::ReadFromMergeTree( MergeTreeMetaBase::DataPartsVector parts_, MergeTreeMetaBase::DeleteBitmapGetter delete_bitmap_getter_, @@ -325,18 +418,17 @@ template ProcessorPtr ReadFromMergeTree::createSource( const RangesInDataPart & part, const Names & required_columns, - const MergeTreeStreamSettings & stream_settings) + const MergeTreeStreamSettings & stream_settings, + const MarkRangesFilterCallback & range_filter_callback) { return std::make_shared( data, storage_snapshot, part.data_part, std::move(combineFilterBitmap(part, delete_bitmap_getter)), required_columns, part.ranges, - query_info, true, stream_settings, virt_column_names, part.part_index_in_query); + query_info, true, stream_settings, virt_column_names, part.part_index_in_query, range_filter_callback); } -Pipe ReadFromMergeTree::readInOrder( - RangesInDataParts parts_with_range, - Names required_columns, - ReadType read_type, - bool use_uncompressed_cache) +Pipe ReadFromMergeTree::readInOrder(RangesInDataParts parts_with_range, + Names required_columns, ReadType read_type, bool use_uncompressed_cache, + const std::shared_ptr & delayed_index) { Pipes pipes; MergeTreeStreamSettings stream_settings{ @@ -348,13 +440,22 @@ Pipe ReadFromMergeTree::readInOrder( .reader_settings = reader_settings }; + MarkRangesFilterCallback filter_callback; + if (delayed_index != nullptr) + { + filter_callback = [reader_settings = this->reader_settings, ctx = this->context, delayed_index](const MergeTreeDataPartPtr& part, const MarkRanges& mark_ranges) { + return MergeTreeDataSelectExecutor::filterMarkRangesForPartByInvertedIndex( + part, mark_ranges, delayed_index, ctx, reader_settings); + }; + } + if (!query_info.atomic_predicates.empty()) { for (const auto & part : parts_with_range) { auto source = read_type == ReadType::InReverseOrder - ? createSource(part, required_columns, stream_settings) - : createSource(part, required_columns, stream_settings); + ? createSource(part, required_columns, stream_settings, filter_callback) + : createSource(part, required_columns, stream_settings, filter_callback); pipes.emplace_back(std::move(source)); } @@ -364,8 +465,8 @@ Pipe ReadFromMergeTree::readInOrder( for (const auto & part : parts_with_range) { auto source = read_type == ReadType::InReverseOrder - ? createSource(part, required_columns, stream_settings) - : createSource(part, required_columns, stream_settings); + ? createSource(part, required_columns, stream_settings, filter_callback) + : createSource(part, required_columns, stream_settings, filter_callback); pipes.emplace_back(std::move(source)); } @@ -385,14 +486,25 @@ Pipe ReadFromMergeTree::readInOrder( } Pipe ReadFromMergeTree::read( - RangesInDataParts parts_with_range, Names required_columns, ReadType read_type, - size_t max_streams, size_t min_marks_for_concurrent_read, bool use_uncompressed_cache) + RangesInDataParts parts_with_range, + Names required_columns, + ReadType read_type, + size_t max_streams, + size_t min_marks_for_concurrent_read, + bool use_uncompressed_cache, + const std::shared_ptr & delayed_index) { if (read_type == ReadType::Default && max_streams > 1) - return readFromPool(parts_with_range, required_columns, max_streams, - min_marks_for_concurrent_read, use_uncompressed_cache); + { + if (unlikely(delayed_index != nullptr)) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Some skip index is delayed to pipeline execution stage"); + } + return readFromPool(parts_with_range, required_columns, max_streams, min_marks_for_concurrent_read, use_uncompressed_cache); + } - auto pipe = readInOrder(parts_with_range, required_columns, read_type, use_uncompressed_cache); + auto pipe = readInOrder(parts_with_range, required_columns, + read_type, use_uncompressed_cache, delayed_index); /// Use ConcatProcessor to concat sources together. /// It is needed to read in parts order (and so in PK order) if single thread is used. @@ -483,8 +595,14 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreams( } } - return read(std::move(parts_with_ranges), column_names, ReadType::Default, - num_streams, info.min_marks_for_concurrent_read, info.use_uncompressed_cache); + return read( + std::move(parts_with_ranges), + column_names, + ReadType::Default, + num_streams, + info.min_marks_for_concurrent_read, + info.use_uncompressed_cache, + nullptr); } static ActionsDAGPtr createProjection(const Block & header) @@ -495,12 +613,78 @@ static ActionsDAGPtr createProjection(const Block & header) return projection; } +namespace +{ +template +struct PartitionValueComparator +{ + bool operator()(const RangesInDataPart & lhs, const RangesInDataPart & rhs) const + { + const auto & l = lhs.data_part->partition.value[0]; + const auto & r = rhs.data_part->partition.value[0]; + if constexpr (ascend) + return l < r; + else + return l > r; + } +}; +} // anonymouse namespace + +Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithPartitionOrder( + RangesInDataParts && parts_with_ranges, + const Names & column_names, + const ActionsDAGPtr & sorting_key_prefix_expr, + ActionsDAGPtr & out_projection, + const InputOrderInfoPtr & input_order_info, + const std::shared_ptr & delayed_index) +{ + chassert(!parts_with_ranges.empty()); + + /// sort parts by partition value + if (input_order_info->direction > 0) + std::sort(parts_with_ranges.begin(), parts_with_ranges.end(), PartitionValueComparator{}); + else + std::sort(parts_with_ranges.begin(), parts_with_ranges.end(), PartitionValueComparator{}); + + Pipes pipes; + auto prev = parts_with_ranges.begin(); + auto end = parts_with_ranges.end(); + + while (prev != end) + { + auto curr = std::next(prev); + while (curr != end && isSamePartition(*prev, *curr)) + ++curr; + + auto pipe = spreadMarkRangesAmongStreamsWithOrder( + {std::make_move_iterator(prev), std::make_move_iterator(curr)}, + column_names, + sorting_key_prefix_expr, + out_projection, + input_order_info, + // for the result pipe to output ordered tuples for this partition + 1 /*num_streams*/, true /*need_preliminary_merge*/, + delayed_index); + + pipes.emplace_back(std::move(pipe)); + prev = curr; + } + + auto res = Pipe::unitePipes(std::move(pipes)); + if (res.numOutputPorts() > 1) + res.addTransform(std::make_shared(res.getHeader(), res.numOutputPorts())); + return res; +} + Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithOrder( RangesInDataParts && parts_with_ranges, const Names & column_names, const ActionsDAGPtr & sorting_key_prefix_expr, ActionsDAGPtr & out_projection, - const InputOrderInfoPtr & input_order_info) + const InputOrderInfoPtr & input_order_info, + size_t num_streams, + bool need_preliminary_merge, + const std::shared_ptr & delayed_index) { const auto & settings = context->getSettingsRef(); const auto data_settings = data.getSettings(); @@ -558,12 +742,11 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithOrder( return new_ranges; }; - const size_t min_marks_per_stream = (info.sum_marks - 1) / requested_num_streams + 1; - bool need_preliminary_merge = (parts_with_ranges.size() > settings.read_in_order_two_level_merge_threshold); + const size_t min_marks_per_stream = (info.sum_marks - 1) / num_streams + 1; Pipes pipes; - for (size_t i = 0; i < requested_num_streams && !parts_with_ranges.empty(); ++i) + for (size_t i = 0; i < num_streams && !parts_with_ranges.empty(); ++i) { size_t need_marks = min_marks_per_stream; RangesInDataParts new_parts; @@ -630,10 +813,11 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithOrder( : ReadFromMergeTree::ReadType::InReverseOrder; pipes.emplace_back(read(std::move(new_parts), column_names, read_type, - requested_num_streams, info.min_marks_for_concurrent_read, info.use_uncompressed_cache)); + num_streams, info.min_marks_for_concurrent_read, info.use_uncompressed_cache, + delayed_index)); } - if (need_preliminary_merge) + if (need_preliminary_merge && !pipes.empty()) { SortDescription sort_description; for (size_t j = 0; j < input_order_info->order_key_prefix_descr.size(); ++j) @@ -644,9 +828,6 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithOrder( for (auto & pipe : pipes) { - /// Drop temporary columns, added by 'sorting_key_prefix_expr' - out_projection = createProjection(pipe.getHeader()); - pipe.addSimpleTransform([sorting_key_expr](const Block & header) { return std::make_shared(header, sorting_key_expr); @@ -663,6 +844,12 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsWithOrder( pipe.addTransform(std::move(transform)); } } + + if (!out_projection) + { + /// Drop temporary columns, added by 'sorting_key_prefix_expr' + out_projection = createProjection(pipes.front().getHeader()); + } } return Pipe::unitePipes(std::move(pipes)); @@ -854,7 +1041,7 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsFinal( continue; pipe = read(std::move(new_parts), column_names, ReadFromMergeTree::ReadType::InOrder, - num_streams, 0, info.use_uncompressed_cache); + num_streams, 0, info.use_uncompressed_cache, nullptr); /// Drop temporary columns, added by 'sorting_key_expr' if (!out_projection) @@ -917,7 +1104,8 @@ Pipe ReadFromMergeTree::spreadMarkRangesAmongStreamsFinal( num_streams_for_lonely_parts = std::max((sum_marks_in_lonely_parts + min_marks_for_concurrent_read - 1) / min_marks_for_concurrent_read, lonely_parts.size()); auto pipe = read(std::move(lonely_parts), column_names, ReadFromMergeTree::ReadType::Default, - num_streams_for_lonely_parts, min_marks_for_concurrent_read, info.use_uncompressed_cache); + num_streams_for_lonely_parts, min_marks_for_concurrent_read, info.use_uncompressed_cache, + nullptr); /// Drop temporary columns, added by 'sorting_key_expr' if (!out_projection) @@ -1115,6 +1303,7 @@ MergeTreeDataSelectAnalysisResultPtr ReadFromMergeTree::selectRangesToRead( log, num_streams, result.index_stats, + *(result.delayed_indices), settings.enable_skip_index, data, result.sampling.use_sampling, @@ -1154,12 +1343,14 @@ MergeTreeDataSelectAnalysisResultPtr ReadFromMergeTree::selectRangesToRead( size_t sum_marks = 0; size_t sum_ranges = 0; size_t sum_rows = 0; + NameSet partition_ids; for (const auto & part : result.parts_with_ranges) { sum_ranges += part.ranges.size(); sum_marks += part.getMarksCount(); sum_rows += part.getRowsCount(); + partition_ids.insert(part.data_part->info.partition_id); } result.total_parts = total_parts; @@ -1170,6 +1361,7 @@ MergeTreeDataSelectAnalysisResultPtr ReadFromMergeTree::selectRangesToRead( result.selected_marks_pk = sum_marks_pk; result.total_marks_pk = total_marks_pk; result.selected_rows = sum_rows; + result.selected_partitions = partition_ids.size(); const auto & input_order_info = query_info.input_order_info ? query_info.input_order_info @@ -1205,6 +1397,9 @@ void ReadFromMergeTree::initializePipeline(QueryPipeline & pipeline, const Build result.selected_marks, result.selected_ranges); + if (context->getSettingsRef().report_segment_profiles) + fillRuntimeAttributeDescriptions(result); + ProfileEvents::increment(ProfileEvents::SelectedParts, result.selected_parts); ProfileEvents::increment(ProfileEvents::SelectedRanges, result.selected_ranges); ProfileEvents::increment(ProfileEvents::SelectedMarks, result.selected_marks); @@ -1249,6 +1444,7 @@ void ReadFromMergeTree::initializePipeline(QueryPipeline & pipeline, const Build Pipe pipe; const auto & settings = context->getSettingsRef(); + bool can_read_in_partition_order = false; if (select.final()) { @@ -1278,12 +1474,32 @@ void ReadFromMergeTree::initializePipeline(QueryPipeline & pipeline, const Build auto syntax_result = TreeRewriter(context).analyze(order_key_prefix_ast, metadata_for_reading->getColumns().getAllPhysical()); auto sorting_key_prefix_expr = ExpressionAnalyzer(order_key_prefix_ast, syntax_result, context).getActionsDAG(false); - pipe = spreadMarkRangesAmongStreamsWithOrder( - std::move(result.parts_with_ranges), - column_names_to_read, - sorting_key_prefix_expr, - result_projection, - input_order_info); + can_read_in_partition_order = (settings.optimize_read_in_partition_order || settings.force_read_in_partition_order) + && canReadInPartitionOrder(*metadata_for_reading, *input_order_info, query_info.query->as()); + + if (can_read_in_partition_order && result.selected_partitions > 1) + { + pipe = spreadMarkRangesAmongStreamsWithPartitionOrder( + std::move(result.parts_with_ranges), + column_names_to_read, + sorting_key_prefix_expr, + result_projection, + input_order_info, + result.delayed_indices); + } + else + { + bool need_preliminary_merge = (result.parts_with_ranges.size() > settings.read_in_order_two_level_merge_threshold); + pipe = spreadMarkRangesAmongStreamsWithOrder( + std::move(result.parts_with_ranges), + column_names_to_read, + sorting_key_prefix_expr, + result_projection, + input_order_info, + requested_num_streams, + need_preliminary_merge, + result.delayed_indices); + } } else { @@ -1292,6 +1508,9 @@ void ReadFromMergeTree::initializePipeline(QueryPipeline & pipeline, const Build column_names_to_read); } + if (settings.force_read_in_partition_order && !can_read_in_partition_order) + throw Exception(ErrorCodes::INDEX_NOT_USED, "Cannot read in partition order but 'force_read_in_partition_order' is set"); + if (pipe.empty()) { pipeline.init(Pipe(std::make_shared(getOutputStream().header))); @@ -1538,6 +1757,58 @@ std::shared_ptr ReadFromMergeTree::copy(ContextPtr) const throw Exception("ReadFromMergeTree can not copy", ErrorCodes::NOT_IMPLEMENTED); } +void ReadFromMergeTree::fillRuntimeAttributeDescriptions(const ReadFromMergeTree::AnalysisResult & result) +{ + auto index_stats = result.index_stats; + if (!result.index_stats.empty()) + { + RuntimeAttributeDescription index_desc; + for (size_t i = 0; i < index_stats.size(); ++i) + { + const auto & stat = index_stats[i]; + if (stat.type == IndexType::None) + continue; + std::stringstream out; + out << "Type: " << indexTypeToString(stat.type) << ";"; + if (!stat.name.empty()) + out << " Name: " << stat.name << ";"; + if (!stat.description.empty()) + out << " Description: " << stat.description << ";"; + if (!stat.used_keys.empty()) + { + String keys = fmt::format("{}", fmt::join(stat.used_keys, ",")); + out << " Keys: " << keys << ";"; + } + if (!stat.condition.empty()) + out << " Condition: " << stat.condition << ";"; + out << " Parts: " << stat.num_parts_after; + if (i) + out << '/' << index_stats[i - 1].num_parts_after; + out << ";"; + out << " Granules: " << stat.num_granules_after; + if (i) + out << '/' << index_stats[i - 1].num_granules_after; + out << ";"; + index_desc.name_and_detail.emplace_back(indexTypeToString(stat.type), out.str()); + } + index_desc.description = "Indexes"; + attribute_descriptions.emplace(index_desc.description, std::move(index_desc)); + } + + RuntimeAttributeDescription parts_desc; + String selected_parts_info = fmt::format( + "Selected {}/{} parts by partition key, {} parts by primary key, {}/{} marks by primary key, {} marks to read from {} ranges", + result.parts_before_pk, + result.total_parts, + result.selected_parts, + result.selected_marks_pk, + result.total_marks_pk, + result.selected_marks, + result.selected_ranges); + parts_desc.description = selected_parts_info; + attribute_descriptions.emplace("SelectParts", std::move(parts_desc)); +} + bool MergeTreeDataSelectAnalysisResult::error() const { return std::holds_alternative(result); diff --git a/src/QueryPlan/ReadFromMergeTree.h b/src/QueryPlan/ReadFromMergeTree.h index bcba9d5f286..947754eaf25 100644 --- a/src/QueryPlan/ReadFromMergeTree.h +++ b/src/QueryPlan/ReadFromMergeTree.h @@ -42,6 +42,15 @@ using MergeTreeDataSelectAnalysisResultPtr = std::shared_ptr; +/// Contains delayed skip index information which should execute +/// on pipeline execution stage +struct DelayedSkipIndex +{ + std::unordered_map> indices; +}; +using MarkRangesFilterCallback = std::function; + + /// This step is created to read from MergeTree* table. /// For now, it takes a list of parts and creates source from it. class ReadFromMergeTree final : public ISourceStep @@ -96,6 +105,7 @@ class ReadFromMergeTree final : public ISourceStep IndexStats index_stats; Names column_names_to_read; ReadFromMergeTree::ReadType read_type = ReadFromMergeTree::ReadType::Default; + std::shared_ptr delayed_indices = std::make_shared(); UInt64 total_parts = 0; UInt64 parts_before_pk = 0; UInt64 selected_parts = 0; @@ -104,6 +114,7 @@ class ReadFromMergeTree final : public ISourceStep UInt64 selected_marks_pk = 0; UInt64 total_marks_pk = 0; UInt64 selected_rows = 0; + UInt64 selected_partitions = 0; }; ReadFromMergeTree( @@ -137,6 +148,8 @@ class ReadFromMergeTree final : public ISourceStep void describeIndexes(JSONBuilder::JSONMap & map) const override; std::shared_ptr copy(ContextPtr ptr) const override; + void fillRuntimeAttributeDescriptions(const ReadFromMergeTree::AnalysisResult & result); + StorageID getStorageID() const { return data.getStorageID(); } UInt64 getSelectedParts() const { return selected_parts; } UInt64 getSelectedRows() const { return selected_rows; } @@ -193,12 +206,12 @@ class ReadFromMergeTree final : public ISourceStep UInt64 selected_rows = 0; UInt64 selected_marks = 0; - Pipe read(RangesInDataParts parts_with_range, Names required_columns, ReadType read_type, size_t max_streams, size_t min_marks_for_concurrent_read, bool use_uncompressed_cache); + Pipe read(RangesInDataParts parts_with_range, Names required_columns, ReadType read_type, size_t max_streams, size_t min_marks_for_concurrent_read, bool use_uncompressed_cache, const std::shared_ptr & delayed_index); Pipe readFromPool(RangesInDataParts parts_with_ranges, Names required_columns, size_t max_streams, size_t min_marks_for_concurrent_read, bool use_uncompressed_cache); - Pipe readInOrder(RangesInDataParts parts_with_range, Names required_columns, ReadType read_type, bool use_uncompressed_cache); + Pipe readInOrder(RangesInDataParts parts_with_range, Names required_columns, ReadType read_type, bool use_uncompressed_cache, const std::shared_ptr & delayed_index); template - ProcessorPtr createSource(const RangesInDataPart & part, const Names & required_columns,const MergeTreeStreamSettings & stream_settings); + ProcessorPtr createSource(const RangesInDataPart & part, const Names & required_columns,const MergeTreeStreamSettings & stream_settings, const MarkRangesFilterCallback & range_filter_callback); Pipe spreadMarkRangesAmongStreams( RangesInDataParts && parts_with_ranges, @@ -209,7 +222,18 @@ class ReadFromMergeTree final : public ISourceStep const Names & column_names, const ActionsDAGPtr & sorting_key_prefix_expr, ActionsDAGPtr & out_projection, - const InputOrderInfoPtr & input_order_info); + const InputOrderInfoPtr & input_order_info, + size_t num_streams, + bool need_preliminary_merge, + const std::shared_ptr & delayed_index); + + Pipe spreadMarkRangesAmongStreamsWithPartitionOrder( + RangesInDataParts && parts_with_ranges, + const Names & column_names, + const ActionsDAGPtr & sorting_key_prefix_expr, + ActionsDAGPtr & out_projection, + const InputOrderInfoPtr & input_order_info, + const std::shared_ptr & delayed_index); Pipe spreadMarkRangesAmongStreamsFinal( RangesInDataParts && parts, diff --git a/src/QueryPlan/SortingStep.cpp b/src/QueryPlan/SortingStep.cpp index 50aa9f80a53..e789ea5261e 100644 --- a/src/QueryPlan/SortingStep.cpp +++ b/src/QueryPlan/SortingStep.cpp @@ -13,19 +13,20 @@ * limitations under the License. */ +#include #include #include +#include #include #include #include #include #include #include -#include #include +#include +#include #include -#include "Core/SettingsEnums.h" -#include "QueryPlan/PlanSerDerHelper.h" namespace DB { @@ -51,7 +52,7 @@ SortingStep::SortingStep( Stage stage_, SortDescription prefix_description_, bool enable_adaptive_spill_) - : ITransformingStep(input_stream_, input_stream_.header, getTraits(limit_, stage_ != Stage::PARTIAL)) + : ITransformingStep(input_stream_, input_stream_.header, getTraits(limit_, stage_ != Stage::PARTIAL && stage_ != Stage::PARTIAL_NO_MERGE)) , result_description(result_description_) , limit(limit_) , stage(stage_) @@ -62,7 +63,7 @@ SortingStep::SortingStep( /// TODO: support mannual/auto spill output_stream->sort_description = result_description; output_stream->sort_mode - = (input_stream_.has_single_port || stage_ != Stage::PARTIAL) ? DataStream::SortMode::Stream : DataStream::SortMode::Port; + = (input_stream_.has_single_port || (stage_ != Stage::PARTIAL && stage_ != Stage::PARTIAL_NO_MERGE)) ? DataStream::SortMode::Stream : DataStream::SortMode::Port; } void SortingStep::setInputStreams(const DataStreams & input_streams_) @@ -87,36 +88,52 @@ void SortingStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPi auto desc_copy = result_description; - if (stage == Stage::FULL || stage == Stage::PARTIAL) + if (stage == Stage::FULL || stage == Stage::PARTIAL || stage == Stage::PARTIAL_NO_MERGE) { // finish sorting if (!prefix_description.empty()) { bool need_finish_sorting = (prefix_description.size() < result_description.size()); + + if (!need_finish_sorting) + { + if (pipeline.getNumStreams() > 1 && stage != Stage::PARTIAL_NO_MERGE) + { + auto transform = std::make_shared( + pipeline.getHeader(), pipeline.getNumStreams(), prefix_description, local_settings.max_block_size, getLimitValue()); + + pipeline.addTransform(std::move(transform)); + } + if (getLimitValue() > 0) + { + auto transform = std::make_shared( + pipeline.getHeader(), getLimitValue(), 0, pipeline.getNumStreams(), false, false, result_description); + pipeline.addTransform(std::move(transform)); + } + return; + } + if (pipeline.getNumStreams() > 1) { - UInt64 limit_for_merging = (need_finish_sorting ? 0 : getLimitValue()); + UInt64 limit_for_merging = 0; // need_finish_sorting auto transform = std::make_shared( pipeline.getHeader(), pipeline.getNumStreams(), prefix_description, local_settings.max_block_size, limit_for_merging); pipeline.addTransform(std::move(transform)); } - if (need_finish_sorting) - { - pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type) -> ProcessorPtr { - if (stream_type != QueryPipeline::StreamType::Main) - return nullptr; - - return std::make_shared(header, result_description, getLimitValue()); - }); - - /// NOTE limits are not applied to the size of temporary sets in FinishSortingTransform - pipeline.addSimpleTransform([&](const Block & header) -> ProcessorPtr { - return std::make_shared( - header, prefix_description, result_description, local_settings.max_block_size, getLimitValue()); - }); - } + pipeline.addSimpleTransform([&](const Block & header, QueryPipeline::StreamType stream_type) -> ProcessorPtr { + if (stream_type != QueryPipeline::StreamType::Main) + return nullptr; + + return std::make_shared(header, result_description, getLimitValue()); + }); + + /// NOTE limits are not applied to the size of temporary sets in FinishSortingTransform + pipeline.addSimpleTransform([&](const Block & header) -> ProcessorPtr { + return std::make_shared( + header, prefix_description, result_description, local_settings.max_block_size, getLimitValue()); + }); return; } @@ -155,6 +172,16 @@ void SortingStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPi local_settings.min_free_disk_space_for_temporary_data, local_settings.spill_mode == SpillMode::AUTO); }); + + /// If there are several streams, then we merge them into one + if (pipeline.getNumStreams() > 1 && stage != Stage::PARTIAL_NO_MERGE) + { + auto transform = std::make_shared( + pipeline.getHeader(), pipeline.getNumStreams(), desc_copy, local_settings.max_block_size, getLimitValue()); + + pipeline.addTransform(std::move(transform)); + } + return; } /// If there are several streams, then we merge them into one diff --git a/src/QueryPlan/SortingStep.h b/src/QueryPlan/SortingStep.h index 3af210406cd..4454a2b1cf4 100644 --- a/src/QueryPlan/SortingStep.h +++ b/src/QueryPlan/SortingStep.h @@ -31,7 +31,8 @@ class SortingStep : public ITransformingStep Protos::SortingStep::Stage, // proto enum message (FULL), (MERGE), - (PARTIAL) + (PARTIAL), + (PARTIAL_NO_MERGE) ); explicit SortingStep(const DataStream & input_stream, SortDescription description_, SizeOrVariable limit_, Stage stage_, SortDescription prefix_description_ = {}, bool enable_adaptive_spill_ = false); diff --git a/src/QueryPlan/SymbolMapper.cpp b/src/QueryPlan/SymbolMapper.cpp index 0e6ce57b2ae..d49eafc3f8b 100644 --- a/src/QueryPlan/SymbolMapper.cpp +++ b/src/QueryPlan/SymbolMapper.cpp @@ -64,7 +64,7 @@ SymbolMapper SymbolMapper::symbolMapper(std::unordered_map & map while (it != mapping.end() && it->second != symbol) { if (++lookup > MAX_LOOKUP_TIMES) - throw Exception("endless loop in SymbolMapper", ErrorCodes::LOGICAL_ERROR); + throw Exception("endless loop in SymbolMapper", ErrorCodes::LOGICAL_ERROR); symbol = it->second; it = mapping.find(symbol); } @@ -82,7 +82,7 @@ SymbolMapper SymbolMapper::symbolReallocator(std::unordered_map while (it != mapping.end() && it->second != symbol) { if (++lookup > MAX_LOOKUP_TIMES) - throw Exception("endless loop in SymbolMapper", ErrorCodes::LOGICAL_ERROR); + throw Exception("endless loop in SymbolMapper", ErrorCodes::LOGICAL_ERROR); symbol = it->second; it = mapping.find(symbol); } @@ -214,6 +214,7 @@ Partitioning SymbolMapper::map(const Partitioning & partition) map(partition.getColumns()), partition.isRequireHandle(), partition.getBuckets(), + partition.getBucketExpr(), partition.isEnforceRoundRobin(), partition.getComponent()}; } @@ -229,6 +230,7 @@ std::shared_ptr SymbolMapper::map(const JoinStep & join) join.getKeepLeftReadInOrder(), map(join.getLeftKeys()), map(join.getRightKeys()), + join.getKeyIdsNullSafe(), map(join.getFilter()), join.isHasUsing(), join.getRequireRightKeys(), @@ -404,10 +406,10 @@ SortDescription SymbolMapper::map(const SortDescription & sort_desc) return res; } -std::map SymbolMapper::map(const std::map & group_id_non_null_symbol) +std::map SymbolMapper::map(const std::map & group_id_non_null_symbol) { std::map res; - for(const auto & entry : group_id_non_null_symbol) + for (const auto & entry : group_id_non_null_symbol) { res[entry.first] = map(entry.second); } @@ -506,11 +508,11 @@ std::shared_ptr SymbolMapper::map(const FinalSampleStep & final std::shared_ptr SymbolMapper::map(const FinishSortingStep & finish_sorting) { return std::make_shared( - map(finish_sorting.getInputStreams()[0]), - SortDescription{map(finish_sorting.getPrefixDescription())}, - SortDescription{map(finish_sorting.getResultDescription())}, - finish_sorting.getMaxBlockSize(), - finish_sorting.getLimit()); + map(finish_sorting.getInputStreams()[0]), + SortDescription{map(finish_sorting.getPrefixDescription())}, + SortDescription{map(finish_sorting.getResultDescription())}, + finish_sorting.getMaxBlockSize(), + finish_sorting.getLimit()); } std::shared_ptr SymbolMapper::map(const IntersectStep & intersect) @@ -545,7 +547,8 @@ std::shared_ptr SymbolMapper::map(const TableScanStep & scan) // order matters as symbol mapper should traverse plan nodes bottom-up std::shared_ptr mapped_filter = scan.getPushdownFilterCast() ? map(*scan.getPushdownFilterCast()) : nullptr; std::shared_ptr mapped_projection = scan.getPushdownProjectionCast() ? map(*scan.getPushdownProjectionCast()) : nullptr; - std::shared_ptr mapped_aggregation = scan.getPushdownAggregationCast() ? map(*scan.getPushdownAggregationCast()) : nullptr; + std::shared_ptr mapped_aggregation + = scan.getPushdownAggregationCast() ? map(*scan.getPushdownAggregationCast()) : nullptr; auto mapped_scan = std::make_shared( std::move(mapped_output_stream), @@ -690,7 +693,11 @@ std::shared_ptr SymbolMapper::map(const ReadNothingStep & read_ std::shared_ptr SymbolMapper::map(const RemoteExchangeSourceStep & remote_exchange) { - return std::make_shared(remote_exchange.getInput(), map(remote_exchange.getInputStreams()[0]), remote_exchange.isAddTotals(), remote_exchange.isAddExtremes()); + return std::make_shared( + remote_exchange.getInput(), + map(remote_exchange.getInputStreams()[0]), + remote_exchange.isAddTotals(), + remote_exchange.isAddExtremes()); } @@ -754,7 +761,12 @@ std::shared_ptr SymbolMapper::map(const CTERefStep & cte_ref) std::shared_ptr SymbolMapper::map(const ExplainAnalyzeStep & step) { return std::make_shared( - map(step.getInputStreams()[0]), map(step.getOutputName()), step.getKind(), step.getContext(), step.getQueryPlan(), step.getSetting()); + map(step.getInputStreams()[0]), + map(step.getOutputName()), + step.getKind(), + step.getContext(), + step.getQueryPlan(), + step.getSetting()); } std::shared_ptr SymbolMapper::map(const LocalExchangeStep & step) @@ -764,7 +776,7 @@ std::shared_ptr SymbolMapper::map(const LocalExchangeStep & s std::shared_ptr SymbolMapper::map(const TableWriteStep & step) { - return std::make_shared(map(step.getInputStreams()[0]), step.getTarget()); + return std::make_shared(map(step.getInputStreams()[0]), step.getTarget(), step.isOutputProfiles()); } std::shared_ptr SymbolMapper::map(const OutfileWriteStep & step) @@ -790,7 +802,8 @@ std::shared_ptr SymbolMapper::map(const BufferStep & step) std::shared_ptr SymbolMapper::map(const TableFinishStep & step) { - return std::make_shared(map(step.getInputStreams()[0]), step.getTarget(), step.getOutputAffectedRowCountSymbol(), step.getQuery()); + return std::make_shared( + map(step.getInputStreams()[0]), step.getTarget(), step.getOutputAffectedRowCountSymbol(), step.getQuery(), step.isOutputProfiles()); } std::shared_ptr SymbolMapper::map(const IntermediateResultCacheStep & step) @@ -805,29 +818,31 @@ std::shared_ptr SymbolMapper::map(const MultiJoinStep & step) std::shared_ptr SymbolMapper::map(const TotalsHavingStep & step) { - return std::make_shared(map(step.getInputStreams()[0]), step.isOverflowRow(), map(step.getHavingFilter()), step.getTotalsMode(), step.getAutoIncludeThreshols(), step.isFinal()); + return std::make_shared( + map(step.getInputStreams()[0]), + step.isOverflowRow(), + map(step.getHavingFilter()), + step.getTotalsMode(), + step.getAutoIncludeThreshols(), + step.isFinal()); } std::shared_ptr SymbolMapper::map(const ExpandStep & step) { return std::make_shared( - map(step.getOutputStream()), + map(step.getOutputStream()), map(step.getAssignments()), - map(step.getNameToType()), + map(step.getNameToType()), map(step.getGroupIdSymbol()), step.getGroupIdValue(), - map(step.getGroupIdNonNullSymbol()) - ); + map(step.getGroupIdNonNullSymbol())); } class SymbolMapper::SymbolMapperVisitor : public StepVisitor { protected: #define VISITOR_DEF(TYPE) \ - QueryPlanStepPtr visit##TYPE##Step(const TYPE##Step & step, SymbolMapper & mapper) override \ - { \ - return mapper.map(step); \ - } + QueryPlanStepPtr visit##TYPE##Step(const TYPE##Step & step, SymbolMapper & mapper) override { return mapper.map(step); } APPLY_STEP_TYPES(VISITOR_DEF) #undef VISITOR_DEF }; diff --git a/src/QueryPlan/TableFinishStep.cpp b/src/QueryPlan/TableFinishStep.cpp index 5b2fd2b9f7e..70688132f6a 100644 --- a/src/QueryPlan/TableFinishStep.cpp +++ b/src/QueryPlan/TableFinishStep.cpp @@ -20,24 +20,33 @@ static ITransformingStep::Traits getTraits() TableFinishStep::TableFinishStep( const DataStream & input_stream_, TableWriteStep::TargetPtr target_, - String output_affected_row_count_symbol_, ASTPtr query_) - : ITransformingStep(input_stream_, input_stream_.header, getTraits()) + String output_affected_row_count_symbol_, ASTPtr query_, bool insert_select_with_profiles_) + : ITransformingStep(input_stream_, {}, getTraits()) , target(std::move(target_)) , output_affected_row_count_symbol(std::move(output_affected_row_count_symbol_)) , query(query_) + , insert_select_with_profiles(insert_select_with_profiles_) , log(&Poco::Logger::get("TableFinishStep")) { + if (insert_select_with_profiles) + { + Block new_header = {ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared(), "inserted_rows")}; + output_stream = DataStream{.header = std::move(new_header)}; + } + else + output_stream = {input_stream_.header}; } std::shared_ptr TableFinishStep::copy(ContextPtr) const { - return std::make_shared(input_streams[0], target, output_affected_row_count_symbol, query); + return std::make_shared(input_streams[0], target, output_affected_row_count_symbol, query, insert_select_with_profiles); } void TableFinishStep::transformPipeline(QueryPipeline & pipeline, const BuildQueryPipelineSettings & settings) { pipeline.resize(1); - pipeline.addTransform(std::make_shared(getInputStreams()[0].header, target->getStorage(), settings.context, query)); + pipeline.addTransform(std::make_shared( + getInputStreams()[0].header, target->getStorage(), settings.context, query, insert_select_with_profiles)); } void TableFinishStep::toProto(Protos::TableFinishStep & proto, bool) const diff --git a/src/QueryPlan/TableFinishStep.h b/src/QueryPlan/TableFinishStep.h index 1aa8714883e..435de564733 100644 --- a/src/QueryPlan/TableFinishStep.h +++ b/src/QueryPlan/TableFinishStep.h @@ -3,13 +3,15 @@ #include #include #include +#include +#include namespace DB { class TableFinishStep : public ITransformingStep { public: - TableFinishStep(const DataStream & input_stream_, TableWriteStep::TargetPtr target_, String output_affected_row_count_symbol_, ASTPtr query_); + TableFinishStep(const DataStream & input_stream_, TableWriteStep::TargetPtr target_, String output_affected_row_count_symbol_, ASTPtr query_, bool insert_select_with_profiles_ = false); String getName() const override { @@ -25,7 +27,13 @@ class TableFinishStep : public ITransformingStep void setInputStreams(const DataStreams & input_streams_) override { input_streams = input_streams_; - output_stream = DataStream{.header = std::move((input_streams_[0].header))}; + if (insert_select_with_profiles) + { + Block new_header = {ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared(), "inserted_rows")}; + output_stream = DataStream{.header = std::move(new_header)}; + } + else + output_stream = DataStream{.header = std::move((input_streams_[0].header))}; } TableWriteStep::TargetPtr getTarget() const @@ -38,6 +46,8 @@ class TableFinishStep : public ITransformingStep void setQuery(const ASTPtr & query_) { query = query_; } ASTPtr getQuery() const { return query; } + bool isOutputProfiles() const { return insert_select_with_profiles; } + void toProto(Protos::TableFinishStep & proto, bool for_hash_equals = false) const; static std::shared_ptr fromProto(const Protos::TableFinishStep & proto, ContextPtr context); @@ -45,6 +55,7 @@ class TableFinishStep : public ITransformingStep TableWriteStep::TargetPtr target; String output_affected_row_count_symbol; ASTPtr query; + bool insert_select_with_profiles; Poco::Logger * log; }; } diff --git a/src/QueryPlan/TableScanStep.cpp b/src/QueryPlan/TableScanStep.cpp index 606c6c23be3..1d164347395 100644 --- a/src/QueryPlan/TableScanStep.cpp +++ b/src/QueryPlan/TableScanStep.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -758,10 +759,12 @@ void TableScanStep::makeSetsForIndex(const ASTPtr & node, ContextPtr context, Pr } else { - auto input = storage->getInMemoryMetadataPtr()->getColumns().getAll(); + Block header = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context) + ->getSampleBlockForColumns(getRequiredColumns()); + Names output; output.emplace_back(left_in_operand->getColumnName()); - auto temp_actions = createExpressionActions(context, input, output, left_in_operand); + auto temp_actions = createExpressionActions(context, header.getNamesAndTypesList(), output, left_in_operand); if (temp_actions->tryFindInOutputs(left_in_operand->getColumnName())) { makeExplicitSet(func, *temp_actions, true, context, size_limits_for_set, prepared_sets); @@ -1223,31 +1226,37 @@ void TableScanStep::initializePipeline(QueryPipeline & pipeline, const BuildQuer options.ignoreProjections(); stage_watch.restart(); - ASTPtr partition_filter; - auto mutable_context = Context::createCopy(build_context.context); - if (query_info.partition_filter) - partition_filter = query_info.partition_filter->clone(); - // FIXME: It is used to work around partition keys being chosen as PREWHERE. In long term, we should rely on - // enable_partition_filter_push_down = 1 to do the stuff - if (mutable_context->getSettingsRef().remove_partition_filter_on_worker) - mutable_context->setSetting("enable_partition_filter_push_down", 1U); - - options.cache_info = query_info.cache_info; - auto interpreter = std::make_shared(query_info.query, mutable_context, options); - interpreter->execute(true); - auto backup_input_order_info = query_info.input_order_info; - query_info = interpreter->getQueryInfo(); - query_info = fillQueryInfo(build_context.context); - query_info.input_order_info = backup_input_order_info; + if (build_context.context->getSettingsRef().enable_table_scan_build_pipeline_optimization) + { + fillQueryInfoV2(build_context.context); + } + else + { + ASTPtr partition_filter; + auto mutable_context = Context::createCopy(build_context.context); + if (query_info.partition_filter) + partition_filter = query_info.partition_filter->clone(); + // FIXME: It is used to work around partition keys being chosen as PREWHERE. In long term, we should rely on + // enable_partition_filter_push_down = 1 to do the stuff + if (mutable_context->getSettingsRef().remove_partition_filter_on_worker) + mutable_context->setSetting("enable_partition_filter_push_down", 1U); + + options.cache_info = query_info.cache_info; + auto interpreter = std::make_shared(query_info.query, mutable_context, options); + interpreter->execute(true); + auto backup_input_order_info = query_info.input_order_info; + query_info = interpreter->getQueryInfo(); + query_info = fillQueryInfo(build_context.context); + query_info.input_order_info = backup_input_order_info; + if (partition_filter) + query_info.partition_filter = partition_filter; + } LOG_DEBUG(log, "init pipeline stage run time: make up query info, {} ms", stage_watch.elapsedMillisecondsAsDouble()); // always do filter underneath, as WHERE filter won't reuse PREWHERE result in optimizer mode if (query_info.prewhere_info) query_info.prewhere_info->need_filter = true; - if (partition_filter) - query_info.partition_filter = partition_filter; - if (use_projection_index) { auto storage_snapshot = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), build_context.context); @@ -1318,8 +1327,33 @@ void TableScanStep::initializePipeline(QueryPipeline & pipeline, const BuildQuer } } // flag = Output - auto pipe = storage->read( - interpreter->getRequiredColumns(), storage_snapshot, query_info, build_context.context, QueryProcessingStage::Enum::FetchColumns, max_block_size, max_streams); + QueryPlan storage_plan; + storage->read( + storage_plan, + getRequiredColumns(), + storage_snapshot, + query_info, + build_context.context, + QueryProcessingStage::Enum::FetchColumns, + max_block_size, + max_streams); + auto pipe = storage_plan.convertToPipe( + QueryPlanOptimizationSettings::fromContext(build_context.context), + BuildQueryPipelineSettings::fromContext(build_context.context)); + + { + for (auto & node : storage_plan.getNodes()) + { + auto & att_descs = node.step->getAttributeDescriptions(); + if (att_descs.empty()) + continue; + for (auto & desc : att_descs) + { + if (!attribute_descriptions.contains(desc.first)) + attribute_descriptions.emplace(desc.first, desc.second); + } + } + } if (pipe.getCacheHolder()) pipeline.addCacheHolder(pipe.getCacheHolder()); @@ -1609,6 +1643,9 @@ void TableScanStep::initializePipeline(QueryPipeline & pipeline, const BuildQuer step_desc << plan_element.part_group.partsNum() << " parts from raw data"; } setStepDescription(step_desc.str()); + RuntimeAttributeDescription tablescan_desc; + tablescan_desc.description = step_desc.str(); + attribute_descriptions.emplace("TableScanDescription", tablescan_desc); LOG_DEBUG(log, "init pipeline total run time: {} ms, table scan descriptiion: {}", total_watch.elapsedMillisecondsAsDouble(), step_desc.str()); } @@ -1753,6 +1790,13 @@ void TableScanStep::allocate(ContextPtr context) query_info = fillQueryInfo(context); original_table = storage_id.table_name; storage_id = storage->prepareTableRead(getRequiredColumns(), query_info, context); + size_t shards = context->tryGetCurrentWorkerGroup() ? context->getCurrentWorkerGroup()->getShardsInfo().size() : 1; + if (shards > 1 && !context->getSettingsRef().enable_final_sample) + { + ASTSelectQuery * select = query_info.query->as(); + if (select && select->sampleSize()) + query_info.query = rewriteSampleForDistributedTable(query_info.query, shards); + } // update query info if (query_info.query) @@ -1920,9 +1964,16 @@ void TableScanStep::setQuotaAndLimits(QueryPipeline & pipeline, const SelectQuer void TableScanStep::setReadOrder(SortDescription read_order) { if (!read_order.empty()) - { query_info.input_order_info = std::make_shared(read_order, read_order[0].direction); - } + else + query_info.input_order_info = nullptr; +} + +SortDescription TableScanStep::getReadOrder() const +{ + if (query_info.input_order_info) + return query_info.input_order_info->order_key_prefix_descr; + return SortDescription{}; } Names TableScanStep::getRequiredColumns(GetFlags flags) const @@ -2002,4 +2053,37 @@ bool TableScanStep::hasFunctionCanUseBitmapIndex() const } return false; } + +void TableScanStep::fillQueryInfoV2(ContextPtr context) +{ + assert(storage); + auto required_columns = getRequiredColumns(); + auto metadata_snapshot = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context); + auto block = metadata_snapshot->getSampleBlockForColumns(required_columns); + + /// 1. build tree rewriter result + auto syntax_analyzer_result = std::make_shared(block.getNamesAndTypesList(), storage, metadata_snapshot); + syntax_analyzer_result->analyzed_join = std::make_shared(); + query_info.syntax_analyzer_result = syntax_analyzer_result; + + /// 2. build prepared sets + if (auto where = query_info.getSelectQuery()->where()) + makeSetsForIndex(where, context, query_info.sets); + if (auto prewhere = query_info.getSelectQuery()->prewhere()) + makeSetsForIndex(prewhere, context, query_info.sets); + // TODO: atomic_predicates_expr + if (query_info.partition_filter) + makeSetsForIndex(query_info.partition_filter, context, query_info.sets); + + /// 3. build prewhere info + if (auto prewhere = query_info.getSelectQuery()->prewhere()) + { + auto prewhere_action = IQueryPlanStep::createFilterExpressionActions(context, prewhere, block); + query_info.prewhere_info = std::make_shared(prewhere_action, prewhere->getColumnName()); + } + + /// 4. build index context + query_info.index_context = std::make_shared(); +} + } diff --git a/src/QueryPlan/TableScanStep.h b/src/QueryPlan/TableScanStep.h index 764fc390ad8..ce71efae9ae 100644 --- a/src/QueryPlan/TableScanStep.h +++ b/src/QueryPlan/TableScanStep.h @@ -138,6 +138,7 @@ class TableScanStep : public ISourceStep } void setReadOrder(SortDescription read_order); + SortDescription getReadOrder() const; void formatOutputStream(ContextPtr context); @@ -151,6 +152,7 @@ class TableScanStep : public ISourceStep SelectQueryInfo fillQueryInfo(ContextPtr context); void fillPrewhereInfo(ContextPtr context); void makeSetsForIndex(const ASTPtr & node, ContextPtr context, PreparedSets & prepared_sets) const; + void fillQueryInfoV2(ContextPtr context); void allocate(ContextPtr context); Int32 getUniqueId() const { return unique_id; } diff --git a/src/QueryPlan/TableWriteStep.cpp b/src/QueryPlan/TableWriteStep.cpp index 2a62c248d6d..6c8ad314c2a 100644 --- a/src/QueryPlan/TableWriteStep.cpp +++ b/src/QueryPlan/TableWriteStep.cpp @@ -1,20 +1,34 @@ #include +#include +#include #include #include #include #include #include +#include +#include #include +#include +#include +#include #include #include +#include +#include #include +#include #include #include #include +#include "QueryPlan/IQueryPlanStep.h" #include #include #include +#include +#include +#include namespace DB { @@ -29,9 +43,18 @@ static ITransformingStep::Traits getTraits() {.preserves_number_of_rows = true}}; } -TableWriteStep::TableWriteStep(const DataStream & input_stream_, TargetPtr target_) - : ITransformingStep(input_stream_, input_stream_.header, getTraits()), target(target_) +TableWriteStep::TableWriteStep(const DataStream & input_stream_, TargetPtr target_, bool insert_select_with_profiles_) + : ITransformingStep(input_stream_, {}, getTraits()) + , target(target_) + , insert_select_with_profiles(insert_select_with_profiles_) { + if (insert_select_with_profiles) + { + Block new_header = {ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared(), "inserted_rows")}; + output_stream = DataStream{.header = std::move(new_header)}; + } + else + output_stream = {input_stream_.header}; } Block TableWriteStep::getHeader(const NamesAndTypes & input_columns) @@ -53,7 +76,7 @@ BlockOutputStreams TableWriteStep::createOutputStream( BlockOutputStreams out_streams; size_t out_streams_size = 1; auto query_settings = settings.context->getSettingsRef(); - if (target_table->supportsParallelInsert() && query_settings.max_insert_threads > 1) + if (target_table->supportsParallelInsert(settings.context) && query_settings.max_insert_threads > 1) { LOG_INFO(&Poco::Logger::get("TableWriteStep"), fmt::format("createOutputStream support parallel insert, max threads:{}, max insert threads.size:{}", max_threads, query_settings.max_insert_threads)); @@ -158,13 +181,29 @@ void TableWriteStep::transformPipeline(QueryPipeline & pipeline, const BuildQuer pipeline.resize(out_streams.size()); LOG_INFO(&Poco::Logger::get("TableWriteStep"), fmt::format("pipeline size: {}, out streams size {}", pipeline.getNumStreams(), out_streams.size())); - pipeline.addSimpleTransform( - [&]([[maybe_unused]] const Block & in_header) -> ProcessorPtr { + if (insert_select_with_profiles) + { + pipeline.addSimpleTransform([&](const Block &, QueryPipeline::StreamType type) -> ProcessorPtr + { + if (type != QueryPipeline::StreamType::Main) + return nullptr; + auto stream = std::move(out_streams.back()); out_streams.pop_back(); - return std::make_shared(stream, insert_target_header, insert_target->getStorage(), settings.context);} - ); - break; + + return std::make_shared(std::move(stream)); + }); + } + else + { + pipeline.addSimpleTransform( + [&]([[maybe_unused]] const Block & in_header) -> ProcessorPtr { + auto stream = std::move(out_streams.back()); + out_streams.pop_back(); + return std::make_shared(stream, insert_target_header, insert_target->getStorage(), settings.context);} + ); + break; + } } } } @@ -172,12 +211,18 @@ void TableWriteStep::transformPipeline(QueryPipeline & pipeline, const BuildQuer void TableWriteStep::setInputStreams(const DataStreams & input_streams_) { input_streams = input_streams_; - output_stream = DataStream{.header = std::move((input_streams_[0].header))}; + if (insert_select_with_profiles) + { + Block new_header = {ColumnWithTypeAndName(ColumnUInt64::create(), std::make_shared(), "inserted_rows")}; + output_stream = DataStream{.header = std::move(new_header)}; + } + else + output_stream = DataStream{.header = std::move((input_streams_[0].header))}; } std::shared_ptr TableWriteStep::copy(ContextPtr) const { - return std::make_shared(input_streams[0], target); + return std::make_shared(input_streams[0], target, insert_select_with_profiles); } void TableWriteStep::toProto(Protos::TableWriteStep & proto, bool) const @@ -187,13 +232,15 @@ void TableWriteStep::toProto(Protos::TableWriteStep & proto, bool) const if (!target) throw Exception("Target cannot be nullptr", ErrorCodes::LOGICAL_ERROR); target->toProto(*proto.mutable_target()); + proto.set_insert_select_with_profiles(insert_select_with_profiles); } std::shared_ptr TableWriteStep::fromProto(const Protos::TableWriteStep & proto, ContextPtr context) { auto [step_description, base_input_stream] = ITransformingStep::deserializeFromProtoBase(proto.query_plan_base()); auto target = TableWriteStep::Target::fromProto(proto.target(), context); - auto step = std::make_shared(base_input_stream, target); + bool insert_select_with_profiles = proto.has_insert_select_with_profiles() ? proto.insert_select_with_profiles() : context->getSettingsRef().insert_select_with_profiles; + auto step = std::make_shared(base_input_stream, target, insert_select_with_profiles); step->setStepDescription(step_description); return step; } @@ -262,9 +309,11 @@ TableWriteStep::InsertTarget::createFromProtoImpl(const Protos::TableWriteStep:: return step; } -String TableWriteStep::InsertTarget::toString() const +String TableWriteStep::InsertTarget::toString(const String & remove_tenant_id) const { - return "Insert " + storage_id.getNameForLogs(); + auto tmp_id = storage_id; + tmp_id.database_name = getOriginalDatabaseName(tmp_id.database_name, remove_tenant_id); + return "Insert " + tmp_id.getNameForLogs(); } NameToNameMap TableWriteStep::InsertTarget::getTableColumnToInputColumnMap(const Names & input_columns) const diff --git a/src/QueryPlan/TableWriteStep.h b/src/QueryPlan/TableWriteStep.h index f4ffc21257f..97d0dc41c29 100644 --- a/src/QueryPlan/TableWriteStep.h +++ b/src/QueryPlan/TableWriteStep.h @@ -18,7 +18,7 @@ class TableWriteStep : public ITransformingStep INSERT, }; - TableWriteStep(const DataStream & input_stream_, TargetPtr target_); + TableWriteStep(const DataStream & input_stream_, TargetPtr target_, bool insert_select_with_profiles_ = false); String getName() const override { @@ -43,6 +43,8 @@ class TableWriteStep : public ITransformingStep void allocate(const ContextPtr & context); + bool isOutputProfiles() const { return insert_select_with_profiles; } + void toProto(Protos::TableWriteStep & proto, bool for_hash_equals = false) const; static std::shared_ptr fromProto(const Protos::TableWriteStep & proto, ContextPtr context); @@ -58,6 +60,7 @@ class TableWriteStep : public ITransformingStep ASTPtr query); TargetPtr target; + bool insert_select_with_profiles; }; class TableWriteStep::Target @@ -65,7 +68,11 @@ class TableWriteStep::Target public: virtual ~Target() = default; virtual TargetType getTargetType() const = 0; - virtual String toString() const = 0; + String toString() const + { + return toString({}); + } + virtual String toString(const String & remove_tenant_id) const = 0; virtual StoragePtr getStorage() const = 0; virtual NameToNameMap getTableColumnToInputColumnMap(const Names & input_columns) const = 0; @@ -82,7 +89,7 @@ class TableWriteStep::InsertTarget : public TableWriteStep::Target } TargetType getTargetType() const override { return TargetType::INSERT; } - String toString() const override; + String toString(const String & remove_tenant_id) const override; StoragePtr getStorage() const override { return storage; diff --git a/src/QueryPlan/tests/gtest_protobuf.cpp b/src/QueryPlan/tests/gtest_protobuf.cpp index d9f64922928..7ef4a02a5fd 100644 --- a/src/QueryPlan/tests/gtest_protobuf.cpp +++ b/src/QueryPlan/tests/gtest_protobuf.cpp @@ -674,7 +674,7 @@ TEST_F(ProtobufTest, TableWriteStep) std::string step_description = fmt::format("description {}", eng() % 100); auto base_input_stream = generateDataStream(eng); auto target = generateTableWriteStepInsertTarget(eng); - auto s = std::make_shared(base_input_stream, target); + auto s = std::make_shared(base_input_stream, target, false); s->setStepDescription(step_description); return s; }(); @@ -700,7 +700,7 @@ TEST_F(ProtobufTest, TableFinishStep) auto base_input_stream = generateDataStream(eng); auto target = generateTableWriteStepInsertTarget(eng); auto output_affected_row_count_symbol = fmt::format("text{}", eng() % 100); - auto s = std::make_shared(base_input_stream, target, output_affected_row_count_symbol, nullptr); + auto s = std::make_shared(base_input_stream, target, output_affected_row_count_symbol, nullptr, false); s->setStepDescription(step_description); return s; }(); @@ -796,6 +796,9 @@ TEST_F(ProtobufTest, JoinStep) Names right_keys; for (int i = 0; i < 10; ++i) right_keys.emplace_back(fmt::format("text{}", eng() % 100)); + std::vector key_ids_null_safe; + for (size_t i = 0; i < left_keys.size(); ++i) + key_ids_null_safe.emplace_back(eng() % 2 == 0); auto filter = generateAST(eng); auto has_using = eng() % 2 == 1; std::optional> require_right_keys; @@ -819,6 +822,7 @@ TEST_F(ProtobufTest, JoinStep) keep_left_read_in_order, left_keys, right_keys, + key_ids_null_safe, filter, has_using, require_right_keys, @@ -896,9 +900,10 @@ TEST_F(ProtobufTest, MergingAggregatedStep) auto step = [&eng] { std::string step_description = fmt::format("description {}", eng() % 100); auto base_input_stream = generateDataStream(eng); - Names keys; + NameSet distinct_keys; for (int i = 0; i < 10; ++i) - keys.emplace_back(fmt::format("text{}", eng() % 100)); + distinct_keys.emplace(fmt::format("text{}", eng() % 100)); + Names keys{distinct_keys.begin(), distinct_keys.end()}; GroupingSetsParamsList grouping_sets_params; for (int i = 0; i < 2; ++i) grouping_sets_params.emplace_back(generateGroupingSetsParams(eng)); @@ -941,9 +946,10 @@ TEST_F(ProtobufTest, AggregatingStep) auto step = [&eng] { std::string step_description = fmt::format("description {}", eng() % 100); auto base_input_stream = generateDataStream(eng); - Names keys; + NameSet distinct_keys; for (int i = 0; i < 10; ++i) - keys.emplace_back(fmt::format("text{}", eng() % 100)); + distinct_keys.emplace(fmt::format("text{}", eng() % 100)); + Names keys{distinct_keys.begin(), distinct_keys.end()}; NameSet keys_not_hashed; for (int i = 0; i < 10; ++i) keys_not_hashed.emplace(fmt::format("text{}", eng() % 100)); @@ -1122,7 +1128,7 @@ TEST_F(ProtobufTest, ReadStorageRowCountStep) auto base_output_header = generateBlock(eng); auto storage_id = test_storage_ids[eng() % 3]; auto query = generateAST(eng); - auto agg_desc = generateAggregateDescription(eng); + auto agg_desc = generateAggregateDescription(eng, 0); auto num_rows = eng() % 1000; auto is_final_agg = false; auto s = std::make_shared(base_output_header, query, agg_desc, num_rows, is_final_agg); diff --git a/src/Server/APIRequestHandler.cpp b/src/Server/APIRequestHandler.cpp index 3d1f1491e8b..747c9f7aaac 100644 --- a/src/Server/APIRequestHandler.cpp +++ b/src/Server/APIRequestHandler.cpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include #include #include @@ -101,7 +103,7 @@ APIRequestHandler::APIRequestHandler(IServer & server) server_display_name = getContext()->getConfigRef().getString("display_name", getFQDNOrHostName()); } -[[maybe_unused]] static Poco::Net::HTTPResponse::HTTPStatus exceptionCodeToHTTPStatus(int exception_code) +static Poco::Net::HTTPResponse::HTTPStatus exceptionCodeToHTTPStatus(int exception_code) { using namespace Poco::Net; @@ -273,4 +275,11 @@ void APIRequestHandler::onResourceReportAction( sendResult(res, response); } +HTTPRequestHandlerFactoryPtr createAPIRequestHandlerFactory(IServer & server, const std::string & config_prefix) +{ + auto factory = std::make_shared>(server); + factory->addFiltersFromConfig(server.config(), config_prefix); + return factory; +} + } diff --git a/src/Server/HTTPHandler.cpp b/src/Server/HTTPHandler.cpp index 6d1b5646812..6c88fa3ce46 100644 --- a/src/Server/HTTPHandler.cpp +++ b/src/Server/HTTPHandler.cpp @@ -474,6 +474,42 @@ void HTTPHandler::processQuery( using namespace Poco::Net; LOG_TRACE(log, "Request URI: {}", request.getURI()); + std::string tenant_id = params.getParsed("tenant_id", ""); + std::string database = request.get("X-ClickHouse-Database", ""); + if (database.empty()) + database = params.getParsed("database", ""); + + if (auto pos = database.find('`'); pos != String::npos) + { + //CNCH multi-tenant default database pattern from gateway client: {tenant_id}`{default_database} + //Even this is a GET request or with "readonly=1" setting, we force to apply the tenant_id setting change. + auto tenant_id_from_db = String(database.c_str(), pos); + if (tenant_id.empty()) + tenant_id = tenant_id_from_db; + else if (tenant_id != tenant_id_from_db && !tenant_id_from_db.empty()) + throw Exception("tenant id " + tenant_id + " from setting doesn't match tenant id from database " + tenant_id_from_db, ErrorCodes::UNKNOWN_USER); + + ///multi-tenant default database storage pattern: {tenant_id}.{database} + if (pos + 1 != database.size()) + { + auto sub_str = database.substr(pos + 1); + if (sub_str == "default" || sub_str == "system") + database = std::move(sub_str); + else + database[pos] = '.'; + } + else /// {tenant_id}` + database.clear(); + } + + if (!database.empty()) + context->setCurrentDatabase(database); + + if (!tenant_id.empty()) + { + context->setSetting("tenant_id", tenant_id); + context->setTenantId(tenant_id); + } if (!authenticateUser(context, request, params, response)) return; // '401 Unauthorized' response with 'Negotiate' has been sent at this point. @@ -494,7 +530,9 @@ void HTTPHandler::processQuery( session = context->acquireNamedSession(session_id, session_timeout, session_check == "1"); - context->copyFrom(session->context); /// FIXME: maybe move this part to HandleRequest(), copyFrom() is used only here. + /// FIXME: maybe move this part to HandleRequest() + /// see also https://github.com/ClickHouse/ClickHouse/pull/26864 + context = Context::createCopy(session->context); context->setSessionContext(session->context); } @@ -700,28 +738,20 @@ void HTTPHandler::processQuery( reserved_param_suffixes.emplace_back("_structure"); } - std::string database = request.get("X-ClickHouse-Database", ""); std::string default_format = request.get("X-ClickHouse-Format", ""); SettingsChanges settings_changes; for (const auto & [key, value] : params) { if (key == "database") - { - if (database.empty()) - database = value; - } + continue; else if (key == "default_format") { if (default_format.empty()) default_format = value; } else if (key == "tenant_id") - { - //Even this is a GET request or with "readonly=1" setting, we force to apply the tenant_id setting change. - context->setSetting("tenant_id", value); - context->setTenantId(value); - } + continue; else if (param_could_be_skipped(key)) { } @@ -733,31 +763,6 @@ void HTTPHandler::processQuery( } } - if (!database.empty()) - { - auto &default_database = database; - auto &connection_context = context; - //CNCH multi-tenant default database pattern from gateway client: {tenant_id}`{default_database} - if (auto pos = default_database.find('`'); pos != String::npos) - { - //Even this is a GET request or with "readonly=1" setting, we force to apply the tenant_id setting change. - connection_context->setSetting("tenant_id", String(default_database.c_str(), pos)); - connection_context->setTenantId(String(default_database.c_str(), pos)); - if (pos + 1 != default_database.size()) ///multi-tenant default database storage pattern: {tenant_id}.{default_database} - { - auto sub_str = default_database.substr(pos + 1); - if (sub_str == "default" || sub_str == "system") - default_database = std::move(sub_str); - else - default_database[pos] = '.'; - } - else /// {tenant_id}` - default_database.clear(); - } - if (!default_database.empty()) - connection_context->setCurrentDatabase(default_database); - } - if (!default_format.empty()) context->setDefaultFormat(default_format); @@ -799,8 +804,6 @@ void HTTPHandler::processQuery( }); }; - adjustAccessTablesIfNeeded(context); - /// While still no data has been sent, we will report about query execution progress by sending HTTP headers. if (settings.send_progress_in_http_headers) append_callback([&used_output] (const Progress & progress) { used_output.out->onProgress(progress); }); @@ -844,6 +847,7 @@ void HTTPHandler::processQuery( /// Send HTTP headers with code 200 if no exception happened and the data is still not sent to /// the client. + used_output.out_maybe_compressed->finalize(); used_output.out->finalize(); } diff --git a/src/Server/HTTPHandlerFactory.cpp b/src/Server/HTTPHandlerFactory.cpp index 6dc15258799..3aab7ad942f 100644 --- a/src/Server/HTTPHandlerFactory.cpp +++ b/src/Server/HTTPHandlerFactory.cpp @@ -95,6 +95,8 @@ static inline auto createHandlersFactoryFromConfig( main_handler_factory->addHandler(createPrometheusHandlerFactory(server, async_metrics, prefix + "." + key, context)); else if (handler_type == "replicas_status") main_handler_factory->addHandler(createReplicasStatusHandlerFactory(server, prefix + "." + key)); + else if (handler_type == "api") + main_handler_factory->addHandler(createAPIRequestHandlerFactory(server, prefix + "." + key)); else throw Exception("Unknown handler type '" + handler_type + "' in config here: " + prefix + "." + key + ".handler.type", ErrorCodes::INVALID_CONFIG_PARAMETER); @@ -176,7 +178,7 @@ void addCommonDefaultHandlersFactory(HTTPRequestHandlerFactoryMain & factory, IS factory.addHandler(ping_handler); auto api_handler = std::make_shared>(server); - api_handler->attachStrictPath("/api"); + api_handler->attachNonStrictPath("/api"); api_handler->allowGetAndHeadRequest(); factory.addHandler(api_handler); diff --git a/src/Server/HTTPHandlerFactory.h b/src/Server/HTTPHandlerFactory.h index fb4315f5713..703982e1128 100644 --- a/src/Server/HTTPHandlerFactory.h +++ b/src/Server/HTTPHandlerFactory.h @@ -132,6 +132,8 @@ HTTPRequestHandlerFactoryPtr createPredefinedHandlerFactory(IServer & server, co HTTPRequestHandlerFactoryPtr createReplicasStatusHandlerFactory(IServer & server, const std::string & config_prefix); +HTTPRequestHandlerFactoryPtr createAPIRequestHandlerFactory(IServer & server, const std::string & config_prefix); + HTTPRequestHandlerFactoryPtr createPrometheusHandlerFactory(IServer & server, AsynchronousMetrics & async_metrics, const std::string & config_prefix, ContextMutablePtr context); diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index 17d72747ecf..dc75b3ff923 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -175,7 +175,9 @@ MySQLHandler::MySQLHandler(IServer & server_, TCPServer & tcp_server_, const Poc server_capabilities |= CLIENT_SSL; static constexpr const char SHOW_CHARSET[] = "SELECT 'utf8mb4' AS charset, 'UTF-8 Unicode' AS Description, 'utf8mb4_0900_ai_ci' AS `Default collation`, 4 AS Maxlen"; - static constexpr const char SHOW_COLLATION[] = "SELECT 'utf8mb4_0900_ai_ci' AS collation, 'utf8mb4' AS Charset, '255' AS Id, 'Yes' AS Default, 'Yes' AS Compiled, 0 AS Sortlen, 'NO PAD' AS Pad_attribute"; + static constexpr const char SHOW_COLLATION[] = "SELECT 'utf8_general_ci' AS collation, 'utf8' AS charset, 33 AS id, 'Yes' AS default, 'Yes' AS Compiled, 1 AS Sortlen, 'NO PAD' AS Pad_attribute " + "UNION SELECT 'binary' AS collation, 'binary' AS charset, 63 AS id, 'Yes' AS default, 'Yes' AS Compiled, 1 AS Sortlen, 'NO PAD' AS Pad_attribute " + "UNION SELECT 'utf8mb4_0900_ai_ci' AS collation, 'utf8mb4' AS Charset, '255' AS Id, 'Yes' AS Default, 'Yes' AS Compiled, 0 AS Sortlen, 'NO PAD' AS Pad_attribute"; static constexpr const char SHOW_ENGINES[] = "SELECT name AS Engine, 'Yes' AS Support, concat(name, ' engine') AS Comment, 'NO' AS Transcations, 'NO' AS XA, 'NO' AS Savepoints FROM system.table_engines"; static constexpr const char SHOW_PRIVILEGES[] = "SELECT '' AS Privilege, '' AS Context, '' AS Comment"; @@ -201,6 +203,7 @@ MySQLHandler::MySQLHandler(IServer & server_, TCPServer & tcp_server_, const Poc queries_replacements.emplace_back("SHOW GLOBAL VARIABLES", showVariableReplacementQuery); queries_replacements.emplace_back("SHOW INDEXES", showIndexReplacementQuery); queries_replacements.emplace_back("SHOW INDEX", showIndexReplacementQuery); + queries_replacements.emplace_back("SHOW KEYS", showIndexReplacementQuery); queries_replacements.emplace_back("SHOW PLUGINS", selectEmptyReplacementQuery); queries_replacements.emplace_back("SHOW PRIVILEGES", ReplaceWith::fn); queries_replacements.emplace_back("SHOW PROCEDURE STATUS", selectEmptySetQuery); @@ -261,11 +264,6 @@ void MySQLHandler::run() if (!(client_capabilities & CLIENT_PROTOCOL_41)) throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES); - handshake_response.username = connection_context->formatUserName(handshake_response.username); - authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response); - - connection_context->getClientInfo().initial_user = handshake_response.username; - try { auto &default_database = handshake_response.database; @@ -302,6 +300,11 @@ void MySQLHandler::run() packet_endpoint->sendPacket(ERRPacket(exc.code(), "HY000", exc.message()), true); } + handshake_response.username = connection_context->formatUserName(handshake_response.username); + authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response); + + connection_context->getClientInfo().initial_user = handshake_response.username; + OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0); packet_endpoint->sendPacket(ok_packet, true); @@ -559,6 +562,8 @@ void MySQLHandler::comQuery(ReadBuffer & payload, bool binary_protocol) query_context->setSetting("mysql_map_fixed_string_to_text_in_show_columns", 1); /// TODO(fredwang) change it to a smaller threshold? query_context->setSetting("max_execution_time", 18000); + /// required by quickbi, otherwise it would fail to get table info + query_context->setSetting("allow_mysql_having_name_resolution", 1); CurrentThread::QueryScope query_scope{query_context}; std::atomic affected_rows {0}; diff --git a/src/Server/ServerPrometheusMetricsWriter.cpp b/src/Server/ServerPrometheusMetricsWriter.cpp index c71eecf333c..16503b45882 100644 --- a/src/Server/ServerPrometheusMetricsWriter.cpp +++ b/src/Server/ServerPrometheusMetricsWriter.cpp @@ -233,13 +233,9 @@ void ServerPrometheusMetricsWriter::writeLabelledMetrics(WriteBuffer & wb) { String metric_name { LabelledMetrics::getSnakeName(metric) }; String metric_doc { LabelledMetrics::getDocumentation(metric) }; - LabelledMetrics::LabelledCounter labelled_counter = LabelledMetrics::getCounter(metric); - for (const auto & item : labelled_counter) - { - String key; - MetricLabels labels = item.first; - LabelledMetrics::Count counter = item.second; + auto write_metric = [&](MetricLabels labels, LabelledMetrics::Count counter) { + String key; if (metric == LabelledMetrics::VwQuery || metric == LabelledMetrics::UnlimitedQuery) { labels.insert({"resource_type", metric == LabelledMetrics::VwQuery ? "vw" : "unlimited"}); @@ -254,15 +250,31 @@ void ServerPrometheusMetricsWriter::writeLabelledMetrics(WriteBuffer & wb) } else { - continue; + return; } String key_label = key + getLabel(labels); writeOutLine(wb, "# HELP", key, metric_doc); writeOutLine(wb, "# TYPE", key, COUNTER_TYPE); writeOutLine(wb, key_label, counter); - } + }; + LabelledMetrics::LabelledCounter labelled_counter = LabelledMetrics::getCounter(metric); + if (!labelled_counter.empty()) + { + for (const auto & item : labelled_counter) + { + + MetricLabels labels = item.first; + LabelledMetrics::Count counter = item.second; + + write_metric(labels, counter); + } + } + else + { + write_metric({}, 0); + } } } @@ -439,7 +451,7 @@ void ServerPrometheusMetricsWriter::writePartMetrics(WriteBuffer & wb) continue; Catalog::PartitionMap partitions; - cnch_catalog->getPartitionsFromMetastore(*cnch_table, partitions); + cnch_catalog->getPartitionsFromMetastore(*cnch_table, partitions, nullptr); for (auto & partition : partitions) { diff --git a/src/Server/ServerPrometheusMetricsWriter.h b/src/Server/ServerPrometheusMetricsWriter.h index 937c1c4b901..ead674ee323 100644 --- a/src/Server/ServerPrometheusMetricsWriter.h +++ b/src/Server/ServerPrometheusMetricsWriter.h @@ -44,9 +44,19 @@ namespace ProfileEvents const extern Event ReadBufferFromHdfsRead; const extern Event ReadBufferFromHdfsReadBytes; const extern Event ReadBufferFromHdfsReadFailed; + const extern Event ReadBufferFromFileDescriptorRead; + const extern Event ReadBufferFromFileDescriptorReadFailed; + const extern Event ReadBufferFromFileDescriptorReadBytes; + const extern Event WriteBufferFromHdfsWrite; const extern Event WriteBufferFromHdfsWriteBytes; const extern Event WriteBufferFromHdfsWriteFailed; + const extern Event WriteBufferFromFileDescriptorWrite; + const extern Event WriteBufferFromFileDescriptorWriteFailed; + const extern Event WriteBufferFromFileDescriptorWriteBytes; + + const extern Event DiskReadElapsedMicroseconds; + const extern Event DiskWriteElapsedMicroseconds; /// SD const extern Event SDRequest; @@ -84,6 +94,8 @@ namespace ProfileEvents const extern Event TsCacheUpdateElapsedMilliseconds; extern const Event TSORequest; extern const Event TSOError; + + /// Disk cache extern const Event DiskCacheGetMetaMicroSeconds; extern const Event DiskCacheGetTotalOps; extern const Event DiskCacheSetTotalOps; @@ -1120,15 +1132,23 @@ class ServerPrometheusMetricsWriter : public IPrometheusMetricsWriter ///About HDFS ProfileEvents::ReadBufferFromHdfsRead, ProfileEvents::ReadBufferFromHdfsReadFailed, + ProfileEvents::ReadBufferFromHdfsReadBytes, + ProfileEvents::ReadBufferFromFileDescriptorRead, + ProfileEvents::ReadBufferFromFileDescriptorReadFailed, + ProfileEvents::ReadBufferFromFileDescriptorReadBytes, ProfileEvents::WriteBufferFromHdfsWrite, ProfileEvents::WriteBufferFromHdfsWriteFailed, + ProfileEvents::WriteBufferFromHdfsWriteBytes, + ProfileEvents::WriteBufferFromFileDescriptorWrite, + ProfileEvents::WriteBufferFromFileDescriptorWriteFailed, + ProfileEvents::WriteBufferFromFileDescriptorWriteBytes, // ProfileEvents::CnchReadRowsFromDiskCache, // ProfileEvents::CnchReadRowsFromRemote, - ProfileEvents::ReadBufferFromHdfsReadBytes, ProfileEvents::HDFSReadElapsedMicroseconds, // ProfileEvents::HDFSReadElapsedCpuMilliseconds, - ProfileEvents::WriteBufferFromHdfsWriteBytes, ProfileEvents::HDFSWriteElapsedMicroseconds, + ProfileEvents::DiskReadElapsedMicroseconds, + ProfileEvents::DiskWriteElapsedMicroseconds, ///About SD ProfileEvents::SDRequest, ProfileEvents::SDRequestFailed, @@ -1164,6 +1184,7 @@ class ServerPrometheusMetricsWriter : public IPrometheusMetricsWriter ProfileEvents::TsCacheUpdateElapsedMilliseconds, ProfileEvents::TSORequest, ProfileEvents::TSOError, + /// About disk cache ProfileEvents::DiskCacheGetMetaMicroSeconds, ProfileEvents::DiskCacheGetTotalOps, ProfileEvents::DiskCacheSetTotalOps, diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index 45ea71a41d3..2270329ead5 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -221,32 +221,14 @@ void TCPHandler::runImpl() /// When connecting, the default database can be specified. if (!default_database.empty()) { - //CNCH multi-tenant default database pattern from gateway client: {tenant_id}`{default_database} - if (auto pos = default_database.find('`'); pos != String::npos) - { - connection_context->setSetting("tenant_id", String(default_database.c_str(), pos)); /// {tenant_id}`* - connection_context->setTenantId(String(default_database.c_str(), pos)); - if (pos + 1 != default_database.size()) ///multi-tenant default database storage pattern: {tenant_id}.{default_database} - { - auto sub_str = default_database.substr(pos + 1); - if (sub_str == "default" || sub_str == "system") - default_database = std::move(sub_str); - else - default_database[pos] = '.'; - } - else /// {tenant_id}` - default_database.clear(); - } - - if ((!default_database.empty()) && (!DatabaseCatalog::instance().isDatabaseExist(default_database, connection_context))) + if (!DatabaseCatalog::instance().isDatabaseExist(default_database, connection_context)) { Exception e("Database " + backQuote(default_database) + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE); LOG_ERROR(log, "Code: {}, e.displayText() = {}, Stack trace:\n\n{}", e.code(), e.displayText(), e.getStackTraceString()); sendException(e, connection_context->getSettingsRef().calculate_text_stack_trace); return; } - if (!default_database.empty()) - connection_context->setCurrentDatabase(default_database); + connection_context->setCurrentDatabase(default_database); } UInt64 idle_connection_timeout = connection_settings.idle_connection_timeout; @@ -367,6 +349,7 @@ void TCPHandler::runImpl() /// Send block to the client - input storage structure. state.input_header = metadata_snapshot->getSampleBlock(); sendData(state.input_header); + sendTimezone(); }); query_context->setInputBlocksReaderCallback([&connection_settings, this](ContextPtr context) -> Block { @@ -414,6 +397,12 @@ void TCPHandler::runImpl() interpretSettings(ast, query_context); } + if (query_context->getSettingsRef().bsp_mode) + { + /// for bsp mode, progress needs to be sent during scheduling. + query_context->setSendTCPProgress([&]() { this->sendProgress(); }); + } + auto * insert_query = ast->as(); if (!(insert_query && insert_query->data) && query_context->isAsyncMode()) { @@ -504,9 +493,9 @@ void TCPHandler::runImpl() if (!state.plan_segment) { state.io = executeQuery(state.query, query_context, false, state.stage, may_have_embedded_data); - + if (OutfileTarget::checkOutfileWithTcpOnServer(query_context)) - { + { sendEndOfStream(); return; // all data already outfile in executequery() } @@ -898,6 +887,10 @@ void TCPHandler::processOrdinaryQueryWithProcessors() { /// A packet was received requesting to stop execution of the request. executor.cancel(); + if (state.io.coordinator && state.is_cancelled) + { + throw Exception("Cancelled by client.", ErrorCodes::QUERY_WAS_CANCELLED); + } break; } @@ -1060,6 +1053,19 @@ void TCPHandler::sendExtremes(const Block & extremes) } } +void TCPHandler::sendTimezone() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_TIMEZONE_UPDATES) + return; + + const String & tz = query_context->getSettingsRef().session_timezone.value; + + LOG_DEBUG(log, "TCPHandler::sendTimezone(): {}", tz); + writeVarUInt(Protocol::Server::TimezoneUpdate, *out); + writeStringBinary(tz, *out); + out->next(); +} + bool TCPHandler::receiveProxyHeader() { if (in->eof()) @@ -1175,26 +1181,52 @@ void TCPHandler::receiveHello() throw NetException("Unexpected packet from client (no user in Hello package)", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); String tenant_id_from_db; - String tenant_id_from_user; - if (auto pos = user.find('`'); pos != String::npos) - tenant_id_from_user = String(user.c_str(), pos); - if (!default_database.empty()) + /* need to support below cases: + tenanted user + tenanted db: gateway 2.0 user access + tenanted user + db: internal dictionary access + user + tenanted db: 2.0 server backwards compatible user access + user + db: devops/developers + */ + + if (auto pos = default_database.find('`'); pos != String::npos) { - if (auto pos = default_database.find('`'); pos != String::npos) - tenant_id_from_db = String(default_database.c_str(), pos); + tenant_id_from_db = String(default_database.c_str(), pos); + connection_context->setSetting("tenant_id", tenant_id_from_db); /// {tenant_id}`* + connection_context->setTenantId(tenant_id_from_db); + ///multi-tenant default database storage pattern: {tenant_id}.{default_database} + if (pos + 1 != default_database.size()) + { + auto sub_str = default_database.substr(pos + 1); + if (sub_str == "default" || sub_str == "system") + default_database = std::move(sub_str); + else + default_database[pos] = '.'; + } + else /// {tenant_id}` + { + default_database.clear(); + } } - if (!tenant_id_from_user.empty() && tenant_id_from_db.empty()) + if (auto pos = user.find('`'); pos != String::npos) { - default_database = formatTenantDatabaseNameWithTenantId(default_database, tenant_id_from_user, '`'); - if (auto pos = user.find('`'); pos != String::npos) // remove tenant id for server and worker communication + String tenant_id_from_user = String(user.c_str(), pos); + + if (tenant_id_from_db.empty()) + { + /// internal dictionary access user = user.substr(pos + 1); + + if (!default_database.empty()) + default_database = formatTenantDatabaseNameWithTenantId(default_database, tenant_id_from_user, '`'); + } + else + { + if (!tenant_id_from_user.empty() && tenant_id_from_user != tenant_id_from_db) + throw NetException("Tenant ID of user and default database are not matching", ErrorCodes::LOGICAL_ERROR); + } } - // else if (tenant_id_from_user.empty() && !tenant_id_from_db.empty()) - // user = tenant_id_from_db + '`' + user; - else if (!tenant_id_from_user.empty() && !tenant_id_from_db.empty() && tenant_id_from_user != tenant_id_from_db) - throw NetException("Tenant ID of user and default database are not matching", ErrorCodes::LOGICAL_ERROR); LOG_DEBUG( log, @@ -1244,7 +1276,7 @@ void TCPHandler::sendHello() writeVarUInt(VERSION_MINOR, *out); writeVarUInt(DBMS_TCP_PROTOCOL_VERSION, *out); if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) - writeStringBinary(DateLUT::instance().getTimeZone(), *out); + writeStringBinary(DateLUT::serverTimezoneInstance().getTimeZone(), *out); if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME) writeStringBinary(server_display_name, *out); if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_VERSION_PATCH) @@ -1403,23 +1435,6 @@ void TCPHandler::receiveQuery() /// Set fields, that are known apriori. client_info.interface = ClientInfo::Interface::TCP; - /** @aeolus - * Make the current_user and initial_user identical so that any INITIAL_QUERY - * could cancel a SECONDARY_QUERY when their initial_query_id(s) are same. - * - * What if current_user is not modified? The current_user will still be `default` - * due to multiplexed connections (or global connection pool). And a initial_query - * requested by a specific user couldn't cancel the queries of `default` user. - * # See the ProcessList.cpp about query cancellation. - * - * NOTE: query_context will be restored to connection_context before received - * a new query. The only thing modified is the current_user of this (current) query. - */ - if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) - { - client_info.current_user = client_info.initial_user; - } - /// Per query settings are also passed via TCP. /// We need to check them before applying due to they can violate the settings constraints. auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) @@ -1428,8 +1443,6 @@ void TCPHandler::receiveQuery() Settings passed_settings; passed_settings.read(*in, settings_format); - adjustAccessTablesIfNeeded(query_context); - /// Interserver secret. std::string received_hash; if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_INTERSERVER_SECRET) @@ -1498,6 +1511,7 @@ void TCPHandler::receiveQuery() query_context->clampToSettingsConstraints(settings_changes); } query_context->applySettingsChanges(settings_changes, false); + query_context->setSessionSettingsChanges(settings_changes); /// Disable function name normalization when it's a secondary query, because queries are either /// already normalized on initiator node, or not normalized and should remain unnormalized for @@ -2098,12 +2112,16 @@ void TCPHandler::updateProgress(const Progress & value) void TCPHandler::sendProgress() { auto increment = state.progress.fetchAndResetPiecewiseAtomically(); - if (!increment.empty()) - { - writeVarUInt(Protocol::Server::Progress, *out); - increment.write(*out, client_tcp_protocol_version); - out->next(); - } + LOG_DEBUG( + &Poco::Logger::get("debug"), + fmt::format( + "send progress rows:{} bytes:{} total_rows_to_read:{}", + increment.read_rows, + increment.read_bytes, + increment.total_rows_to_read)); + writeVarUInt(Protocol::Server::Progress, *out); + increment.write(*out, client_tcp_protocol_version); + out->next(); } void TCPHandler::sendLogs() diff --git a/src/Server/TCPHandler.h b/src/Server/TCPHandler.h index e0bfae14568..fbbd1be7dc1 100644 --- a/src/Server/TCPHandler.h +++ b/src/Server/TCPHandler.h @@ -242,6 +242,7 @@ class TCPHandler : public Poco::Net::TCPServerConnection void sendProfileInfo(const BlockStreamProfileInfo & info); void sendTotals(const Block & totals); void sendExtremes(const Block & extremes); + void sendTimezone(); /// Creates state.block_in/block_out for blocks read/write, depending on whether compression is enabled. void initBlockInput(); diff --git a/src/Server/TSOPrometheusMetricsWriter.h b/src/Server/TSOPrometheusMetricsWriter.h index fa06288b880..e95b4c04caa 100644 --- a/src/Server/TSOPrometheusMetricsWriter.h +++ b/src/Server/TSOPrometheusMetricsWriter.h @@ -44,6 +44,7 @@ class TSOPrometheusMetricsWriter : public IPrometheusMetricsWriter { {YIELD_LEADERSHIP_KEY, "Number of times leadership was yielded by this TSO node."}, {IS_LEADER_KEY, "Denotes if this TSO node is a leader."}, + {UPDATE_TS_STOPPED_KEY, "Number of times TSO update timestamp stopped functioning."}, }; const bool send_metrics; diff --git a/src/Server/ya.make b/src/Server/ya.make index 6a6a442fce8..9076c534759 100644 --- a/src/Server/ya.make +++ b/src/Server/ya.make @@ -32,6 +32,7 @@ SRCS( PrometheusRequestHandler.cpp ProtocolServerAdapter.cpp ReplicasStatusHandler.cpp + APIRequestHandler.cpp StaticRequestHandler.cpp TCPHandler.cpp WebUIRequestHandler.cpp diff --git a/src/ServiceDiscovery/ServiceDiscoveryLocal.cpp b/src/ServiceDiscovery/ServiceDiscoveryLocal.cpp index 082064033e3..8fe750be799 100644 --- a/src/ServiceDiscovery/ServiceDiscoveryLocal.cpp +++ b/src/ServiceDiscovery/ServiceDiscoveryLocal.cpp @@ -16,12 +16,12 @@ #include #include +#include +#include +#include #include #include #include -#include -#include -#include #include @@ -64,6 +64,20 @@ HostWithPortsVec ServiceDiscoveryLocal::lookup(const String & psm_name, Componen if (type == ComponentType::WORKER && !vw_name.empty() && ep.virtual_warehouse != vw_name) continue; + // We need to resolve hostname to ip, + // because we will use ip to set `TransactionRecord`. + try + { + auto ip = DNSResolver::instance().resolveHost(ep.host).toString(); + if (!ip.empty()) + { + ep.host = ip; + } + } + catch (...) + { + } + HostWithPorts host_with_ports{ep.host}; host_with_ports.id = ep.hostname; if (ep.ports.count("PORT0")) diff --git a/src/Statistics/ASTHelpers.cpp b/src/Statistics/ASTHelpers.cpp new file mode 100644 index 00000000000..0687af3a211 --- /dev/null +++ b/src/Statistics/ASTHelpers.cpp @@ -0,0 +1,94 @@ +#include + +#include +#include +#include +#include + +namespace DB::Statistics +{ +std::vector getTablesFromScope(ContextPtr context, const StatisticsScope & scope) +{ + std::vector tables; + auto catalog = createCatalogAdaptor(context); + + if (!scope.database) + { + const auto access = context->getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_DATABASES); + const String tenant_id = context->getTenantId(); + for (const auto & [database_name, db] : DatabaseCatalog::instance().getDatabases(context)) + { + String database_strip_tenantid = database_name; + if (!tenant_id.empty()) + { + if (startsWith(database_name, tenant_id + ".")) + database_strip_tenantid = getOriginalDatabaseName(database_name, tenant_id); + // Will skip database of other tenants and default user (without tenantid prefix) + else if (database_name.find('.') != std::string::npos || !DatabaseCatalog::isDefaultVisibleSystemDatabase(database_name)) + continue; + } + + if (check_access_for_databases && !access->isGranted(AccessType::SHOW_DATABASES, database_name)) + continue; + + if (database_name == DatabaseCatalog::TEMPORARY_DATABASE) + continue; /// We don't want to show the internal database for temporary tables in system.databases + + auto new_tables = catalog->getAllTablesID(database_name); + tables.insert(tables.end(), new_tables.begin(), new_tables.end()); + } + } + else + { + auto db = context->resolveDatabase(scope.database.value()); + if (!scope.table) + { + tables = catalog->getAllTablesID(db); + } + else + { + auto table = scope.table.value(); + auto table_info_opt = catalog->getTableIdByName(db, table); + if (!table_info_opt) + { + auto msg = "Unknown Table (" + table + ") in database (" + db + ")"; + throw Exception(msg, ErrorCodes::UNKNOWN_TABLE); + } + tables.emplace_back(table_info_opt.value()); + } + } + + // ensure table is unique + std::unordered_set table_set; + std::vector result; + // show materialized view as target table + for (auto table : tables) + { + auto storage = catalog->getStorageByTableId(table); + if (const auto * mv = dynamic_cast(storage.get())) + { + auto table_opt = catalog->getTableIdByName(mv->getTargetDatabaseName(), mv->getTargetTableName()); + if (!table_opt.has_value()) + { + auto err_msg = fmt::format( + FMT_STRING("mv {}.{} has invalid target table {}.{}"), + mv->getDatabaseName(), + mv->getTableName(), + mv->getTargetDatabaseName(), + mv->getTargetTableName()); + LOG_WARNING(&Poco::Logger::get("ShowStats"), err_msg); + continue; + } + table = table_opt.value(); + } + if (table_set.count(table)) + { + continue; + } + table_set.insert(table); + result.emplace_back(table); + } + return result; +} +} diff --git a/src/Statistics/ASTHelpers.h b/src/Statistics/ASTHelpers.h index 03dbe30eb6e..8e180d31f68 100644 --- a/src/Statistics/ASTHelpers.h +++ b/src/Statistics/ASTHelpers.h @@ -1,77 +1,41 @@ +#pragma once +#include #include #include -#include -#include #include -#include + namespace DB::Statistics { // use any_database, any_table, database, table in query // to construct visited tables + +struct StatisticsScope +{ + std::optional database; // nullopt for all + std::optional table; // nullopt for all +}; + +std::vector getTablesFromScope(ContextPtr context, const StatisticsScope & scope); + template -inline auto getTablesFromAST(ContextPtr context, const QueryType * query) +StatisticsScope scopeFromAST(ContextPtr context, const QueryType * query) { - std::vector tables; - auto catalog = createCatalogAdaptor(context); if (query->any_database) - { - for (const auto & db : DatabaseCatalog::instance().getDatabases(context)) - { - auto new_tables = catalog->getAllTablesID(db.first); - tables.insert(tables.end(), new_tables.begin(), new_tables.end()); - } - } - else - { - auto db = context->resolveDatabase(query->database); - if (query->any_table) - { - tables = catalog->getAllTablesID(db); - } - else - { - auto table_info_opt = catalog->getTableIdByName(db, query->table); - if (!table_info_opt) - { - auto msg = "Unknown Table (" + query->table + ") in database (" + db + ")"; - throw Exception(msg, ErrorCodes::UNKNOWN_TABLE); - } - tables.emplace_back(table_info_opt.value()); - } - } + return StatisticsScope{}; + auto database = context->resolveDatabase(query->database); + if (query->any_table) + return StatisticsScope{database, std::nullopt}; + auto table = query->table; + return StatisticsScope{database, table}; +} - // ensure table is unique - std::unordered_set table_set; - std::vector result; - // show materialized view as target table - for (auto table : tables) - { - auto storage = catalog->getStorageByTableId(table); - if (const auto * mv = dynamic_cast(storage.get())) - { - auto table_opt = catalog->getTableIdByName(mv->getTargetDatabaseName(), mv->getTargetTableName()); - if (!table_opt.has_value()) - { - auto err_msg = fmt::format( - FMT_STRING("mv {}.{} has invalid target table {}.{}"), - mv->getDatabaseName(), - mv->getTableName(), - mv->getTargetDatabaseName(), - mv->getTargetTableName()); - LOG_WARNING(&Poco::Logger::get("ShowStats"), err_msg); - continue; - } - table = table_opt.value(); - } - if (table_set.count(table)) - { - continue; - } - table_set.insert(table); - result.emplace_back(table); - } - return result; +template +inline auto getTablesFromAST(ContextPtr context, const QueryType * query) +{ + auto scope = scopeFromAST(context, query); + return getTablesFromScope(context, scope); } + } diff --git a/src/Statistics/AutoStatisticsHelper.cpp b/src/Statistics/AutoStatisticsHelper.cpp index ae4e31d33c5..e580d7e6358 100644 --- a/src/Statistics/AutoStatisticsHelper.cpp +++ b/src/Statistics/AutoStatisticsHelper.cpp @@ -86,7 +86,7 @@ TimePoint nowTimePoint() ExtendedDayNum convertToDate(DateTime64 time) { time_t ts = time.value / DecimalUtils::scaleMultiplier(DataTypeDateTime64::default_scale); - auto date = DateLUT::instance().toDayNum(ts); + auto date = DateLUT::serverTimezoneInstance().toDayNum(ts); return date; } @@ -107,7 +107,7 @@ std::optional