Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the code to be compatible with the latest version of sqlglot. #183

Merged
merged 5 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aidb/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,4 @@ def bind_inference_service(self, bound_service: BoundInferenceService):
def add_user_defined_function(self, function_name: str, function: Callable):
self.clear_cached_properties()
logger.info(f'Added user defined function {function_name}')
self.user_defined_functions[str.lower(function_name)] = function
self.user_defined_functions[str.upper(function_name)] = function
4 changes: 2 additions & 2 deletions aidb/engine/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def add_filter_key_into_query(
value_list.append(f'({", ".join([str(value) for value in row])})')

filtered_key_str = f'({col_tuple}) IN ({", ".join(value_list)})'
new_query = query.add_where_condition('and', filtered_key_str)
new_query = query.add_where_condition(filtered_key_str)

return new_query, selected_column

Expand Down Expand Up @@ -371,7 +371,7 @@ def register_user_defined_function(self, function_name, function):


def _call_user_function(self, res_df: pd.DataFrame, function_name: str, args_list: List[str]):
function_name = str.lower(function_name)
function_name = str.upper(function_name)

if inspect.iscoroutinefunction(self._config.user_defined_functions[function_name]):
list_function_results = asyncio_run(self._config.user_defined_functions[function_name](res_df[args_list]))
Expand Down
170 changes: 72 additions & 98 deletions aidb/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict, List

import sqlglot.expressions as exp
from sqlglot.rewriter import Rewriter
from sqlglot import Parser, Tokenizer, parse_one
from sympy import sympify
from sympy.logic.boolalg import to_cnf
Expand Down Expand Up @@ -80,7 +79,7 @@ def all_queries_in_expressions(self):
This function is used to extract all queries in the expression, including the entire query and subqueries
'''
all_queries = []
for node, _, _ in self._expression.walk():
for node in self._expression.walk():
if isinstance(node, exp.Select):
depth = 0
parent = node.parent
Expand Down Expand Up @@ -130,23 +129,19 @@ def table_and_column_aliases_in_query(self):
"""
table_alias_to_name = {}
column_alias_to_name = {}
for node, _, _ in self._expression.walk():
for node in self._expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Alias) and 'alias' in node.args and 'this' in node.args:
if isinstance(node.args['this'], exp.Table):
tbl_alias = node.args['alias'].args['this']
if str.lower(tbl_alias) in table_alias_to_name:
raise Exception('Duplicated alias found in query, please use another alias')
tbl_name = node.args['this'].args['this'].args['this']
table_alias_to_name[str.lower(tbl_alias)] = str.lower(tbl_name)

elif isinstance(node.args['this'], exp.Column):
col_alias = node.args['alias'].args['this']
if str.lower(col_alias) in column_alias_to_name:
raise Exception('Duplicated alias found in query, please use another alias')
col_name = node.args['this'].args['this'].args['this']
column_alias_to_name[str.lower(col_alias)] = str.lower(col_name)
if isinstance(node, exp.Table) and node.args.get('alias'):
table_alias = node.alias
table_name = node.name
table_alias_to_name[str.lower(table_alias)] = str.lower(table_name)
if isinstance(node, exp.Alias) and isinstance(node.args['this'], exp.Column):
col_alias = node.alias
if str.lower(col_alias) in column_alias_to_name:
raise Exception('Duplicated alias found in query, please use another alias')
col_name = node.args['this'].args['this'].args['this']
column_alias_to_name[str.lower(col_alias)] = str.lower(col_name)

return table_alias_to_name, column_alias_to_name

Expand All @@ -160,11 +155,11 @@ def udf_outputs_aliases(self):
# Dictionary of aliases for user-defined functions defined in AIDB
alias_to_udf_mapping = {}
alias_index = 0
for node, _, _ in self._expression.walk(bfs=False):
for node in self._expression.walk(bfs=False):
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Alias) and 'alias' in node.args and 'this' in node.args:
if (isinstance(node.args['this'], exp.UserFunction) and node.args['alias'].args['this'] is not None
if (isinstance(node.args['this'], exp.Anonymous) and node.args['alias'].args['this'] is not None
and node.args['this'].args['this'] in self.config.user_defined_functions):
udf_output_alias_key = f"{node.args['this'].args['this']}__{alias_index}"
alias_index += 1
Expand All @@ -176,7 +171,7 @@ def udf_outputs_aliases(self):
else:
raise Exception('Duplicated alias is not allowed, please use another alias')
elif isinstance(node, exp.Aliases) and 'expressions' in node.args and 'this' in node.args:
if (isinstance(node.args['this'], exp.UserFunction) and len(node.args['expressions']) != 0
if (isinstance(node.args['this'], exp.Anonymous) and len(node.args['expressions']) != 0
and node.args['this'].args['this'] in self.config.user_defined_functions):
udf_output_alias_key = f"{node.args['this'].args['this']}__{alias_index}"
alias_index += 1
Expand Down Expand Up @@ -204,7 +199,7 @@ def columns_in_query(self):
@cached_property
def tables_in_query(self):
table_set = set()
for node, _, _ in self._expression.walk():
for node in self._expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Table):
Expand Down Expand Up @@ -291,7 +286,7 @@ def _replace_col_in_filter_predicate_with_root_col(self, expression):
so filtering predicate will be converted into 'blob.frame > 10000'
'''
copied_expression = expression.copy()
for node, _, _ in copied_expression.walk():
for node in copied_expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Column):
Expand Down Expand Up @@ -371,29 +366,16 @@ def query_after_normalizing(self):
e.g. for this query: SELECT frame FROM blobs b WHERE b.timestamp > 100
the expression will be converted into SELECT blobs.frame FROM blobs WHERE blobs.timestamp > 100
"""

def _remove_alias_in_expression(original_list):
removed_alias_list = []
for element in original_list:
if isinstance(element, exp.Alias):
removed_alias_list.append(element.args['this'])
elif isinstance(element, exp.Aliases):
removed_alias_list.append(element.args['this'])
else:
removed_alias_list.append(element)
return removed_alias_list


copied_expression = self._expression.copy()
table_alias_to_name, column_alias_to_name = self.table_and_column_aliases_in_query
udf_output_to_alias_mapping, alias_to_udf_mapping = self.udf_outputs_aliases

for node, parent, _ in copied_expression.walk():
for node in copied_expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Column):
if isinstance(node.args['this'], exp.Identifier):
col_name = str.lower(node.args['this'].args['this'])
col_name = str.lower(node.name)
if col_name in column_alias_to_name:
col_name = column_alias_to_name[col_name]
node.args['this'].set('this', col_name)
Expand All @@ -407,30 +389,25 @@ def _remove_alias_in_expression(original_list):
table_name = self._get_table_of_column(col_name)

node.set('table', exp.Identifier(this=table_name, quoted=False))

elif isinstance(node.args['this'], exp.Star):
if isinstance(parent, exp.AggFunc):
continue
select_exp_list = []
for table_name in self.tables_in_query:
for col_name, _ in self._tables[table_name].items():
new_table = exp.Identifier(this=table_name, quoted=False)
new_column = exp.Identifier(this=col_name, quoted=False)
select_exp_list.append(exp.Column(this=new_column, table=new_table))
copied_expression.set('expressions', select_exp_list)

# remove alias

copied_expression.set('expressions', _remove_alias_in_expression(copied_expression.args['expressions']))
copied_expression.args['from'].set(
'expressions',
_remove_alias_in_expression(copied_expression.args['from'].args['expressions'])
)

for join_element in copied_expression.args['joins']:
if isinstance(join_element.args['this'], exp.Alias):
join_element.set('this', join_element.args['this'].args['this'])

elif isinstance(node, exp.Star):
if isinstance(node.parent, exp.AggFunc):
continue
select_exp_list = []
for table_name in self.tables_in_query:
for col_name, _ in self._tables[table_name].items():
new_table = exp.Identifier(this=table_name, quoted=False)
new_column = exp.Identifier(this=col_name, quoted=False)
select_exp_list.append(exp.Column(this=new_column, table=new_table))
copied_expression.set('expressions', select_exp_list)

# remove alias
if isinstance(node, exp.Table):
node.set('alias', None)
if isinstance(node, exp.Alias):
parent = node.parent
if isinstance(parent, exp.Select):
idx = parent.args['expressions'].index(node)
parent.args['expressions'][idx] = node.this
return Query(copied_expression.sql(), self.config)


Expand All @@ -442,18 +419,18 @@ def _get_normalized_column_set(self, normalized_expression):
it will return [blobs.frame, blobs.timestamp],
"""
normalized_column_set = set()
for node, _, _ in normalized_expression.walk():
for node in normalized_expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Column):
if isinstance(node.args['this'], exp.Identifier):
col_name = node.args['this'].args['this']
table_name = node.args['table'].args['this']
normalized_column_set.add(f'{table_name}.{col_name}')
elif isinstance(node.args['this'], exp.Star):
for table_name in self.tables_in_query:
for col_name, _ in self._tables[table_name].items():
normalized_column_set.add(f'{table_name}.{col_name}')
elif isinstance(node, exp.Star):
for table_name in self.tables_in_query:
for col_name, _ in self._tables[table_name].items():
normalized_column_set.add(f'{table_name}.{col_name}')
return normalized_column_set


Expand Down Expand Up @@ -548,12 +525,15 @@ def blob_tables_required_for_query(self):

def _get_keyword_arg(self, exp_type):
value = None
for node, _, key in self._expression.walk():
for node in self._expression.walk():
if isinstance(node, exp_type):
if value is not None:
raise Exception(f'Multiple unexpected keywords found')
else:
value = float(node.args['this'].args['this'])
if node.args['this']:
value = float(node.args['this'].args['this'])
elif node.args['expression']:
value = float(node.args['expression'].args['this'])
return value


Expand Down Expand Up @@ -636,7 +616,7 @@ def is_valid_aqp_query(self):

@cached_property
def is_aqp_join_query(self):
return self._expression.find(exp.Join) is not None and self._expression.find(exp.UserFunction) is not None
return self._expression.find(exp.Join) is not None and self._expression.find(exp.Anonymous) is not None


@cached_property
Expand Down Expand Up @@ -693,39 +673,34 @@ def base_sql_no_where(self):


# FIXME: move it to sqlglot.rewriter
def add_where_condition(self, operator:str , where_condition):
def add_where_condition(self, where_condition):
expression = self._expression.copy()
re = Rewriter(expression)
new_sql = re.add_where(operator, where_condition)
return Query(new_sql.expression.sql(), self.config)
new_sql = expression.where(where_condition)
return Query(new_sql.sql(), self.config)


def add_select(self, selects):
def add_select(self, new_select):
expression = self._expression.copy()
re = Rewriter(expression)
new_sql = re.add_selects(selects)
return Query(new_sql.expression.sql(), self.config)
new_sql = expression.select(new_select)
return Query(new_sql.sql(), self.config)


def add_join(self, new_join):
expression = self._expression.copy()
re = Rewriter(expression)
new_sql = re.add_join(new_join)
return Query(new_sql.expression.sql(), self.config)
new_sql = expression.join(new_join)
return Query(new_sql.sql(), self.config)


def add_offset_keyword(self, offset):
expression = self._expression.copy()
if offset != 0:
offset_node = exp.Offset(this=exp.Literal(this=offset, is_string=False))
expression.set('offset', offset_node)
expression = expression.offset(offset)
return Query(expression.sql(), self.config)


def add_limit_keyword(self, limit):
expression = self._expression.copy()
limit_node = exp.Limit(this=exp.Literal(this=limit, is_string=False))
expression.set('limit', limit_node)
expression = expression.limit(limit)
return Query(expression.sql(), self.config)

@staticmethod
Expand Down Expand Up @@ -753,15 +728,15 @@ def _build_tree(elements, operator):

@cached_property
def is_udf_query(self):
for expression in self._expression.find_all(exp.UserFunction):
for expression in self._expression.find_all(exp.Anonymous):
if expression.args['this'] in self.config.user_defined_functions:
return True
return False


def check_udf_query_validity(self):
for expression in self._expression.find_all(exp.UserFunction):
if isinstance(expression.parent, exp.UserFunction):
for expression in self._expression.find_all(exp.Anonymous):
if isinstance(expression.parent, exp.Anonymous):
if (expression.parent.args['this'] in self.config.user_defined_functions
and expression.args['this'] in self.config.user_defined_functions):
raise Exception("AIDB does not support nested user defined function currently")
Expand All @@ -781,12 +756,12 @@ def check_udf_query_validity(self):

def _expression_contains_udf(self, expression):
# check if the expression contain user defined function
for expression in expression.find_all(exp.UserFunction):
for expression in expression.find_all(exp.Anonymous):
if expression.args['this'] in self.config.user_defined_functions:
return True

udf_output_to_alias_mapping, alias_to_udf_mapping = self.udf_outputs_aliases
for node, _, _ in expression.walk():
for node in expression.walk():
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.Column):
Expand Down Expand Up @@ -877,10 +852,9 @@ def add_udf(self, user_defined_function, udf_output_to_alias_mapping, is_select_
modified_query = QueryModifier()
normalized_query = self.query_after_normalizing
expression = normalized_query.get_expression().copy()

udf_output_to_alias_mapping, alias_to_udf_mapping = self.udf_outputs_aliases
for select_exp in expression.args['expressions']:
user_function = select_exp.find(exp.UserFunction)
user_function = select_exp.find(exp.Anonymous)
if user_function and user_function.args['this'] in aidb_user_defined_functions:
_ = modified_query.add_udf(user_function, udf_output_to_alias_mapping, is_select_col=True)
else:
Expand All @@ -892,8 +866,8 @@ def add_udf(self, user_defined_function, udf_output_to_alias_mapping, is_select_
# then remove this join condition
if expression.find(exp.Join) is not None:
for join_exp in expression.args['joins']:
if join_exp.args['on'] is not None:
user_function = join_exp.args['on'].find(exp.UserFunction)
if 'on' in join_exp.args:
user_function = join_exp.args['on'].find(exp.Anonymous)
if user_function and user_function.args['this'] in aidb_user_defined_functions:
filter_predicates.extend(self._convert_logical_condition_to_cnf(join_exp.args['on']).copy())
join_exp.set('on', None)
Expand All @@ -919,10 +893,10 @@ def add_udf(self, user_defined_function, udf_output_to_alias_mapping, is_select_
new_or_connected_fp = []
for fp in or_connected:
fp_copy = fp.copy()
for node, _, key in fp_copy.walk(bfs=False):
for node in fp_copy.walk(bfs=False):
if isinstance(node, exp.Expression) and self._check_in_subquery(node):
continue
if isinstance(node, exp.UserFunction):
if isinstance(node, exp.Anonymous):
if node.args['this'] in aidb_user_defined_functions:
output_alias = modified_query.add_udf(node.copy(), udf_output_to_alias_mapping)
if len(output_alias) > 1:
Expand All @@ -932,11 +906,11 @@ def add_udf(self, user_defined_function, udf_output_to_alias_mapping, is_select_
'then specify the filter conditions accordingly.')
converted_fp = exp.Column(this=exp.Identifier(this=output_alias[0]))
# FIXME: for IN operator
node.parent.set(key, converted_fp)
node.parent.set(node.arg_key, converted_fp)
else:
node_copy = node.copy()
modified_query.add_column_with_alias(node_copy)
node.parent.set(key, modified_query.col_index_mapping[node_copy])
node.parent.set(node.arg_key, modified_query.col_index_mapping[node_copy])

elif isinstance(node, exp.Column):
# for the column that uses udf alias, we directly use it
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ six==1.16.0
sniffio==1.3.0
SQLAlchemy==1.4.39
SQLAlchemy-Utils==0.41.1
sqlglot-aidb==0.0.10
sqlglot-aidb==0.1.3
starlette==0.37.2
statsmodels==0.14.0
sympy==1.11.1
Expand Down
Loading
Loading