Skip to content

Commit

Permalink
Make sure we can filter common fields
Browse files Browse the repository at this point in the history
  • Loading branch information
snejus committed Jun 20, 2024
1 parent 7fd90be commit 323f17c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
60 changes: 33 additions & 27 deletions beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,24 @@ class FieldQuery(Query, Generic[P]):
same matching functionality in SQLite.
"""

@property
def field(self) -> str:
return (
f"{self.table}.{self.field_name}" if self.table else self.field_name
)

@property
def field_names(self) -> Set[str]:
"""Return a set with field names that this query operates on."""
return {self.field}
return {self.field_name}

def __init__(self, field: str, pattern: P, fast: bool = True):
self.field = field
def __init__(self, field_name: str, pattern: P, fast: bool = True):
self.table, _, self.field_name = field_name.rpartition(".")
self.pattern = pattern
self.fast = fast

def col_clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
return None, ()
def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field, ()

def clause(self) -> Tuple[Optional[str], Sequence[SQLiteType]]:
if self.fast:
Expand All @@ -160,23 +166,23 @@ def value_match(cls, pattern: P, value: Any):
raise NotImplementedError()

def match(self, obj: Model) -> bool:
return self.value_match(self.pattern, obj.get(self.field))
return self.value_match(self.pattern, obj.get(self.field_name))

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.field!r}, {self.pattern!r}, "
f"{self.__class__.__name__}({self.field_name!r}, {self.pattern!r}, "
f"fast={self.fast})"
)

def __eq__(self, other) -> bool:
return (
super().__eq__(other)
and self.field == other.field
and self.field_name == other.field_name
and self.pattern == other.pattern
)

def __hash__(self) -> int:
return hash((self.field, hash(self.pattern)))
return hash((self.field_name, hash(self.pattern)))


class MatchQuery(FieldQuery[AnySQLiteType]):
Expand All @@ -200,10 +206,10 @@ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " IS NULL", ()

def match(self, obj: Model) -> bool:
return obj.get(self.field) is None
return obj.get(self.field_name) is None

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.field!r}, {self.fast})"
return f"{self.__class__.__name__}({self.field_name!r}, {self.fast})"


class StringFieldQuery(FieldQuery[P]):
Expand Down Expand Up @@ -274,7 +280,7 @@ class RegexpQuery(StringFieldQuery[Pattern[str]]):
expression.
"""

def __init__(self, field: str, pattern: str, fast: bool = True):
def __init__(self, field_name: str, pattern: str, fast: bool = True):
pattern = self._normalize(pattern)
try:
pattern_re = re.compile(pattern)
Expand All @@ -284,7 +290,7 @@ def __init__(self, field: str, pattern: str, fast: bool = True):
pattern, "a regular expression", format(exc)
)

super().__init__(field, pattern_re, fast)
super().__init__(field_name, pattern_re, fast)

def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return f" regexp({self.field}, ?)", [self.pattern.pattern]
Expand All @@ -308,7 +314,7 @@ class BooleanQuery(MatchQuery[int]):

def __init__(
self,
field: str,
field_name: str,
pattern: bool,
fast: bool = True,
):
Expand All @@ -317,7 +323,7 @@ def __init__(

pattern_int = int(pattern)

super().__init__(field, pattern_int, fast)
super().__init__(field_name, pattern_int, fast)


class BytesQuery(FieldQuery[bytes]):
Expand All @@ -327,7 +333,7 @@ class BytesQuery(FieldQuery[bytes]):
`MatchQuery` when matching on BLOB values.
"""

def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
def __init__(self, field_name: str, pattern: Union[bytes, str, memoryview]):
# Use a buffer/memoryview representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
Expand All @@ -343,7 +349,7 @@ def __init__(self, field: str, pattern: Union[bytes, str, memoryview]):
else:
raise ValueError("pattern must be bytes, str, or memoryview")

super().__init__(field, bytes_pattern)
super().__init__(field_name, bytes_pattern)

def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
return self.field + " = ?", [self.buf_pattern]
Expand Down Expand Up @@ -379,8 +385,8 @@ def _convert(self, s: str) -> Union[float, int, None]:
except ValueError:
raise InvalidQueryArgumentValueError(s, "an int or a float")

def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)

parts = pattern.split("..", 1)
if len(parts) == 1:
Expand All @@ -395,9 +401,9 @@ def __init__(self, field: str, pattern: str, fast: bool = True):
self.rangemax = self._convert(parts[1])

def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
value = obj[self.field]
value = obj[self.field_name]
if isinstance(value, str):
value = self._convert(value)

Expand Down Expand Up @@ -430,7 +436,7 @@ def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
class InQuery(Generic[AnySQLiteType], FieldQuery[Sequence[AnySQLiteType]]):
"""Query which matches values in the given set."""

field: str
field_name: str
pattern: Sequence[AnySQLiteType]
fast: bool = True

Expand All @@ -440,7 +446,7 @@ def subvals(self) -> Sequence[SQLiteType]:

def col_clause(self) -> Tuple[str, Sequence[SQLiteType]]:
placeholders = ", ".join(["?"] * len(self.subvals))
return f"{self.field} IN ({placeholders})", self.subvals
return f"{self.field_name} IN ({placeholders})", self.subvals

@classmethod
def value_match(
Expand Down Expand Up @@ -823,15 +829,15 @@ class DateQuery(FieldQuery[str]):
using an ellipsis interval syntax similar to that of NumericQuery.
"""

def __init__(self, field: str, pattern: str, fast: bool = True):
super().__init__(field, pattern, fast)
def __init__(self, field_name: str, pattern: str, fast: bool = True):
super().__init__(field_name, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)

def match(self, obj: Model) -> bool:
if self.field not in obj:
if self.field_name not in obj:
return False
timestamp = float(obj[self.field])
timestamp = float(obj[self.field_name])
date = datetime.fromtimestamp(timestamp)
return self.interval.contains(date)

Expand Down
6 changes: 6 additions & 0 deletions beets/dbcore/queryparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def construct_query_part(
# they are querying.
else:
key = key.lower()
if key in model_cls.shared_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to query it in a join.
# Using an explicit table name resolves this.
key = f"{model_cls._table}.{key}"

out_query = query_class(key, pattern, key in model_cls.all_db_fields)

# Apply negation.
Expand Down
6 changes: 6 additions & 0 deletions test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,7 @@ def setUp(self):
album_items.append(item)
album = self.lib.add_album(album_items)
album.artpath = f"{album_name} Artpath"
album.catalognum = "ABC"
album.store()
albums.append(album)

Expand All @@ -1163,6 +1164,11 @@ def test_get_items_filter_by_album_field(self):
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])

def test_filter_by_common_field(self):
q = "catalognum:ABC Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])


def suite():
return unittest.TestLoader().loadTestsFromName(__name__)
Expand Down

0 comments on commit 323f17c

Please sign in to comment.