From 5fdc68dd61894663a7df61ba70ec54b1dca18943 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 22 Oct 2023 22:15:52 +0000 Subject: [PATCH] Improve support for WHERE-clauses --- CHANGELOG.md | 7 + py_partiql_parser/__init__.py | 2 +- py_partiql_parser/_internal/parser.py | 2 +- py_partiql_parser/_internal/utils.py | 15 +- py_partiql_parser/_internal/where_parser.py | 313 ++++++++++++++++---- pyproject.toml | 2 +- tests/test_dynamodb_examples.py | 150 +++++++++- tests/test_tokenizer.py | 38 +++ tests/test_where_parser.py | 194 ++++++++---- 9 files changed, 608 insertions(+), 115 deletions(-) create mode 100644 tests/test_tokenizer.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 38c7bbe..f7fb61e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ CHANGELOG ========= +0.4.1 +----- + - Increased support for WHERE-clauses: + 1. Nested clauses + 2. OR-clauses + 3. Functions: attribute_type, IF (NOT) MISSING, comparison operators (<, >) + 0.4.0 ----- - The DynamoDBStatementParser now expects a document in the DynamoDB format: diff --git a/py_partiql_parser/__init__.py b/py_partiql_parser/__init__.py index 6524108..c50f6e2 100644 --- a/py_partiql_parser/__init__.py +++ b/py_partiql_parser/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.4.1" from ._internal.parser import DynamoDBStatementParser, S3SelectParser # noqa diff --git a/py_partiql_parser/_internal/parser.py b/py_partiql_parser/_internal/parser.py index 0377662..4f3678f 100644 --- a/py_partiql_parser/_internal/parser.py +++ b/py_partiql_parser/_internal/parser.py @@ -101,4 +101,4 @@ def get_query_metadata(cls, query: str): else: where = None - return QueryMetadata(tables=from_clauses, where_clauses=where) + return QueryMetadata(tables=from_clauses, where_clause=where) diff --git a/py_partiql_parser/_internal/utils.py b/py_partiql_parser/_internal/utils.py index fab217e..b58df0c 100644 --- a/py_partiql_parser/_internal/utils.py +++ b/py_partiql_parser/_internal/utils.py @@ -2,7 +2,10 @@ from .case_insensitive_dict import CaseInsensitiveDict from .json_parser import MissingVariable, Variable -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from .where_parser import AbstractWhereClause def is_dict(dct): @@ -134,13 +137,17 @@ def find_value_in_dynamodb_document(keys: List[str], json_doc): class QueryMetadata: def __init__( - self, tables: Dict[str, str], where_clauses: List[Tuple[List[str], str]] = None + self, + tables: Dict[str, str], + where_clause: Optional["AbstractWhereClause"] = None, ): self._tables = tables - self._where_clauses = where_clauses or [] + self._where_clause = where_clause def get_table_names(self) -> List[str]: return list(self._tables.values()) def get_filter_names(self) -> List[str]: - return [".".join(keys) for keys, _ in self._where_clauses] + if self._where_clause: + return self._where_clause.get_filter_names() + return [] diff --git a/py_partiql_parser/_internal/where_parser.py b/py_partiql_parser/_internal/where_parser.py index c2855b2..54ac7e7 100644 --- a/py_partiql_parser/_internal/where_parser.py +++ b/py_partiql_parser/_internal/where_parser.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Tuple +from decimal import Decimal +from typing import Any, List, Optional from .clause_tokenizer import ClauseTokenizer from .utils import find_value_in_document @@ -10,54 +11,253 @@ serializer = TypeSerializer() +class AbstractWhereClause: + def __init__(self, left: Any): + self.children = [left] + + def apply(self, find_value, row) -> bool: + return NotImplemented + + def get_filter_names(self) -> List[str]: + all_names: List[str] = [] + for child in self.children: + all_names.extend(child.get_filter_names()) + return all_names + + def process_value(self, fn) -> None: + """ + Transform all the values in this Where-clause, using a custom function + """ + for child in self.children: + child.process_value(fn) + + def __eq__(self, other: Any): + if isinstance(other, AbstractWhereClause): + return self.children == other.children + return NotImplemented + + def __str__(self): + return f"<{type(self)} {self.children}>" + + def __repr__(self): + return str(self) + + +class WhereAndClause(AbstractWhereClause): + def __init__(self, left: "AbstractWhereClause"): + super().__init__(left) + + def apply(self, find_value, row) -> bool: + return all([child.apply(find_value, row) for child in self.children]) + + +class WhereOrClause(AbstractWhereClause): + def __init__(self, left: "AbstractWhereClause"): + super().__init__(left) + + def apply(self, find_value, row) -> bool: + return any([child.apply(find_value, row) for child in self.children]) + + +class WhereClause(AbstractWhereClause): + def __init__(self, fn: str, left: List[str], right: str): + super().__init__([]) + self.fn = fn.lower() + self.left = left + self.right = right + + def apply(self, find_value, row) -> bool: + value = find_value(self.left, row) + if self.fn == "contains": + if "S" in self.right and "S" in value: + return self.right["S"] in value["S"] + if self.fn == "is": + if self.right == {"S": "MISSING"}: + return value is None + elif self.right == {"S": "NOT MISSING"}: + return value is not None + if self.fn in ["<=", "<", ">=", ">"]: + actual_value = Decimal(list(value.values())[0]) + expected = Decimal(self.right["S"]) + if self.fn == "<=": + return actual_value <= expected + if self.fn == "<": + return actual_value < expected + if self.fn == ">=": + return actual_value >= expected + if self.fn == ">": + return actual_value > expected + if self.fn == "attribute_type" and value is not None: + actual_value = list(value.keys())[0] + return actual_value == self.right["S"] + # Default - should we error instead if fn != '=='? + return value == self.right + + def process_value(self, fn) -> None: + self.right = fn(self.right) + + def get_filter_names(self) -> List[str]: + return [".".join(self.left)] + + def __eq__(self, other: Any): + if isinstance(other, WhereClause): + return ( + self.fn == other.fn + and self.left == other.left + and self.right == other.right + ) + return NotImplemented + + def __str__(self): + return f"<{type(self)} {self.fn}({self.left}, {self.right})>" + + def __repr__(self): + return str(self) + + class WhereParser: def __init__(self, source_data: Any): self.source_data = source_data @classmethod - def parse_where_clause(cls, where_clause: str) -> Tuple[List[str], str]: - where_clause_parser = ClauseTokenizer(where_clause) - results = [] - keys: List[str] = [] + def parse_where_clause( + cls, where_clause: str, tokenizer: Optional[ClauseTokenizer] = None + ) -> AbstractWhereClause: + where_clause_parser = tokenizer or ClauseTokenizer(where_clause) + current_clause: Optional[AbstractWhereClause] = None + processing_function = False + left = [] + fn: str = "" section: Optional[str] = "KEY" current_phrase = "" while True: c = where_clause_parser.next() if c is None: if section == "KEY" and current_phrase != "": - keys.append(current_phrase) + left.append(current_phrase) + if section == "START_VALUE" and current_phrase != "": + current_clause = cls._determine_current_clause( + current_clause, left=left.copy(), fn=fn, right=current_phrase + ) break + if section == "KEY" and c == "(": + if current_phrase == "": + # Process a subsection of the WHERE-clause + # .. and (sub-clause) and ... + next_clause = WhereParser.parse_where_clause( + where_clause="", tokenizer=where_clause_parser + ) + if current_clause is None: + current_clause = next_clause + else: + current_clause.children.append(next_clause) + section = "END_VALUE" + where_clause_parser.skip_white_space() + continue + else: + # Function + fn = current_phrase + current_phrase = "" + processing_function = True + continue + if c == ")": + if processing_function: + # | + # v + # fn("key", val) + processing_function = False + continue + else: + # Finished processing a subsection of the WHERE-clause + # .. and (sub-clause) and ... + return current_clause if c == ".": - if section == "KEY": + if section in ["KEY", "END_KEY"]: if current_phrase != "": - keys.append(current_phrase) + left.append(current_phrase) current_phrase = "" + section = "KEY" + continue + if c in [","]: + if section in ["END_KEY"]: + # | + # v + # fn("key", val) + section = "START_VALUE" + where_clause_parser.skip_white_space() continue if c in ['"', "'"]: if section == "KEY": # collect everything between these quotes - keys.append(where_clause_parser.next_until([c])) + left.append(where_clause_parser.next_until([c])) + # This could be the end of a key + # | + # v + # "key" = val + # fn("key", val) + # "key".subkey = val + # + # Note that in the last example, the key isn't actually finished + # When we encounter a '.' next, will we return to processing a KEY + section = "END_KEY" continue if section == "START_VALUE": - section = "VALUE" - continue - if section == "VALUE": + current_phrase = where_clause_parser.next_until([c]) section = "END_VALUE" - results.append((keys.copy(), current_phrase)) - keys.clear() + current_clause = cls._determine_current_clause( + current_clause, left=left.copy(), fn=fn, right=current_phrase + ) + left.clear() current_phrase = "" where_clause_parser.skip_white_space() continue - if c in [" "] and section == "KEY": + if c in [" "] and section in ["KEY", "END_KEY"]: + # | + # v + # "key" = val + # key >= val if current_phrase != "": - keys.append(current_phrase) + left.append(current_phrase) current_phrase = "" - where_clause_parser.skip_until(["="]) + fn = where_clause_parser.next_until([" "]) + if fn == "IS": + # Options: + # IS MISSING + # IS NOT MISSING + current_phrase = where_clause_parser.next_until([" "]) + if current_phrase == "NOT": + current_phrase = ( + f"{current_phrase} {where_clause_parser.next_until([' '])}" + ) + current_clause = cls._determine_current_clause( + current_clause, left=left.copy(), fn=fn, right=current_phrase + ) + left.clear() + current_phrase = "" + section = "END_VALUE" + continue where_clause_parser.skip_white_space() section = "START_VALUE" continue + if c in [" "] and section == "START_VALUE": + # | + # v + # "key" = 0 AND .. + current_clause = cls._determine_current_clause( + current_clause, left=left.copy(), fn=fn, right=current_phrase + ) + left.clear() + current_phrase = "" + section = "END_VALUE" + continue if c in [" "] and section == "END_VALUE": if current_phrase.upper() == "AND": + current_clause = WhereAndClause(current_clause) + current_phrase = "" + section = "KEY" + where_clause_parser.skip_white_space() + elif current_phrase.upper() == "OR": + current_clause = WhereOrClause(current_clause) current_phrase = "" section = "KEY" where_clause_parser.skip_white_space() @@ -65,55 +265,62 @@ def parse_where_clause(cls, where_clause: str) -> Tuple[List[str], str]: if c in ["?"] and section == "START_VALUE": # Most values have to be surrounded by quotes # Question marks are parameters, and are valid values on their own - results.append((keys.copy(), "?")) - keys.clear() + current_clause = cls._determine_current_clause( + current_clause, left=left.copy(), fn=fn, right="?" + ) + left.clear() section = "END_VALUE" # Next step is to look for other key/value pairs continue if current_phrase == "" and section == "START_KEY": section = "KEY" - if section in ["KEY", "VALUE", "END_VALUE"]: + if section in ["KEY", "VALUE", "START_VALUE", "END_VALUE"]: current_phrase += c - return results + return current_clause + + @classmethod + def _determine_current_clause(cls, current_clause, left: str, fn: str, right: str): + if current_clause is not None and isinstance( + current_clause, (WhereAndClause, WhereOrClause) + ): + current_clause.children.append( + WhereClause(fn=fn, left=left.copy(), right=right) + ) + else: + current_clause = WhereClause(fn=fn, left=left.copy(), right=right) + return current_clause class DynamoDBWhereParser(WhereParser): - def parse(self, where_clause: str, parameters) -> Any: - _filters = WhereParser.parse_where_clause(where_clause) + def parse(self, _where_clause: str, parameters) -> Any: + where_clause = WhereParser.parse_where_clause(_where_clause) - _filters = [ - ( - key, - deserializer.deserialize(parameters.pop(0)) if value == "?" else value, - ) - for key, value in _filters - ] + def prep_value(val): + # WHERE key = ? + # ? should be parametrized + if val == "?": + return parameters.pop(0) + # WHERE key = val + # 'val' needs to be comparable with a DynamoDB document + # So we need to turn that into {'S': 'val'} + else: + return serializer.serialize(val) - return self.filter_rows(_filters) + where_clause.process_value(prep_value) - def filter_rows(self, _filters): - def _filter(row) -> bool: - return all( - [ - find_value_in_dynamodb_document(keys, row) - == serializer.serialize(value) - for keys, value in _filters - ] - ) - - return [row for row in self.source_data if _filter(row)] + return [ + row + for row in self.source_data + if where_clause.apply(find_value_in_dynamodb_document, row) + ] class S3WhereParser(WhereParser): - def parse(self, where_clause: str, parameters) -> Any: + def parse(self, _where_clause: str, parameters=None) -> Any: # parameters argument is ignored - only relevant for DynamoDB - _filters = WhereParser.parse_where_clause(where_clause) - - return self.filter_rows(_filters) + where_clause = WhereParser.parse_where_clause(_where_clause) - def filter_rows(self, _filters): - def _filter(row): - return all( - [find_value_in_document(keys, row) == value for keys, value in _filters] - ) - - return [row for row in self.source_data if _filter(row)] + return [ + row + for row in self.source_data + if where_clause.apply(find_value_in_document, row) + ] diff --git a/pyproject.toml b/pyproject.toml index 3e7d113..44bc4f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "py-partiql-parser" -version = "0.4.0" +version = "0.4.1" description = "Pure Python PartiQL Parser" readme = "README.md" keywords = ["pypartiql", "parser"] diff --git a/tests/test_dynamodb_examples.py b/tests/test_dynamodb_examples.py index d8f579b..6a707e5 100644 --- a/tests/test_dynamodb_examples.py +++ b/tests/test_dynamodb_examples.py @@ -103,7 +103,6 @@ def test_search_object_inside_a_list(): "a2": serializer.serialize("b2"), } ] - print(obj) query = "select * from table where a1[0].name = 'lvyan'" assert DynamoDBStatementParser(source_data={"table": obj}).parse(query) == obj @@ -145,3 +144,152 @@ def test_table_starting_with_number(): DynamoDBStatementParser(source_data={"0table": double_doc}).parse(query) == double_doc ) + + +def test_complex_where_clauses(): + items = [ + { + "Id": { + "S": "0", + }, + "Name": { + "S": "Lambda", + }, + "NameLower": {"S": "lambda"}, + "Description": { + "S": "Run code in under 15 minutes", + }, + "DescriptionLower": {"S": "run code in under 15 minutes"}, + "Price": { + "N": "2E-7", + }, + "Unit": { + "S": "invocation", + }, + "Category": { + "S": "free", + }, + "FreeTier": { + "N": "1E+6", + }, + }, + { + "Id": { + "S": "1", + }, + "Name": { + "S": "Auto Scaling", + }, + "NameLower": {"S": "auto scaling"}, + "Description": { + "S": "Automatically scale the number of EC2 instances with demand", + }, + "DescriptionLower": { + "S": "automatically scale the number of ec2 instances with demand" + }, + "Price": { + "N": "0", + }, + "Unit": { + "S": "group", + }, + "Category": { + "S": "free", + }, + "FreeTier": { + "NULL": True, + }, + }, + { + "Id": { + "S": "2", + }, + "Name": { + "S": "EC2", + }, + "NameLower": {"S": "ec2"}, + "Description": { + "S": "Servers in the cloud", + }, + "DescriptionLower": {"S": "servers in the cloud"}, + "Price": { + "N": "7.2", + }, + "Unit": { + "S": "instance", + }, + "Category": { + "S": "trial", + }, + }, + { + "Id": { + "S": "3", + }, + "Name": { + "S": "Config", + }, + "NameLower": {"S": "config"}, + "Description": { + "S": "Audit the configuration of AWS resources", + }, + "DescriptionLower": { + "S": "audit the configuration of aws resources", + }, + "Price": { + "N": "0.003", + }, + "Unit": { + "S": "configuration item", + }, + "Category": { + "S": "paid", + }, + }, + ] + + # IS MISSING + query = "SELECT Id from table where FreeTier IS MISSING" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "2"}}, + {"Id": {"S": "3"}}, + ] + + # IS NOT MISSING + query = "SELECT Id from table where FreeTier IS NOT MISSING" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "0"}}, + {"Id": {"S": "1"}}, + ] + + # CONTAINS + query = "SELECT Id from table where contains(\"DescriptionLower\", 'cloud')" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "2"}} + ] + + # < + query = "SELECT Id from table where Price < 1" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "0"}}, + {"Id": {"S": "1"}}, + {"Id": {"S": "3"}}, + ] + + # >= + query = "SELECT Id from table where Price >= 7.2" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "2"}} + ] + + # attribute_type + query = "SELECT Id from table where attribute_type(\"FreeTier\", 'N')" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "0"}} + ] + + # FULL + query = f"SELECT Id FROM \"table\" WHERE (contains(\"NameLower\", 'code') OR contains(\"DescriptionLower\", 'code')) AND Category = 'free' AND Price >= 0 AND Price <= 1 AND FreeTier IS NOT MISSING AND attribute_type(\"FreeTier\", 'N')" + assert DynamoDBStatementParser(source_data={"table": items}).parse(query) == [ + {"Id": {"S": "0"}} + ] diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..a7a0a6f --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,38 @@ +from py_partiql_parser._internal.clause_tokenizer import ClauseTokenizer + + +def test_immediate_overrun(): + assert ClauseTokenizer("").next() is None + + +def test_overrun(): + tokenizer = ClauseTokenizer("ab") + assert tokenizer.next() == "a" + assert tokenizer.next() == "b" + assert tokenizer.next() is None + + +def test_current(): + tokenizer = ClauseTokenizer("ab") + assert tokenizer.current() == "a" + assert tokenizer.current() == "a" + assert tokenizer.next() == "a" + assert tokenizer.current() == "b" + assert tokenizer.next() == "b" + assert tokenizer.current() is None + + +def test_peek(): + tokenizer = ClauseTokenizer("abc") + assert tokenizer.peek() == "b" + tokenizer.next() + assert tokenizer.peek() == "c" + assert tokenizer.next() == "b" + assert tokenizer.peek() is None + + +def test_next_until(): + tokenizer = ClauseTokenizer("sth (relevant data) else") + while tokenizer.next() != "(": + pass + assert tokenizer.next_until([")"]) == "relevant data" diff --git a/tests/test_where_parser.py b/tests/test_where_parser.py index cb04290..77b87a3 100644 --- a/tests/test_where_parser.py +++ b/tests/test_where_parser.py @@ -1,64 +1,161 @@ +from itertools import chain + from py_partiql_parser._internal.where_parser import WhereParser from py_partiql_parser._internal.where_parser import S3WhereParser from py_partiql_parser._internal.where_parser import DynamoDBWhereParser +from py_partiql_parser._internal.where_parser import ( + WhereClause, + WhereAndClause, + WhereOrClause, +) class TestWhereClause: def test_single_key(self): where_clause = "s3object.city = 'Chicago'" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3object", "city"], - "Chicago", - ) - ] + assert WhereParser.parse_where_clause(where_clause) == WhereClause( + fn="=", left=["s3object", "city"], right="Chicago" + ) + + def test_single_key_surrounded_by_parentheses(self): + where_clause = "(city = 'Chicago')" + assert WhereParser.parse_where_clause(where_clause) == WhereClause( + fn="=", left=["city"], right="Chicago" + ) def test_nested_key(self): where_clause = "s3object.city.street = 'Chicago'" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3object", "city", "street"], - "Chicago", - ) - ] + assert WhereParser.parse_where_clause(where_clause) == WhereClause( + fn="=", left=["s3object", "city", "street"], right="Chicago" + ) def test_quoted_key(self): where_clause = "s3object.\"city\" = 'Chicago'" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3object", "city"], - "Chicago", - ) - ] + assert WhereParser.parse_where_clause(where_clause) == WhereClause( + fn="=", left=["s3object", "city"], right="Chicago" + ) def test_quoted_nested_key(self): where_clause = "s3object.\"city details\".street = 'Chicago'" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3object", "city details", "street"], - "Chicago", - ) - ] + assert WhereParser.parse_where_clause(where_clause) == WhereClause( + fn="=", left=["s3object", "city details", "street"], right="Chicago" + ) - def test_multiple_keys(self): + def test_multiple_keys__and(self): where_clause = "s3.city = 'Chicago' AND s3.name = 'Tommy'" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3", "city"], - "Chicago", - ), - (["s3", "name"], "Tommy"), - ] + expected = WhereAndClause( + WhereClause(fn="=", left=["s3", "city"], right="Chicago") + ) + expected.children.append( + WhereClause(fn="=", left=["s3", "name"], right="Tommy") + ) + assert WhereParser.parse_where_clause(where_clause) == expected def test_multiple_keys_with_question_marks(self): where_clause = "s3.city = ? AND s3.name = ?" - assert WhereParser.parse_where_clause(where_clause) == [ - ( - ["s3", "city"], - "?", - ), - (["s3", "name"], "?"), - ] + expected = WhereAndClause(WhereClause(fn="=", left=["s3", "city"], right="?")) + expected.children.append(WhereClause(fn="=", left=["s3", "name"], right="?")) + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_multiple_keys__or(self): + where_clause = "s3.city = 'Chicago' OR s3.name = 'Tommy'" + expected = WhereOrClause( + WhereClause(fn="=", left=["s3", "city"], right="Chicago") + ) + expected.children.append( + WhereClause(fn="=", left=["s3", "name"], right="Tommy") + ) + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_multiple_keys__and_and_or(self): + where_clause = "(city = 'Chicago' AND name = 'Tommy') OR stuff = 'sth'" + left = WhereAndClause(WhereClause(fn="=", left=["city"], right="Chicago")) + left.children.append(WhereClause(fn="=", left=["name"], right="Tommy")) + expected = WhereOrClause(left) + expected.children.append(WhereClause(fn="=", left=["stuff"], right="sth")) + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_multiple_keys__or_and_and(self): + where_clause = "stuff = 'sth' or (city = 'Chicago' AND name = 'Tommy')" + right = WhereAndClause(WhereClause(fn="=", left=["city"], right="Chicago")) + right.children.append(WhereClause(fn="=", left=["name"], right="Tommy")) + expected = WhereAndClause(WhereClause(fn="=", left=["stuff"], right="sth")) + expected.children.append(right) + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_number_operators_and_is_not_missing(self): + where_clause = "Price >= 0 AND Price <= 1 AND FreeTier IS NOT MISSING AND attribute_type(\"FreeTier\", 'N')" + result = WhereParser.parse_where_clause(where_clause) + + flat_results = [] + + def _get_root_clause(clause): + if isinstance(clause, WhereAndClause): + [ + _get_root_clause(c) + for c in clause.children + if isinstance(c, WhereAndClause) + ] + flat_results.extend( + [c for c in clause.children if not isinstance(c, WhereAndClause)] + ) + else: + flat_results.append(clause) + + _get_root_clause(result) + assert len(flat_results) == 4 + + assert WhereClause(left=["Price"], fn=">=", right="0") in flat_results + assert WhereClause(left=["Price"], fn="<=", right="1") in flat_results + assert ( + WhereClause(left=["FreeTier"], fn="is", right="NOT MISSING") in flat_results + ) + assert ( + WhereClause(left=["FreeTier"], fn="attribute_type", right="N") + in flat_results + ) + + def test_contains(self): + where_clause = "(contains(\"city\", 'dam'))" + expected = WhereClause(fn="contains", left=["city"], right="dam") + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_multiple_contains(self): + where_clause = "(contains(\"city\", 'dam') and contains(\"find\", 'things'))" + left = WhereClause(fn="contains", left=["city"], right="dam") + right = WhereClause(fn="contains", left=["find"], right="things") + expected = WhereAndClause(left) + expected.children.append(right) + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_comparisons(self): + where_clause = "size >= 20" + expected = WhereClause(fn=">=", left=["size"], right="20") + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_missing(self): + where_clause = "size IS MISSING" + expected = WhereClause(fn="IS", left=["size"], right="MISSING") + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_not_missing(self): + where_clause = "size IS NOT MISSING" + expected = WhereClause(fn="IS", left=["size"], right="NOT MISSING") + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_attribute_type(self): + # This is only applicable for DynamoDB + # Parsing is the same though - it's just a function like any other + where_clause = "attribute_type(\"FreeTier\", 'N')" + expected = WhereClause(fn="attribute_type", left=["FreeTier"], right="N") + assert WhereParser.parse_where_clause(where_clause) == expected + + def test_where_values_contain_parentheses(self): + # Parentheses are a special case in case of nested clauses + # But should be processed correctly when they are part of a value + where_clause = "sth = 's(meth)ng'" + expected = WhereClause(fn="=", left=["sth"], right="s(meth)ng") + assert WhereParser.parse_where_clause(where_clause) == expected class TestFilter: @@ -73,24 +170,13 @@ class TestFilter: ] def test_simple(self): - filter_keys = ["city"] - filter_value = "Los Angeles" - assert S3WhereParser(TestFilter.all_rows).filter_rows( - _filters=[(filter_keys, filter_value)] - ) == [{"Name": "Vinod", "city": "Los Angeles"}] - - def test_alias(self): - filter_keys = ["city"] - filter_value = "Los Angeles" - assert S3WhereParser(TestFilter.all_rows).filter_rows( - _filters=[(filter_keys, filter_value)] + assert S3WhereParser(TestFilter.all_rows).parse( + _where_clause="city = 'Los Angeles'" ) == [{"Name": "Vinod", "city": "Los Angeles"}] def test_alias_nested_key(self): - filter_keys = ["notes", "extra"] - filter_value = "y" - assert S3WhereParser(TestFilter.all_rows).filter_rows( - _filters=[(filter_keys, filter_value)] + assert S3WhereParser(TestFilter.all_rows).parse( + _where_clause="notes.extra = 'y'" ) == [{"Name": "Mary", "city": "Chicago", "notes": {"extra": "y"}}]