From 00881bf51151a955a6fff23d301b495746d899d3 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 25 Dec 2024 17:44:23 -0100 Subject: [PATCH] S3: API now also returns number of bytes processed --- .../_internal/clause_tokenizer.py | 9 ++ py_partiql_parser/_internal/from_parser.py | 76 +++-------- py_partiql_parser/_internal/json_parser.py | 16 ++- py_partiql_parser/_internal/parser.py | 46 ++++--- py_partiql_parser/_internal/select_parser.py | 127 ++++++++++++------ py_partiql_parser/_internal/where_parser.py | 19 +-- tests/test_json_encoder.py | 4 +- tests/test_s3_examples.py | 54 ++++---- tests/test_select_functions.py | 2 +- tests/test_select_parser.py | 12 +- tests/test_where_parser.py | 18 ++- 11 files changed, 215 insertions(+), 168 deletions(-) diff --git a/py_partiql_parser/_internal/clause_tokenizer.py b/py_partiql_parser/_internal/clause_tokenizer.py index 1835fc7..0e00927 100644 --- a/py_partiql_parser/_internal/clause_tokenizer.py +++ b/py_partiql_parser/_internal/clause_tokenizer.py @@ -5,6 +5,12 @@ class ClauseTokenizer: def __init__(self, from_clause: str): self.token_list = from_clause self.token_pos = 0 + self.tokens_parsed = 0 + + def get_tokens_parsed(self) -> int: + x = self.tokens_parsed + self.tokens_parsed = 0 + return x def current(self) -> Optional[str]: """ @@ -22,6 +28,7 @@ def next(self) -> Optional[str]: """ try: crnt_token = self.token_list[self.token_pos] + self.tokens_parsed += 1 self.token_pos += 1 return crnt_token except IndexError: @@ -34,11 +41,13 @@ def peek(self) -> Optional[str]: return None def revert(self) -> None: + self.tokens_parsed -= 1 self.token_pos -= 1 def skip_white_space(self) -> None: try: while self.token_list[self.token_pos] in [" ", "\n"]: + self.tokens_parsed += 1 self.token_pos += 1 except IndexError: pass diff --git a/py_partiql_parser/_internal/from_parser.py b/py_partiql_parser/_internal/from_parser.py index d2b271e..181c524 100644 --- a/py_partiql_parser/_internal/from_parser.py +++ b/py_partiql_parser/_internal/from_parser.py @@ -75,79 +75,43 @@ def __init__(self, from_clause: str): class S3FromParser(FromParser): - def get_source_data(self, documents: Dict[str, str]) -> Any: - from_alias = list(self.clauses.keys())[0].lower() + def get_source_data(self, document: CaseInsensitiveDict[str, str]) -> Any: from_query = list(self.clauses.values())[0].lower() if "." in from_query: - return self._get_nested_source_data(documents) + return self._get_nested_source_data(document) - key_has_asterix = from_query.endswith("[*]") - from_query = from_query[0:-3] if key_has_asterix else from_query - from_alias = from_alias[0:-3] if from_alias.endswith("[*]") else from_alias - doc_is_list = documents[from_query].startswith("[") and documents[ - from_query - ].endswith("]") - - source_data = list(JsonParser.parse(documents[from_query])) - - if doc_is_list: - return {"_1": source_data[0]} - elif from_alias: - return [CaseInsensitiveDict({from_alias: doc}) for doc in source_data] + if isinstance(document, list): + return {"_1": document} else: - return source_data + return document - def _get_nested_source_data(self, documents: Dict[str, Any]) -> Any: + def _get_nested_source_data(self, document: CaseInsensitiveDict[str, Any]) -> Any: """ Our FROM-clauses are nested, meaning we need to dig into the provided document to return the key that we need --> FROM s3object.name as name """ - root_doc = True - source_data = documents - iterate_over_docs = False entire_key = list(self.clauses.values())[0].lower().split(".") + if entire_key[0].lower() in ["s3object[*]"]: + entire_key = entire_key[1:] alias = list(self.clauses.keys())[0] if alias.endswith("[*]"): alias = alias[0:-3] key_so_far = [] for key in entire_key: key_so_far.append(key) - key_has_asterix = key.endswith("[*]") and key[0:-3] in source_data - new_key = key[0:-3] if key_has_asterix else key - if iterate_over_docs and isinstance(source_data, list): # type: ignore[unreachable] - # The previous key ended in [*] - # Iterate over all docs in the result, and only return the requested source key - if key_so_far == entire_key: # type: ignore[unreachable] - # If we have an alias, we have to use that instead of the original name - source_data = [{alias: doc.get(new_key, {})} for doc in source_data] - else: - source_data = [ - doc.get_original(new_key) or CaseInsensitiveDict({}) - for doc in source_data - ] + + if key in document: + document = document[key] + if isinstance(document, list): + # AWS behaviour when the root-document is a list + document = {"_1": document[0]} + elif key_so_far == entire_key: + if list(self.clauses.keys()) != list(self.clauses.values()): + document = CaseInsensitiveDict({alias: document}) else: - # The previous key was a regular key - # Assume that the result consists of a singular JSON document - if new_key in source_data: - doc_is_list = source_data[new_key].startswith("[") and source_data[ - new_key - ].endswith("]") - source_data = list(JsonParser.parse(source_data[new_key])) # type: ignore - if root_doc and doc_is_list: - # AWS behaviour when the root-document is a list - source_data = {"_1": source_data[0]} # type: ignore - elif key_so_far == entire_key: - if isinstance(source_data, list): # type: ignore[unreachable] - source_data = [{alias: doc} for doc in source_data] # type: ignore[unreachable] - else: - source_data = {alias: source_data} - else: - source_data = {} - - iterate_over_docs = key_has_asterix - root_doc = False - - return source_data + document = {} + + return document class DynamoDBFromParser(FromParser): diff --git a/py_partiql_parser/_internal/json_parser.py b/py_partiql_parser/_internal/json_parser.py index c5216be..5ea7b3b 100644 --- a/py_partiql_parser/_internal/json_parser.py +++ b/py_partiql_parser/_internal/json_parser.py @@ -1,5 +1,5 @@ from json import JSONEncoder -from typing import Any, List, Iterator, Optional +from typing import Any, List, Iterator, Optional, Tuple from .clause_tokenizer import ClauseTokenizer from .utils import CaseInsensitiveDict, Variable @@ -25,6 +25,20 @@ def parse(original: str) -> Iterator[Any]: # type: ignore[misc] if result is not None: yield result + @staticmethod + def parse_with_tokens(original: str) -> Tuple[Iterator[Any], int]: # type: ignore[misc] + """ + Parse JSON string. Returns a tuple of (json_doc, nr_of_bytes_processed) + """ + if not (original.startswith("{") or original.startswith("[")): + # Doesn't look like JSON - let's return as a variable + yield original if original.isnumeric() else Variable(original) + tokenizer = ClauseTokenizer(original) + while tokenizer.current() is not None: + result = JsonParser._get_next_document(original, tokenizer) + if result is not None: + yield result, tokenizer.get_tokens_parsed() + @staticmethod def _get_next_document( # type: ignore[misc] original: str, diff --git a/py_partiql_parser/_internal/parser.py b/py_partiql_parser/_internal/parser.py index b0034ed..e97fdf8 100644 --- a/py_partiql_parser/_internal/parser.py +++ b/py_partiql_parser/_internal/parser.py @@ -6,7 +6,8 @@ from .delete_parser import DeleteParser from .from_parser import DynamoDBFromParser, S3FromParser, FromParser from .insert_parser import InsertParser -from .select_parser import SelectParser +from .json_parser import JsonParser +from .select_parser import DynamoDBSelectParser, S3SelectClauseParser from .update_parser import UpdateParser from .where_parser import DynamoDBWhereParser, S3WhereParser, WhereParser from .utils import is_dict, QueryMetadata, CaseInsensitiveDict @@ -19,38 +20,43 @@ class S3SelectParser: - def __init__(self, source_data: Dict[str, str]): + def __init__(self, source_data: str): # Source data is in the format: {source: json} # Where 'json' is one or more json documents separated by a newline self.documents = source_data self.table_prefix = "s3object" + self.bytes_scanned = 0 def parse(self, query: str) -> List[Dict[str, Any]]: query = query.replace("\n", " ") clauses = re.split("SELECT | FROM | WHERE ", query, flags=re.IGNORECASE) # First clause is whatever comes in front of SELECT - which should be nothing _ = clauses[0] - # FROM - from_parser = S3FromParser(from_clause=clauses[2]) - - source_data = from_parser.get_source_data(self.documents) - if is_dict(source_data): - source_data = [source_data] - - # WHERE - if len(clauses) > 3: - where_clause = clauses[3] - source_data = S3WhereParser(source_data).parse(where_clause) - # SELECT - select_clause = clauses[1] + from_parser = S3FromParser(from_clause=clauses[2]) table_prefix = self.table_prefix for alias_key, alias_value in from_parser.clauses.items(): - if table_prefix == alias_value: + if table_prefix == alias_value or f"{table_prefix}[*]" == alias_value: table_prefix = alias_key - return SelectParser(table_prefix).parse( - select_clause, from_parser.clauses, source_data - ) + + results = [] + + for doc, tokens_parsed in JsonParser.parse_with_tokens(self.documents): + # get bytes scanned + doc = from_parser.get_source_data(doc) + self.bytes_scanned += tokens_parsed + + if len(clauses) > 3: + where_clause = clauses[3] + if not S3WhereParser.applies(doc, table_prefix, where_clause): + continue + + select_clause = clauses[1] + + S3SelectClauseParser(table_prefix).parse(select_clause, from_parser.clauses, doc, results) + # get bytes returned + + return results class DynamoDBStatementParser: @@ -116,7 +122,7 @@ def _parse_select( # SELECT select_clause = clauses[1] - queried_data = SelectParser().parse( + queried_data = DynamoDBSelectParser().parse( select_clause, from_parser.clauses, source_data ) updates: Dict[ diff --git a/py_partiql_parser/_internal/select_parser.py b/py_partiql_parser/_internal/select_parser.py index a834fb2..0aa5843 100644 --- a/py_partiql_parser/_internal/select_parser.py +++ b/py_partiql_parser/_internal/select_parser.py @@ -14,12 +14,13 @@ def __init__(self, value: str, table_prefix: Optional[str] = None): self.table_prefix = table_prefix self.value = value.strip() + if self.value == self.table_prefix: + self.value = "*" + def select(self, document: CaseInsensitiveDict) -> Any: if self.value == "*": - if self.table_prefix and self.table_prefix in document: - return document[self.table_prefix] - else: - return document + return document + self.value = self.value.removeprefix(f"{self.table_prefix}.") if "." in self.value: key, remaining = self.value.split(".", maxsplit=1) return find_nested_data_in_object( @@ -33,7 +34,7 @@ def select(self, document: CaseInsensitiveDict) -> Any: if is_dict(document[self.value]): return document[self.value] else: - return {self.value: document[self.value]} + return document.get_original(self.value) def __repr__(self) -> str: return f"" @@ -48,9 +49,12 @@ def __init__(self, value: str, function_name: str): self.function_name = function_name.strip() def execute( - self, aliases: Dict[str, str], documents: List[CaseInsensitiveDict] + self, aliases: Dict[str, str], document: CaseInsensitiveDict, results: List ) -> Dict[str, int]: - return {"_1": len(documents)} + if results: + results[0]["_1"] += 1 + else: + results.append({"_1": 1}) def __repr__(self) -> str: return f"" @@ -64,40 +68,6 @@ def __eq__(self, other: Any) -> bool: class SelectParser: - def __init__(self, table_prefix: Optional[str] = None): - self.table_prefix = table_prefix - - def parse( - self, - select_clause: str, - aliases: Dict[str, Any], - documents: List[CaseInsensitiveDict], - ) -> List[Dict[str, Any]]: - clauses = SelectParser.parse_clauses(select_clause, prefix=self.table_prefix) - - for clause in clauses: - if isinstance(clause, FunctionClause): - return [clause.execute(aliases, documents)] - - result: List[Dict[str, Any]] = [] - - for json_document in documents: - filtered_document = dict() - for clause in clauses: - attr = clause.select(json_document) - if attr is not None and not isinstance(attr, MissingVariable): - # Specific usecase: - # select * from s3object[*].Name my_n - if ( - "." in list(aliases.values())[0] - and list(aliases.keys())[0] in attr - and select_clause == "*" - ): - filtered_document.update({"_1": attr[list(aliases.keys())[0]]}) - else: - filtered_document.update(attr) - result.append(filtered_document) - return result @classmethod def parse_clauses( @@ -130,3 +100,78 @@ def parse_clauses( continue current_clause += c return results + + +class S3SelectClauseParser(SelectParser): + def __init__(self, table_prefix: Optional[str]): + self.table_prefix = table_prefix + + def parse( + self, + select_clause: str, + aliases: Dict[str, Any], + document: CaseInsensitiveDict, + results: List, + ): + clauses = SelectParser.parse_clauses(select_clause, prefix=self.table_prefix) + + has_fn_clause = False + + for clause in clauses: + if isinstance(clause, FunctionClause): + has_fn_clause = True + clause.execute(aliases, document, results) + + if has_fn_clause: + return + + filtered_document = dict() + for clause in clauses: + attr = clause.select(document) + if attr is not None and not isinstance(attr, MissingVariable): + # Specific usecase: + # select * from s3object[*].Name my_n + if ( + "." in list(aliases.values())[0] + and list(aliases.keys()) != list(aliases.values()) + and list(aliases.keys())[0] in attr + and select_clause == "*" + ): + filtered_document.update({"_1": attr[list(aliases.keys())[0]]}) + else: + filtered_document.update(attr) + results.append(filtered_document) + + +class DynamoDBSelectParser(SelectParser): + def parse( + self, + select_clause: str, + aliases: Dict[str, Any], + documents: List[CaseInsensitiveDict], + ) -> List[Dict[str, Any]]: + clauses = SelectParser.parse_clauses(select_clause) + + for clause in clauses: + if isinstance(clause, FunctionClause): + return [clause.execute(aliases, documents)] + + result: List[Dict[str, Any]] = [] + + for json_document in documents: + filtered_document = dict() + for clause in clauses: + attr = clause.select(json_document) + if attr is not None and not isinstance(attr, MissingVariable): + # Specific usecase: + # select * from s3object[*].Name my_n + if ( + "." in list(aliases.values())[0] + and list(aliases.keys())[0] in attr + and select_clause == "*" + ): + filtered_document.update({"_1": attr[list(aliases.keys())[0]]}) + else: + filtered_document.update(attr) + result.append(filtered_document) + return result diff --git a/py_partiql_parser/_internal/where_parser.py b/py_partiql_parser/_internal/where_parser.py index b95a413..9293e83 100644 --- a/py_partiql_parser/_internal/where_parser.py +++ b/py_partiql_parser/_internal/where_parser.py @@ -114,8 +114,6 @@ def __repr__(self) -> str: class WhereParser: - def __init__(self, source_data: List[CaseInsensitiveDict]): - self.source_data = source_data @classmethod def parse_where_clause( @@ -295,6 +293,9 @@ def _determine_current_clause( class DynamoDBWhereParser(WhereParser): + def __init__(self, source_data: List[CaseInsensitiveDict]): + self.source_data = source_data + def parse( self, _where_clause: str, parameters: Optional[List[Dict[str, Any]]] ) -> List[CaseInsensitiveDict]: @@ -321,12 +322,12 @@ def prep_value(val: str) -> Dict[str, Any]: class S3WhereParser(WhereParser): - def parse(self, _where_clause: str) -> Any: - # parameters argument is ignored - only relevant for DynamoDB + @classmethod + def applies(cls, doc, table_prefix, _where_clause: str) -> bool: where_clause = WhereParser.parse_where_clause(_where_clause) - return [ - row - for row in self.source_data - if where_clause.apply(find_value_in_document, row) - ] + # + if isinstance(where_clause, WhereClause): + where_clause.left.remove(table_prefix) + + return where_clause.apply(find_value_in_document, doc) diff --git a/tests/test_json_encoder.py b/tests/test_json_encoder.py index a3c5645..a131297 100644 --- a/tests/test_json_encoder.py +++ b/tests/test_json_encoder.py @@ -13,12 +13,12 @@ def test_json_output_can_be_dumped() -> None: "kids": None, } ) - result = S3SelectParser(source_data={"s3object": input_with_none}).parse(query) + result = S3SelectParser(source_data=input_with_none).parse(query) assert f"[{input_with_none}]" == json.dumps(result, cls=SelectEncoder) def test_json_with_lists_can_be_dumped() -> None: query = "select * from s3object s" input_with_none = json.dumps(input_with_lists[0]) - result = S3SelectParser(source_data={"s3object": input_with_none}).parse(query) + result = S3SelectParser(source_data=input_with_none).parse(query) assert f"[{input_with_none}]" == json.dumps(result, cls=SelectEncoder) diff --git a/tests/test_s3_examples.py b/tests/test_s3_examples.py index c42a1a3..d43f6b5 100644 --- a/tests/test_s3_examples.py +++ b/tests/test_s3_examples.py @@ -22,7 +22,7 @@ @pytest.mark.xfail(reason="CSV functionality not yet implemented") def test_aws_sample__csv() -> None: query = "SELECT * FROM s3object s where s.\"Name\" = 'Jane'" - x = S3SelectParser(source_data={"s3object": input_data_csv}).parse(query) + x = S3SelectParser(source_data=input_data_csv).parse(query) @pytest.mark.xfail( @@ -30,7 +30,7 @@ def test_aws_sample__csv() -> None: ) def test_aws_sample__json__search_by_name() -> None: query = "SELECT * FROM s3object s where s.\"Name\" = 'Jane'" - result = S3SelectParser(source_data={"s3object": input_json_list}).parse(query) # type: ignore + result = S3SelectParser(source_data=input_json_list).parse(query) # type: ignore assert result == [ { "Name": "Jane", @@ -49,7 +49,7 @@ def test_aws_sample__json__search_by_name() -> None: ], ) def test_aws_sample__json__search_by_city(query: str) -> None: - result = S3SelectParser(source_data={"s3object": json_as_lines}).parse(query) + result = S3SelectParser(source_data=json_as_lines).parse(query) assert len(result) == 4 assert { "Name": "Jane", @@ -79,7 +79,7 @@ def test_aws_sample__json__search_by_city(query: str) -> None: def test_aws_sample__json_select_multiple_attrs__search_by_city() -> None: query = "SELECT s.name, s.city FROM s3object s where s.\"City\" = 'Chicago'" - result = S3SelectParser(source_data={"s3object": json_as_lines}).parse(query) + result = S3SelectParser(source_data=json_as_lines).parse(query) assert len(result) == 4 assert { "Name": "Jane", @@ -101,33 +101,25 @@ def test_aws_sample__json_select_multiple_attrs__search_by_city() -> None: def test_aws_sample__object_select_all() -> None: query = "SELECT * FROM s3object" - result = S3SelectParser( - source_data={"s3object": json.dumps(input_json_object)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_json_object)).parse(query) assert result == [input_json_object] def test_aws_sample__s3object_is_case_insensitive() -> None: query = "SELECT * FROM s3obJEct" - result = S3SelectParser( - source_data={"s3object": json.dumps(input_json_object)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_json_object)).parse(query) assert result == [input_json_object] def test_aws_sample__object_select_everything() -> None: query = "SELECT s FROM s3object AS s" - result = S3SelectParser( - source_data={"s3object": json.dumps(input_json_object)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_json_object)).parse(query) assert result == [input_json_object] def test_aws_sample__object_select_attr() -> None: query = "SELECT s.a1 FROM s3object AS s" - result = S3SelectParser( - source_data={"s3object": json.dumps(input_json_object)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_json_object)).parse(query) assert result == [{"a1": "b1"}] @@ -140,13 +132,13 @@ def test_case_insensitivity() -> None: + "\n" + json.dumps({"Name": "Vinod", "City": "Los Angeles"}) ) - parser = S3SelectParser(source_data={"s3object": all_rows}) + parser = S3SelectParser(source_data=all_rows) assert parser.parse(query) == [{"Name": "Vinod", "City": "Los Angeles"}] def test_select_doc_using_asterisk() -> None: query = "select * from s3object[*]" - result = S3SelectParser(source_data={"s3object": json_as_lines}).parse(query) + result = S3SelectParser(source_data=json_as_lines).parse(query) assert len(result) == 7 assert { "Name": "Sam", @@ -164,7 +156,7 @@ def test_select_doc_using_asterisk() -> None: ], ) def test_select_specific_object_doc_using_named_asterisk(query: str) -> None: - result = S3SelectParser(source_data={"s3object": json_as_lines}).parse(query) + result = S3SelectParser(source_data=json_as_lines).parse(query) assert len(result) == 7 assert {"my_n": "Sam"} in result assert {"my_n": "Jeff"} in result @@ -175,7 +167,7 @@ def test_select_specific_object_doc_using_named_asterisk(query: str) -> None: ["select * from s3object[*].Name my_n", "select * from s3object[*].Name as my_n"], ) def test_select_nested_object_using_named_asterisk(query: str) -> None: - result = S3SelectParser(source_data={"s3object": json_as_lines}).parse(query) + result = S3SelectParser(source_data=json_as_lines).parse(query) assert len(result) == 7 assert {"_1": "Sam"} in result assert {"_1": "Jeff"} in result @@ -183,9 +175,7 @@ def test_select_nested_object_using_named_asterisk(query: str) -> None: def test_select_list_using_asterisk() -> None: query = "select * from s3object[*] s" - result = S3SelectParser( - source_data={"s3object": json.dumps(input_with_lists)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_with_lists)).parse(query) assert result == [{"_1": input_with_lists}] @@ -194,7 +184,19 @@ def test_select_list_using_asterisk() -> None: ["select * from s3object[*].staff s", "select * from s3object[*].staff[*] s"], ) def test_select_nested_list_using_asterisk(query: str) -> None: - result = S3SelectParser( - source_data={"s3object": json.dumps(input_with_lists)} - ).parse(query) + result = S3SelectParser(source_data=json.dumps(input_with_lists)).parse(query) assert result == [{}] + + +def test_nested_where_clause(): + query = "select * from s3object[*] s where s.c = 'd2'" + doc = "".join([json.dumps(x) for x in [{"a": {f"b{i}": "s"}, "c": f"d{i}"} for i in range(5)]]) + result = S3SelectParser(source_data=doc).parse(query) + assert result == [{'a': {'b2': 's'}, 'c': 'd2'}] + + +def test_nested_from_clause(): + query = "select * from s3object[*].a" + doc = "".join([json.dumps(x) for x in [{"a": {f"b{i}": "s"}, "c": f"d{i}"} for i in range(5)]]) + result = S3SelectParser(source_data=doc).parse(query) + assert {'b2': 's'} in result diff --git a/tests/test_select_functions.py b/tests/test_select_functions.py index b9f8b7b..5e39f2c 100644 --- a/tests/test_select_functions.py +++ b/tests/test_select_functions.py @@ -5,7 +5,7 @@ class TestCount: def setup_method(self) -> None: - self.parser = S3SelectParser(source_data={"s3object": json_as_lines}) + self.parser = S3SelectParser(source_data=json_as_lines) @pytest.mark.parametrize( "query,key,result", diff --git a/tests/test_select_parser.py b/tests/test_select_parser.py index 3c49d5f..2f0f570 100644 --- a/tests/test_select_parser.py +++ b/tests/test_select_parser.py @@ -4,33 +4,33 @@ def test_select_all_clause() -> None: - result = SelectParser(table_prefix=None).parse_clauses("*") + result = SelectParser().parse_clauses("*") assert result == [SelectClause("*")] def test_parse_simple_clause() -> None: - result = SelectParser(table_prefix=None).parse_clauses("s.name") + result = SelectParser().parse_clauses("s.name") assert result == [SelectClause("s.name")] def test_parse_multiple_clauses() -> None: - result = SelectParser(table_prefix=None).parse_clauses("s.name, s.city") + result = SelectParser().parse_clauses("s.name, s.city") assert result == [SelectClause("s.name"), SelectClause("s.city")] def test_parse_function_clause() -> None: - result = SelectParser(table_prefix=None).parse_clauses("count(*)") + result = SelectParser().parse_clauses("count(*)") assert result == [FunctionClause(function_name="count", value="*")] @pytest.mark.xfail(reason="Not yet implemented") def test_parse_function_alias_clause() -> None: - result = SelectParser(table_prefix=None).parse_clauses("count(*) as cnt") + result = SelectParser().parse_clauses("count(*) as cnt") assert result == [FunctionClause(function_name="count", value="*")] def test_parse_mix_of_function_and_regular_clauses() -> None: - result = SelectParser(table_prefix=None).parse_clauses( + result = SelectParser().parse_clauses( "count(*), s.city, max(s.citizens)" ) assert len(result) == 3 diff --git a/tests/test_where_parser.py b/tests/test_where_parser.py index 0502b50..4df97de 100644 --- a/tests/test_where_parser.py +++ b/tests/test_where_parser.py @@ -168,14 +168,20 @@ class TestFilter: ] def test_simple(self) -> None: - assert S3WhereParser(TestFilter.all_rows).parse( # type: ignore[arg-type] - _where_clause="city = 'Los Angeles'" - ) == [{"Name": "Vinod", "city": "Los Angeles"}] + query = "s3.city = 'Los Angeles'" + assert not S3WhereParser.applies(doc={"Name": "Sam", "city": "Irvine"}, table_prefix="s3", _where_clause=query) + assert not S3WhereParser.applies(doc={"Name": "Sam", "city": "Seattle"}, table_prefix="s3", _where_clause=query) + assert not S3WhereParser.applies(doc={"Name": "Sam", "city": "Chicago"}, table_prefix="s3", _where_clause=query) + assert not S3WhereParser.applies(doc={"Name": "Sam"}, table_prefix="s3", _where_clause=query) + + assert S3WhereParser.applies(doc={"Name": "Sam", "city": "Los Angeles"}, table_prefix="s3", _where_clause=query) def test_alias_nested_key(self) -> None: - assert S3WhereParser(TestFilter.all_rows).parse( # type: ignore[arg-type] - _where_clause="notes.extra = 'y'" - ) == [{"Name": "Mary", "city": "Chicago", "notes": {"extra": "y"}}] + query = "s3.notes.extra = 'y'" + assert not S3WhereParser.applies(doc={"Name": "Sam", "city": "Chicago"}, table_prefix="s3", _where_clause=query) + assert not S3WhereParser.applies(doc={"Name": "Sam"}, table_prefix="s3", _where_clause=query) + + assert S3WhereParser.applies(doc={"Name": "Sam", "notes": {"extra": "y"}}, table_prefix="s3", _where_clause=query) class TestDynamoDBParse: