From 3288afd8d933e6f899aae25c2028a05d0ffe84c1 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 20 Nov 2025 10:42:08 -0700 Subject: [PATCH] feat(typing): make types more precise for many common Values --- ibis/expr/types/generic.py | 82 +++++++- ibis/expr/types/logical.py | 40 +++- ibis/expr/types/numeric.py | 371 ++++++++++++++++++++++++++++++++----- ibis/expr/types/strings.py | 166 +++++++++++++++-- 4 files changed, 576 insertions(+), 83 deletions(-) diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 7257feab5834..c1c4fde2c396 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -42,7 +42,7 @@ class Value(Expr): """Base class for a data generating expression having a known type.""" - def name(self, name: str, /) -> Value: + def name(self, name: str, /) -> Self: """Rename an expression to `name`. Parameters @@ -121,6 +121,10 @@ def type(self) -> dt.DataType: """ return self.op().dtype + @overload + def hash(self: Scalar) -> ir.IntegerScalar: ... + @overload + def hash(self: Column) -> ir.IntegerColumn: ... def hash(self) -> ir.IntegerValue: """Compute an integer hash value. @@ -297,6 +301,10 @@ def cast(self, target_type: Any, /) -> Value: return op.to_expr() + @overload + def try_cast(self: Scalar, target_type: Any, /) -> Scalar: ... + @overload + def try_cast(self: Column, target_type: Any, /) -> Column: ... def try_cast(self, target_type: Any, /) -> Value: """Try cast expression to indicated data type. @@ -383,6 +391,10 @@ def coalesce(self, /, *args: Value) -> Value: """ return ops.Coalesce((self, *args)).to_expr() + @overload + def typeof(self: Scalar) -> ir.StringScalar: ... + @overload + def typeof(self: Column) -> ir.StringColumn: ... def typeof(self) -> ir.StringValue: """Return the string name of the datatype of self. @@ -923,6 +935,10 @@ def bind(table): except com.IbisInputError: return bind(_) + @overload + def isnull(self: Scalar) -> ir.BooleanScalar: ... + @overload + def isnull(self: Column) -> ir.BooleanColumn: ... def isnull(self) -> ir.BooleanValue: """Whether this expression is `NULL`. Does NOT detect `NaN` and `inf` values. @@ -960,6 +976,10 @@ def isnull(self) -> ir.BooleanValue: """ return ops.IsNull(self).to_expr() + @overload + def notnull(self: Scalar) -> ir.BooleanScalar: ... + @overload + def notnull(self: Column) -> ir.BooleanColumn: ... def notnull(self) -> ir.BooleanValue: """Return whether this expression is not NULL. @@ -1291,6 +1311,14 @@ def group_concat( def __hash__(self) -> int: return super().__hash__() + @overload + def __eq__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __eq__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __eq__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __eq__(self: Column, other: Column) -> ir.BooleanColumn: ... def __eq__(self, other: Value) -> ir.BooleanValue: if _is_null_literal(other): return self.isnull() @@ -1298,6 +1326,14 @@ def __eq__(self, other: Value) -> ir.BooleanValue: return other.isnull() return _binop(ops.Equals, self, other) + @overload + def __ne__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __ne__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __ne__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __ne__(self: Column, other: Column) -> ir.BooleanColumn: ... def __ne__(self, other: Value) -> ir.BooleanValue: if _is_null_literal(other): return self.notnull() @@ -1305,15 +1341,47 @@ def __ne__(self, other: Value) -> ir.BooleanValue: return other.notnull() return _binop(ops.NotEquals, self, other) + @overload + def __ge__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __ge__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __ge__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __ge__(self: Column, other: Column) -> ir.BooleanColumn: ... def __ge__(self, other: Value) -> ir.BooleanValue: return _binop(ops.GreaterEqual, self, other) + @overload + def __gt__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __gt__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __gt__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __gt__(self: Column, other: Column) -> ir.BooleanColumn: ... def __gt__(self, other: Value) -> ir.BooleanValue: return _binop(ops.Greater, self, other) + @overload + def __le__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __le__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __le__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __le__(self: Column, other: Column) -> ir.BooleanColumn: ... def __le__(self, other: Value) -> ir.BooleanValue: return _binop(ops.LessEqual, self, other) + @overload + def __lt__(self: Scalar, other: Scalar) -> ir.BooleanScalar: ... + @overload + def __lt__(self: Scalar, other: Column) -> ir.BooleanColumn: ... + @overload + def __lt__(self: Column, other: Scalar) -> ir.BooleanColumn: ... + @overload + def __lt__(self: Column, other: Column) -> ir.BooleanColumn: ... def __lt__(self, other: Value) -> ir.BooleanValue: return _binop(ops.Less, self, other) @@ -1484,7 +1552,7 @@ def __polars_result__(self, df: pl.DataFrame) -> Any: return PolarsData.convert_scalar(df, self.type()) - def as_scalar(self): + def as_scalar(self) -> Self: """Inform ibis that the expression should be treated as a scalar. If the expression is a literal, it will be returned as is. If it depends @@ -2461,7 +2529,7 @@ def first( where: ir.BooleanValue | None = None, order_by: Any = None, include_null: bool = False, - ) -> Value: + ) -> Scalar: """Return the first value of a column. Parameters @@ -2515,7 +2583,7 @@ def last( where: ir.BooleanValue | None = None, order_by: Any = None, include_null: bool = False, - ) -> Value: + ) -> Scalar: """Return the last value of a column. Parameters @@ -2627,7 +2695,7 @@ def dense_rank(self) -> ir.IntegerColumn: """ return ibis.dense_rank().over(order_by=self) - def percent_rank(self) -> Column: + def percent_rank(self) -> ir.IntegerColumn: """Return the relative rank of the values in the column. Examples @@ -2651,7 +2719,7 @@ def percent_rank(self) -> Column: """ return ibis.percent_rank().over(order_by=self) - def cume_dist(self) -> Column: + def cume_dist(self) -> ir.FloatingColumn: """Return the cumulative distribution over a window. Examples @@ -2995,7 +3063,7 @@ class NullColumn(Column, NullValue): @public @deferrable -def null(type: dt.DataType | str | None = None, /) -> Value: +def null(type: dt.DataType | str | None = None, /) -> Scalar: """Create a NULL scalar. `NULL`s with an unspecified type are castable and comparable to values, diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index b921621187c2..b40141067d14 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload from public import public @@ -11,11 +11,19 @@ from ibis.expr.types.numeric import NumericColumn, NumericScalar, NumericValue if TYPE_CHECKING: + from typing_extensions import Self + import ibis.expr.types as ir @public class BooleanValue(NumericValue): + @overload + def ifelse( + self: ir.BooleanScalar, true_expr: ir.Scalar, false_expr: ir.Scalar, / + ) -> ir.Scalar: ... + @overload + def ifelse(self, true_expr: ir.Value, false_expr: ir.Value, /) -> ir.Column: ... def ifelse(self, true_expr: ir.Value, false_expr: ir.Value, /) -> ir.Value: """Construct a ternary conditional expression. @@ -53,7 +61,11 @@ def ifelse(self, true_expr: ir.Value, false_expr: ir.Value, /) -> ir.Value: # must be used. return ops.IfElse(self, true_expr, false_expr).to_expr() - def __and__(self, other: BooleanValue) -> BooleanValue: + @overload + def __and__(self: BooleanScalar, other: bool | BooleanScalar) -> BooleanScalar: ... + @overload + def __and__(self, other: bool | BooleanValue) -> BooleanColumn: ... + def __and__(self, other: bool | BooleanValue) -> BooleanValue: """Construct a binary AND conditional expression with `self` and `other`. Parameters @@ -101,7 +113,11 @@ def __and__(self, other: BooleanValue) -> BooleanValue: __rand__ = __and__ - def __or__(self, other: BooleanValue) -> BooleanValue: + @overload + def __or__(self: BooleanScalar, other: bool | BooleanScalar) -> BooleanScalar: ... + @overload + def __or__(self, other: bool | BooleanValue) -> BooleanColumn: ... + def __or__(self, other: bool | BooleanValue) -> BooleanValue: """Construct a binary OR conditional expression with `self` and `other`. Parameters @@ -137,7 +153,11 @@ def __or__(self, other: BooleanValue) -> BooleanValue: __ror__ = __or__ - def __xor__(self, other: BooleanValue) -> BooleanValue: + @overload + def __xor__(self: BooleanScalar, other: bool | BooleanScalar) -> BooleanScalar: ... + @overload + def __xor__(self, other: bool | BooleanValue) -> BooleanColumn: ... + def __xor__(self, other: bool | BooleanValue) -> BooleanValue: """Construct a binary XOR conditional expression with `self` and `other`. Parameters @@ -198,7 +218,7 @@ def __xor__(self, other: BooleanValue) -> BooleanValue: __rxor__ = __xor__ - def __invert__(self) -> BooleanValue: + def __invert__(self) -> Self: """Construct a unary NOT conditional expression with `self`. Parameters @@ -230,7 +250,7 @@ def __invert__(self) -> BooleanValue: """ return ops.Not(self).to_expr() - def negate(self) -> BooleanValue: + def negate(self) -> Self: """DEPRECATED.""" util.warn_deprecated( "`-bool_val`/`bool_val.negate()`", @@ -247,7 +267,7 @@ class BooleanScalar(NumericScalar, BooleanValue): @public class BooleanColumn(NumericColumn, BooleanValue): - def any(self, *, where: BooleanValue | None = None) -> BooleanValue: + def any(self, *, where: bool | BooleanValue | None = None) -> BooleanScalar: """Return whether at least one element is `True`. If the expression does not reference any foreign tables, the result @@ -339,7 +359,7 @@ def resolve_exists_subquery(outer): return op.to_expr() - def notany(self, *, where: BooleanValue | None = None) -> BooleanValue: + def notany(self, *, where: bool | BooleanValue | None = None) -> BooleanScalar: """Return whether no elements are `True`. Parameters @@ -373,7 +393,7 @@ def notany(self, *, where: BooleanValue | None = None) -> BooleanValue: """ return ~self.any(where=where) - def all(self, *, where: BooleanValue | None = None) -> BooleanScalar: + def all(self, *, where: bool | BooleanValue | None = None) -> BooleanScalar: """Return whether all elements are `True`. Parameters @@ -410,7 +430,7 @@ def all(self, *, where: BooleanValue | None = None) -> BooleanScalar: """ return ops.All(self, where=self._bind_to_parent_table(where)).to_expr() - def notall(self, *, where: BooleanValue | None = None) -> BooleanScalar: + def notall(self, *, where: bool | BooleanValue | None = None) -> BooleanScalar: """Return whether not all elements are `True`. Parameters diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 44d68297cb72..c0ef35954a2a 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, overload from public import public @@ -11,12 +11,19 @@ from ibis.expr.types.generic import Column, Scalar, Value, _binop if TYPE_CHECKING: + from decimal import Decimal + from typing import Union + + from typing_extensions import Self + import ibis.expr.types as ir + Number = Union[int, float, Decimal] + @public class NumericValue(Value): - def negate(self) -> NumericValue: + def negate(self) -> Self: """Negate a numeric expression. Returns @@ -42,7 +49,7 @@ def negate(self) -> NumericValue: """ return ops.Negate(self).to_expr() - def __neg__(self) -> NumericValue: + def __neg__(self) -> Self: """Negate `self`. Returns @@ -52,6 +59,12 @@ def __neg__(self) -> NumericValue: """ return self.negate() + @overload + def round(self: NumericScalar, digits: int | IntegerScalar, /) -> NumericScalar: ... + @overload + def round(self: NumericScalar, digits: IntegerColumn, /) -> NumericColumn: ... + @overload + def round(self: NumericColumn, digits: int | IntegerValue, /) -> NumericColumn: ... def round(self, digits: int | IntegerValue = 0, /) -> NumericValue: """Round values to an indicated number of decimal places. @@ -114,7 +127,17 @@ def round(self, digits: int | IntegerValue = 0, /) -> NumericValue: """ return ops.Round(self, digits).to_expr() - def log(self, base: NumericValue | None = None, /) -> NumericValue: + @overload + def log( + self: NumericScalar, base: Number | NumericScalar | None, / + ) -> FloatingScalar: ... + @overload + def log(self: NumericScalar, base: NumericColumn, /) -> FloatingColumn: ... + @overload + def log( + self: NumericColumn, base: Number | NumericValue | None, / + ) -> FloatingColumn: ... + def log(self, base: Number | NumericValue | None = None, /) -> NumericValue: r"""Compute $\log_{\texttt{base}}\left(\texttt{self}\right)$. Parameters @@ -163,9 +186,9 @@ def log(self, base: NumericValue | None = None, /) -> NumericValue: def clip( self, - lower: NumericValue | None = None, - upper: NumericValue | None = None, - ) -> NumericValue: + lower: Number | NumericValue | None = None, + upper: Number | NumericValue | None = None, + ) -> Self: """Trim values outside of `lower` and `upper` bounds. `NULL` values are preserved and are not replaced with bounds. @@ -211,7 +234,7 @@ def clip( return ops.Clip(self, lower, upper).to_expr() - def abs(self) -> NumericValue: + def abs(self) -> Self: """Return the absolute value of `self`. Examples @@ -233,6 +256,14 @@ def abs(self) -> NumericValue: """ return ops.Abs(self).to_expr() + @overload + def ceil(self: DecimalScalar) -> DecimalScalar: ... + @overload + def ceil(self: DecimalColumn) -> DecimalColumn: ... + @overload + def ceil(self: NumericScalar) -> IntegerScalar: ... + @overload + def ceil(self: NumericColumn) -> IntegerColumn: ... def ceil(self) -> DecimalValue | IntegerValue: """Return the ceiling of `self`. @@ -256,6 +287,14 @@ def ceil(self) -> DecimalValue | IntegerValue: """ return ops.Ceil(self).to_expr() + @overload + def degrees(self: NumericScalar) -> FloatingScalar: ... + @overload + def degrees(self: NumericColumn) -> FloatingColumn: ... + @overload + def degrees(self: DecimalScalar) -> DecimalScalar: ... + @overload + def degrees(self: DecimalColumn) -> DecimalColumn: ... def degrees(self) -> NumericValue: """Compute the degrees of `self` radians. @@ -282,7 +321,15 @@ def degrees(self) -> NumericValue: rad2deg = degrees - def exp(self) -> NumericValue: + @overload + def exp(self: NumericScalar) -> FloatingScalar: ... + @overload + def exp(self: NumericColumn) -> FloatingColumn: ... + @overload + def exp(self: DecimalScalar) -> DecimalScalar: ... + @overload + def exp(self: DecimalColumn) -> DecimalColumn: ... + def exp(self) -> FloatingValue: r"""Compute $e^\texttt{self}$. Returns @@ -309,6 +356,14 @@ def exp(self) -> NumericValue: """ return ops.Exp(self).to_expr() + @overload + def floor(self: DecimalScalar) -> DecimalScalar: ... + @overload + def floor(self: DecimalColumn) -> DecimalColumn: ... + @overload + def floor(self: NumericScalar) -> IntegerScalar: ... + @overload + def floor(self: NumericColumn) -> IntegerColumn: ... def floor(self) -> DecimalValue | IntegerValue: """Return the floor of an expression. @@ -333,7 +388,11 @@ def floor(self) -> DecimalValue | IntegerValue: """ return ops.Floor(self).to_expr() - def log2(self) -> NumericValue: + @overload + def log2(self: NumericScalar) -> FloatingScalar: ... + @overload + def log2(self: NumericColumn) -> FloatingColumn: ... + def log2(self) -> FloatingValue: r"""Compute $\log_{2}\left(\texttt{self}\right)$. Examples @@ -355,7 +414,11 @@ def log2(self) -> NumericValue: """ return ops.Log2(self).to_expr() - def log10(self) -> NumericValue: + @overload + def log10(self: NumericScalar) -> FloatingScalar: ... + @overload + def log10(self: NumericColumn) -> FloatingColumn: ... + def log10(self) -> FloatingValue: r"""Compute $\log_{10}\left(\texttt{self}\right)$. Examples @@ -376,7 +439,11 @@ def log10(self) -> NumericValue: """ return ops.Log10(self).to_expr() - def ln(self) -> NumericValue: + @overload + def ln(self: NumericScalar) -> FloatingScalar: ... + @overload + def ln(self: NumericColumn) -> FloatingColumn: ... + def ln(self) -> FloatingValue: r"""Compute $\ln\left(\texttt{self}\right)$. Examples @@ -397,7 +464,11 @@ def ln(self) -> NumericValue: """ return ops.Ln(self).to_expr() - def radians(self) -> NumericValue: + @overload + def radians(self: NumericScalar) -> FloatingScalar: ... + @overload + def radians(self: NumericColumn) -> FloatingColumn: ... + def radians(self) -> FloatingValue: """Compute radians from `self` degrees. Examples @@ -422,7 +493,7 @@ def radians(self) -> NumericValue: deg2rad = radians - def sign(self) -> NumericValue: + def sign(self) -> Self: """Return the sign of the input. Examples @@ -444,7 +515,11 @@ def sign(self) -> NumericValue: """ return ops.Sign(self).to_expr() - def sqrt(self) -> NumericValue: + @overload + def sqrt(self: NumericScalar) -> FloatingScalar: ... + @overload + def sqrt(self: NumericColumn) -> FloatingColumn: ... + def sqrt(self) -> FloatingValue: """Compute the square root of `self`. Examples @@ -466,7 +541,11 @@ def sqrt(self) -> NumericValue: """ return ops.Sqrt(self).to_expr() - def acos(self) -> NumericValue: + @overload + def acos(self: NumericScalar) -> FloatingScalar: ... + @overload + def acos(self: NumericColumn) -> FloatingColumn: ... + def acos(self) -> FloatingValue: """Compute the arc cosine of `self`. Examples @@ -488,7 +567,11 @@ def acos(self) -> NumericValue: """ return ops.Acos(self).to_expr() - def asin(self) -> NumericValue: + @overload + def asin(self: NumericScalar) -> FloatingScalar: ... + @overload + def asin(self: NumericColumn) -> FloatingColumn: ... + def asin(self) -> FloatingValue: """Compute the arc sine of `self`. Examples @@ -509,7 +592,11 @@ def asin(self) -> NumericValue: """ return ops.Asin(self).to_expr() - def atan(self) -> NumericValue: + @overload + def atan(self: NumericScalar) -> FloatingScalar: ... + @overload + def atan(self: NumericColumn) -> FloatingColumn: ... + def atan(self) -> FloatingValue: """Compute the arc tangent of `self`. Examples @@ -530,7 +617,17 @@ def atan(self) -> NumericValue: """ return ops.Atan(self).to_expr() - def atan2(self, other: NumericValue, /) -> NumericValue: + @overload + def atan2( + self: NumericScalar, other: Number | NumericScalar, / + ) -> FloatingScalar: ... + @overload + def atan2(self: NumericScalar, other: NumericColumn, /) -> FloatingColumn: ... + @overload + def atan2( + self: NumericColumn, other: Number | NumericValue, / + ) -> FloatingColumn: ... + def atan2(self, other: NumericValue, /) -> FloatingValue: """Compute the two-argument version of arc tangent. Examples @@ -551,7 +648,11 @@ def atan2(self, other: NumericValue, /) -> NumericValue: """ return ops.Atan2(self, other).to_expr() - def cos(self) -> NumericValue: + @overload + def cos(self: NumericScalar) -> FloatingScalar: ... + @overload + def cos(self: NumericColumn) -> FloatingColumn: ... + def cos(self) -> FloatingValue: """Compute the cosine of `self`. Examples @@ -572,7 +673,11 @@ def cos(self) -> NumericValue: """ return ops.Cos(self).to_expr() - def cot(self) -> NumericValue: + @overload + def cot(self: NumericScalar) -> FloatingScalar: ... + @overload + def cot(self: NumericColumn) -> FloatingColumn: ... + def cot(self) -> FloatingValue: """Compute the cotangent of `self`. Examples @@ -593,7 +698,11 @@ def cot(self) -> NumericValue: """ return ops.Cot(self).to_expr() - def sin(self) -> NumericValue: + @overload + def sin(self: NumericScalar) -> FloatingScalar: ... + @overload + def sin(self: NumericColumn) -> FloatingColumn: ... + def sin(self) -> FloatingValue: """Compute the sine of `self`. Examples @@ -614,7 +723,11 @@ def sin(self) -> NumericValue: """ return ops.Sin(self).to_expr() - def tan(self) -> NumericValue: + @overload + def tan(self: NumericScalar) -> FloatingScalar: ... + @overload + def tan(self: NumericColumn) -> FloatingColumn: ... + def tan(self) -> FloatingValue: """Compute the tangent of `self`. Examples @@ -635,51 +748,112 @@ def tan(self) -> NumericValue: """ return ops.Tan(self).to_expr() - def __add__(self, other: NumericValue) -> NumericValue: + @overload + def __add__( + self: NumericScalar, other: Number | NumericScalar + ) -> NumericScalar: ... + @overload + def __add__(self: NumericScalar, other: NumericColumn) -> NumericColumn: ... + @overload + def __add__(self, other: Number | NumericValue) -> NumericColumn: ... + def __add__(self, other: Number | NumericValue) -> NumericValue: """Add `self` with `other`.""" return _binop(ops.Add, self, other) add = radd = __radd__ = __add__ - def __sub__(self, other: NumericValue) -> NumericValue: + @overload + def __sub__( + self: NumericScalar, other: Number | NumericScalar + ) -> NumericScalar: ... + @overload + def __sub__(self: NumericScalar, other: NumericColumn) -> NumericColumn: ... + @overload + def __sub__(self, other: Number | NumericValue) -> NumericColumn: ... + def __sub__(self, other: Number | NumericValue) -> NumericValue: """Subtract `other` from `self`.""" return _binop(ops.Subtract, self, other) sub = __sub__ - def __rsub__(self, other: NumericValue) -> NumericValue: + @overload + def __rsub__( + self: NumericScalar, other: Number | NumericScalar + ) -> NumericScalar: ... + @overload + def __rsub__(self: NumericScalar, other: NumericColumn) -> NumericColumn: ... + @overload + def __rsub__(self, other: Number | NumericValue) -> NumericColumn: ... + def __rsub__(self, other: Number | NumericValue) -> NumericValue: """Subtract `self` from `other`.""" return _binop(ops.Subtract, other, self) rsub = __rsub__ - def __mul__(self, other: NumericValue) -> NumericValue: + @overload + def __mul__( + self: NumericScalar, other: Number | NumericScalar + ) -> NumericScalar: ... + @overload + def __mul__(self: NumericScalar, other: NumericColumn) -> NumericColumn: ... + @overload + def __mul__(self, other: Number | NumericValue) -> NumericColumn: ... + def __mul__(self, other: Number | NumericValue) -> NumericValue: """Multiply `self` and `other`.""" return _binop(ops.Multiply, self, other) mul = rmul = __rmul__ = __mul__ - def __truediv__(self, other): + @overload + def __truediv__( + self: NumericScalar, other: Number | NumericScalar + ) -> FloatingScalar: ... + @overload + def __truediv__(self: NumericScalar, other: NumericColumn) -> FloatingColumn: ... + @overload + def __truediv__(self, other: Number | NumericValue) -> FloatingColumn: ... + def __truediv__(self, other: Number | NumericValue) -> FloatingValue: """Divide `self` by `other`.""" return _binop(ops.Divide, self, other) div = __div__ = __truediv__ - def __rtruediv__(self, other: NumericValue) -> NumericValue: + @overload + def __rtruediv__( + self: NumericScalar, other: Number | NumericScalar + ) -> FloatingScalar: ... + @overload + def __rtruediv__(self: NumericScalar, other: NumericColumn) -> FloatingColumn: ... + @overload + def __rtruediv__(self, other: Number | NumericValue) -> FloatingColumn: ... + def __rtruediv__(self, other: Number | NumericValue) -> NumericValue: """Divide `other` by `self`.""" return _binop(ops.Divide, other, self) rdiv = __rdiv__ = __rtruediv__ + @overload def __floordiv__( - self, - other: NumericValue, - ) -> NumericValue: + self: NumericScalar, other: Number | NumericScalar + ) -> IntegerScalar: ... + @overload + def __floordiv__(self: NumericScalar, other: NumericColumn) -> IntegerColumn: ... + @overload + def __floordiv__(self, other: Number | NumericValue) -> IntegerColumn: ... + def __floordiv__(self, other: Number | NumericValue) -> IntegerValue: """Floor divide `self` by `other`.""" return _binop(ops.FloorDivide, self, other) floordiv = __floordiv__ + @overload + def __rfloordiv__( + self: NumericScalar, other: Number | NumericScalar + ) -> IntegerScalar: ... + @overload + def __rfloordiv__(self: NumericScalar, other: NumericColumn) -> IntegerColumn: ... + @overload + def __rfloordiv__(self, other: Number | NumericValue) -> IntegerColumn: ... def __rfloordiv__( self, other: NumericValue, @@ -714,7 +888,17 @@ def __rmod__(self, other: NumericValue) -> NumericValue: rmod = __rmod__ - def point(self, right: int | float | NumericValue, /) -> ir.PointValue: + @overload + def point( + self: NumericScalar, right: Number | NumericScalar, / + ) -> ir.PointScalar: ... + @overload + def point(self: NumericScalar, right: NumericColumn, /) -> ir.PointColumn: ... + @overload + def point( + self: NumericColumn, right: Number | NumericValue, / + ) -> ir.PointColumn: ... + def point(self, right: Number | NumericValue, /) -> ir.PointValue: """Return a point constructed from the coordinate values. Constant coordinates result in construction of a `POINT` literal or @@ -764,12 +948,26 @@ class NumericScalar(Scalar, NumericValue): @public class NumericColumn(Column, NumericValue): + @overload + def kurtosis( + self: DecimalColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> DecimalScalar: ... + @overload + def kurtosis( + self: NumericColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> FloatingScalar: ... def kurtosis( self, *, where: ir.BooleanValue | None = None, how: Literal["sample", "pop"] = "sample", - ) -> NumericScalar: + ) -> FloatingScalar | DecimalScalar: """Return the kurtosis of a numeric column. Parameters @@ -813,12 +1011,26 @@ def kurtosis( self, how=how, where=self._bind_to_parent_table(where) ).to_expr() + @overload + def std( + self: DecimalColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> DecimalScalar: ... + @overload + def std( + self: NumericColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> FloatingScalar: ... def std( self, *, where: ir.BooleanValue | None = None, how: Literal["sample", "pop"] = "sample", - ) -> NumericScalar: + ) -> FloatingScalar | DecimalScalar: """Return the standard deviation of a numeric column. Parameters @@ -864,12 +1076,26 @@ def std( self, how=how, where=self._bind_to_parent_table(where) ).to_expr() + @overload + def var( + self: DecimalColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> DecimalScalar: ... + @overload + def var( + self: NumericColumn, + *, + where: ir.BooleanValue | None = None, + how: Literal["sample", "pop"] = "sample", + ) -> FloatingScalar: ... def var( self, *, where: ir.BooleanValue | None = None, how: Literal["sample", "pop"] = "sample", - ) -> NumericScalar: + ) -> FloatingScalar | DecimalScalar: """Return the variance of a numeric column. Parameters @@ -922,7 +1148,7 @@ def corr( *, where: ir.BooleanValue | None = None, how: Literal["sample", "pop"] = "sample", - ) -> NumericScalar: + ) -> FloatingScalar: """Return the correlation of two numeric columns. Parameters @@ -981,7 +1207,7 @@ def cov( *, where: ir.BooleanValue | None = None, how: Literal["sample", "pop"] = "sample", - ) -> NumericScalar: + ) -> FloatingScalar: """Return the covariance of two numeric columns. Parameters @@ -1037,7 +1263,7 @@ def cov( where=self._bind_to_parent_table(where), ).to_expr() - def mean(self, *, where: ir.BooleanValue | None = None) -> NumericScalar: + def mean(self, *, where: ir.BooleanValue | None = None) -> FloatingScalar: """Return the mean of a numeric column. Parameters @@ -1093,7 +1319,7 @@ def mean(self, *, where: ir.BooleanValue | None = None) -> NumericScalar: # of default name generated by ops.Value operations return ops.Mean(self, where=self._bind_to_parent_table(where)).to_expr() - def cummean(self, *, where=None, group_by=None, order_by=None) -> NumericColumn: + def cummean(self, *, where=None, group_by=None, order_by=None) -> FloatingColumn: """Return the cumulative mean of the input. Examples @@ -1322,7 +1548,7 @@ def histogram( binwidth: float | None = None, base: float | None = None, eps: float = 1e-13, - ): + ) -> IntegerColumn: """Compute a histogram with fixed width bins. Parameters @@ -1467,6 +1693,14 @@ def approx_quantile( @public class IntegerValue(NumericValue): + @overload + def as_timestamp( + self: IntegerScalar, unit: Literal["s", "ms", "us"], / + ) -> ir.TimestampScalar: ... + @overload + def as_timestamp( + self: IntegerColumn, unit: Literal["s", "ms", "us"], / + ) -> ir.TimestampColumn: ... def as_timestamp(self, unit: Literal["s", "ms", "us"], /) -> ir.TimestampValue: """Convert an integral UNIX timestamp to a timestamp expression. @@ -1498,6 +1732,18 @@ def as_timestamp(self, unit: Literal["s", "ms", "us"], /) -> ir.TimestampValue: """ return ops.TimestampFromUNIX(self, unit).to_expr() + @overload + def as_interval( + self: IntegerScalar, + unit: Literal["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"] = "s", + /, + ) -> ir.IntervalScalar: ... + @overload + def as_interval( + self: IntegerColumn, + unit: Literal["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"] = "s", + /, + ) -> ir.IntervalColumn: ... def as_interval( self, unit: Literal["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"] = "s", @@ -1554,11 +1800,36 @@ def as_interval( """ return ops.IntervalFromInteger(self, unit).to_expr() + @overload + def convert_base( + self: IntegerScalar, + from_base: int | IntegerScalar, + to_base: int | IntegerScalar, + ) -> ir.StringScalar: ... + + @overload + def convert_base( + self: IntegerScalar, + from_base: IntegerColumn, + to_base: int | IntegerScalar, + ) -> ir.StringColumn: ... + @overload + def convert_base( + self: IntegerScalar, + from_base: int | IntegerScalar, + to_base: IntegerColumn, + ) -> ir.StringColumn: ... + @overload + def convert_base( + self: IntegerColumn, + from_base: int | IntegerValue, + to_base: int | IntegerValue, + ) -> ir.StringColumn: ... def convert_base( self, - from_base: IntegerValue, - to_base: IntegerValue, - ) -> IntegerValue: + from_base: int | IntegerValue, + to_base: int | IntegerValue, + ) -> ir.StringColumn: """Convert an integer from one base to another. Parameters @@ -1570,7 +1841,7 @@ def convert_base( Returns ------- - IntegerValue + StringValue Converted expression """ return ops.BaseConvert(self, from_base, to_base).to_expr() @@ -1609,7 +1880,7 @@ def __rrshift__(self, other: IntegerValue) -> IntegerValue: """Bitwise right shift `self` with `other`.""" return _binop(ops.BitwiseRightShift, other, self) - def __invert__(self) -> IntegerValue: + def __invert__(self) -> Self: """Bitwise not of `self`. Returns @@ -1692,6 +1963,10 @@ def bit_xor(self, *, where: ir.BooleanValue | None = None) -> IntegerScalar: @public class FloatingValue(NumericValue): + @overload + def isnan(self: FloatingScalar) -> ir.BooleanScalar: ... + @overload + def isnan(self: FloatingColumn) -> ir.BooleanColumn: ... def isnan(self) -> ir.BooleanValue: """Return whether the value is NaN. Does NOT detect `NULL` and `inf` values. @@ -1725,6 +2000,10 @@ def isnan(self) -> ir.BooleanValue: """ return ops.IsNan(self).to_expr() + @overload + def isinf(self: FloatingScalar) -> ir.BooleanScalar: ... + @overload + def isinf(self: FloatingColumn) -> ir.BooleanColumn: ... def isinf(self) -> ir.BooleanValue: """Return whether the value is +/-inf. Does NOT detect `NULL` and `inf` values. diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index 3ad482b5c1ee..1f3d2e33a9e7 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -2,7 +2,7 @@ import functools import operator -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, overload from public import public @@ -13,12 +13,20 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence + from typing_extensions import Self + import ibis.expr.types as ir @public class StringValue(Value): - def __getitem__(self, key: slice | int | ir.IntegerScalar) -> StringValue: + @overload + def __getitem__( + self: ir.StringScalar, key: slice | int | ir.IntegerScalar + ) -> ir.StringScalar: ... + @overload + def __getitem__(self, key: slice | int | ir.IntegerValue) -> ir.StringColumn: ... + def __getitem__(self, key: slice | int | ir.IntegerValue) -> StringValue: """Index or slice a string expression. Parameters @@ -109,6 +117,10 @@ def __getitem__(self, key: slice | int | ir.IntegerScalar) -> StringValue: return self.substr(key, 1) raise NotImplementedError(f"string __getitem__[{key.__class__.__name__}]") + @overload + def length(self: ir.StringScalar) -> ir.IntegerScalar: ... + @overload + def length(self: ir.StringColumn) -> ir.IntegerColumn: ... def length(self) -> ir.IntegerValue: """Compute the length of a string. @@ -135,7 +147,7 @@ def length(self) -> ir.IntegerValue: """ return ops.StringLength(self).to_expr() - def lower(self) -> StringValue: + def lower(self) -> Self: """Convert string to all lowercase. Returns @@ -171,7 +183,7 @@ def lower(self) -> StringValue: """ return ops.Lowercase(self).to_expr() - def upper(self) -> StringValue: + def upper(self) -> Self: """Convert string to all uppercase. Returns @@ -207,7 +219,7 @@ def upper(self) -> StringValue: """ return ops.Uppercase(self).to_expr() - def reverse(self) -> StringValue: + def reverse(self) -> Self: """Reverse the characters of a string. Returns @@ -243,6 +255,10 @@ def reverse(self) -> StringValue: """ return ops.Reverse(self).to_expr() + @overload + def ascii_str(self: ir.StringScalar) -> ir.IntegerScalar: ... + @overload + def ascii_str(self: ir.StringColumn) -> ir.IntegerColumn: ... def ascii_str(self) -> ir.IntegerValue: """Return the numeric ASCII code of the first character of a string. @@ -269,7 +285,7 @@ def ascii_str(self) -> ir.IntegerValue: """ return ops.StringAscii(self).to_expr() - def strip(self) -> StringValue: + def strip(self) -> Self: r"""Remove whitespace from left and right sides of a string. Returns @@ -305,7 +321,7 @@ def strip(self) -> StringValue: """ return ops.Strip(self).to_expr() - def lstrip(self) -> StringValue: + def lstrip(self) -> Self: r"""Remove whitespace from the left side of string. Returns @@ -341,7 +357,7 @@ def lstrip(self) -> StringValue: """ return ops.LStrip(self).to_expr() - def rstrip(self) -> StringValue: + def rstrip(self) -> Self: r"""Remove whitespace from the right side of string. Returns @@ -377,7 +393,7 @@ def rstrip(self) -> StringValue: """ return ops.RStrip(self).to_expr() - def capitalize(self) -> StringValue: + def capitalize(self) -> Self: """Uppercase the first letter, lowercase the rest. This API matches the semantics of the Python [](`str.capitalize`) @@ -410,6 +426,14 @@ def capitalize(self) -> StringValue: def __contains__(self, *_: Any) -> bool: raise TypeError("Use string_expr.contains(arg)") + @overload + def contains( + self: StringScalar, substr: str | StringScalar + ) -> ir.BooleanScalar: ... + @overload + def contains(self: StringScalar, substr: StringColumn) -> ir.BooleanColumn: ... + @overload + def contains(self: StringColumn, substr: str | StringValue) -> ir.BooleanColumn: ... def contains(self, substr: str | StringValue, /) -> ir.BooleanValue: """Return whether the expression contains `substr`. @@ -441,6 +465,18 @@ def contains(self, substr: str | StringValue, /) -> ir.BooleanValue: """ return ops.StringContains(self, substr).to_expr() + @overload + def hashbytes( + self: StringScalar, + how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", + /, + ) -> ir.BinaryScalar: ... + @overload + def hashbytes( + self: StringColumn, + how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", + /, + ) -> ir.BinaryColumn: ... def hashbytes( self, how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", / ) -> ir.BinaryValue: @@ -466,7 +502,7 @@ def hashbytes( def hexdigest( self, how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256", / - ) -> ir.StringValue: + ) -> Self: """Return the hash digest of the input as a hex encoded string. Parameters @@ -497,6 +533,18 @@ def hexdigest( """ return ops.HexDigest(self, how.lower()).to_expr() + @overload + def substr( + self: StringScalar, + start: int | ir.IntegerScalar, + length: int | ir.IntegerScalar | None = None, + ) -> StringScalar: ... + @overload + def substr( + self, + start: int | ir.IntegerValue, + length: int | ir.IntegerValue | None = None, + ) -> StringColumn: ... def substr( self, start: int | ir.IntegerValue, length: int | ir.IntegerValue | None = None ) -> StringValue: @@ -533,6 +581,12 @@ def substr( """ return ops.Substring(self, start, length).to_expr() + @overload + def left(self: StringScalar, nchars: int | ir.IntegerScalar, /) -> StringScalar: ... + @overload + def left(self: StringScalar, nchars: ir.IntegerColumn, /) -> StringColumn: ... + @overload + def left(self: StringColumn, nchars: int | ir.IntegerValue, /) -> StringColumn: ... def left(self, nchars: int | ir.IntegerValue, /) -> StringValue: """Return the `nchars` left-most characters. @@ -564,6 +618,14 @@ def left(self, nchars: int | ir.IntegerValue, /) -> StringValue: """ return self.substr(0, length=nchars) + @overload + def right( + self: StringScalar, nchars: int | ir.IntegerScalar, / + ) -> StringScalar: ... + @overload + def right(self: StringScalar, nchars: ir.IntegerColumn, /) -> StringColumn: ... + @overload + def right(self: StringColumn, nchars: int | ir.IntegerValue, /) -> StringColumn: ... def right(self, nchars: int | ir.IntegerValue, /) -> StringValue: """Return up to `nchars` from the end of each string. @@ -595,6 +657,12 @@ def right(self, nchars: int | ir.IntegerValue, /) -> StringValue: """ return ops.StrRight(self, nchars).to_expr() + @overload + def repeat(self: StringScalar, n: int | ir.IntegerScalar, /) -> StringScalar: ... + @overload + def repeat(self: StringScalar, n: ir.IntegerColumn, /) -> StringColumn: ... + @overload + def repeat(self: StringColumn, n: int | ir.IntegerValue, /) -> StringColumn: ... def repeat(self, n: int | ir.IntegerValue, /) -> StringValue: """Repeat a string `n` times. @@ -626,7 +694,17 @@ def repeat(self, n: int | ir.IntegerValue, /) -> StringValue: """ return ops.Repeat(self, n).to_expr() - def translate(self, from_str: StringValue, to_str: StringValue) -> StringValue: + @overload + def translate( + self: StringScalar, from_str: str | StringScalar, to_str: str | StringScalar + ) -> StringScalar: ... + @overload + def translate( + self, from_str: str | StringValue, to_str: str | StringValue + ) -> StringColumn: ... + def translate( + self, from_str: str | StringValue, to_str: str | StringValue + ) -> StringValue: """Replace `from_str` characters in `self` characters in `to_str`. To avoid unexpected behavior, `from_str` should be shorter than @@ -857,6 +935,12 @@ def join( cls = ops.StringJoin return cls(strings, sep=self).to_expr() + @overload + def startswith( + self: StringScalar, start: str | StringScalar, / + ) -> ir.BooleanScalar: ... + @overload + def startswith(self, start: str | StringValue, /) -> ir.BooleanColumn: ... def startswith(self, start: str | StringValue, /) -> ir.BooleanValue: """Determine whether `self` starts with `start`. @@ -887,6 +971,12 @@ def startswith(self, start: str | StringValue, /) -> ir.BooleanValue: """ return ops.StartsWith(self, start).to_expr() + @overload + def endswith( + self: StringScalar, end: str | StringScalar, / + ) -> ir.BooleanScalar: ... + @overload + def endswith(self, end: str | StringValue, /) -> ir.BooleanColumn: ... def endswith(self, end: str | StringValue, /) -> ir.BooleanValue: """Determine if `self` ends with `end`. @@ -1259,6 +1349,10 @@ def replace(self, pattern: StringValue, replacement: StringValue) -> StringValue """ return ops.StringReplace(self, pattern, replacement).to_expr() + @overload + def as_timestamp(self: StringScalar, format_str: str, /) -> ir.TimestampScalar: ... + @overload + def as_timestamp(self: StringColumn, format_str: str, /) -> ir.TimestampColumn: ... def as_timestamp(self, format_str: str, /) -> ir.TimestampValue: """Parse a string and return a timestamp. @@ -1288,6 +1382,10 @@ def as_timestamp(self, format_str: str, /) -> ir.TimestampValue: """ return ops.StringToTimestamp(self, format_str).to_expr() + @overload + def as_date(self: StringScalar, format_str: str, /) -> ir.DateScalar: ... + @overload + def as_date(self: StringColumn, format_str: str, /) -> ir.DateColumn: ... def as_date(self, format_str: str, /) -> ir.DateValue: """Parse a string and return a date. @@ -1317,6 +1415,10 @@ def as_date(self, format_str: str, /) -> ir.DateValue: """ return ops.StringToDate(self, format_str).to_expr() + @overload + def as_time(self: StringScalar, format_str: str, /) -> ir.TimeScalar: ... + @overload + def as_time(self: StringColumn, format_str: str, /) -> ir.TimeColumn: ... def as_time(self, format_str: str, /) -> ir.TimeValue: """Parse a string and return a time. @@ -1346,7 +1448,7 @@ def as_time(self, format_str: str, /) -> ir.TimeValue: """ return ops.StringToTime(self, format_str).to_expr() - def protocol(self): + def protocol(self) -> Self: """Parse a URL and extract protocol. Examples @@ -1362,7 +1464,7 @@ def protocol(self): """ return ops.ExtractProtocol(self).to_expr() - def authority(self): + def authority(self) -> Self: """Parse a URL and extract authority. Examples @@ -1378,7 +1480,7 @@ def authority(self): """ return ops.ExtractAuthority(self).to_expr() - def userinfo(self): + def userinfo(self) -> Self: """Parse a URL and extract user info. Examples @@ -1394,7 +1496,7 @@ def userinfo(self): """ return ops.ExtractUserInfo(self).to_expr() - def host(self): + def host(self) -> Self: """Parse a URL and extract host. Examples @@ -1410,7 +1512,7 @@ def host(self): """ return ops.ExtractHost(self).to_expr() - def file(self): + def file(self) -> Self: """Parse a URL and extract file. Examples @@ -1428,7 +1530,7 @@ def file(self): """ return ops.ExtractFile(self).to_expr() - def path(self): + def path(self) -> Self: """Parse a URL and extract path. Examples @@ -1446,6 +1548,12 @@ def path(self): """ return ops.ExtractPath(self).to_expr() + @overload + def query(self: StringScalar, key: str | StringScalar, /) -> StringScalar: ... + @overload + def query(self, key: None, /) -> Self: ... + @overload + def query(self, key: str | StringValue, /) -> StringColumn: ... def query(self, key: str | StringValue | None = None, /): """Parse a URL and returns query string or query string parameter. @@ -1473,7 +1581,7 @@ def query(self, key: str | StringValue | None = None, /): """ return ops.ExtractQuery(self, key).to_expr() - def fragment(self): + def fragment(self) -> Self: """Parse a URL and extract fragment identifier. Examples @@ -1579,6 +1687,10 @@ def concat( """ return ops.StringConcat((self, other, *args)).to_expr() + @overload + def __add__(self: StringScalar, other: str | StringScalar) -> StringScalar: ... + @overload + def __add__(self, other: str | StringValue) -> StringColumn: ... def __add__(self, other: str | StringValue) -> StringValue: """Concatenate strings. @@ -1630,6 +1742,10 @@ def __add__(self, other: str | StringValue) -> StringValue: """ return self.concat(other) + @overload + def __radd__(self: StringScalar, other: str | StringScalar) -> StringScalar: ... + @overload + def __radd__(self, other: str | StringValue) -> StringColumn: ... def __radd__(self, other: str | StringValue) -> StringValue: """Concatenate strings. @@ -1690,12 +1806,22 @@ def convert_base( """ return ops.BaseConvert(self, from_base, to_base).to_expr() + @overload + def __mul__(self: StringScalar, n: int | ir.IntegerScalar) -> StringScalar: ... + @overload + def __mul__(self, n: int | ir.IntegerValue) -> StringColumn: ... def __mul__(self, n: int | ir.IntegerValue) -> StringValue: return _binop(ops.Repeat, self, n) __rmul__ = __mul__ - def levenshtein(self, other: StringValue, /) -> ir.IntegerValue: + @overload + def levenshtein( + self: StringScalar, other: str | StringScalar, / + ) -> ir.IntegerScalar: ... + @overload + def levenshtein(self, other: str | StringValue, /) -> ir.IntegerColumn: ... + def levenshtein(self, other: str | StringValue, /) -> ir.IntegerValue: """Return the Levenshtein distance between two strings. Parameters @@ -1728,5 +1854,5 @@ class StringScalar(Scalar, StringValue): @public class StringColumn(Column, StringValue): - def __getitem__(self, key: slice | int | ir.IntegerScalar) -> StringColumn: + def __getitem__(self, key: slice | int | ir.IntegerValue) -> StringColumn: return StringValue.__getitem__(self, key)