diff --git a/python/pyarrow-stubs/pyarrow/_cuda.pyi b/python/pyarrow-stubs/pyarrow/_cuda.pyi new file mode 100644 index 000000000000..d484fc5cf5f3 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/_cuda.pyi @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +import cuda # type: ignore[import-not-found] + +from numba.cuda.cudadrv import driver as _numba_driver # type: ignore[import-untyped, import-not-found] # noqa: E501 + +from . import lib +from ._stubs_typing import ArrayLike + + +class Context(lib._Weakrefable): + def __init__(self, device_number: int = 0, handle: int | None = None) -> None: ... + + @staticmethod + def from_numba(context: _numba_driver.Context | None = None) -> Context: ... + + def to_numba(self) -> _numba_driver.Context: ... + + @staticmethod + def get_num_devices() -> int: ... + + @property + def device_number(self) -> int: ... + + @property + def handle(self) -> int: ... + + def synchronize(self) -> None: ... + + @property + def bytes_allocated(self) -> int: ... + + def get_device_address(self, address: int) -> int: ... + + def new_buffer(self, nbytes: int) -> CudaBuffer: ... + + @property + def memory_manager(self) -> lib.MemoryManager: ... + + @property + def device(self) -> lib.Device: ... + + def foreign_buffer(self, address: int, size: int, base: Any | + None = None) -> CudaBuffer: ... + + def open_ipc_buffer(self, ipc_handle: IpcMemHandle) -> CudaBuffer: ... + + def buffer_from_data( + self, + data: CudaBuffer | HostBuffer | lib.Buffer | ArrayLike, + offset: int = 0, + size: int = -1, + ) -> CudaBuffer: ... + + def buffer_from_object(self, obj: Any) -> CudaBuffer: ... + + +class IpcMemHandle(lib._Weakrefable): + @staticmethod + def from_buffer(opaque_handle: lib.Buffer) -> IpcMemHandle: ... + + def serialize(self, pool: lib.MemoryPool | None = None) -> lib.Buffer: ... + + +class CudaBuffer(lib.Buffer): + @staticmethod + def from_buffer(buf: lib.Buffer) -> CudaBuffer: ... + + @staticmethod + def from_numba(mem: _numba_driver.MemoryPointer) -> CudaBuffer: ... + + def to_numba(self) -> _numba_driver.MemoryPointer: ... + + def copy_to_host( + self, + position: int = 0, + nbytes: int = -1, + buf: lib.Buffer | None = None, + memory_pool: lib.MemoryPool | None = None, + resizable: bool = False, + ) -> lib.Buffer: ... + + def copy_from_host( + self, data: lib.Buffer | ArrayLike, position: int = 0, nbytes: int = -1 + ) -> int: ... + + def copy_from_device(self, buf: CudaBuffer, position: int = 0, + nbytes: int = -1) -> int: ... + + def export_for_ipc(self) -> IpcMemHandle: ... + + @property + def context(self) -> Context: ... + + def slice(self, offset: int = 0, length: int | None = None) -> CudaBuffer: ... + + def to_pybytes(self) -> bytes: ... + + +class HostBuffer(lib.Buffer): + @property + def size(self) -> int: ... + + +class BufferReader(lib.NativeFile): + def __init__(self, obj: CudaBuffer) -> None: ... + def read_buffer(self, nbytes: int | None = None) -> CudaBuffer: ... + + +class BufferWriter(lib.NativeFile): + def __init__(self, obj: CudaBuffer) -> None: ... + def writeat(self, position: int, data: ArrayLike) -> None: ... + + @property + def buffer_size(self) -> int: ... + + @buffer_size.setter + def buffer_size(self, buffer_size: int): ... + + @property + def num_bytes_buffered(self) -> int: ... + + +def new_host_buffer(size: int, device: int = 0) -> HostBuffer: ... + + +def serialize_record_batch(batch: lib.RecordBatch, ctx: Context) -> CudaBuffer: ... + + +def read_message( + source: CudaBuffer | cuda.BufferReader, pool: lib.MemoryManager | None = None +) -> lib.Message: ... + + +def read_record_batch( + buffer: lib.Buffer, + object: lib.Schema, + *, + dictionary_memo: lib.DictionaryMemo | None = None, + pool: lib.MemoryPool | None = None, +) -> lib.RecordBatch: ... diff --git a/python/pyarrow-stubs/pyarrow/_flight.pyi b/python/pyarrow-stubs/pyarrow/_flight.pyi new file mode 100644 index 000000000000..03d6c6580ab0 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/_flight.pyi @@ -0,0 +1,660 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import enum +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self +from collections.abc import Generator, Iterable, Iterator, Sequence +from typing import Any, Generic, NamedTuple, TypeVar +from datetime import datetime +from typing_extensions import deprecated + +from .ipc import _ReadPandasMixin, ReadStats +from .lib import ( + ArrowCancelled, + ArrowException, + ArrowInvalid, + Buffer, + IpcReadOptions, + IpcWriteOptions, + RecordBatch, + RecordBatchReader, + Scalar, + Schema, + Table, + _CRecordBatchWriter, + _Weakrefable, +) + +_T = TypeVar("_T") + + +class FlightCallOptions(_Weakrefable): + def __init__( + self, + timeout: float | None = None, + write_options: IpcWriteOptions | None = None, + headers: list[tuple[str | bytes, str | bytes]] | None = None, + read_options: IpcReadOptions | None = None, + ) -> None: ... + + +class CertKeyPair(NamedTuple): + cert: str | bytes | None + key: str | bytes | None + + +class FlightError(Exception): + extra_info: bytes + + +class FlightInternalError(FlightError, ArrowException): + ... + + +class FlightTimedOutError(FlightError, ArrowException): + ... + + +class FlightCancelledError(FlightError, ArrowCancelled): + def __init__(self, message: str, *, extra_info: bytes | None = None) -> None: ... + + +class FlightServerError(FlightError, ArrowException): + ... + + +class FlightUnauthenticatedError(FlightError, ArrowException): + ... + + +class FlightUnauthorizedError(FlightError, ArrowException): + ... + + +class FlightUnavailableError(FlightError, ArrowException): + ... + + +class FlightWriteSizeExceededError(ArrowInvalid): + limit: int + actual: int + + +class Action(_Weakrefable): + def __init__( + self, action_type: bytes | str, buf: Buffer | bytes | None) -> None: ... + + @property + def type(self) -> str: ... + + @property + def body(self) -> Buffer: ... + + def serialize(self) -> bytes: ... + + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class ActionType(NamedTuple): + type: str + description: str + + def make_action(self, buf: Buffer | bytes) -> Action: ... + + +class Result(_Weakrefable): + def __init__(self, buf: Buffer | bytes) -> None: ... + + @property + def body(self) -> Buffer: ... + + def serialize(self) -> bytes: ... + + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class BasicAuth(_Weakrefable): + def __init__( + self, username: str | bytes | None = None, password: str | bytes | None = None + ) -> None: ... + + @property + def username(self) -> bytes: ... + @property + def password(self) -> bytes: ... + def serialize(self) -> str: ... + @staticmethod + def deserialize(serialized: str | bytes) -> BasicAuth: ... + + +class DescriptorType(enum.Enum): + UNKNOWN = 0 + PATH = 1 + CMD = 2 + + +class FlightMethod(enum.Enum): + INVALID = 0 + HANDSHAKE = 1 + LIST_FLIGHTS = 2 + GET_FLIGHT_INFO = 3 + GET_SCHEMA = 4 + DO_GET = 5 + DO_PUT = 6 + DO_ACTION = 7 + LIST_ACTIONS = 8 + DO_EXCHANGE = 9 + + +class FlightDescriptor(_Weakrefable): + @staticmethod + def for_path(*path: str | bytes) -> FlightDescriptor: ... + + @staticmethod + def for_command(command: str | bytes) -> FlightDescriptor: ... + + @property + def descriptor_type(self) -> DescriptorType: ... + + @property + def path(self) -> list[bytes] | None: ... + + @property + def command(self) -> bytes | None: ... + + def serialize(self) -> bytes: ... + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class Ticket(_Weakrefable): + def __init__(self, ticket: str | bytes) -> None: ... + @property + def ticket(self) -> bytes: ... + def serialize(self) -> bytes: ... + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class Location(_Weakrefable): + def __init__(self, uri: str | bytes) -> None: ... + @property + def uri(self) -> bytes: ... + def equals(self, other: Location) -> bool: ... + @staticmethod + def for_grpc_tcp(host: str | bytes, port: int) -> Location: ... + + @staticmethod + def for_grpc_tls(host: str | bytes, port: int) -> Location: ... + + @staticmethod + def for_grpc_unix(path: str | bytes) -> Location: ... + + +class FlightEndpoint(_Weakrefable): + def __init__( + self, + ticket: Ticket | str | bytes | object, + locations: list[str | bytes | Location | object], + expiration_time: Scalar[Any] | str | datetime | None = ..., + app_metadata: bytes | str | object = ..., + ): ... + + @property + def ticket(self) -> Ticket: ... + + @property + def locations(self) -> list[Location]: ... + + def serialize(self) -> bytes: ... + @property + def expiration_time(self) -> Scalar[Any] | None: ... + + @property + def app_metadata(self) -> bytes | str: ... + + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class SchemaResult(_Weakrefable): + def __init__(self, schema: Schema) -> None: ... + + @property + def schema(self) -> Schema: ... + + def serialize(self) -> bytes: ... + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class FlightInfo(_Weakrefable): + def __init__( + self, + schema: Schema | None, + descriptor: FlightDescriptor, + endpoints: list[FlightEndpoint], + total_records: int | None = ..., + total_bytes: int | None = ..., + ordered: bool = ..., + app_metadata: bytes | str = ..., + ) -> None: ... + + @property + def schema(self) -> Schema | None: ... + + @property + def descriptor(self) -> FlightDescriptor: ... + + @property + def endpoints(self) -> list[FlightEndpoint]: ... + + @property + def total_records(self) -> int: ... + + @property + def total_bytes(self) -> int: ... + + @property + def ordered(self) -> bool: ... + + @property + def app_metadata(self) -> bytes | str: ... + + def serialize(self) -> bytes: ... + @classmethod + def deserialize(cls, serialized: bytes) -> Self: ... + + +class FlightStreamChunk(_Weakrefable): + @property + def data(self) -> RecordBatch | None: ... + @property + def app_metadata(self) -> Buffer | None: ... + def __iter__(self): ... + + +class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): + # Needs to be separate class so the "real" class can subclass the + # pure-Python mixin class + + def __iter__(self) -> Self: ... + def __next__(self) -> FlightStreamChunk: ... + @property + def schema(self) -> Schema: ... + + def read_all(self) -> Table: ... + + def read_chunk(self) -> FlightStreamChunk: ... + + def to_reader(self) -> RecordBatchReader: ... + + +class MetadataRecordBatchReader(_MetadataRecordBatchReader): + @property + def stats(self) -> ReadStats: ... + + +class FlightStreamReader(MetadataRecordBatchReader): + @property + def stats(self) -> ReadStats: ... + + def cancel(self) -> None: ... + + def read_all(self) -> Table: ... + + def read(self) -> RecordBatch | None: ... + + +class MetadataRecordBatchWriter(_CRecordBatchWriter): + def begin(self, schema: Schema, options: IpcWriteOptions | None = None) -> None: ... + + def write_metadata(self, buf: Buffer | bytes) -> None: ... + + def write_batch(self, batch: RecordBatch) -> None: ... # type: ignore[override] + + def write_table(self, table: Table, max_chunksize: int | + None = None, **kwargs) -> None: ... + + def close(self) -> None: ... + + def write_with_metadata(self, batch: RecordBatch, buf: Buffer | bytes) -> None: ... + + +class FlightStreamWriter(MetadataRecordBatchWriter): + def done_writing(self) -> None: ... + + +class FlightMetadataReader(_Weakrefable): + def read(self) -> Buffer | None: ... + + +class FlightMetadataWriter(_Weakrefable): + def write(self, message: Buffer) -> None: ... + + +class AsyncioCall(Generic[_T]): + _future: asyncio.Future[_T] + + def as_awaitable(self) -> asyncio.Future[_T]: ... + def wakeup(self, result_or_exception: BaseException | _T) -> None: ... + + +class AsyncioFlightClient: + def __init__(self, client: FlightClient) -> None: ... + + async def get_flight_info( + self, + descriptor: FlightDescriptor, + *, + options: FlightCallOptions | None = None, + ): ... + + +class FlightClient(_Weakrefable): + def __init__( + self, + location: str | tuple[str, int] | Location, + *, + tls_root_certs: str | None = None, + cert_chain: str | None = None, + private_key: str | None = None, + override_hostname: str | None = None, + middleware: list[ClientMiddlewareFactory] | None = None, + write_size_limit_bytes: int | None = None, + disable_server_verification: bool = False, + generic_options: list[tuple[str, int | str]] | None = None, + ): ... + + @property + def supports_async(self) -> bool: ... + def as_async(self) -> AsyncioFlightClient: ... + def wait_for_available(self, timeout: int = 5) -> None: ... + + @classmethod + @deprecated( + "Use the ``FlightClient`` constructor or " + "``pyarrow.flight.connect`` function instead." + ) + def connect( + cls, + location: str | tuple[str, int] | Location, + tls_root_certs: str | None = None, + cert_chain: str | None = None, + private_key: str | None = None, + override_hostname: str | None = None, + disable_server_verification: bool = False, + ) -> FlightClient: ... + + def authenticate( + self, auth_handler: ClientAuthHandler, options: FlightCallOptions | None = None + ) -> None: ... + + def authenticate_basic_token( + self, username: str | bytes, password: str | bytes, + options: FlightCallOptions | None = None + ) -> tuple[str, str]: ... + + def list_actions(self, options: FlightCallOptions | + None = None) -> list[Action]: ... + + def do_action( + self, action: Action | tuple[bytes | str, bytes | str] | str, + options: FlightCallOptions | None = None + ) -> Iterator[Result]: ... + + def list_flights( + self, criteria: str | bytes | None = None, + options: FlightCallOptions | None = None + ) -> Generator[FlightInfo, None, None]: ... + + def get_flight_info( + self, descriptor: FlightDescriptor, options: FlightCallOptions | None = None + ) -> FlightInfo: ... + + def get_schema( + self, descriptor: FlightDescriptor, options: FlightCallOptions | None = None + ) -> SchemaResult: ... + + def do_get( + self, ticket: Ticket, options: FlightCallOptions | None = None + ) -> FlightStreamReader: ... + + def do_put( + self, + descriptor: FlightDescriptor, + schema: Schema | None, + options: FlightCallOptions | None = None, + ) -> tuple[FlightStreamWriter, FlightStreamReader]: ... + + def do_exchange( + self, descriptor: FlightDescriptor, options: FlightCallOptions | None = None + ) -> tuple[FlightStreamWriter, FlightStreamReader]: ... + + def close(self) -> None: ... + + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_value, traceback) -> None: ... + + +class FlightDataStream(_Weakrefable): + ... + + +class RecordBatchStream(FlightDataStream): + def __init__(self, data_source: RecordBatchReader | Table | None = None, + options: IpcWriteOptions | None = None) -> None: ... + + +class GeneratorStream(FlightDataStream): + def __init__( + self, + schema: Schema, + generator: Iterable[ + FlightDataStream + | Table + | RecordBatch + | RecordBatchReader + | tuple[RecordBatch, bytes] + ], + options: IpcWriteOptions | None = None, + ) -> None: ... + + +class ServerCallContext(_Weakrefable): + def peer_identity(self) -> bytes: ... + + def peer(self) -> str: ... + + # Set safe=True as gRPC on Windows sometimes gives garbage bytes + def is_cancelled(self) -> bool: ... + + def add_header(self, key: str, value: str) -> None: ... + + def add_trailer(self, key: str, value: str) -> None: ... + + def get_middleware(self, key: str) -> ServerMiddleware | None: ... + + +class ServerAuthReader(_Weakrefable): + def read(self) -> str: ... + + +class ServerAuthSender(_Weakrefable): + def write(self, message: str) -> None: ... + + +class ClientAuthReader(_Weakrefable): + def read(self) -> str: ... + + +class ClientAuthSender(_Weakrefable): + def write(self, message: str) -> None: ... + + +class ServerAuthHandler(_Weakrefable): + def authenticate(self, outgoing: ServerAuthSender, incoming: ServerAuthReader): ... + + def is_valid(self, token: str) -> bool: ... + + +class ClientAuthHandler(_Weakrefable): + def authenticate(self, outgoing: ClientAuthSender, incoming: ClientAuthReader): ... + + def get_token(self) -> str: ... + + +class CallInfo(NamedTuple): + method: FlightMethod + + +class ClientMiddlewareFactory(_Weakrefable): + def start_call(self, info: CallInfo) -> ClientMiddleware | None: ... + + +class ClientMiddleware(_Weakrefable): + def sending_headers(self) -> dict[str, list[str] | list[bytes]]: ... + + def received_headers(self, headers: dict[str, list[str] | list[bytes]]): ... + + def call_completed(self, exception: ArrowException): ... + + +class ServerMiddlewareFactory(_Weakrefable): + def start_call( + self, info: CallInfo, headers: dict[str, list[str] | list[bytes]] + ) -> ServerMiddleware | None: ... + + +class TracingServerMiddlewareFactory(ServerMiddlewareFactory): + ... + + +class ServerMiddleware(_Weakrefable): + def sending_headers(self) -> dict[str, list[str] | list[bytes]]: ... + + def call_completed(self, exception: ArrowException): ... + + @property + def trace_context(self) -> dict: ... + + +class TracingServerMiddleware(ServerMiddleware): + trace_context: dict + def __init__(self, trace_context: dict) -> None: ... + + +class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory): + def __init__(self, factories: dict[str, ServerMiddlewareFactory]) -> None: ... + + def start_call( # type: ignore[override] + self, info: CallInfo, headers: dict[str, list[str] | list[bytes]] + ) -> _ServerMiddlewareFactoryWrapper | None: ... + + +class _ServerMiddlewareWrapper(ServerMiddleware): + def __init__(self, middleware: dict[str, ServerMiddleware]) -> None: ... + def send_headers(self) -> dict[str, dict[str, list[str] | list[bytes]]]: ... + def call_completed(self, exception: ArrowException) -> None: ... + + +class _FlightServerFinalizer(_Weakrefable): + + def finalize(self) -> None: ... + + +class FlightServerBase(_Weakrefable): + def __init__( + self, + location: str | tuple[str, int] | Location | None = None, + auth_handler: ServerAuthHandler | None = None, + tls_certificates: list[tuple[str, str]] | None = None, + verify_client: bool = False, + root_certificates: str | None = None, + middleware: dict[str, ServerMiddlewareFactory] | None = None, + ): ... + + @property + def port(self) -> int: ... + + def list_flights(self, context: ServerCallContext, + criteria: str) -> Iterator[FlightInfo]: ... + + def get_flight_info( + self, context: ServerCallContext, descriptor: FlightDescriptor + ) -> FlightInfo: ... + + def get_schema(self, context: ServerCallContext, + descriptor: FlightDescriptor) -> Schema: ... + + def do_put( + self, + context: ServerCallContext, + descriptor: FlightDescriptor, + reader: MetadataRecordBatchReader, + writer: FlightMetadataWriter, + ) -> None: ... + + def do_get(self, context: ServerCallContext, + ticket: Ticket) -> FlightDataStream: ... + + def do_exchange( + self, + context: ServerCallContext, + descriptor: FlightDescriptor, + reader: MetadataRecordBatchReader, + writer: MetadataRecordBatchWriter, + ) -> None: ... + + def list_actions(self, context: ServerCallContext) -> Iterable[Action]: ... + + def do_action(self, context: ServerCallContext, + action: Action) -> Iterable[bytes]: ... + + def serve(self) -> None: ... + + def run(self) -> None: ... + + def shutdown(self) -> None: ... + + def wait(self) -> None: ... + + def __enter__(self) -> Self: ... + def __exit__( + self, exc_type: object, exc_value: object, traceback: object) -> None: ... + + +def connect( + location: str | tuple[str, int] | Location, + *, + tls_root_certs: str | None = None, + cert_chain: str | None = None, + private_key: str | None = None, + override_hostname: str | None = None, + middleware: list[ClientMiddlewareFactory] | None = None, + write_size_limit_bytes: int | None = None, + disable_server_verification: bool = False, + generic_options: Sequence[tuple[str, int | str]] | None = None, +) -> FlightClient: ... diff --git a/python/pyarrow-stubs/pyarrow/_substrait.pyi b/python/pyarrow-stubs/pyarrow/_substrait.pyi new file mode 100644 index 000000000000..6818d9822ab0 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/_substrait.pyi @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections.abc import Callable +from typing import Any + +from ._compute import Expression +from .lib import Buffer, RecordBatchReader, Schema, Table, _Weakrefable + + +def run_query( + plan: Buffer | int, + *, + table_provider: Callable[[list[str], Schema], Table] | None = None, + use_threads: bool = True, +) -> RecordBatchReader: ... +def _parse_json_plan(plan: bytes) -> Buffer: ... + + +class SubstraitSchema: + schema: bytes + expression: bytes + def __init__(self, schema: bytes, expression: bytes) -> None: ... + def to_pysubstrait(self) -> Any: ... + + +def serialize_schema(schema: Schema) -> SubstraitSchema: ... +def deserialize_schema(buf: Buffer | bytes | SubstraitSchema) -> Schema: ... + + +def serialize_expressions( + exprs: list[Expression], + names: list[str], + schema: Schema, + *, + allow_arrow_extensions: bool = False, +) -> Buffer: ... + + +class BoundExpressions(_Weakrefable): + @property + def schema(self) -> Schema: ... + @property + def expressions(self) -> dict[str, Expression]: ... + @classmethod + def from_substrait(cls, message: Buffer | bytes | Any) -> BoundExpressions: ... + + +def deserialize_expressions(buf: Buffer | bytes) -> BoundExpressions: ... +def get_supported_functions() -> list[str]: ... diff --git a/python/pyarrow-stubs/pyarrow/cffi.pyi b/python/pyarrow-stubs/pyarrow/cffi.pyi new file mode 100644 index 000000000000..e4f077d7155b --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/cffi.pyi @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import cffi + +c_source: str +ffi: cffi.FFI diff --git a/python/pyarrow-stubs/pyarrow/compat.pyi b/python/pyarrow-stubs/pyarrow/compat.pyi new file mode 100644 index 000000000000..30e3ec13e0dd --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/compat.pyi @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def encode_file_path(path: str | bytes) -> bytes: ... +def tobytes(o: str | bytes) -> bytes: ... +def frombytes(o: bytes, *, safe: bool = False): ... + + +__all__ = ["encode_file_path", "tobytes", "frombytes"] diff --git a/python/pyarrow-stubs/pyarrow/cuda.pyi b/python/pyarrow-stubs/pyarrow/cuda.pyi new file mode 100644 index 000000000000..0394965bb738 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/cuda.pyi @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._cuda import ( + BufferReader, + BufferWriter, + Context, + CudaBuffer, + HostBuffer, + IpcMemHandle, + new_host_buffer, + read_message, + read_record_batch, + serialize_record_batch, +) + +__all__ = [ + "BufferReader", + "BufferWriter", + "Context", + "CudaBuffer", + "HostBuffer", + "IpcMemHandle", + "new_host_buffer", + "read_message", + "read_record_batch", + "serialize_record_batch", +] diff --git a/python/pyarrow-stubs/pyarrow/flight.pyi b/python/pyarrow-stubs/pyarrow/flight.pyi new file mode 100644 index 000000000000..dcc6ee2244b3 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/flight.pyi @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._flight import ( + Action, + ActionType, + BasicAuth, + CallInfo, + CertKeyPair, + ClientAuthHandler, + ClientMiddleware, + ClientMiddlewareFactory, + DescriptorType, + FlightCallOptions, + FlightCancelledError, + FlightClient, + FlightDataStream, + FlightDescriptor, + FlightEndpoint, + FlightError, + FlightInfo, + FlightInternalError, + FlightMetadataReader, + FlightMetadataWriter, + FlightMethod, + FlightServerBase, + FlightServerError, + FlightStreamChunk, + FlightStreamReader, + FlightStreamWriter, + FlightTimedOutError, + FlightUnauthenticatedError, + FlightUnauthorizedError, + FlightUnavailableError, + FlightWriteSizeExceededError, + GeneratorStream, + Location, + MetadataRecordBatchReader, + MetadataRecordBatchWriter, + RecordBatchStream, + Result, + SchemaResult, + ServerAuthHandler, + ServerCallContext, + ServerMiddleware, + ServerMiddlewareFactory, + Ticket, + TracingServerMiddlewareFactory, + connect, +) + +__all__ = [ + "Action", + "ActionType", + "BasicAuth", + "CallInfo", + "CertKeyPair", + "ClientAuthHandler", + "ClientMiddleware", + "ClientMiddlewareFactory", + "DescriptorType", + "FlightCallOptions", + "FlightCancelledError", + "FlightClient", + "FlightDataStream", + "FlightDescriptor", + "FlightEndpoint", + "FlightError", + "FlightInfo", + "FlightInternalError", + "FlightMetadataReader", + "FlightMetadataWriter", + "FlightMethod", + "FlightServerBase", + "FlightServerError", + "FlightStreamChunk", + "FlightStreamReader", + "FlightStreamWriter", + "FlightTimedOutError", + "FlightUnauthenticatedError", + "FlightUnauthorizedError", + "FlightUnavailableError", + "FlightWriteSizeExceededError", + "GeneratorStream", + "Location", + "MetadataRecordBatchReader", + "MetadataRecordBatchWriter", + "RecordBatchStream", + "Result", + "SchemaResult", + "ServerAuthHandler", + "ServerCallContext", + "ServerMiddleware", + "ServerMiddlewareFactory", + "Ticket", + "TracingServerMiddlewareFactory", + "connect", +] diff --git a/python/pyarrow-stubs/pyarrow/gandiva.pyi b/python/pyarrow-stubs/pyarrow/gandiva.pyi new file mode 100644 index 000000000000..7e129d3ed1de --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/gandiva.pyi @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections.abc import Iterable +from typing import Literal + +from .lib import Array, DataType, Field, MemoryPool, RecordBatch, Schema, _Weakrefable + + +class Node(_Weakrefable): + def return_type(self) -> DataType: ... + + +class Expression(_Weakrefable): + def root(self) -> Node: ... + def result(self) -> Field: ... + + +class Condition(_Weakrefable): + def root(self) -> Node: ... + def result(self) -> Field: ... + + +class SelectionVector(_Weakrefable): + def to_array(self) -> Array: ... + + +class Projector(_Weakrefable): + @property + def llvm_ir(self): ... + + def evaluate( + self, batch: RecordBatch, selection: SelectionVector | None = None + ) -> list[Array]: ... + + +class Filter(_Weakrefable): + @property + def llvm_ir(self): ... + + def evaluate( + self, batch: RecordBatch, pool: MemoryPool, dtype: DataType | str = "int32" + ) -> SelectionVector: ... + + +class TreeExprBuilder(_Weakrefable): + def make_literal(self, value: float | str | bytes | + bool, dtype: DataType | str | None) -> Node: ... + + def make_expression( + self, root_node: Node | None, return_field: Field) -> Expression: ... + + def make_function( + self, name: str, children: list[Node | None], + return_type: DataType) -> Node: ... + + def make_field(self, field: Field | None) -> Node: ... + + def make_if( + self, condition: Node, this_node: Node | None, + else_node: Node | None, return_type: DataType | None + ) -> Node: ... + def make_and(self, children: list[Node | None]) -> Node: ... + def make_or(self, children: list[Node | None]) -> Node: ... + def make_in_expression(self, node: Node | None, values: Iterable, + dtype: DataType) -> Node: ... + + def make_condition(self, condition: Node | None) -> Condition: ... + + +class Configuration(_Weakrefable): + def __init__(self, optimize: bool = True, dump_ir: bool = False) -> None: ... + + +def make_projector( + schema: Schema, + children: list[Expression | None], + pool: MemoryPool | None = None, + selection_mode: Literal["NONE", "UINT16", "UINT32", "UINT64"] = "NONE", + configuration: Configuration | None = None, +) -> Projector: ... + + +def make_filter( + schema: Schema, condition: Condition | None, + configuration: Configuration | None = None +) -> Filter: ... + + +class FunctionSignature(_Weakrefable): + def return_type(self) -> DataType: ... + def param_types(self) -> list[DataType]: ... + def name(self) -> str: ... + + +def get_registered_function_signatures() -> list[FunctionSignature]: ... diff --git a/python/pyarrow-stubs/pyarrow/interchange/__init__.pyi b/python/pyarrow-stubs/pyarrow/interchange/__init__.pyi new file mode 100644 index 000000000000..fd5ae83c5692 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/interchange/__init__.pyi @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .from_dataframe import from_dataframe as from_dataframe + +__all__ = ["from_dataframe"] diff --git a/python/pyarrow-stubs/pyarrow/interchange/buffer.pyi b/python/pyarrow-stubs/pyarrow/interchange/buffer.pyi new file mode 100644 index 000000000000..e1d8ae949c90 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/interchange/buffer.pyi @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import enum + +from pyarrow.lib import Buffer + + +class DlpackDeviceType(enum.IntEnum): + CPU = 1 + CUDA = 2 + CPU_PINNED = 3 + OPENCL = 4 + VULKAN = 7 + METAL = 8 + VPI = 9 + ROCM = 10 + + +class _PyArrowBuffer: + def __init__(self, x: Buffer, allow_copy: bool = True) -> None: ... + @property + def bufsize(self) -> int: ... + @property + def ptr(self) -> int: ... + def __dlpack__(self): ... + def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]: ... diff --git a/python/pyarrow-stubs/pyarrow/interchange/column.pyi b/python/pyarrow-stubs/pyarrow/interchange/column.pyi new file mode 100644 index 000000000000..67508ac0689c --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/interchange/column.pyi @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import enum + +from collections.abc import Iterable +from typing import Any, TypeAlias, TypedDict + +from pyarrow.lib import Array, ChunkedArray + +from .buffer import _PyArrowBuffer + + +class DtypeKind(enum.IntEnum): + INT = 0 + UINT = 1 + FLOAT = 2 + BOOL = 20 + STRING = 21 # UTF-8 + DATETIME = 22 + CATEGORICAL = 23 + + +Dtype: TypeAlias = tuple[DtypeKind, int, str, str] + + +class ColumnNullType(enum.IntEnum): + NON_NULLABLE = 0 + USE_NAN = 1 + USE_SENTINEL = 2 + USE_BITMASK = 3 + USE_BYTEMASK = 4 + + +class ColumnBuffers(TypedDict): + data: tuple[_PyArrowBuffer, Dtype] + validity: tuple[_PyArrowBuffer, Dtype] | None + offsets: tuple[_PyArrowBuffer, Dtype] | None + + +class CategoricalDescription(TypedDict): + is_ordered: bool + is_dictionary: bool + categories: _PyArrowColumn | None + + +class Endianness(enum.Enum): + LITTLE = "<" + BIG = ">" + NATIVE = "=" + NA = "|" + + +class NoBufferPresent(Exception): + ... + + +class _PyArrowColumn: + _col: Array | ChunkedArray + + def __init__(self, column: Array | ChunkedArray, + allow_copy: bool = True) -> None: ... + + def size(self) -> int: ... + @property + def offset(self) -> int: ... + @property + def dtype(self) -> tuple[DtypeKind, int, str, str]: ... + @property + def describe_categorical(self) -> CategoricalDescription: ... + @property + def describe_null(self) -> tuple[ColumnNullType, Any]: ... + @property + def null_count(self) -> int: ... + @property + def metadata(self) -> dict[str, Any]: ... + def num_chunks(self) -> int: ... + def get_chunks(self, n_chunks: int | None = None) -> Iterable[_PyArrowColumn]: ... + def get_buffers(self) -> ColumnBuffers: ... diff --git a/python/pyarrow-stubs/pyarrow/interchange/dataframe.pyi b/python/pyarrow-stubs/pyarrow/interchange/dataframe.pyi new file mode 100644 index 000000000000..419b3e2cdb33 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/interchange/dataframe.pyi @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self +from collections.abc import Iterable, Sequence +from typing import Any + +from pyarrow.interchange.column import _PyArrowColumn +from pyarrow.lib import RecordBatch, Table + + +class _PyArrowDataFrame: + def __init__( + self, + df: Table | RecordBatch, + nan_as_null: bool = False, + allow_copy: bool = True) -> None: ... + + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> _PyArrowDataFrame: ... + @property + def metadata(self) -> dict[str, Any]: ... + def num_columns(self) -> int: ... + def num_rows(self) -> int: ... + def num_chunks(self) -> int: ... + def column_names(self) -> Iterable[str]: ... + def get_column(self, i: int) -> _PyArrowColumn: ... + def get_column_by_name(self, name: str) -> _PyArrowColumn: ... + def get_columns(self) -> Iterable[_PyArrowColumn]: ... + def select_columns(self, indices: Sequence[int]) -> Self: ... + def select_columns_by_name(self, names: Sequence[str]) -> Self: ... + def get_chunks(self, n_chunks: int | None = None) -> Iterable[Self]: ... diff --git a/python/pyarrow-stubs/pyarrow/interchange/from_dataframe.pyi b/python/pyarrow-stubs/pyarrow/interchange/from_dataframe.pyi new file mode 100644 index 000000000000..d6ad272dfc69 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/interchange/from_dataframe.pyi @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Protocol, TypeAlias + +from pyarrow.lib import Array, Buffer, DataType, DictionaryArray, RecordBatch, Table + +from .column import ( + ColumnBuffers, + ColumnNullType, + Dtype, + DtypeKind, +) + + +class DataFrameObject(Protocol): + def __dataframe__(self, nan_as_null: bool = False, + allow_copy: bool = True) -> Any: ... + + +ColumnObject: TypeAlias = Any + + +def from_dataframe(df: DataFrameObject, allow_copy=True) -> Table: ... + + +def _from_dataframe(df: DataFrameObject, allow_copy=True) -> Table: ... + + +def protocol_df_chunk_to_pyarrow( + df: DataFrameObject, allow_copy: bool = True) -> RecordBatch: ... + + +def column_to_array(col: ColumnObject, allow_copy: bool = True) -> Array: ... + + +def bool_column_to_array(col: ColumnObject, allow_copy: bool = True) -> Array: ... + + +def categorical_column_to_dictionary( + col: ColumnObject, allow_copy: bool = True +) -> DictionaryArray: ... + + +def parse_datetime_format_str(format_str: str) -> tuple[str, str]: ... + + +def map_date_type(data_type: tuple[DtypeKind, int, str, str]) -> DataType: ... + + +def buffers_to_array( + buffers: ColumnBuffers, + data_type: tuple[DtypeKind, int, str, str], + length: int, + describe_null: ColumnNullType, + offset: int = 0, + allow_copy: bool = True, +) -> Array: ... + + +def validity_buffer_from_mask( + validity_buff: Buffer, + validity_dtype: Dtype, + describe_null: ColumnNullType, + length: int, + offset: int = 0, + allow_copy: bool = True, +) -> Buffer: ... + + +def validity_buffer_nan_sentinel( + data_pa_buffer: Buffer, + data_type: Dtype, + describe_null: ColumnNullType, + length: int, + offset: int = 0, + allow_copy: bool = True, +) -> Buffer: ... diff --git a/python/pyarrow-stubs/pyarrow/pandas_compat.pyi b/python/pyarrow-stubs/pyarrow/pandas_compat.pyi new file mode 100644 index 000000000000..4e614c58a3fd --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/pandas_compat.pyi @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, TypedDict, TypeVar + +import numpy as np +import pandas as pd + +from pandas import DatetimeTZDtype + +from .lib import Array, DataType, Schema, Table, _pandas_api + +_T = TypeVar("_T") + + +def get_logical_type_map() -> dict[int, str]: ... +def get_logical_type(arrow_type: DataType) -> str: ... +def get_numpy_logical_type_map() -> dict[type[np.generic], str]: ... +def get_logical_type_from_numpy(pandas_collection) -> str: ... +def get_extension_dtype_info(column) -> tuple[str, dict[str, Any]]: ... + + +class _ColumnMetadata(TypedDict): + name: str + field_name: str + pandas_type: int + numpy_type: str + metadata: dict | None + + +def get_column_metadata( + column: pd.Series | pd.Index, name: str, arrow_type: DataType, field_name: str +) -> _ColumnMetadata: ... + + +def construct_metadata( + columns_to_convert: list[pd.Series], + df: pd.DataFrame, + column_names: list[str], + index_levels: list[pd.Index], + index_descriptors: list[dict], + preserve_index: bool, + types: list[DataType], + column_field_names: list[str] = ..., +) -> dict[bytes, bytes]: ... + + +def dataframe_to_types( + df: pd.DataFrame, preserve_index: bool | None, columns: list[str] | None = None +) -> tuple[list[str], list[DataType], dict[bytes, bytes]]: ... + + +def dataframe_to_arrays( + df: pd.DataFrame, + schema: Schema, + preserve_index: bool | None, + nthreads: int = 1, + columns: list[str] | None = None, + safe: bool = True, +) -> tuple[Array, Schema, int]: ... +def get_datetimetz_type(values: _T, dtype, type_) -> tuple[_T, DataType]: ... +def make_datetimetz(unit: str, tz: str) -> DatetimeTZDtype: ... + + +def table_to_dataframe( + options, + table: Table, + categories=None, + ignore_metadata: bool = False, + types_mapper=None) -> pd.DataFrame: ... + + +def make_tz_aware(series: pd.Series, tz: str) -> pd.Series: ... + + +__all__ = [ + "_pandas_api", +] diff --git a/python/pyarrow-stubs/pyarrow/pandas_shim.pyi b/python/pyarrow-stubs/pyarrow/pandas_shim.pyi new file mode 100644 index 000000000000..181d78e7a0c9 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/pandas_shim.pyi @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import types as stdlib_types +from collections.abc import Iterable +from typing import Any, TypeGuard + +from pandas import Categorical, DatetimeTZDtype, Index, Series, DataFrame + +from numpy import dtype +from pandas.core.dtypes.base import ExtensionDtype + + +class _PandasAPIShim: + has_sparse: bool + + def series(self, *args, **kwargs) -> Series: ... + def data_frame(self, *args, **kwargs) -> DataFrame: ... + @property + def have_pandas(self) -> bool: ... + @property + def compat(self) -> stdlib_types.ModuleType: ... + @property + def pd(self) -> stdlib_types.ModuleType: ... + def infer_dtype(self, obj: Iterable) -> str: ... + def pandas_dtype(self, dtype: str) -> dtype: ... + @property + def loose_version(self) -> Any: ... + @property + def version(self) -> str: ... + def is_v1(self) -> bool: ... + def is_ge_v21(self) -> bool: ... + def is_ge_v23(self) -> bool: ... + def is_ge_v3(self) -> bool: ... + def uses_string_dtype(self) -> bool: ... + @property + def categorical_type(self) -> type[Categorical]: ... + @property + def datetimetz_type(self) -> type[DatetimeTZDtype]: ... + @property + def extension_dtype(self) -> type[ExtensionDtype]: ... + + def is_array_like( + self, obj: Any + ) -> TypeGuard[Series | Index | Categorical | ExtensionDtype]: ... + def is_categorical(self, obj: Any) -> TypeGuard[Categorical]: ... + def is_datetimetz(self, obj: Any) -> TypeGuard[DatetimeTZDtype]: ... + def is_extension_array_dtype(self, obj: Any) -> TypeGuard[ExtensionDtype]: ... + def is_sparse(self, obj: Any) -> bool: ... + def is_data_frame(self, obj: Any) -> TypeGuard[DataFrame]: ... + def is_series(self, obj: Any) -> TypeGuard[Series]: ... + def is_index(self, obj: Any) -> TypeGuard[Index]: ... + def get_values(self, obj: Any) -> bool: ... + def get_rangeindex_attribute(self, level, name): ... + + +_pandas_api: _PandasAPIShim + +__all__ = ["_PandasAPIShim", "_pandas_api"] diff --git a/python/pyarrow-stubs/pyarrow/substrait.pyi b/python/pyarrow-stubs/pyarrow/substrait.pyi new file mode 100644 index 000000000000..b78bbd8aebd7 --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/substrait.pyi @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._substrait import ( + BoundExpressions, + SubstraitSchema, + deserialize_expressions, + deserialize_schema, + get_supported_functions, + run_query, + serialize_expressions, + serialize_schema, +) + +__all__ = [ + "BoundExpressions", + "get_supported_functions", + "run_query", + "deserialize_expressions", + "serialize_expressions", + "deserialize_schema", + "serialize_schema", + "SubstraitSchema", +] diff --git a/python/pyarrow-stubs/pyarrow/util.pyi b/python/pyarrow-stubs/pyarrow/util.pyi new file mode 100644 index 000000000000..c3317960c81c --- /dev/null +++ b/python/pyarrow-stubs/pyarrow/util.pyi @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections.abc import Callable, Sequence +from os import PathLike +from typing import Any, Protocol, TypeVar + +_F = TypeVar("_F", bound=Callable) +_N = TypeVar("_N") + + +class _DocStringComponents(Protocol): + _docstring_components: list[str] + + +def doc( + *docstrings: str | _DocStringComponents | Callable | None, **params: Any +) -> Callable[[_F], _F]: ... +def _is_iterable(obj) -> bool: ... +def _is_path_like(path) -> bool: ... +def _stringify_path(path: str | PathLike) -> str: ... +def product(seq: Sequence[_N]) -> _N: ... + + +def get_contiguous_span( + shape: tuple[int, ...], strides: tuple[int, ...], itemsize: int +) -> tuple[int, int]: ... +def find_free_port() -> int: ... +def guid() -> str: ... +def _download_urllib(url, out_path) -> None: ... +def _download_requests(url, out_path) -> None: ... +def download_tzdata_on_windows() -> None: ... +def _deprecate_api(old_name, new_name, api, next_version, type=...): ... +def _deprecate_class(old_name, new_class, next_version, instancecheck=True): ... +def _break_traceback_cycle_from_frame(frame) -> None: ... diff --git a/python/pyarrow/benchmark.py b/python/pyarrow/benchmark.py index 25ee1141f08d..0ee9063a9a76 100644 --- a/python/pyarrow/benchmark.py +++ b/python/pyarrow/benchmark.py @@ -18,4 +18,4 @@ # flake8: noqa -from pyarrow.lib import benchmark_PandasObjectIsNull +from pyarrow.lib import benchmark_PandasObjectIsNull # type: ignore[attr-defined] diff --git a/python/pyarrow/cffi.py b/python/pyarrow/cffi.py index 1da1a9169140..e5a1c9c1d072 100644 --- a/python/pyarrow/cffi.py +++ b/python/pyarrow/cffi.py @@ -16,8 +16,15 @@ # under the License. from __future__ import absolute_import +from typing import TYPE_CHECKING -import cffi +if TYPE_CHECKING: + import cffi +else: + try: + import cffi + except ImportError: + pass c_source = """ struct ArrowSchema { diff --git a/python/pyarrow/cuda.py b/python/pyarrow/cuda.py index 18c530d4afe4..eeb637f0ab41 100644 --- a/python/pyarrow/cuda.py +++ b/python/pyarrow/cuda.py @@ -18,7 +18,7 @@ # flake8: noqa -from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer, +from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer, # type: ignore[reportMissingModuleSource] HostBuffer, BufferReader, BufferWriter, new_host_buffer, serialize_record_batch, read_message, diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py index b1836907c674..ba5008c9ecf7 100644 --- a/python/pyarrow/flight.py +++ b/python/pyarrow/flight.py @@ -16,7 +16,7 @@ # under the License. try: - from pyarrow._flight import ( # noqa:F401 + from pyarrow._flight import ( # noqa:F401 # type: ignore[import-not-found] connect, Action, ActionType, diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index dfca59cbf5f9..b9086ce4e86b 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -33,18 +33,18 @@ try: import numpy as np except ImportError: - np = None + pass import pyarrow as pa from pyarrow.lib import _pandas_api, frombytes, is_threading_enabled # noqa -_logical_type_map = {} -_numpy_logical_type_map = {} -_pandas_logical_type_map = {} +_logical_type_map: dict[int, str] = {} +_numpy_logical_type_map: dict[int, str] = {} +_pandas_logical_type_map: dict[int, str] = {} def get_logical_type_map(): - global _logical_type_map + global _logical_type_map # noqa: F824 if not _logical_type_map: _logical_type_map.update({ @@ -90,9 +90,9 @@ def get_logical_type(arrow_type): def get_numpy_logical_type_map(): - global _numpy_logical_type_map + global _numpy_logical_type_map # noqa: F824 if not _numpy_logical_type_map: - _numpy_logical_type_map.update({ + _numpy_logical_type_map.update({ # type: ignore[reportCallIssue] np.bool_: 'bool', np.int8: 'int8', np.int16: 'int16', @@ -704,7 +704,7 @@ def get_datetimetz_type(values, dtype, type_): # If no user type passed, construct a tz-aware timestamp type tz = dtype.tz unit = dtype.unit - type_ = pa.timestamp(unit, tz) + type_ = pa.timestamp(unit, tz) # type: ignore[reportArgumentType] elif type_ is None: # Trust the NumPy dtype type_ = pa.from_numpy_dtype(values.dtype) @@ -743,7 +743,7 @@ def _reconstruct_block(item, columns=None, extension_columns=None, return_block= pandas Block """ - import pandas.core.internals as _int + import pandas.core.internals as _int # type: ignore[import-not-found] block_arr = item.get('block', None) placement = item['placement'] @@ -769,6 +769,8 @@ def _reconstruct_block(item, columns=None, extension_columns=None, return_block= # create ExtensionBlock arr = item['py_array'] assert len(placement) == 1 + assert isinstance(columns, list) + assert isinstance(extension_columns, dict) name = columns[placement[0]] pandas_dtype = extension_columns[name] if not hasattr(pandas_dtype, '__from_arrow__'): @@ -788,7 +790,7 @@ def make_datetimetz(unit, tz): if _pandas_api.is_v1(): unit = 'ns' # ARROW-3789: Coerce date/timestamp types to datetime64[ns] tz = pa.lib.string_to_tzinfo(tz) - return _pandas_api.datetimetz_type(unit, tz=tz) + return _pandas_api.datetimetz_type(unit, tz=tz) # type: ignore[reportArgumentType] def table_to_dataframe( @@ -822,7 +824,8 @@ def table_to_dataframe( result = pa.lib.table_to_blocks(options, table, categories, list(ext_columns_dtypes.keys())) if _pandas_api.is_ge_v3(): - from pandas.api.internals import create_dataframe_from_blocks + from pandas.api.internals import ( # type: ignore[import-not-found] + create_dataframe_from_blocks) blocks = [ _reconstruct_block( @@ -834,7 +837,8 @@ def table_to_dataframe( return df else: - from pandas.core.internals import BlockManager + from pandas.core.internals import ( # type: ignore[reportMissingImports] + BlockManager) from pandas import DataFrame blocks = [ @@ -844,7 +848,8 @@ def table_to_dataframe( axes = [columns, index] mgr = BlockManager(blocks, axes) if _pandas_api.is_ge_v21(): - df = DataFrame._from_mgr(mgr, mgr.axes) + df = DataFrame._from_mgr( # type: ignore[reportAttributeAccessIssue] + mgr, mgr.axes) else: df = DataFrame(mgr) @@ -1092,10 +1097,10 @@ def _is_generated_index_name(name): def get_pandas_logical_type_map(): - global _pandas_logical_type_map + global _pandas_logical_type_map # noqa: F824 if not _pandas_logical_type_map: - _pandas_logical_type_map.update({ + _pandas_logical_type_map.update({ # type: ignore[reportCallIssue] 'date': 'datetime64[D]', 'datetime': 'datetime64[ns]', 'datetimetz': 'datetime64[ns]', @@ -1162,12 +1167,14 @@ def _reconstruct_columns_from_metadata(columns, column_indexes): labels = getattr(columns, 'codes', None) or [None] # Convert each level to the dtype provided in the metadata - levels_dtypes = [ - (level, col_index.get('pandas_type', str(level.dtype)), - col_index.get('numpy_type', None)) + levels_dtypes = [(level, col_index.get( + 'pandas_type', + str(level.dtype) # type: ignore[reportAttributeAccessIssue] + ), + col_index.get('numpy_type', None)) for level, col_index in zip_longest( levels, column_indexes, fillvalue={} - ) + ) ] new_levels = [] @@ -1179,7 +1186,7 @@ def _reconstruct_columns_from_metadata(columns, column_indexes): # bytes into unicode strings when json.loads-ing them. We need to # convert them back to bytes to preserve metadata. if dtype == np.bytes_: - level = level.map(encoder) + level = level.map(encoder) # type: ignore[reportAttributeAccessIssue] # ARROW-13756: if index is timezone aware DataTimeIndex elif pandas_dtype == "datetimetz": tz = pa.lib.string_to_tzinfo( @@ -1193,7 +1200,8 @@ def _reconstruct_columns_from_metadata(columns, column_indexes): elif pandas_dtype == "decimal": level = _pandas_api.pd.Index([decimal.Decimal(i) for i in level]) elif ( - level.dtype == "str" and numpy_dtype == "object" + level.dtype == "str" # type: ignore[reportAttributeAccessIssue] + and numpy_dtype == "object" and ("mixed" in pandas_dtype or pandas_dtype in ["unicode", "string"]) ): # the metadata indicate that the original dataframe used object dtype, @@ -1206,11 +1214,12 @@ def _reconstruct_columns_from_metadata(columns, column_indexes): # for pandas >= 3 we want to use the default string dtype for .columns new_levels.append(level) continue - elif level.dtype != dtype: - level = level.astype(dtype) + elif level.dtype != dtype: # type: ignore[reportAttributeAccessIssue] + level = level.astype(dtype) # type: ignore[reportAttributeAccessIssue] # ARROW-9096: if original DataFrame was upcast we keep that if level.dtype != numpy_dtype and pandas_dtype != "datetimetz": - level = level.astype(numpy_dtype) + level = level.astype( # type: ignore[reportAttributeAccessIssue] + numpy_dtype) new_levels.append(level) diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 50da6693afff..62da25f0af32 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -23,7 +23,7 @@ try: import numpy as np except ImportError: - np = None + pass import pyarrow.interchange as pi from pyarrow.interchange.column import ( @@ -163,8 +163,8 @@ def test_pandas_roundtrip_string(): result = pi.from_dataframe(pandas_df) assert result["a"].to_pylist() == table["a"].to_pylist() - assert pa.types.is_string(table["a"].type) - assert pa.types.is_large_string(result["a"].type) + assert pa.types.is_string(table.column("a").type) + assert pa.types.is_large_string(result.column("a").type) table_protocol = table.__dataframe__() result_protocol = result.__dataframe__() @@ -193,8 +193,8 @@ def test_pandas_roundtrip_large_string(): result = pi.from_dataframe(pandas_df) assert result["a_large"].to_pylist() == table["a_large"].to_pylist() - assert pa.types.is_large_string(table["a_large"].type) - assert pa.types.is_large_string(result["a_large"].type) + assert pa.types.is_large_string(table.column("a_large").type) + assert pa.types.is_large_string(result.column("a_large").type) table_protocol = table.__dataframe__() result_protocol = result.__dataframe__() @@ -231,12 +231,12 @@ def test_pandas_roundtrip_string_with_missing(): result = pi.from_dataframe(pandas_df) assert result["a"].to_pylist() == table["a"].to_pylist() - assert pa.types.is_string(table["a"].type) - assert pa.types.is_large_string(result["a"].type) + assert pa.types.is_string(table.column("a").type) + assert pa.types.is_large_string(result.column("a").type) assert result["a_large"].to_pylist() == table["a_large"].to_pylist() - assert pa.types.is_large_string(table["a_large"].type) - assert pa.types.is_large_string(result["a_large"].type) + assert pa.types.is_large_string(table.column("a_large").type) + assert pa.types.is_large_string(result.column("a_large").type) else: # older versions of pandas do not have bitmask support # https://github.com/pandas-dev/pandas/issues/49888 @@ -261,12 +261,16 @@ def test_pandas_roundtrip_categorical(): result = pi.from_dataframe(pandas_df) assert result["weekday"].to_pylist() == table["weekday"].to_pylist() - assert pa.types.is_dictionary(table["weekday"].type) - assert pa.types.is_dictionary(result["weekday"].type) - assert pa.types.is_string(table["weekday"].chunk(0).dictionary.type) - assert pa.types.is_large_string(result["weekday"].chunk(0).dictionary.type) - assert pa.types.is_int32(table["weekday"].chunk(0).indices.type) - assert pa.types.is_int8(result["weekday"].chunk(0).indices.type) + assert pa.types.is_dictionary(table.column("weekday").type) + assert pa.types.is_dictionary(result.column("weekday").type) + table_chunk_0 = table.column("weekday").chunk(0) + result_chunk_0 = result.column("weekday").chunk(0) + assert isinstance(table_chunk_0, pa.DictionaryArray) + assert isinstance(result_chunk_0, pa.DictionaryArray) + assert pa.types.is_string(table_chunk_0.dictionary.type) + assert pa.types.is_large_string(result_chunk_0.dictionary.type) + assert pa.types.is_int32(table_chunk_0.indices.type) + assert pa.types.is_int8(result_chunk_0.indices.type) table_protocol = table.__dataframe__() result_protocol = result.__dataframe__() @@ -289,6 +293,7 @@ def test_pandas_roundtrip_categorical(): assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] + assert desc_cat_result["categories"] is not None assert isinstance(desc_cat_result["categories"]._col, pa.Array) @@ -450,6 +455,7 @@ def test_pyarrow_roundtrip_categorical(offset, length): assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] + assert desc_cat_result["categories"] is not None assert isinstance(desc_cat_result["categories"]._col, pa.Array) @@ -464,8 +470,8 @@ def test_pyarrow_roundtrip_large_string(): col = result.__dataframe__().get_column(0) assert col.size() == 3*1024**2 - assert pa.types.is_large_string(table[0].type) - assert pa.types.is_large_string(result[0].type) + assert pa.types.is_large_string(table.column(0).type) + assert pa.types.is_large_string(result.column(0).type) assert table.equals(result) diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py b/python/pyarrow/tests/interchange/test_interchange_spec.py index cea694d1c1ee..3208b56c42df 100644 --- a/python/pyarrow/tests/interchange/test_interchange_spec.py +++ b/python/pyarrow/tests/interchange/test_interchange_spec.py @@ -23,7 +23,7 @@ try: import numpy as np except ImportError: - np = None + pass import pyarrow as pa import pyarrow.tests.strategies as past diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index 481c387d5337..f8abec902694 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -24,7 +24,7 @@ try: from pyarrow.cffi import ffi except ImportError: - ffi = None + pass import pytest @@ -32,7 +32,7 @@ import pandas as pd import pandas.testing as tm except ImportError: - pd = tm = None + pd = None # type: ignore[assignment] needs_cffi = pytest.mark.skipif(ffi is None, @@ -148,7 +148,7 @@ def test_export_import_type(): # Invalid format string pa.int32()._export_to_c(ptr_schema) bad_format = ffi.new("char[]", b"zzz") - c_schema.format = bad_format + c_schema.format = bad_format # type: ignore[attr-defined] with pytest.raises(ValueError, match="Invalid or unsupported format string"): pa.DataType._import_from_c(ptr_schema) @@ -248,9 +248,9 @@ def test_export_import_device_array(): arr = pa.array([[1], [2, 42]], type=pa.list_(pa.int32())) arr._export_to_c_device(ptr_array) - assert c_array.device_type == 1 # ARROW_DEVICE_CPU 1 - assert c_array.device_id == -1 - assert c_array.array.length == 2 + assert c_array.device_type == 1 # type: ignore[attr-defined] # ARROW_DEVICE_CPU 1 + assert c_array.device_id == -1 # type: ignore[attr-defined] + assert c_array.array.length == 2 # type: ignore[attr-defined] def check_export_import_schema(schema_factory, expected_schema_factory=None): @@ -310,9 +310,10 @@ def test_export_import_schema_float_pointer(): match = "Passing a pointer value as a float is unsafe" with pytest.warns(UserWarning, match=match): - make_schema()._export_to_c(float(ptr_schema)) + make_schema()._export_to_c(float(ptr_schema)) # type: ignore[arg-type] with pytest.warns(UserWarning, match=match): - schema_new = pa.Schema._import_from_c(float(ptr_schema)) + schema_new = pa.Schema._import_from_c( + float(ptr_schema)) # type: ignore[arg-type] assert schema_new == make_schema() @@ -405,9 +406,9 @@ def test_export_import_device_batch(): ptr_array = int(ffi.cast("uintptr_t", c_array)) batch = make_batch() batch._export_to_c_device(ptr_array) - assert c_array.device_type == 1 # ARROW_DEVICE_CPU 1 - assert c_array.device_id == -1 - assert c_array.array.length == 2 + assert c_array.device_type == 1 # type: ignore[attr-defined] # ARROW_DEVICE_CPU 1 + assert c_array.device_id == -1 # type: ignore[attr-defined] + assert c_array.array.length == 2 # type: ignore[attr-defined] def _export_import_batch_reader(ptr_stream, reader_factory): @@ -764,7 +765,7 @@ def test_import_device_no_cuda(): # patch the device type of the struct, this results in an invalid ArrowDeviceArray # but this is just to test we raise am error before actually importing buffers - c_array.device_type = 2 # ARROW_DEVICE_CUDA + c_array.device_type = 2 # type: ignore[attr-defined] # ARROW_DEVICE_CUDA with pytest.raises(ImportError, match="Trying to import data on a CUDA device"): pa.Array._import_from_c_device(ptr_array, arr.type) diff --git a/python/pyarrow/tests/test_cuda.py b/python/pyarrow/tests/test_cuda.py index e06f479987cb..9d03a3bbff2f 100644 --- a/python/pyarrow/tests/test_cuda.py +++ b/python/pyarrow/tests/test_cuda.py @@ -103,6 +103,7 @@ def make_random_buffer(size, target='host'): assert size >= 0 buf = pa.allocate_buffer(size) assert buf.size == size + assert isinstance(buf, pa.Buffer) arr = np.frombuffer(buf, dtype=np.uint8) assert arr.size == size arr[:] = np.random.randint(low=1, high=255, size=size, dtype=np.uint8) @@ -194,12 +195,14 @@ def test_context_device_buffer(size): np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) # Creating a device buffer from a slice of an array - cudabuf = global_context.buffer_from_data(arr, offset=soffset, size=ssize) + cudabuf = global_context.buffer_from_data( + arr, offset=soffset, size=ssize) assert cudabuf.size == ssize arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) - cudabuf = global_context.buffer_from_data(arr[soffset:soffset+ssize]) + cudabuf = global_context.buffer_from_data( + arr[soffset:soffset+ssize]) assert cudabuf.size == ssize arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) @@ -235,7 +238,8 @@ def test_context_device_buffer(size): # Creating device buffer from HostBuffer slice - cudabuf = global_context.buffer_from_data(buf, offset=soffset, size=ssize) + cudabuf = global_context.buffer_from_data( + buf, offset=soffset, size=ssize) assert cudabuf.size == ssize arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) np.testing.assert_equal(arr[soffset:soffset+ssize], arr2) @@ -384,7 +388,8 @@ def test_copy_from_to_host(size): device_buffer.copy_from_host(buf, position=0, nbytes=nbytes) # Copy back to host and compare contents - buf2 = device_buffer.copy_to_host(position=0, nbytes=nbytes) + buf2 = device_buffer.copy_to_host( + position=0, nbytes=nbytes) arr2 = np.frombuffer(buf2, dtype=dt) np.testing.assert_equal(arr, arr2) @@ -395,7 +400,8 @@ def test_copy_to_host(size): buf = dbuf.copy_to_host() assert buf.is_cpu - np.testing.assert_equal(arr, np.frombuffer(buf, dtype=np.uint8)) + np.testing.assert_equal(arr, np.frombuffer( + buf, dtype=np.uint8)) buf = dbuf.copy_to_host(position=size//4) assert buf.is_cpu @@ -437,11 +443,13 @@ def test_copy_to_host(size): np.frombuffer(buf, dtype=np.uint8)) dbuf.copy_to_host(buf=buf, nbytes=12) - np.testing.assert_equal(arr[:12], np.frombuffer(buf, dtype=np.uint8)[:12]) + np.testing.assert_equal(arr[:12], np.frombuffer( + buf, dtype=np.uint8)[:12]) dbuf.copy_to_host(buf=buf, nbytes=12, position=6) - np.testing.assert_equal(arr[6:6+12], - np.frombuffer(buf, dtype=np.uint8)[:12]) + np.testing.assert_equal( + arr[6:6+12], np.frombuffer(buf, dtype=np.uint8)[:12] + ) for (position, nbytes) in [ (0, size+10), (10, size-5), @@ -450,7 +458,8 @@ def test_copy_to_host(size): with pytest.raises(ValueError, match=('requested copy does not ' 'fit into host buffer')): - dbuf.copy_to_host(buf=buf, position=position, nbytes=nbytes) + dbuf.copy_to_host( + buf=buf, position=position, nbytes=nbytes) @pytest.mark.parametrize("dest_ctx", ['same', 'another']) @@ -460,7 +469,9 @@ def test_copy_from_device(dest_ctx, size): lst = arr.tolist() if dest_ctx == 'another': dest_ctx = global_context1 - if buf.context.device_number == dest_ctx.device_number: + if ( + buf.context.device_number == dest_ctx.device_number + ): pytest.skip("not a multi-GPU system") else: dest_ctx = buf.context @@ -563,7 +574,10 @@ def test_buffer_device(): _, buf = make_random_buffer(size=10, target='device') assert buf.device_type == pa.DeviceAllocationType.CUDA assert isinstance(buf.device, pa.Device) - assert buf.device == global_context.memory_manager.device + assert ( + buf.device == + global_context.memory_manager.device + ) assert isinstance(buf.memory_manager, pa.MemoryManager) assert not buf.is_cpu assert not buf.device.is_cpu @@ -807,8 +821,9 @@ def test_create_table_with_device_buffers(): def other_process_for_test_IPC(handle_buffer, expected_arr): - other_context = pa.cuda.Context(0) - ipc_handle = pa.cuda.IpcMemHandle.from_buffer(handle_buffer) + other_context = cuda.Context(0) + ipc_handle = cuda.IpcMemHandle.from_buffer( + handle_buffer) ipc_buf = other_context.open_ipc_buffer(ipc_handle) ipc_buf.context.synchronize() buf = ipc_buf.copy_to_host() @@ -848,7 +863,8 @@ def test_copy_to(): batch = pa.record_batch({"col": arr}) batch_cuda = batch.copy_to(dest) - buf_cuda = batch_cuda["col"].buffers()[1] + buf_cuda = batch_cuda.column("col").buffers()[1] + assert buf_cuda is not None assert not buf_cuda.is_cpu assert buf_cuda.device_type == pa.DeviceAllocationType.CUDA assert buf_cuda.device == mm_cuda.device @@ -949,7 +965,8 @@ def test_device_interface_batch_array(): cbatch._export_to_c_device(ptr_array, ptr_schema) # Delete and recreate C++ objects from exported pointers del cbatch - cbatch_new = pa.RecordBatch._import_from_c_device(ptr_array, ptr_schema) + cbatch_new = pa.RecordBatch._import_from_c_device( + ptr_array, ptr_schema) assert cbatch_new.schema == schema batch_new = cbatch_new.copy_to(pa.default_cpu_memory_manager()) assert batch_new.equals(batch) @@ -957,13 +974,15 @@ def test_device_interface_batch_array(): del cbatch_new # Now released with pytest.raises(ValueError, match="Cannot import released ArrowSchema"): - pa.RecordBatch._import_from_c_device(ptr_array, ptr_schema) + pa.RecordBatch._import_from_c_device( + ptr_array, ptr_schema) # Not a struct type pa.int32()._export_to_c(ptr_schema) with pytest.raises(ValueError, match="ArrowSchema describes non-struct type"): - pa.RecordBatch._import_from_c_device(ptr_array, ptr_schema) + pa.RecordBatch._import_from_c_device( + ptr_array, ptr_schema) def test_print_array(): diff --git a/python/pyarrow/tests/test_cuda_numba_interop.py b/python/pyarrow/tests/test_cuda_numba_interop.py index 876f3c7f761c..4a5bc7975333 100644 --- a/python/pyarrow/tests/test_cuda_numba_interop.py +++ b/python/pyarrow/tests/test_cuda_numba_interop.py @@ -28,7 +28,6 @@ from numba.cuda.cudadrv.devicearray import DeviceNDArray # noqa: E402 - context_choices = None context_choice_ids = ['pyarrow.cuda', 'numba.cuda'] @@ -62,17 +61,19 @@ def test_context(c): def make_random_buffer(size, target='host', dtype='uint8', ctx=None): """Return a host or device buffer with random data. """ - dtype = np.dtype(dtype) + assert np is not None + dtype_obj = np.dtype(dtype) if target == 'host': assert size >= 0 - buf = pa.allocate_buffer(size*dtype.itemsize) - arr = np.frombuffer(buf, dtype=dtype) + buf = pa.allocate_buffer(size*dtype_obj.itemsize) + arr = np.frombuffer(buf, dtype=dtype_obj) arr[:] = np.random.randint(low=0, high=255, size=size, dtype=np.uint8) return arr, buf elif target == 'device': arr, buf = make_random_buffer(size, target='host', dtype=dtype) - dbuf = ctx.new_buffer(size * dtype.itemsize) + assert ctx is not None + dbuf = ctx.new_buffer(size * dtype_obj.itemsize) dbuf.copy_from_host(buf, position=0, nbytes=buf.size) return arr, dbuf raise ValueError('invalid target value') @@ -161,8 +162,8 @@ def __cuda_array_interface__(self): ids=context_choice_ids) @pytest.mark.parametrize("dtype", dtypes, ids=dtypes) def test_numba_memalloc(c, dtype): + assert np is not None ctx, nb_ctx = context_choices[c] - dtype = np.dtype(dtype) # Allocate memory using numba context # Warning: this will not be reflected in pyarrow context manager # (e.g bytes_allocated does not change) @@ -198,6 +199,7 @@ def test_pyarrow_memalloc(c, dtype): ids=context_choice_ids) @pytest.mark.parametrize("dtype", dtypes, ids=dtypes) def test_numba_context(c, dtype): + assert np is not None ctx, nb_ctx = context_choices[c] size = 10 with nb_cuda.gpus[0]: @@ -209,7 +211,10 @@ def test_numba_context(c, dtype): np.testing.assert_equal(darr.copy_to_host(), arr) darr[0] = 99 cbuf.context.synchronize() - arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype) + arr2 = np.frombuffer( + cbuf.copy_to_host(), + dtype=np.dtype(dtype) + ) assert arr2[0] == 99 @@ -217,6 +222,7 @@ def test_numba_context(c, dtype): ids=context_choice_ids) @pytest.mark.parametrize("dtype", dtypes, ids=dtypes) def test_pyarrow_jit(c, dtype): + assert np is not None ctx, nb_ctx = context_choices[c] @nb_cuda.jit @@ -234,5 +240,8 @@ def increment_by_one(an_array): darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem) increment_by_one[blockspergrid, threadsperblock](darr) cbuf.context.synchronize() - arr1 = np.frombuffer(cbuf.copy_to_host(), dtype=arr.dtype) + arr1 = np.frombuffer( + cbuf.copy_to_host(), + dtype=arr.dtype + ) np.testing.assert_equal(arr1, arr + 1) diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 9e7bb312398f..1294e681be45 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -28,19 +28,21 @@ import traceback import json from datetime import datetime +from typing import Any try: import numpy as np except ImportError: - np = None + pass import pytest import pyarrow as pa from pyarrow.lib import IpcReadOptions, ReadStats, tobytes from pyarrow.util import find_free_port from pyarrow.tests import util +from typing import TYPE_CHECKING -try: +if TYPE_CHECKING: from pyarrow import flight from pyarrow.flight import ( FlightClient, FlightServerBase, @@ -49,13 +51,26 @@ ClientMiddleware, ClientMiddlewareFactory, FlightCallOptions, ) -except ImportError: - flight = None - FlightClient, FlightServerBase = object, object - ServerAuthHandler, ClientAuthHandler = object, object - ServerMiddleware, ServerMiddlewareFactory = object, object - ClientMiddleware, ClientMiddlewareFactory = object, object - FlightCallOptions = object +else: + try: + from pyarrow import flight + from pyarrow.flight import ( + FlightClient, FlightServerBase, + ServerAuthHandler, ClientAuthHandler, + ServerMiddleware, ServerMiddlewareFactory, + ClientMiddleware, ClientMiddlewareFactory, + FlightCallOptions, + ) + except ImportError: + flight = None # type: ignore[assignment] + FlightClient, FlightServerBase = object, object + ServerAuthHandler, ClientAuthHandler = ( # type: ignore[misc] + object, object) # type: ignore[assignment] + ServerMiddleware, ServerMiddlewareFactory = ( # type: ignore[misc] + object, object) # type: ignore[assignment] + ClientMiddleware, ClientMiddlewareFactory = ( # type: ignore[misc] + object, object) # type: ignore[assignment] + # FlightCallOptions = object # type: ignore[assignment, misc] # Marks all of the tests in this module # Ignore these with pytest ... -m 'not flight' @@ -196,7 +211,7 @@ def do_put(self, context, descriptor, reader, writer): assert buf is not None client_counter, = struct.unpack(' 0, - 'datetime[s]': np.arange("2016-01-01T00:00:00.001", size, - dtype='datetime64[s]'), - 'datetime[ms]': np.arange("2016-01-01T00:00:00.001", size, - dtype='datetime64[ms]'), - 'datetime[us]': np.arange("2016-01-01T00:00:00.001", size, - dtype='datetime64[us]'), - 'datetime[ns]': np.arange("2016-01-01T00:00:00.001", size, - dtype='datetime64[ns]'), + 'datetime[s]': pd.date_range("2016-01-01T00:00:00.001", periods=size, freq='s').values, + 'datetime[ms]': pd.date_range("2016-01-01T00:00:00.001", periods=size, freq='ms').values, + 'datetime[us]': pd.date_range("2016-01-01T00:00:00.001", periods=size, freq='us').values, + 'datetime[ns]': pd.date_range("2016-01-01T00:00:00.001", periods=size, freq='ns').values, 'timedelta64[s]': np.arange(0, size, dtype='timedelta64[s]'), 'timedelta64[ms]': np.arange(0, size, dtype='timedelta64[ms]'), 'timedelta64[us]': np.arange(0, size, dtype='timedelta64[us]'), @@ -98,7 +91,7 @@ def _alltypes_example(size=100): def _check_pandas_roundtrip(df, expected=None, use_threads=False, expected_schema=None, check_dtype=True, schema=None, - preserve_index=False, + preserve_index: bool | None = False, as_batch=False): klass = pa.RecordBatch if as_batch else pa.Table table = klass.from_pandas(df, schema=schema, @@ -723,7 +716,7 @@ def test_mismatch_metadata_schema(self): # OPTION 1: casting after conversion table = pa.Table.from_pandas(df) # cast the "datetime" column to be tz-aware - new_col = table["datetime"].cast(pa.timestamp('ns', tz="UTC")) + new_col = table.column(0).cast(pa.timestamp('ns', tz="UTC")) new_table1 = table.set_column( 0, pa.field("datetime", new_col.type), new_col ) @@ -991,7 +984,7 @@ def test_float_with_null_as_integer(self): schema = pa.schema([pa.field('has_nulls', ty)]) result = pa.Table.from_pandas(df, schema=schema, preserve_index=False) - assert result[0].chunk(0).equals(expected) + assert result.column(0).chunk(0).equals(expected) def test_int_object_nulls(self): arr = np.array([None, 1, np.int64(3)] * 5, dtype=object) @@ -1153,7 +1146,7 @@ def test_python_datetime(self): }) table = pa.Table.from_pandas(df) - assert isinstance(table[0].chunk(0), pa.TimestampArray) + assert isinstance(table.column(0).chunk(0), pa.TimestampArray) result = table.to_pandas() # Pandas v2 defaults to [ns], but Arrow defaults to [us] time units @@ -1210,7 +1203,7 @@ class MyDatetime(datetime): df = pd.DataFrame({"datetime": pd.Series(date_array, dtype=object)}) table = pa.Table.from_pandas(df) - assert isinstance(table[0].chunk(0), pa.TimestampArray) + assert isinstance(table.column(0).chunk(0), pa.TimestampArray) result = table.to_pandas() @@ -1234,7 +1227,7 @@ class MyDate(date): df = pd.DataFrame({"date": pd.Series(date_array, dtype=object)}) table = pa.Table.from_pandas(df) - assert isinstance(table[0].chunk(0), pa.Date32Array) + assert isinstance(table.column(0).chunk(0), pa.Date32Array) result = table.to_pandas() expected_df = pd.DataFrame( @@ -1746,7 +1739,7 @@ def test_bytes_to_binary(self): df = pd.DataFrame({'strings': values}) table = pa.Table.from_pandas(df) - assert table[0].type == pa.binary() + assert table.column(0).type == pa.binary() values2 = [b'qux', b'foo', None, b'barz', b'qux', None] expected = pd.DataFrame({'strings': values2}) @@ -1767,7 +1760,7 @@ def test_bytes_exceed_2gb(self): arr = None table = pa.Table.from_pandas(df) - assert table[0].num_chunks == 2 + assert table.column(0).num_chunks == 2 @pytest.mark.large_memory @pytest.mark.parametrize('char', ['x', b'x']) @@ -1909,13 +1902,13 @@ def test_table_str_to_categorical_without_na(self, string_type): zero_copy_only=True) # chunked array - result = table["strings"].to_pandas(strings_to_categorical=True) + result = table.column("strings").to_pandas(strings_to_categorical=True) expected = pd.Series(pd.Categorical(values), name="strings") tm.assert_series_equal(result, expected) with pytest.raises(pa.ArrowInvalid): - table["strings"].to_pandas(strings_to_categorical=True, - zero_copy_only=True) + table.column("strings").to_pandas(strings_to_categorical=True, + zero_copy_only=True) @pytest.mark.parametrize( "string_type", [pa.string(), pa.large_string(), pa.string_view()] @@ -1936,13 +1929,13 @@ def test_table_str_to_categorical_with_na(self, string_type): zero_copy_only=True) # chunked array - result = table["strings"].to_pandas(strings_to_categorical=True) + result = table.column("strings").to_pandas(strings_to_categorical=True) expected = pd.Series(pd.Categorical(values), name="strings") tm.assert_series_equal(result, expected) with pytest.raises(pa.ArrowInvalid): - table["strings"].to_pandas(strings_to_categorical=True, - zero_copy_only=True) + table.column("strings").to_pandas(strings_to_categorical=True, + zero_copy_only=True) # Regression test for ARROW-2101 def test_array_of_bytes_to_strings(self): @@ -2524,7 +2517,7 @@ def test_auto_chunking_on_list_overflow(self): table = pa.Table.from_pandas(df) table.validate(full=True) - column_a = table[0] + column_a = table.column(0) assert column_a.num_chunks == 2 assert len(column_a.chunk(0)) == 2**21 - 1 assert len(column_a.chunk(1)) == 1 @@ -3168,9 +3161,8 @@ def test_strided_data_import(self): boolean_objects[5] = None cases.append(boolean_objects) - cases.append(np.arange("2016-01-01T00:00:00.001", N * K, - dtype='datetime64[ms]') - .reshape(N, K).copy()) + cases.append(pd.date_range("2016-01-01T00:00:00.001", periods=N * K, freq='ms') + .values.reshape(N, K).copy()) strided_mask = (random_numbers > 0).astype(bool)[:, 0] @@ -3776,8 +3768,8 @@ def test_recordbatchlist_to_pandas(): def test_recordbatch_table_pass_name_to_pandas(): rb = pa.record_batch([pa.array([1, 2, 3, 4])], names=['a0']) t = pa.table([pa.array([1, 2, 3, 4])], names=['a0']) - assert rb[0].to_pandas().name == 'a0' - assert t[0].to_pandas().name == 'a0' + assert rb.column(0).to_pandas().name == 'a0' + assert t.column(0).to_pandas().name == 'a0' # ---------------------------------------------------------------------- @@ -4331,13 +4323,13 @@ def test_array_protocol(): # default conversion result = pa.table(df) expected = pa.array([1, 2, None], pa.int64()) - assert result[0].chunk(0).equals(expected) + assert result.column(0).chunk(0).equals(expected) # with specifying schema schema = pa.schema([('a', pa.float64())]) result = pa.table(df, schema=schema) expected2 = pa.array([1, 2, None], pa.float64()) - assert result[0].chunk(0).equals(expected2) + assert result.column(0).chunk(0).equals(expected2) # pass Series to pa.array result = pa.array(df['a']) @@ -4467,7 +4459,7 @@ def __init__(self): def __arrow_ext_serialize__(self): return b'' - def to_pandas_dtype(self): + def to_pandas_dtype(self): # type: ignore[override] return pd.Int64Dtype() @@ -4567,7 +4559,7 @@ def test_array_to_pandas(): expected = pd.Series(arr) tm.assert_series_equal(result, expected) - result = pa.table({"col": arr})["col"].to_pandas() + result = pa.table({"col": arr}).column("col").to_pandas() expected = pd.Series(arr, name="col") tm.assert_series_equal(result, expected) @@ -4626,7 +4618,6 @@ def test_array_to_pandas_types_mapper(): assert result.dtype == np.dtype("int64") -@pytest.mark.pandas def test_chunked_array_to_pandas_types_mapper(): # https://issues.apache.org/jira/browse/ARROW-9664 if Version(pd.__version__) < Version("1.2.0"): @@ -5117,7 +5108,7 @@ def test_roundtrip_nested_map_array_with_pydicts_sliced(): ty = pa.list_(pa.map_(pa.string(), pa.list_(pa.string()))) - def assert_roundtrip(series: pd.Series, data) -> None: + def assert_roundtrip(series, data): array_roundtrip = pa.chunked_array(pa.Array.from_pandas(series, type=ty)) array_roundtrip.validate(full=True) assert data.equals(array_roundtrip) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index fcd1c8d48c5f..9ad65f0738d9 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -25,13 +25,10 @@ from pyarrow.lib import tobytes from pyarrow.lib import ArrowInvalid, ArrowNotImplementedError -try: - import pyarrow.substrait as substrait -except ImportError: - substrait = None - # Marks all of the tests in this module # Ignore these with pytest ... -m 'not substrait' +substrait = pytest.importorskip('pyarrow.substrait') +_substrait = pytest.importorskip('pyarrow._substrait') pytestmark = pytest.mark.substrait @@ -85,7 +82,7 @@ def test_run_serialized_query(tmpdir, use_threads): query = tobytes(substrait_query.replace( "FILENAME_PLACEHOLDER", pathlib.Path(path).as_uri())) - buf = pa._substrait._parse_json_plan(query) + buf = _substrait._parse_json_plan(query) reader = substrait.run_query(buf, use_threads=use_threads) res_tb = reader.read_all() @@ -116,7 +113,7 @@ def test_invalid_plan(): ] } """ - buf = pa._substrait._parse_json_plan(tobytes(query)) + buf = _substrait._parse_json_plan(tobytes(query)) exec_message = "Plan has no relations" with pytest.raises(ArrowInvalid, match=exec_message): substrait.run_query(buf) @@ -162,7 +159,7 @@ def test_binary_conversion_with_json_options(tmpdir, use_threads): path = _write_dummy_data_to_disk(tmpdir, file_name, table) query = tobytes(substrait_query.replace( "FILENAME_PLACEHOLDER", pathlib.Path(path).as_uri())) - buf = pa._substrait._parse_json_plan(tobytes(query)) + buf = _substrait._parse_json_plan(tobytes(query)) reader = substrait.run_query(buf, use_threads=use_threads) res_tb = reader.read_all() @@ -181,7 +178,7 @@ def has_function(fns, ext_file, fn_name): def test_get_supported_functions(): - supported_functions = pa._substrait.get_supported_functions() + supported_functions = _substrait.get_supported_functions() # It probably doesn't make sense to exhaustively verify this list but # we can check a sample aggregate and a sample non-aggregate entry assert has_function(supported_functions, @@ -232,8 +229,8 @@ def table_provider(names, schema): } """ - buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) - reader = pa.substrait.run_query( + buf = _substrait._parse_json_plan(tobytes(substrait_query)) + reader = substrait.run_query( buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() assert res_tb == test_table_1 @@ -275,7 +272,7 @@ def table_provider(names, _): } """ - buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + buf = _substrait._parse_json_plan(tobytes(substrait_query)) exec_message = "Invalid NamedTable Source" with pytest.raises(ArrowInvalid, match=exec_message): substrait.run_query(buf, table_provider=table_provider) @@ -317,7 +314,7 @@ def table_provider(names, _): } """ query = tobytes(substrait_query) - buf = pa._substrait._parse_json_plan(tobytes(query)) + buf = _substrait._parse_json_plan(tobytes(query)) exec_message = "names for NamedTable not provided" with pytest.raises(ArrowInvalid, match=exec_message): substrait.run_query(buf, table_provider=table_provider) @@ -436,8 +433,8 @@ def table_provider(names, _): } """ - buf = pa._substrait._parse_json_plan(substrait_query) - reader = pa.substrait.run_query( + buf = _substrait._parse_json_plan(substrait_query) + reader = substrait.run_query( buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() @@ -559,9 +556,9 @@ def table_provider(names, _): } """ - buf = pa._substrait._parse_json_plan(substrait_query) + buf = _substrait._parse_json_plan(substrait_query) with pytest.raises(pa.ArrowKeyError) as excinfo: - pa.substrait.run_query(buf, table_provider=table_provider) + substrait.run_query(buf, table_provider=table_provider) assert "No function registered" in str(excinfo.value) @@ -598,8 +595,8 @@ def table_provider(names, schema): } """ - buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) - reader = pa.substrait.run_query( + buf = _substrait._parse_json_plan(tobytes(substrait_query)) + reader = substrait.run_query( buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() @@ -744,8 +741,8 @@ def table_provider(names, _): ], } """ - buf = pa._substrait._parse_json_plan(substrait_query) - reader = pa.substrait.run_query( + buf = _substrait._parse_json_plan(substrait_query) + reader = substrait.run_query( buf, table_provider=table_provider, use_threads=False) res_tb = reader.read_all() @@ -913,8 +910,8 @@ def table_provider(names, _): ], } """ - buf = pa._substrait._parse_json_plan(substrait_query) - reader = pa.substrait.run_query( + buf = _substrait._parse_json_plan(substrait_query) + reader = substrait.run_query( buf, table_provider=table_provider, use_threads=False) res_tb = reader.read_all() @@ -929,8 +926,8 @@ def table_provider(names, _): @pytest.mark.parametrize("expr", [ - pc.equal(pc.field("x"), 7), - pc.equal(pc.field("x"), pc.field("y")), + pc.equal(pc.field("x"), 7), # type: ignore[attr-defined] + pc.equal(pc.field("x"), pc.field("y")), # type: ignore[attr-defined] pc.field("x") > 50 ]) def test_serializing_expressions(expr): @@ -939,8 +936,8 @@ def test_serializing_expressions(expr): pa.field("y", pa.int32()) ]) - buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) - returned = pa.substrait.deserialize_expressions(buf) + buf = substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = substrait.deserialize_expressions(buf) assert schema == returned.schema assert len(returned.expressions) == 1 assert "test_expr" in returned.expressions @@ -958,8 +955,8 @@ def test_arrow_specific_types(): schema = pa.schema([pa.field(name, typ) for name, (typ, _) in fields.items()]) def check_round_trip(expr): - buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) - returned = pa.substrait.deserialize_expressions(buf) + buf = substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = substrait.deserialize_expressions(buf) assert schema == returned.schema for name, (typ, val) in fields.items(): @@ -986,8 +983,8 @@ def test_arrow_one_way_types(): def check_one_way(field): expr = pc.is_null(pc.field(field.name)) - buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) - returned = pa.substrait.deserialize_expressions(buf) + buf = substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = substrait.deserialize_expressions(buf) assert alt_schema == returned.schema for field in schema: @@ -1003,14 +1000,14 @@ def test_invalid_expression_ser_des(): bad_expr = pc.equal(pc.field("z"), 7) # Invalid number of names with pytest.raises(ValueError) as excinfo: - pa.substrait.serialize_expressions([expr], [], schema) + substrait.serialize_expressions([expr], [], schema) assert 'need to have the same length' in str(excinfo.value) with pytest.raises(ValueError) as excinfo: - pa.substrait.serialize_expressions([expr], ["foo", "bar"], schema) + substrait.serialize_expressions([expr], ["foo", "bar"], schema) assert 'need to have the same length' in str(excinfo.value) # Expression doesn't match schema with pytest.raises(ValueError) as excinfo: - pa.substrait.serialize_expressions([bad_expr], ["expr"], schema) + substrait.serialize_expressions([bad_expr], ["expr"], schema) assert 'No match for FieldRef' in str(excinfo.value) @@ -1020,8 +1017,8 @@ def test_serializing_multiple_expressions(): pa.field("y", pa.int32()) ]) exprs = [pc.equal(pc.field("x"), 7), pc.equal(pc.field("x"), pc.field("y"))] - buf = pa.substrait.serialize_expressions(exprs, ["first", "second"], schema) - returned = pa.substrait.deserialize_expressions(buf) + buf = substrait.serialize_expressions(exprs, ["first", "second"], schema) + returned = substrait.deserialize_expressions(buf) assert schema == returned.schema assert len(returned.expressions) == 2 @@ -1037,8 +1034,8 @@ def test_serializing_with_compute(): ]) expr = pc.equal(pc.field("x"), 7) expr_norm = pc.equal(pc.field(0), 7) - buf = expr.to_substrait(schema) - returned = pa.substrait.deserialize_expressions(buf) + buf = expr.to_substrait(schema) # type: ignore[union-attr] + returned = substrait.deserialize_expressions(buf) assert schema == returned.schema assert len(returned.expressions) == 1 @@ -1046,13 +1043,13 @@ def test_serializing_with_compute(): assert str(returned.expressions["expression"]) == str(expr_norm) # Compute can't deserialize messages with multiple expressions - buf = pa.substrait.serialize_expressions([expr, expr], ["first", "second"], schema) + buf = substrait.serialize_expressions([expr, expr], ["first", "second"], schema) with pytest.raises(ValueError) as excinfo: pc.Expression.from_substrait(buf) assert 'contained multiple expressions' in str(excinfo.value) # Deserialization should be possible regardless of the expression name - buf = pa.substrait.serialize_expressions([expr], ["weirdname"], schema) + buf = substrait.serialize_expressions([expr], ["weirdname"], schema) expr2 = pc.Expression.from_substrait(buf) assert str(expr2) == str(expr_norm) @@ -1069,11 +1066,11 @@ def test_serializing_udfs(): exprs = [pc.shift_left(a, b)] with pytest.raises(ArrowNotImplementedError): - pa.substrait.serialize_expressions(exprs, ["expr"], schema) + substrait.serialize_expressions(exprs, ["expr"], schema) - buf = pa.substrait.serialize_expressions( + buf = substrait.serialize_expressions( exprs, ["expr"], schema, allow_arrow_extensions=True) - returned = pa.substrait.deserialize_expressions(buf) + returned = substrait.deserialize_expressions(buf) assert schema == returned.schema assert len(returned.expressions) == 1 assert str(returned.expressions["expr"]) == str(exprs[0]) @@ -1085,19 +1082,19 @@ def test_serializing_schema(): pa.field("x", pa.int32()), pa.field("y", pa.string()) ]) - returned = pa.substrait.deserialize_schema(substrait_schema) + returned = substrait.deserialize_schema(substrait_schema) assert expected_schema == returned - arrow_substrait_schema = pa.substrait.serialize_schema(returned) + arrow_substrait_schema = substrait.serialize_schema(returned) assert arrow_substrait_schema.schema == substrait_schema - returned = pa.substrait.deserialize_schema(arrow_substrait_schema) + returned = substrait.deserialize_schema(arrow_substrait_schema) assert expected_schema == returned - returned = pa.substrait.deserialize_schema(arrow_substrait_schema.schema) + returned = substrait.deserialize_schema(arrow_substrait_schema.schema) assert expected_schema == returned - returned = pa.substrait.deserialize_expressions(arrow_substrait_schema.expression) + returned = substrait.deserialize_expressions(arrow_substrait_schema.expression) assert returned.schema == expected_schema @@ -1114,7 +1111,7 @@ def SerializeToString(self): b'\x1a\x19\n\x06\x12\x04\n\x02\x12\x00\x1a\x0fproject_version' b'"0\n\x0fproject_version\n\x0fproject_release' b'\x12\x0c\n\x04:\x02\x10\x01\n\x04b\x02\x10\x01') - exprs = pa.substrait.BoundExpressions.from_substrait(FakeMessage(message)) + exprs = substrait.BoundExpressions.from_substrait(FakeMessage(message)) assert len(exprs.expressions) == 2 assert 'project_release' in exprs.expressions assert 'project_version' in exprs.expressions